onnx2tf 1.29.16__py3-none-any.whl → 1.29.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
onnx2tf/ops/Scan.py ADDED
@@ -0,0 +1,438 @@
1
+ import re
2
+ import sys
3
+ import random
4
+ random.seed(0)
5
+ import numpy as np
6
+ np.random.seed(0)
7
+ import importlib
8
+ import tensorflow as tf
9
+ import tf_keras
10
+ import onnx_graphsurgeon as gs
11
+ from onnx2tf.utils.common_functions import (
12
+ get_constant_or_variable,
13
+ convert_axis,
14
+ print_node_info,
15
+ inverted_operation_enable_disable,
16
+ make_tf_node_info,
17
+ get_replacement_parameter,
18
+ pre_process_transpose,
19
+ post_process_transpose,
20
+ )
21
+ from onnx2tf.utils.enums import NUMPY_DTYPES_TO_TF_DTYPES
22
+ from onnx2tf.utils.logging import *
23
+
24
+
25
+ class While_Loop_CustomLayer(tf_keras.layers.Layer):
26
+ def __init__(self):
27
+ super(While_Loop_CustomLayer, self).__init__()
28
+
29
+ def call(self, cond, body, loop_vars, shape_invariants, maximum_iterations=None):
30
+ return tf.while_loop(
31
+ cond=cond,
32
+ body=body,
33
+ loop_vars=loop_vars,
34
+ shape_invariants=shape_invariants,
35
+ maximum_iterations=maximum_iterations,
36
+ )
37
+
38
+
39
+ def _to_tf_dtype(dtype):
40
+ return NUMPY_DTYPES_TO_TF_DTYPES[dtype] if isinstance(dtype, np.dtype) else dtype
41
+
42
+
43
+ def _as_tensor(value):
44
+ if isinstance(value, np.ndarray):
45
+ return tf.convert_to_tensor(value)
46
+ if isinstance(value, (np.generic, int, float, bool, str, bytes)):
47
+ return tf.convert_to_tensor(value)
48
+ return value
49
+
50
+
51
+ def _shape_invariant(value):
52
+ try:
53
+ shape = value.shape
54
+ except Exception:
55
+ return tf.TensorShape(None)
56
+ if shape is None:
57
+ return tf.TensorShape(None)
58
+ if isinstance(shape, tf.TensorShape):
59
+ if shape.rank is None:
60
+ return tf.TensorShape(None)
61
+ return tf.TensorShape([None for _ in range(shape.rank)])
62
+ return tf.TensorShape([None for _ in range(len(shape))])
63
+
64
+
65
+ def _sanitize(name, output_signaturedefs):
66
+ name = name.replace(':', '__')
67
+ if output_signaturedefs:
68
+ name = re.sub('^/', 'wa/', name)
69
+ return name
70
+
71
+
72
+ @print_node_info
73
+ @inverted_operation_enable_disable
74
+ @get_replacement_parameter
75
+ def make_node(
76
+ *,
77
+ graph_node: gs.Node,
78
+ tf_layers_dict: dict,
79
+ **kwargs: dict,
80
+ ):
81
+ """Scan
82
+
83
+ Parameters
84
+ ----------
85
+ graph_node: gs.Node
86
+ graph_surgeon Node
87
+
88
+ tf_layers_dict: dict
89
+ optype, shape, dtype, tensorflow graph
90
+ """
91
+ num_scan_inputs = int(graph_node.attrs.get('num_scan_inputs', 0))
92
+ if num_scan_inputs <= 0:
93
+ error(
94
+ f'num_scan_inputs must be > 0 for Scan.\n' +
95
+ f'graph_node.name: {graph_node.name}'
96
+ )
97
+ sys.exit(1)
98
+
99
+ total_inputs = len(graph_node.inputs)
100
+ num_state_vars = total_inputs - num_scan_inputs
101
+ if num_state_vars < 0:
102
+ error(
103
+ f'Invalid num_scan_inputs for Scan.\n' +
104
+ f'graph_node.name: {graph_node.name}'
105
+ )
106
+ sys.exit(1)
107
+
108
+ state_inputs = list(graph_node.inputs[:num_state_vars])
109
+ scan_inputs = list(graph_node.inputs[num_state_vars:])
110
+
111
+ state_values = []
112
+ state_input_meta = []
113
+ for state_input in state_inputs:
114
+ before_op_output_shape_trans = \
115
+ tf_layers_dict.get(state_input.name, {}).get('before_op_output_shape_trans', True)
116
+ state_node = get_constant_or_variable(
117
+ state_input,
118
+ before_op_output_shape_trans,
119
+ )
120
+ state_val = tf_layers_dict[state_node.name]['tf_node'] \
121
+ if isinstance(state_node, gs.Variable) else state_node
122
+ state_val = _as_tensor(state_val)
123
+ if isinstance(state_node, gs.Variable) \
124
+ and state_node.shape is not None \
125
+ and len(state_node.shape) == 0:
126
+ state_val = tf.reshape(state_val, [])
127
+ state_values.append(state_val)
128
+ if isinstance(state_node, gs.Variable):
129
+ state_input_meta.append(tf_layers_dict.get(state_node.name, {}))
130
+ else:
131
+ state_input_meta.append({})
132
+
133
+ scan_input_tensors = []
134
+ scan_input_meta = []
135
+ scan_input_axes = graph_node.attrs.get('scan_input_axes', None)
136
+ if scan_input_axes is None:
137
+ scan_input_axes = [0 for _ in range(num_scan_inputs)]
138
+ scan_input_directions = graph_node.attrs.get('scan_input_directions', None)
139
+ if scan_input_directions is None:
140
+ scan_input_directions = [0 for _ in range(num_scan_inputs)]
141
+ if len(scan_input_axes) != num_scan_inputs or len(scan_input_directions) != num_scan_inputs:
142
+ error(
143
+ f'Invalid scan_input_axes or scan_input_directions for Scan.\n' +
144
+ f'graph_node.name: {graph_node.name}'
145
+ )
146
+ sys.exit(1)
147
+
148
+ converted_scan_input_axes = []
149
+ for idx, scan_input in enumerate(scan_inputs):
150
+ before_op_output_shape_trans = \
151
+ tf_layers_dict.get(scan_input.name, {}).get('before_op_output_shape_trans', True)
152
+ scan_node = get_constant_or_variable(
153
+ scan_input,
154
+ before_op_output_shape_trans,
155
+ )
156
+ scan_tensor = tf_layers_dict[scan_node.name]['tf_node'] \
157
+ if isinstance(scan_node, gs.Variable) else scan_node
158
+ scan_tensor = pre_process_transpose(
159
+ value_before_transpose=scan_tensor,
160
+ param_target='inputs',
161
+ param_name=scan_input.name,
162
+ **kwargs,
163
+ )
164
+ scan_tensor = _as_tensor(scan_tensor)
165
+ scan_input_tensors.append(scan_tensor)
166
+ if isinstance(scan_node, gs.Variable):
167
+ scan_input_meta.append(tf_layers_dict.get(scan_node.name, {}))
168
+ else:
169
+ scan_input_meta.append({})
170
+ scan_rank = scan_tensor.shape.rank
171
+ if scan_rank is None and scan_input.shape is not None:
172
+ scan_rank = len(scan_input.shape)
173
+ if scan_rank is None:
174
+ error(
175
+ f'Scan input rank must be known.\n' +
176
+ f'graph_node.name: {graph_node.name}'
177
+ )
178
+ sys.exit(1)
179
+ converted_scan_input_axes.append(
180
+ convert_axis(
181
+ axis=int(scan_input_axes[idx]),
182
+ tensor_rank=scan_rank,
183
+ before_op_output_shape_trans=before_op_output_shape_trans,
184
+ )
185
+ )
186
+
187
+ scan_outputs = list(graph_node.outputs[num_state_vars:])
188
+ num_scan_outputs = len(scan_outputs)
189
+ scan_output_axes = graph_node.attrs.get('scan_output_axes', None)
190
+ if scan_output_axes is None:
191
+ scan_output_axes = [0 for _ in range(num_scan_outputs)]
192
+ scan_output_directions = graph_node.attrs.get('scan_output_directions', None)
193
+ if scan_output_directions is None:
194
+ scan_output_directions = [0 for _ in range(num_scan_outputs)]
195
+ if len(scan_output_axes) != num_scan_outputs or len(scan_output_directions) != num_scan_outputs:
196
+ error(
197
+ f'Invalid scan_output_axes or scan_output_directions for Scan.\n' +
198
+ f'graph_node.name: {graph_node.name}'
199
+ )
200
+ sys.exit(1)
201
+
202
+ graph_node_outputs = list(graph_node.outputs)
203
+ for graph_node_output in graph_node_outputs:
204
+ tf_layers_dict[graph_node_output.name] = {
205
+ 'optype': graph_node.op,
206
+ 'shape': graph_node_output.shape,
207
+ 'dtype': graph_node_output.dtype,
208
+ }
209
+
210
+ body: gs.Graph = graph_node.attrs["body"]
211
+
212
+ def _register_graph_output_constants(graph: gs.Graph):
213
+ for output in graph.outputs:
214
+ if output.name not in tf_layers_dict and isinstance(output, gs.Constant):
215
+ tf_layers_dict[output.name] = {
216
+ 'optype': 'Constant',
217
+ 'shape': output.values.shape,
218
+ 'dtype': output.values.dtype,
219
+ }
220
+ tf_layers_dict[output.name]['tf_node'] = \
221
+ tf.constant(
222
+ output.values,
223
+ dtype=_to_tf_dtype(output.values.dtype),
224
+ )
225
+ for node in graph.nodes:
226
+ for attr_val in node.attrs.values():
227
+ if isinstance(attr_val, gs.Graph):
228
+ _register_graph_output_constants(attr_val)
229
+ elif isinstance(attr_val, (list, tuple)):
230
+ for sub_val in attr_val:
231
+ if isinstance(sub_val, gs.Graph):
232
+ _register_graph_output_constants(sub_val)
233
+
234
+ _register_graph_output_constants(body)
235
+
236
+ if len(body.inputs) != (num_state_vars + num_scan_inputs):
237
+ error(
238
+ f'Body input count mismatch for Scan.\n' +
239
+ f'graph_node.name: {graph_node.name}'
240
+ )
241
+ sys.exit(1)
242
+ if len(body.outputs) < num_state_vars:
243
+ error(
244
+ f'Body output count mismatch for Scan.\n' +
245
+ f'graph_node.name: {graph_node.name}'
246
+ )
247
+ sys.exit(1)
248
+
249
+ scan_out_start_index = num_state_vars
250
+ if len(body.outputs) != (num_state_vars + num_scan_outputs):
251
+ error(
252
+ f'Body output count mismatch for Scan.\n' +
253
+ f'graph_node.name: {graph_node.name}'
254
+ )
255
+ sys.exit(1)
256
+
257
+ # Determine sequence length from the first scan input
258
+ sequence_length = tf.shape(scan_input_tensors[0])[converted_scan_input_axes[0]]
259
+ sequence_length = tf.cast(sequence_length, tf.int32)
260
+
261
+ scan_outputs_init = []
262
+ for i in range(scan_out_start_index, len(body.outputs)):
263
+ elem_shape = body.outputs[i].shape
264
+ if elem_shape is not None:
265
+ elem_shape = [
266
+ dim if isinstance(dim, int) else None for dim in elem_shape
267
+ ]
268
+ elem_shape = tf.TensorShape(elem_shape)
269
+ scan_outputs_init.append(
270
+ tf.TensorArray(
271
+ dtype=_to_tf_dtype(body.outputs[i].dtype),
272
+ size=sequence_length,
273
+ dynamic_size=False,
274
+ element_shape=elem_shape,
275
+ )
276
+ )
277
+ scan_outputs_shapes = [tf.TensorShape(None) for _ in scan_outputs_init]
278
+
279
+ state_shapes = [_shape_invariant(v) for v in state_values]
280
+ iter_cnt_init = tf.constant(0, dtype=tf.int32)
281
+
282
+ def run_subgraph(iter_cnt, state_vals, scan_outputs_vals):
283
+ scan_elems = []
284
+ for i, scan_tensor in enumerate(scan_input_tensors):
285
+ axis = converted_scan_input_axes[i]
286
+ direction = int(scan_input_directions[i])
287
+ if direction == 1:
288
+ idx = sequence_length - 1 - iter_cnt
289
+ else:
290
+ idx = iter_cnt
291
+ scan_elems.append(
292
+ tf.gather(
293
+ params=scan_tensor,
294
+ indices=idx,
295
+ axis=axis,
296
+ )
297
+ )
298
+
299
+ loop_inputs = list(state_vals) + scan_elems
300
+ for idx, (body_input, loop_val) in enumerate(zip(body.inputs, loop_inputs)):
301
+ body_input.name = _sanitize(body_input.name, kwargs.get('output_signaturedefs', False))
302
+ target_dtype = _to_tf_dtype(body_input.dtype) if body_input.dtype is not None else None
303
+ loop_val_cast = loop_val
304
+ if target_dtype is not None \
305
+ and isinstance(loop_val, tf.Tensor) \
306
+ and loop_val.dtype != target_dtype:
307
+ loop_val_cast = tf.cast(loop_val, target_dtype)
308
+ if body_input.shape is not None \
309
+ and len(body_input.shape) == 0:
310
+ loop_val_cast = tf.reshape(loop_val_cast, [])
311
+ tf_layers_dict[body_input.name] = {
312
+ 'optype': 'Input',
313
+ 'shape': body_input.shape,
314
+ 'dtype': body_input.dtype,
315
+ 'tf_node': loop_val_cast,
316
+ 'before_op_output_shape_trans': True,
317
+ }
318
+ meta = None
319
+ if idx < num_state_vars:
320
+ meta = state_input_meta[idx]
321
+ else:
322
+ meta = scan_input_meta[idx - num_state_vars]
323
+ for key in ('before_op_output_shape_trans', 'nhwc'):
324
+ if key in meta:
325
+ tf_layers_dict[body_input.name][key] = meta[key]
326
+
327
+ subgraph_kwargs = dict(kwargs)
328
+ subgraph_kwargs['suppress_log'] = True
329
+ for body_node in body.nodes:
330
+ optype = body_node.op
331
+ try:
332
+ op = importlib.import_module(f'onnx2tf.ops.{optype}')
333
+ except ModuleNotFoundError:
334
+ error(
335
+ f'{optype} OP is not yet implemented.'
336
+ )
337
+ sys.exit(1)
338
+ body_node.name = _sanitize(body_node.name, kwargs.get('output_signaturedefs', False))
339
+ op.make_node(
340
+ graph_node=body_node,
341
+ tf_layers_dict=tf_layers_dict,
342
+ **subgraph_kwargs,
343
+ )
344
+
345
+ outputs = [tf_layers_dict[output.name]['tf_node'] for output in body.outputs]
346
+ new_state_vals = outputs[:num_state_vars]
347
+ scan_out_elems = outputs[scan_out_start_index:]
348
+
349
+ updated_scan_outputs = []
350
+ for i, ta in enumerate(scan_outputs_vals):
351
+ direction = int(scan_output_directions[i])
352
+ if direction == 1:
353
+ write_idx = sequence_length - 1 - iter_cnt
354
+ else:
355
+ write_idx = iter_cnt
356
+ updated_scan_outputs.append(
357
+ ta.write(write_idx, scan_out_elems[i])
358
+ )
359
+
360
+ return [iter_cnt + 1, new_state_vals, updated_scan_outputs]
361
+
362
+ def condition(iter_cnt, state_vals, scan_outputs_vals):
363
+ return tf.less(iter_cnt, sequence_length)
364
+
365
+ while_loop_layer = While_Loop_CustomLayer()
366
+ iter_cnt_final, state_vals_final, scan_outputs_final = while_loop_layer(
367
+ cond=condition,
368
+ body=run_subgraph,
369
+ loop_vars=[
370
+ iter_cnt_init,
371
+ state_values,
372
+ scan_outputs_init,
373
+ ],
374
+ shape_invariants=[
375
+ tf.TensorShape([]),
376
+ state_shapes,
377
+ scan_outputs_shapes,
378
+ ],
379
+ maximum_iterations=sequence_length,
380
+ )
381
+
382
+ scan_output_tensors = []
383
+ for i, ta in enumerate(scan_outputs_final):
384
+ out_tensor = ta.stack()
385
+ out_rank = out_tensor.shape.rank
386
+ if out_rank is None and scan_outputs[i].shape is not None:
387
+ out_rank = len(scan_outputs[i].shape)
388
+ axis = int(scan_output_axes[i])
389
+ if out_rank is not None:
390
+ axis = axis if axis >= 0 else axis + out_rank
391
+ if out_rank is not None and axis != 0:
392
+ perm = list(range(1, out_rank))
393
+ perm.insert(axis, 0)
394
+ out_tensor = tf.transpose(out_tensor, perm=perm)
395
+ scan_output_tensors.append(out_tensor)
396
+
397
+ final_outputs = list(state_vals_final) + scan_output_tensors
398
+ if len(final_outputs) != len(graph_node_outputs):
399
+ error(
400
+ f'Scan output count mismatch. expected={len(graph_node_outputs)} actual={len(final_outputs)}\n' +
401
+ f'graph_node.name: {graph_node.name}'
402
+ )
403
+ sys.exit(1)
404
+
405
+ for idx, (graph_node_output, output_tensor) in enumerate(zip(graph_node_outputs, final_outputs)):
406
+ tf_layers_dict[graph_node_output.name]['tf_node'] = output_tensor
407
+ body_output = body.outputs[idx]
408
+ body_meta = tf_layers_dict.get(body_output.name, {})
409
+ for key in ('before_op_output_shape_trans', 'nhwc'):
410
+ if key in body_meta:
411
+ tf_layers_dict[graph_node_output.name][key] = body_meta[key]
412
+
413
+ tf_layers_dict[graph_node_output.name]['tf_node'] = post_process_transpose(
414
+ value_before_transpose=tf_layers_dict[graph_node_output.name]['tf_node'],
415
+ param_target='outputs',
416
+ param_name=graph_node_output.name,
417
+ **kwargs,
418
+ )
419
+
420
+ tf_outputs = {f"output{idx}": value for idx, value in enumerate(final_outputs)}
421
+ for graph_node_output in graph_node_outputs:
422
+ tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
423
+ make_tf_node_info(
424
+ node_info={
425
+ 'tf_op_type': tf.while_loop,
426
+ 'tf_inputs': {
427
+ 'scan_input_axes': scan_input_axes,
428
+ 'scan_input_directions': scan_input_directions,
429
+ 'scan_output_axes': scan_output_axes,
430
+ 'scan_output_directions': scan_output_directions,
431
+ 'state_inputs': state_values,
432
+ 'scan_inputs': scan_input_tensors,
433
+ },
434
+ 'tf_outputs': {
435
+ 'output': tf_outputs,
436
+ },
437
+ }
438
+ )