From 958249d55e4e7bed62285cc63379ca282d7e68fc Mon Sep 17 00:00:00 2001 From: SeqIO Team Date: Fri, 13 Sep 2024 10:42:22 -0700 Subject: [PATCH] Support ragged tensor in seqio.evaluation in calculating the max seq length. PiperOrigin-RevId: 674353863 --- seqio/evaluation.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/seqio/evaluation.py b/seqio/evaluation.py index 5012e8ca..89708849 100644 --- a/seqio/evaluation.py +++ b/seqio/evaluation.py @@ -131,7 +131,12 @@ def _cache_and_measure_examples( for ex in tfds.as_numpy(ds): for k in max_sequence_length: sequence_dim = sequence_dims.get(k, 0) - sequence_length = ex[k].shape[sequence_dim] + if isinstance(ex[k], tf.RaggedTensor): + sequence_length = tf.reduce_max( + ex[k].row_lengths(axis=sequence_dim) + ).numpy() + else: + sequence_length = ex[k].shape[sequence_dim] max_sequence_length[k] = max(max_sequence_length[k], sequence_length) cnt += 1