Skip to content

Support custom (non-timm) encoders via explicit embed_dim across all methods #425

Description

@kitewatermelon

Thank you for this amazing library! 🎉

I'd like to propose adding explicit embed_dim support across all methods in stable_pretraining.methods to better support custom (non-timm) encoders.

Problem

Currently, several methods (e.g. LeJEPA) are designed around timm models, which expose .embed_dim automatically. This causes two failure modes when using custom backbones:

Case 1 — Missing .embed_dim attribute (e.g. MONAI ViT)

Many popular medical imaging libraries like MONAI provide ViT implementations that do not expose a .embed_dim attribute, resulting in a runtime error even when passing a pre-instantiated nn.Module.

Case 2 — BYOL-style dummy forward pass

An alternative inference approach runs a dummy forward pass to infer embed_dim:

with torch.no_grad():
    embed_dim = base_backbone(torch.zeros(1, 3, 224, 224)).shape[-1]

However, this hardcodes a 2D input shape, making it incompatible with 3D encoders (e.g. volumetric MRI with shape (1, 1, 96, 96, 96)).

Proposed Solution

Make embed_dim an explicit optional argument across all methods — inferred automatically from .embed_dim when a timm string is passed, and required when a custom nn.Module is provided:

# timm string → embed_dim inferred automatically
model = LeJEPA("vit_base_patch16_224")

# Custom nn.Module (MONAI, 3DViT, etc.) → embed_dim provided explicitly
model = LeJEPA(encoder_name=my_monai_vit, embed_dim=768)
model = LeJEPA(encoder_name=my_3d_vit, embed_dim=768)

This would extend the library's applicability beyond natural image domains to medical imaging and other 3D modalities.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions