Skip to content

[feat] SID: add SidRqkmeans model (FAISS-trained residual K-Means)#539

Merged
tiankongdeguiji merged 48 commits into
alibaba:masterfrom
WhiteSwan1:sid-2-rqkmeans
Jun 11, 2026
Merged

[feat] SID: add SidRqkmeans model (FAISS-trained residual K-Means)#539
tiankongdeguiji merged 48 commits into
alibaba:masterfrom
WhiteSwan1:sid-2-rqkmeans

Conversation

@WhiteSwan1

Copy link
Copy Markdown
Collaborator

Second of three PRs splitting the Semantic-ID models onto the shared base from #538. Adds the concrete RQ-KMeans backend on top of ResidualQuantizer / BaseSidModel; RQ-VAE follows in PR3.

  • tzrec/modules/sid/kmeans.py: KMeansLayer centroid container + recon_diagnostics.
  • tzrec/modules/sid/residual_kmeans_quantizer.py: ResidualKMeansQuantizer (FAISS-trained, FX-traceable forward, non-uniform per-layer codebooks).
  • tzrec/models/sid_rqkmeans.py: SidRqkmeans(BaseSidModel) - gradient -free; reservoir-samples embeddings during the train loop and fits FAISS once in on_train_end.
  • tzrec/models/model.py: BaseModel.on_train_end() no-op lifecycle hook.
  • tzrec/main.py: invoke on_train_end after the train loop and force the tail checkpoint so post-hook state is persisted.
  • protos: SidRqkmeans message + ModelConfig registration (601; 600 is reserved for SidRqvae in PR3).
  • tests: kmeans_test, ResidualKMeansQuantizerTest, sid_rqkmeans_test.

WhiteSwan1 and others added 7 commits June 5, 2026 07:10
Second of three PRs splitting the Semantic-ID models onto the shared
base from alibaba#538. Adds the concrete RQ-KMeans backend on top of
ResidualQuantizer / BaseSidModel; RQ-VAE follows in PR3.

- tzrec/modules/sid/kmeans.py: KMeansLayer centroid container +
  recon_diagnostics.
- tzrec/modules/sid/residual_kmeans_quantizer.py:
  ResidualKMeansQuantizer (FAISS-trained, FX-traceable forward,
  non-uniform per-layer codebooks).
- tzrec/models/sid_rqkmeans.py: SidRqkmeans(BaseSidModel) - gradient
  -free; reservoir-samples embeddings during the train loop and fits
  FAISS once in on_train_end.
- tzrec/models/model.py: BaseModel.on_train_end() no-op lifecycle hook.
- tzrec/main.py: invoke on_train_end after the train loop and force the
  tail checkpoint so post-hook state is persisted.
- protos: SidRqkmeans message + ModelConfig registration (601; 600 is
  reserved for SidRqvae in PR3).
