Skip to content

Commit

Permalink
minor changes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 633691869
  • Loading branch information
gauravmishra authored and SeqIO committed May 15, 2024
1 parent f09ed20 commit 32c9000
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 3 deletions.
3 changes: 3 additions & 0 deletions seqio/beam_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import importlib
import json
import operator
import os
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple

from absl import logging
Expand Down Expand Up @@ -207,6 +208,8 @@ def expand(self, pcoll):
)




class _ArrayRecordSink(beam.io.filebasedsink.FileBasedSink):
"""Sink Class for use in Arrayrecord PTransform."""

Expand Down
1 change: 1 addition & 0 deletions seqio/beam_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def test_write_example_tf_record(self):
parsed_example = tf.train.Example.FromString(next(iter(ds)).numpy())
self.assertEqual(parsed_example, seqio.dict_to_tfexample(example))


def test_write_json(self):
output_path = os.path.join(self.test_data_dir, "output.json")
data = {
Expand Down
12 changes: 9 additions & 3 deletions seqio/scripts/cache_tasks_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@
flags.DEFINE_enum(
"output_format",
"tfrecord",
["arrayrecord", "tfrecord"],
[
"arrayrecord",
"tfrecord",
],
"Output format of the cached tasks.",
)
flags.DEFINE_boolean(
Expand Down Expand Up @@ -173,6 +176,8 @@ def run_pipeline(
overwrite=False,
ignore_other_caches=False,
completed_file_contents="",
store_metadata_proto: bool = False, # GOOGLE-INTERNAL,
output_format: str = "tfrecord",
):
"""Run preprocess pipeline."""
output_dirs = []
Expand Down Expand Up @@ -304,7 +309,7 @@ def run_pipeline(
| "%s_global_example_shuffle" % label >> beam.Reshuffle()
)

if FLAGS.output_format == "arrayrecord":
if output_format == "arrayrecord":
completion_values.append(
examples
| "%s_write_arrayrecord" % label
Expand All @@ -316,7 +321,7 @@ def run_pipeline(
preserve_random_access=FLAGS.preserve_random_access,
)
)
elif FLAGS.output_format == "tfrecord":
elif output_format == "tfrecord":
completion_values.append(
examples
| "%s_write_tfrecord" % label
Expand Down Expand Up @@ -389,6 +394,7 @@ def main(_):
FLAGS.module_import,
FLAGS.overwrite,
FLAGS.ignore_other_caches,
FLAGS.output_format,
)


Expand Down
1 change: 1 addition & 0 deletions seqio/scripts/cache_tasks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from seqio.scripts import cache_tasks_main
import tensorflow.compat.v2 as tf


tf.compat.v1.enable_eager_execution()

flags.FLAGS.min_shards = 0
Expand Down

0 comments on commit 32c9000

Please sign in to comment.