diff --git a/tf2onnx/optimizer/transpose_optimizer.py b/tf2onnx/optimizer/transpose_optimizer.py index 76f897828..ffcf06311 100644 --- a/tf2onnx/optimizer/transpose_optimizer.py +++ b/tf2onnx/optimizer/transpose_optimizer.py @@ -511,14 +511,15 @@ def _add_handler(self, trans, node): return True return self._handle_node_having_branches(trans, node) - def _output_node_has_single_consumer_node(self, node): + def _output_has_no_multiple_consumers(self, node): output_node = self._g.get_node_by_name(node.output[0]) - return output_node and output_node.output and self._nodes_has_single_consumer_node([output_node]) + return True if output_node is None \ + else (output_node.output and self._nodes_has_single_consumer_node([output_node])) def _transpose_handler(self, trans, node): perm = trans.get_attr_value("perm") perm_inv = invert_perm(perm) - if is_tranpose_of_type(node, perm_inv) and self._output_node_has_single_consumer_node(node): + if is_tranpose_of_type(node, perm_inv) and self._output_has_no_multiple_consumers(node): for g in {self._g, node.graph}: g.replace_all_inputs(node.output[0], trans.input[0]) # ops=g.get_nodes()