Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
6f47c0d
Merge pull request #3 from galilai-group/main
MarcelMatsal May 19, 2026
21780f9
implemented state persistence
MarcelMatsal May 19, 2026
fa2423a
Merge branch 'galilai-group:main' into main
MarcelMatsal Jun 4, 2026
2ddada9
Merge pull request #4 from MarcelMatsal/main
MarcelMatsal Jun 4, 2026
7df08bd
many new features for spt web, more coming soon
MarcelMatsal Jun 4, 2026
985605a
making it possible to update the names internally and have persistence
MarcelMatsal Jun 4, 2026
f708f62
now you can add notes per note that persist throughout, in addition, …
MarcelMatsal Jun 4, 2026
6009554
added a small live display to logs outputting, in addition, automatic…
MarcelMatsal Jun 4, 2026
5bcbfc1
added duration for the runs
MarcelMatsal Jun 5, 2026
ae86b50
added functionality to detect stale runs, that would be shown as runn…
MarcelMatsal Jun 5, 2026
b40300f
added scatterplot to compare different metrics across runs
MarcelMatsal Jun 5, 2026
fc11d27
csv download functionality added
MarcelMatsal Jun 5, 2026
e235294
added resizing functionality for the tables allowing to easily reset …
MarcelMatsal Jun 5, 2026
dd3d97a
added resizeable sidebar
MarcelMatsal Jun 5, 2026
9095392
added keyboard shortcuts
MarcelMatsal Jun 5, 2026
4d1b488
added functionality to add new tags
MarcelMatsal Jun 5, 2026
c95b0cf
fixing format
MarcelMatsal Jun 5, 2026
b0b69c9
Merge branch 'main' into spt_web_update
MarcelMatsal Jun 5, 2026
bd06035
made it so that removing a zoom, zooms out of all graphs
MarcelMatsal Jun 5, 2026
2599ffe
added button to table that will hide all the things that are the same
MarcelMatsal Jun 5, 2026
e9dffb4
functionality to download tables
MarcelMatsal Jun 5, 2026
fef791f
functionality to combine different metrics into a single graph for a run
MarcelMatsal Jun 6, 2026
5eff3a9
update to light mode, making it more pleasing to the eyes
MarcelMatsal Jun 6, 2026
58faa18
feature to be able to choose if min or max is better for a graph
MarcelMatsal Jun 7, 2026
d4a8c33
added graceful failure when files are deleted
MarcelMatsal Jun 7, 2026
0769506
added more graceful error catching to renaming
MarcelMatsal Jun 7, 2026
7dfc150
updates for agents to understand how to interact with spt web
MarcelMatsal Jun 7, 2026
e55a34c
added lots of unit tests
MarcelMatsal Jun 7, 2026
2714a7d
fixing precommit errors
MarcelMatsal Jun 7, 2026
155e9ef
fixed final precommit errors from the jax implementation
MarcelMatsal Jun 7, 2026
9daba1a
small updates to pass all tests
MarcelMatsal Jun 7, 2026
29c0474
fixing failing integration tests
MarcelMatsal Jun 8, 2026
d271b05
Revert "fixing failing integration tests"
MarcelMatsal Jun 8, 2026
4a5bb07
Merge branch 'main' into spt_web_update
RandallBalestriero Jun 16, 2026
1020dd7
first round of additional tests and fixes
MarcelMatsal Jun 18, 2026
d70d5d0
Merge branch 'spt_web_update' of https://github.com/MarcelMatsal/stab…
MarcelMatsal Jun 18, 2026
be2201c
new safety fix to frontend
MarcelMatsal Jun 18, 2026
ff9ddf0
Merge branch 'main' into spt_web_update
RandallBalestriero Jun 19, 2026
f788bdd
Merge branch 'main' into spt_web_update
MarcelMatsal Jun 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,12 @@ docs/source/modules.rst
outputs/
multirun/
./data/
mock_runs/
examples/data/
wandb/
DISCOVERABILITY_PLAN.md
SPT_WEB_PLAN.md
scripts/
.DS_Store

