onnx2tf 1.23.3__py3-none-any.whl → 1.25.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2tf/__init__.py +1 -1
- onnx2tf/onnx2tf.py +181 -30
- onnx2tf/ops/Add.py +29 -0
- onnx2tf/ops/AveragePool.py +20 -10
- onnx2tf/ops/BatchNormalization.py +270 -24
- onnx2tf/ops/Concat.py +4 -4
- onnx2tf/ops/DepthToSpace.py +8 -0
- onnx2tf/ops/Div.py +30 -0
- onnx2tf/ops/Expand.py +207 -0
- onnx2tf/ops/Gather.py +67 -18
- onnx2tf/ops/Mod.py +29 -0
- onnx2tf/ops/Mul.py +30 -0
- onnx2tf/ops/ReduceL1.py +3 -0
- onnx2tf/ops/ReduceL2.py +3 -0
- onnx2tf/ops/ReduceLogSum.py +3 -0
- onnx2tf/ops/ReduceLogSumExp.py +3 -0
- onnx2tf/ops/ReduceMax.py +3 -0
- onnx2tf/ops/ReduceMean.py +3 -0
- onnx2tf/ops/ReduceMin.py +3 -0
- onnx2tf/ops/ReduceProd.py +3 -0
- onnx2tf/ops/ReduceSum.py +3 -0
- onnx2tf/ops/ReduceSumSquare.py +3 -0
- onnx2tf/ops/Shape.py +2 -0
- onnx2tf/ops/Sub.py +29 -0
- onnx2tf/ops/Transpose.py +14 -0
- onnx2tf/utils/common_functions.py +2 -2
- {onnx2tf-1.23.3.dist-info → onnx2tf-1.25.8.dist-info}/METADATA +269 -28
- {onnx2tf-1.23.3.dist-info → onnx2tf-1.25.8.dist-info}/RECORD +33 -33
- {onnx2tf-1.23.3.dist-info → onnx2tf-1.25.8.dist-info}/WHEEL +1 -1
- {onnx2tf-1.23.3.dist-info → onnx2tf-1.25.8.dist-info}/LICENSE +0 -0
- {onnx2tf-1.23.3.dist-info → onnx2tf-1.25.8.dist-info}/LICENSE_onnx-tensorflow +0 -0
- {onnx2tf-1.23.3.dist-info → onnx2tf-1.25.8.dist-info}/entry_points.txt +0 -0
- {onnx2tf-1.23.3.dist-info → onnx2tf-1.25.8.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,12 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import copy
|
|
1
3
|
import random
|
|
2
4
|
random.seed(0)
|
|
3
5
|
import numpy as np
|
|
4
6
|
np.random.seed(0)
|
|
7
|
+
import itertools
|
|
5
8
|
import tensorflow as tf
|
|
9
|
+
import tf_keras
|
|
6
10
|
import onnx_graphsurgeon as gs
|
|
7
11
|
from onnx2tf.utils.common_functions import (
|
|
8
12
|
print_node_info,
|
|
@@ -14,7 +18,11 @@ from onnx2tf.utils.common_functions import (
|
|
|
14
18
|
explicit_broadcast,
|
|
15
19
|
pre_explicit_broadcast,
|
|
16
20
|
transpose_with_flexing_deterrence,
|
|
21
|
+
get_tf_model_inputs,
|
|
22
|
+
dummy_tf_inference,
|
|
23
|
+
onnx_tf_tensor_validation,
|
|
17
24
|
)
|
|
25
|
+
from typing import List, Dict, Any
|
|
18
26
|
|
|
19
27
|
|
|
20
28
|
@print_node_info
|
|
@@ -85,6 +93,11 @@ def make_node(
|
|
|
85
93
|
nhwc: bool = tf_layers_dict[X.name]['nhwc'] \
|
|
86
94
|
if isinstance(X, gs.Variable) and 'nhwc' in tf_layers_dict[X.name].keys() else False
|
|
87
95
|
|
|
96
|
+
onnx_tensor_infos_for_validation: Dict[str: np.ndarray] = kwargs['onnx_tensor_infos_for_validation']
|
|
97
|
+
test_data_nhwc: np.ndarray = kwargs['test_data_nhwc']
|
|
98
|
+
custom_input_op_name_np_data_path: str = kwargs['custom_input_op_name_np_data_path']
|
|
99
|
+
disable_strict_mode: bool = kwargs['disable_strict_mode']
|
|
100
|
+
|
|
88
101
|
# Preserving Graph Structure (Dict)
|
|
89
102
|
tf_layers_dict[Y.name] = {
|
|
90
103
|
'optype': graph_node.op,
|
|
@@ -95,6 +108,7 @@ def make_node(
|
|
|
95
108
|
|
|
96
109
|
# Generation of TF OP
|
|
97
110
|
input_tensor = tf_layers_dict[X.name]['tf_node']
|
|
111
|
+
input_tensor_rank = len(input_tensor.shape)
|
|
98
112
|
|
|
99
113
|
# Pre-process transpose
|
|
100
114
|
input_tensor = pre_process_transpose(
|
|
@@ -123,42 +137,18 @@ def make_node(
|
|
|
123
137
|
input_tensor_1=input_tensor,
|
|
124
138
|
input_tensor_2=mean,
|
|
125
139
|
)
|
|
126
|
-
input_tensor, mean = explicit_broadcast(
|
|
127
|
-
const_or_var_1=input_tensor,
|
|
128
|
-
const_or_var_2=mean,
|
|
129
|
-
graph_node=graph_node,
|
|
130
|
-
tf_layers_dict= tf_layers_dict,
|
|
131
|
-
)
|
|
132
140
|
input_tensor, var = pre_explicit_broadcast(
|
|
133
141
|
input_tensor_1=input_tensor,
|
|
134
142
|
input_tensor_2=var,
|
|
135
143
|
)
|
|
136
|
-
input_tensor, var = explicit_broadcast(
|
|
137
|
-
const_or_var_1=input_tensor,
|
|
138
|
-
const_or_var_2=var,
|
|
139
|
-
graph_node=graph_node,
|
|
140
|
-
tf_layers_dict= tf_layers_dict,
|
|
141
|
-
)
|
|
142
144
|
input_tensor, offset = pre_explicit_broadcast(
|
|
143
145
|
input_tensor_1=input_tensor,
|
|
144
146
|
input_tensor_2=offset,
|
|
145
147
|
)
|
|
146
|
-
input_tensor, offset = explicit_broadcast(
|
|
147
|
-
const_or_var_1=input_tensor,
|
|
148
|
-
const_or_var_2=offset,
|
|
149
|
-
graph_node=graph_node,
|
|
150
|
-
tf_layers_dict= tf_layers_dict,
|
|
151
|
-
)
|
|
152
148
|
input_tensor, scale = pre_explicit_broadcast(
|
|
153
149
|
input_tensor_1=input_tensor,
|
|
154
150
|
input_tensor_2=scale,
|
|
155
151
|
)
|
|
156
|
-
input_tensor, scale = explicit_broadcast(
|
|
157
|
-
const_or_var_1=input_tensor,
|
|
158
|
-
const_or_var_2=scale,
|
|
159
|
-
graph_node=graph_node,
|
|
160
|
-
tf_layers_dict= tf_layers_dict,
|
|
161
|
-
)
|
|
162
152
|
|
|
163
153
|
try:
|
|
164
154
|
tf_layers_dict[Y.name]['tf_node'] = \
|
|
@@ -290,6 +280,262 @@ def make_node(
|
|
|
290
280
|
else:
|
|
291
281
|
raise
|
|
292
282
|
|
|
283
|
+
# Automatic accuracy compensation
|
|
284
|
+
graph_node_input_1_shape = X.shape
|
|
285
|
+
if graph_node_input_1_shape is not None:
|
|
286
|
+
|
|
287
|
+
# Get the output tensor of one previous OP of TensorFlow only once
|
|
288
|
+
if not disable_strict_mode:
|
|
289
|
+
tf_model_inputs = get_tf_model_inputs(
|
|
290
|
+
tf_layers_dict=tf_layers_dict,
|
|
291
|
+
)
|
|
292
|
+
val_model = None
|
|
293
|
+
if not isinstance(input_tensor, np.ndarray):
|
|
294
|
+
val_model = tf_keras.Model(
|
|
295
|
+
inputs=tf_model_inputs,
|
|
296
|
+
outputs=[
|
|
297
|
+
input_tensor,
|
|
298
|
+
],
|
|
299
|
+
)
|
|
300
|
+
else:
|
|
301
|
+
pass
|
|
302
|
+
|
|
303
|
+
# TF dummy inference
|
|
304
|
+
# Get the output tensor of the previous layer of MatMul
|
|
305
|
+
# If input.1 and input.2 are both layers, tf_pre_tensor_infos is 2 cases
|
|
306
|
+
# If one of input.1 or input.2 is np.ndarray, tf_pre_tensor_infos is 1 case
|
|
307
|
+
tf_pre_tensor_infos = {}
|
|
308
|
+
if not disable_strict_mode:
|
|
309
|
+
try:
|
|
310
|
+
tf_pre_tensor_infos: Dict[Any] = \
|
|
311
|
+
dummy_tf_inference(
|
|
312
|
+
model=val_model,
|
|
313
|
+
inputs=tf_model_inputs,
|
|
314
|
+
test_data_nhwc=test_data_nhwc,
|
|
315
|
+
custom_input_op_name_np_data_path=custom_input_op_name_np_data_path,
|
|
316
|
+
)
|
|
317
|
+
except Exception as ex:
|
|
318
|
+
pass
|
|
319
|
+
del val_model
|
|
320
|
+
|
|
321
|
+
# Get np.ndarray for validation
|
|
322
|
+
validation_data = None
|
|
323
|
+
if not disable_strict_mode:
|
|
324
|
+
if len(tf_pre_tensor_infos) == 1:
|
|
325
|
+
if not isinstance(input_tensor, np.ndarray):
|
|
326
|
+
validation_data = list(tf_pre_tensor_infos.values())[0]
|
|
327
|
+
else:
|
|
328
|
+
validation_data = copy.deepcopy(input_tensor)
|
|
329
|
+
|
|
330
|
+
# Get ONNX inference results
|
|
331
|
+
onnx_tensor_infos = None
|
|
332
|
+
if onnx_tensor_infos_for_validation is not None \
|
|
333
|
+
and onnx_tensor_infos_for_validation.get(Y.name, None) is not None:
|
|
334
|
+
onnx_tensor_infos = {
|
|
335
|
+
Y.name: onnx_tensor_infos_for_validation[Y.name]
|
|
336
|
+
}
|
|
337
|
+
del onnx_tensor_infos_for_validation
|
|
338
|
+
|
|
339
|
+
# Automatic correction of accuracy degradation
|
|
340
|
+
min_abs_err = sys.maxsize
|
|
341
|
+
min_abs_err_perm_1: List[int] = [idx for idx in range(len(mean.shape))]
|
|
342
|
+
|
|
343
|
+
if not disable_strict_mode:
|
|
344
|
+
if onnx_tensor_infos is not None and validation_data is not None:
|
|
345
|
+
tensor_1_candidate_for_transpositions = list(itertools.permutations(range(len(mean.shape))))
|
|
346
|
+
# Search for the axis with the smallest error
|
|
347
|
+
for tensor_1_candidate_for_transposition in tensor_1_candidate_for_transpositions:
|
|
348
|
+
try:
|
|
349
|
+
target_validation_data = validation_data
|
|
350
|
+
# Build TF dummy model
|
|
351
|
+
input = tf_keras.Input(
|
|
352
|
+
shape=validation_data.shape[1:],
|
|
353
|
+
batch_size=validation_data.shape[0] \
|
|
354
|
+
if isinstance(validation_data.shape[0], int) else None,
|
|
355
|
+
name='dummy_input',
|
|
356
|
+
dtype=validation_data.dtype,
|
|
357
|
+
)
|
|
358
|
+
val_model = tf_keras.Model(
|
|
359
|
+
inputs=[
|
|
360
|
+
input,
|
|
361
|
+
],
|
|
362
|
+
outputs=[
|
|
363
|
+
tf.nn.batch_normalization(
|
|
364
|
+
x=input,
|
|
365
|
+
mean=\
|
|
366
|
+
transpose_with_flexing_deterrence(
|
|
367
|
+
input_tensor=mean,
|
|
368
|
+
perm=min_abs_err_perm_1,
|
|
369
|
+
output_shape=Y.shape \
|
|
370
|
+
if None not in Y.shape and Y.shape != [] else None,
|
|
371
|
+
**kwargs,
|
|
372
|
+
) if not isinstance(mean, np.ndarray) else \
|
|
373
|
+
transpose_with_flexing_deterrence(
|
|
374
|
+
input_tensor=tf.convert_to_tensor(mean),
|
|
375
|
+
perm=min_abs_err_perm_1,
|
|
376
|
+
output_shape=Y.shape \
|
|
377
|
+
if None not in Y.shape and Y.shape != [] else None,
|
|
378
|
+
**kwargs,
|
|
379
|
+
),
|
|
380
|
+
variance=\
|
|
381
|
+
transpose_with_flexing_deterrence(
|
|
382
|
+
input_tensor=var,
|
|
383
|
+
perm=min_abs_err_perm_1,
|
|
384
|
+
output_shape=Y.shape \
|
|
385
|
+
if None not in Y.shape and Y.shape != [] else None,
|
|
386
|
+
**kwargs,
|
|
387
|
+
) if not isinstance(var, np.ndarray) else \
|
|
388
|
+
transpose_with_flexing_deterrence(
|
|
389
|
+
input_tensor=tf.convert_to_tensor(var),
|
|
390
|
+
perm=min_abs_err_perm_1,
|
|
391
|
+
output_shape=Y.shape \
|
|
392
|
+
if None not in Y.shape and Y.shape != [] else None,
|
|
393
|
+
**kwargs,
|
|
394
|
+
),
|
|
395
|
+
offset=\
|
|
396
|
+
transpose_with_flexing_deterrence(
|
|
397
|
+
input_tensor=offset,
|
|
398
|
+
perm=min_abs_err_perm_1,
|
|
399
|
+
output_shape=Y.shape \
|
|
400
|
+
if None not in Y.shape and Y.shape != [] else None,
|
|
401
|
+
**kwargs,
|
|
402
|
+
) if not isinstance(offset, np.ndarray) else \
|
|
403
|
+
transpose_with_flexing_deterrence(
|
|
404
|
+
input_tensor=tf.convert_to_tensor(offset),
|
|
405
|
+
perm=min_abs_err_perm_1,
|
|
406
|
+
output_shape=Y.shape \
|
|
407
|
+
if None not in Y.shape and Y.shape != [] else None,
|
|
408
|
+
**kwargs,
|
|
409
|
+
),
|
|
410
|
+
scale=\
|
|
411
|
+
transpose_with_flexing_deterrence(
|
|
412
|
+
input_tensor=scale,
|
|
413
|
+
perm=min_abs_err_perm_1,
|
|
414
|
+
output_shape=Y.shape \
|
|
415
|
+
if None not in Y.shape and Y.shape != [] else None,
|
|
416
|
+
**kwargs,
|
|
417
|
+
) if not isinstance(scale, np.ndarray) else \
|
|
418
|
+
transpose_with_flexing_deterrence(
|
|
419
|
+
input_tensor=tf.convert_to_tensor(scale),
|
|
420
|
+
perm=min_abs_err_perm_1,
|
|
421
|
+
output_shape=Y.shape \
|
|
422
|
+
if None not in Y.shape and Y.shape != [] else None,
|
|
423
|
+
**kwargs,
|
|
424
|
+
),
|
|
425
|
+
variance_epsilon=epsilon,
|
|
426
|
+
)
|
|
427
|
+
],
|
|
428
|
+
)
|
|
429
|
+
# TF dummy inference
|
|
430
|
+
tf_tensor_infos: Dict[Any] = \
|
|
431
|
+
dummy_tf_inference(
|
|
432
|
+
model=val_model,
|
|
433
|
+
inputs=[
|
|
434
|
+
input,
|
|
435
|
+
],
|
|
436
|
+
verification_datas=[
|
|
437
|
+
target_validation_data,
|
|
438
|
+
],
|
|
439
|
+
)
|
|
440
|
+
del input
|
|
441
|
+
del val_model
|
|
442
|
+
|
|
443
|
+
# Validation
|
|
444
|
+
onnx_tf_output_pairs = {
|
|
445
|
+
(oi[0], ti[0]): (oi[1], ti[1]) \
|
|
446
|
+
for oi, ti in zip(onnx_tensor_infos.items(), tf_tensor_infos.items())
|
|
447
|
+
}
|
|
448
|
+
"""
|
|
449
|
+
check_results: Dict[str, List[np.ndarray, int, float|int]]
|
|
450
|
+
{
|
|
451
|
+
onnx_output_name: [
|
|
452
|
+
onnx_tensor,
|
|
453
|
+
matched_flg, <--- 0: Unmatched, 1: Matched, 2: Skipped (Deleted or Shape Unmatched)
|
|
454
|
+
max_abs_err,
|
|
455
|
+
]
|
|
456
|
+
}
|
|
457
|
+
"""
|
|
458
|
+
check_results = \
|
|
459
|
+
onnx_tf_tensor_validation(
|
|
460
|
+
output_pairs=onnx_tf_output_pairs,
|
|
461
|
+
rtol=0.0,
|
|
462
|
+
atol=0.0,
|
|
463
|
+
)
|
|
464
|
+
result_err = sum([val[2] for val in check_results.values()])
|
|
465
|
+
if result_err < min_abs_err:
|
|
466
|
+
min_abs_err = result_err
|
|
467
|
+
min_abs_err_perm_1 = list(tensor_1_candidate_for_transposition)
|
|
468
|
+
if min_abs_err < 1e-3:
|
|
469
|
+
break
|
|
470
|
+
except Exception as ex:
|
|
471
|
+
pass
|
|
472
|
+
|
|
473
|
+
tf_layers_dict[Y.name]['tf_node'] = \
|
|
474
|
+
tf.nn.batch_normalization(
|
|
475
|
+
x=input_tensor,
|
|
476
|
+
mean=\
|
|
477
|
+
transpose_with_flexing_deterrence(
|
|
478
|
+
input_tensor=mean,
|
|
479
|
+
perm=min_abs_err_perm_1,
|
|
480
|
+
output_shape=Y.shape \
|
|
481
|
+
if None not in Y.shape and Y.shape != [] else None,
|
|
482
|
+
**kwargs,
|
|
483
|
+
) if not isinstance(mean, np.ndarray) else \
|
|
484
|
+
transpose_with_flexing_deterrence(
|
|
485
|
+
input_tensor=tf.convert_to_tensor(mean),
|
|
486
|
+
perm=min_abs_err_perm_1,
|
|
487
|
+
output_shape=Y.shape \
|
|
488
|
+
if None not in Y.shape and Y.shape != [] else None,
|
|
489
|
+
**kwargs,
|
|
490
|
+
),
|
|
491
|
+
variance=\
|
|
492
|
+
transpose_with_flexing_deterrence(
|
|
493
|
+
input_tensor=var,
|
|
494
|
+
perm=min_abs_err_perm_1,
|
|
495
|
+
output_shape=Y.shape \
|
|
496
|
+
if None not in Y.shape and Y.shape != [] else None,
|
|
497
|
+
**kwargs,
|
|
498
|
+
) if not isinstance(var, np.ndarray) else \
|
|
499
|
+
transpose_with_flexing_deterrence(
|
|
500
|
+
input_tensor=tf.convert_to_tensor(var),
|
|
501
|
+
perm=min_abs_err_perm_1,
|
|
502
|
+
output_shape=Y.shape \
|
|
503
|
+
if None not in Y.shape and Y.shape != [] else None,
|
|
504
|
+
**kwargs,
|
|
505
|
+
),
|
|
506
|
+
offset=\
|
|
507
|
+
transpose_with_flexing_deterrence(
|
|
508
|
+
input_tensor=offset,
|
|
509
|
+
perm=min_abs_err_perm_1,
|
|
510
|
+
output_shape=Y.shape \
|
|
511
|
+
if None not in Y.shape and Y.shape != [] else None,
|
|
512
|
+
**kwargs,
|
|
513
|
+
) if not isinstance(offset, np.ndarray) else \
|
|
514
|
+
transpose_with_flexing_deterrence(
|
|
515
|
+
input_tensor=tf.convert_to_tensor(offset),
|
|
516
|
+
perm=min_abs_err_perm_1,
|
|
517
|
+
output_shape=Y.shape \
|
|
518
|
+
if None not in Y.shape and Y.shape != [] else None,
|
|
519
|
+
**kwargs,
|
|
520
|
+
),
|
|
521
|
+
scale=\
|
|
522
|
+
transpose_with_flexing_deterrence(
|
|
523
|
+
input_tensor=scale,
|
|
524
|
+
perm=min_abs_err_perm_1,
|
|
525
|
+
output_shape=Y.shape \
|
|
526
|
+
if None not in Y.shape and Y.shape != [] else None,
|
|
527
|
+
**kwargs,
|
|
528
|
+
) if not isinstance(scale, np.ndarray) else \
|
|
529
|
+
transpose_with_flexing_deterrence(
|
|
530
|
+
input_tensor=tf.convert_to_tensor(scale),
|
|
531
|
+
perm=min_abs_err_perm_1,
|
|
532
|
+
output_shape=Y.shape \
|
|
533
|
+
if None not in Y.shape and Y.shape != [] else None,
|
|
534
|
+
**kwargs,
|
|
535
|
+
),
|
|
536
|
+
variance_epsilon=epsilon,
|
|
537
|
+
)
|
|
538
|
+
tf_type = tf.nn.batch_normalization
|
|
293
539
|
|
|
294
540
|
# Post-process transpose
|
|
295
541
|
tf_layers_dict[Y.name]['tf_node'] = post_process_transpose(
|
onnx2tf/ops/Concat.py
CHANGED
|
@@ -252,13 +252,13 @@ def make_node(
|
|
|
252
252
|
)
|
|
253
253
|
tf_type = tf.slice
|
|
254
254
|
|
|
255
|
-
elif simple_resize2 and len(values)
|
|
256
|
-
target_input: np.ndarray =
|
|
255
|
+
elif simple_resize2 and len(values) >= 2:
|
|
256
|
+
target_input: np.ndarray = np.array([], dtype=np.int64)
|
|
257
257
|
target_spartial_size: int = 0
|
|
258
258
|
for cat_value in values:
|
|
259
259
|
if hasattr(cat_value, 'numpy'):
|
|
260
|
-
target_input = cat_value.numpy()
|
|
261
|
-
|
|
260
|
+
target_input = np.append(target_input, cat_value.numpy())
|
|
261
|
+
elif not hasattr(cat_value, 'numpy') and cat_value.shape is not None:
|
|
262
262
|
target_spartial_size = cat_value.shape[0] - 2
|
|
263
263
|
if target_spartial_size == len(target_input):
|
|
264
264
|
target_input = np.asarray([1] + [i for i in target_input] + [1])
|
onnx2tf/ops/DepthToSpace.py
CHANGED
|
@@ -96,6 +96,14 @@ def make_node(
|
|
|
96
96
|
elif mode == "CRD":
|
|
97
97
|
batch, channel = input_tensor_shape[0], input_tensor_shape[-1]
|
|
98
98
|
height, width = input_tensor_shape[1], input_tensor_shape[2]
|
|
99
|
+
if batch is None:
|
|
100
|
+
batch = tf.shape(input_tensor)[0]
|
|
101
|
+
if channel is None:
|
|
102
|
+
channel = tf.shape(input_tensor)[-1]
|
|
103
|
+
if height is None:
|
|
104
|
+
height = tf.shape(input_tensor)[1]
|
|
105
|
+
if width is None:
|
|
106
|
+
width = tf.shape(input_tensor)[2]
|
|
99
107
|
csize = channel // (blocksize**2)
|
|
100
108
|
|
|
101
109
|
reshape_node = tf.reshape(
|
onnx2tf/ops/Div.py
CHANGED
|
@@ -158,8 +158,37 @@ def make_node(
|
|
|
158
158
|
is_scalar_2_rank = tf.rank(input_tensor_2) == 0
|
|
159
159
|
if hasattr(is_scalar_2_rank, 'numpy'):
|
|
160
160
|
is_scalar_2 = is_scalar_2_rank.numpy()
|
|
161
|
+
|
|
161
162
|
if (is_scalar_1 or is_scalar_2) and graph_node.i().op == 'Gemm':
|
|
162
163
|
pass
|
|
164
|
+
elif (is_scalar_1 or is_scalar_2) and graph_node.i().op != 'Gemm':
|
|
165
|
+
first_tensor = None
|
|
166
|
+
second_tensor = None
|
|
167
|
+
if is_scalar_1:
|
|
168
|
+
first_tensor = input_tensor_2
|
|
169
|
+
second_tensor = input_tensor_1
|
|
170
|
+
elif is_scalar_2:
|
|
171
|
+
first_tensor = input_tensor_1
|
|
172
|
+
second_tensor = input_tensor_2
|
|
173
|
+
tmp_result = tf.math.divide(first_tensor, second_tensor)
|
|
174
|
+
tmp_result_shape = tmp_result.shape
|
|
175
|
+
if first_tensor.shape == tmp_result_shape:
|
|
176
|
+
pass
|
|
177
|
+
else:
|
|
178
|
+
input_tensor_1, input_tensor_2 = \
|
|
179
|
+
pre_explicit_broadcast(
|
|
180
|
+
input_tensor_1=input_tensor_1,
|
|
181
|
+
input_tensor_2=input_tensor_2,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
input_tensor_1, input_tensor_2 = \
|
|
185
|
+
explicit_broadcast(
|
|
186
|
+
const_or_var_1=input_tensor_1,
|
|
187
|
+
const_or_var_2=input_tensor_2,
|
|
188
|
+
graph_node=graph_node,
|
|
189
|
+
tf_layers_dict= tf_layers_dict,
|
|
190
|
+
)
|
|
191
|
+
|
|
163
192
|
else:
|
|
164
193
|
input_tensor_1, input_tensor_2 = \
|
|
165
194
|
pre_explicit_broadcast(
|
|
@@ -174,6 +203,7 @@ def make_node(
|
|
|
174
203
|
graph_node=graph_node,
|
|
175
204
|
tf_layers_dict= tf_layers_dict,
|
|
176
205
|
)
|
|
206
|
+
|
|
177
207
|
except Exception as ex:
|
|
178
208
|
input_tensor_1, input_tensor_2 = \
|
|
179
209
|
pre_explicit_broadcast(
|
onnx2tf/ops/Expand.py
CHANGED
|
@@ -1,8 +1,12 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import copy
|
|
1
3
|
import random
|
|
2
4
|
random.seed(0)
|
|
3
5
|
import numpy as np
|
|
4
6
|
np.random.seed(0)
|
|
7
|
+
import itertools
|
|
5
8
|
import tensorflow as tf
|
|
9
|
+
import tf_keras
|
|
6
10
|
import onnx_graphsurgeon as gs
|
|
7
11
|
from onnx2tf.utils.common_functions import (
|
|
8
12
|
get_replacement_parameter,
|
|
@@ -13,7 +17,12 @@ from onnx2tf.utils.common_functions import (
|
|
|
13
17
|
make_tf_node_info,
|
|
14
18
|
pre_process_transpose,
|
|
15
19
|
post_process_transpose,
|
|
20
|
+
transpose_with_flexing_deterrence,
|
|
21
|
+
get_tf_model_inputs,
|
|
22
|
+
dummy_tf_inference,
|
|
23
|
+
onnx_tf_tensor_validation,
|
|
16
24
|
)
|
|
25
|
+
from typing import List, Dict, Any
|
|
17
26
|
|
|
18
27
|
|
|
19
28
|
@print_node_info
|
|
@@ -57,9 +66,15 @@ def make_node(
|
|
|
57
66
|
|
|
58
67
|
input_tensor = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
|
|
59
68
|
if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
|
|
69
|
+
input_tensor_rank = len(input_tensor.shape)
|
|
60
70
|
input_tensor_shape = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
|
|
61
71
|
if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
|
|
62
72
|
|
|
73
|
+
onnx_tensor_infos_for_validation: Dict[str: np.ndarray] = kwargs['onnx_tensor_infos_for_validation']
|
|
74
|
+
test_data_nhwc: np.ndarray = kwargs['test_data_nhwc']
|
|
75
|
+
custom_input_op_name_np_data_path: str = kwargs['custom_input_op_name_np_data_path']
|
|
76
|
+
disable_strict_mode: bool = kwargs['disable_strict_mode']
|
|
77
|
+
|
|
63
78
|
# Preserving Graph Structure (Dict)
|
|
64
79
|
tf_layers_dict[graph_node_output.name] = {
|
|
65
80
|
'optype': graph_node.op,
|
|
@@ -119,6 +134,198 @@ def make_node(
|
|
|
119
134
|
tf_layers_dict[graph_node_output.name]['tf_node'] = expanded_tensor
|
|
120
135
|
tf_type = 'Expand'
|
|
121
136
|
|
|
137
|
+
if tf_type == 'Expand':
|
|
138
|
+
graph_node_input_1_shape = graph_node_input_1.shape
|
|
139
|
+
graph_node_input_2_shape = graph_node_input_2.shape
|
|
140
|
+
|
|
141
|
+
# Get the output tensor of one previous OP of TensorFlow only once
|
|
142
|
+
if not disable_strict_mode:
|
|
143
|
+
tf_model_inputs = get_tf_model_inputs(
|
|
144
|
+
tf_layers_dict=tf_layers_dict,
|
|
145
|
+
)
|
|
146
|
+
val_model = None
|
|
147
|
+
if not isinstance(input_tensor, np.ndarray):
|
|
148
|
+
expand_shape = []
|
|
149
|
+
if not isinstance(input_tensor_shape, np.ndarray):
|
|
150
|
+
expand_shape = [input_tensor_shape]
|
|
151
|
+
val_model = tf_keras.Model(
|
|
152
|
+
inputs=tf_model_inputs,
|
|
153
|
+
outputs=[
|
|
154
|
+
input_tensor,
|
|
155
|
+
] + expand_shape,
|
|
156
|
+
)
|
|
157
|
+
else:
|
|
158
|
+
pass
|
|
159
|
+
|
|
160
|
+
# TF dummy inference
|
|
161
|
+
# Get the output tensor of the previous layer of MatMul
|
|
162
|
+
# If input.1 and input.2 are both layers, tf_pre_tensor_infos is 2 cases
|
|
163
|
+
# If one of input.1 or input.2 is np.ndarray, tf_pre_tensor_infos is 1 case
|
|
164
|
+
tf_pre_tensor_infos = {}
|
|
165
|
+
if not disable_strict_mode:
|
|
166
|
+
try:
|
|
167
|
+
tf_pre_tensor_infos: Dict[Any] = \
|
|
168
|
+
dummy_tf_inference(
|
|
169
|
+
model=val_model,
|
|
170
|
+
inputs=tf_model_inputs,
|
|
171
|
+
test_data_nhwc=test_data_nhwc,
|
|
172
|
+
custom_input_op_name_np_data_path=custom_input_op_name_np_data_path,
|
|
173
|
+
)
|
|
174
|
+
except Exception as ex:
|
|
175
|
+
pass
|
|
176
|
+
del val_model
|
|
177
|
+
|
|
178
|
+
# Get np.ndarray for validation
|
|
179
|
+
validation_data_1 = None
|
|
180
|
+
validation_data_2 = None
|
|
181
|
+
|
|
182
|
+
if not disable_strict_mode:
|
|
183
|
+
if len(tf_pre_tensor_infos) == 1:
|
|
184
|
+
if not isinstance(input_tensor, np.ndarray):
|
|
185
|
+
validation_data_1 = list(tf_pre_tensor_infos.values())[0]
|
|
186
|
+
else:
|
|
187
|
+
validation_data_1 = copy.deepcopy(input_tensor)
|
|
188
|
+
elif len(tf_pre_tensor_infos) == 2:
|
|
189
|
+
if not isinstance(input_tensor, np.ndarray):
|
|
190
|
+
validation_data_1 = list(tf_pre_tensor_infos.values())[0]
|
|
191
|
+
else:
|
|
192
|
+
validation_data_1 = copy.deepcopy(input_tensor)
|
|
193
|
+
if not isinstance(input_tensor_shape, np.ndarray):
|
|
194
|
+
validation_data_2 = list(tf_pre_tensor_infos.values())[1]
|
|
195
|
+
else:
|
|
196
|
+
validation_data_2 = copy.deepcopy(input_tensor_shape)
|
|
197
|
+
|
|
198
|
+
# Get ONNX inference results
|
|
199
|
+
onnx_tensor_infos = None
|
|
200
|
+
if onnx_tensor_infos_for_validation is not None \
|
|
201
|
+
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
|
|
202
|
+
onnx_tensor_infos = {
|
|
203
|
+
graph_node_output.name: onnx_tensor_infos_for_validation[graph_node_output.name]
|
|
204
|
+
}
|
|
205
|
+
del onnx_tensor_infos_for_validation
|
|
206
|
+
|
|
207
|
+
# ONNX : N,C,W
|
|
208
|
+
# TF : N,W,C
|
|
209
|
+
# TF-axes: [1]
|
|
210
|
+
#
|
|
211
|
+
# ONNX: N,C,H,W
|
|
212
|
+
# TF : N,H,W,C
|
|
213
|
+
# TF-axes: [1,2]
|
|
214
|
+
#
|
|
215
|
+
# ONNX: N,C,D,H,W
|
|
216
|
+
# TF : N,D,H,W,C
|
|
217
|
+
# TF-axes: [1,2,3]
|
|
218
|
+
|
|
219
|
+
# Automatic correction of accuracy degradation
|
|
220
|
+
min_abs_err = sys.maxsize
|
|
221
|
+
min_abs_err_perm_1: List[int] = [idx for idx in range(input_tensor_rank)]
|
|
222
|
+
min_abs_err_perm_2: List[int] = [idx for idx, val in enumerate(input_tensor_shape)]
|
|
223
|
+
|
|
224
|
+
if not disable_strict_mode:
|
|
225
|
+
if onnx_tensor_infos is not None and validation_data_1 is not None and validation_data_2 is not None:
|
|
226
|
+
tensor_1_candidate_for_transpositions = list(itertools.permutations(range(input_tensor_rank)))
|
|
227
|
+
tensor_2_candidate_for_transpositions = list(itertools.permutations(range(len(min_abs_err_perm_2))))
|
|
228
|
+
# Search for the axis with the smallest error
|
|
229
|
+
for tensor_1_candidate_for_transposition in tensor_1_candidate_for_transpositions:
|
|
230
|
+
try:
|
|
231
|
+
for tensor_2_candidate_for_transposition in tensor_2_candidate_for_transpositions:
|
|
232
|
+
try:
|
|
233
|
+
# Build TF dummy model
|
|
234
|
+
input_1 = tf_keras.Input(
|
|
235
|
+
shape=validation_data_1.shape[1:],
|
|
236
|
+
batch_size=validation_data_1.shape[0] \
|
|
237
|
+
if isinstance(validation_data_1.shape[0], int) else None,
|
|
238
|
+
name='dummy_input_1',
|
|
239
|
+
dtype=validation_data_1.dtype,
|
|
240
|
+
)
|
|
241
|
+
expand_shape = [validation_data_2[pos] for pos in tensor_2_candidate_for_transposition]
|
|
242
|
+
input_2 = tf_keras.Input(
|
|
243
|
+
shape=[len(expand_shape)],
|
|
244
|
+
batch_size=1,
|
|
245
|
+
name='dummy_input_2',
|
|
246
|
+
dtype=validation_data_2.dtype,
|
|
247
|
+
)
|
|
248
|
+
a=0
|
|
249
|
+
|
|
250
|
+
ones = tf.ones(input_2[0], dtype=input_tensor.dtype)
|
|
251
|
+
expanded_tensor = input_1 * ones
|
|
252
|
+
a=0
|
|
253
|
+
|
|
254
|
+
val_model = tf_keras.Model(
|
|
255
|
+
inputs=[
|
|
256
|
+
input_1,
|
|
257
|
+
input_2,
|
|
258
|
+
],
|
|
259
|
+
outputs=[
|
|
260
|
+
expanded_tensor,
|
|
261
|
+
],
|
|
262
|
+
)
|
|
263
|
+
a=0
|
|
264
|
+
# TF dummy inference
|
|
265
|
+
tf_tensor_infos: Dict[Any] = \
|
|
266
|
+
dummy_tf_inference(
|
|
267
|
+
model=val_model,
|
|
268
|
+
inputs=[
|
|
269
|
+
input_1,
|
|
270
|
+
input_2,
|
|
271
|
+
],
|
|
272
|
+
verification_datas=[
|
|
273
|
+
validation_data_1,
|
|
274
|
+
tf.convert_to_tensor([expand_shape], dtype=tf.int64) if isinstance(expand_shape, list) else tf.expand_dims(expand_shape, axis=0),
|
|
275
|
+
],
|
|
276
|
+
)
|
|
277
|
+
del input_1
|
|
278
|
+
del input_2
|
|
279
|
+
del val_model
|
|
280
|
+
|
|
281
|
+
# Validation
|
|
282
|
+
onnx_tf_output_pairs = {
|
|
283
|
+
(oi[0], ti[0]): (oi[1], ti[1]) \
|
|
284
|
+
for oi, ti in zip(onnx_tensor_infos.items(), tf_tensor_infos.items())
|
|
285
|
+
}
|
|
286
|
+
"""
|
|
287
|
+
check_results: Dict[str, List[np.ndarray, int, float|int]]
|
|
288
|
+
{
|
|
289
|
+
onnx_output_name: [
|
|
290
|
+
onnx_tensor,
|
|
291
|
+
matched_flg, <--- 0: Unmatched, 1: Matched, 2: Skipped (Deleted or Shape Unmatched)
|
|
292
|
+
max_abs_err,
|
|
293
|
+
]
|
|
294
|
+
}
|
|
295
|
+
"""
|
|
296
|
+
check_results = \
|
|
297
|
+
onnx_tf_tensor_validation(
|
|
298
|
+
output_pairs=onnx_tf_output_pairs,
|
|
299
|
+
rtol=0.0,
|
|
300
|
+
atol=0.0,
|
|
301
|
+
)
|
|
302
|
+
result_err = sum([val[2] for val in check_results.values()])
|
|
303
|
+
if result_err < min_abs_err:
|
|
304
|
+
min_abs_err = result_err
|
|
305
|
+
min_abs_err_perm_1 = list(tensor_1_candidate_for_transposition)
|
|
306
|
+
min_abs_err_perm_2 = list(tensor_2_candidate_for_transposition)
|
|
307
|
+
if min_abs_err < 1e-3:
|
|
308
|
+
break
|
|
309
|
+
except Exception as ex:
|
|
310
|
+
pass
|
|
311
|
+
except Exception as ex:
|
|
312
|
+
pass
|
|
313
|
+
|
|
314
|
+
input_tensor = \
|
|
315
|
+
transpose_with_flexing_deterrence(
|
|
316
|
+
input_tensor=input_tensor,
|
|
317
|
+
perm=min_abs_err_perm_1,
|
|
318
|
+
output_shape=input_tensor_shape \
|
|
319
|
+
if None not in input_tensor.shape and input_tensor.shape != [] else None,
|
|
320
|
+
**kwargs,
|
|
321
|
+
)
|
|
322
|
+
input_tensor_shape = [input_tensor_shape[pos] for pos in min_abs_err_perm_2]
|
|
323
|
+
ones = tf.ones(input_tensor_shape, dtype=input_tensor.dtype)
|
|
324
|
+
expanded_tensor = input_tensor * ones
|
|
325
|
+
|
|
326
|
+
tf_layers_dict[graph_node_output.name]['tf_node'] = expanded_tensor
|
|
327
|
+
tf_type = tf.expand_dims
|
|
328
|
+
|
|
122
329
|
# Post-process transpose
|
|
123
330
|
tf_layers_dict[graph_node_output.name]['tf_node'] = post_process_transpose(
|
|
124
331
|
value_before_transpose=tf_layers_dict[graph_node_output.name]['tf_node'],
|