From 40318b4ab2df0efd51c82c6d4bee2b8a46f1aa1d Mon Sep 17 00:00:00 2001 From: Tom van der Weide Date: Fri, 20 Sep 2024 06:16:57 -0700 Subject: [PATCH] Add a ShardDatasetBuilder that creates shards directly. In certain cases, users have data available in different shards and they want to keep the same number of shards and in each shard the same order of examples (or they don't care about the ordering). In that case, our current dataset builder classes are much slower than necessary. The `ShardBasedBuilder` allows users to create dataset builders that process source data shard by shard. It can be run with or without Beam. In case of Beam, the resulting Beam pipeline is significantly simpler and therefore faster. PiperOrigin-RevId: 676820294 --- tensorflow_datasets/core/dataset_builder.py | 94 ++++++++++++++++--- .../core/dataset_builder_beam_test.py | 47 ++++++++++ .../core/dataset_builder_test.py | 48 ++++++++++ tensorflow_datasets/core/split_builder.py | 94 ++++++++++++++++++- tensorflow_datasets/core/writer.py | 60 +++++++++++- 5 files changed, 328 insertions(+), 15 deletions(-) diff --git a/tensorflow_datasets/core/dataset_builder.py b/tensorflow_datasets/core/dataset_builder.py index 1632a5541a4..90258909fbe 100644 --- a/tensorflow_datasets/core/dataset_builder.py +++ b/tensorflow_datasets/core/dataset_builder.py @@ -19,14 +19,14 @@ import abc import collections -from collections.abc import Sequence +from collections.abc import Iterable, Iterator, Mapping, Sequence import dataclasses import functools import inspect import json import os import sys -from typing import Any, ClassVar, Dict, Iterable, Iterator, List, Optional, Tuple, Type, Union +from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union from absl import logging from etils import epy @@ -1445,6 +1445,17 @@ def builder_configs(cls) -> dict[str, BuilderConfig]: ) return config_dict + def _get_filename_template( + self, split_name: str + ) -> naming.ShardedFileTemplate: + """Returns a filename template for the given split.""" + return naming.ShardedFileTemplate( + split=split_name, + dataset_name=self.name, + data_dir=self.data_path, + filetype_suffix=self.info.file_format.file_suffix, # pytype: disable=attribute-error + ) + class FileReaderBuilder(DatasetBuilder): """Base class for datasets reading files. @@ -1675,17 +1686,6 @@ def _example_writer(self) -> writer_lib.ExampleWriter: """ return writer_lib.ExampleWriter(file_format=self.info.file_format) - def _get_filename_template( - self, split_name: str - ) -> naming.ShardedFileTemplate: - """Returns a filename template for the given split.""" - return naming.ShardedFileTemplate( - split=split_name, - dataset_name=self.name, - data_dir=self.data_path, - filetype_suffix=self.info.file_format.file_suffix, # pytype: disable=attribute-error - ) - def _generate_splits( self, dl_manager: download.DownloadManager, @@ -1852,6 +1852,74 @@ def read_tfrecord_beam( ) +class ShardBasedBuilder(FileReaderBuilder): + """Base class for datasets with data generated shard by shard.""" + + def _download_and_prepare( + self, + dl_manager: download.DownloadManager, + download_config: download.DownloadConfig | None = None, + ) -> None: + download_config = download_config or download.DownloadConfig() + + split_builder = split_builder_lib.SplitBuilder( + split_dict=self.info.splits, + features=self.info.features, + dataset_size=self.info.dataset_size, + beam_options=download_config.beam_options, + beam_runner=download_config.beam_runner, + example_writer=self._example_writer(), + # The following options are ignored by `ShardBasedBuilder`. + ignore_duplicates=None, + max_examples_per_split=None, + shard_config=None, + ) + + shard_iterators_per_split = self._shard_iterators_per_split(dl_manager) + split_info_futures = [] + for split_name, example_gen_per_shard in shard_iterators_per_split.items(): + logging.info("Generating split %s", split_name) + split_info_future = split_builder.submit_shard_based_generation( + split_name=split_name, + example_gen_per_shard=example_gen_per_shard, + filename_template=self._get_filename_template(split_name=split_name), + ) + split_info_futures.append(split_info_future) + + # Update the info object with the splits. + split_infos: list[splits_lib.SplitInfo] = [ + future.result() for future in split_info_futures + ] + split_dict = splits_lib.SplitDict(split_infos) + self.info.set_splits(split_dict) + + @abc.abstractmethod + @utils.docs.do_not_doc_in_subclasses + @utils.docs.doc_private + def _shard_iterators_per_split( + self, dl_manager: download.DownloadManager + ) -> Mapping[str, Sequence[split_builder_lib.ExampleGeneratorFn]]: + """Returns a mapping from split name to example generators per shard. + + The example generators are functions that take no parameters and return + an iterator of tuples of key + example. The order of the example generators + is the order in which the shards will be written. + + Args: + dl_manager: `tfds.download.DownloadManager` used to download/extract the + data. + """ + raise NotImplementedError() + + def _example_writer(self) -> writer_lib.ExampleWriter: + """Returns an example writer. + + If datasets should be written to a custom storage, e.g., a database, then + implement a custom `ExampleWriter` and inject it here. + """ + return writer_lib.ExampleWriter(file_format=self.info.file_format) + + @utils.docs.deprecated class BeamBasedBuilder(GeneratorBasedBuilder): """Beam based Builder. diff --git a/tensorflow_datasets/core/dataset_builder_beam_test.py b/tensorflow_datasets/core/dataset_builder_beam_test.py index 890a5ffc6a9..ddf30c99d9c 100644 --- a/tensorflow_datasets/core/dataset_builder_beam_test.py +++ b/tensorflow_datasets/core/dataset_builder_beam_test.py @@ -15,6 +15,8 @@ """Tests for tensorflow_datasets.core.dataset_builder.""" +from collections.abc import Iterator, Mapping, Sequence +import functools import pathlib from typing import Callable from unittest import mock @@ -102,6 +104,31 @@ def _generate_examples(self, examples, num_examples): return examples +class ShardBuilderBeam(dataset_builder.ShardBasedBuilder): + VERSION = utils.Version('0.0.1') + + def _info(self): + return dataset_info.DatasetInfo( + builder=self, + features=features.FeaturesDict({'x': np.int64}), + ) + + def _shard_iterators_per_split(self, dl_manager): + del dl_manager + + def gen_examples(start: int, end: int): + for i in range(start, end): + yield i, {'x': i} + + return { + 'train': [ + functools.partial(gen_examples, start=0, end=10), + functools.partial(gen_examples, start=10, end=20), + ], + 'test': [functools.partial(gen_examples, start=100, end=110)], + } + + def _gen_example(x): return ( x, @@ -198,6 +225,26 @@ def _assert_values_equal(nested_lhs, nested_rhs): np.testing.assert_array_equal(lhs, rhs) +@pytest.mark.parametrize( + 'make_dl_config', + [ + make_default_config, + ], +) +def test_beam_shard_builder_dataset( + tmp_path: pathlib.Path, + make_dl_config: Callable[[], download.DownloadConfig], +): + builder = ShardBuilderBeam(data_dir=tmp_path, version='0.0.1') + builder.download_and_prepare( + file_format='array_record', download_config=make_dl_config() + ) + actual_train_data = list(builder.as_data_source(split='train')) + assert actual_train_data == [{'x': i} for i in range(20)] + actual_test_data = list(builder.as_data_source(split='test')) + assert actual_test_data == [{'x': i} for i in range(100, 110)] + + def test_read_tfrecord_beam(): builder = DummyBeamDataset() with mock.patch.object( diff --git a/tensorflow_datasets/core/dataset_builder_test.py b/tensorflow_datasets/core/dataset_builder_test.py index 9059abc282b..f7051e5c416 100644 --- a/tensorflow_datasets/core/dataset_builder_test.py +++ b/tensorflow_datasets/core/dataset_builder_test.py @@ -15,7 +15,9 @@ """Tests for tensorflow_datasets.core.dataset_builder.""" +from collections.abc import Iterator, Mapping, Sequence import dataclasses +import functools import os import tempfile from unittest import mock @@ -37,9 +39,11 @@ from tensorflow_datasets.core import load from tensorflow_datasets.core import naming from tensorflow_datasets.core import read_only_builder +from tensorflow_datasets.core import split_builder from tensorflow_datasets.core import splits as splits_lib from tensorflow_datasets.core import utils from tensorflow_datasets.core.data_sources import array_record +from tensorflow_datasets.core.download import download_manager from tensorflow_datasets.core.utils import file_utils from tensorflow_datasets.core.utils import read_config as read_config_lib from tensorflow_datasets.testing.dummy_config_based_datasets.dummy_ds_1 import dummy_ds_1_dataset_builder @@ -147,6 +151,50 @@ def _split_generators(self, _): return {"all": self._generate_examples(range(5))} +class ShardBuilder(dataset_builder.ShardBasedBuilder): + VERSION = utils.Version("0.0.1") + BUILDER_CONFIGS = [DummyBuilderConfig(name="cfg1")] + + def _info(self): + return dataset_info.DatasetInfo( + builder=self, + features=features.FeaturesDict({"x": np.int64}), + ) + + def _shard_iterators_per_split( + self, dl_manager: download_manager.DownloadManager + ) -> Mapping[str, Sequence[Iterator[split_builder.KeyExample]]]: + del dl_manager + + def gen_examples( + start: int, end: int + ) -> Iterator[split_builder.KeyExample]: + for i in range(start, end): + yield i, {"x": i} + + return { + # train split has 2 shards + "train": [ + functools.partial(gen_examples, start=0, end=10), + functools.partial(gen_examples, start=10, end=20), + ], + "test": [functools.partial(gen_examples, start=100, end=110)], + } + + +class ShardBuilderTest(testing.TestCase): + + def test_download_and_prepare(self): + with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: + builder = ShardBuilder(data_dir=tmp_dir, config="cfg1", version="0.0.1") + builder.download_and_prepare(file_format="array_record") + actual_data = list(builder.as_data_source(split="train")) + self.assertEqual( + actual_data, + [{"x": i} for i in range(20)], + ) + + class GetBuilderDatadirPathTest(testing.TestCase): def test_builder_data_dir_path_is_correct(self): diff --git a/tensorflow_datasets/core/split_builder.py b/tensorflow_datasets/core/split_builder.py index 3a9618b0518..c5184da86b3 100644 --- a/tensorflow_datasets/core/split_builder.py +++ b/tensorflow_datasets/core/split_builder.py @@ -15,15 +15,17 @@ """Dataset generator code.""" -from collections.abc import Iterable, Iterator +from collections.abc import Iterable, Iterator, Sequence import contextlib import dataclasses import functools import itertools +import json import sys from typing import Any, Callable, Optional, Union from absl import logging +from etils import epath from tensorflow_datasets.core import example_serializer from tensorflow_datasets.core import features as features_lib from tensorflow_datasets.core import naming @@ -49,6 +51,7 @@ 'beam.PTransform', 'beam.PCollection[KeyExample]', ] +ExampleGeneratorFn = Callable[[], Iterator[KeyExample]] @utils.docs.deprecated @@ -147,6 +150,95 @@ def __init__( self._ignore_duplicates = ignore_duplicates self._example_writer = example_writer + def submit_shard_based_generation( + self, + split_name: str, + filename_template: naming.ShardedFileTemplate, + example_gen_per_shard: Sequence[ExampleGeneratorFn], + ) -> _SplitInfoFuture: + """Creates the shards for the split with the given example generators. + + If a Beam runner was added when initializing the `SplitBuilder`, then + the `example_gen_per_shard` will be run in parallel using Beam. Otherwise, + they will be run sequentially in the current process. + + Args: + split_name: Name of the split to generate + filename_template: Template to format the filename for a shard. + example_gen_per_shard: List of example generators, one per shard. Must be + in the same order as the shards. + + Returns: + a future with the split info. + """ + num_shards = len(example_gen_per_shard) + filename_template = filename_template.replace(split=split_name) + serialized_info = self._features.get_serialized_info() + serializer = example_serializer.ExampleSerializer(serialized_info) + + shard_writer = writer_lib.ShardWriter( + serializer=serializer, + example_writer=self._example_writer, + ) + + shard_paths = [] + shard_lengths = [] + if self._beam_runner is None: + for shard_index, example_gen in enumerate(example_gen_per_shard): + shard_path = filename_template.sharded_filepath( + shard_index=shard_index, num_shards=num_shards + ) + shard_paths.append(shard_path) + num_examples = shard_writer.write( + path=shard_path, examples=example_gen() + ) + shard_lengths.append(num_examples) + else: + shard_infos_path = filename_template.data_dir / 'shard_infos.json' + with self.maybe_beam_pipeline(): + shard_infos = [] + for shard_index, example_gen in enumerate(example_gen_per_shard): + shard_path = filename_template.sharded_filepath( + shard_index=shard_index, num_shards=num_shards + ) + shard_paths.append(shard_path) + shard_info = shard_writer.write_with_beam( + path=shard_path, + example_gen=example_gen, + shard_index=shard_index, + pipeline=self.beam_pipeline, + ) + shard_infos.append(shard_info) + + def write_shard_infos( + shard_infos: list[tuple[int, int]], path: epath.Path + ) -> None: + shard_infos_dict = {index: length for index, length in shard_infos} + path.write_text(json.dumps(shard_infos_dict)) + + _ = ( + shard_infos + | f'FlattenShardInfos_{split_name}' >> beam.Flatten() + | f'CombineShardInfos_{split_name}' + >> beam.CombineGlobally(beam.combiners.ToListCombineFn()) + | f'WriteShardInfos_{split_name}' + >> beam.Map(write_shard_infos, path=shard_infos_path) + ) + + shard_infos_dict = json.loads(shard_infos_path.read_text()) + shard_lengths = [ + num_examples for _, num_examples in sorted(shard_infos_dict.items()) + ] + + total_size = sum([shard_path.stat().length for shard_path in shard_paths]) + split_info = splits_lib.SplitInfo( + name=split_name, + shard_lengths=shard_lengths, + num_bytes=total_size, + filename_template=filename_template, + ) + return _SplitInfoFuture(lambda: split_info) + @contextlib.contextmanager def maybe_beam_pipeline(self) -> Iterator[PipelineProxy]: """Context manager wrapping the beam pipeline. diff --git a/tensorflow_datasets/core/writer.py b/tensorflow_datasets/core/writer.py index 23074127e06..e4128b68e25 100644 --- a/tensorflow_datasets/core/writer.py +++ b/tensorflow_datasets/core/writer.py @@ -24,7 +24,7 @@ import json import os import threading -from typing import Any +from typing import Any, Callable from etils import epy from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam @@ -48,6 +48,8 @@ # TODO(tfds): Should be `TreeDict[FeatureValue]` Example = Any +Key = int | bytes +KeyExample = tuple[Key, Example] _INDEX_PATH_SUFFIX = "_index.json" @@ -256,6 +258,62 @@ def write( write_fn() +class ShardWriter: + """Writes examples to a single shard.""" + + def __init__( + self, + serializer: example_serializer.Serializer, + example_writer: ExampleWriter, + ): + """Initializes Writer. + + Args: + serializer: class that can serialize examples. + example_writer: class that writes examples to disk or elsewhere. + """ + self._serializer = serializer + self._example_writer = example_writer + + def write( + self, + examples: Iterable[KeyExample], + path: epath.Path, + ) -> int: + """Returns the number of examples written to the given path.""" + serialized_examples = [ + (k, self._serializer.serialize_example(v)) for k, v in examples + ] + self._example_writer.write(path=path, examples=serialized_examples) + + return len(serialized_examples) + + def write_with_beam( + self, + example_gen: Callable[[], Iterable[KeyExample]], + path: epath.Path, + shard_index: int, + pipeline: beam.Pipeline, + ) -> None: + """Writes a PCollection of examples to a file.""" + + def write_examples(dummy_value: Any) -> tuple[int, int]: + # The dummy value is needed to make the pipeline work with + # `beam.Create([None])`. + del dummy_value + num_examples = self.write( + examples=example_gen(), + path=path, + ) + return shard_index, num_examples + + return ( + pipeline + | f"CreateShard{path.name}" >> beam.Create([None]) + | f"WriteShard{path.name}" >> beam.Map(write_examples) + ) + + class Writer: """Shuffles and writes Examples to sharded files.