feat(world-model): FutureState feature-reconstruction (JEPA) loss + optional action conditioning Feat/future state jepa#80
Merged
m-zain-khawaja merged 3 commits intoJun 22, 2026
Conversation
… and optional action conditioning Adds Model/model_components/losses/feature_reconstruction_loss.py (FeatureReconstructionLoss: JEPA-style MSE between FutureState's 4 predicted future feature maps and frozen-backbone targets at +1.6/3.2/4.8/6.4s, with optional temporal weighting), exported from losses/. Extends FutureState with an optional action_dim conditioning (project the explicit trajectory into a per-channel bias for counterfactual rollouts); default action_dim=None keeps behaviour byte-identical. Relates to autowarefoundation#13. Signed-off-by: GABRIELA CORDOVA <100548769@alumnos.uc3m.es>
Signed-off-by: GABRIELA CORDOVA <100548769@alumnos.uc3m.es>
Reused variable 'per_step' was first a list[Tensor] then reassigned to a stacked Tensor, tripping mypy. Use a separate 'per_step_losses' list and keep 'per_step' as the stacked Tensor (no_implicit_optional / strict-friendly). Signed-off-by: GABRIELA CORDOVA <100548769@alumnos.uc3m.es>
3bdc609 to
47a209a
Compare
m-zain-khawaja
approved these changes
Jun 22, 2026
m-zain-khawaja
left a comment
Member
There was a problem hiding this comment.
Approved, thanks @gcordova10
gcordova10
added a commit
to gcordova10/auto_fsd
that referenced
this pull request
Jun 23, 2026
…ureState FutureState (autowarefoundation#80) predicts future BEV features and FeatureReconstructionLoss scores them, but the target side was missing. Per @ryotayamada's 'make it useful' list (autowarefoundation#56/autowarefoundation#13), add JepaTargetEncoder: a stop-gradient target encoder supporting both modes he named ('frozen or EMA'), plus compute_jepa_loss to fold the term into a training step. Additive and optional; does not touch AutoE2E.forward or the training loop. Frequency/horizons/weighting/data-pipeline remain open design decisions (Zain, 17-06 action item). Signed-off-by: GABRIELA CORDOVA <100548769@alumnos.uc3m.es>
This was referenced Jun 23, 2026
This was referenced Jun 29, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
Implements the loss side of #13 — FutureState feature-reconstruction (JEPA) — as Priority #1 per the maintainer feedback in #56 (@riita10069).
Two additive changes, both optional/non-breaking:
losses/feature_reconstruction_loss.py—FeatureReconstructionLoss: JEPA-style MSE between FutureState's 4 predicted future feature maps and frozen-backbone targets at +1.6/3.2/4.8/6.4 s, with optional temporal weighting and a zero-sum-weights guard.future_state.py— optional action conditioning (action_dim=None⇒ byte-identical to current). Enables counterfactual "what-if" rollouts by projecting the trajectory into a per-channel bias.AutoE2E.forward()contract unchanged. 16 tests added, 157 total green.Files changed
Model/model_components/losses/feature_reconstruction_loss.py(new)Model/model_components/losses/__init__.py(export)Model/model_components/future_state.py(optional path, default unchanged)Model/tests/test_feature_reconstruction_loss.py(16 tests)Relates to #13. See tracking in #56.