Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Model/data_parsing/l2d/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .camera import CAMERA_NAMES, NUM_VIEWS, load_camera_frames, make_camera_params_placeholder
from .dataset import L2DDataset
from .egomotion import EGOMOTION_DIM, extract_egomotion
from .world_model_windows import build_windows, required_margins, stride_for_hz, window_offsets

__all__ = [
"L2DDataset",
Expand All @@ -10,4 +11,9 @@
"extract_egomotion",
"NUM_VIEWS",
"EGOMOTION_DIM",
# World Model 1 Hz sequential windows (#16, enables JEPA #13)
"build_windows",
"window_offsets",
"required_margins",
"stride_for_hz",
]
90 changes: 74 additions & 16 deletions Model/data_parsing/l2d/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,25 @@

dataset = L2DDataset(repo_id="yaak-ai/L2D")
sample = dataset[0]
# sample["visual_tiles"] (7, 3, 256, 256)
# sample["visual_tiles"] (7, 3, 256, 256) current 10 Hz frame
# sample["egomotion_history"] (256,)
# sample["visual_history"] (896,)
# sample["trajectory_target"] (128,)
# sample["episode_index"] int
# sample["frame_index"] int

# World Model training (#16, enables the JEPA loss #13): also emit the 1 Hz
# multi-view past/future windows.
dataset = L2DDataset(repo_id="yaak-ai/L2D", include_world_model_windows=True)
sample = dataset[0]
# sample["history_frames"] (N, 7, 3, 256, 256) past @1 Hz, oldest->newest
# sample["future_frames"] (N, 7, 3, 256, 256) future @1 Hz (JEPA targets)
"""

from __future__ import annotations

import logging
import sys
from typing import TypedDict

import timm
Expand All @@ -26,26 +34,36 @@

import numpy as np

if sys.version_info >= (3, 11):
from typing import NotRequired
else: # Python 3.10 (local dev venv); CI runs 3.12
from typing_extensions import NotRequired

from .camera import CAMERA_NAMES
from .egomotion import (
MIN_FRAMES,
_FUTURE_TIMESTEPS,
_HISTORY_TIMESTEPS,
extract_egomotion,
)
from .world_model_windows import build_windows, required_margins, stride_for_hz

logger = logging.getLogger(__name__)

_VISUAL_HISTORY_DIM = 896


class L2DSample(TypedDict):
visual_tiles: torch.Tensor # (7, 3, H, W)
visual_tiles: torch.Tensor # (7, 3, H, W) — current 10 Hz frame
egomotion_history: torch.Tensor # (256,)
visual_history: torch.Tensor # (896,)
trajectory_target: torch.Tensor # (128,)
episode_index: int
frame_index: int
# Present only when include_world_model_windows=True (#16, enables JEPA #13):
# the 1 Hz multi-view past/future windows, each (N, 7, 3, H, W), oldest->newest.
history_frames: NotRequired[torch.Tensor]
future_frames: NotRequired[torch.Tensor]


class L2DDataset(Dataset):
Expand All @@ -70,6 +88,10 @@ def __init__(
episodes: list[int] | None = None,
backbone_name: str = "swinv2_tiny_window8_256",
local_files_only: bool = False,
include_world_model_windows: bool = False,
wm_num_frames: int = 4,
wm_hz: float = 1.0,
source_hz: float = 10.0,
) -> None:
try:
from lerobot.datasets.lerobot_dataset import LeRobotDataset
Expand All @@ -81,6 +103,13 @@ def __init__(
self.repo_id = repo_id
self._episodes = episodes

# World Model (#16): optionally emit the 1 Hz multi-view past/future
# windows that the JEPA loss (#13) needs. stride converts the source rate
# (L2D = 10 Hz) to the World Model rate (1 Hz) -> stride 10.
self._wm_enabled = include_world_model_windows
self._wm_num_frames = wm_num_frames
self._wm_stride = stride_for_hz(source_hz, wm_hz)

# lerobot 0.5.x removed `local_files_only`; it now syncs from cache by
# default and only re-fetches when `force_cache_sync=True`. We map the
# legacy flag onto that: local_files_only=True means "don't force a
Expand Down Expand Up @@ -134,14 +163,24 @@ def _build_sample_index(self) -> list[tuple[int, int]]:
"""
samples = []