*.ckpt
*.pt
Expand Down
113 changes: 113 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ stable_pretraining/
data/ # HFDataset, MultiViewTransform, RepeatedRandomSampler
loggers/ # WandB, Trackio, SwanLab integrations
registry/ # filesystem-first run registry (sidecars + SQLite)
web/ # local run viewer (spt web); reads RegistryLogger output
optim/ # optimizer and scheduler factories
utils/ # atomic checkpointing, lightning patch, error handling
_config.py # global config: spt.set(key, value) / spt.get_config()
Expand Down Expand Up @@ -224,6 +225,118 @@ spt registry export sweep.csv # export to CSV
spt registry scan --full # rebuild SQLite cache
```

## spt web — local run viewer

`spt web` is a dependency-free local viewer for `RegistryLogger` runs — a
WandB alternative requiring no account, no internet, and no extra dependencies
beyond what the library ships. It reads `sidecar.json` + `metrics.csv` produced
by `RegistryLogger` and serves a browser UI over stdlib `http.server` with
Server-Sent Events for live updates.

### Launching

```bash
# Scan any directory (scanner walks the tree to any depth)
spt web runs/

# No argument → uses {cache_dir}/runs automatically
spt web

# Custom host / port / poll interval (mtime-based, no inotify — NFS-safe)
spt web runs/ --host 0.0.0.0 --port 8080 --poll 2.0
```

The viewer opens at `http://127.0.0.1:4242` by default. The page updates in
real time as training writes new metrics — no reload needed.

### Connection to RegistryLogger

`Manager` automatically injects a `RegistryLogger` into every run. When using
`Manager`, no extra configuration is needed: `spt web {cache_dir}/runs` will
show all runs. The logger writes to `{run_dir}/`:

| File | Written by | Displayed as |
|------|------------|-------------|
| `sidecar.json` | RegistryLogger | run metadata, hparams, summary, status, tags, notes |
| `metrics.csv` | RegistryLogger | per-step charts (figures tab) |
| `heartbeat` | RegistryLogger | mtime → stale detection (>5 min old without terminal status = ⚠ stale) |
| `media.jsonl` | RegistryLogger | image/video media panel |
| `*.out` / `*.err` | training process | log tab (auto-tails live runs) |

To use `RegistryLogger` directly (outside `Manager`):
```python
from stable_pretraining.registry import RegistryLogger
logger = RegistryLogger(run_dir="runs/my_run", run_id="my_run")
trainer = pl.Trainer(logger=logger, ...)
```

### Run directory layout (what spt web expects)

```
{run_dir}/ ← any leaf directory that contains a sidecar.json
sidecar.json ← required: run metadata
metrics.csv ← optional: columns = step, epoch, <metric_name>, ...
heartbeat ← optional: empty file; mtime = last alive timestamp
media.jsonl ← optional: one JSON object per line, image/video events
train.out / train.err ← optional: log files shown in the .out / .err tabs
checkpoints/ ← optional: ignored by the viewer
```

Runs do not need to be at the top level of the scanned directory — the scanner
recurses to any depth.

### sidecar.json schema

Key fields agents may read or patch:

| Field | Type | Notes |
|-------|------|-------|
| `run_id` | `str` | Unique identifier, usually path relative to cache_dir |
| `display_name` | `str` | Human-readable label shown in the sidebar (editable via UI or `PATCH /api/run-meta`) |
| `status` | `str` | `"running"` \| `"completed"` \| `"failed"` \| `"orphaned"` |
| `hparams` | `dict` | Hyperparameters; logged via `log_hyperparams` |
| `summary` | `dict` | Final scalars (e.g. best val_acc); populated by `RegistryLogger` at `finalize` |
| `tags` | `list[str]` | Labels for filtering and grouping (editable via UI) |
| `notes` | `str` | Free-text notes (editable via UI) |
| `created_at` | `float` | Unix timestamp of run start |
| `ended_at` | `float \| None` | Unix timestamp of run end; `None` while still running |

The sidecar can be patched programmatically via the server's HTTP API while
`spt web` is running:
```python
import requests
requests.patch("http://127.0.0.1:4242/api/run-meta", json={
"run_id": "runs/my_run",
"display_name": "experiment-v2",
"tags": ["sweep", "lr-1e-3"],
"notes": "Increased weight decay.",
})
```
Allowed mutable fields: `display_name`, `notes`, `tags`, `archived`.

### UI features (for agents reasoning about what users can do)

| Tab / panel | What it shows | Key interactions |
|-------------|--------------|-----------------|
| **Figures** | One uPlot chart per metric, all visible runs overlaid | Drag to zoom (synced across all charts); ↓ min / ↑ max direction toggle; `+` to combine metrics into one panel; `⬇` to download as PNG |
| **Table** | Runs × (hparams + summary) comparison grid | Sortable columns; column search; "hide same" collapses identical columns; amber diff highlighting |
| **.out / .err** | Last ~4 MiB of log files | Auto-refreshes every 10 s while run is live; pause button stops auto-refresh |
| **Detail modal** | Full hparams / summary / tags / notes for one run | Notes and tags are editable in-place; double-click run name in sidebar to rename |
| **Sidebar** | Run list | Search by name/tag/hparam; filter by field value; group-by any hparam key; sort by any metric |

State (selected runs, filters, active tab, smoothing, theme) persists to
`localStorage` across reloads. Visible run IDs and active tab are also written
into the URL fragment (`#runs=id1,id2&tab=figures`) so a shared URL restores the
exact selection.

### When to suggest `spt web`

- User wants to inspect or compare runs **locally** (no WandB account needed).
- User wants to **monitor a live training run** — the viewer updates as metrics arrive.
- User is running a **hyperparameter sweep** and wants to compare configs (table tab + hide-same).
- User wants to see **log output** from a run: `spt web {parent_dir}`, then open the .out tab.
- User asks "how do I see my training curves?" — `spt web` is the answer unless they already have WandB configured.

## Callback ordering

Lightning runs `trainer.callbacks` in registration order. Within a single hook, callbacks fire in that order; across hooks, Lightning completes each hook for **every** callback before moving to the next.
Expand Down
81 changes: 81 additions & 0 deletions RELEASES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,87 @@ Unreleased
- New API pages: :doc:`api/registry` (RegistryLogger + Registry query)
and :doc:`api/web` (``serve`` entry point).

**spt web — viewer improvements**

- **State persistence & shareable URLs**: all interactive state
(selected runs, filters, group-by, sort, x-axis, smoothing, log-y, active
tab, theme) is serialised to ``localStorage`` on every mutation and restored
on load. The URL fragment is kept in sync (``#runs=…&tab=…``) so copying the
address bar produces a link that reopens the exact same selection.

- **Chart value annotations**: each metric chart now has a compact table below
it showing the last and best value for every visible series. The best row is
highlighted; clicking a row toggles that series on/off. Lower-is-better is
inferred automatically from metric names containing ``loss``, ``err``,
``perplexity``, or ``ppl``.

- **Runs table view**: a fourth **table** tab renders a horizontally-scrollable
grid with one row per visible run and one column per hparam/summary key. Cells
whose values differ across runs are highlighted in amber; identical cells are
dimmed. Columns are sortable by click and filterable by a search box. The
run-ID column and header row are sticky.

- **Run display names**: the sidebar and table now show ``display_name`` as the
primary label with the raw path shown as a dimmed hint. Double-clicking the
label opens an inline editor; changes are persisted via ``PATCH /api/run-meta``.

- **Notes editing**: the run detail modal has an editable notes textarea that
auto-resizes to its content. Notes are saved on blur or ``Ctrl+Enter`` via
``PATCH /api/run-meta``.

- **Log auto-refresh / live tail**: the ``.out`` and ``.err`` tabs auto-refresh
every 10 s while the selected run is running or stale, showing a pulsing
**live** badge. A pause button stops the timer; a refresh button triggers a
manual fetch. Scroll position is preserved when the user has scrolled up.

- **Elapsed time / duration**: each run row and the detail modal now show the
elapsed or total duration. Running runs tick forward every 60 s without a full
re-render. ``RegistryLogger.finalize()`` now records ``ended_at`` in the
sidecar so completed durations survive restarts.

- **Heartbeat staleness**: runs whose heartbeat file is more than 5 minutes old
are flagged as **stale** with an amber ⚠ indicator in the sidebar, topbar
stats, detail modal, figures overview, and activity timeline. Filters,
group-by, and sort all treat stale as a distinct status value via a centralised
``effectiveStatus()`` helper.

- **Scatter plot**: the figures tab renders a scatter plot below the metric
charts when two or more runs are visible. X and Y axes are independently
selectable from any numeric ``hparams.*`` or ``summary.*`` key across visible
runs. Each run contributes one point (its final summary scalar). Axis
selection is persisted to ``localStorage``.

- **CSV export**: a **download CSV** button in the figures toolbar downloads all
visible metrics for the currently selected runs as a flat CSV
(``run_id, run_name, metric, step, epoch, value``). Only the runs and metrics
currently in view (respecting the metric-search filter) are included. No
server round-trip — the data is built entirely from the in-memory metrics
cache.

- **Zoom reset**: drag-selecting a region on any chart zooms all charts
simultaneously (shared sync key). A **⤢** reset button appears in the chart
title bar after zooming and resets all charts to their full x-range in one
click. The drag-selection region is now styled with an accent-coloured border
and fill.

- **Sidebar resize + virtual scroll**: the sidebar can be dragged to any width
between 160 px and 600 px; the chosen width is persisted. The
search/filter/sort controls are now fixed at the top of the sidebar while the
run list scrolls independently below them. When more than 300 ungrouped runs
are visible, the run list switches to a virtual-scroll window so SSE updates
remain smooth at scale.

- **Keyboard shortcuts**: ``/`` focuses metric search, ``r`` focuses run
search, ``t`` cycles tabs, ``Shift+A`` selects all, ``Shift+C`` clears all,
``?`` opens a help popover listing all shortcuts. ``Esc`` exits focused inputs
and dismisses all modals and popovers.

- **Tag editing**: run tags are now displayed as editable pills in the detail
modal. Clicking **star** on a pill removes the tag; a dashed ``+ tag`` input
with autocomplete (populated from tags on other runs) adds new ones. All
changes are persisted via ``PATCH /api/run-meta`` with optimistic updates and
revert on failure.

Version 0.1
-----------

Expand Down
40 changes: 20 additions & 20 deletions stable_pretraining/jax/backbone/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from flax import nnx


def _conv(din, dout, k, stride, rngs, padding, dtype=None):
def _conv(din, d_out, k, stride, rngs, padding, dtype=None):
# ``dtype`` is the *computation* dtype (e.g. bfloat16 for mixed precision);
# ``param_dtype`` stays float32 so the weights are kept in full precision.
return nnx.Conv(
din,
dout,
d_out,
kernel_size=(k, k),
strides=(stride, stride),
padding=padding,
Expand All @@ -46,12 +46,12 @@ def _bn(features, rngs, axis_name):
)


def _maybe_downsample(din, dout, stride, rngs, axis_name, dtype=None):
def _maybe_downsample(din, d_out, stride, rngs, axis_name, dtype=None):
"""1x1-conv + BN projection shortcut, or ``None`` when shapes already match."""
if stride != 1 or din != dout:
if stride != 1 or din != d_out:
return nnx.Sequential(
_conv(din, dout, 1, stride, rngs, "VALID", dtype),
_bn(dout, rngs, axis_name),
_conv(din, d_out, 1, stride, rngs, "VALID", dtype),
_bn(d_out, rngs, axis_name),
)
return None

Expand All @@ -61,15 +61,15 @@ class BasicBlock(nnx.Module):

expansion = 1

def __init__(self, din, dout, stride, rngs, axis_name=None, dtype=None):
self.conv1 = _conv(din, dout, 3, stride, rngs, [(1, 1), (1, 1)], dtype)
self.bn1 = _bn(dout, rngs, axis_name)
self.conv2 = _conv(dout, dout, 3, 1, rngs, [(1, 1), (1, 1)], dtype)
self.bn2 = _bn(dout, rngs, axis_name)
def __init__(self, din, d_out, stride, rngs, axis_name=None, dtype=None):
self.conv1 = _conv(din, d_out, 3, stride, rngs, [(1, 1), (1, 1)], dtype)
self.bn1 = _bn(d_out, rngs, axis_name)
self.conv2 = _conv(d_out, d_out, 3, 1, rngs, [(1, 1), (1, 1)], dtype)
self.bn2 = _bn(d_out, rngs, axis_name)
# Assign exactly once: mixing a static ``None`` then a module would
# flip the attribute's pytree status and NNX rejects that.
self.downsample = _maybe_downsample(
din, dout * self.expansion, stride, rngs, axis_name, dtype
din, d_out * self.expansion, stride, rngs, axis_name, dtype
)

def __call__(self, x):
Expand All @@ -84,15 +84,15 @@ class Bottleneck(nnx.Module):

expansion = 4

def __init__(self, din, dout, stride, rngs, axis_name=None, dtype=None):
self.conv1 = _conv(din, dout, 1, 1, rngs, "VALID", dtype)
self.bn1 = _bn(dout, rngs, axis_name)
self.conv2 = _conv(dout, dout, 3, stride, rngs, [(1, 1), (1, 1)], dtype)
self.bn2 = _bn(dout, rngs, axis_name)
self.conv3 = _conv(dout, dout * self.expansion, 1, 1, rngs, "VALID", dtype)
self.bn3 = _bn(dout * self.expansion, rngs, axis_name)
def __init__(self, din, d_out, stride, rngs, axis_name=None, dtype=None):
self.conv1 = _conv(din, d_out, 1, 1, rngs, "VALID", dtype)
self.bn1 = _bn(d_out, rngs, axis_name)
self.conv2 = _conv(d_out, d_out, 3, stride, rngs, [(1, 1), (1, 1)], dtype)
self.bn2 = _bn(d_out, rngs, axis_name)
self.conv3 = _conv(d_out, d_out * self.expansion, 1, 1, rngs, "VALID", dtype)
self.bn3 = _bn(d_out * self.expansion, rngs, axis_name)
self.downsample = _maybe_downsample(
din, dout * self.expansion, stride, rngs, axis_name, dtype
din, d_out * self.expansion, stride, rngs, axis_name, dtype
)

def __call__(self, x):
Expand Down
6 changes: 3 additions & 3 deletions stable_pretraining/jax/backbone/small.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@
from flax import nnx


def _conv_bn(din, dout, rngs, k=3, stride=1):
def _conv_bn(din, d_out, rngs, k=3, stride=1):
return nnx.Sequential(
nnx.Conv(
din,
dout,
d_out,
kernel_size=(k, k),
strides=(stride, stride),
padding=[(k // 2, k // 2)] * 2,
use_bias=False,
rngs=rngs,
),
nnx.BatchNorm(dout, momentum=0.9, rngs=rngs),
nnx.BatchNorm(d_out, momentum=0.9, rngs=rngs),
)


Expand Down
2 changes: 2 additions & 0 deletions stable_pretraining/registry/_sidecar.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def make_sidecar(
run_dir: str,
status: str = "running",
created_at: Optional[float] = None,
ended_at: Optional[float] = None,
hparams: Optional[Dict[str, Any]] = None,
summary: Optional[Dict[str, Any]] = None,
tags: Optional[List[str]] = None,
Expand All @@ -70,6 +71,7 @@ def make_sidecar(
"run_dir": run_dir,
"status": status,
"created_at": created_at if created_at is not None else now,
"ended_at": ended_at,
"updated_at": now,
"tags": list(tags or []),
"notes": notes or "",
Expand Down
3 changes: 3 additions & 0 deletions stable_pretraining/registry/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def __init__(
# the registry can order runs chronologically regardless of how
# often we flush.
self._created_at: Optional[float] = None
self._ended_at: Optional[float] = None
# First-write flag for summary.json — used to log a one-shot info
# line on creation, then debug lines on subsequent rewrites so we
# don't spam every flush.
Expand Down Expand Up @@ -262,6 +263,7 @@ def save(self) -> None:
def finalize(self, status: str) -> None:
# Map Lightning status strings to our canonical vocabulary.
self._status = {"success": "completed", "failed": "failed"}.get(status, status)
self._ended_at = time.time()
# Parent writes CSVs. We don't call super().finalize first
# because _experiment may be None on rank-zero callers that
# never logged — super() handles that no-op correctly.
Expand Down Expand Up @@ -406,6 +408,7 @@ def _write_sidecar(self) -> None:
run_dir=str(self._run_dir),
status=self._status,
created_at=self._created_at,
ended_at=self._ended_at,
hparams=self._hparams,
summary=self._summary,
tags=self._tags,
Expand Down
Loading
Loading