From 57b1c22eecfd228154cfa65eb6f8d7cc5bb87d34 Mon Sep 17 00:00:00 2001 From: Yesung Hwang Date: Wed, 25 Feb 2026 18:28:35 +0100 Subject: [PATCH] Fix numerical instability of A matrix --- lejepa/multivariate/slicing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lejepa/multivariate/slicing.py b/lejepa/multivariate/slicing.py index 638279c..7609a83 100644 --- a/lejepa/multivariate/slicing.py +++ b/lejepa/multivariate/slicing.py @@ -139,7 +139,9 @@ def forward(self, x): proj_shape = (x.size(-1), self.num_slices) A = torch.randn(proj_shape, **dev, generator=g) - A /= A.norm(p=2, dim=0) + norms = A.norm(p=2, dim=0) + norms = torch.where(norms == 0.0, 1e-4, norms) + A /= norms self.global_step.add_(1) stats = self.univariate_test(x @ A)