diff --git a/lejepa/multivariate/slicing.py b/lejepa/multivariate/slicing.py index 638279c..7609a83 100644 --- a/lejepa/multivariate/slicing.py +++ b/lejepa/multivariate/slicing.py @@ -139,7 +139,9 @@ def forward(self, x): proj_shape = (x.size(-1), self.num_slices) A = torch.randn(proj_shape, **dev, generator=g) - A /= A.norm(p=2, dim=0) + norms = A.norm(p=2, dim=0) + norms = torch.where(norms == 0.0, 1e-4, norms) + A /= norms self.global_step.add_(1) stats = self.univariate_test(x @ A)