feat(world-model): slow World Model branch (JEPA) + AutoE2E wiring (#13)#85
feat(world-model): slow World Model branch (JEPA) + AutoE2E wiring (#13)#85gcordova10 wants to merge 3 commits into
Conversation
|
Hi @gcordova10 - I have done some refactoring to the code. If you please take a look at the new auto_e2e.py file, you will notice that we have a Reactive_E2E model which consumes egomotion history, multi-camera images and rendered map image to output a trajectory. I have added a placeholder for a World Action Model which is the 'slow thinking' component. This model should consume only multi-view images sampled at a lower frequency (e.g. 1Hz or tunable) - we don't need egomotion_history for this, and the model should encode the past multi-view images into a compressed feature vector called 'visual_history' which is then fed into the Reactive_E2E model. The World Action Model should also predict the future multi-view image features as encoded by the same backbone as used in the World Action Model for the same number of future samples as we have past samples and have a feature reconstruction loss - not on BEV, but on multi-view features themselves. Here is the new architecture diagram which describes this:
|
|
Thanks @m-zain-khawaja, this is a great refactor. The explicit split into So I'll re-scope this PR onto your refactor. The BEV-side JEPA in #85 was built before the refactor and now conflicts, so rather than salvage that diff I'll rebase onto
The BEV Concretely the World Action Model would: take the 1 Hz multi-view stream → its own image backbone → encode the past views into the compressed A few interface questions so I match your diagram exactly:
Do you prefer I rebase #85 into the World Action Model implementation (and repurpose this PR), or open a fresh PR on top of the refactor? Either way I'll reuse the already-tested target-encoder + reconstruction-loss so we're not starting from zero. On sequencing: I'll rebase the sibling temporal-memory PR (#87) onto the refactor right away since it's independent. For this World Action Model PR, I'd rather lock the 5 interface points above with you first so I'm not building on assumptions, then push the implementation reusing the already-tested target-encoder + reconstruction-loss + compressor.
|
…r + L1 loss) (autowarefoundation#13) Re-scoped from the pre-refactor BEV-JEPA PR onto the new architecture (autowarefoundation#85). After @m-zain-khawaja's refactor (Reactive_E2E + World_Action_Model_E2E placeholder), the feature-reconstruction objective moves to multi-view features in the slow model, so the BEV FutureState/train_il wiring is dropped here and this PR keeps only the refactor-independent, reusable pieces the World Action Model will consume: - JepaTargetEncoder: frozen/EMA stop-gradient target encoder (+ compute_jepa_loss). - FeatureCompressor: tunable forecast-space compression (full/projected/occupancy + spatial stride) — a learned low-rank map keeps a semantic space at the channel-mean's memory budget (I-JEPA arXiv:2301.08243). - FeatureReconstructionLoss: optional L1 distance (V-JEPA arXiv:2404.08471). Additive, fully tested (35 tests), mypy/ruff clean. Default behaviour unchanged. Note: full-suite CI is red on main itself due to autowarefoundation#88 (refactor broke shared test infra) — unrelated to these modules. Signed-off-by: GABRIELA CORDOVA <100548769@alumnos.uc3m.es>
f486cd5 to
1f9103d
Compare
…r + L1 loss) (autowarefoundation#13) Re-scoped from the pre-refactor BEV-JEPA PR onto the new architecture (autowarefoundation#85). After @m-zain-khawaja's refactor (Reactive_E2E + World_Action_Model_E2E placeholder), the feature-reconstruction objective moves to multi-view features in the slow model, so the BEV FutureState/train_il wiring is dropped here and this PR keeps only the refactor-independent, reusable pieces the World Action Model will consume: - JepaTargetEncoder: frozen/EMA stop-gradient target encoder (+ compute_jepa_loss). - FeatureCompressor: tunable forecast-space compression (full/projected/occupancy + spatial stride) — a learned low-rank map keeps a semantic space at the channel-mean's memory budget (I-JEPA arXiv:2301.08243). - FeatureReconstructionLoss: optional L1 distance (V-JEPA arXiv:2404.08471). Additive, fully tested (35 tests), mypy/ruff clean. Default behaviour unchanged. Note: full-suite CI is red on main itself due to autowarefoundation#88 (refactor broke shared test infra) — unrelated to these modules. Signed-off-by: GABRIELA CORDOVA <100548769@alumnos.uc3m.es>
…utowarefoundation#13) Consolidated 'World Model' contribution: the reusable JEPA building blocks plus the WorldActionModel that wires them, implementing the slow branch agreed in the 24/06 WG meeting (Zain's answers to the 5 interface questions + the miro diagram). Building blocks: - JepaTargetEncoder: frozen/EMA stop-gradient target encoder (+ compute_jepa_loss). - FeatureCompressor: tunable forecast-space compression (full/projected/occupancy). - FeatureReconstructionLoss: optional L1 distance (V-JEPA), default L2 unchanged. WorldActionModel (world_action_model.py): - FrameEncoder: shared backbone -> GAP -> per-frame embedding (224). - RollingHistoryBuffer (FIFO, len 4) -> Encoded Visual History (4*224=896 = visual_history_dim) for the reactive planner. - future_predictor -> N future embeddings; JEPA loss vs a FROZEN copy of the encoder (stop-grad), L1. - forward: returns visual_history at inference; (visual_history, predicted, loss) in training -> lets AutoE2E return trajectory / (trajectory, future, ego_hidden). Agreed defaults (autowarefoundation#1-5): own backbone + frozen target; N_past=N_future=4 @1Hz; per-frame embed 224; equal loss weight; L1. Additive, opt-in. 44 tests; mypy/ruff clean. Remaining for merge: wire into AutoE2E + the 1Hz/10Hz dataloader + the CI fix (autowarefoundation#88). Signed-off-by: GABRIELA CORDOVA <100548769@alumnos.uc3m.es>
…wiring (autowarefoundation#93/autowarefoundation#13) Per Zain's commented pseudocode in auto_e2e.py (lines 67-76): visual_embedding, future_state_pred = self.World_Action_Model_E2E(frame, visual_history) visual_history.push(visual_embedding) # circular buffer of size N (N=4) return trajectory / (trajectory, future_state_pred) [infer / train] - forward is now per-tick: returns (visual_embedding, future_state_pred); the rolling FIFO buffer is EXTERNAL (RollingHistoryBuffer.push) and forms the Encoded Visual History fed to the reactive planner. - the JEPA loss is computed separately via WorldActionModel.jepa_loss (kept out of the model, for train_il) — matches the 'separate loss modules' action item. - predict_future / encode_history kept as helpers. 10 tests (incl. the online buffer->reactive loop); mypy/ruff clean. Signed-off-by: GABRIELA CORDOVA <100548769@alumnos.uc3m.es>
…e.py lines 67-76) Realizes the wiring @m-zain-khawaja sketched as comments in auto_e2e.py: visual_embedding, future_state_pred = self.World_Action_Model_E2E(frame, visual_history) visual_history.push(visual_embedding) # circular buffer of size N (N=4) return trajectory / (trajectory, future_state_pred) (infer / train) - AutoE2E gains opt-in enable_world_model (default OFF -> reactive-only behaviour byte-identical). When on, it builds a WorldActionModel reusing the reactive backbone (one shared backbone; frozen JEPA target is a copy) + a RollingHistoryBuffer; forward encodes the multi-cam frame, pushes to the buffer, feeds the Encoded Visual History to Reactive_E2E, and (training) also returns future_state_pred. reset_visual_history() clears the buffer per sequence. - FrameEncoder now handles multi-view [B,V,3,H,W] (pool over views). 14 tests incl. the AutoE2E wiring (infer/train contract, WM-off default); mypy/ruff clean. Remaining for T2: the 1Hz/10Hz dataloader (autowarefoundation#16) for end-to-end training, and autowarefoundation#88 (CI). Signed-off-by: GABRIELA CORDOVA <100548769@alumnos.uc3m.es>
1f9103d to
b9533b0
Compare
|
@m-zain-khawaja — I've implemented the What's in the PR nowA new Mapping to your 24/06 answers
Wiring (your
|

Implements your #13 decisions and wires JEPA into
train.py. Memory: instead of a fixed channel-mean,FeatureCompressorexposes the forecast space as a tunable knob —occupancy(= your mean→1, cheapest),projected(default, learned256→d_slowkeeping a semantic space at the same memory budget — I-JEPA arXiv:2301.08243),full— plus spatial stride.FutureStategets configurable horizons + an optional TDV residual predictor (arXiv:2606.15956, @intisar1020).FeatureReconstructionLossgets an L1 option (V-JEPA). Frozen/EMA target encoder, equal weighting, bf16, EMA update per step. Real-data targets need future-frame export indata_processing(still open) so that path is gated;--enable-jepa --smoke-testexercises the full two-loss step. Additive & opt-in: defaults leaveAutoE2E/FutureState/loss byte-identical. 21 new tests, full suite 292 green, mypy/ruff clean. Also fixes a pre-existingtrain.pycall that didn't passmap_inputsince #55.