Skip to content

Fix numerical instability of A matrix#40

Open
yesung31 wants to merge 1 commit into
galilai-group:mainfrom
yesung31:main
Open

Fix numerical instability of A matrix#40
yesung31 wants to merge 1 commit into
galilai-group:mainfrom
yesung31:main

Conversation

@yesung31

Copy link
Copy Markdown

SIGReg was sometimes returning NaN. This fixes the divided by 0 issue.

Reproducible by the following script on MacBook Air M2.

import lejepa
import torch

torch.set_default_device("mps")
torch.set_default_dtype(torch.bfloat16)

rng = torch.Generator("mps").manual_seed(42)
a = torch.randn((2, 64, 1), generator=rng)

SIGReg = lejepa.multivariate.SlicingUnivariateTest(
    univariate_test=lejepa.univariate.EppsPulley(), num_slices=1024
)
SIGReg.global_step.fill_(5433)

loss = SIGReg(a)
print(SIGReg.global_step, loss)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant