diff --git a/seqio/dataset_providers.py b/seqio/dataset_providers.py index 79b53410..57675a15 100644 --- a/seqio/dataset_providers.py +++ b/seqio/dataset_providers.py @@ -288,12 +288,14 @@ def __init__( splits: Iterable[str], num_input_examples: Optional[Mapping[str, int]] = None, caching_permitted: bool = True, + allow_no_shuffle_buffer: bool = False, ): self._splits = tuple(splits) self._num_input_examples = ( dict(num_input_examples) if num_input_examples is not None else None ) self._caching_permitted = caching_permitted + self._allow_no_shuffle_buffer = allow_no_shuffle_buffer @property def caching_permitted(self) -> bool: @@ -319,6 +321,15 @@ def output_features(self) -> Mapping[str, Feature]: """Override unused property of `DatasetProviderBase`.""" raise NotImplementedError + @property + def allow_no_shuffle_buffer(self) -> bool: + """Indicates whether this data source be shuffled without a buffer. + + Some datasets may provide internal shuffling mechanisms that could allow + the dataset to be shuffled without calling ds.shuffle(). + """ + return self._allow_no_shuffle_buffer + @abc.abstractmethod def list_shards(self, split: str) -> Sequence[str]: """Returns string identifiers of input shards.""" @@ -590,6 +601,7 @@ def __init__( file_shuffle_buffer_size: Optional[int] = None, cycle_length: int = 16, block_length: int = 16, + allow_no_shuffle_buffer: bool = False, ): """FileDataSource constructor. @@ -609,6 +621,9 @@ def __init__( replicate earlier behavior. cycle_length: The cycle_length to pass to tf.data.Dataset.interleave. block_length: The block_length to pass to tf.data.Dataset.interleave. + allow_no_shuffle_buffer: Allow enclosing task to call get_dataset with + shuffle_buffer_size=None. In this case, only filename shuffling will be + performed when shuffle==True. """ self._split_to_filepattern = split_to_filepattern self._reader = read_file_fn @@ -619,6 +634,7 @@ def __init__( splits=split_to_filepattern.keys(), num_input_examples=num_input_examples, caching_permitted=caching_permitted, + allow_no_shuffle_buffer=allow_no_shuffle_buffer, ) @property @@ -1663,14 +1679,16 @@ def get_dataset( ds = self._trim_output_features(ds, sequence_length=sequence_length) if shuffle: if self._shuffle_buffer_size is None: - raise ValueError( - f"Shuffling is disallowed for Task '{self.name}' since its " - "`shuffle_buffer_size` was set to `None` on construction." - ) - shuffle_buffer_size = shuffle_buffer_size or self._shuffle_buffer_size - # Shuffle before mixing since preprocessor can output multiple - # (correlated) examples per input. - ds = ds.shuffle(shuffle_buffer_size, seed=seed) + if not self.source.allow_no_shuffle_buffer: + raise ValueError( + f"Shuffling is disallowed for Task '{self.name}' since its " + "`shuffle_buffer_size` was set to `None` on construction." + ) + else: + shuffle_buffer_size = shuffle_buffer_size or self._shuffle_buffer_size + # Shuffle before mixing since preprocessor can output multiple + # (correlated) examples per input. + ds = ds.shuffle(shuffle_buffer_size, seed=seed) return ds.prefetch(tf.data.experimental.AUTOTUNE) diff --git a/seqio/dataset_providers_test.py b/seqio/dataset_providers_test.py index 5fbd247d..599c7243 100644 --- a/seqio/dataset_providers_test.py +++ b/seqio/dataset_providers_test.py @@ -302,6 +302,12 @@ def test_disallow_shuffle(self): task.get_dataset(None, shuffle=False) + # When the source specifies allow_no_shuffle_buffer, it should be possible + # to call get_dataset with shuffle=True and shuffle_buffer_size=None. In + # this case, only the source's internal shuffling mechanism will be active. + self.function_source._allow_no_shuffle_buffer = True + task.get_dataset(None, shuffle=True) + def test_supports_caching(self): self.assertFalse( dataset_providers.Task(