# A frame needs enough past/future for BOTH egomotion (64/64) and, when
# enabled, the World Model 1 Hz window. Take the max of the two margins.
past_margin = _HISTORY_TIMESTEPS
future_margin = _FUTURE_TIMESTEPS
if self._wm_enabled:
wm_past, wm_future = required_margins(self._wm_num_frames, self._wm_stride)
past_margin = max(past_margin, wm_past)
future_margin = max(future_margin, wm_future)
min_len = max(MIN_FRAMES, past_margin + future_margin + 1)

for ep_idx, (ep_start, ep_end) in sorted(self._episode_ranges.items()):
ep_len = ep_end - ep_start

if ep_len < MIN_FRAMES:
if ep_len < min_len:
continue

min_frame = _HISTORY_TIMESTEPS
max_frame = ep_len - _FUTURE_TIMESTEPS - 1
min_frame = past_margin
max_frame = ep_len - future_margin - 1

for frame_idx in range(min_frame, max_frame + 1):
samples.append((ep_idx, ep_start + frame_idx))
Expand All @@ -166,6 +205,21 @@ def _get_vehicle_states_window(self, ep_start: int, ep_end: int) -> np.ndarray:
)
return states

def _load_multiview_frame(self, row: int) -> torch.Tensor:
"""Decode + preprocess the 7 camera views for one local row -> (7, 3, H, W).

Decodes video, so it is the expensive path; reused for the current frame
and (when enabled) every frame of the World Model 1 Hz windows.
"""
item = self.lerobot_dataset[row]
tensors = []
for cam_name in CAMERA_NAMES:
frame = item[cam_name] # CHW float [0,1]
frame = TF.resize(frame, list(self._input_size), antialias=True)
frame = TF.normalize(frame, self._mean.squeeze(), self._std.squeeze())
tensors.append(frame)
return torch.stack(tensors, dim=0)

def __getitem__(self, idx: int) -> L2DSample:
# row is the local index into hf_dataset / lerobot_dataset.
ep_idx, row = self._samples[idx]
Expand All @@ -180,24 +234,28 @@ def __getitem__(self, idx: int) -> L2DSample:
vehicle_states, sample_idx=sample_idx_in_episode
)

# Load camera frames for the current timestep (decodes video)
item = self.lerobot_dataset[row]
tensors = []
for cam_name in CAMERA_NAMES:
frame = item[cam_name] # CHW float [0,1]
frame = TF.resize(frame, list(self._input_size), antialias=True)
frame = TF.normalize(frame, self._mean.squeeze(), self._std.squeeze())
tensors.append(frame)

visual_tiles = torch.stack(tensors, dim=0)
# Current 10 Hz multi-view frame (reactive model input).
visual_tiles = self._load_multiview_frame(row)

visual_history = torch.zeros(_VISUAL_HISTORY_DIM, dtype=torch.float32)

return L2DSample(
sample = L2DSample(
visual_tiles=visual_tiles,
egomotion_history=egomotion_history,
visual_history=visual_history,
trajectory_target=trajectory_target,
episode_index=ep_idx,
frame_index=sample_idx_in_episode,
)

# World Model (#16): the 1 Hz multi-view past/future windows for the JEPA
# loss (#13). The valid-index margins above guarantee the window fits.
if self._wm_enabled:
history_frames, future_frames = build_windows(
self._load_multiview_frame, row, ep_start, ep_end,
num_frames=self._wm_num_frames, stride=self._wm_stride,
)
sample["history_frames"] = history_frames
sample["future_frames"] = future_frames

return sample
90 changes: 90 additions & 0 deletions Model/data_parsing/l2d/world_model_windows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""1 Hz sequential multi-view windows for the World Model (#16, enables #13).

The reactive model runs at 10 Hz on the *current* frame; the World Model runs at
~1 Hz over a window of past frames and predicts *future* frame features (the JEPA
feature-reconstruction loss, #13). This module turns a higher-rate frame source
(L2D is 10 Hz) into the World Model's **past** and **future** windows by striding
(``stride = round(source_hz / world_model_hz)`` — e.g. 10 at 10 Hz → 1 Hz).

It is **dataset-agnostic**: it takes a ``load_frame(row) -> [V, 3, H, W]``
callable, so the windowing logic is unit-testable without the real (lerobot)
L2D dataset. ``L2DDataset`` wires its own frame loader into ``build_windows``.
"""

from __future__ import annotations

from collections.abc import Callable

import torch


def stride_for_hz(source_hz: float, world_model_hz: float) -> int:
"""Frames to skip between 1 Hz samples (>= 1)."""
if source_hz <= 0 or world_model_hz <= 0:
raise ValueError("source_hz and world_model_hz must be > 0")
return max(1, round(source_hz / world_model_hz))


def window_offsets(num_frames: int, stride: int) -> tuple[list[int], list[int]]:
"""Row offsets (relative to the current row) for the past/future windows.

Returns ``(history_offsets, future_offsets)``, both oldest → newest:
- history: ``num_frames`` frames *ending at* the current row (current last):
``[-(N-1)*stride, …, -stride, 0]``
- future: the next ``num_frames`` frames: ``[+stride, +2*stride, …, +N*stride]``
"""
if num_frames < 1 or stride < 1:
raise ValueError("num_frames and stride must be >= 1")
history = [-(num_frames - 1 - i) * stride for i in range(num_frames)]
future = [(i + 1) * stride for i in range(num_frames)]
return history, future


def required_margins(num_frames: int, stride: int) -> tuple[int, int]:
"""Frames needed (before, after) the current row for a full window.

before = ``(N-1)*stride`` (history reaches back to the oldest past frame),
after = ``N*stride`` (future reaches the furthest target).
"""
if num_frames < 1 or stride < 1:
raise ValueError("num_frames and stride must be >= 1")
return (num_frames - 1) * stride, num_frames * stride


def build_windows(
load_frame: Callable[[int], torch.Tensor],
row: int,
ep_start: int,
ep_end: int,
num_frames: int = 4,
stride: int = 10,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Build the 1 Hz past/future multi-view windows for ``row``.

Args:
load_frame: ``global_row -> [V, 3, H, W]`` multi-view frame loader.
row: current frame's (local) row index.
ep_start, ep_end: ``[start, end)`` local row range of ``row``'s episode;
the whole window must stay inside it (no cross-episode leakage).
num_frames: frames per window (N_past = N_future).
stride: rows between 1 Hz samples.

Returns:
``(history_frames, future_frames)``, each ``[num_frames, V, 3, H, W]``,
ordered oldest → newest.

Raises:
IndexError: if the window does not fit within the episode (the caller's
valid-index enumeration must guarantee the margins; see
:func:`required_margins`).
"""
hist_off, fut_off = window_offsets(num_frames, stride)
if row + hist_off[0] < ep_start or row + fut_off[-1] >= ep_end:
raise IndexError(
f"World-model window for row {row} exceeds episode "
f"[{ep_start}, {ep_end}) (need {required_margins(num_frames, stride)} "
f"frames before/after)."
)
history = torch.stack([load_frame(row + o) for o in hist_off], dim=0)
future = torch.stack([load_frame(row + o) for o in fut_off], dim=0)
return history, future
100 changes: 100 additions & 0 deletions Model/tests/test_world_model_windows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Tests for the World Model 1 Hz sequential windows (#16, enables JEPA #13).

The windowing logic is dataset-agnostic (takes a frame-loader callable), so it
is fully tested here without the real lerobot/L2D dataset.
"""

import pytest
import torch

from data_parsing.l2d.world_model_windows import (
build_windows,
required_margins,
stride_for_hz,
window_offsets,
)


def _fake_loader(num_views=2):
"""load_frame(row) -> [V, 3, 2, 2] filled with the row value (so a window's
loaded rows are recoverable from the tensor contents)."""
def load(row: int) -> torch.Tensor:
return torch.full((num_views, 3, 2, 2), float(row))
return load


# --- stride_for_hz -------------------------------------------------------------

def test_stride_10hz_to_1hz_is_10():
assert stride_for_hz(10.0, 1.0) == 10

def test_stride_30hz_to_1hz_is_30():
assert stride_for_hz(30.0, 1.0) == 30

def test_stride_rounds_and_floors_at_one():
assert stride_for_hz(10.0, 2.0) == 5
assert stride_for_hz(1.0, 10.0) == 1 # never below 1

def test_stride_invalid_raises():
for bad in [(0, 1), (10, 0), (-1, 1)]:
with pytest.raises(ValueError):
stride_for_hz(*bad)


# --- window_offsets / required_margins ----------------------------------------

def test_window_offsets_default():
hist, fut = window_offsets(num_frames=4, stride=10)
assert hist == [-30, -20, -10, 0] # oldest -> newest, current last
assert fut == [10, 20, 30, 40] # next N frames

def test_window_offsets_single_frame():
hist, fut = window_offsets(num_frames=1, stride=10)
assert hist == [0] and fut == [10]

def test_required_margins():
assert required_margins(4, 10) == (30, 40) # (N-1)*s before, N*s after

def test_offsets_invalid_raises():
with pytest.raises(ValueError):
window_offsets(0, 10)
with pytest.raises(ValueError):
window_offsets(4, 0)


# --- build_windows ------------------------------------------------------------

def test_build_windows_shapes():
hist, fut = build_windows(_fake_loader(num_views=7), row=64,
ep_start=0, ep_end=200, num_frames=4, stride=10)
assert hist.shape == (4, 7, 3, 2, 2)
assert fut.shape == (4, 7, 3, 2, 2)

def test_build_windows_loads_correct_rows_oldest_to_newest():
hist, fut = build_windows(_fake_loader(), row=50,
ep_start=0, ep_end=100, num_frames=4, stride=10)
# history rows: 20,30,40,50 (current last); future: 60,70,80,90
assert [hist[i, 0, 0, 0, 0].item() for i in range(4)] == [20, 30, 40, 50]
assert [fut[i, 0, 0, 0, 0].item() for i in range(4)] == [60, 70, 80, 90]

def test_build_windows_history_ends_at_current():
hist, _ = build_windows(_fake_loader(), row=33, ep_start=0, ep_end=100,
num_frames=4, stride=10)
assert hist[-1, 0, 0, 0, 0].item() == 33 # newest history frame == current

def test_build_windows_raises_when_past_exceeds_episode():
with pytest.raises(IndexError):
build_windows(_fake_loader(), row=20, ep_start=0, ep_end=100,
num_frames=4, stride=10) # needs row-30 = -10 < 0

def test_build_windows_raises_when_future_exceeds_episode():
with pytest.raises(IndexError):
build_windows(_fake_loader(), row=95, ep_start=0, ep_end=100,
num_frames=4, stride=10) # needs row+40 = 135 >= 100

def test_build_windows_respects_episode_start_offset():
# episode is rows [100, 200); current at 150 must read within it.
hist, fut = build_windows(_fake_loader(), row=150, ep_start=100, ep_end=200,
num_frames=4, stride=10)
assert [hist[i, 0, 0, 0, 0].item() for i in range(4)] == [120, 130, 140, 150]
assert [fut[i, 0, 0, 0, 0].item() for i in range(4)] == [160, 170, 180, 190]