Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 85 additions & 11 deletions grain/_src/python/dataset/transformations/interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections.abc import Sequence
import copy
import functools
from typing import Any, TypeVar
from typing import Any, TypeVar, cast
import weakref

from absl import logging
Expand All @@ -28,7 +28,6 @@
from grain._src.python.dataset import stats
from grain._src.python.dataset.transformations import prefetch


T = TypeVar("T")


Expand Down Expand Up @@ -92,6 +91,7 @@ def __init__(
] = [None] * self._cycle_length
# Future states used for elastic iterators
self._future_states: dict[int, Any] = {}
self._iterator_start_states: dict[int, Any] = {}

@stats.record_next_duration_if_output
@stats.trace_input_pipeline_next(stage_category=stats.IPL_CAT_PREPROCESSING)
Expand Down Expand Up @@ -208,6 +208,41 @@ def get_state(self):
}
return state

def get_shard_states(self) -> Sequence[Any]:
state = self.get_state()
indices = state["iterators_in_use_indices"]
states = state["iterators_in_use_states"]
exhausted = state["exhausted"]
next_index_in_datasets = state["next_index_in_datasets"]

shard_states = [None] * len(self._datasets)

for i in range(len(self._datasets)): # pylint: disable=protected-access
# If the current shard index is greater than or equal to the next
# index in datasets, it means the current shard has not yet started
# to be iterated on.
if i >= next_index_in_datasets:
shard_states[i] = {
"exhausted": 0,
"state": self._get_iterator_start_state(i), # pylint: disable=protected-access
}
elif i not in indices:
# These shards are exhausted but should still create a state to maintain
# static state spec shapes.
shard_states[i] = {
"exhausted": 1,
"state": self._get_iterator_start_state(i), # pylint: disable=protected-access
}

for index, ds_state, is_exhausted in zip(indices, states, exhausted):
# These shards are currently being iterated on.
shard_states[index] = {
"exhausted": is_exhausted,
"state": ds_state,
}

return shard_states

def set_state(self, state):
exhausted = state["exhausted"]
for index_in_cycle, (index_in_datasets, it_state) in enumerate(
Expand Down Expand Up @@ -252,7 +287,45 @@ def set_state(self, state):
self._next_index_in_cycle = state["next_index_in_cycle"]
self._next_index_in_datasets = state["next_index_in_datasets"]
self._iterators_in_use_indices = state["iterators_in_use_indices"]
self._future_states = state.get("future_states", {})
self._future_states = cast(dict[int, Any], state.get("future_states", {}))

def set_shard_states(self, state: Sequence[Any]) -> None:
active_states = []
for ind, shard_state in enumerate(state):
if not shard_state["exhausted"]:
active_states.append((ind, shard_state["state"]))

iterators_in_use_indices = []
iterators_in_use_states = []
exhausted = []
count = 0
future_states = {}
for ind, s in active_states:
if count < self._cycle_length:
iterators_in_use_indices.append(ind)
iterators_in_use_states.append(s)
exhausted.append(0)
count += 1
elif s:
future_states[ind] = s
next_index_in_datasets = (
max(iterators_in_use_indices) + 1 if iterators_in_use_indices else 0
)
while count < self._cycle_length:
iterators_in_use_indices.append(next_index_in_datasets)
iterators_in_use_states.append(None)
exhausted.append(1)
count += 1

new_state = {
"next_index_in_cycle": 0,
"next_index_in_datasets": next_index_in_datasets,
"iterators_in_use_indices": iterators_in_use_indices,
"iterators_in_use_states": iterators_in_use_states,
"exhausted": exhausted,
"future_states": future_states,
}
self.set_state(new_state)

def _get_next_index(self) -> int:
if len(self._datasets) == 1:
Expand Down Expand Up @@ -334,14 +407,15 @@ def __str__(self) -> str:
)

def _get_iterator_start_state(self, index: int) -> dict[str, Any]:
it = _add_prefetch_and_make_iterator(
self._datasets[index],
weakref.ref(self),
start_prefetch=False,
)
state = it.get_state()
del it
return state
if index not in self._iterator_start_states:
it = _add_prefetch_and_make_iterator(
self._datasets[index],
weakref.ref(self),
start_prefetch=False,
)
self._iterator_start_states[index] = it.get_state()
it.close()
return self._iterator_start_states[index]


def _add_prefetch_and_make_iterator(
Expand Down
30 changes: 29 additions & 1 deletion grain/_src/python/dataset/transformations/interleave_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from grain._src.python.testing.experimental import assert_equal_output_after_checkpoint
import numpy as np


_INTERLEAVE_TEST_CASES = (
dict(
testcase_name="cycle_length_1",
Expand Down Expand Up @@ -339,6 +338,35 @@ def test_future_states(self):
with self.assertRaises(StopIteration):
next(ds_iter)

def test_get_set_shard_state(self):
ds1 = dataset.MapDataset.source([1, 2, 3]).to_iter_dataset()
ds2 = dataset.MapDataset.source([4, 5, 6]).to_iter_dataset()
ds = self._create_dataset([ds1, ds2], cycle_length=2)
ds = self._maybe_wrap_ds(ds)
it = cast(interleave.InterleaveDatasetIterator, ds.__iter__())

# Consume some elements to advance state.
self.assertEqual(next(it), 1)
self.assertEqual(next(it), 4)

# Get the shard state.
shard_state = it.get_shard_states()

# Check correctness of the individual shard statuses in scenarios without
# preprefetch wrapping.
expected_shard_state = [
{"exhausted": 0, "state": {"next_index": 1}},
{"exhausted": 0, "state": {"next_index": 1}},
]
self.assertEqual(shard_state, expected_shard_state)

# Create a new iterator and restore state.
it2 = cast(interleave.InterleaveDatasetIterator, ds.__iter__())
it2.set_shard_states(shard_state)

# Verify it continues from the correct position.
self.assertSequenceEqual(list(it2), [2, 5, 3, 6])


if __name__ == "__main__":
absltest.main()
26 changes: 26 additions & 0 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,32 @@ def set_slice(self, sl: slice, sequential_slice: bool = False) -> None:
...


@typing.runtime_checkable
class SupportsSlicedStateManagement(Protocol):
"""Iterators that support setting a sliced state.

This protocol is mainly used to support elastic resizing of iterators.
"""

def get_shard_states(self) -> Sequence[Any]:
"""Returns the states of all shards managed by this iterator.

This is used for elastic resizing to capture the current progress of each
shard.
"""
...

def set_shard_states(self, state: Sequence[Any]):
"""Sets the states of all shards managed by this iterator.

This is used for elastic resizing to restore the progress of each shard.

Args:
state: A sequence of states, one for each shard.
"""
...


class PrefetchIterDataset(dataset.IterDataset[T]):
"""Iterable dataset that uses a thread pool for prefetching."""

Expand Down
Loading