[bugfix] restore dataloader state before create_dataloader forks workers#544
[bugfix] restore dataloader state before create_dataloader forks workers#544tiankongdeguiji wants to merge 9 commits into
Conversation
create_dataloader eagerly starts persistent workers via iter(dataloader); workers keep a fork-time copy of the dataset, so load_state_dict applied to the returned dataloader's dataset afterwards never reaches them. Kafka consumers therefore ignored the restored offsets and re-seeked via start.timestamp.ms / committed offsets on resume. Hoist CheckpointManager construction, ckpt_path resolution and restore_dataloader_state above create_dataloader, and pass the state via a new create_dataloader(checkpoint_state=...) argument applied to the dataset before the DataLoader is built. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Regression test for the restore ordering: resume state passed through create_dataloader(checkpoint_state=...) is honored by persistent workers (num_workers=2); with the old set-after-create pattern workers replay from row 0. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
| if checkpoint_state: | ||
| dataset.load_state_dict(checkpoint_state) |
There was a problem hiding this comment.
BaseReader.load_state_dict stores this dict by reference (dataset.py:555), and train_and_evaluate passes the very same object on to _train_and_evaluate, which mutates it in place on every step via update_dataloder_state (main.py:462) and at save time. With num_workers >= 1 the fork-time copy makes this accidentally safe, but with data_config.num_workers < 1 the DataLoader runs in-process, so the reader aliases the live, ever-growing dict — the next iter(train_dataloader) (epoch 2) recomputes intervals from offsets accumulated during this run and skips that data. KafkaReader's on_assign rebalance callback also reads _checkpoint_state at arbitrary later times, cross-thread. A defensive copy breaks the aliasing:
| if checkpoint_state: | |
| dataset.load_state_dict(checkpoint_state) | |
| if checkpoint_state: | |
| dataset.load_state_dict(dict(checkpoint_state)) |
There was a problem hiding this comment.
Fixed in 5389e7f — create_dataloader now passes dict(checkpoint_state) to load_state_dict, decoupling the reader from the training loop's accumulating dict.
| # Build dataloader | ||
| train_dataloader = create_dataloader( | ||
| data_config, | ||
| features, | ||
| pipeline_config.train_input_path, | ||
| mode=Mode.TRAIN, | ||
| checkpoint_state=dataloader_state, | ||
| ) |
There was a problem hiding this comment.
Now that the state actually reaches the workers, note it is re-applied on every epoch, not just the resumed one: nothing ever clears _checkpoint_state, and each epoch's iter(train_dataloader) re-enters to_batches → calc_slice_intervals(checkpoint_state=...) (parquet_dataset.py:252-261). For num_epochs > 1 + --continue_train:
- After a mid-epoch resume, every subsequent epoch reads only the unconsumed tail instead of a full pass.
dataloader_stateaccumulates across the whole run with no per-pass reset (main.py:462), so a checkpoint saved after one full pass marks the dataset fully consumed — resuming from it makescalc_remaining_intervalsreturn[]and every epoch yields zero batches (withnum_steps, the loop spins on immediateStopIterationandi_stepnever advances).
Pre-PR this was latent (the state never reached forked workers at all); this PR activates it in the default multi-worker config. One subtlety if you fix it with consume-once semantics: clearing on entry to to_batches would break the at-least-once re-seek this PR relies on (the eager iter() in create_dataloader plus the train loop's iter() are two resets that must both see the state) — clearing after the generator is fully exhausted (end of to_batches; a closed prefetch generator skips it) gives "finish the interrupted pass, then full passes". Alternatively, if resume is intentionally single-pass/streaming-only (kafka/odps), an explicit guard or a docstring caveat here and on create_dataloader would prevent silent data loss for multi-epoch configs.
There was a problem hiding this comment.
Fixed in 5389e7f with the consume-once-on-exhaustion approach: BaseDataset.__iter__ clears the reader state after the wrapped to_batches generator exhausts normally (generator close() skips the clear, preserving the mid-pass re-seek that the eager iter() + train-loop iter() double reset relies on). Resumed pass finishes the tail, later epochs are full passes, and resuming from a fully-consumed checkpoint degrades to one empty pass instead of spinning. Verified the clear is safe for all readers (OdpsReader.load_state_dict guards if state:; Kafka's generator never exhausts, so streaming is unchanged) and covered by a new second-pass-reads-full-dataset assertion in the test.
| for key, new_offset in batch2.checkpoint_info.items(): | ||
| if key in checkpoint_state: | ||
| # without the pre-fork state, workers replay from row 0 | ||
| self.assertGreater(new_offset, checkpoint_state[key]) |
There was a problem hiding this comment.
This guard makes the green path vacuous: when the fix works, resumed source keys are f"{input_path}:{consumed+1}" (calc_remaining_intervals starts remaining intervals at consumed + 1), which can never equal a saved key f"{input_path}:{start}" — so key in checkpoint_state is never true and the loop asserts nothing on all 4 batches. The test does still catch the targeted regression (without pre-fork state, workers emit the original keys :0/:10000 with offsets ≤ the saved ones, failing assertGreater), but adjacent regressions (off-by-N resume position, key-format drift) would pass silently. Asserting the affirmative invariant keeps the detection power and makes the pass path meaningful:
resumed_starts = set()
for _ in range(4):
batch2 = next(iterator2)
self.assertTrue(set(batch2.checkpoint_info).isdisjoint(checkpoint_state))
resumed_starts |= {int(k.rsplit(":", 1)[1]) for k in batch2.checkpoint_info}
self.assertEqual(resumed_starts, {v + 1 for v in checkpoint_state.values()})There was a problem hiding this comment.
Fixed in 51f6cf4 — phase 2 now iterates to exhaustion and asserts the affirmative invariants: key disjointness, presence of every {path}:{consumed+1} resumed start, and row-count conservation against an empirically measured baseline (robust to slicing details). Re-verified detection power: under the old set-after-create pattern all three assertions fail (full 20000-row replay vs expected 19488).
| # 20000 rows at max_rows_per_file=5000 -> 4 files, so | ||
| # data_config.num_workers=2 survives the num_files clamp. |
There was a problem hiding this comment.
This comment justifies the fixture size with a constraint that doesn't apply: ParquetReader.num_files() returns None when rebalance=True (the default, and ParquetDataset doesn't override it), so the num_files clamp in create_dataloader never fires for ParquetDataset. ~2k rows would exercise the same paths faster; either shrink the fixture or fix the comment.
There was a problem hiding this comment.
Fixed in 51f6cf4 — comment now justifies the fixture by rebalanced multi-worker intervals + a remaining tail; kept the 20000 rows since the reworked assertions iterate full passes.
Review summaryThe fix is correct and well-motivated: applying
Two issues worth addressing before merge (details in inline comments):
Minor: the pre-existing comment at |
…mpletes Review follow-ups for the restore-ordering fix: - create_dataloader copies the state dict: the reader keeps it by reference while the train loop mutates the caller's dict per step (aliases the live dict with num_workers=0; Kafka on_assign reads it cross-thread). - BaseDataset.__iter__ clears the reader state after the wrapped to_batches generator is exhausted: a resumed pass finishes the remaining intervals, subsequent epochs read full passes, and resuming from a fully-consumed checkpoint degrades to one empty pass instead of yielding zero batches forever. Generator close() skips the clear, so mid-pass dataloader resets still re-seek from the restored state. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…e test
Resumed source keys ({path}:{consumed+1}) never collide with saved keys,
so the previous green path asserted nothing. Now: full-pass iteration
asserts key disjointness, presence of every resumed start, row-count
conservation vs an empirically measured baseline, and that a second
iter() after the completed pass reads the full dataset again
(state consumed on exhaustion).
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
|
Review addressed in 5389e7f + 51f6cf4 (details in the inline replies):
Deferred: rank-0 read + broadcast of |
The reader-side clear left the parent's dataloader_state dict intact; since update_dataloder_state max-merges per key, every save during the next pass (or between passes) wrote back the old fully-consumed positions, so crash-resume cycles never recorded forward progress. Clear the dict at the train loop's StopIteration: epoch-boundary and final checkpoints now store an empty state (resume = fresh full pass, no empty pass), and mid-pass checkpoints of later epochs record real positions within that pass. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
|
1e472f6 closes a gap in the consume-once fix: the reader-side clear left the parent's |
Record __epochs_completed__ in dataloader_state at every pass boundary (where the per-pass positions are cleared) and resume num_epochs jobs with epoch_iter = range(completed, num_epochs), so a job that dies mid-pass-N trains exactly the remaining budget instead of num_epochs fresh passes. skip_steps already gives exact continuation for num_steps jobs; this closes the epoch-mode gap. Checkpoints from fine-tune sources keep their data positions but drop the counter -- the epoch budget belongs to the job, not the warm-start weights. Reserved keys carry no ':' so per-source consumers skip them; old checkpoints without the key resume with the previous behavior. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
|
faa42a1 adds the epoch-budget follow-up: |
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
| epochs_completed = dataloader_state.get(checkpoint_util.EPOCHS_COMPLETED, 0) | ||
| epoch_iter = ( | ||
| range(min(epochs_completed, train_config.num_epochs), train_config.num_epochs) | ||
| if use_epoch | ||
| else itertools.count(0, 0) | ||
| ) |
There was a problem hiding this comment.
Major: when epochs_completed >= num_epochs (e.g. a scheduler auto-restarts a job that already finished, with --continue_train), epoch_iter is empty and the loop body never runs — but the model/optimizer restore lives inside the epoch loop (if i_step == 0 and ckpt_path is not None, L446). The post-loop maybe_save(i_step=0, ..., final=True) still passes the dedupe (0 != _last_ckpt_step == -1) and writes model.ckpt-0 from freshly-initialized weights into a model_dir full of real checkpoints, then run_eval appends metrics evaluated on random weights, and prune() may evict good checkpoints.
Suggest an explicit early return (with a log line, before set_save_policy/any save) when use_epoch and epochs_completed >= train_config.num_epochs. Note the min(...) clamp is dead code — range(7, 5) is already empty — and only masks this condition.
Minor, same line: the value comes straight from json.load with no validation; a non-int ("2", 2.0, hand-edited file) makes min()/range() raise an opaque TypeError. Worth coercing/validating in restore_dataloader_state.
| # pass completed: later saves should record positions | ||
| # within the next pass, on top of the completed-pass count. | ||
| epochs_completed += 1 | ||
| dataloader_state.clear() | ||
| dataloader_state[checkpoint_util.EPOCHS_COMPLETED] = epochs_completed |
There was a problem hiding this comment.
Major: the bumped state can fail to persist due to the per-step dedupe. If save_checkpoints_steps divides the pass length, the step-triggered save fires at the pass's final step s (with fully-consumed offsets); StopIteration then hits at s+1, i_step -= 1 brings it back to s, and both the epoch-boundary save (L511) and the final save (L541) return False on step == _last_ckpt_step in maybe_save. The cleared state + EPOCHS_COMPLETED bump is never written: the job's last checkpoint holds fully-consumed offsets without the bump.
Resuming such a checkpoint (or any legacy pre-1.2.19 checkpoint saved at pass end) makes the resumed pass empty, and the optimizer-warmup peek peek_batch = next(train_iterator) (L454) raises an uncaught StopIteration — it sits outside the try block. So the PR description's "legacy fully-consumed state yields one empty pass and then proceeds normally" only holds on the ignore_restore_optimizer path; the default --continue_train path crashes.
Suggested fixes: guard the peek (fall through to this handled branch on StopIteration), and make the epoch-boundary bookkeeping immune to the step dedupe (e.g. allow rewriting dataloader_state.json of the existing model.ckpt-s when the state changed).
| # Normal exhaustion = the resumed pass is complete; clear the state so | ||
| # subsequent epochs do full passes. close()/GeneratorExit skips this | ||
| # line, so dataloader resets mid-pass still re-seek from the state. | ||
| self._reader.load_state_dict(None) |
There was a problem hiding this comment.
Edge case: the eager iter(dataloader) in create_dataloader makes persistent workers prefetch immediately (~prefetch_factor batches each). If a worker's remaining resumed interval fits inside that prefetch window — exactly the situation for checkpoints taken near pass end — its generator exhausts normally during the eager iter and this line clears the state before training starts. The real first epoch's iter(train_dataloader) then resets persistent workers, and that worker silently re-reads its full share of the pass (duplicated data, uneven batch counts across ranks).
Consider deferring the clear, e.g. set a flag here and drop the state at the start of the next __iter__ instead of immediately on exhaustion.
Review summaryThe core fix is correct and well-executed: applying Three correctness issues are posted inline, all in the new epoch-resume logic (empty epoch range resumes save an unrestored model; the step-save dedupe can drop the Test coverage: the dataset-layer half is well tested, but the Docs (minor): the resume semantics are user-visible but |
Problem
train_and_evaluaterestores the dataloader checkpoint state aftercreate_dataloader, butcreate_dataloaderends with an eageriter(dataloader)(persistent_workers=True) that forks the DataLoader workers immediately. Workers keep a fork-time copy of the dataset with_checkpoint_state=None; the laterload_state_dictmutates only the parent process, and persistent workers are never re-pickled (_ResumeIterationrecreates the fetcher from the worker's own dataset copy). The restored state therefore never reaches the processes that actually read data.Production symptom: on resume with
--fine_tune_checkpoint+--continue_train, Kafka consumers ignoreddataloader_state.jsonand re-seeked viastart.timestamp.ms(days of data re-trained). The same ordering bug affects Parquet/ODPS resume.The eager
iter()itself must stay: forking workers later in the lifecycle (after CUDA-holding garbage cycles exist in the heap) makes worker GC abort withCUDA error: initialization error— the original reason it was added in v0.7.0.Fix
create_dataloaderaccepts an optionalcheckpoint_stateand applies it to the dataset before the DataLoader is built, so the eageriter()forks workers that already carry the state.train_and_evaluatehoistsCheckpointManagerconstruction, ckpt-path resolution andrestore_dataloader_stateabovecreate_dataloader(logic unchanged; the manager's event-time watermark seeding is preserved) and passes the state through the new argument.With the state present at fork time, every dataloader reset (e.g. the train loop's
iter()) re-seeks to the checkpointed offsets, so prefetched-then-discarded batches are re-read — correct at-least-once semantics.Tests
test_create_dataloader_checkpoint_state_reaches_workers(parquet,num_workers=2): state passed throughcreate_dataloaderis honored by forked workers. Verified it catches the bug: with the old set-after-create pattern, workers replay from row 0.test_multi_tower_din_fg_encoded_finetuneintegration pass.🤖 Generated with Claude Code
Resume semantics (added after review)
The restored state is consumed by the first completed pass:
BaseDataset.__iter__clears the reader state when itsto_batchesgenerator exhausts normally (generatorclose()keeps it, so mid-pass dataloader resets still re-seek). Consequences:dataloader_stateis likewise reset at the pass boundary (train loop StopIteration), so checkpoints taken during later passes record positions within that pass instead of max-merging back the old fully-consumed state — epoch-boundary and final checkpoints store an empty state, and resuming from them starts a fresh full pass directly (no empty pass);The checkpoint also records the completed-pass count (reserved key
__epochs_completed__, incremented at each pass boundary), andnum_epochsjobs resume withrange(completed, num_epochs)— so a job that dies mid-pass-N trains exactly the remaining budget (skip_stepsalready handlednum_stepsjobs). Fine-tune checkpoints keep their data positions but drop the counter: the epoch budget belongs to the job, not the warm-start weights. Old checkpoints without the key resume with the previous behavior.