Thank you for this amazing library! 🎉
I'd like to propose adding explicit embed_dim support across all methods in stable_pretraining.methods to better support custom (non-timm) encoders.
Problem
Currently, several methods (e.g. LeJEPA) are designed around timm models, which expose .embed_dim automatically. This causes two failure modes when using custom backbones:
Case 1 — Missing .embed_dim attribute (e.g. MONAI ViT)
Many popular medical imaging libraries like MONAI provide ViT implementations that do not expose a .embed_dim attribute, resulting in a runtime error even when passing a pre-instantiated nn.Module.
Case 2 — BYOL-style dummy forward pass
An alternative inference approach runs a dummy forward pass to infer embed_dim:
with torch.no_grad():
embed_dim = base_backbone(torch.zeros(1, 3, 224, 224)).shape[-1]
However, this hardcodes a 2D input shape, making it incompatible with 3D encoders (e.g. volumetric MRI with shape (1, 1, 96, 96, 96)).
Proposed Solution
Make embed_dim an explicit optional argument across all methods — inferred automatically from .embed_dim when a timm string is passed, and required when a custom nn.Module is provided:
# timm string → embed_dim inferred automatically
model = LeJEPA("vit_base_patch16_224")
# Custom nn.Module (MONAI, 3DViT, etc.) → embed_dim provided explicitly
model = LeJEPA(encoder_name=my_monai_vit, embed_dim=768)
model = LeJEPA(encoder_name=my_3d_vit, embed_dim=768)
This would extend the library's applicability beyond natural image domains to medical imaging and other 3D modalities.
Thank you for this amazing library! 🎉
I'd like to propose adding explicit
embed_dimsupport across all methods instable_pretraining.methodsto better support custom (non-timm) encoders.Problem
Currently, several methods (e.g. LeJEPA) are designed around timm models, which expose
.embed_dimautomatically. This causes two failure modes when using custom backbones:Case 1 — Missing
.embed_dimattribute (e.g. MONAI ViT)Many popular medical imaging libraries like MONAI provide ViT implementations that do not expose a
.embed_dimattribute, resulting in a runtime error even when passing a pre-instantiatednn.Module.Case 2 — BYOL-style dummy forward pass
An alternative inference approach runs a dummy forward pass to infer embed_dim:
However, this hardcodes a 2D input shape, making it incompatible with 3D encoders (e.g. volumetric MRI with shape (1, 1, 96, 96, 96)).
Proposed Solution
Make
embed_diman explicit optional argument across all methods — inferred automatically from.embed_dimwhen a timm string is passed, and required when a customnn.Moduleis provided:This would extend the library's applicability beyond natural image domains to medical imaging and other 3D modalities.