onnx2tf 1.29.9__py3-none-any.whl → 1.29.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
onnx2tf/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  from onnx2tf.onnx2tf import convert, main
2
2
 
3
- __version__ = '1.29.9'
3
+ __version__ = '1.29.10'
onnx2tf/onnx2tf.py CHANGED
@@ -1467,82 +1467,89 @@ def convert(
1467
1467
  SIGNATURE_KEY = 'serving_default'
1468
1468
 
1469
1469
  # saved_model
1470
+ saved_model_log_level = get_log_level()
1470
1471
  try:
1471
- # concrete_func
1472
- info(Color.REVERSE(f'saved_model output started'), '=' * 58)
1473
- if not output_signaturedefs and not output_integer_quantized_tflite:
1474
- tf.saved_model.save(model, output_folder_path)
1475
- else:
1476
- export_archive = tf_keras.export.ExportArchive()
1477
- export_archive.add_endpoint(
1478
- name=SIGNATURE_KEY,
1479
- fn=lambda *inputs : model(inputs),
1480
- input_signature=[tf.TensorSpec(tensor.shape, tensor.dtype, tensor.name) for tensor in model.inputs],
1481
- )
1482
- export_archive.write_out(output_folder_path)
1483
- info(Color.GREEN(f'saved_model output complete!'))
1484
- except TypeError as e:
1485
- # Switch to .pb
1486
- info(Color.GREEN(f'Switch to the output of an optimized protocol buffer file (.pb).'))
1487
- except (KeyError, AssertionError) as e:
1488
- msg_list = [s for s in e.args if isinstance(s, str)]
1489
- if len(msg_list) > 0:
1490
- try:
1491
- for s in msg_list:
1492
- if 'Failed to add concrete function' in s \
1493
- or "Tried to export a function which references an 'untracked' resource" in s:
1494
- export_archive = tf_keras.export.ExportArchive()
1495
- export_archive.add_endpoint(
1496
- name=SIGNATURE_KEY,
1497
- fn=lambda *inputs : model(inputs),
1498
- input_signature=[tf.TensorSpec(tensor.shape, tensor.dtype, tensor.name) for tensor in model.inputs],
1499
- )
1500
- export_archive.write_out(output_folder_path)
1501
- break
1502
- except ValueError as e:
1503
- msg_list = [s for s in e.args if isinstance(s, str)]
1504
- if len(msg_list) > 0:
1472
+ if saved_model_log_level <= LOG_LEVELS['debug']:
1473
+ set_log_level('info')
1474
+ try:
1475
+ # concrete_func
1476
+ info(Color.REVERSE(f'saved_model output started'), '=' * 58)
1477
+ if not output_signaturedefs and not output_integer_quantized_tflite:
1478
+ tf.saved_model.save(model, output_folder_path)
1479
+ else:
1480
+ export_archive = tf_keras.export.ExportArchive()
1481
+ export_archive.add_endpoint(
1482
+ name=SIGNATURE_KEY,
1483
+ fn=lambda *inputs : model(inputs),
1484
+ input_signature=[tf.TensorSpec(tensor.shape, tensor.dtype, tensor.name) for tensor in model.inputs],
1485
+ )
1486
+ export_archive.write_out(output_folder_path)
1487
+ info(Color.GREEN(f'saved_model output complete!'))
1488
+ except TypeError as e:
1489
+ # Switch to .pb
1490
+ info(Color.GREEN(f'Switch to the output of an optimized protocol buffer file (.pb).'))
1491
+ except (KeyError, AssertionError) as e:
1492
+ msg_list = [s for s in e.args if isinstance(s, str)]
1493
+ if len(msg_list) > 0:
1494
+ try:
1505
1495
  for s in msg_list:
1506
- if 'A root scope name has to match the following pattern' in s:
1507
- error(
1508
- f'Generation of saved_model failed because the OP name does not match the following pattern. ^[A-Za-z0-9.][A-Za-z0-9_.\\\\/>-]*$'
1496
+ if 'Failed to add concrete function' in s \
1497
+ or "Tried to export a function which references an 'untracked' resource" in s:
1498
+ export_archive = tf_keras.export.ExportArchive()
1499
+ export_archive.add_endpoint(
1500
+ name=SIGNATURE_KEY,
1501
+ fn=lambda *inputs : model(inputs),
1502
+ input_signature=[tf.TensorSpec(tensor.shape, tensor.dtype, tensor.name) for tensor in model.inputs],
1509
1503
  )
1510
- matches = re.findall(r"'([^']*)'", s)
1511
- error(f'{matches[0]}')
1512
- error(
1513
- f'Please convert again with the `-osd` or `--output_signaturedefs` option.'
1514
- )
1515
- sys.exit(1)
1516
- else:
1517
- error(e)
1518
- import traceback
1519
- error(traceback.format_exc(), prefix=False)
1520
- else:
1521
- error(e)
1522
- import traceback
1523
- error(traceback.format_exc(), prefix=False)
1524
- except ValueError as e:
1525
- msg_list = [s for s in e.args if isinstance(s, str)]
1526
- if len(msg_list) > 0:
1527
- for s in msg_list:
1528
- if 'A root scope name has to match the following pattern' in s:
1529
- error(
1530
- f'Generation of saved_model failed because the OP name does not match the following pattern. ^[A-Za-z0-9.][A-Za-z0-9_.\\\\/>-]*$'
1531
- )
1532
- matches = re.findall(r"'([^']*)'", s)
1533
- error(f'{matches[0]}')
1534
- error(
1535
- f'Please convert again with the `-osd` or `--output_signaturedefs` option.'
1536
- )
1537
- sys.exit(1)
1538
- else:
1504
+ export_archive.write_out(output_folder_path)
1505
+ break
1506
+ except ValueError as e:
1507
+ msg_list = [s for s in e.args if isinstance(s, str)]
1508
+ if len(msg_list) > 0:
1509
+ for s in msg_list:
1510
+ if 'A root scope name has to match the following pattern' in s:
1511
+ error(
1512
+ f'Generation of saved_model failed because the OP name does not match the following pattern. ^[A-Za-z0-9.][A-Za-z0-9_.\\\\/>-]*$'
1513
+ )
1514
+ matches = re.findall(r"'([^']*)'", s)
1515
+ error(f'{matches[0]}')
1516
+ error(
1517
+ f'Please convert again with the `-osd` or `--output_signaturedefs` option.'
1518
+ )
1519
+ sys.exit(1)
1520
+ else:
1521
+ error(e)
1522
+ import traceback
1523
+ error(traceback.format_exc(), prefix=False)
1524
+ else:
1525
+ error(e)
1526
+ import traceback
1527
+ error(traceback.format_exc(), prefix=False)
1528
+ except ValueError as e:
1529
+ msg_list = [s for s in e.args if isinstance(s, str)]
1530
+ if len(msg_list) > 0:
1531
+ for s in msg_list:
1532
+ if 'A root scope name has to match the following pattern' in s:
1533
+ error(
1534
+ f'Generation of saved_model failed because the OP name does not match the following pattern. ^[A-Za-z0-9.][A-Za-z0-9_.\\\\/>-]*$'
1535
+ )
1536
+ matches = re.findall(r"'([^']*)'", s)
1537
+ error(f'{matches[0]}')
1538
+ error(
1539
+ f'Please convert again with the `-osd` or `--output_signaturedefs` option.'
1540
+ )
1541
+ sys.exit(1)
1542
+ else:
1543
+ error(e)
1544
+ import traceback
1545
+ error(traceback.format_exc(), prefix=False)
1546
+ except Exception as e:
1539
1547
  error(e)
1540
1548
  import traceback
1541
1549
  error(traceback.format_exc(), prefix=False)
1542
- except Exception as e:
1543
- error(e)
1544
- import traceback
1545
- error(traceback.format_exc(), prefix=False)
1550
+ finally:
1551
+ if get_log_level() != saved_model_log_level:
1552
+ set_log_level(saved_model_log_level)
1546
1553
 
1547
1554
  # TFv1 .pb
1548
1555
  if output_tfv1_pb:
@@ -1581,9 +1588,12 @@ def convert(
1581
1588
  Name: flatbuffers
1582
1589
  Version: 22.10.26
1583
1590
  """
1584
- converter = tf.lite.TFLiteConverter.from_concrete_functions(
1585
- [concrete_func]
1586
- )
1591
+ try:
1592
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
1593
+ except Exception as e:
1594
+ converter = tf.lite.TFLiteConverter.from_concrete_functions(
1595
+ [concrete_func]
1596
+ )
1587
1597
  converter.target_spec.supported_ops = [
1588
1598
  tf.lite.OpsSet.TFLITE_BUILTINS,
1589
1599
  tf.lite.OpsSet.SELECT_TF_OPS,
onnx2tf/ops/If.py CHANGED
@@ -54,6 +54,8 @@ def make_node(
54
54
  graph_node_outputs = [] + graph_node.outputs
55
55
 
56
56
  # Then branch
57
+ subgraph_kwargs = dict(kwargs)
58
+ subgraph_kwargs['suppress_log'] = True
57
59
  then_branch_graph: gs.Graph = graph_node.attrs['then_branch']
58
60
  then_branch_graph_outputs = then_branch_graph.outputs
59
61
  for then_branch_graph_node in then_branch_graph.nodes:
@@ -73,7 +75,7 @@ def make_node(
73
75
  op.make_node(
74
76
  graph_node=then_branch_graph_node,
75
77
  tf_layers_dict=tf_layers_dict,
76
- **kwargs,
78
+ **subgraph_kwargs,
77
79
  )
78
80
  # Then branch - Resister constant
79
81
  for output in then_branch_graph_outputs:
@@ -115,7 +117,7 @@ def make_node(
115
117
  op.make_node(
116
118
  graph_node=else_branch_graph_node,
117
119
  tf_layers_dict=tf_layers_dict,
118
- **kwargs,
120
+ **subgraph_kwargs,
119
121
  )
120
122
  # Else branch - Resister constant
121
123
  for output in else_branch_graph_outputs:
onnx2tf/ops/Loop.py ADDED
@@ -0,0 +1,392 @@
1
+ import re
2
+ import sys
3
+ import random
4
+ random.seed(0)
5
+ import numpy as np
6
+ np.random.seed(0)
7
+ import tensorflow as tf
8
+ import tf_keras
9
+ import onnx_graphsurgeon as gs
10
+ from onnx2tf.utils.common_functions import (
11
+ get_constant_or_variable,
12
+ print_node_info,
13
+ inverted_operation_enable_disable,
14
+ make_tf_node_info,
15
+ )
16
+ from onnx2tf.utils.enums import NUMPY_DTYPES_TO_TF_DTYPES
17
+ import importlib
18
+ from onnx2tf.utils.logging import *
19
+
20
+
21
+ class While_Loop_CustomLayer(tf_keras.layers.Layer):
22
+ def __init__(self):
23
+ super(While_Loop_CustomLayer, self).__init__()
24
+
25
+ def call(self, cond, body, loop_vars, shape_invariants, maximum_iterations):
26
+ return tf.while_loop(
27
+ cond=cond,
28
+ body=body,
29
+ loop_vars=loop_vars,
30
+ shape_invariants=shape_invariants,
31
+ maximum_iterations=maximum_iterations,
32
+ )
33
+
34
+
35
+ def _to_tf_dtype(dtype):
36
+ return NUMPY_DTYPES_TO_TF_DTYPES[dtype] if isinstance(dtype, np.dtype) else dtype
37
+
38
+
39
+ def _as_tensor(value):
40
+ if isinstance(value, np.ndarray):
41
+ return tf.convert_to_tensor(value)
42
+ if isinstance(value, (np.generic, int, float, bool)):
43
+ return tf.convert_to_tensor(value)
44
+ return value
45
+
46
+
47
+ def _shape_invariant(value):
48
+ try:
49
+ shape = value.shape
50
+ except Exception:
51
+ return tf.TensorShape(None)
52
+ if shape is None:
53
+ return tf.TensorShape(None)
54
+ if isinstance(shape, tf.TensorShape):
55
+ if shape.rank is None:
56
+ return tf.TensorShape(None)
57
+ return tf.TensorShape([None for _ in range(shape.rank)])
58
+ return tf.TensorShape([None for _ in range(len(shape))])
59
+
60
+
61
+ @print_node_info
62
+ @inverted_operation_enable_disable
63
+ def make_node(
64
+ *,
65
+ graph_node: gs.Node,
66
+ tf_layers_dict: dict,
67
+ **kwargs: dict,
68
+ ):
69
+ """Loop
70
+
71
+ Parameters
72
+ ----------
73
+ graph_node: gs.Node
74
+ graph_surgeon Node
75
+
76
+ tf_layers_dict: dict
77
+ optype, shape, dtype, tensorflow graph
78
+ """
79
+ before_op_output_shape_trans_1 = \
80
+ tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
81
+ before_op_output_shape_trans_2 = \
82
+ tf_layers_dict.get(graph_node.inputs[1].name, {}).get('before_op_output_shape_trans', True)
83
+ before_op_output_shape_trans = \
84
+ before_op_output_shape_trans_1 \
85
+ and before_op_output_shape_trans_2
86
+
87
+ graph_node_input_1 = get_constant_or_variable(
88
+ graph_node.inputs[0],
89
+ before_op_output_shape_trans,
90
+ )
91
+ graph_node_input_2 = get_constant_or_variable(
92
+ graph_node.inputs[1],
93
+ before_op_output_shape_trans,
94
+ )
95
+ graph_node_input_n_list = []
96
+ for graph_node_input in graph_node.inputs[2:]:
97
+ graph_node_input_n = get_constant_or_variable(
98
+ graph_node_input,
99
+ before_op_output_shape_trans,
100
+ )
101
+ graph_node_input_n_list.append(graph_node_input_n)
102
+
103
+ def _sanitize(name):
104
+ name = name.replace(':', '__')
105
+ if kwargs.get('output_signaturedefs', False):
106
+ name = re.sub('^/', 'wa/', name)
107
+ return name
108
+
109
+ # M: maximum trip-count
110
+ M = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
111
+ if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
112
+ M = None if isinstance(M, str) and M == "" else M
113
+ M = _as_tensor(M) if M is not None else None
114
+ if M is not None:
115
+ M = tf.cast(M, tf.int64)
116
+ max_i32 = tf.constant(tf.int32.max, dtype=tf.int64)
117
+ M = tf.where(tf.greater(M, max_i32), max_i32, M)
118
+ M = tf.cast(M, tf.int32)
119
+ if M.shape is not None and M.shape.rank not in (None, 0):
120
+ M = tf.reshape(M, [])
121
+
122
+ # cond: loop continuation condition (optional)
123
+ cond = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
124
+ if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
125
+ cond = None if isinstance(cond, str) and cond == "" else cond
126
+ cond_init = None if cond is None else tf.cast(_as_tensor(cond), tf.bool)
127
+ if cond_init is not None \
128
+ and isinstance(graph_node_input_2, gs.Variable) \
129
+ and graph_node_input_2.shape is not None \
130
+ and len(graph_node_input_2.shape) == 0:
131
+ cond_init = tf.reshape(cond_init, [])
132
+ cond_provided = cond_init is not None
133
+ if not cond_provided:
134
+ cond_init = tf.constant(True, dtype=tf.bool)
135
+
136
+ v_init = []
137
+ v_input_meta = []
138
+ for graph_node_input_n in graph_node_input_n_list:
139
+ v_val = tf_layers_dict[graph_node_input_n.name]['tf_node'] \
140
+ if isinstance(graph_node_input_n, gs.Variable) else graph_node_input_n
141
+ v_val = _as_tensor(v_val)
142
+ if isinstance(graph_node_input_n, gs.Variable) \
143
+ and graph_node_input_n.shape is not None \
144
+ and len(graph_node_input_n.shape) == 0:
145
+ v_val = tf.reshape(v_val, [])
146
+ v_init.append(v_val)
147
+ if isinstance(graph_node_input_n, gs.Variable):
148
+ v_input_meta.append(tf_layers_dict.get(graph_node_input_n.name, {}))
149
+ else:
150
+ v_input_meta.append({})
151
+
152
+ v_shapes = [_shape_invariant(v) for v in v_init]
153
+
154
+ body: gs.Graph = graph_node.attrs["body"]
155
+
156
+ iter_cnt_init = tf.constant(0, dtype=tf.int32)
157
+
158
+ scan_outputs_start_index = 1 + len(v_init)
159
+ scan_outputs_init = []
160
+ for i in range(scan_outputs_start_index, len(body.outputs)):
161
+ elem_shape = body.outputs[i].shape
162
+ if elem_shape is not None:
163
+ elem_shape = [
164
+ dim if isinstance(dim, int) else None for dim in elem_shape
165
+ ]
166
+ elem_shape = tf.TensorShape(elem_shape)
167
+ scan_outputs_init.append(
168
+ tf.TensorArray(
169
+ dtype=_to_tf_dtype(body.outputs[i].dtype),
170
+ size=0,
171
+ dynamic_size=True,
172
+ element_shape=elem_shape,
173
+ )
174
+ )
175
+ scan_outputs_shapes = [tf.TensorShape(None) for _ in scan_outputs_init]
176
+
177
+ graph_node_outputs = list(graph_node.outputs)
178
+ for graph_node_output in graph_node_outputs:
179
+ tf_layers_dict[graph_node_output.name] = {
180
+ 'optype': graph_node.op,
181
+ 'shape': graph_node_output.shape,
182
+ 'dtype': graph_node_output.dtype,
183
+ }
184
+
185
+ def _register_graph_output_constants(graph: gs.Graph):
186
+ for output in graph.outputs:
187
+ if output.name not in tf_layers_dict and isinstance(output, gs.Constant):
188
+ tf_layers_dict[output.name] = {
189
+ 'optype': 'Constant',
190
+ 'shape': output.values.shape,
191
+ 'dtype': output.values.dtype,
192
+ }
193
+ tf_layers_dict[output.name]['tf_node'] = \
194
+ tf.constant(
195
+ output.values,
196
+ dtype=_to_tf_dtype(output.values.dtype),
197
+ )
198
+ for node in graph.nodes:
199
+ for attr_val in node.attrs.values():
200
+ if isinstance(attr_val, gs.Graph):
201
+ _register_graph_output_constants(attr_val)
202
+ elif isinstance(attr_val, (list, tuple)):
203
+ for sub_val in attr_val:
204
+ if isinstance(sub_val, gs.Graph):
205
+ _register_graph_output_constants(sub_val)
206
+
207
+ # Register subgraph constants outside the while_loop to avoid scope issues.
208
+ _register_graph_output_constants(body)
209
+
210
+ def run_subgraph(iter_cnt, cond, v, scan_outputs):
211
+ # Bind loop vars to body graph inputs
212
+ loop_inputs = [iter_cnt, cond] + list(v)
213
+ for idx, (body_input, loop_val) in enumerate(zip(body.inputs, loop_inputs)):
214
+ body_input.name = _sanitize(body_input.name)
215
+ target_dtype = _to_tf_dtype(body_input.dtype) if body_input.dtype is not None else None
216
+ loop_val_cast = loop_val
217
+ if target_dtype is not None \
218
+ and isinstance(loop_val, tf.Tensor) \
219
+ and loop_val.dtype != target_dtype:
220
+ loop_val_cast = tf.cast(loop_val, target_dtype)
221
+ if body_input.shape is not None \
222
+ and len(body_input.shape) == 0:
223
+ loop_val_cast = tf.reshape(loop_val_cast, [])
224
+ tf_layers_dict[body_input.name] = {
225
+ 'optype': 'Input',
226
+ 'shape': body_input.shape,
227
+ 'dtype': body_input.dtype,
228
+ 'tf_node': loop_val_cast,
229
+ 'before_op_output_shape_trans': True,
230
+ }
231
+ if idx >= 2:
232
+ meta = v_input_meta[idx - 2]
233
+ for key in ('before_op_output_shape_trans', 'nhwc'):
234
+ if key in meta:
235
+ tf_layers_dict[body_input.name][key] = meta[key]
236
+
237
+ subgraph_kwargs = dict(kwargs)
238
+ subgraph_kwargs['suppress_log'] = True
239
+ for body_node in body.nodes:
240
+ optype = body_node.op
241
+ try:
242
+ op = importlib.import_module(f'onnx2tf.ops.{optype}')
243
+ except ModuleNotFoundError as ex:
244
+ error(
245
+ f'{optype} OP is not yet implemented.'
246
+ )
247
+ sys.exit(1)
248
+ body_node.name = _sanitize(body_node.name)
249
+ op.make_node(
250
+ graph_node=body_node,
251
+ tf_layers_dict=tf_layers_dict,
252
+ **subgraph_kwargs,
253
+ )
254
+ outputs = [tf_layers_dict[output.name]['tf_node'] for output in body.outputs]
255
+ for i in range(scan_outputs_start_index, len(outputs)):
256
+ s_index = i - scan_outputs_start_index
257
+ scan_outputs[s_index] = scan_outputs[s_index].write(
258
+ scan_outputs[s_index].size(), outputs[i]
259
+ )
260
+ cond_out = outputs[0]
261
+ if isinstance(cond_out, tf.Tensor) and cond_out.dtype != tf.bool:
262
+ cond_out = tf.cast(cond_out, tf.bool)
263
+ iter_cnt = iter_cnt + 1
264
+ return [iter_cnt, cond_out, outputs[1:scan_outputs_start_index], scan_outputs]
265
+
266
+ if M is None and not cond_provided:
267
+ error(
268
+ f'Both M and cond in Loop are not set at the same time ' +
269
+ f'Tensorflow.(PS. if you want to create a do-while loop ' +
270
+ f'then please set cond to True or 1)\n' +
271
+ f'graph_node.name: {graph_node.name}'
272
+ )
273
+ sys.exit(1)
274
+
275
+ cond_true = tf.constant(True, dtype=tf.bool)
276
+ if M is not None and not cond_provided:
277
+ condition = lambda iter_cnt, cond, v, scan_outputs: cond_true
278
+ while_loop_layer = While_Loop_CustomLayer()
279
+ iter_cnt_final, _, v_final, scan_outputs_final = while_loop_layer(
280
+ cond=condition,
281
+ body=run_subgraph,
282
+ loop_vars=[
283
+ iter_cnt_init,
284
+ cond_init,
285
+ v_init,
286
+ scan_outputs_init,
287
+ ],
288
+ shape_invariants=[
289
+ tf.TensorShape([]),
290
+ tf.TensorShape(None),
291
+ v_shapes,
292
+ scan_outputs_shapes,
293
+ ],
294
+ maximum_iterations=M,
295
+ )
296
+ elif M is None and cond_provided:
297
+ condition = lambda iter_cnt, cond, v, scan_outputs: tf.reduce_all(cond)
298
+ while_loop_layer = While_Loop_CustomLayer()
299
+ iter_cnt_final, cond_final, v_final, scan_outputs_final = while_loop_layer(
300
+ cond=condition,
301
+ body=run_subgraph,
302
+ loop_vars=[
303
+ iter_cnt_init,
304
+ cond_init,
305
+ v_init,
306
+ scan_outputs_init,
307
+ ],
308
+ shape_invariants=[
309
+ tf.TensorShape([]),
310
+ tf.TensorShape(None),
311
+ v_shapes,
312
+ scan_outputs_shapes,
313
+ ],
314
+ )
315
+ else:
316
+ condition = lambda iter_cnt, cond, v, scan_outputs: tf.reduce_all(cond)
317
+ while_loop_layer = While_Loop_CustomLayer()
318
+ iter_cnt_final, cond_final, v_final, scan_outputs_final = while_loop_layer(
319
+ cond=condition,
320
+ body=run_subgraph,
321
+ loop_vars=[
322
+ iter_cnt_init,
323
+ cond_init,
324
+ v_init,
325
+ scan_outputs_init,
326
+ ],
327
+ shape_invariants=[
328
+ tf.TensorShape([]),
329
+ tf.TensorShape(None),
330
+ v_shapes,
331
+ scan_outputs_shapes,
332
+ ],
333
+ maximum_iterations=M,
334
+ )
335
+
336
+ if scan_outputs_start_index == len(body.outputs):
337
+ final_outputs = list(v_final)
338
+ else:
339
+ def true_fn():
340
+ return scan_outputs_final
341
+
342
+ def false_fn():
343
+ empty_scan_outputs = []
344
+ for ta in scan_outputs_init:
345
+ empty_scan_outputs.append(
346
+ tf.TensorArray(
347
+ dtype=ta.dtype,
348
+ size=0,
349
+ element_shape=ta.element_shape,
350
+ )
351
+ )
352
+ return empty_scan_outputs
353
+
354
+ scan_out_final = tf.cond(tf.greater(iter_cnt_final, 0), true_fn, false_fn)
355
+ scan_outputs_tensors = [o.stack() for o in scan_out_final]
356
+ final_outputs = list(v_final) + scan_outputs_tensors
357
+
358
+ if len(final_outputs) != len(graph_node_outputs):
359
+ error(
360
+ f'Loop output count mismatch. expected={len(graph_node_outputs)} actual={len(final_outputs)}\n' +
361
+ f'graph_node.name: {graph_node.name}'
362
+ )
363
+ sys.exit(1)
364
+
365
+ for idx, (graph_node_output, output_tensor) in enumerate(zip(graph_node_outputs, final_outputs)):
366
+ tf_layers_dict[graph_node_output.name]['tf_node'] = output_tensor
367
+ if idx < len(v_init):
368
+ body_output = body.outputs[1 + idx]
369
+ else:
370
+ body_output = body.outputs[scan_outputs_start_index + (idx - len(v_init))]
371
+ body_meta = tf_layers_dict.get(body_output.name, {})
372
+ for key in ('before_op_output_shape_trans', 'nhwc'):
373
+ if key in body_meta:
374
+ tf_layers_dict[graph_node_output.name][key] = body_meta[key]
375
+
376
+ tf_outputs = {f"output{idx}": value for idx, value in enumerate(final_outputs)}
377
+ for graph_node_output in graph_node_outputs:
378
+ tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
379
+ make_tf_node_info(
380
+ node_info={
381
+ 'tf_op_type': tf.while_loop,
382
+ 'tf_inputs': {
383
+ 'condition': condition,
384
+ 'M': M,
385
+ 'cond': cond_init,
386
+ 'v_initial': v_init,
387
+ },
388
+ 'tf_outputs': {
389
+ 'output': tf_outputs,
390
+ },
391
+ }
392
+ )