Skip to content

Commit

Permalink
Allow data sources to specify that they can be shuffled without a buf…
Browse files Browse the repository at this point in the history
…fer.

PiperOrigin-RevId: 633740168
  • Loading branch information
SeqIO Team authored and SeqIO committed May 15, 2024
1 parent 20e5f45 commit 254fb9d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
34 changes: 26 additions & 8 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,14 @@ def __init__(
splits: Iterable[str],
num_input_examples: Optional[Mapping[str, int]] = None,
caching_permitted: bool = True,
allow_no_shuffle_buffer: bool = False,
):
self._splits = tuple(splits)
self._num_input_examples = (
dict(num_input_examples) if num_input_examples is not None else None
)
self._caching_permitted = caching_permitted
self._allow_no_shuffle_buffer = allow_no_shuffle_buffer

@property
def caching_permitted(self) -> bool:
Expand All @@ -319,6 +321,15 @@ def output_features(self) -> Mapping[str, Feature]:
"""Override unused property of `DatasetProviderBase`."""
raise NotImplementedError

@property
def allow_no_shuffle_buffer(self) -> bool:
"""Indicates whether this data source be shuffled without a buffer.
Some datasets may provide internal shuffling mechanisms that could allow
the dataset to be shuffled without calling ds.shuffle().
"""
return self._allow_no_shuffle_buffer

@abc.abstractmethod
def list_shards(self, split: str) -> Sequence[str]:
"""Returns string identifiers of input shards."""
Expand Down Expand Up @@ -590,6 +601,7 @@ def __init__(
file_shuffle_buffer_size: Optional[int] = None,
cycle_length: int = 16,
block_length: int = 16,
allow_no_shuffle_buffer: bool = False,
):
"""FileDataSource constructor.
Expand All @@ -609,6 +621,9 @@ def __init__(
replicate earlier behavior.
cycle_length: The cycle_length to pass to tf.data.Dataset.interleave.
block_length: The block_length to pass to tf.data.Dataset.interleave.
allow_no_shuffle_buffer: Allow enclosing task to call get_dataset with
shuffle_buffer_size=None. In this case, only filename shuffling will be
performed when shuffle==True.
"""
self._split_to_filepattern = split_to_filepattern
self._reader = read_file_fn
Expand All @@ -619,6 +634,7 @@ def __init__(
splits=split_to_filepattern.keys(),
num_input_examples=num_input_examples,
caching_permitted=caching_permitted,
allow_no_shuffle_buffer=allow_no_shuffle_buffer,
)

@property
Expand Down Expand Up @@ -1663,14 +1679,16 @@ def get_dataset(
ds = self._trim_output_features(ds, sequence_length=sequence_length)
if shuffle:
if self._shuffle_buffer_size is None:
raise ValueError(
f"Shuffling is disallowed for Task '{self.name}' since its "
"`shuffle_buffer_size` was set to `None` on construction."
)
shuffle_buffer_size = shuffle_buffer_size or self._shuffle_buffer_size
# Shuffle before mixing since preprocessor can output multiple
# (correlated) examples per input.
ds = ds.shuffle(shuffle_buffer_size, seed=seed)
if not self.source.allow_no_shuffle_buffer:
raise ValueError(
f"Shuffling is disallowed for Task '{self.name}' since its "
"`shuffle_buffer_size` was set to `None` on construction."
)
else:
shuffle_buffer_size = shuffle_buffer_size or self._shuffle_buffer_size
# Shuffle before mixing since preprocessor can output multiple
# (correlated) examples per input.
ds = ds.shuffle(shuffle_buffer_size, seed=seed)


return ds.prefetch(tf.data.experimental.AUTOTUNE)
Expand Down
6 changes: 6 additions & 0 deletions seqio/dataset_providers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,12 @@ def test_disallow_shuffle(self):

task.get_dataset(None, shuffle=False)

# When the source specifies allow_no_shuffle_buffer, it should be possible
# to call get_dataset with shuffle=True and shuffle_buffer_size=None. In
# this case, only the source's internal shuffling mechanism will be active.
self.function_source._allow_no_shuffle_buffer = True
task.get_dataset(None, shuffle=True)

def test_supports_caching(self):
self.assertFalse(
dataset_providers.Task(
Expand Down

0 comments on commit 254fb9d

Please sign in to comment.