Summary
The FutureState module currently outputs predicted future visual features ([B, 1440, 7, 7] × 4) during the forward pass, but these predictions are never used — there is no loss function that compares them against actual future observations. This issue proposes implementing the Feature Reconstruction Loss to activate FutureState as an auxiliary training signal and to enable introspection of model uncertainty.
Current State
FutureState.forward() predicts features at t+1.6s, t+3.2s, t+4.8s, t+6.4s
FrozenBackbone exists in the codebase but is unused
- The predicted future features are returned from
AutoE2E.forward() but discarded
Proposed: Feature Reconstruction Loss
How it works
During training with sequential frame data, we compare what the model predicted the future would look like against what the future actually looked like:
Frame t (current):
image_t → Backbone → Fusion → FutureState → predicted_features[t+1.6s]
Frame t+16 (actual 1.6s later, at 10Hz):
image_{t+16} → FrozenBackbone → Fusion → target_features[t+1.6s]
feature_reconstruction_loss = MSE(predicted_features, target_features.detach())
Why predict in feature space, not pixel space
Predicting future images (pixels) forces the model to reconstruct irrelevant details like exact texture, shadow positions, and compression artifacts. JEPA (Joint Embedding Predictive Architecture) [1] proposes predicting in a learned latent space instead, where only semantically meaningful changes are captured. MILE [2] demonstrated this principle specifically for autonomous driving — their latent world model outperformed pixel-prediction baselines.
Why use a FrozenBackbone for targets
If both the predictor and the target encoder are trainable, the network can trivially minimize loss by collapsing all representations to a constant (representation collapse). Using a frozen or slowly-updating target encoder prevents this. This asymmetric design is established in BYOL [3] and JEPA [1].
Two update strategies for the frozen encoder:
- Keep at pretrained initialization (simplest)
- Exponential Moving Average from trainable Backbone:
for p_frozen, p_train in zip(frozen.parameters(), trainable.parameters()):
p_frozen.data = τ * p_frozen.data + (1 - τ) * p_train.data # τ ≈ 0.996
How this helps DrivingPolicy (auxiliary task)
Training only with trajectory loss provides a weak gradient signal to the Backbone and View Fusion modules — the loss is far downstream and must propagate through many layers. Feature reconstruction loss provides a direct training signal to these early modules: "extract features that contain enough information to predict how the scene will change." This richer gradient signal improves the quality of the shared representation, which in turn benefits trajectory prediction.
UniAD [4] uses a similar multi-task design where perception auxiliary losses (detection, tracking, map segmentation) improve the shared BEV representation quality, leading to better final planning performance.
Integration with total loss
total_loss = trajectory_loss + λ * feature_reconstruction_loss
The weight λ balances the two objectives. Typical range: 0.5 – 1.0. If λ is too large, the model focuses on future prediction at the expense of trajectory accuracy. If too small, the auxiliary signal is negligible.
Proposed: Introspection via Prediction Error
Uncertainty estimation at inference time
After training, the feature reconstruction loss can be computed at inference time (when sequential frames are available) as a measure of model uncertainty:
prediction_error = MSE(predicted_features_t_plus_1_6, actual_features_t_plus_1_6)
A high prediction error at a given frame indicates the model encountered a situation it did not expect — the scene evolved in a way that differs from its internal world model. This can signal:
- Unusual driving scenarios (construction zones, accidents)
- Out-of-distribution situations (weather the model has not seen)
- Sudden events (pedestrian stepping onto road)
This prediction error can be monitored as a real-time safety signal: if the model's world model is consistently wrong, the system should reduce confidence in its trajectory predictions or request human intervention.
GAIA-1 [5] demonstrates that generative world models trained on driving data develop internal representations of scene dynamics that can be used for anomaly detection in exactly this way.
Implementation Scope
| Item |
Notes |
Wire FrozenBackbone output as target for FutureState loss |
Requires sequential frame pairs in DataLoader |
Add feature_reconstruction_loss computation |
MSE between predicted and actual features |
Add λ hyperparameter |
Config-level parameter |
| Add EMA update option for FrozenBackbone |
Post optimizer.step() hook |
| Add prediction_error logging at inference |
For introspection/monitoring |
Out of scope (separate issues)
- Training loop infrastructure (optimizer, scheduler, DataLoader)
- Trajectory loss implementation
- Dataset selection
- BEV segmentation auxiliary loss
References
- LeCun, Y. "A Path Towards Autonomous Machine Intelligence." 2022. https://openreview.net/pdf?id=BZ5a1r-kVsf
- Hu, A., et al. "Model-Based Imitation Learning for Urban Driving (MILE)." NeurIPS 2022. https://arxiv.org/abs/2210.07729
- Grill, J.B., et al. "Bootstrap Your Own Latent (BYOL)." NeurIPS 2020. https://arxiv.org/abs/2006.07733
- Hu, Y., et al. "Planning-oriented Autonomous Driving (UniAD)." CVPR 2023. https://arxiv.org/abs/2212.10156
- Hu, A., et al. "GAIA-1: A Generative World Model for Autonomous Driving." 2023. https://arxiv.org/abs/2309.17080
Summary
The
FutureStatemodule currently outputs predicted future visual features ([B, 1440, 7, 7] × 4) during the forward pass, but these predictions are never used — there is no loss function that compares them against actual future observations. This issue proposes implementing the Feature Reconstruction Loss to activate FutureState as an auxiliary training signal and to enable introspection of model uncertainty.Current State
FutureState.forward()predicts features at t+1.6s, t+3.2s, t+4.8s, t+6.4sFrozenBackboneexists in the codebase but is unusedAutoE2E.forward()but discardedProposed: Feature Reconstruction Loss
How it works
During training with sequential frame data, we compare what the model predicted the future would look like against what the future actually looked like:
Why predict in feature space, not pixel space
Predicting future images (pixels) forces the model to reconstruct irrelevant details like exact texture, shadow positions, and compression artifacts. JEPA (Joint Embedding Predictive Architecture) [1] proposes predicting in a learned latent space instead, where only semantically meaningful changes are captured. MILE [2] demonstrated this principle specifically for autonomous driving — their latent world model outperformed pixel-prediction baselines.
Why use a FrozenBackbone for targets
If both the predictor and the target encoder are trainable, the network can trivially minimize loss by collapsing all representations to a constant (representation collapse). Using a frozen or slowly-updating target encoder prevents this. This asymmetric design is established in BYOL [3] and JEPA [1].
Two update strategies for the frozen encoder:
How this helps DrivingPolicy (auxiliary task)
Training only with trajectory loss provides a weak gradient signal to the Backbone and View Fusion modules — the loss is far downstream and must propagate through many layers. Feature reconstruction loss provides a direct training signal to these early modules: "extract features that contain enough information to predict how the scene will change." This richer gradient signal improves the quality of the shared representation, which in turn benefits trajectory prediction.
UniAD [4] uses a similar multi-task design where perception auxiliary losses (detection, tracking, map segmentation) improve the shared BEV representation quality, leading to better final planning performance.
Integration with total loss
The weight
λbalances the two objectives. Typical range: 0.5 – 1.0. Ifλis too large, the model focuses on future prediction at the expense of trajectory accuracy. If too small, the auxiliary signal is negligible.Proposed: Introspection via Prediction Error
Uncertainty estimation at inference time
After training, the feature reconstruction loss can be computed at inference time (when sequential frames are available) as a measure of model uncertainty:
A high prediction error at a given frame indicates the model encountered a situation it did not expect — the scene evolved in a way that differs from its internal world model. This can signal:
This prediction error can be monitored as a real-time safety signal: if the model's world model is consistently wrong, the system should reduce confidence in its trajectory predictions or request human intervention.
GAIA-1 [5] demonstrates that generative world models trained on driving data develop internal representations of scene dynamics that can be used for anomaly detection in exactly this way.
Implementation Scope
FrozenBackboneoutput as target forFutureStatelossfeature_reconstruction_losscomputationλhyperparameterOut of scope (separate issues)
References