- tests: kmeans_test, ResidualKMeansQuantizerTest, sid_rqkmeans_test.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Remove the `last_ckpt_step == i_step -> -1` override (and its stale
comment) in the train loop's end-of-loop hook. The normal checkpoint
cadence already persists the post-hook state.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- on_train_end() now returns is_ckpt_after_train; the tail save fires on
  `last_ckpt_step != i_step or is_ckpt_after_train`, so the fitted FAISS
  codebook is always persisted even when the last periodic checkpoint
  landed on the final step (main.py, model.py, sid_rqkmeans.py). (alibaba#1)
- DDP on_train_end: wrap the rank0 FAISS fit in try/except and broadcast a
  fit-status flag so a rank0-only failure (or an empty reservoir) makes all
  ranks raise together instead of deadlocking on the centroid broadcast;
  correct the empty-reservoir docstring. (alibaba#2, alibaba#3)
- KMeansLayer: cache is_initialized as a plain Python bool to drop a
  per-layer per-batch GPU->CPU .item() sync on the eval/predict path,
  kept in lockstep with the _is_initialized buffer. (alibaba#6)
- _reservoir_add: copy only the kept rows to host instead of the whole
  batch every training step (keep float64 for n_seen exactness). (alibaba#7)
- train_offline: per-layer fit-loss log now reports cumulative
  reconstruction of the original input (correct under normalize_residuals);
  align the module normalize_residuals default to True to match the proto. (alibaba#8, alibaba#10)
- Drop dead faiss_residual_kmeans (RQ-VAE-only, lands in PR3) and its test;
  tidy _coerce_proto_numbers into a comprehension.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Flip the default for RQ-KMeans residual normalization to False, in both
the SidRqkmeans proto field and the ResidualKMeansQuantizer constructor
(kept consistent to avoid the proto/module mismatch). This matches
OpenOneRec's residual k-means, which fits raw residuals with no per-layer
L2 normalization. Configs that set normalize_residuals explicitly are
unaffected.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- KMeansLayer: add mark_initialized_() so the buffer + cached-bool init
  flag is owned by the layer; the DDP broadcast in SidRqkmeans uses it
  instead of poking the private fields.
- SidRqkmeans: extract the reservoir-cap setup into _init_reservoir().
- residual_kmeans_quantizer: import faiss at module level (it's a pinned
  requirement) instead of a lazy in-function import; narrow
  train_offline(inputs) to torch.Tensor (all callers pass tensors) and
  drop the dead numpy branch.
- Tighten the verbose comments/docstrings across the SID files.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Add ResidualKMeansQuantizer.default_fit_sample_size() (max(K) *
max_points_per_centroid) so the FAISS default lives in the FAISS-owning
class; SidRqkmeans._init_reservoir asks the quantizer instead of reading
faiss_kwargs and hardcoding 256. Behavior-identical.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Use logger.exception() in on_train_end's rank0 except so the underlying
error's stack trace is captured (peers raise a coordinated RuntimeError
pointing at the rank0 log); drop the now-unused `as e`.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@WhiteSwan1 WhiteSwan1 added the claude-review Let Claude Review label Jun 8, 2026
Comment-only; pushed to re-trigger CI.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@WhiteSwan1 WhiteSwan1 added claude-review Let Claude Review and removed claude-review Let Claude Review labels Jun 9, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label Jun 9, 2026
Comment thread tzrec/modules/sid/residual_kmeans_quantizer.py Outdated
Comment thread tzrec/models/sid_rqkmeans.py Outdated
@github-actions

github-actions Bot commented Jun 9, 2026

Copy link
Copy Markdown

Review summary — SidRqkmeans (FAISS residual K-Means)

Solid, carefully engineered PR. The Vitter Algorithm R reservoir, the int-status-flag DDP fit/broadcast (sidestepping NCCL bool quirks), the cached _initialized mirror to avoid per-batch .item() syncs, and the mid-fit-checkpoint poison guard are all thoughtful and well-documented. Two concrete issues + a few test/doc notes below.

Should-fix (inline):

  • FAISS never runs on GPUgpu=torch.cuda.current_device() passes a device index, but faiss treats gpu as a GPU count where 0 (the common case, incl. the rank0-only DDP fit) is falsy → silent CPU fallback. The codebook fit is typically the largest one-shot cost, so this is a real perf regression.
  • gather_object on the NCCL group — should use the existing dist_util.get_dist_object_pg() gloo helper, as odps_dataset.py does for the same pattern.

Test gaps worth closing:

  • The reservoir Phase-2 steady-state replacement (accept prob cap/(n_seen+j+1), float64 exactness, slot collisions — the subtlest code in the PR) is never validated for correctness; test_reservoir_caps_memory only checks counts/shape, which pass regardless of replacement logic.
  • normalize_residuals=True is never exercised through train_offline — note its F.normalize is a second normalize site independent of _residual_pass, so the two can silently diverge.
  • The eval-metric path (init_metric/update_metric, mse/rel_loss, the eval-only quantized/input_embedding outputs) and the inference-mode (set_is_inference) codes-only contract are untested.
  • test_post_fit_checkpoint_round_trips asserts only codes.abs().sum() > 0; comparing against the source model's codes (torch.testing.assert_close) would actually verify the round-trip rather than just "not uninitialized".

Minor (docs/consistency):

  • on_train_end Returns: doc says False on empty reservoir, but the DDP path never returns False (it raises via the fail-fast status broadcast) — worth clarifying the two paths differ.
  • train_offline docstring ("moves only FAISS's subsampled working set to the GPU") understates the chunked index.search over all N rows; and the train_sample_size proto comment says K * max_points_per_centroid but the code uses max(K) * ... for non-uniform codebooks.

Should-fix:
- train_offline: faiss reads `gpu` as a GPU *count*, not a device index, so
  `gpu=current_device()` was 0 (single-GPU / rank0) -> falsy -> silent CPU
  fallback. Pass `gpu=1` so the fit actually runs on the (rank0) GPU.

Test gaps:
- reservoir Phase-2 replacement correctness (identifiable rows: intact,
  in-range, replacement actually occurs) — beyond the count/shape checks.
- normalize_residuals=True end-to-end through train_offline (the F.normalize
  site the other tests never reached).
- eval vs inference predict contract (quantized/input_embedding vs codes-only)
  and the init_metric/update_metric/compute_metric path.
- checkpoint round-trip now asserts codes match the source model exactly
  (assert_close), not merely non-zero.

Minor docs:
- on_train_end Returns: clarify only the single-process path returns False;
  DDP raises on an empty gather.
- train_offline docstring: the post-fit index.search streams all N in chunks.
- proto train_sample_size comment: K -> max(K) for non-uniform codebooks.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@WhiteSwan1 WhiteSwan1 added the claude-review Let Claude Review label Jun 9, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label Jun 9, 2026
Comment thread tzrec/modules/sid/residual_kmeans_quantizer.py Outdated
Comment thread tzrec/models/sid_rqkmeans.py Outdated
Comment thread tzrec/main.py Outdated
@github-actions

github-actions Bot commented Jun 9, 2026

Copy link
Copy Markdown

Review summary

Strong, carefully-engineered PR. The hard parts are handled well: FX-traceable read-only predict, the cached _initialized bool to avoid a per-batch GPU→CPU .item() sync, Phase-2 reservoir sampling that copies only accepted rows to host, the mid-fit-checkpoint poison rejection in _load_from_state_dict, and a deadlock-free DDP on_train_end (status flag broadcast before the centroid broadcast, int64 not bool for NCCL). Test coverage is unusually thorough (reservoir Phase-2 correctness, non-uniform codebooks, normalize path, exact checkpoint round-trip, 2-rank DDP happy path). I ran five focused review passes; most areas came back clean.

Three items worth a look (posted inline):

  1. gpu=1 selects all GPUs, not rank0's device (residual_kmeans_quantizer.py) — verified against faiss source: gpu=1 collapses to True (since 1 == True) → get_num_gpus()index_cpu_to_all_gpus(ngpu=ALL). With all GPUs visible per rank, the rank0-only fit allocates faiss temp memory across every rank's GPU. Results stay correct, but it contradicts the comment and risks OOM on tight runs. This is in the line the latest commit just changed, and CI (faiss-cpu, 0 GPUs) can't exercise it.

  2. The DDP coordinated-failure path is untested (sid_rqkmeans.py) — the status-flag-before-broadcast logic exists specifically to avoid a deadlock; a regression would hang rather than fail, with no test to catch it.

  3. The is_ckpt_after_train tail checkpoint is untested (main.py) — it's the only thing persisting the codebook for this FAISS-only model.

Minor (not posted inline):

  • _init_reservoir docstring says the per-rank cap targets default_fit_sample_size(), but the code targets train_sample_size when it's set (>0) and only falls back to the default at 0 — worth a one-line tweak.
  • The faiss end-to-end tests (test_on_train_end_runs_faiss, non-uniform, normalize, eval-contract) don't seed RNG and assert only "codes in range / centroids non-zero"; seeding + one reconstruction-error threshold would make them deterministic and able to catch a broken residual walk.
  • A misspelled faiss_kmeans_kwargs key (e.g. n_iter) only fails at fit time — i.e. after a full training run. Validating keys at construction would fail fast.

Nice work overall — the engineering rigor here is clearly high.

WhiteSwan1 and others added 4 commits June 9, 2026 06:02
The earlier gpu=1 "fix" was itself wrong and broke the GPU unittest_ci
(cpu_ci/h20 passed): faiss reads `gpu` as a COUNT and 1 == True collapses
to all-GPUs, so the rank0-only fit sharded over every rank's device and the
GPU faiss path (newly activated — it was a silent CPU fallback before) failed
on the tiny test data. faiss's count kwarg cannot pin to a single device, so
default the fit to CPU (a bounded one-shot; set gpu in faiss_kmeans_kwargs to
opt in explicitly).

Also:
- _init_reservoir docstring: note the cap targets train_sample_size when set,
  else default_fit_sample_size().
- Add test_on_train_end_ddp_rank0_failure: forces rank0's fit to raise and
  asserts every rank raises the coordinated RuntimeError, with join(timeout)
  so a reintroduced deadlock fails CI instead of hanging.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Announce CPU vs GPU + N/D at the start of train_offline so the CPU
default isn't silent (configs that don't set faiss_kmeans_kwargs.gpu
now fit on CPU). Gated by verbose (on_train_end passes verbose=True).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Merge upstream/master (1.2.17, incl. alibaba#540 DlrmHSTU fix) and bump.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@WhiteSwan1 WhiteSwan1 added the claude-review Let Claude Review label Jun 9, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label Jun 9, 2026
Comment thread tzrec/modules/sid/residual_kmeans_quantizer.py
Comment thread tzrec/modules/sid/residual_kmeans_quantizer.py Outdated
@WhiteSwan1 WhiteSwan1 added the claude-review Let Claude Review label Jun 10, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label Jun 10, 2026
Comment thread tzrec/main.py
Comment thread tzrec/tests/configs/sid_rqkmeans_mock.config
Comment thread tzrec/modules/sid/kmeans_quantize.py
Comment thread tzrec/modules/sid/kmeans_quantize.py Outdated
Comment thread tzrec/modules/sid/residual_kmeans_quantizer.py Outdated
Comment thread tzrec/metrics/relative_l1.py Outdated
Comment thread tzrec/protos/models/sid_model.proto Outdated
Comment thread tzrec/models/sid_rqkmeans.py
@github-actions

Copy link
Copy Markdown

Code Review Summary

Reviewed by parallel subagents (code quality, performance, test coverage, documentation accuracy, security), with findings cross-checked before posting.

Overall this is carefully written code. The fail-fast guards (reservoir cap vs. largest K at construction), the mid-fit checkpoint poison detection, the strictly-typed FAISS kwargs proto, the bounded reservoir with float64 stream counters, and the exact-round-trip checkpoint tests all show real attention to edge cases. Test coverage of the prediction contracts, non-uniform codebooks, and checkpoint poisoning is genuinely strong.

Main concern (inline on tzrec/main.py)

Four of five reviewers independently converged on the same issue: the "periodic checkpointing disabled" assumption is documented in three places but enforced nowhere. save_checkpoints_steps defaults to 1000, maybe_save(final=True) is still subject to the per-step dedupe, and a periodic save landing on the final step silently discards the fitted codebook — the surviving checkpoint loads cleanly and emits all-zero SIDs. This is the one silent-data-loss path in the PR and worth closing before merge (details + fix options inline).

Lower-priority items (no inline comment)

  • import tzrec now hard-depends on faiss at import time: residual_kmeans_quantizer.py imports faiss at module level and auto_import eagerly loads all model modules. Since faiss is in all requirements files this works, but it makes the eleven skipTest("faiss not installed") guards in the new tests unreachable (collection would fail first), and faiss.contrib.torch_utils monkey-patches faiss process-wide as an import side effect. Moving the import inside train_offline would decouple this.
  • x_hat is exposed even with normalize_residuals=True, where the inline comment itself notes reconstruction metrics are not meaningful — mse/rel_loss then compare across scales with no warning. Consider gating or warning.
  • Reservoir subsampling uses the unseeded global RNG, so once the corpus exceeds the cap, two runs with identical config (including faiss seed) produce different codebooks/SIDs. A config-seeded torch.Generator would make SID generation reproducible.
  • train_offline's out buffer exists only for the verbose log, and the per-layer diagnostics materialize several extra (N, D) temporaries — fine at the default cap, but a large train_sample_size multiplies peak host memory ~3-4x right at the end of the run. Folding scalar diagnostics into the existing chunk loop would fix it.
  • Integration test asserts only that eval_result.txt exists — an unfitted-restore regression produces NaN metrics but still passes. Parsing the file and asserting mse/rel_loss are finite would catch that whole class (including the dedupe issue above).
  • No user-facing docs: no docs/source/models/ page or README model-table row for the new sid_rqkmeans config. Fine if planned for PR3, but the save_checkpoints=0 requirement especially needs documenting somewhere users will see it.

🤖 Generated with Claude Code

…CPU pin

- sid_integration_test: force CPU with CUDA_VISIBLE_DEVICES="-1" not ""
  (empty is treated inconsistently across CUDA runtimes; the GPU CI runner
  didn't hide devices, tripping the CPU-only guard in the train_eval child).
- BaseSidModel: validate codebook entries >=1 and input_dim >=1 at construction;
  guard feature width in _extract_feature (a (B,1) tensor would otherwise
  broadcast into a degenerate rank-1 codebook). assert -> raise.
- residual_kmeans_quantizer / kmeans_quantize: assert -> raise for the
  data-corruption guards (N>=max_k, load_centroids_ shape, CPU/shape contract)
  so they survive python -O.
- RelativeL1: float64 sum / long count to avoid float32 rounding past 2**24.
- kmeans_quantize: drop the duplicate relative_l1/recon_diagnostics helpers;
  RelativeL1 (tzrec/metrics) is the single home of the formula. Per-layer
  offline-fit log now reports MSE only.
- sid_rqkmeans: TODO documenting the periodic-checkpointing-disabled contract
  (codebook can be dropped by save dedupe otherwise).
- sid_model.proto: drop stale "(global, across all ranks)" wording.
- mock config: set save_checkpoints_steps/epochs = 0 (the documented convention).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@WhiteSwan1 WhiteSwan1 added the claude-review Let Claude Review label Jun 10, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label Jun 10, 2026
Comment thread tzrec/models/sid_rqkmeans.py
Comment thread tzrec/modules/sid/residual_kmeans_quantizer.py Outdated
Comment thread tzrec/models/sid_rqkmeans.py
Comment thread tzrec/models/sid_model.py Outdated
Comment thread tzrec/tests/sid_integration_test.py Outdated
@github-actions

Copy link
Copy Markdown

Code Review Summary

Reviewed at head a9a889c by parallel subagents (code quality, performance, test coverage, docs accuracy, multi-process safety), findings cross-checked before posting. Prior-round items were verified first: the DDP coordination concerns are resolved by removal (single-process guard in __init__, raised on all ranks before any collective), the faiss gpu kwarg is stripped, tests are now seeded with real assertions, the reservoir-cap and faiss-kwargs validation landed, and the checkpoint-dedupe risk is explicitly documented as a TODO. The fail-fast posture of this commit (raise-not-assert with the python -O rationale, typed faiss kwargs, width guards) is genuinely good work.

Inline comments (5)

  1. CUDA-guard message recommends the workaround this commit proved unreliable — the error says CUDA_VISIBLE_DEVICES="", while the integration-test fix in the same commit switched to "-1" because "" doesn't hide devices on the GPU CI runtime. One-line fix.
  2. NaN/Inf embeddings defeat the poison guard (residual_kmeans_quantizer.py) — a non-finite row from external data becomes NaN centroids, and the all-zero poison check (abs().sum() == 0) passes a NaN codebook as healthy → silent degenerate SIDs. One torch.isfinite check in train_offline closes it. Also: the adjacent comment mis-states faiss (it throws for N < K; warn-only is N < K * min_points_per_centroid).
  3. Empty-vs-small reservoir inconsistency (on_train_end) — empty warns and exits 0 with an unfitted checkpoint persisted; 0 < N < max(K) hard-fails after the whole pass. The warn path is the same silent zero-SID artifact the dedupe TODO describes; suggest raising in both cases.
  4. The new fail-fast validations have no negative tests — all four BaseSidModel ValueError paths and both train_offline RuntimeError guards are uncovered; a revert passes the suite. ~4 lines each with the existing helpers.
  5. Integration test docstring overstates coverage — it claims finite-metric verification but asserts only file existence; parsing eval_result.txt for finite mse/rel_loss would also give the dedupe TODO an end-to-end regression net.

Minor (no inline)

  • train_offline still allocates and accumulates the (N, D) out buffer when verbose=False (the commit gated the x0 clone but not out); in the default normalize_residuals=False path the logged loss reduces to mean(x**2), making out avoidable entirely.
  • RelativeL1.update has no shape check — broadcasting silently corrupts the metric; torchmetrics convention is _check_same_shape at the top of update. The float64-accumulation fix also has no regression test.
  • BaseSidModel class docstring lists mse/unique_sid_ratio but omits rel_loss (the init_metric docstring is correct).
  • Reservoir subsampling still uses the unseeded global RNG, so identical configs (incl. faiss seed) give different codebooks — noted in a prior round; fine if accepted, but the proto seed field implies determinism it can't deliver.
  • Docs page / README model-table row still absent — assuming this is deferred to the PR3 wrap-up; the save_checkpoints_steps: 0 and CPU-only/single-process constraints especially need a user-visible home.

🤖 Generated with Claude Code

WhiteSwan1 and others added 3 commits June 10, 2026 12:22
- CPU-only guard message recommends CUDA_VISIBLE_DEVICES="-1" (not "", which
  this PR found unreliable on the GPU CI runner).
- Correct the train_offline comment: faiss throws (not warns) for N < K.
- Add negative tests for the fail-fast guards: empty/zero codebook, input_dim<1,
  feature-width mismatch, and train_offline too-few-points / wrong-dim.
- sid_integration_test: assert the post-fit eval reports finite mse/rel_loss/
  unique_sid_ratio (rel_loss < 1.0, unique_sid_ratio > 0) so a degenerate /
  unfitted codebook can't keep the test green.
- Trim verbose comments.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ever 1)

The (B, 1) broadcast footgun isn't reachable in practice, so revert
_extract_feature to the plain feature read and remove its negative test.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The end-to-end train_eval is CPU-only (SidRqkmeans refuses a visible CUDA
device). Forcing CPU on the CUDA-built GPU CI image is unreliable (the prior
CUDA_VISIBLE_DEVICES="" / "-1" workarounds both still failed in the train_eval
child). Skip when CUDA is available so the test runs on the CPU CI job (where
it passes) and skips on the GPU runner. Keep nproc=1 for the single-process
guard.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@WhiteSwan1

Copy link
Copy Markdown
Collaborator Author

Thanks for both github-actions passes — went through them all. Dispositions:

Fixed:

  • proto wording + codebook≥1/input_dim≥1 validation;
  • removed the duplicated relative-L1 helper (the metric is now the single source);
  • float64/long RelativeL1 accumulators;
  • assert→raise for the -O-stripped data-corruption guards;
  • mock config save_checkpoints=0;
  • corrected the faiss-N<K comment;
  • guard message now says CUDA_VISIBLE_DEVICES="-1";
  • added negative tests for every fail-fast path and strengthened the integration test to assert finite reconstruction metrics.

Declined / deferred (with reason):

  • __init__ CPU/world_size guards stay in __init__ — v1 is intentionally CPU-only end to end (training and inference), so failing fast at every entry point is the intended contract. Because KMeans inference is so simple and only involves nearest-neighbor search, it's unlikely we'll need GPU inference down the road.
  • Checkpoint-dedupe enforcement — kept the save_checkpoints=0 convention as agreed; left a TODO to harden it in a follow-up.
  • ReservoirSampler (B,1) width guard — the embedding width is never 1 in practice, so reverted.
  • Empty-vs-small-reservoir consistency and a NaN/Inf input check — reasonable, but behavior/logic changes better suited to a follow-up.

Also fixed the CPU-only end-to-end test that was failing on the GPU CI image: it now runs on the CPU CI job and skips where CUDA is present.

WhiteSwan1 added a commit to WhiteSwan1/TorchEasyRec that referenced this pull request Jun 11, 2026
…_abstract

Brings the reviewed alibaba#539 foundation onto feat/sid_abstract (which already
carries alibaba#538 + an older RQ-VAE/RQ-Kmeans port), and syncs to upstream/master
(alibaba#540, alibaba#541, which alibaba#539 already contains).

Conflict resolutions:
- sid_rqkmeans.py(+test), residual_kmeans_quantizer.py, sid_model.py:
  take alibaba#539's canonical versions (BaseSidModel now hosts both SID models,
  with mse/rel_loss/unique_sid_ratio and the unified x_hat recon key).
- types.py: union — keep alibaba#539's QuantizeOutput, retain feat's
  QuantizeForwardMode enum + ResidualQuantizerOutput (RQ-VAE needs them).
- protos/models/sid_model.proto: union — alibaba#539's typed FaissKmeansConfig +
  clean SidRqkmeans, re-add feat's SinkhornConfig/ClipConfig/SidRqvae;
  drop the now-unused struct.proto import.
- protos/model.proto: enable `SidRqvae sid_rqvae = 600;` (the field alibaba#539
  reserved for this follow-up).
- main.py / model.py on_train_end: take alibaba#539's wording; drop feat's forced
  tail-checkpoint (SID models rely on the final=True tail save).

Transitional state: old modules/sid/kmeans.py still coexists with alibaba#539's
kmeans_quantize.py, and the RQ-VAE stack is still on the old abstraction —
both retired in the follow-up refactor commit. All SID modules import.
WhiteSwan1 added a commit to WhiteSwan1/TorchEasyRec that referenced this pull request Jun 11, 2026
… kmeans.py

Refactors the RQ-VAE stack onto the reviewed alibaba#539 SID foundation so the VQ
and K-Means backends share one per-layer interface, and keeps RQ-VAE
device-agnostic (CPU + GPU) — unlike the deliberately CPU-only SidRqkmeans.

- VectorQuantize now subclasses QuantizeLayer: implements quantize() ->
  QuantizeOutput and get_codebook_embeddings(); forward() delegates to
  quantize() so standalone vq(x) still works; the base lookup() returns the
  raw codebook vector embedding.weight[ids].
- ResidualVectorQuantizer drives the layer through the ABC
  (layer.quantize / layer.lookup / get_codebook_embeddings) instead of
  reaching into layer.embedding directly; behavior (raw-vector accumulation,
  STE-on-aggregate) is unchanged.
- SidRqvae drops its update_metric override; alibaba#539's BaseSidModel now scores
  mse + rel_loss + unique_sid_ratio off predictions["x_hat"]/["codes"].
  The train-path mse override stays (RQ-VAE has a train reconstruction).
- Retire modules/sid/kmeans.py (replaced by alibaba#539's kmeans_quantize.py):
  relocate faiss_residual_kmeans into kmeans_quantize.py (CPU fit, centroids
  returned on the input device — safe from a GPU-resident RQ-VAE) and
  _squared_euclidean_distance into vector_quantize.py (its only user); drop
  the now-orphaned KMeansLayer / recon_diagnostics. Tests for the two moved
  helpers migrate to kmeans_quantize_test.py / vector_quantize_test.py.

CPU + GPU: no hard-CUDA assumptions; the only device-sensitive path is the
optional FAISS kmeans_init, which fits on CPU and moves centroids to the
module's device (DDP: fit on rank 0, broadcast). Sinkhorn's all_reduce works
under gloo and nccl.

Verified on CPU: all SID unit tests pass (quantize_layer, vector_quantize,
kmeans_quantize, residual_quantizer, residual_kmeans_quantizer, relative_l1,
sid_rqvae, sid_rqkmeans, residual_vector_quantizer_dist). ruff check/format
clean. GPU smoke + the full sid_integration_test must run in the torchgpuv4
container (this shell's CUDA driver is too old and has a stale installed
tzrec; checkpoint_util import already fails there independent of this change).
@WhiteSwan1

Copy link
Copy Markdown
Collaborator Author

Thanks @tiankongdeguiji — all addressed on latest (5f5af01):

Round 1

  1. input_embedding in predictions — removed; predict exposes only the reconstruction x_hat (eval, once fitted), and update_metric re-extracts the target from the batch.
  2. Integration test — moved to tzrec/tests/sid_integration_test.py (with a mock config); it also asserts the post-fit eval reports finite metrics.
  3. init_metric/update_metricBaseSidModel; rel_loss is now a torchmetric — a symmetric RelativeL1 rather than MeanAbsolutePercentageError (MAPE's |t-p|/|t| is asymmetric and unbounded as target→0; the symmetric form matches OpenOneRec's calc_loss). Happy to switch to the builtin if you prefer.
  4. DDP — removed; CPU-only, single-process, with a world_size>1 fail-fast in __init__.
  5. torch.cdist(x, centroids).argmin(-1) — adopted; _squared_euclidean_distance deleted.

Round 2

  1. File renamed to residual_kmeans_quantizer.py.
  2. copy=True — removed; train_offline now explicitly consumes its input (documented), and the caller decides whether to copy.
  3. Added a QuantizeLayer base class (tzrec/modules/sid/quantize_layer.py); KMeansQuantizeLayer subclasses it and the PR3 VectorQuantizeLayer will too.
  4. Replaced the Struct + _coerce_proto_numbers with a typed FaissKmeansConfig proto message (strictly-typed faiss kwargs).
  5. Dropped the force param; SID sets save_checkpoints_steps = save_checkpoints_epochs = 0.

@tiankongdeguiji tiankongdeguiji merged commit 3d4d5a8 into alibaba:master Jun 11, 2026
7 of 8 checks passed
WhiteSwan1 added a commit to WhiteSwan1/TorchEasyRec that referenced this pull request Jun 11, 2026
…t/sid_abstract

upstream/master advanced to 3d4d5a8: alibaba#539 (SidRqkmeans) was merged as a squash
and alibaba#542 (docs) landed. feat already carries alibaba#539's content (from the earlier
pr-539 merge), so this is effectively the alibaba#542 doc update plus reconciling the
SID proto/types surface.

Conflicts (all add/add on files feat extended for RQ-VAE) resolved by taking
feat's superset:
- modules/sid/types.py: keep QuantizeForwardMode + ResidualQuantizerOutput
  (master's alibaba#539 has only QuantizeOutput, which is identical and auto-merged).
- protos/models/sid_model.proto: keep SinkhornConfig/ClipConfig/SidRqvae on top
  of the shared FaissKmeansConfig + SidRqkmeans.
- protos/model.proto: keep `SidRqvae sid_rqvae = 600;` (master still has it
  reserved as a comment).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants