onnx2tf 1.29.15__py3-none-any.whl → 1.29.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2tf/__init__.py +1 -1
- onnx2tf/onnx2tf.py +141 -0
- onnx2tf/ops/Add.py +112 -0
- onnx2tf/ops/Concat.py +236 -64
- onnx2tf/ops/DequantizeLinear.py +76 -34
- onnx2tf/ops/DynamicQuantizeLinear.py +18 -17
- onnx2tf/ops/QLinearConcat.py +245 -26
- onnx2tf/ops/QLinearConv.py +70 -75
- onnx2tf/ops/QLinearMatMul.py +77 -20
- onnx2tf/ops/QuantizeLinear.py +117 -44
- onnx2tf/ops/Split.py +33 -8
- {onnx2tf-1.29.15.dist-info → onnx2tf-1.29.17.dist-info}/METADATA +3 -3
- {onnx2tf-1.29.15.dist-info → onnx2tf-1.29.17.dist-info}/RECORD +15 -15
- {onnx2tf-1.29.15.dist-info → onnx2tf-1.29.17.dist-info}/WHEEL +0 -0
- {onnx2tf-1.29.15.dist-info → onnx2tf-1.29.17.dist-info}/entry_points.txt +0 -0
onnx2tf/__init__.py
CHANGED
onnx2tf/onnx2tf.py
CHANGED
|
@@ -62,6 +62,146 @@ from onnx2tf.utils.enums import (
|
|
|
62
62
|
from onnx2tf.utils.logging import *
|
|
63
63
|
from sng4onnx import generate as op_name_auto_generate
|
|
64
64
|
|
|
65
|
+
def fuse_expanded_qdq_to_qdq(
|
|
66
|
+
*,
|
|
67
|
+
graph: gs.Graph,
|
|
68
|
+
):
|
|
69
|
+
def _get_const_value(tensor):
|
|
70
|
+
if isinstance(tensor, gs.Constant):
|
|
71
|
+
return tensor.values
|
|
72
|
+
if isinstance(tensor, gs.Variable) and len(tensor.inputs) == 1:
|
|
73
|
+
producer = tensor.inputs[0]
|
|
74
|
+
if producer.op == 'Constant' and 'value' in producer.attrs:
|
|
75
|
+
return producer.attrs['value'].values
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
def _split_const_and_var(inputs):
|
|
79
|
+
if len(inputs) != 2:
|
|
80
|
+
return None, None
|
|
81
|
+
const_val = _get_const_value(inputs[0])
|
|
82
|
+
if const_val is not None:
|
|
83
|
+
return const_val, inputs[1]
|
|
84
|
+
const_val = _get_const_value(inputs[1])
|
|
85
|
+
if const_val is not None:
|
|
86
|
+
return const_val, inputs[0]
|
|
87
|
+
return None, None
|
|
88
|
+
|
|
89
|
+
nodes_to_remove = []
|
|
90
|
+
nodes_to_add = []
|
|
91
|
+
|
|
92
|
+
for round_node in list(graph.nodes):
|
|
93
|
+
if round_node.op != 'Round' or len(round_node.inputs) < 1:
|
|
94
|
+
continue
|
|
95
|
+
|
|
96
|
+
round_in = round_node.inputs[0]
|
|
97
|
+
if len(round_in.inputs) != 1:
|
|
98
|
+
continue
|
|
99
|
+
mul1_node = round_in.inputs[0]
|
|
100
|
+
if mul1_node.op != 'Mul':
|
|
101
|
+
continue
|
|
102
|
+
if len(mul1_node.outputs) != 1 or len(mul1_node.outputs[0].outputs) != 1:
|
|
103
|
+
continue
|
|
104
|
+
|
|
105
|
+
inv_scale, x = _split_const_and_var(mul1_node.inputs)
|
|
106
|
+
if inv_scale is None or x is None:
|
|
107
|
+
continue
|
|
108
|
+
|
|
109
|
+
relu_node = round_node.outputs[0].outputs[0] if round_node.outputs else None
|
|
110
|
+
if relu_node is None:
|
|
111
|
+
continue
|
|
112
|
+
if relu_node.op == 'Relu':
|
|
113
|
+
relu_out = relu_node.outputs[0]
|
|
114
|
+
elif relu_node.op in ['Max', 'Maximum']:
|
|
115
|
+
max_const, max_var = _split_const_and_var(relu_node.inputs)
|
|
116
|
+
if max_const is None or max_var != round_node.outputs[0]:
|
|
117
|
+
continue
|
|
118
|
+
if np.asarray(max_const).size != 1 or float(np.asarray(max_const).item()) != 0.0:
|
|
119
|
+
continue
|
|
120
|
+
relu_out = relu_node.outputs[0]
|
|
121
|
+
else:
|
|
122
|
+
continue
|
|
123
|
+
|
|
124
|
+
if len(relu_out.outputs) != 1:
|
|
125
|
+
continue
|
|
126
|
+
min_node = relu_out.outputs[0]
|
|
127
|
+
if min_node.op not in ['Min', 'Minimum']:
|
|
128
|
+
continue
|
|
129
|
+
|
|
130
|
+
qmax, min_var = _split_const_and_var(min_node.inputs)
|
|
131
|
+
if qmax is None or min_var != relu_out:
|
|
132
|
+
continue
|
|
133
|
+
if np.asarray(qmax).size != 1:
|
|
134
|
+
continue
|
|
135
|
+
|
|
136
|
+
if len(min_node.outputs) != 1 or len(min_node.outputs[0].outputs) != 1:
|
|
137
|
+
continue
|
|
138
|
+
mul2_node = min_node.outputs[0].outputs[0]
|
|
139
|
+
if mul2_node.op != 'Mul':
|
|
140
|
+
continue
|
|
141
|
+
|
|
142
|
+
scale, min_out = _split_const_and_var(mul2_node.inputs)
|
|
143
|
+
if scale is None or min_out != min_node.outputs[0]:
|
|
144
|
+
continue
|
|
145
|
+
if np.asarray(scale).size != 1:
|
|
146
|
+
continue
|
|
147
|
+
|
|
148
|
+
scale_val = float(np.asarray(scale).item())
|
|
149
|
+
inv_scale_val = float(np.asarray(inv_scale).item())
|
|
150
|
+
if scale_val == 0.0 or not np.isfinite(scale_val) or not np.isfinite(inv_scale_val):
|
|
151
|
+
continue
|
|
152
|
+
if not np.isclose(scale_val * inv_scale_val, 1.0, rtol=1e-3, atol=1e-6):
|
|
153
|
+
continue
|
|
154
|
+
|
|
155
|
+
if len(mul2_node.outputs) != 1:
|
|
156
|
+
continue
|
|
157
|
+
output_var = mul2_node.outputs[0]
|
|
158
|
+
|
|
159
|
+
# Require linear chain
|
|
160
|
+
chain_nodes = [mul1_node, round_node, relu_node, min_node, mul2_node]
|
|
161
|
+
if any(len(n.outputs) == 0 for n in chain_nodes):
|
|
162
|
+
continue
|
|
163
|
+
if len(round_node.outputs[0].outputs) != 1 or len(relu_out.outputs) != 1 or len(min_node.outputs[0].outputs) != 1:
|
|
164
|
+
continue
|
|
165
|
+
|
|
166
|
+
# Build QDQ
|
|
167
|
+
scale_const = gs.Constant(
|
|
168
|
+
name=f"{mul2_node.name}_scale",
|
|
169
|
+
values=np.asarray(scale_val, dtype=np.float32),
|
|
170
|
+
)
|
|
171
|
+
zero_const = gs.Constant(
|
|
172
|
+
name=f"{mul2_node.name}_zero_point",
|
|
173
|
+
values=np.asarray(0, dtype=np.uint8),
|
|
174
|
+
)
|
|
175
|
+
quant_out = gs.Variable(
|
|
176
|
+
name=f"{output_var.name}_quant",
|
|
177
|
+
dtype=np.uint8,
|
|
178
|
+
shape=output_var.shape,
|
|
179
|
+
)
|
|
180
|
+
q_node = gs.Node(
|
|
181
|
+
op="QuantizeLinear",
|
|
182
|
+
name=f"{mul2_node.name}_QuantizeLinear",
|
|
183
|
+
inputs=[x, scale_const, zero_const],
|
|
184
|
+
outputs=[quant_out],
|
|
185
|
+
)
|
|
186
|
+
dq_node = gs.Node(
|
|
187
|
+
op="DequantizeLinear",
|
|
188
|
+
name=f"{mul2_node.name}_DequantizeLinear",
|
|
189
|
+
inputs=[quant_out, scale_const, zero_const],
|
|
190
|
+
outputs=[output_var],
|
|
191
|
+
)
|
|
192
|
+
output_var.inputs = [dq_node]
|
|
193
|
+
|
|
194
|
+
nodes_to_add.extend([q_node, dq_node])
|
|
195
|
+
nodes_to_remove.extend(chain_nodes)
|
|
196
|
+
|
|
197
|
+
if nodes_to_add:
|
|
198
|
+
graph.nodes.extend(nodes_to_add)
|
|
199
|
+
if nodes_to_remove:
|
|
200
|
+
for n in nodes_to_remove:
|
|
201
|
+
if n in graph.nodes:
|
|
202
|
+
graph.nodes.remove(n)
|
|
203
|
+
graph.cleanup().toposort()
|
|
204
|
+
|
|
65
205
|
def apply_nonzero_passthrough(
|
|
66
206
|
*,
|
|
67
207
|
graph: gs.Graph,
|
|
@@ -848,6 +988,7 @@ def convert(
|
|
|
848
988
|
if hasattr(onnx_graph, 'metadata_props'):
|
|
849
989
|
metadata_props = onnx_graph.metadata_props
|
|
850
990
|
graph = gs.import_onnx(onnx_graph)
|
|
991
|
+
fuse_expanded_qdq_to_qdq(graph=graph)
|
|
851
992
|
|
|
852
993
|
# Cut the ONNX graph when an input name is specified that interrupts the conversion
|
|
853
994
|
if not input_names_to_interrupt_model_conversion:
|
onnx2tf/ops/Add.py
CHANGED
|
@@ -21,6 +21,7 @@ from onnx2tf.utils.common_functions import (
|
|
|
21
21
|
disable_unnecessary_transpose,
|
|
22
22
|
shape_unmatched_special_avoidance_workaround,
|
|
23
23
|
merge_two_consecutive_identical_ops_into_one,
|
|
24
|
+
transpose_with_flexing_deterrence,
|
|
24
25
|
deterring_shape_corruption_due_to_broadcast,
|
|
25
26
|
acquisition_of_validation_data,
|
|
26
27
|
onnx_tf_tensor_validation,
|
|
@@ -297,6 +298,117 @@ def make_node(
|
|
|
297
298
|
)
|
|
298
299
|
tf_type = tf.identity
|
|
299
300
|
|
|
301
|
+
def _normalize_dim(dim):
|
|
302
|
+
return int(dim) if isinstance(dim, (int, np.integer)) else None
|
|
303
|
+
|
|
304
|
+
def _get_static_shape(tensor):
|
|
305
|
+
shape = getattr(tensor, 'shape', None)
|
|
306
|
+
if shape is None or shape == tf.TensorShape(None):
|
|
307
|
+
return None
|
|
308
|
+
return [_normalize_dim(dim) for dim in list(shape)]
|
|
309
|
+
|
|
310
|
+
def _shape_match_with_none(expected, actual):
|
|
311
|
+
if expected is None or actual is None:
|
|
312
|
+
return False
|
|
313
|
+
if len(expected) != len(actual):
|
|
314
|
+
return False
|
|
315
|
+
for e_dim, a_dim in zip(expected, actual):
|
|
316
|
+
e_dim = _normalize_dim(e_dim)
|
|
317
|
+
a_dim = _normalize_dim(a_dim)
|
|
318
|
+
if e_dim is None or a_dim is None:
|
|
319
|
+
continue
|
|
320
|
+
if e_dim != a_dim:
|
|
321
|
+
return False
|
|
322
|
+
return True
|
|
323
|
+
|
|
324
|
+
def _perm_shape(shape, perm):
|
|
325
|
+
return [shape[i] for i in perm] if shape is not None else None
|
|
326
|
+
|
|
327
|
+
def _limited_perms(rank):
|
|
328
|
+
identity = list(range(rank))
|
|
329
|
+
perms = [identity]
|
|
330
|
+
if rank == 3:
|
|
331
|
+
perms.append([0, 2, 1])
|
|
332
|
+
elif rank == 4:
|
|
333
|
+
perms.extend([[0, 2, 3, 1], [0, 3, 1, 2]])
|
|
334
|
+
elif rank == 5:
|
|
335
|
+
perms.extend([[0, 2, 3, 4, 1], [0, 4, 1, 2, 3]])
|
|
336
|
+
return perms
|
|
337
|
+
|
|
338
|
+
def _ranked_perms(perms, input_shape, onnx_shape):
|
|
339
|
+
if input_shape is None or onnx_shape is None:
|
|
340
|
+
return perms
|
|
341
|
+
scored = []
|
|
342
|
+
for perm in perms:
|
|
343
|
+
score = 0
|
|
344
|
+
for out_idx, in_idx in enumerate(perm):
|
|
345
|
+
if out_idx >= len(onnx_shape) or in_idx >= len(input_shape):
|
|
346
|
+
continue
|
|
347
|
+
o_dim = _normalize_dim(onnx_shape[out_idx])
|
|
348
|
+
i_dim = input_shape[in_idx]
|
|
349
|
+
if isinstance(o_dim, int) and isinstance(i_dim, int) and o_dim == i_dim:
|
|
350
|
+
score += o_dim
|
|
351
|
+
scored.append((score, 1 if perm == list(range(len(perm))) else 0, perm))
|
|
352
|
+
scored.sort(key=lambda x: (x[0], x[1]), reverse=True)
|
|
353
|
+
return [p for _, _, p in scored]
|
|
354
|
+
|
|
355
|
+
# Rescue guard for unexpected broadcasted shapes
|
|
356
|
+
if not enable_gelu:
|
|
357
|
+
expected_shape = None
|
|
358
|
+
if graph_node_output_shape is not None:
|
|
359
|
+
expected_shape = [_normalize_dim(dim) for dim in list(graph_node_output_shape)]
|
|
360
|
+
output_shape = _get_static_shape(tf_layers_dict[graph_node_output.name]['tf_node'])
|
|
361
|
+
input_shape_1 = _get_static_shape(input_tensor_1)
|
|
362
|
+
input_shape_2 = _get_static_shape(input_tensor_2)
|
|
363
|
+
if expected_shape is not None \
|
|
364
|
+
and output_shape is not None \
|
|
365
|
+
and not _shape_match_with_none(expected_shape, output_shape) \
|
|
366
|
+
and input_shape_1 is not None \
|
|
367
|
+
and input_shape_2 is not None \
|
|
368
|
+
and len(input_shape_1) == len(expected_shape) \
|
|
369
|
+
and len(input_shape_2) == len(expected_shape):
|
|
370
|
+
|
|
371
|
+
rank = len(expected_shape)
|
|
372
|
+
perms = _limited_perms(rank)
|
|
373
|
+
perm_list_1 = _ranked_perms(perms, input_shape_1, expected_shape)
|
|
374
|
+
perm_list_2 = _ranked_perms(perms, input_shape_2, expected_shape)
|
|
375
|
+
rescue_done = False
|
|
376
|
+
for perm_1 in perm_list_1:
|
|
377
|
+
for perm_2 in perm_list_2:
|
|
378
|
+
try_input_1 = transpose_with_flexing_deterrence(
|
|
379
|
+
input_tensor=input_tensor_1,
|
|
380
|
+
perm=perm_1,
|
|
381
|
+
**kwargs,
|
|
382
|
+
)
|
|
383
|
+
try_input_2 = transpose_with_flexing_deterrence(
|
|
384
|
+
input_tensor=input_tensor_2,
|
|
385
|
+
perm=perm_2,
|
|
386
|
+
**kwargs,
|
|
387
|
+
)
|
|
388
|
+
try:
|
|
389
|
+
rescue_tensor = tf.math.add(
|
|
390
|
+
x=try_input_1 \
|
|
391
|
+
if not isinstance(try_input_1, np.ndarray) \
|
|
392
|
+
else tf.convert_to_tensor(try_input_1),
|
|
393
|
+
y=try_input_2 \
|
|
394
|
+
if not isinstance(try_input_2, np.ndarray) \
|
|
395
|
+
else tf.convert_to_tensor(try_input_2),
|
|
396
|
+
name=graph_node.name,
|
|
397
|
+
)
|
|
398
|
+
except Exception as ex:
|
|
399
|
+
continue
|
|
400
|
+
|
|
401
|
+
rescue_shape = _get_static_shape(rescue_tensor)
|
|
402
|
+
if _shape_match_with_none(expected_shape, rescue_shape):
|
|
403
|
+
input_tensor_1 = try_input_1
|
|
404
|
+
input_tensor_2 = try_input_2
|
|
405
|
+
tf_layers_dict[graph_node_output.name]['tf_node'] = rescue_tensor
|
|
406
|
+
tf_type = tf.math.add
|
|
407
|
+
rescue_done = True
|
|
408
|
+
break
|
|
409
|
+
if rescue_done:
|
|
410
|
+
break
|
|
411
|
+
|
|
300
412
|
# Post-process transpose
|
|
301
413
|
tf_layers_dict[graph_node_output.name]['tf_node'] = \
|
|
302
414
|
post_process_transpose(
|
onnx2tf/ops/Concat.py
CHANGED
|
@@ -234,6 +234,31 @@ def make_node(
|
|
|
234
234
|
and len(value.shape) > 0 else tf.reshape(value, [1]) for value in values
|
|
235
235
|
]
|
|
236
236
|
|
|
237
|
+
def _infer_concat_axis_runtime(values, fallback_axis):
|
|
238
|
+
if not values:
|
|
239
|
+
return fallback_axis
|
|
240
|
+
shapes = [tf.shape(v) for v in values]
|
|
241
|
+
shapes = tf.stack(shapes)
|
|
242
|
+
equal_mask = tf.reduce_all(tf.equal(shapes, shapes[0]), axis=0)
|
|
243
|
+
diff_mask = tf.cast(tf.logical_not(equal_mask), tf.int32)
|
|
244
|
+
candidate_count = tf.reduce_sum(diff_mask)
|
|
245
|
+
axis_from_diff = tf.argmax(diff_mask, axis=0, output_type=tf.int32)
|
|
246
|
+
fallback_axis_tensor = tf.cast(fallback_axis, tf.int32)
|
|
247
|
+
is_single = tf.cast(tf.equal(candidate_count, 1), tf.int32)
|
|
248
|
+
return axis_from_diff * is_single + fallback_axis_tensor * (1 - is_single)
|
|
249
|
+
|
|
250
|
+
axis_is_dynamic = False
|
|
251
|
+
if len(values) > 0:
|
|
252
|
+
all_none = True
|
|
253
|
+
for value in values:
|
|
254
|
+
if value.shape is not None and value.shape != tf.TensorShape(None):
|
|
255
|
+
if not all([s is None for s in value.shape]):
|
|
256
|
+
all_none = False
|
|
257
|
+
break
|
|
258
|
+
if all_none:
|
|
259
|
+
axis_is_dynamic = True
|
|
260
|
+
axis_for_concat = _infer_concat_axis_runtime(values, axis) if axis_is_dynamic else axis
|
|
261
|
+
|
|
237
262
|
# Generation of TF OP
|
|
238
263
|
tf_type = None
|
|
239
264
|
if simple_resize:
|
|
@@ -266,96 +291,243 @@ def make_node(
|
|
|
266
291
|
tf_type = tf.constant
|
|
267
292
|
|
|
268
293
|
else:
|
|
294
|
+
def _normalize_dim(dim):
|
|
295
|
+
return int(dim) if isinstance(dim, (int, np.integer)) else None
|
|
296
|
+
|
|
297
|
+
def _get_static_shape(tensor):
|
|
298
|
+
shape = getattr(tensor, 'shape', None)
|
|
299
|
+
if shape is None or shape == tf.TensorShape(None):
|
|
300
|
+
return None
|
|
301
|
+
return [_normalize_dim(dim) for dim in list(shape)]
|
|
302
|
+
|
|
303
|
+
def _shape_match_with_none(onnx_shape, tf_shape):
|
|
304
|
+
if onnx_shape is None or tf_shape is None:
|
|
305
|
+
return False
|
|
306
|
+
if len(onnx_shape) != len(tf_shape):
|
|
307
|
+
return False
|
|
308
|
+
for o_dim, t_dim in zip(onnx_shape, tf_shape):
|
|
309
|
+
o_dim = _normalize_dim(o_dim)
|
|
310
|
+
t_dim = _normalize_dim(t_dim)
|
|
311
|
+
if o_dim is None or t_dim is None:
|
|
312
|
+
continue
|
|
313
|
+
if o_dim != t_dim:
|
|
314
|
+
return False
|
|
315
|
+
return True
|
|
316
|
+
|
|
317
|
+
def _can_concat_shapes(shapes, axis):
|
|
318
|
+
if shapes is None or any(s is None for s in shapes):
|
|
319
|
+
return True
|
|
320
|
+
rank = len(shapes[0])
|
|
321
|
+
for idx in range(rank):
|
|
322
|
+
if idx == axis:
|
|
323
|
+
continue
|
|
324
|
+
dims = [s[idx] for s in shapes]
|
|
325
|
+
known = [d for d in dims if isinstance(d, int)]
|
|
326
|
+
if len(known) >= 2 and len(set(known)) != 1:
|
|
327
|
+
return False
|
|
328
|
+
return True
|
|
329
|
+
|
|
330
|
+
def _perm_shape(shape, perm):
|
|
331
|
+
return [shape[i] for i in perm] if shape is not None else None
|
|
332
|
+
|
|
333
|
+
def _limited_perms(rank):
|
|
334
|
+
identity = list(range(rank))
|
|
335
|
+
perms = [identity]
|
|
336
|
+
if rank == 3:
|
|
337
|
+
perms.append([0, 2, 1])
|
|
338
|
+
elif rank == 4:
|
|
339
|
+
perms.extend([[0, 2, 3, 1], [0, 3, 1, 2]])
|
|
340
|
+
elif rank == 5:
|
|
341
|
+
perms.extend([[0, 2, 3, 4, 1], [0, 4, 1, 2, 3]])
|
|
342
|
+
return perms
|
|
343
|
+
|
|
344
|
+
def _base_perms(rank):
|
|
345
|
+
if rank <= 1:
|
|
346
|
+
return [list(range(rank))]
|
|
347
|
+
return [list(p) for p in itertools.permutations(range(rank))]
|
|
348
|
+
|
|
349
|
+
def _ranked_perms(perms, input_shape, axis, onnx_shape):
|
|
350
|
+
identity = list(range(len(perms[0]))) if perms else []
|
|
351
|
+
scored = []
|
|
352
|
+
for perm in perms:
|
|
353
|
+
score = 0
|
|
354
|
+
if input_shape is not None and onnx_shape is not None:
|
|
355
|
+
for out_idx, in_idx in enumerate(perm):
|
|
356
|
+
if out_idx == axis:
|
|
357
|
+
continue
|
|
358
|
+
o_dim = _normalize_dim(onnx_shape[out_idx]) if out_idx < len(onnx_shape) else None
|
|
359
|
+
i_dim = input_shape[in_idx] if in_idx < len(input_shape) else None
|
|
360
|
+
if isinstance(o_dim, int) and isinstance(i_dim, int) and o_dim == i_dim:
|
|
361
|
+
score += o_dim
|
|
362
|
+
scored.append((score, 1 if perm == identity else 0, perm))
|
|
363
|
+
scored.sort(key=lambda x: (x[0], x[1]), reverse=True)
|
|
364
|
+
return [p for _, _, p in scored]
|
|
365
|
+
|
|
269
366
|
try:
|
|
270
367
|
# normal concat attempt
|
|
271
368
|
tf_layers_dict[graph_node_output.name]['tf_node'] = \
|
|
272
369
|
tf.concat(
|
|
273
370
|
values=values,
|
|
274
|
-
axis=
|
|
371
|
+
axis=axis_for_concat,
|
|
275
372
|
name=graph_node.name,
|
|
276
373
|
)
|
|
277
374
|
except:
|
|
278
375
|
# Workaround to reduce error rate when merging tensors with undefined dimensions
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
376
|
+
original_values = values
|
|
377
|
+
original_shapes = [_get_static_shape(v) for v in original_values]
|
|
378
|
+
value_rank = getattr(original_values[0].shape, 'rank', None)
|
|
379
|
+
if value_rank is None:
|
|
380
|
+
value_rank = len(original_values[0].shape)
|
|
381
|
+
|
|
382
|
+
onnx_shape_list = None
|
|
383
|
+
if onnx_output_shape is not None:
|
|
384
|
+
onnx_shape_list = [_normalize_dim(dim) for dim in list(onnx_output_shape)]
|
|
385
|
+
|
|
386
|
+
onnx_axis = int(graph_node.attrs.get('axis', 0))
|
|
387
|
+
onnx_axis = onnx_axis + value_rank if onnx_axis < 0 else onnx_axis
|
|
388
|
+
|
|
389
|
+
def _axis_score(axis_idx):
|
|
390
|
+
if onnx_shape_list is not None and axis_idx < len(onnx_shape_list):
|
|
391
|
+
onnx_dim = onnx_shape_list[axis_idx]
|
|
392
|
+
if isinstance(onnx_dim, int):
|
|
393
|
+
return onnx_dim
|
|
394
|
+
score = 0
|
|
395
|
+
for shape in original_shapes:
|
|
396
|
+
if shape is None or axis_idx >= len(shape):
|
|
397
|
+
continue
|
|
398
|
+
dim = shape[axis_idx]
|
|
399
|
+
if isinstance(dim, int):
|
|
400
|
+
score += dim
|
|
401
|
+
return score
|
|
402
|
+
|
|
403
|
+
axis_candidates = list(range(value_rank))
|
|
404
|
+
axis_candidates.sort(
|
|
405
|
+
key=lambda a: (_axis_score(a), 1 if a == onnx_axis else 0),
|
|
406
|
+
reverse=True,
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
base_perms = _base_perms(value_rank)
|
|
410
|
+
max_combo = 20000
|
|
411
|
+
if len(base_perms) ** len(original_values) > max_combo:
|
|
412
|
+
base_perms = _limited_perms(value_rank)
|
|
413
|
+
|
|
414
|
+
succeed = False
|
|
415
|
+
matched = False
|
|
416
|
+
chosen_axis = None
|
|
417
|
+
chosen_values = None
|
|
418
|
+
chosen_tensor = None
|
|
419
|
+
|
|
420
|
+
for axis_idx in axis_candidates:
|
|
421
|
+
perm_lists = [
|
|
422
|
+
_ranked_perms(base_perms, shape, axis_idx, onnx_shape_list)
|
|
423
|
+
for shape in original_shapes
|
|
424
|
+
]
|
|
425
|
+
for perm_combo in itertools.product(*perm_lists):
|
|
426
|
+
permuted_shapes = [
|
|
427
|
+
_perm_shape(shape, perm) for shape, perm in zip(original_shapes, perm_combo)
|
|
428
|
+
]
|
|
429
|
+
if not _can_concat_shapes(permuted_shapes, axis_idx):
|
|
430
|
+
continue
|
|
431
|
+
try_values = [
|
|
432
|
+
value if perm == list(range(value_rank)) else
|
|
433
|
+
transpose_with_flexing_deterrence(
|
|
434
|
+
input_tensor=value,
|
|
435
|
+
perm=perm,
|
|
436
|
+
**kwargs,
|
|
437
|
+
)
|
|
438
|
+
for value, perm in zip(original_values, perm_combo)
|
|
439
|
+
]
|
|
294
440
|
try:
|
|
295
|
-
|
|
441
|
+
concat_tensor = \
|
|
296
442
|
tf.concat(
|
|
297
|
-
values=
|
|
298
|
-
axis=
|
|
443
|
+
values=try_values,
|
|
444
|
+
axis=axis_idx,
|
|
299
445
|
name=graph_node.name,
|
|
300
446
|
)
|
|
301
|
-
|
|
447
|
+
except:
|
|
448
|
+
continue
|
|
449
|
+
|
|
450
|
+
if not succeed:
|
|
302
451
|
succeed = True
|
|
452
|
+
chosen_axis = axis_idx
|
|
453
|
+
chosen_values = try_values
|
|
454
|
+
chosen_tensor = concat_tensor
|
|
455
|
+
|
|
456
|
+
if onnx_shape_list is None:
|
|
457
|
+
matched = True
|
|
458
|
+
chosen_axis = axis_idx
|
|
459
|
+
chosen_values = try_values
|
|
460
|
+
chosen_tensor = concat_tensor
|
|
303
461
|
break
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
462
|
+
|
|
463
|
+
output_shape = _get_static_shape(concat_tensor)
|
|
464
|
+
if _shape_match_with_none(onnx_shape_list, output_shape):
|
|
465
|
+
matched = True
|
|
466
|
+
chosen_axis = axis_idx
|
|
467
|
+
chosen_values = try_values
|
|
468
|
+
chosen_tensor = concat_tensor
|
|
469
|
+
break
|
|
470
|
+
if matched:
|
|
471
|
+
break
|
|
472
|
+
|
|
473
|
+
if succeed:
|
|
474
|
+
tf_layers_dict[graph_node_output.name]['tf_node'] = chosen_tensor
|
|
475
|
+
axis = chosen_axis
|
|
476
|
+
values = chosen_values
|
|
477
|
+
else:
|
|
478
|
+
raise
|
|
308
479
|
|
|
309
480
|
# Attempts to force axis correction when the number of axes in the combined tensor do not exactly match.
|
|
310
481
|
# However, if more than 2 patterns of correct answers exist, give up the correction.
|
|
311
482
|
# This workaround is useful when automatic axis correction is practically difficult,
|
|
312
483
|
# such as when all tensors to be combined originate from Transpose or Reshape.
|
|
313
484
|
# https://github.com/PINTO0309/onnx2tf/issues/473
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
tf.concat(
|
|
342
|
-
values=values,
|
|
343
|
-
axis=matched_axes[0],
|
|
344
|
-
name=graph_node.name,
|
|
345
|
-
)
|
|
346
|
-
axis = matched_axes[0]
|
|
347
|
-
elif not nhwc_judge:
|
|
348
|
-
onnx_axis = int(graph_node.attrs.get('axis', 0))
|
|
349
|
-
onnx_axis = output_tensor_rank - 1 if onnx_axis == -1 else onnx_axis
|
|
350
|
-
if onnx_axis == output_tensor_rank - 1 \
|
|
351
|
-
and onnx_axis in matched_axes:
|
|
485
|
+
if not axis_is_dynamic:
|
|
486
|
+
output_tensor_shape = tf_layers_dict[graph_node_output.name]['tf_node'].shape
|
|
487
|
+
if output_tensor_shape != tf.TensorShape(None):
|
|
488
|
+
output_tensor_rank = len(output_tensor_shape)
|
|
489
|
+
if graph_node.outputs[0].shape is not None \
|
|
490
|
+
and axis != 0 \
|
|
491
|
+
and output_tensor_rank >= 2 \
|
|
492
|
+
and before_axis == axis:
|
|
493
|
+
|
|
494
|
+
# Search for valid Concat patterns
|
|
495
|
+
if not shape_is_equal_ignore_order(list(graph_node.outputs[0].shape), list(output_tensor_shape)):
|
|
496
|
+
matched_axes = []
|
|
497
|
+
for dummy_axis in range(1, output_tensor_rank):
|
|
498
|
+
try:
|
|
499
|
+
dummy_concat_tensor = \
|
|
500
|
+
tf.concat(
|
|
501
|
+
values=values,
|
|
502
|
+
axis=dummy_axis,
|
|
503
|
+
name=graph_node.name,
|
|
504
|
+
)
|
|
505
|
+
dummy_output_shape = dummy_concat_tensor.shape
|
|
506
|
+
if shape_is_equal_ignore_order(list(graph_node.outputs[0].shape), list(dummy_output_shape)):
|
|
507
|
+
matched_axes.append(dummy_axis)
|
|
508
|
+
except:
|
|
509
|
+
pass
|
|
510
|
+
# Review Concat axes only if there is one valid join pattern
|
|
511
|
+
if len(matched_axes) == 1:
|
|
352
512
|
tf_layers_dict[graph_node_output.name]['tf_node'] = \
|
|
353
513
|
tf.concat(
|
|
354
514
|
values=values,
|
|
355
|
-
axis=
|
|
515
|
+
axis=matched_axes[0],
|
|
356
516
|
name=graph_node.name,
|
|
357
517
|
)
|
|
358
|
-
axis =
|
|
518
|
+
axis = matched_axes[0]
|
|
519
|
+
elif not nhwc_judge:
|
|
520
|
+
onnx_axis = int(graph_node.attrs.get('axis', 0))
|
|
521
|
+
onnx_axis = output_tensor_rank - 1 if onnx_axis == -1 else onnx_axis
|
|
522
|
+
if onnx_axis == output_tensor_rank - 1 \
|
|
523
|
+
and onnx_axis in matched_axes:
|
|
524
|
+
tf_layers_dict[graph_node_output.name]['tf_node'] = \
|
|
525
|
+
tf.concat(
|
|
526
|
+
values=values,
|
|
527
|
+
axis=onnx_axis,
|
|
528
|
+
name=graph_node.name,
|
|
529
|
+
)
|
|
530
|
+
axis = onnx_axis
|
|
359
531
|
|
|
360
532
|
# Workaround for post-concat accuracy degradation issues
|
|
361
533
|
# Process only in the case of a Concat of two tensors because the process is too redundant.
|