Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion Model/model_components/losses/trajectory_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
class TrajectoryImitationLoss(nn.Module):
"""Primary task loss: imitation loss over predicted trajectory."""

# Class-level annotations so mypy resolves these to their real types
# instead of nn.Module's ``__getattr__ -> Tensor | Module`` fallback
# (otherwise ``self.loss_fn(...)`` is flagged "Tensor not callable").
loss_fn: nn.Module
temporal_weights: torch.Tensor

def __init__(self, loss_type: str = "smooth_l1", temporal_decay: float = 0.95,
num_timesteps: int = 64, num_signals: int = 2):
# temporal_decay defaults to 0.95 so near-future predictions are
# weighted more heavily than far-future ones; near-future accuracy
# is more safety-critical for planning.
super().__init__()
self.loss_fn: nn.Module
if loss_type == "smooth_l1":
self.loss_fn = nn.SmoothL1Loss(reduction="none")
elif loss_type == "mse":
Expand Down
7 changes: 6 additions & 1 deletion Model/model_components/map_encoder/raster_map_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
features, so the two can be fused directly.
"""

from typing import Any

import torch
import torch.nn as nn
import timm
Expand Down Expand Up @@ -40,8 +42,11 @@ def __init__(
in_chans=in_channels,
)

# timm's FeatureInfo is exposed via __getattr__ (typed Tensor | Module),
# so bind through Any to iterate it without a mypy union-attr error.
backbone: Any = self._backbone
self._feature_channels = [
stage["num_chs"] for stage in self._backbone.feature_info
stage["num_chs"] for stage in backbone.feature_info
]
backbone_channels = sum(self._feature_channels)

Expand Down
36 changes: 17 additions & 19 deletions Model/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,30 +76,30 @@ def forward(self, image):
return self.backbone(image)


def _build_model_with_mock_backbone(num_views, fusion_mode, device,
def _build_model_with_mock_backbone(num_views, fusion_mode="bev", device=None,
num_timesteps=64, map_fusion_mode="residual",
planner_mode="gru", planner_kwargs=None):
planner_mode="bezier", planner_kwargs=None,
**model_kwargs):
"""Construct AutoE2E with the mock backbone injected.

Patches Backbone at the module level during construction to avoid
loading pretrained weights entirely. Forces BEV fusion to a small
8x8 grid so tests stay fast and memory-light; the production default
(450x300) is exercised by dedicated configuration tests.
Post-refactor (#86): the model is ``AutoE2E`` -> ``Reactive_E2E`` and the
image backbone now lives in ``reactive_e2e``; fusion is always BEV
(concat/cross_attn were removed) and GRU was dropped. ``fusion_mode`` is
accepted for backward-compatibility with existing call sites but ignored
(always BEV at a small 8x8 grid). Extra ``model_kwargs`` are forwarded.
"""
from unittest.mock import patch
from model_components.auto_e2e import AutoE2E

view_fusion_kwargs = {"bev_h": 8, "bev_w": 8} if fusion_mode == "bev" else None

with patch('model_components.auto_e2e.Backbone', MockBackbone):
with patch('model_components.reactive_e2e.Backbone', MockBackbone):
model = AutoE2E(
num_views=num_views,
fusion_mode=fusion_mode,
view_fusion_kwargs=view_fusion_kwargs,
view_fusion_kwargs={"bev_h": 8, "bev_w": 8},
num_timesteps=num_timesteps,
planner_mode=planner_mode,
planner_kwargs=planner_kwargs,
map_fusion_mode=map_fusion_mode,
**model_kwargs,
)
return model.to(device)

Expand All @@ -115,13 +115,13 @@ def build_mock_model():
return _build_model_with_mock_backbone


@pytest.fixture(scope="session", params=["concat", "cross_attn", "bev"])
@pytest.fixture(scope="session", params=["bev"])
def model(request, device):
"""Session-scoped model with mock backbone — shared across all tests.

Built once per fusion mode to avoid redundant 1.5s construction overhead
per test. Gradient state is reset before each test via the autouse
_reset_model_grads fixture below.
Post-refactor only BEV fusion exists. Built once to avoid redundant
construction overhead; gradient state is reset before each test via the
autouse fixture below.
"""
return _build_model_with_mock_backbone(
num_views=7, fusion_mode=request.param, device=device
Expand All @@ -138,16 +138,14 @@ def _reset_model_state(request):
model.train()


@pytest.fixture(params=["concat", "cross_attn", "bev"])
@pytest.fixture(params=["bev"])
def full_model(request, device):
"""Full model with real backbone — use only for integration tests."""
from model_components.auto_e2e import AutoE2E

view_fusion_kwargs = {"bev_h": 8, "bev_w": 8} if request.param == "bev" else None
try:
model = AutoE2E(
num_views=7, fusion_mode=request.param,
view_fusion_kwargs=view_fusion_kwargs,
num_views=7, view_fusion_kwargs={"bev_h": 8, "bev_w": 8},
)
except (FileNotFoundError, OSError) as e:
pytest.skip(f"Pretrained weights unavailable: {e}")
Expand Down
Loading
Loading