diff --git a/torchdata/nodes/prefetch.py b/torchdata/nodes/prefetch.py index 29f64b083..7ca4059d1 100644 --- a/torchdata/nodes/prefetch.py +++ b/torchdata/nodes/prefetch.py @@ -34,9 +34,11 @@ def __init__(self, source: BaseNode[T], prefetch_factor: int, snapshot_frequency def reset(self, initial_state: Optional[Dict[str, Any]] = None): super().reset(initial_state) - if self._it is not None: + if hasattr(self, "_it") and self._it is not None: self._it._shutdown() del self._it + + # This can throw, so _it may be deleted self._it = _SingleThreadedMapper( source=self.source, prefetch_factor=self.prefetch_factor,