Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
995b23e
[feat] SID: add SidRqkmeans model (FAISS-trained residual K-Means)
WhiteSwan1 Jun 5, 2026
c7f3a09
[review] SID: drop forced tail-checkpoint after on_train_end
WhiteSwan1 Jun 8, 2026
61ec842
[review] SID: address code-review findings on PR #539
WhiteSwan1 Jun 8, 2026
753f3fe
[review] SID: default normalize_residuals to False
WhiteSwan1 Jun 8, 2026
52c7452
[review] SID: encapsulation, comment, and import cleanups
WhiteSwan1 Jun 8, 2026
fbd973f
[review] SID: move FAISS fit-sample sizing into the quantizer
WhiteSwan1 Jun 8, 2026
893a627
[review] SID: log rank0 FAISS-fit failure with traceback
WhiteSwan1 Jun 8, 2026
3734fc2
[review] SID: clarify the reservoir ceil-div comment
WhiteSwan1 Jun 9, 2026
795c676
[review] SID: fix FAISS gpu kwarg + close test gaps from PR review
WhiteSwan1 Jun 9, 2026
2bb5abc
[review] SID: default FAISS fit to CPU + DDP fit-failure test
WhiteSwan1 Jun 9, 2026
33acbe6
[review] SID: log the FAISS fit device (CPU/GPU)
WhiteSwan1 Jun 9, 2026
25a1e30
Merge remote-tracking branch 'upstream/master' into sid-2-rqkmeans
WhiteSwan1 Jun 9, 2026
23c552c
[chore] bump version to 1.2.18
WhiteSwan1 Jun 9, 2026
3261c2c
[review] SID: address 23c552c review (test timeout, N>=K assert, cap …
WhiteSwan1 Jun 9, 2026
e6e4d00
Merge upstream/master into sid-2-rqkmeans; bump version to 1.2.19
WhiteSwan1 Jun 9, 2026
39017ab
[review] checkpoint_util: force only overrides the dedupe
WhiteSwan1 Jun 9, 2026
5afbd5e
[review] checkpoint maybe_save: clarify final vs force docstrings
WhiteSwan1 Jun 9, 2026
415b8a3
[refactor] SidRqkmeans: single-process only; raise under DDP
WhiteSwan1 Jun 9, 2026
b27eb7b
[refactor] SidRqkmeans: move DDP guard to __init__ (fail fast)
WhiteSwan1 Jun 9, 2026
6f7ae1d
[simplify] SidRqkmeans: drop dead max(1,...) cap clamp; fold test _bu…
WhiteSwan1 Jun 9, 2026
5827d5b
[style] ruff-format the __init__ DDP guard (collapse to one line)
WhiteSwan1 Jun 9, 2026
4e2e878
[refactor] SidRqkmeans: CPU-only — raise on visible CUDA, drop device…
WhiteSwan1 Jun 9, 2026
4773e2a
[simplify] train_offline: assert host input; single-copy float32 own
WhiteSwan1 Jun 9, 2026
df83d07
[refactor] KMeansLayer.predict: use torch.cdist; drop _squared_euclid…
WhiteSwan1 Jun 10, 2026
d037db7
[refactor] SidRqkmeans: drop input_embedding from predictions
WhiteSwan1 Jun 10, 2026
88856f3
[simplify] trim SID docstrings (predict provenance; stale SidRqvae xref)
WhiteSwan1 Jun 10, 2026
2fa312b
[refactor] extract reservoir sampling into ReservoirSampler (kmeans.py)
WhiteSwan1 Jun 10, 2026
e296c8d
[refactor] ReservoirSampler: log capacity + dim on construction
WhiteSwan1 Jun 10, 2026
892a8d2
[fix] SID code-review: fail-fast cap, skip pre-fit eval, dedup MSE, d…
WhiteSwan1 Jun 10, 2026
b14304a
[simplify] SID: raise (not assert) for cap guard; name normalize_resi…
WhiteSwan1 Jun 10, 2026
eb39b5e
[style] SID: trim verbose comments
WhiteSwan1 Jun 10, 2026
8bf50aa
[refactor] SID: move init_metric/update_metric to BaseSidModel + Rela…
WhiteSwan1 Jun 10, 2026
e8a3609
[test] SID: add sid_integration_test (train -> fit -> checkpoint -> e…
WhiteSwan1 Jun 10, 2026
3dfbde0
[test] checkpoint: verify force re-save overwrites the same step
WhiteSwan1 Jun 10, 2026
d67ccd1
[review] split quantizer tests by module; clarify copy=True
WhiteSwan1 Jun 10, 2026
6a736c5
[refactor] drop CheckpointManager force param; SID uses no periodic c…
WhiteSwan1 Jun 10, 2026
5bc89d4
[refactor] typed FaissKmeansConfig proto; drop Struct + _coerce_proto…
WhiteSwan1 Jun 10, 2026
feeb4af
[refactor] add QuantizeLayer base; KMeansLayer -> KMeansQuantizeLayer
WhiteSwan1 Jun 10, 2026
a5d43b2
[refactor] unify reconstruction key to x_hat; drop _reconstruction hook
WhiteSwan1 Jun 10, 2026
c4c361a
[style] SID: trim redundant comments
WhiteSwan1 Jun 10, 2026
db7f2be
[refactor] QuantizeLayer: make lookup concrete in the base
WhiteSwan1 Jun 10, 2026
ed12cff
[refactor] QuantizeLayer: own n_clusters/n_features in the base
WhiteSwan1 Jun 10, 2026
d2697eb
[refactor] SID: extract QuantizeLayer ABC; rename kmeans -> kmeans_qu…
WhiteSwan1 Jun 10, 2026
097e9eb
[docs] checkpoint_util: tighten maybe_save `final` param docstring
WhiteSwan1 Jun 10, 2026
a9a889c
[fix] SID: review fixes + fail-fast validation; fix integration test …
WhiteSwan1 Jun 10, 2026
3b41df9
[review] SID: doc fixes, negative tests, stronger integration assertions
WhiteSwan1 Jun 10, 2026
5f5af01
[review] SID: drop _extract_feature width guard (embedding width is n…
WhiteSwan1 Jun 10, 2026
43e84ca
[fix] SID integration test: skip on CUDA, run on CPU CI
WhiteSwan1 Jun 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tzrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,12 @@ def run_eval(step: int, epoch: int) -> None:
if lr.by_epoch:
lr.step()

# One-shot end-of-loop hook (default no-op; e.g. SidRqkmeans fits its FAISS
# codebook here). SID models run with periodic checkpointing disabled
# (save_checkpoints_steps/epochs = 0), so the tail final=True save below is
# the only checkpoint and persists whatever on_train_end produced.
_model.on_train_end()
Comment thread
WhiteSwan1 marked this conversation as resolved.

_log_train(
i_step,
losses,
Expand Down
58 changes: 58 additions & 0 deletions tzrec/metrics/relative_l1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2026, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torchmetrics import Metric


class RelativeL1(Metric):
"""Mean symmetric relative-L1 error ``|t - p| / (max(|t|, |p|) + eps)``.

A bounded reconstruction-error metric (0 = exact, → 1 = unrelated). It is a
verbatim port of OpenOneRec's residual-K-Means ``calc_loss`` and is
deliberately **not** ``torchmetrics.MeanAbsolutePercentageError``, which uses
the asymmetric ``|t - p| / |t|`` denominator. Aggregation is element-wise
(count-weighted), so the reported value is the mean over all elements seen.
"""

higher_is_better = False
is_differentiable = True

def __init__(self, epsilon: float = 1e-4, **kwargs) -> None:
super().__init__(**kwargs)
self.epsilon = epsilon
# float64 sum / long count: float32 loses integer precision past 2**24
# (~32K rows of a 512-dim embedding) under element-wise aggregation.
self.add_state(
"sum_rel",
default=torch.tensor(0.0, dtype=torch.float64),
dist_reduce_fx="sum",
)
self.add_state(
"count", default=torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum"
)

def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
"""Accumulate the relative-L1 error between ``preds`` and ``target``.

Args:
preds (Tensor): reconstruction, shape (B, D).
target (Tensor): ground-truth embedding, shape (B, D).
"""
rel = torch.abs(target - preds) / (
torch.maximum(torch.abs(target), torch.abs(preds)) + self.epsilon
)
self.sum_rel += rel.sum().double()
self.count += rel.numel()

def compute(self) -> torch.Tensor:
"""Mean relative-L1 over all elements (NaN before any update)."""
return self.sum_rel / self.count
49 changes: 49 additions & 0 deletions tzrec/metrics/relative_l1_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2026, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from tzrec.metrics.relative_l1 import RelativeL1


class RelativeL1Test(unittest.TestCase):
def test_zero_on_identity(self) -> None:
metric = RelativeL1()
x = torch.randn(8, 4)
metric.update(x, x.clone())
self.assertAlmostEqual(metric.compute().item(), 0.0, places=6)

def test_matches_formula(self) -> None:
metric = RelativeL1(epsilon=1e-4)
p = torch.tensor([[1.0, 0.0]])
t = torch.tensor([[0.0, 2.0]])
# |t-p|/(max(|t|,|p|)+eps): [1/(1+eps), 2/(2+eps)], mean of the two.
expected = (1.0 / (1.0 + 1e-4) + 2.0 / (2.0 + 1e-4)) / 2
metric.update(p, t)
self.assertAlmostEqual(metric.compute().item(), expected, places=5)

def test_count_weighted_across_updates(self) -> None:
"""Aggregation is element-wise, not a mean of per-batch means."""
metric = RelativeL1()
metric.update(torch.zeros(1, 4), torch.ones(1, 4)) # 4 elems, rel ~1
metric.update(torch.ones(3, 4), torch.ones(3, 4)) # 12 elems, rel 0
# Element-weighted: 4 nonzero over 16 elems -> ~0.25, NOT (1+0)/2 = 0.5.
per = 1.0 / (1.0 + 1e-4) # rel of a 0-vs-1 element (with epsilon)
self.assertAlmostEqual(metric.compute().item(), 4 * per / 16, places=6)

def test_nan_before_update(self) -> None:
self.assertTrue(torch.isnan(RelativeL1().compute()))


if __name__ == "__main__":
unittest.main()
9 changes: 9 additions & 0 deletions tzrec/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,15 @@ def compute_train_metric(self) -> Dict[str, torch.Tensor]:
metric_results[metric_name] = metric.compute()
return metric_results

def on_train_end(self) -> None:
"""Hook fired once after the train_eval loop exits.
Default no-op; override for one-shot end-of-loop work (e.g.
:class:`SidRqkmeans` fits its FAISS codebook here). The tail
``final=True`` checkpoint persists whatever it produced.
"""
return

def sparse_parameters(
self,
) -> Tuple[Iterable[nn.Parameter], Iterable[nn.Parameter]]:
Expand Down
60 changes: 52 additions & 8 deletions tzrec/models/sid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from tzrec.datasets.utils import BASE_DATA_GROUP, Batch
from tzrec.features.feature import BaseFeature
from tzrec.metrics.relative_l1 import RelativeL1
from tzrec.metrics.unique_ratio import UniqueRatio
from tzrec.models.model import BaseModel
from tzrec.protos.model_pb2 import ModelConfig
Expand All @@ -40,9 +41,9 @@ class BaseSidModel(BaseModel):

Subclasses build their quantizer in ``__init__`` (after calling
``super().__init__``) and implement :meth:`predict` and :meth:`loss`.
They extend :meth:`init_metric` (via ``super()``) and implement
:meth:`update_metric` to populate the registered metrics
(:meth:`update_train_metric` defaults to a no-op).
:meth:`predict` exposes the reconstruction under ``predictions["x_hat"]``
(only when meaningful) so the shared :meth:`update_metric` can score it.
(:meth:`update_train_metric` defaults to a no-op.)

Args:
model_config (ModelConfig): an instance of ModelConfig.
Expand All @@ -69,8 +70,17 @@ def __init__(
self._input_dim = cfg.input_dim
self._normalize_residuals = cfg.normalize_residuals

assert cfg.codebook, "codebook must be set, e.g. [256, 256, 256]"
if not cfg.codebook:
raise ValueError("codebook must be set, e.g. [256, 256, 256]")
self._n_embed_list = list(cfg.codebook)
# Fail fast: a zero codebook entry / input_dim==0 only errors opaquely
# deep inside faiss, after the whole training pass.
if any(k < 1 for k in self._n_embed_list):
raise ValueError(
f"every codebook entry must be >= 1, got {self._n_embed_list}"
)
if self._input_dim < 1:
raise ValueError(f"input_dim must be >= 1, got {self._input_dim}")
self._n_layers = len(self._n_embed_list)

def _extract_feature(
Expand Down Expand Up @@ -99,14 +109,48 @@ def init_loss(self) -> None:
def init_metric(self) -> None:
"""Initialize the eval metrics shared by all SID models.

``mse``: reconstruction error (input vs. quantized / decoded).
``unique_sid_ratio``: mean per-batch unique-SID ratio (distinct rows /
batch size; a batch-size-sensitive diversity proxy, not global
coverage). Subclasses call ``super().init_metric()`` then add extras.
- ``mse``: reconstruction error (input vs. quantized / decoded).
- ``rel_loss``: symmetric relative-L1 reconstruction error
(:class:`~tzrec.metrics.relative_l1.RelativeL1`); meaningful only with
``normalize_residuals=False`` (else the reconstruction and the input
live on different scales).
- ``unique_sid_ratio``: mean per-batch unique-SID ratio (distinct rows /
batch size; a batch-size-sensitive diversity proxy, not global
coverage).

Subclasses that add extras call ``super().init_metric()`` first.
"""
self._metric_modules["mse"] = torchmetrics.MeanSquaredError()
self._metric_modules["rel_loss"] = RelativeL1()
self._metric_modules["unique_sid_ratio"] = UniqueRatio()

def update_metric(
self,
predictions: Dict[str, torch.Tensor],
batch: Batch,
losses: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
"""Update eval metrics from the reconstruction + the re-extracted input.

``predictions["x_hat"]`` is the model's reconstruction of the input
embedding (the centroid sum for RQ-KMeans, the decoder output for
RQ-VAE). Subclasses expose it only when it is meaningful, so a
not-yet-fitted model omits it and this logs nothing. The target
embedding is re-extracted from ``batch`` (it is an input, not an output).

Args:
predictions (dict): a dict of predicted result.
batch (Batch): input batch data.
losses (dict, optional): a dict of loss.
"""
if "x_hat" not in predictions:
return
recon = predictions["x_hat"]
embedding = self._extract_feature(batch)
self._metric_modules["mse"].update(recon, embedding)
self._metric_modules["rel_loss"].update(recon, embedding)
self._metric_modules["unique_sid_ratio"].update(predictions["codes"])

def update_train_metric(
self,
predictions: Dict[str, torch.Tensor],
Expand Down
Loading
Loading