Skip to content

Commit

Permalink
Add option to TfdsDataSource to specify only the data dir pointing to…
Browse files Browse the repository at this point in the history
… a single dataset

PiperOrigin-RevId: 647673240
  • Loading branch information
tomvdw authored and SeqIO committed Jun 28, 2024
1 parent 568f9c4 commit 7f79267
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
12 changes: 7 additions & 5 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,12 +499,14 @@ def __init__(
Args:
tfds_name: The name and version number of a TFDS dataset, optionally with
a config. If `tfds_name` is not specified then `splits` values must be
instances of `TfdsSplit`.
a config. If `tfds_name` is not specified then either `tfds_data_dir`
must point to a folder that contains the data (e.g.,
`/data/tfds/dataset/config/1.2.3`), or `splits` values must be instances
of `TfdsSplit`.
tfds_data_dir: An optional path to a specific TFDS data directory to use.
If provided `tfds_name` must be a valid dataset in the directory. If
`tfds_name` is empty `tfds_dara_dir` must point to the directory with
one dataset.
If provided, `tfds_name` must either be a valid dataset in the
directory, or if `tfds_name` is empty, `tfds_dara_dir` must point to the
directory with one dataset (e.g., `/data/tfds/dataset/config/1.2.3`).
splits: an iterable of allowable string split names, a dict mapping
allowable canonical splits (e.g., 'validation') to TFDS splits or slices
(e.g., 'train[':1%']), or `TfdsSplit` (e.g. `TfdsSplit(dataset='mnist',
Expand Down
9 changes: 9 additions & 0 deletions seqio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,15 @@ def load(
)
read_config.shuffle_seed = seed
read_config.skip_prefetch = True
if dataset is None and data_dir is not None:
# Load directly from the data dir.
builder = self._get_builder(split=split)
return builder.as_dataset(
split=dataset_split,
shuffle_files=shuffle_files,
read_config=read_config,
decoders=self._decoders,
)
return tfds.load(
dataset,
split=dataset_split,
Expand Down

0 comments on commit 7f79267

Please sign in to comment.