diff --git a/grain/_src/python/dataset/transformations/interleave.py b/grain/_src/python/dataset/transformations/interleave.py index f91077637..647f8d433 100644 --- a/grain/_src/python/dataset/transformations/interleave.py +++ b/grain/_src/python/dataset/transformations/interleave.py @@ -17,7 +17,7 @@ from collections.abc import Sequence import copy import functools -from typing import Any, TypeVar +from typing import Any, TypeVar, cast import weakref from absl import logging @@ -28,7 +28,6 @@ from grain._src.python.dataset import stats from grain._src.python.dataset.transformations import prefetch - T = TypeVar("T") @@ -92,6 +91,7 @@ def __init__( ] = [None] * self._cycle_length # Future states used for elastic iterators self._future_states: dict[int, Any] = {} + self._iterator_start_states: dict[int, Any] = {} @stats.record_next_duration_if_output @stats.trace_input_pipeline_next(stage_category=stats.IPL_CAT_PREPROCESSING) @@ -208,6 +208,41 @@ def get_state(self): } return state + def get_shard_states(self) -> Sequence[Any]: + state = self.get_state() + indices = state["iterators_in_use_indices"] + states = state["iterators_in_use_states"] + exhausted = state["exhausted"] + next_index_in_datasets = state["next_index_in_datasets"] + + shard_states = [None] * len(self._datasets) + + for i in range(len(self._datasets)): # pylint: disable=protected-access + # If the current shard index is greater than or equal to the next + # index in datasets, it means the current shard has not yet started + # to be iterated on. + if i >= next_index_in_datasets: + shard_states[i] = { + "exhausted": 0, + "state": self._get_iterator_start_state(i), # pylint: disable=protected-access + } + elif i not in indices: + # These shards are exhausted but should still create a state to maintain + # static state spec shapes. + shard_states[i] = { + "exhausted": 1, + "state": self._get_iterator_start_state(i), # pylint: disable=protected-access + } + + for index, ds_state, is_exhausted in zip(indices, states, exhausted): + # These shards are currently being iterated on. + shard_states[index] = { + "exhausted": is_exhausted, + "state": ds_state, + } + + return shard_states + def set_state(self, state): exhausted = state["exhausted"] for index_in_cycle, (index_in_datasets, it_state) in enumerate( @@ -252,7 +287,45 @@ def set_state(self, state): self._next_index_in_cycle = state["next_index_in_cycle"] self._next_index_in_datasets = state["next_index_in_datasets"] self._iterators_in_use_indices = state["iterators_in_use_indices"] - self._future_states = state.get("future_states", {}) + self._future_states = cast(dict[int, Any], state.get("future_states", {})) + + def set_shard_states(self, state: Sequence[Any]) -> None: + active_states = [] + for ind, shard_state in enumerate(state): + if not shard_state["exhausted"]: + active_states.append((ind, shard_state["state"])) + + iterators_in_use_indices = [] + iterators_in_use_states = [] + exhausted = [] + count = 0 + future_states = {} + for ind, s in active_states: + if count < self._cycle_length: + iterators_in_use_indices.append(ind) + iterators_in_use_states.append(s) + exhausted.append(0) + count += 1 + elif s: + future_states[ind] = s + next_index_in_datasets = ( + max(iterators_in_use_indices) + 1 if iterators_in_use_indices else 0 + ) + while count < self._cycle_length: + iterators_in_use_indices.append(next_index_in_datasets) + iterators_in_use_states.append(None) + exhausted.append(1) + count += 1 + + new_state = { + "next_index_in_cycle": 0, + "next_index_in_datasets": next_index_in_datasets, + "iterators_in_use_indices": iterators_in_use_indices, + "iterators_in_use_states": iterators_in_use_states, + "exhausted": exhausted, + "future_states": future_states, + } + self.set_state(new_state) def _get_next_index(self) -> int: if len(self._datasets) == 1: @@ -334,14 +407,15 @@ def __str__(self) -> str: ) def _get_iterator_start_state(self, index: int) -> dict[str, Any]: - it = _add_prefetch_and_make_iterator( - self._datasets[index], - weakref.ref(self), - start_prefetch=False, - ) - state = it.get_state() - del it - return state + if index not in self._iterator_start_states: + it = _add_prefetch_and_make_iterator( + self._datasets[index], + weakref.ref(self), + start_prefetch=False, + ) + self._iterator_start_states[index] = it.get_state() + it.close() + return self._iterator_start_states[index] def _add_prefetch_and_make_iterator( diff --git a/grain/_src/python/dataset/transformations/interleave_test.py b/grain/_src/python/dataset/transformations/interleave_test.py index 29cae4704..f5744c3b9 100644 --- a/grain/_src/python/dataset/transformations/interleave_test.py +++ b/grain/_src/python/dataset/transformations/interleave_test.py @@ -24,7 +24,6 @@ from grain._src.python.testing.experimental import assert_equal_output_after_checkpoint import numpy as np - _INTERLEAVE_TEST_CASES = ( dict( testcase_name="cycle_length_1", @@ -339,6 +338,35 @@ def test_future_states(self): with self.assertRaises(StopIteration): next(ds_iter) + def test_get_set_shard_state(self): + ds1 = dataset.MapDataset.source([1, 2, 3]).to_iter_dataset() + ds2 = dataset.MapDataset.source([4, 5, 6]).to_iter_dataset() + ds = self._create_dataset([ds1, ds2], cycle_length=2) + ds = self._maybe_wrap_ds(ds) + it = cast(interleave.InterleaveDatasetIterator, ds.__iter__()) + + # Consume some elements to advance state. + self.assertEqual(next(it), 1) + self.assertEqual(next(it), 4) + + # Get the shard state. + shard_state = it.get_shard_states() + + # Check correctness of the individual shard statuses in scenarios without + # preprefetch wrapping. + expected_shard_state = [ + {"exhausted": 0, "state": {"next_index": 1}}, + {"exhausted": 0, "state": {"next_index": 1}}, + ] + self.assertEqual(shard_state, expected_shard_state) + + # Create a new iterator and restore state. + it2 = cast(interleave.InterleaveDatasetIterator, ds.__iter__()) + it2.set_shard_states(shard_state) + + # Verify it continues from the correct position. + self.assertSequenceEqual(list(it2), [2, 5, 3, 6]) + if __name__ == "__main__": absltest.main() diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index 651b09275..b329459de 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -88,6 +88,32 @@ def set_slice(self, sl: slice, sequential_slice: bool = False) -> None: ... +@typing.runtime_checkable +class SupportsSlicedStateManagement(Protocol): + """Iterators that support setting a sliced state. + + This protocol is mainly used to support elastic resizing of iterators. + """ + + def get_shard_states(self) -> Sequence[Any]: + """Returns the states of all shards managed by this iterator. + + This is used for elastic resizing to capture the current progress of each + shard. + """ + ... + + def set_shard_states(self, state: Sequence[Any]): + """Sets the states of all shards managed by this iterator. + + This is used for elastic resizing to restore the progress of each shard. + + Args: + state: A sequence of states, one for each shard. + """ + ... + + class PrefetchIterDataset(dataset.IterDataset[T]): """Iterable dataset that uses a thread pool for prefetching."""