Skip to content

[Feature] Enable non-timm encoders with explicit embed_dim#426

Open
kitewatermelon wants to merge 8 commits into
galilai-group:mainfrom
kitewatermelon:main
Open

[Feature] Enable non-timm encoders with explicit embed_dim#426
kitewatermelon wants to merge 8 commits into
galilai-group:mainfrom
kitewatermelon:main

Conversation

@kitewatermelon

Copy link
Copy Markdown
Contributor

Description

Fixes two failure modes that prevented custom (non-timm) backbones from being used with SSL methods:

1. Missing .embed_dim attribute
Many encoders (e.g. MONAI ViT) do not expose a .embed_dim attribute that timm models provide automatically. Passing such a module previously raised an AttributeError at construction time.

2. Hardcoded 2D dummy forward pass
Several CLS-token methods inferred embed_dim by running backbone(torch.zeros(1, 3, 224, 224)). This silently breaks any 3D encoder (e.g. volumetric MRI with shape (1, 1, 96, 96, 96)).

Fix — explicit embed_dim parameter across 13 CLS-token methods:

# timm string → embed_dim inferred automatically (no change for existing users)
model = SimCLR("vit_base_patch16_224", projector_dims=(2048, 256))

# Custom nn.Module → embed_dim provided explicitly
model = SimCLR(my_monai_vit, projector_dims=(2048, 256), embed_dim=768)
model = SimCLR(my_3d_vit,    projector_dims=(2048, 256), embed_dim=512)

Also — encoder_name API unification:
MaskedEncoder, IJEPA, MAE, and LeJEPA previously accepted the encoder argument as model_or_model_name. Renamed to encoder_name to match all other methods in the library.

Affected methods: BarlowTwins, BYOL, DINO, MoCov2, MoCov3, NNCLR, PIRL, SimCLR, SimSiam, SwAV, TiCO, VICReg, WMSE, MaskedEncoder, IJEPA, MAE, LeJEPA

Note: Although model_or_model_name may be a more descriptive name, encoder_name was chosen to match the existing convention already used across the majority of the codebase.

Together, encoder_name and embed_dim make the methods API more predictable and consistent for users bringing custom backbones.

Checklist

  • I have read the Contributing document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR to the RELEASES.rst file.

- Rename encoder_name → model_or_model_name to reflect dual str/Module input (consistence with mae and ijepa)
- Infer embed_dim automatically for timm str; require explicit embed_dim for custom encoder
- Add embed_dim and model_or_model_name params to docstring
…methods

- Add optional `embed_dim` parameter to all CLS-token SSL methods (SimCLR,
  BYOL, DINO, BarlowTwins, SwAV, MoCov2, MoCov3, SimSiam, WMSE, VICReg,
  TiCO, NNCLR, PIRL): inferred automatically from backbone.embed_dim for
  timm strings, or from the module attribute if present, with a clear error
  when neither is available — enables custom encoders without .embed_dim
- Replace all BYOL-style dummy forward passes (torch.zeros(1,3,224,224))
  with direct backbone.embed_dim access, removing 2D input shape hardcoding
- Rename model_or_model_name → encoder_name in IJEPA, MAE, MaskedEncoder,
  and all related tests for consistency with the rest of the library
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant