Skip to content

Commit

Permalink
internal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 633691869
  • Loading branch information
gauravmishra authored and SeqIO committed May 14, 2024
1 parent 20e5f45 commit 4caad71
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 3 deletions.
3 changes: 3 additions & 0 deletions seqio/beam_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import importlib
import json
import operator
import os
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple

from absl import logging
Expand Down Expand Up @@ -205,6 +206,8 @@ def expand(self, pcoll):
)




class _ArrayRecordSink(beam.io.filebasedsink.FileBasedSink):
"""Sink Class for use in Arrayrecord PTransform."""

Expand Down
1 change: 1 addition & 0 deletions seqio/beam_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def test_write_example_tf_record(self):
parsed_example = tf.train.Example.FromString(next(iter(ds)).numpy())
self.assertEqual(parsed_example, seqio.dict_to_tfexample(example))


def test_write_json(self):
output_path = os.path.join(self.test_data_dir, "output.json")
data = {
Expand Down
10 changes: 7 additions & 3 deletions seqio/scripts/cache_tasks_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,11 @@
)


_OUTPUT_FORMATS = ["arrayrecord", "tfrecord"]
flags.DEFINE_enum(
"output_format",
"tfrecord",
["arrayrecord", "tfrecord"],
_OUTPUT_FORMATS,
"Output format of the cached tasks.",
)
flags.DEFINE_boolean(
Expand Down Expand Up @@ -173,6 +174,8 @@ def run_pipeline(
overwrite=False,
ignore_other_caches=False,
completed_file_contents="",
store_metadata_proto: bool = False, # GOOGLE-INTERNAL,
output_format: str = "tfrecord",
):
"""Run preprocess pipeline."""
output_dirs = []
Expand Down Expand Up @@ -304,7 +307,7 @@ def run_pipeline(
| "%s_global_example_shuffle" % label >> beam.Reshuffle()
)

if FLAGS.output_format == "arrayrecord":
if output_format == "arrayrecord":
completion_values.append(
examples
| "%s_write_arrayrecord" % label
Expand All @@ -316,7 +319,7 @@ def run_pipeline(
preserve_random_access=FLAGS.preserve_random_access,
)
)
elif FLAGS.output_format == "tfrecord":
elif output_format == "tfrecord":
completion_values.append(
examples
| "%s_write_tfrecord" % label
Expand Down Expand Up @@ -389,6 +392,7 @@ def main(_):
FLAGS.module_import,
FLAGS.overwrite,
FLAGS.ignore_other_caches,
FLAGS.output_format,
)


Expand Down
2 changes: 2 additions & 0 deletions seqio/scripts/cache_tasks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
from absl import flags
from absl.testing import absltest
from apache_beam.testing.test_pipeline import TestPipeline
import bagz
import seqio
from seqio.scripts import cache_tasks_main
import tensorflow.compat.v2 as tf


tf.compat.v1.enable_eager_execution()

flags.FLAGS.min_shards = 0
Expand Down

0 comments on commit 4caad71

Please sign in to comment.