[Feature] Enable non-timm encoders with explicit embed_dim#426
Open
kitewatermelon wants to merge 8 commits into
Open
[Feature] Enable non-timm encoders with explicit embed_dim#426kitewatermelon wants to merge 8 commits into
kitewatermelon wants to merge 8 commits into
Conversation
- Rename encoder_name → model_or_model_name to reflect dual str/Module input (consistence with mae and ijepa) - Infer embed_dim automatically for timm str; require explicit embed_dim for custom encoder - Add embed_dim and model_or_model_name params to docstring
…methods - Add optional `embed_dim` parameter to all CLS-token SSL methods (SimCLR, BYOL, DINO, BarlowTwins, SwAV, MoCov2, MoCov3, SimSiam, WMSE, VICReg, TiCO, NNCLR, PIRL): inferred automatically from backbone.embed_dim for timm strings, or from the module attribute if present, with a clear error when neither is available — enables custom encoders without .embed_dim - Replace all BYOL-style dummy forward passes (torch.zeros(1,3,224,224)) with direct backbone.embed_dim access, removing 2D input shape hardcoding - Rename model_or_model_name → encoder_name in IJEPA, MAE, MaskedEncoder, and all related tests for consistency with the rest of the library
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Fixes two failure modes that prevented custom (non-timm) backbones from being used with SSL methods:
1. Missing
.embed_dimattributeMany encoders (e.g. MONAI ViT) do not expose a
.embed_dimattribute that timm models provide automatically. Passing such a module previously raised anAttributeErrorat construction time.2. Hardcoded 2D dummy forward pass
Several CLS-token methods inferred
embed_dimby runningbackbone(torch.zeros(1, 3, 224, 224)). This silently breaks any 3D encoder (e.g. volumetric MRI with shape(1, 1, 96, 96, 96)).Fix — explicit
embed_dimparameter across 13 CLS-token methods:Also —
encoder_nameAPI unification:MaskedEncoder,IJEPA,MAE, andLeJEPApreviously accepted the encoder argument asmodel_or_model_name. Renamed toencoder_nameto match all other methods in the library.Affected methods:
BarlowTwins,BYOL,DINO,MoCov2,MoCov3,NNCLR,PIRL,SimCLR,SimSiam,SwAV,TiCO,VICReg,WMSE,MaskedEncoder,IJEPA,MAE,LeJEPANote: Although
model_or_model_namemay be a more descriptive name,encoder_namewas chosen to match the existing convention already used across the majority of the codebase.Together,
encoder_nameandembed_dimmake themethodsAPI more predictable and consistent for users bringing custom backbones.Checklist