Skip to content

Proposal: Feature Reconstruction Loss for FutureState (Auxiliary Self-Supervised Task) #13

Description

@riita10069

Summary

The FutureState module currently outputs predicted future visual features ([B, 1440, 7, 7] × 4) during the forward pass, but these predictions are never used — there is no loss function that compares them against actual future observations. This issue proposes implementing the Feature Reconstruction Loss to activate FutureState as an auxiliary training signal and to enable introspection of model uncertainty.

Current State

  • FutureState.forward() predicts features at t+1.6s, t+3.2s, t+4.8s, t+6.4s
  • FrozenBackbone exists in the codebase but is unused
  • The predicted future features are returned from AutoE2E.forward() but discarded

Proposed: Feature Reconstruction Loss

How it works

During training with sequential frame data, we compare what the model predicted the future would look like against what the future actually looked like:

Frame t (current):
  image_t → Backbone → Fusion → FutureState → predicted_features[t+1.6s]

Frame t+16 (actual 1.6s later, at 10Hz):
  image_{t+16} → FrozenBackbone → Fusion → target_features[t+1.6s]

feature_reconstruction_loss = MSE(predicted_features, target_features.detach())

Why predict in feature space, not pixel space

Predicting future images (pixels) forces the model to reconstruct irrelevant details like exact texture, shadow positions, and compression artifacts. JEPA (Joint Embedding Predictive Architecture) [1] proposes predicting in a learned latent space instead, where only semantically meaningful changes are captured. MILE [2] demonstrated this principle specifically for autonomous driving — their latent world model outperformed pixel-prediction baselines.

Why use a FrozenBackbone for targets

If both the predictor and the target encoder are trainable, the network can trivially minimize loss by collapsing all representations to a constant (representation collapse). Using a frozen or slowly-updating target encoder prevents this. This asymmetric design is established in BYOL [3] and JEPA [1].

Two update strategies for the frozen encoder:

  1. Keep at pretrained initialization (simplest)
  2. Exponential Moving Average from trainable Backbone:
    for p_frozen, p_train in zip(frozen.parameters(), trainable.parameters()):
        p_frozen.data = τ * p_frozen.data + (1 - τ) * p_train.data  # τ ≈ 0.996

How this helps DrivingPolicy (auxiliary task)

Training only with trajectory loss provides a weak gradient signal to the Backbone and View Fusion modules — the loss is far downstream and must propagate through many layers. Feature reconstruction loss provides a direct training signal to these early modules: "extract features that contain enough information to predict how the scene will change." This richer gradient signal improves the quality of the shared representation, which in turn benefits trajectory prediction.

UniAD [4] uses a similar multi-task design where perception auxiliary losses (detection, tracking, map segmentation) improve the shared BEV representation quality, leading to better final planning performance.

Integration with total loss

total_loss = trajectory_loss + λ * feature_reconstruction_loss

The weight λ balances the two objectives. Typical range: 0.5 – 1.0. If λ is too large, the model focuses on future prediction at the expense of trajectory accuracy. If too small, the auxiliary signal is negligible.

Proposed: Introspection via Prediction Error

Uncertainty estimation at inference time

After training, the feature reconstruction loss can be computed at inference time (when sequential frames are available) as a measure of model uncertainty:

prediction_error = MSE(predicted_features_t_plus_1_6, actual_features_t_plus_1_6)

A high prediction error at a given frame indicates the model encountered a situation it did not expect — the scene evolved in a way that differs from its internal world model. This can signal:

  • Unusual driving scenarios (construction zones, accidents)
  • Out-of-distribution situations (weather the model has not seen)
  • Sudden events (pedestrian stepping onto road)

This prediction error can be monitored as a real-time safety signal: if the model's world model is consistently wrong, the system should reduce confidence in its trajectory predictions or request human intervention.

GAIA-1 [5] demonstrates that generative world models trained on driving data develop internal representations of scene dynamics that can be used for anomaly detection in exactly this way.

Implementation Scope

Item Notes
Wire FrozenBackbone output as target for FutureState loss Requires sequential frame pairs in DataLoader
Add feature_reconstruction_loss computation MSE between predicted and actual features
Add λ hyperparameter Config-level parameter
Add EMA update option for FrozenBackbone Post optimizer.step() hook
Add prediction_error logging at inference For introspection/monitoring

Out of scope (separate issues)

  • Training loop infrastructure (optimizer, scheduler, DataLoader)
  • Trajectory loss implementation
  • Dataset selection
  • BEV segmentation auxiliary loss

References

  1. LeCun, Y. "A Path Towards Autonomous Machine Intelligence." 2022. https://openreview.net/pdf?id=BZ5a1r-kVsf
  2. Hu, A., et al. "Model-Based Imitation Learning for Urban Driving (MILE)." NeurIPS 2022. https://arxiv.org/abs/2210.07729
  3. Grill, J.B., et al. "Bootstrap Your Own Latent (BYOL)." NeurIPS 2020. https://arxiv.org/abs/2006.07733
  4. Hu, Y., et al. "Planning-oriented Autonomous Driving (UniAD)." CVPR 2023. https://arxiv.org/abs/2212.10156
  5. Hu, A., et al. "GAIA-1: A Generative World Model for Autonomous Driving." 2023. https://arxiv.org/abs/2309.17080

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request
    No fields configured for Feature.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions