onnx2tf 1.29.16__py3-none-any.whl → 1.29.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2tf/__init__.py +1 -1
- onnx2tf/ops/Add.py +112 -0
- onnx2tf/ops/Concat.py +169 -23
- onnx2tf/ops/ImageDecoder.py +147 -0
- onnx2tf/ops/NegativeLogLikelihoodLoss.py +237 -0
- onnx2tf/ops/RMSNormalization.py +175 -0
- onnx2tf/ops/RegexFullMatch.py +108 -0
- onnx2tf/ops/RotaryEmbedding.py +285 -0
- onnx2tf/ops/Scan.py +438 -0
- onnx2tf/ops/SoftmaxCrossEntropyLoss.py +289 -0
- onnx2tf/ops/StringConcat.py +128 -0
- onnx2tf/ops/StringNormalizer.py +54 -39
- onnx2tf/ops/StringSplit.py +156 -0
- onnx2tf/ops/TensorScatter.py +223 -0
- {onnx2tf-1.29.16.dist-info → onnx2tf-1.29.18.dist-info}/METADATA +13 -12
- {onnx2tf-1.29.16.dist-info → onnx2tf-1.29.18.dist-info}/RECORD +18 -8
- {onnx2tf-1.29.16.dist-info → onnx2tf-1.29.18.dist-info}/WHEEL +1 -1
- {onnx2tf-1.29.16.dist-info → onnx2tf-1.29.18.dist-info}/entry_points.txt +0 -0
onnx2tf/ops/Scan.py
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import sys
|
|
3
|
+
import random
|
|
4
|
+
random.seed(0)
|
|
5
|
+
import numpy as np
|
|
6
|
+
np.random.seed(0)
|
|
7
|
+
import importlib
|
|
8
|
+
import tensorflow as tf
|
|
9
|
+
import tf_keras
|
|
10
|
+
import onnx_graphsurgeon as gs
|
|
11
|
+
from onnx2tf.utils.common_functions import (
|
|
12
|
+
get_constant_or_variable,
|
|
13
|
+
convert_axis,
|
|
14
|
+
print_node_info,
|
|
15
|
+
inverted_operation_enable_disable,
|
|
16
|
+
make_tf_node_info,
|
|
17
|
+
get_replacement_parameter,
|
|
18
|
+
pre_process_transpose,
|
|
19
|
+
post_process_transpose,
|
|
20
|
+
)
|
|
21
|
+
from onnx2tf.utils.enums import NUMPY_DTYPES_TO_TF_DTYPES
|
|
22
|
+
from onnx2tf.utils.logging import *
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class While_Loop_CustomLayer(tf_keras.layers.Layer):
|
|
26
|
+
def __init__(self):
|
|
27
|
+
super(While_Loop_CustomLayer, self).__init__()
|
|
28
|
+
|
|
29
|
+
def call(self, cond, body, loop_vars, shape_invariants, maximum_iterations=None):
|
|
30
|
+
return tf.while_loop(
|
|
31
|
+
cond=cond,
|
|
32
|
+
body=body,
|
|
33
|
+
loop_vars=loop_vars,
|
|
34
|
+
shape_invariants=shape_invariants,
|
|
35
|
+
maximum_iterations=maximum_iterations,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _to_tf_dtype(dtype):
|
|
40
|
+
return NUMPY_DTYPES_TO_TF_DTYPES[dtype] if isinstance(dtype, np.dtype) else dtype
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _as_tensor(value):
|
|
44
|
+
if isinstance(value, np.ndarray):
|
|
45
|
+
return tf.convert_to_tensor(value)
|
|
46
|
+
if isinstance(value, (np.generic, int, float, bool, str, bytes)):
|
|
47
|
+
return tf.convert_to_tensor(value)
|
|
48
|
+
return value
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _shape_invariant(value):
|
|
52
|
+
try:
|
|
53
|
+
shape = value.shape
|
|
54
|
+
except Exception:
|
|
55
|
+
return tf.TensorShape(None)
|
|
56
|
+
if shape is None:
|
|
57
|
+
return tf.TensorShape(None)
|
|
58
|
+
if isinstance(shape, tf.TensorShape):
|
|
59
|
+
if shape.rank is None:
|
|
60
|
+
return tf.TensorShape(None)
|
|
61
|
+
return tf.TensorShape([None for _ in range(shape.rank)])
|
|
62
|
+
return tf.TensorShape([None for _ in range(len(shape))])
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _sanitize(name, output_signaturedefs):
|
|
66
|
+
name = name.replace(':', '__')
|
|
67
|
+
if output_signaturedefs:
|
|
68
|
+
name = re.sub('^/', 'wa/', name)
|
|
69
|
+
return name
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@print_node_info
|
|
73
|
+
@inverted_operation_enable_disable
|
|
74
|
+
@get_replacement_parameter
|
|
75
|
+
def make_node(
|
|
76
|
+
*,
|
|
77
|
+
graph_node: gs.Node,
|
|
78
|
+
tf_layers_dict: dict,
|
|
79
|
+
**kwargs: dict,
|
|
80
|
+
):
|
|
81
|
+
"""Scan
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
graph_node: gs.Node
|
|
86
|
+
graph_surgeon Node
|
|
87
|
+
|
|
88
|
+
tf_layers_dict: dict
|
|
89
|
+
optype, shape, dtype, tensorflow graph
|
|
90
|
+
"""
|
|
91
|
+
num_scan_inputs = int(graph_node.attrs.get('num_scan_inputs', 0))
|
|
92
|
+
if num_scan_inputs <= 0:
|
|
93
|
+
error(
|
|
94
|
+
f'num_scan_inputs must be > 0 for Scan.\n' +
|
|
95
|
+
f'graph_node.name: {graph_node.name}'
|
|
96
|
+
)
|
|
97
|
+
sys.exit(1)
|
|
98
|
+
|
|
99
|
+
total_inputs = len(graph_node.inputs)
|
|
100
|
+
num_state_vars = total_inputs - num_scan_inputs
|
|
101
|
+
if num_state_vars < 0:
|
|
102
|
+
error(
|
|
103
|
+
f'Invalid num_scan_inputs for Scan.\n' +
|
|
104
|
+
f'graph_node.name: {graph_node.name}'
|
|
105
|
+
)
|
|
106
|
+
sys.exit(1)
|
|
107
|
+
|
|
108
|
+
state_inputs = list(graph_node.inputs[:num_state_vars])
|
|
109
|
+
scan_inputs = list(graph_node.inputs[num_state_vars:])
|
|
110
|
+
|
|
111
|
+
state_values = []
|
|
112
|
+
state_input_meta = []
|
|
113
|
+
for state_input in state_inputs:
|
|
114
|
+
before_op_output_shape_trans = \
|
|
115
|
+
tf_layers_dict.get(state_input.name, {}).get('before_op_output_shape_trans', True)
|
|
116
|
+
state_node = get_constant_or_variable(
|
|
117
|
+
state_input,
|
|
118
|
+
before_op_output_shape_trans,
|
|
119
|
+
)
|
|
120
|
+
state_val = tf_layers_dict[state_node.name]['tf_node'] \
|
|
121
|
+
if isinstance(state_node, gs.Variable) else state_node
|
|
122
|
+
state_val = _as_tensor(state_val)
|
|
123
|
+
if isinstance(state_node, gs.Variable) \
|
|
124
|
+
and state_node.shape is not None \
|
|
125
|
+
and len(state_node.shape) == 0:
|
|
126
|
+
state_val = tf.reshape(state_val, [])
|
|
127
|
+
state_values.append(state_val)
|
|
128
|
+
if isinstance(state_node, gs.Variable):
|
|
129
|
+
state_input_meta.append(tf_layers_dict.get(state_node.name, {}))
|
|
130
|
+
else:
|
|
131
|
+
state_input_meta.append({})
|
|
132
|
+
|
|
133
|
+
scan_input_tensors = []
|
|
134
|
+
scan_input_meta = []
|
|
135
|
+
scan_input_axes = graph_node.attrs.get('scan_input_axes', None)
|
|
136
|
+
if scan_input_axes is None:
|
|
137
|
+
scan_input_axes = [0 for _ in range(num_scan_inputs)]
|
|
138
|
+
scan_input_directions = graph_node.attrs.get('scan_input_directions', None)
|
|
139
|
+
if scan_input_directions is None:
|
|
140
|
+
scan_input_directions = [0 for _ in range(num_scan_inputs)]
|
|
141
|
+
if len(scan_input_axes) != num_scan_inputs or len(scan_input_directions) != num_scan_inputs:
|
|
142
|
+
error(
|
|
143
|
+
f'Invalid scan_input_axes or scan_input_directions for Scan.\n' +
|
|
144
|
+
f'graph_node.name: {graph_node.name}'
|
|
145
|
+
)
|
|
146
|
+
sys.exit(1)
|
|
147
|
+
|
|
148
|
+
converted_scan_input_axes = []
|
|
149
|
+
for idx, scan_input in enumerate(scan_inputs):
|
|
150
|
+
before_op_output_shape_trans = \
|
|
151
|
+
tf_layers_dict.get(scan_input.name, {}).get('before_op_output_shape_trans', True)
|
|
152
|
+
scan_node = get_constant_or_variable(
|
|
153
|
+
scan_input,
|
|
154
|
+
before_op_output_shape_trans,
|
|
155
|
+
)
|
|
156
|
+
scan_tensor = tf_layers_dict[scan_node.name]['tf_node'] \
|
|
157
|
+
if isinstance(scan_node, gs.Variable) else scan_node
|
|
158
|
+
scan_tensor = pre_process_transpose(
|
|
159
|
+
value_before_transpose=scan_tensor,
|
|
160
|
+
param_target='inputs',
|
|
161
|
+
param_name=scan_input.name,
|
|
162
|
+
**kwargs,
|
|
163
|
+
)
|
|
164
|
+
scan_tensor = _as_tensor(scan_tensor)
|
|
165
|
+
scan_input_tensors.append(scan_tensor)
|
|
166
|
+
if isinstance(scan_node, gs.Variable):
|
|
167
|
+
scan_input_meta.append(tf_layers_dict.get(scan_node.name, {}))
|
|
168
|
+
else:
|
|
169
|
+
scan_input_meta.append({})
|
|
170
|
+
scan_rank = scan_tensor.shape.rank
|
|
171
|
+
if scan_rank is None and scan_input.shape is not None:
|
|
172
|
+
scan_rank = len(scan_input.shape)
|
|
173
|
+
if scan_rank is None:
|
|
174
|
+
error(
|
|
175
|
+
f'Scan input rank must be known.\n' +
|
|
176
|
+
f'graph_node.name: {graph_node.name}'
|
|
177
|
+
)
|
|
178
|
+
sys.exit(1)
|
|
179
|
+
converted_scan_input_axes.append(
|
|
180
|
+
convert_axis(
|
|
181
|
+
axis=int(scan_input_axes[idx]),
|
|
182
|
+
tensor_rank=scan_rank,
|
|
183
|
+
before_op_output_shape_trans=before_op_output_shape_trans,
|
|
184
|
+
)
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
scan_outputs = list(graph_node.outputs[num_state_vars:])
|
|
188
|
+
num_scan_outputs = len(scan_outputs)
|
|
189
|
+
scan_output_axes = graph_node.attrs.get('scan_output_axes', None)
|
|
190
|
+
if scan_output_axes is None:
|
|
191
|
+
scan_output_axes = [0 for _ in range(num_scan_outputs)]
|
|
192
|
+
scan_output_directions = graph_node.attrs.get('scan_output_directions', None)
|
|
193
|
+
if scan_output_directions is None:
|
|
194
|
+
scan_output_directions = [0 for _ in range(num_scan_outputs)]
|
|
195
|
+
if len(scan_output_axes) != num_scan_outputs or len(scan_output_directions) != num_scan_outputs:
|
|
196
|
+
error(
|
|
197
|
+
f'Invalid scan_output_axes or scan_output_directions for Scan.\n' +
|
|
198
|
+
f'graph_node.name: {graph_node.name}'
|
|
199
|
+
)
|
|
200
|
+
sys.exit(1)
|
|
201
|
+
|
|
202
|
+
graph_node_outputs = list(graph_node.outputs)
|
|
203
|
+
for graph_node_output in graph_node_outputs:
|
|
204
|
+
tf_layers_dict[graph_node_output.name] = {
|
|
205
|
+
'optype': graph_node.op,
|
|
206
|
+
'shape': graph_node_output.shape,
|
|
207
|
+
'dtype': graph_node_output.dtype,
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
body: gs.Graph = graph_node.attrs["body"]
|
|
211
|
+
|
|
212
|
+
def _register_graph_output_constants(graph: gs.Graph):
|
|
213
|
+
for output in graph.outputs:
|
|
214
|
+
if output.name not in tf_layers_dict and isinstance(output, gs.Constant):
|
|
215
|
+
tf_layers_dict[output.name] = {
|
|
216
|
+
'optype': 'Constant',
|
|
217
|
+
'shape': output.values.shape,
|
|
218
|
+
'dtype': output.values.dtype,
|
|
219
|
+
}
|
|
220
|
+
tf_layers_dict[output.name]['tf_node'] = \
|
|
221
|
+
tf.constant(
|
|
222
|
+
output.values,
|
|
223
|
+
dtype=_to_tf_dtype(output.values.dtype),
|
|
224
|
+
)
|
|
225
|
+
for node in graph.nodes:
|
|
226
|
+
for attr_val in node.attrs.values():
|
|
227
|
+
if isinstance(attr_val, gs.Graph):
|
|
228
|
+
_register_graph_output_constants(attr_val)
|
|
229
|
+
elif isinstance(attr_val, (list, tuple)):
|
|
230
|
+
for sub_val in attr_val:
|
|
231
|
+
if isinstance(sub_val, gs.Graph):
|
|
232
|
+
_register_graph_output_constants(sub_val)
|
|
233
|
+
|
|
234
|
+
_register_graph_output_constants(body)
|
|
235
|
+
|
|
236
|
+
if len(body.inputs) != (num_state_vars + num_scan_inputs):
|
|
237
|
+
error(
|
|
238
|
+
f'Body input count mismatch for Scan.\n' +
|
|
239
|
+
f'graph_node.name: {graph_node.name}'
|
|
240
|
+
)
|
|
241
|
+
sys.exit(1)
|
|
242
|
+
if len(body.outputs) < num_state_vars:
|
|
243
|
+
error(
|
|
244
|
+
f'Body output count mismatch for Scan.\n' +
|
|
245
|
+
f'graph_node.name: {graph_node.name}'
|
|
246
|
+
)
|
|
247
|
+
sys.exit(1)
|
|
248
|
+
|
|
249
|
+
scan_out_start_index = num_state_vars
|
|
250
|
+
if len(body.outputs) != (num_state_vars + num_scan_outputs):
|
|
251
|
+
error(
|
|
252
|
+
f'Body output count mismatch for Scan.\n' +
|
|
253
|
+
f'graph_node.name: {graph_node.name}'
|
|
254
|
+
)
|
|
255
|
+
sys.exit(1)
|
|
256
|
+
|
|
257
|
+
# Determine sequence length from the first scan input
|
|
258
|
+
sequence_length = tf.shape(scan_input_tensors[0])[converted_scan_input_axes[0]]
|
|
259
|
+
sequence_length = tf.cast(sequence_length, tf.int32)
|
|
260
|
+
|
|
261
|
+
scan_outputs_init = []
|
|
262
|
+
for i in range(scan_out_start_index, len(body.outputs)):
|
|
263
|
+
elem_shape = body.outputs[i].shape
|
|
264
|
+
if elem_shape is not None:
|
|
265
|
+
elem_shape = [
|
|
266
|
+
dim if isinstance(dim, int) else None for dim in elem_shape
|
|
267
|
+
]
|
|
268
|
+
elem_shape = tf.TensorShape(elem_shape)
|
|
269
|
+
scan_outputs_init.append(
|
|
270
|
+
tf.TensorArray(
|
|
271
|
+
dtype=_to_tf_dtype(body.outputs[i].dtype),
|
|
272
|
+
size=sequence_length,
|
|
273
|
+
dynamic_size=False,
|
|
274
|
+
element_shape=elem_shape,
|
|
275
|
+
)
|
|
276
|
+
)
|
|
277
|
+
scan_outputs_shapes = [tf.TensorShape(None) for _ in scan_outputs_init]
|
|
278
|
+
|
|
279
|
+
state_shapes = [_shape_invariant(v) for v in state_values]
|
|
280
|
+
iter_cnt_init = tf.constant(0, dtype=tf.int32)
|
|
281
|
+
|
|
282
|
+
def run_subgraph(iter_cnt, state_vals, scan_outputs_vals):
|
|
283
|
+
scan_elems = []
|
|
284
|
+
for i, scan_tensor in enumerate(scan_input_tensors):
|
|
285
|
+
axis = converted_scan_input_axes[i]
|
|
286
|
+
direction = int(scan_input_directions[i])
|
|
287
|
+
if direction == 1:
|
|
288
|
+
idx = sequence_length - 1 - iter_cnt
|
|
289
|
+
else:
|
|
290
|
+
idx = iter_cnt
|
|
291
|
+
scan_elems.append(
|
|
292
|
+
tf.gather(
|
|
293
|
+
params=scan_tensor,
|
|
294
|
+
indices=idx,
|
|
295
|
+
axis=axis,
|
|
296
|
+
)
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
loop_inputs = list(state_vals) + scan_elems
|
|
300
|
+
for idx, (body_input, loop_val) in enumerate(zip(body.inputs, loop_inputs)):
|
|
301
|
+
body_input.name = _sanitize(body_input.name, kwargs.get('output_signaturedefs', False))
|
|
302
|
+
target_dtype = _to_tf_dtype(body_input.dtype) if body_input.dtype is not None else None
|
|
303
|
+
loop_val_cast = loop_val
|
|
304
|
+
if target_dtype is not None \
|
|
305
|
+
and isinstance(loop_val, tf.Tensor) \
|
|
306
|
+
and loop_val.dtype != target_dtype:
|
|
307
|
+
loop_val_cast = tf.cast(loop_val, target_dtype)
|
|
308
|
+
if body_input.shape is not None \
|
|
309
|
+
and len(body_input.shape) == 0:
|
|
310
|
+
loop_val_cast = tf.reshape(loop_val_cast, [])
|
|
311
|
+
tf_layers_dict[body_input.name] = {
|
|
312
|
+
'optype': 'Input',
|
|
313
|
+
'shape': body_input.shape,
|
|
314
|
+
'dtype': body_input.dtype,
|
|
315
|
+
'tf_node': loop_val_cast,
|
|
316
|
+
'before_op_output_shape_trans': True,
|
|
317
|
+
}
|
|
318
|
+
meta = None
|
|
319
|
+
if idx < num_state_vars:
|
|
320
|
+
meta = state_input_meta[idx]
|
|
321
|
+
else:
|
|
322
|
+
meta = scan_input_meta[idx - num_state_vars]
|
|
323
|
+
for key in ('before_op_output_shape_trans', 'nhwc'):
|
|
324
|
+
if key in meta:
|
|
325
|
+
tf_layers_dict[body_input.name][key] = meta[key]
|
|
326
|
+
|
|
327
|
+
subgraph_kwargs = dict(kwargs)
|
|
328
|
+
subgraph_kwargs['suppress_log'] = True
|
|
329
|
+
for body_node in body.nodes:
|
|
330
|
+
optype = body_node.op
|
|
331
|
+
try:
|
|
332
|
+
op = importlib.import_module(f'onnx2tf.ops.{optype}')
|
|
333
|
+
except ModuleNotFoundError:
|
|
334
|
+
error(
|
|
335
|
+
f'{optype} OP is not yet implemented.'
|
|
336
|
+
)
|
|
337
|
+
sys.exit(1)
|
|
338
|
+
body_node.name = _sanitize(body_node.name, kwargs.get('output_signaturedefs', False))
|
|
339
|
+
op.make_node(
|
|
340
|
+
graph_node=body_node,
|
|
341
|
+
tf_layers_dict=tf_layers_dict,
|
|
342
|
+
**subgraph_kwargs,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
outputs = [tf_layers_dict[output.name]['tf_node'] for output in body.outputs]
|
|
346
|
+
new_state_vals = outputs[:num_state_vars]
|
|
347
|
+
scan_out_elems = outputs[scan_out_start_index:]
|
|
348
|
+
|
|
349
|
+
updated_scan_outputs = []
|
|
350
|
+
for i, ta in enumerate(scan_outputs_vals):
|
|
351
|
+
direction = int(scan_output_directions[i])
|
|
352
|
+
if direction == 1:
|
|
353
|
+
write_idx = sequence_length - 1 - iter_cnt
|
|
354
|
+
else:
|
|
355
|
+
write_idx = iter_cnt
|
|
356
|
+
updated_scan_outputs.append(
|
|
357
|
+
ta.write(write_idx, scan_out_elems[i])
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
return [iter_cnt + 1, new_state_vals, updated_scan_outputs]
|
|
361
|
+
|
|
362
|
+
def condition(iter_cnt, state_vals, scan_outputs_vals):
|
|
363
|
+
return tf.less(iter_cnt, sequence_length)
|
|
364
|
+
|
|
365
|
+
while_loop_layer = While_Loop_CustomLayer()
|
|
366
|
+
iter_cnt_final, state_vals_final, scan_outputs_final = while_loop_layer(
|
|
367
|
+
cond=condition,
|
|
368
|
+
body=run_subgraph,
|
|
369
|
+
loop_vars=[
|
|
370
|
+
iter_cnt_init,
|
|
371
|
+
state_values,
|
|
372
|
+
scan_outputs_init,
|
|
373
|
+
],
|
|
374
|
+
shape_invariants=[
|
|
375
|
+
tf.TensorShape([]),
|
|
376
|
+
state_shapes,
|
|
377
|
+
scan_outputs_shapes,
|
|
378
|
+
],
|
|
379
|
+
maximum_iterations=sequence_length,
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
scan_output_tensors = []
|
|
383
|
+
for i, ta in enumerate(scan_outputs_final):
|
|
384
|
+
out_tensor = ta.stack()
|
|
385
|
+
out_rank = out_tensor.shape.rank
|
|
386
|
+
if out_rank is None and scan_outputs[i].shape is not None:
|
|
387
|
+
out_rank = len(scan_outputs[i].shape)
|
|
388
|
+
axis = int(scan_output_axes[i])
|
|
389
|
+
if out_rank is not None:
|
|
390
|
+
axis = axis if axis >= 0 else axis + out_rank
|
|
391
|
+
if out_rank is not None and axis != 0:
|
|
392
|
+
perm = list(range(1, out_rank))
|
|
393
|
+
perm.insert(axis, 0)
|
|
394
|
+
out_tensor = tf.transpose(out_tensor, perm=perm)
|
|
395
|
+
scan_output_tensors.append(out_tensor)
|
|
396
|
+
|
|
397
|
+
final_outputs = list(state_vals_final) + scan_output_tensors
|
|
398
|
+
if len(final_outputs) != len(graph_node_outputs):
|
|
399
|
+
error(
|
|
400
|
+
f'Scan output count mismatch. expected={len(graph_node_outputs)} actual={len(final_outputs)}\n' +
|
|
401
|
+
f'graph_node.name: {graph_node.name}'
|
|
402
|
+
)
|
|
403
|
+
sys.exit(1)
|
|
404
|
+
|
|
405
|
+
for idx, (graph_node_output, output_tensor) in enumerate(zip(graph_node_outputs, final_outputs)):
|
|
406
|
+
tf_layers_dict[graph_node_output.name]['tf_node'] = output_tensor
|
|
407
|
+
body_output = body.outputs[idx]
|
|
408
|
+
body_meta = tf_layers_dict.get(body_output.name, {})
|
|
409
|
+
for key in ('before_op_output_shape_trans', 'nhwc'):
|
|
410
|
+
if key in body_meta:
|
|
411
|
+
tf_layers_dict[graph_node_output.name][key] = body_meta[key]
|
|
412
|
+
|
|
413
|
+
tf_layers_dict[graph_node_output.name]['tf_node'] = post_process_transpose(
|
|
414
|
+
value_before_transpose=tf_layers_dict[graph_node_output.name]['tf_node'],
|
|
415
|
+
param_target='outputs',
|
|
416
|
+
param_name=graph_node_output.name,
|
|
417
|
+
**kwargs,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
tf_outputs = {f"output{idx}": value for idx, value in enumerate(final_outputs)}
|
|
421
|
+
for graph_node_output in graph_node_outputs:
|
|
422
|
+
tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
|
|
423
|
+
make_tf_node_info(
|
|
424
|
+
node_info={
|
|
425
|
+
'tf_op_type': tf.while_loop,
|
|
426
|
+
'tf_inputs': {
|
|
427
|
+
'scan_input_axes': scan_input_axes,
|
|
428
|
+
'scan_input_directions': scan_input_directions,
|
|
429
|
+
'scan_output_axes': scan_output_axes,
|
|
430
|
+
'scan_output_directions': scan_output_directions,
|
|
431
|
+
'state_inputs': state_values,
|
|
432
|
+
'scan_inputs': scan_input_tensors,
|
|
433
|
+
},
|
|
434
|
+
'tf_outputs': {
|
|
435
|
+
'output': tf_outputs,
|
|
436
|
+
},
|
|
437
|
+
}
|
|
438
|
+
)
|