Skip to content

feat(world-model): slow World Model branch (JEPA) + AutoE2E wiring (#13)#85

Open
gcordova10 wants to merge 3 commits into
autowarefoundation:mainfrom
gcordova10:feat/t2-jepa-target-encoder
Open

feat(world-model): slow World Model branch (JEPA) + AutoE2E wiring (#13)#85
gcordova10 wants to merge 3 commits into
autowarefoundation:mainfrom
gcordova10:feat/t2-jepa-target-encoder

Conversation

@gcordova10

Copy link
Copy Markdown
Contributor

Implements your #13 decisions and wires JEPA into train.py. Memory: instead of a fixed channel-mean, FeatureCompressor exposes the forecast space as a tunable knob — occupancy (= your mean→1, cheapest), projected (default, learned 256→d_slow keeping a semantic space at the same memory budget — I-JEPA arXiv:2301.08243), full — plus spatial stride. FutureState gets configurable horizons + an optional TDV residual predictor (arXiv:2606.15956, @intisar1020). FeatureReconstructionLoss gets an L1 option (V-JEPA). Frozen/EMA target encoder, equal weighting, bf16, EMA update per step. Real-data targets need future-frame export in data_processing (still open) so that path is gated; --enable-jepa --smoke-test exercises the full two-loss step. Additive & opt-in: defaults leave AutoE2E/FutureState/loss byte-identical. 21 new tests, full suite 292 green, mypy/ruff clean. Also fixes a pre-existing train.py call that didn't pass map_input since #55.

@m-zain-khawaja

Copy link
Copy Markdown
Member

Hi @gcordova10 - I have done some refactoring to the code. If you please take a look at the new auto_e2e.py file, you will notice that we have a Reactive_E2E model which consumes egomotion history, multi-camera images and rendered map image to output a trajectory.

I have added a placeholder for a World Action Model which is the 'slow thinking' component. This model should consume only multi-view images sampled at a lower frequency (e.g. 1Hz or tunable) - we don't need egomotion_history for this, and the model should encode the past multi-view images into a compressed feature vector called 'visual_history' which is then fed into the Reactive_E2E model. The World Action Model should also predict the future multi-view image features as encoded by the same backbone as used in the World Action Model for the same number of future samples as we have past samples and have a feature reconstruction loss - not on BEV, but on multi-view features themselves.

Here is the new architecture diagram which describes this:

auto_e2e_architecture

@gcordova10

Copy link
Copy Markdown
Contributor Author

Thanks @m-zain-khawaja, this is a great refactor. The explicit split into Reactive_E2E (fast, 10 Hz, BEV) and the World_Action_Model_E2E placeholder (slow, 1 Hz) is exactly the slow/fast separation we'd been converging on, and putting the feature-reconstruction loss on multi-view features rather than BEV is a cleaner choice — it's closer to V-JEPA's image-feature prediction and keeps BEV as the reactive model's concern.

So I'll re-scope this PR onto your refactor. The BEV-side JEPA in #85 was built before the refactor and now conflicts, so rather than salvage that diff I'll rebase onto f484156 and implement the World_Action_Model_E2E placeholder, reusing the pieces I already built and tested:

  • the frozen/EMA JepaTargetEncoder (stop-gradient target encoder),
  • the FeatureReconstructionLoss (with the L1 option, V-JEPA-style),
  • the FeatureCompressor for the compressed visual_history vector.

The BEV FutureState predictor drops out of this PR — it belongs to the reactive side.

Concretely the World Action Model would: take the 1 Hz multi-view stream → its own image backbone → encode the past views into the compressed visual_history (fed into Reactive_E2E) → predict the next N future multi-view features → reconstruction loss against the same backbone (frozen/EMA) applied to the actual future views.

A few interface questions so I match your diagram exactly:

  1. Backbone — does the World Action Model get its own image backbone (separate from the reactive one), and the target is a frozen/EMA copy of that backbone? (frozen vs EMA — the module already supports both.)
  2. Window / N — how many past frames at 1 Hz, and is N_future == N_past exactly (e.g. 4↔4)? Tunable default?
  3. Feature level — reconstruction on the per-view backbone feature maps [B, V, C, h, w] (which backbone stage?), per-view or aggregated?
  4. visual_history shape — the compressed vector fed into Reactive_E2E: [B, visual_history_dim] (896), pooled across views/time?
  5. Integration — still into train_il (feat(platform): MLOps training platform on EKS — Flyte + MLflow training loop (#14) #78) with equal loss weighting? The future multi-view frames will need exporting in data_processing (same open item as Proposal: Feature Reconstruction Loss for FutureState (Auxiliary Self-Supervised Task) #13).

Do you prefer I rebase #85 into the World Action Model implementation (and repurpose this PR), or open a fresh PR on top of the refactor? Either way I'll reuse the already-tested target-encoder + reconstruction-loss so we're not starting from zero.

On sequencing: I'll rebase the sibling temporal-memory PR (#87) onto the refactor right away since it's independent. For this World Action Model PR, I'd rather lock the 5 interface points above with you first so I'm not building on assumptions, then push the implementation reusing the already-tested target-encoder + reconstruction-loss + compressor.

Note for @intisar1020 (#84): the ego-warp residual is a BEV operation, so it now naturally sits on the reactive/BEV side rather than this multi-view World Action Model — let's keep that thread separate.

gcordova10 added a commit to gcordova10/auto_fsd that referenced this pull request Jun 23, 2026
…r + L1 loss) (autowarefoundation#13)

Re-scoped from the pre-refactor BEV-JEPA PR onto the new architecture (autowarefoundation#85).
After @m-zain-khawaja's refactor (Reactive_E2E + World_Action_Model_E2E
placeholder), the feature-reconstruction objective moves to multi-view features
in the slow model, so the BEV FutureState/train_il wiring is dropped here and
this PR keeps only the refactor-independent, reusable pieces the World Action
Model will consume:
- JepaTargetEncoder: frozen/EMA stop-gradient target encoder (+ compute_jepa_loss).
- FeatureCompressor: tunable forecast-space compression (full/projected/occupancy
  + spatial stride) — a learned low-rank map keeps a semantic space at the
  channel-mean's memory budget (I-JEPA arXiv:2301.08243).
- FeatureReconstructionLoss: optional L1 distance (V-JEPA arXiv:2404.08471).
Additive, fully tested (35 tests), mypy/ruff clean. Default behaviour unchanged.
Note: full-suite CI is red on main itself due to autowarefoundation#88 (refactor broke shared test
infra) — unrelated to these modules.

Signed-off-by: GABRIELA CORDOVA <100548769@alumnos.uc3m.es>
@gcordova10 gcordova10 force-pushed the feat/t2-jepa-target-encoder branch 2 times, most recently from f486cd5 to 1f9103d Compare June 24, 2026 12:16
gcordova10 added a commit to gcordova10/auto_fsd that referenced this pull request Jun 24, 2026
…r + L1 loss) (autowarefoundation#13)

Re-scoped from the pre-refactor BEV-JEPA PR onto the new architecture (autowarefoundation#85).
After @m-zain-khawaja's refactor (Reactive_E2E + World_Action_Model_E2E
placeholder), the feature-reconstruction objective moves to multi-view features
in the slow model, so the BEV FutureState/train_il wiring is dropped here and
this PR keeps only the refactor-independent, reusable pieces the World Action
Model will consume:
- JepaTargetEncoder: frozen/EMA stop-gradient target encoder (+ compute_jepa_loss).
- FeatureCompressor: tunable forecast-space compression (full/projected/occupancy
  + spatial stride) — a learned low-rank map keeps a semantic space at the
  channel-mean's memory budget (I-JEPA arXiv:2301.08243).
- FeatureReconstructionLoss: optional L1 distance (V-JEPA arXiv:2404.08471).
Additive, fully tested (35 tests), mypy/ruff clean. Default behaviour unchanged.
Note: full-suite CI is red on main itself due to autowarefoundation#88 (refactor broke shared test
infra) — unrelated to these modules.

Signed-off-by: GABRIELA CORDOVA <100548769@alumnos.uc3m.es>
…utowarefoundation#13)

Consolidated 'World Model' contribution: the reusable JEPA building blocks plus
the WorldActionModel that wires them, implementing the slow branch agreed in the
24/06 WG meeting (Zain's answers to the 5 interface questions + the miro
diagram).

Building blocks:
- JepaTargetEncoder: frozen/EMA stop-gradient target encoder (+ compute_jepa_loss).
- FeatureCompressor: tunable forecast-space compression (full/projected/occupancy).
- FeatureReconstructionLoss: optional L1 distance (V-JEPA), default L2 unchanged.

WorldActionModel (world_action_model.py):
- FrameEncoder: shared backbone -> GAP -> per-frame embedding (224).
- RollingHistoryBuffer (FIFO, len 4) -> Encoded Visual History (4*224=896 =
  visual_history_dim) for the reactive planner.
- future_predictor -> N future embeddings; JEPA loss vs a FROZEN copy of the
  encoder (stop-grad), L1.
- forward: returns visual_history at inference; (visual_history, predicted, loss)
  in training -> lets AutoE2E return trajectory / (trajectory, future, ego_hidden).

Agreed defaults (autowarefoundation#1-5): own backbone + frozen target; N_past=N_future=4 @1Hz;
per-frame embed 224; equal loss weight; L1. Additive, opt-in.
44 tests; mypy/ruff clean. Remaining for merge: wire into AutoE2E + the 1Hz/10Hz
dataloader + the CI fix (autowarefoundation#88).

Signed-off-by: GABRIELA CORDOVA <100548769@alumnos.uc3m.es>
…wiring (autowarefoundation#93/autowarefoundation#13)

Per Zain's commented pseudocode in auto_e2e.py (lines 67-76):
  visual_embedding, future_state_pred = self.World_Action_Model_E2E(frame, visual_history)
  visual_history.push(visual_embedding)   # circular buffer of size N (N=4)
  return trajectory  /  (trajectory, future_state_pred)  [infer / train]

- forward is now per-tick: returns (visual_embedding, future_state_pred); the
  rolling FIFO buffer is EXTERNAL (RollingHistoryBuffer.push) and forms the
  Encoded Visual History fed to the reactive planner.
- the JEPA loss is computed separately via WorldActionModel.jepa_loss (kept out
  of the model, for train_il) — matches the 'separate loss modules' action item.
- predict_future / encode_history kept as helpers.
10 tests (incl. the online buffer->reactive loop); mypy/ruff clean.

Signed-off-by: GABRIELA CORDOVA <100548769@alumnos.uc3m.es>
…e.py lines 67-76)

Realizes the wiring @m-zain-khawaja sketched as comments in auto_e2e.py:
  visual_embedding, future_state_pred = self.World_Action_Model_E2E(frame, visual_history)
  visual_history.push(visual_embedding)   # circular buffer of size N (N=4)
  return trajectory  /  (trajectory, future_state_pred)   (infer / train)

- AutoE2E gains opt-in enable_world_model (default OFF -> reactive-only behaviour
  byte-identical). When on, it builds a WorldActionModel reusing the reactive
  backbone (one shared backbone; frozen JEPA target is a copy) + a
  RollingHistoryBuffer; forward encodes the multi-cam frame, pushes to the
  buffer, feeds the Encoded Visual History to Reactive_E2E, and (training) also
  returns future_state_pred. reset_visual_history() clears the buffer per sequence.
- FrameEncoder now handles multi-view [B,V,3,H,W] (pool over views).
14 tests incl. the AutoE2E wiring (infer/train contract, WM-off default); mypy/ruff clean.
Remaining for T2: the 1Hz/10Hz dataloader (autowarefoundation#16) for end-to-end training, and autowarefoundation#88 (CI).

Signed-off-by: GABRIELA CORDOVA <100548769@alumnos.uc3m.es>
@gcordova10 gcordova10 force-pushed the feat/t2-jepa-target-encoder branch from 1f9103d to b9533b0 Compare June 25, 2026 11:44
@gcordova10 gcordova10 changed the title feat(jepa): integrate the feature-reconstruction objective into training (#13) feat(world-model): slow World Model branch (JEPA) + AutoE2E wiring (#13) Jun 25, 2026
@gcordova10

Copy link
Copy Markdown
Contributor Author

@m-zain-khawaja — I've implemented the World_Action_Model_E2E placeholder and wired it into AutoE2E, following your answers in the 24/06 meeting and the commented wiring you left in auto_e2e.py. I force-pushed this branch with the full implementation (it now supersedes the old BEV-side diff). Here's exactly what changed and the one design point I'd like to confirm with you.

What's in the PR now

A new world_action_model.py (WorldActionModel, FrameEncoder, RollingHistoryBuffer) plus the reusable JEPA pieces I'd already built (JepaTargetEncoder, FeatureCompressor, L1 option in FeatureReconstructionLoss), and the wiring in auto_e2e.py. It's opt-in (enable_world_model, default off → Reactive_E2E behaviour is byte-identical), reuses the reactive backbone (one shared backbone; the JEPA target is a frozen copy of it), and the JEPA loss is kept out of the model (computed in the training loop via WorldActionModel.jepa_loss).

Mapping to your 24/06 answers

Interface (my Q#) Your answer (24/06) Implemented
Backbone one shared backbone, online trainable + frozen target copy FrameEncoder reuses the reactive backbone; JepaTargetEncoder(mode="frozen") is a frozen copy (stop-gradient, anti-collapse)
Window / N N_past = N_future = 4 @ 1 Hz history_len=4, num_future_steps=4
visual_history rolling FIFO buffer → fed to Reactive_E2E RollingHistoryBuffer (len 4), left-padded, FIFO order
Loss feature-reconstruction, L1, equal weighting, in train_il FeatureReconstructionLoss(loss_type="l1"), weight=1.0

Wiring (your auto_e2e.py, realized)

visual_embedding, future_state_pred = self.World_Action_Model_E2E(camera_tiles, visual_history)
self.visual_history_buffer.push(visual_embedding)   # circular buffer, size N=4
visual_history = self.visual_history_buffer.visual_history()
trajectory = self.Reactive_E2E(camera_tiles, map_input, visual_history, egomotion_history, ...)
return trajectory                       # inference
return trajectory, future_state_pred    # training

One point I'd like to confirm — future-prediction granularity

Here's where the docs diverge and I want your call before I lock it. The miro diagram (24/06) shows a per-frame 224-d embedding and an Encoded Visual History of 896 (4×224), so I currently predict the future as pooled 224-d embeddings. But your earlier comment on this PR ("reconstruction on the multi-view features themselves, not BEV") and TRIAL.md ("Future Visual Features Prediction: torch.Size([8, 1440, 7, 7]) × 4") describe the JEPA target as the full backbone feature maps [B, 1440, 7, 7], not a pooled vector.

So: should the future-feature reconstruction be on

  • (a) feature maps [B, 1440(=backbone_channels), 7, 7] × 4 (per TRIAL.md / your comment), or
  • (b) pooled per-frame embeddings 224 (per the miro)?

FeatureReconstructionLoss and FeatureCompressor already support feature-map targets, so switching to (a) is a small, contained change — I just want to match whichever you intend. (Relatedly, TRIAL.md lists the compressed visual_history item as [14], vs the miro's 896 — happy to align that too.)

Coordination with #93 (@intisar1020)

@intisar1020 proposed the same World Action Model in #93 — flagging here so we pool effort rather than duplicate; happy to co-author (their temporal-attention history head + the backbone-freezing choice are clean drop-ins behind this interface).

Status

World-model + building-block tests pass locally; mypy/ruff clean. CI is red only because of #88 (the test infra from the refactor), not this code.

Remaining after the feature-level decision is not part of this PR: the left-most block of the architecture diagram — the Data Loader (MP4 Video at 30 FPS) that subsamples into the Video sampled at 1 Hz (World Model + Reasoning) and Video sampled at 10 Hz (Reactive) streams. That's #16 (Dataset Selection and DataLoader Implementation); in particular it needs to export the future 1 Hz multi-view frames so the JEPA target encoder has something to compare against for end-to-end training.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants