Skip to content

[bugfix] restore dataloader state before create_dataloader forks workers#544

Open
tiankongdeguiji wants to merge 9 commits into
masterfrom
fix_dataloader_state_restore_order
Open

[bugfix] restore dataloader state before create_dataloader forks workers#544
tiankongdeguiji wants to merge 9 commits into
masterfrom
fix_dataloader_state_restore_order

Conversation

@tiankongdeguiji

@tiankongdeguiji tiankongdeguiji commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator

Problem

train_and_evaluate restores the dataloader checkpoint state after create_dataloader, but create_dataloader ends with an eager iter(dataloader) (persistent_workers=True) that forks the DataLoader workers immediately. Workers keep a fork-time copy of the dataset with _checkpoint_state=None; the later load_state_dict mutates only the parent process, and persistent workers are never re-pickled (_ResumeIteration recreates the fetcher from the worker's own dataset copy). The restored state therefore never reaches the processes that actually read data.

Production symptom: on resume with --fine_tune_checkpoint + --continue_train, Kafka consumers ignored dataloader_state.json and re-seeked via start.timestamp.ms (days of data re-trained). The same ordering bug affects Parquet/ODPS resume.

The eager iter() itself must stay: forking workers later in the lifecycle (after CUDA-holding garbage cycles exist in the heap) makes worker GC abort with CUDA error: initialization error — the original reason it was added in v0.7.0.

Fix

  • create_dataloader accepts an optional checkpoint_state and applies it to the dataset before the DataLoader is built, so the eager iter() forks workers that already carry the state.
  • train_and_evaluate hoists CheckpointManager construction, ckpt-path resolution and restore_dataloader_state above create_dataloader (logic unchanged; the manager's event-time watermark seeding is preserved) and passes the state through the new argument.

With the state present at fork time, every dataloader reset (e.g. the train loop's iter()) re-seeks to the checkpointed offsets, so prefetched-then-discarded batches are re-read — correct at-least-once semantics.

Tests

  • New regression test test_create_dataloader_checkpoint_state_reaches_workers (parquet, num_workers=2): state passed through create_dataloader is honored by forked workers. Verified it catches the bug: with the old set-after-create pattern, workers replay from row 0.
  • Existing parquet/kafka dataset tests and test_multi_tower_din_fg_encoded_finetune integration pass.

🤖 Generated with Claude Code

Resume semantics (added after review)

The restored state is consumed by the first completed pass: BaseDataset.__iter__ clears the reader state when its to_batches generator exhausts normally (generator close() keeps it, so mid-pass dataloader resets still re-seek). Consequences:

  • a mid-pass resume finishes the remaining intervals, then subsequent epochs read full passes;
  • the parent's accumulated dataloader_state is likewise reset at the pass boundary (train loop StopIteration), so checkpoints taken during later passes record positions within that pass instead of max-merging back the old fully-consumed state — epoch-boundary and final checkpoints store an empty state, and resuming from them starts a fresh full pass directly (no empty pass);
  • resuming from a legacy checkpoint that recorded a fully-consumed state yields one empty pass and then proceeds normally (previously: zero batches forever);
  • streaming sources (Kafka) are unaffected — their generators never exhaust.

The checkpoint also records the completed-pass count (reserved key __epochs_completed__, incremented at each pass boundary), and num_epochs jobs resume with range(completed, num_epochs) — so a job that dies mid-pass-N trains exactly the remaining budget (skip_steps already handled num_steps jobs). Fine-tune checkpoints keep their data positions but drop the counter: the epoch budget belongs to the job, not the warm-start weights. Old checkpoints without the key resume with the previous behavior.

tiankongdeguiji and others added 3 commits June 11, 2026 13:34
create_dataloader eagerly starts persistent workers via iter(dataloader);
workers keep a fork-time copy of the dataset, so load_state_dict applied to
the returned dataloader's dataset afterwards never reaches them. Kafka
consumers therefore ignored the restored offsets and re-seeked via
start.timestamp.ms / committed offsets on resume.

Hoist CheckpointManager construction, ckpt_path resolution and
restore_dataloader_state above create_dataloader, and pass the state via a
new create_dataloader(checkpoint_state=...) argument applied to the dataset
before the DataLoader is built.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Regression test for the restore ordering: resume state passed through
create_dataloader(checkpoint_state=...) is honored by persistent workers
(num_workers=2); with the old set-after-create pattern workers replay from
row 0.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji added the claude-review Let Claude Review label Jun 11, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label Jun 11, 2026
Comment thread tzrec/datasets/dataset.py Outdated
Comment on lines +792 to +793
if checkpoint_state:
dataset.load_state_dict(checkpoint_state)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BaseReader.load_state_dict stores this dict by reference (dataset.py:555), and train_and_evaluate passes the very same object on to _train_and_evaluate, which mutates it in place on every step via update_dataloder_state (main.py:462) and at save time. With num_workers >= 1 the fork-time copy makes this accidentally safe, but with data_config.num_workers < 1 the DataLoader runs in-process, so the reader aliases the live, ever-growing dict — the next iter(train_dataloader) (epoch 2) recomputes intervals from offsets accumulated during this run and skips that data. KafkaReader's on_assign rebalance callback also reads _checkpoint_state at arbitrary later times, cross-thread. A defensive copy breaks the aliasing:

Suggested change
if checkpoint_state:
dataset.load_state_dict(checkpoint_state)
if checkpoint_state:
dataset.load_state_dict(dict(checkpoint_state))

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 5389e7fcreate_dataloader now passes dict(checkpoint_state) to load_state_dict, decoupling the reader from the training loop's accumulating dict.

Comment thread tzrec/main.py
Comment on lines +632 to +639
# Build dataloader
train_dataloader = create_dataloader(
data_config,
features,
pipeline_config.train_input_path,
mode=Mode.TRAIN,
checkpoint_state=dataloader_state,
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that the state actually reaches the workers, note it is re-applied on every epoch, not just the resumed one: nothing ever clears _checkpoint_state, and each epoch's iter(train_dataloader) re-enters to_batchescalc_slice_intervals(checkpoint_state=...) (parquet_dataset.py:252-261). For num_epochs > 1 + --continue_train:

  1. After a mid-epoch resume, every subsequent epoch reads only the unconsumed tail instead of a full pass.
  2. dataloader_state accumulates across the whole run with no per-pass reset (main.py:462), so a checkpoint saved after one full pass marks the dataset fully consumed — resuming from it makes calc_remaining_intervals return [] and every epoch yields zero batches (with num_steps, the loop spins on immediate StopIteration and i_step never advances).

Pre-PR this was latent (the state never reached forked workers at all); this PR activates it in the default multi-worker config. One subtlety if you fix it with consume-once semantics: clearing on entry to to_batches would break the at-least-once re-seek this PR relies on (the eager iter() in create_dataloader plus the train loop's iter() are two resets that must both see the state) — clearing after the generator is fully exhausted (end of to_batches; a closed prefetch generator skips it) gives "finish the interrupted pass, then full passes". Alternatively, if resume is intentionally single-pass/streaming-only (kafka/odps), an explicit guard or a docstring caveat here and on create_dataloader would prevent silent data loss for multi-epoch configs.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 5389e7f with the consume-once-on-exhaustion approach: BaseDataset.__iter__ clears the reader state after the wrapped to_batches generator exhausts normally (generator close() skips the clear, preserving the mid-pass re-seek that the eager iter() + train-loop iter() double reset relies on). Resumed pass finishes the tail, later epochs are full passes, and resuming from a fully-consumed checkpoint degrades to one empty pass instead of spinning. Verified the clear is safe for all readers (OdpsReader.load_state_dict guards if state:; Kafka's generator never exhausts, so streaming is unchanged) and covered by a new second-pass-reads-full-dataset assertion in the test.

Comment thread tzrec/datasets/parquet_dataset_test.py Outdated
Comment on lines +277 to +280
for key, new_offset in batch2.checkpoint_info.items():
if key in checkpoint_state:
# without the pre-fork state, workers replay from row 0
self.assertGreater(new_offset, checkpoint_state[key])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This guard makes the green path vacuous: when the fix works, resumed source keys are f"{input_path}:{consumed+1}" (calc_remaining_intervals starts remaining intervals at consumed + 1), which can never equal a saved key f"{input_path}:{start}" — so key in checkpoint_state is never true and the loop asserts nothing on all 4 batches. The test does still catch the targeted regression (without pre-fork state, workers emit the original keys :0/:10000 with offsets ≤ the saved ones, failing assertGreater), but adjacent regressions (off-by-N resume position, key-format drift) would pass silently. Asserting the affirmative invariant keeps the detection power and makes the pass path meaningful:

resumed_starts = set()
for _ in range(4):
    batch2 = next(iterator2)
    self.assertTrue(set(batch2.checkpoint_info).isdisjoint(checkpoint_state))
    resumed_starts |= {int(k.rsplit(":", 1)[1]) for k in batch2.checkpoint_info}
self.assertEqual(resumed_starts, {v + 1 for v in checkpoint_state.values()})

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 51f6cf4 — phase 2 now iterates to exhaustion and asserts the affirmative invariants: key disjointness, presence of every {path}:{consumed+1} resumed start, and row-count conservation against an empirically measured baseline (robust to slicing details). Re-verified detection power: under the old set-after-create pattern all three assertions fail (full 20000-row replay vs expected 19488).

Comment thread tzrec/datasets/parquet_dataset_test.py Outdated
Comment on lines +250 to +251
# 20000 rows at max_rows_per_file=5000 -> 4 files, so
# data_config.num_workers=2 survives the num_files clamp.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment justifies the fixture size with a constraint that doesn't apply: ParquetReader.num_files() returns None when rebalance=True (the default, and ParquetDataset doesn't override it), so the num_files clamp in create_dataloader never fires for ParquetDataset. ~2k rows would exercise the same paths faster; either shrink the fixture or fix the comment.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 51f6cf4 — comment now justifies the fixture by rebalanced multi-worker intervals + a remaining tail; kept the 20000 rows since the reworked assertions iterate full passes.

@github-actions

Copy link
Copy Markdown

Review summary

The fix is correct and well-motivated: applying checkpoint_state before the eager iter(dataloader) is the only way state can reach persistent workers, and the root-cause analysis in the description checks out. Things I verified:

  • Hoisting is safe. CheckpointManager.__init__ is side-effect-free (no collectives, no filesystem writes, no threads — the prune daemon and gloo group are lazy), and restore_dataloader_state is a rank-local file read, so moving them above create_dataloader changes no cross-rank ordering. The continue-train RuntimeError now fires before workers fork — a fail-fast improvement.
  • No other call site has the bug. dataset.load_state_dict outside tests now exists only inside create_dataloader; evaluate/predict/export are one-shot jobs with no resume.
  • The new regression test has real detection power for the targeted bug (replay-from-0 reproduces the saved keys with smaller offsets and fails assertGreater).

Two issues worth addressing before merge (details in inline comments):

  1. Multi-epoch resume semantics (main.py): the fork-time state is re-applied on every epoch's iter(), and dataloader_state accumulates with no per-pass reset — so after a resume, later epochs read only the unconsumed tail, and resuming from a post-full-pass checkpoint yields zero batches (or an idle spin with num_steps). Latent pre-PR; activated now that state reaches workers.
  2. Shared-dict aliasing (dataset.py:793): the reader stores the caller's dict by reference while the train loop mutates it per step — corrupts epoch 2+ when num_workers=0, and races with Kafka's cross-thread on_assign. A one-line copy fixes it.

Minor: the pre-existing comment at main.py:611 ("Restore dataloader state if continuing training") now describes code that moved to L624-630 — it sits above the latest-checkpoint resolution and should be reworded. Optional hardening: ranks read dataloader_state.json independently; now that the state affects per-rank intervals, a rank-0 read + broadcast would rule out divergent epoch lengths (and a collective-save hang) on skewed shared filesystems.

tiankongdeguiji and others added 2 commits June 11, 2026 14:53
…mpletes

Review follow-ups for the restore-ordering fix:
- create_dataloader copies the state dict: the reader keeps it by
  reference while the train loop mutates the caller's dict per step
  (aliases the live dict with num_workers=0; Kafka on_assign reads it
  cross-thread).
- BaseDataset.__iter__ clears the reader state after the wrapped
  to_batches generator is exhausted: a resumed pass finishes the
  remaining intervals, subsequent epochs read full passes, and resuming
  from a fully-consumed checkpoint degrades to one empty pass instead of
  yielding zero batches forever. Generator close() skips the clear, so
  mid-pass dataloader resets still re-seek from the restored state.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…e test

Resumed source keys ({path}:{consumed+1}) never collide with saved keys,
so the previous green path asserted nothing. Now: full-pass iteration
asserts key disjointness, presence of every resumed start, row-count
conservation vs an empirically measured baseline, and that a second
iter() after the completed pass reads the full dataset again
(state consumed on exhaustion).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
@tiankongdeguiji

Copy link
Copy Markdown
Collaborator Author

Review addressed in 5389e7f + 51f6cf4 (details in the inline replies):

  • defensive copy of the state dict in create_dataloader
  • consume-once semantics: reader state cleared when a pass completes (normal generator exhaustion only), so resumed runs finish the interrupted pass and later epochs read full passes; resuming from a fully-consumed checkpoint degrades to one empty pass
  • test asserts affirmative invariants (disjoint keys, resumed starts at consumed+1, row conservation, second pass reads full dataset) — re-verified it fails on the old set-after-create pattern
  • stale # Restore dataloader state if continuing training comment reworded

Deferred: rank-0 read + broadcast of dataloader_state.json. The file is written by rank 0 at a previous checkpoint, long before restore, so stale-read divergence needs an unusually skewed shared FS; adding a collective here also couples dataloader construction to PG state. Can follow up separately if we see divergent epoch lengths in practice.

The reader-side clear left the parent's dataloader_state dict intact;
since update_dataloder_state max-merges per key, every save during the
next pass (or between passes) wrote back the old fully-consumed
positions, so crash-resume cycles never recorded forward progress.
Clear the dict at the train loop's StopIteration: epoch-boundary and
final checkpoints now store an empty state (resume = fresh full pass,
no empty pass), and mid-pass checkpoints of later epochs record real
positions within that pass.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
@tiankongdeguiji

Copy link
Copy Markdown
Collaborator Author

1e472f6 closes a gap in the consume-once fix: the reader-side clear left the parent's dataloader_state intact, and since update_dataloder_state max-merges per key, any checkpoint taken during the next pass (or at the pass boundary) wrote back the old fully-consumed positions — crash-resume cycles would never record forward progress. The train loop now clears the dict at StopIteration, so epoch-boundary/final checkpoints store an empty state (resume = fresh full pass, no empty pass) and mid-pass checkpoints of later epochs record real positions.

Record __epochs_completed__ in dataloader_state at every pass boundary
(where the per-pass positions are cleared) and resume num_epochs jobs
with epoch_iter = range(completed, num_epochs), so a job that dies
mid-pass-N trains exactly the remaining budget instead of num_epochs
fresh passes. skip_steps already gives exact continuation for num_steps
jobs; this closes the epoch-mode gap. Checkpoints from fine-tune sources
keep their data positions but drop the counter -- the epoch budget
belongs to the job, not the warm-start weights. Reserved keys carry no
':' so per-source consumers skip them; old checkpoints without the key
resume with the previous behavior.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
@tiankongdeguiji

Copy link
Copy Markdown
Collaborator Author

faa42a1 adds the epoch-budget follow-up: __epochs_completed__ is persisted in dataloader_state at each pass boundary and num_epochs jobs resume from range(completed, num_epochs). Verified end-to-end on mock data: a 1-epoch run records {"__epochs_completed__": 1}; resuming the same model_dir with num_epochs: 2 logs "resume training after 1 completed epochs" and trains exactly one more pass (steps 8-15, final counter 2); fine-tuning from that checkpoint into a fresh model_dir ignores the counter and trains the full budget from step 0. Old checkpoints without the key keep the previous behavior; the key has no ':' so all per-source state consumers skip it (verified in calc_remaining_intervals, kafka on_assign, ODPS _restore_sessions).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji added the claude-review Let Claude Review label Jun 11, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label Jun 11, 2026
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Comment thread tzrec/main.py
Comment on lines +350 to +355
epochs_completed = dataloader_state.get(checkpoint_util.EPOCHS_COMPLETED, 0)
epoch_iter = (
range(min(epochs_completed, train_config.num_epochs), train_config.num_epochs)
if use_epoch
else itertools.count(0, 0)
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Major: when epochs_completed >= num_epochs (e.g. a scheduler auto-restarts a job that already finished, with --continue_train), epoch_iter is empty and the loop body never runs — but the model/optimizer restore lives inside the epoch loop (if i_step == 0 and ckpt_path is not None, L446). The post-loop maybe_save(i_step=0, ..., final=True) still passes the dedupe (0 != _last_ckpt_step == -1) and writes model.ckpt-0 from freshly-initialized weights into a model_dir full of real checkpoints, then run_eval appends metrics evaluated on random weights, and prune() may evict good checkpoints.

Suggest an explicit early return (with a log line, before set_save_policy/any save) when use_epoch and epochs_completed >= train_config.num_epochs. Note the min(...) clamp is dead code — range(7, 5) is already empty — and only masks this condition.

Minor, same line: the value comes straight from json.load with no validation; a non-int ("2", 2.0, hand-edited file) makes min()/range() raise an opaque TypeError. Worth coercing/validating in restore_dataloader_state.

Comment thread tzrec/main.py
Comment on lines +488 to +492
# pass completed: later saves should record positions
# within the next pass, on top of the completed-pass count.
epochs_completed += 1
dataloader_state.clear()
dataloader_state[checkpoint_util.EPOCHS_COMPLETED] = epochs_completed

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Major: the bumped state can fail to persist due to the per-step dedupe. If save_checkpoints_steps divides the pass length, the step-triggered save fires at the pass's final step s (with fully-consumed offsets); StopIteration then hits at s+1, i_step -= 1 brings it back to s, and both the epoch-boundary save (L511) and the final save (L541) return False on step == _last_ckpt_step in maybe_save. The cleared state + EPOCHS_COMPLETED bump is never written: the job's last checkpoint holds fully-consumed offsets without the bump.

Resuming such a checkpoint (or any legacy pre-1.2.19 checkpoint saved at pass end) makes the resumed pass empty, and the optimizer-warmup peek peek_batch = next(train_iterator) (L454) raises an uncaught StopIteration — it sits outside the try block. So the PR description's "legacy fully-consumed state yields one empty pass and then proceeds normally" only holds on the ignore_restore_optimizer path; the default --continue_train path crashes.

Suggested fixes: guard the peek (fall through to this handled branch on StopIteration), and make the epoch-boundary bookkeeping immune to the step dedupe (e.g. allow rewriting dataloader_state.json of the existing model.ckpt-s when the state changed).

Comment thread tzrec/datasets/dataset.py
Comment on lines +309 to +312
# Normal exhaustion = the resumed pass is complete; clear the state so
# subsequent epochs do full passes. close()/GeneratorExit skips this
# line, so dataloader resets mid-pass still re-seek from the state.
self._reader.load_state_dict(None)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edge case: the eager iter(dataloader) in create_dataloader makes persistent workers prefetch immediately (~prefetch_factor batches each). If a worker's remaining resumed interval fits inside that prefetch window — exactly the situation for checkpoints taken near pass end — its generator exhausts normally during the eager iter and this line clears the state before training starts. The real first epoch's iter(train_dataloader) then resets persistent workers, and that worker silently re-reads its full share of the pass (duplicated data, uneven batch counts across ranks).

Consider deferring the clear, e.g. set a flag here and drop the state at the start of the next __iter__ instead of immediately on exhaustion.

@github-actions

Copy link
Copy Markdown

Review summary

The core fix is correct and well-executed: applying checkpoint_state inside create_dataloader before the eager iter() forks persistent workers is the right place, and the new parquet regression test is deterministic and genuinely catches the old set-after-create bug. Verified details that hold up: reserved no-: keys are skipped by all three per-source consumers (parquet calc_remaining_intervals, Kafka exact-key lookup, ODPS _restore_sessions); load_state_dict(None) is safe for every reader; the dict(checkpoint_state) copy at dataset.py:797 is load-bearing (it shields the num_workers=0 in-process dataset from the trainer's later dataloader_state.clear()) — worth a one-line comment so it doesn't get "simplified" away.

Three correctness issues are posted inline, all in the new epoch-resume logic (empty epoch range resumes save an unrestored model; the step-save dedupe can drop the EPOCHS_COMPLETED bump, whose fully-consumed checkpoint then crashes the unguarded next(train_iterator) peek on resume; eager-prefetch exhaustion can clear the resumed state before training starts).

Test coverage: the dataset-layer half is well tested, but the main.py orchestration half ships untested — no test in the repo exercises --continue_train, so a regression re-ordering the restore below create_dataloader would pass CI. An integration test in rank_integration_test.py (train with num_epochs=2 + small save_checkpoints_steps, kill/rerun same model_dir with --continue_train, assert __epochs_completed__ round-trips) would protect the actual bug being fixed. The fine-tune EPOCHS_COMPLETED pop is also unreachable by existing tests (test_multi_tower_din_fg_encoded_finetune doesn't pass --continue_train).

Docs (minor): the resume semantics are user-visible but docs/source/feature/data.md (断点续训) and docs/source/usage/train.md (--continue_train) aren't updated — a brief note on the epoch-budget resume and the fine-tune behavior (positions kept, budget reset) would help. The load_state_dict docstrings could also mention that None clears state and that reserved keys are ignored, since the PR now relies on both.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant