diff --git a/tensorflow_datasets/core/data_sources/array_record.py b/tensorflow_datasets/core/data_sources/array_record.py index 88cf3cb1c82..c8f3ca8fabe 100644 --- a/tensorflow_datasets/core/data_sources/array_record.py +++ b/tensorflow_datasets/core/data_sources/array_record.py @@ -20,8 +20,13 @@ """ import dataclasses +from typing import Any, Optional +from tensorflow_datasets.core import dataset_info as dataset_info_lib +from tensorflow_datasets.core import decode +from tensorflow_datasets.core import splits as splits_lib from tensorflow_datasets.core.data_sources import base +from tensorflow_datasets.core.utils import type_utils from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_data_source @@ -37,9 +42,18 @@ class ArrayRecordDataSource(base.BaseDataSource): source. """ + dataset_info: dataset_info_lib.DatasetInfo + split: splits_lib.Split = None + decoders: Optional[type_utils.TreeDict[decode.partial_decode.DecoderArg]] = ( + None + ) + # In order to lazy load array_record, we don't load + # `array_record_data_source.ArrayRecordDataSource` here. + data_source: Any = dataclasses.field(init=False) + length: int = dataclasses.field(init=False) + def __post_init__(self): - dataset_info = self.dataset_builder.info - file_instructions = base.file_instructions(dataset_info, self.split) + file_instructions = base.file_instructions(self.dataset_info, self.split) self.data_source = array_record_data_source.ArrayRecordDataSource( file_instructions ) diff --git a/tensorflow_datasets/core/data_sources/base.py b/tensorflow_datasets/core/data_sources/base.py index bb6c44ad2dc..c70f736b92c 100644 --- a/tensorflow_datasets/core/data_sources/base.py +++ b/tensorflow_datasets/core/data_sources/base.py @@ -17,14 +17,12 @@ from collections.abc import MappingView, Sequence import dataclasses -import functools import typing from typing import Any, Generic, Iterable, Protocol, SupportsIndex, TypeVar from tensorflow_datasets.core import dataset_info as dataset_info_lib from tensorflow_datasets.core import decode from tensorflow_datasets.core import splits as splits_lib -from tensorflow_datasets.core.features import top_level_feature from tensorflow_datasets.core.utils import shard_utils from tensorflow_datasets.core.utils import type_utils from tensorflow_datasets.core.utils.lazy_imports_utils import tree @@ -56,14 +54,6 @@ def file_instructions( return split_dict[split].file_instructions -class _DatasetBuilder(Protocol): - """Protocol for the DatasetBuilder to avoid cyclic imports.""" - - @property - def info(self) -> dataset_info_lib.DatasetInfo: - ... - - @dataclasses.dataclass class BaseDataSource(MappingView, Sequence): """Base DataSource to override all dunder methods with the deserialization. @@ -74,28 +64,22 @@ class BaseDataSource(MappingView, Sequence): deserialization/decoding. Attributes: - dataset_builder: The dataset builder. + dataset_info: The DatasetInfo of the split: The split to load in the data source. decoders: Optional decoders for decoding. data_source: The underlying data source to initialize in the __post_init__. """ - dataset_builder: _DatasetBuilder + dataset_info: dataset_info_lib.DatasetInfo split: splits_lib.Split | None = None decoders: type_utils.TreeDict[decode.partial_decode.DecoderArg] | None = None data_source: DataSource[Any] = dataclasses.field(init=False) - @functools.cached_property - def _features(self) -> top_level_feature.TopLevelFeature: - """Caches features because we log the use of dataset_builder.info.""" - features = self.dataset_builder.info.features - if not features: - raise ValueError('No feature defined in the dataset buidler.') - return features - def __getitem__(self, key: SupportsIndex) -> Any: record = self.data_source[key.__index__()] - return self._features.deserialize_example_np(record, decoders=self.decoders) + return self.dataset_info.features.deserialize_example_np( + record, decoders=self.decoders + ) def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]: """Retrieves items by batch. @@ -114,6 +98,7 @@ def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]: if not keys: return [] records = self.data_source.__getitems__(keys) + features = self.dataset_info.features if len(keys) != len(records): raise IndexError( f'Requested {len(keys)} records but got' @@ -121,7 +106,7 @@ def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]: f'{keys=}, {records=}' ) return [ - self._features.deserialize_example_np(record, decoders=self.decoders) + features.deserialize_example_np(record, decoders=self.decoders) for record in records ] @@ -129,9 +114,8 @@ def __repr__(self) -> str: decoders_repr = ( tree.map_structure(type, self.decoders) if self.decoders else None ) - name = self.dataset_builder.info.name return ( - f'{self.__class__.__name__}(name={name}, ' + f'{self.__class__.__name__}(name={self.dataset_info.name}, ' f'split={self.split!r}, ' f'decoders={decoders_repr})' ) diff --git a/tensorflow_datasets/core/data_sources/base_test.py b/tensorflow_datasets/core/data_sources/base_test.py index 12b70392044..6891a2cff7b 100644 --- a/tensorflow_datasets/core/data_sources/base_test.py +++ b/tensorflow_datasets/core/data_sources/base_test.py @@ -15,15 +15,13 @@ """Tests for all data sources.""" -import pickle from unittest import mock -import cloudpickle from etils import epath import pytest import tensorflow_datasets as tfds from tensorflow_datasets import testing -from tensorflow_datasets.core import dataset_builder as dataset_builder_lib +from tensorflow_datasets.core import dataset_builder from tensorflow_datasets.core import dataset_info as dataset_info_lib from tensorflow_datasets.core import decode from tensorflow_datasets.core import file_adapters @@ -79,7 +77,7 @@ def mocked_parquet_dataset(): ) def test_read_write( tmp_path: epath.Path, - builder_cls: dataset_builder_lib.DatasetBuilder, + builder_cls: dataset_builder.DatasetBuilder, file_format: file_adapters.FileFormat, ): builder = builder_cls(data_dir=tmp_path, file_format=file_format) @@ -108,7 +106,7 @@ def test_read_write( ] -def create_dataset_builder(file_format: file_adapters.FileFormat): +def create_dataset_info(file_format: file_adapters.FileFormat): with mock.patch.object(splits_lib, 'SplitInfo') as split_mock: split_mock.return_value.name = 'train' split_mock.return_value.file_instructions = _FILE_INSTRUCTIONS @@ -116,11 +114,7 @@ def create_dataset_builder(file_format: file_adapters.FileFormat): dataset_info.file_format = file_format dataset_info.splits = {'train': split_mock()} dataset_info.name = 'dataset_name' - - dataset_builder = mock.create_autospec(dataset_builder_lib.DatasetBuilder) - dataset_builder.info = dataset_info - - return dataset_builder + return dataset_info @pytest.mark.parametrize( @@ -128,14 +122,12 @@ def create_dataset_builder(file_format: file_adapters.FileFormat): _DATA_SOURCE_CLS, ) def test_missing_split_raises_error(data_source_cls): - dataset_builder = create_dataset_builder( - file_adapters.FileFormat.ARRAY_RECORD - ) + dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD) with pytest.raises( ValueError, match="Unknown split 'doesnotexist'.", ): - data_source_cls(dataset_builder, split='doesnotexist') + data_source_cls(dataset_info, split='doesnotexist') @pytest.mark.usefixtures(*_FIXTURES) @@ -144,10 +136,8 @@ def test_missing_split_raises_error(data_source_cls): _DATA_SOURCE_CLS, ) def test_repr_returns_meaningful_string_without_decoders(data_source_cls): - dataset_builder = create_dataset_builder( - file_adapters.FileFormat.ARRAY_RECORD - ) - source = data_source_cls(dataset_builder, split='train') + dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD) + source = data_source_cls(dataset_info, split='train') name = data_source_cls.__name__ assert ( repr(source) == f"{name}(name=dataset_name, split='train', decoders=None)" @@ -160,11 +150,9 @@ def test_repr_returns_meaningful_string_without_decoders(data_source_cls): _DATA_SOURCE_CLS, ) def test_repr_returns_meaningful_string_with_decoders(data_source_cls): - dataset_builder = create_dataset_builder( - file_adapters.FileFormat.ARRAY_RECORD - ) + dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD) source = data_source_cls( - dataset_builder, + dataset_info, split='train', decoders={'my_feature': decode.SkipDecoding()}, ) @@ -193,18 +181,3 @@ def test_data_source_is_sliceable(): file_instructions = mock_array_record_data_source.call_args_list[1].args[0] assert file_instructions[0].skip == 0 assert file_instructions[0].take == 30000 - - -# PyGrain requires that data sources are picklable. -@pytest.mark.parametrize( - 'file_format', - file_adapters.FileFormat.with_random_access(), -) -@pytest.mark.parametrize('pickle_module', [pickle, cloudpickle]) -def test_data_source_is_picklable_after_use(file_format, pickle_module): - with tfds.testing.tmp_dir() as data_dir: - builder = tfds.testing.DummyDataset(data_dir=data_dir) - builder.download_and_prepare(file_format=file_format) - data_source = builder.as_data_source(split='train') - assert data_source[0] == {'id': 0} - assert pickle_module.loads(pickle_module.dumps(data_source))[0] == {'id': 0} diff --git a/tensorflow_datasets/core/data_sources/parquet.py b/tensorflow_datasets/core/data_sources/parquet.py index 048bf18994e..7fe8b19b85e 100644 --- a/tensorflow_datasets/core/data_sources/parquet.py +++ b/tensorflow_datasets/core/data_sources/parquet.py @@ -57,8 +57,7 @@ class ParquetDataSource(base.BaseDataSource): """ParquetDataSource to read from a ParquetDataset.""" def __post_init__(self): - dataset_info = self.dataset_builder.info - file_instructions = base.file_instructions(dataset_info, self.split) + file_instructions = base.file_instructions(self.dataset_info, self.split) filenames = [ file_instruction.filename for file_instruction in file_instructions ] diff --git a/tensorflow_datasets/core/dataset_builder.py b/tensorflow_datasets/core/dataset_builder.py index 39d713871e9..b834cb970c8 100644 --- a/tensorflow_datasets/core/dataset_builder.py +++ b/tensorflow_datasets/core/dataset_builder.py @@ -774,13 +774,13 @@ def build_single_data_source( file_format = self.info.file_format if file_format == file_adapters.FileFormat.ARRAY_RECORD: return array_record.ArrayRecordDataSource( - self, + self.info, split=split, decoders=decoders, ) elif file_format == file_adapters.FileFormat.PARQUET: return parquet.ParquetDataSource( - self, + self.info, split=split, decoders=decoders, ) diff --git a/tensorflow_datasets/testing/mocking.py b/tensorflow_datasets/testing/mocking.py index 81c7998c751..9378204504c 100644 --- a/tensorflow_datasets/testing/mocking.py +++ b/tensorflow_datasets/testing/mocking.py @@ -83,25 +83,13 @@ class PickableDataSourceMock(mock.MagicMock): """Makes MagicMock pickable in order to work with multiprocessing in Grain.""" def __getstate__(self): - return { - 'num_examples': len(self), - 'generator': self._generator, - 'serialize_example': self._serialize_example, - } + return {'num_examples': len(self), 'generator': self._generator} def __setstate__(self, state): - num_examples, generator, serialize_example = ( - state['num_examples'], - state['generator'], - state['serialize_example'], - ) + num_examples, generator = state['num_examples'], state['generator'] self.__len__.return_value = num_examples - self.__getitem__ = functools.partial( - _getitem, generator=generator, serialize_example=serialize_example - ) - self.__getitems__ = functools.partial( - _getitems, generator=generator, serialize_example=serialize_example - ) + self.__getitem__ = functools.partial(_getitem, generator=generator) + self.__getitems__ = functools.partial(_getitems, generator=generator) def __reduce__(self): return (PickableDataSourceMock, (), self.__getstate__()) @@ -111,14 +99,13 @@ def _getitem( self, record_key: int, generator: RandomFakeGenerator, - serialize_example=None, + serialized: bool = False, ) -> Any: """Function to overwrite __getitem__ in data sources.""" - del self example = generator[record_key] - if serialize_example: + if serialized: # Return serialized raw bytes - return serialize_example(example) + return self.dataset_info.features.serialize_example(example) return example @@ -126,18 +113,36 @@ def _getitems( self, record_keys: Sequence[int], generator: RandomFakeGenerator, - serialize_example=None, + serialized: bool = False, ) -> Sequence[Any]: """Function to overwrite __getitems__ in data sources.""" items = [ - _getitem(self, record_key, generator, serialize_example=serialize_example) + _getitem(self, record_key, generator, serialized=serialized) for record_key in record_keys ] - if serialize_example: + if serialized: return np.array(items) return items +def _deserialize_example_np(serialized_example, *, decoders=None): + """Function to overwrite dataset_info.features.deserialize_example_np. + + Warning: this has to be defined in the outer scope in order for the function + to be pickable. + + Args: + serialized_example: the example to deserialize. + decoders: optional decoders. + + Returns: + The serialized example, because deserialization is taken care by + RandomFakeGenerator. + """ + del decoders + return serialized_example + + class MockPolicy(enum.Enum): """Strategy to use with `tfds.testing.mock_data` to mock the dataset. @@ -380,27 +385,21 @@ def mock_as_data_source(self, split, decoders=None, **kwargs): # Force ARRAY_RECORD as the default file_format. return_value=file_adapters.FileFormat.ARRAY_RECORD, ): - # Make mock_data_source pickable with a given len: + self.info.features.deserialize_example_np = _deserialize_example_np mock_data_source.return_value.__len__.return_value = num_examples - # Make mock_data_source pickable with a given generator: mock_data_source.return_value._generator = ( # pylint:disable=protected-access generator ) - # Make mock_data_source pickable with a given serialize_example: - mock_data_source.return_value._serialize_example = ( # pylint:disable=protected-access - self.info.features.serialize_example - ) - serialize_example = self.info.features.serialize_example mock_data_source.return_value.__getitem__ = functools.partial( - _getitem, generator=generator, serialize_example=serialize_example + _getitem, generator=generator ) mock_data_source.return_value.__getitems__ = functools.partial( - _getitems, generator=generator, serialize_example=serialize_example + _getitems, generator=generator ) def build_single_data_source(split): single_data_source = array_record.ArrayRecordDataSource( - dataset_builder=self, split=split, decoders=decoders + dataset_info=self.info, split=split, decoders=decoders ) return single_data_source diff --git a/tensorflow_datasets/testing/mocking_test.py b/tensorflow_datasets/testing/mocking_test.py index d707e810cbf..3280e166512 100644 --- a/tensorflow_datasets/testing/mocking_test.py +++ b/tensorflow_datasets/testing/mocking_test.py @@ -392,12 +392,3 @@ def test_as_data_source_fn(): assert imagenet[0] == 'foo' assert imagenet[1] == 'bar' assert imagenet[2] == 'baz' - - -# PyGrain requires that data sources are picklable. -def test_mocked_data_source_is_pickable(): - with tfds.testing.mock_data(num_examples=2): - data_source = tfds.data_source('imagenet2012', split='train') - pickled_and_unpickled_data_source = pickle.loads(pickle.dumps(data_source)) - assert len(pickled_and_unpickled_data_source) == 2 - assert isinstance(pickled_and_unpickled_data_source[0]['image'], np.ndarray)