onnx2tf 1.23.3__py3-none-any.whl → 1.25.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,12 @@
1
+ import sys
2
+ import copy
1
3
  import random
2
4
  random.seed(0)
3
5
  import numpy as np
4
6
  np.random.seed(0)
7
+ import itertools
5
8
  import tensorflow as tf
9
+ import tf_keras
6
10
  import onnx_graphsurgeon as gs
7
11
  from onnx2tf.utils.common_functions import (
8
12
  print_node_info,
@@ -14,7 +18,11 @@ from onnx2tf.utils.common_functions import (
14
18
  explicit_broadcast,
15
19
  pre_explicit_broadcast,
16
20
  transpose_with_flexing_deterrence,
21
+ get_tf_model_inputs,
22
+ dummy_tf_inference,
23
+ onnx_tf_tensor_validation,
17
24
  )
25
+ from typing import List, Dict, Any
18
26
 
19
27
 
20
28
  @print_node_info
@@ -85,6 +93,11 @@ def make_node(
85
93
  nhwc: bool = tf_layers_dict[X.name]['nhwc'] \
86
94
  if isinstance(X, gs.Variable) and 'nhwc' in tf_layers_dict[X.name].keys() else False
87
95
 
96
+ onnx_tensor_infos_for_validation: Dict[str: np.ndarray] = kwargs['onnx_tensor_infos_for_validation']
97
+ test_data_nhwc: np.ndarray = kwargs['test_data_nhwc']
98
+ custom_input_op_name_np_data_path: str = kwargs['custom_input_op_name_np_data_path']
99
+ disable_strict_mode: bool = kwargs['disable_strict_mode']
100
+
88
101
  # Preserving Graph Structure (Dict)
89
102
  tf_layers_dict[Y.name] = {
90
103
  'optype': graph_node.op,
@@ -95,6 +108,7 @@ def make_node(
95
108
 
96
109
  # Generation of TF OP
97
110
  input_tensor = tf_layers_dict[X.name]['tf_node']
111
+ input_tensor_rank = len(input_tensor.shape)
98
112
 
99
113
  # Pre-process transpose
100
114
  input_tensor = pre_process_transpose(
@@ -123,42 +137,18 @@ def make_node(
123
137
  input_tensor_1=input_tensor,
124
138
  input_tensor_2=mean,
125
139
  )
126
- input_tensor, mean = explicit_broadcast(
127
- const_or_var_1=input_tensor,
128
- const_or_var_2=mean,
129
- graph_node=graph_node,
130
- tf_layers_dict= tf_layers_dict,
131
- )
132
140
  input_tensor, var = pre_explicit_broadcast(
133
141
  input_tensor_1=input_tensor,
134
142
  input_tensor_2=var,
135
143
  )
136
- input_tensor, var = explicit_broadcast(
137
- const_or_var_1=input_tensor,
138
- const_or_var_2=var,
139
- graph_node=graph_node,
140
- tf_layers_dict= tf_layers_dict,
141
- )
142
144
  input_tensor, offset = pre_explicit_broadcast(
143
145
  input_tensor_1=input_tensor,
144
146
  input_tensor_2=offset,
145
147
  )
146
- input_tensor, offset = explicit_broadcast(
147
- const_or_var_1=input_tensor,
148
- const_or_var_2=offset,
149
- graph_node=graph_node,
150
- tf_layers_dict= tf_layers_dict,
151
- )
152
148
  input_tensor, scale = pre_explicit_broadcast(
153
149
  input_tensor_1=input_tensor,
154
150
  input_tensor_2=scale,
155
151
  )
156
- input_tensor, scale = explicit_broadcast(
157
- const_or_var_1=input_tensor,
158
- const_or_var_2=scale,
159
- graph_node=graph_node,
160
- tf_layers_dict= tf_layers_dict,
161
- )
162
152
 
163
153
  try:
164
154
  tf_layers_dict[Y.name]['tf_node'] = \
@@ -290,6 +280,262 @@ def make_node(
290
280
  else:
291
281
  raise
292
282
 
283
+ # Automatic accuracy compensation
284
+ graph_node_input_1_shape = X.shape
285
+ if graph_node_input_1_shape is not None:
286
+
287
+ # Get the output tensor of one previous OP of TensorFlow only once
288
+ if not disable_strict_mode:
289
+ tf_model_inputs = get_tf_model_inputs(
290
+ tf_layers_dict=tf_layers_dict,
291
+ )
292
+ val_model = None
293
+ if not isinstance(input_tensor, np.ndarray):
294
+ val_model = tf_keras.Model(
295
+ inputs=tf_model_inputs,
296
+ outputs=[
297
+ input_tensor,
298
+ ],
299
+ )
300
+ else:
301
+ pass
302
+
303
+ # TF dummy inference
304
+ # Get the output tensor of the previous layer of MatMul
305
+ # If input.1 and input.2 are both layers, tf_pre_tensor_infos is 2 cases
306
+ # If one of input.1 or input.2 is np.ndarray, tf_pre_tensor_infos is 1 case
307
+ tf_pre_tensor_infos = {}
308
+ if not disable_strict_mode:
309
+ try:
310
+ tf_pre_tensor_infos: Dict[Any] = \
311
+ dummy_tf_inference(
312
+ model=val_model,
313
+ inputs=tf_model_inputs,
314
+ test_data_nhwc=test_data_nhwc,
315
+ custom_input_op_name_np_data_path=custom_input_op_name_np_data_path,
316
+ )
317
+ except Exception as ex:
318
+ pass
319
+ del val_model
320
+
321
+ # Get np.ndarray for validation
322
+ validation_data = None
323
+ if not disable_strict_mode:
324
+ if len(tf_pre_tensor_infos) == 1:
325
+ if not isinstance(input_tensor, np.ndarray):
326
+ validation_data = list(tf_pre_tensor_infos.values())[0]
327
+ else:
328
+ validation_data = copy.deepcopy(input_tensor)
329
+
330
+ # Get ONNX inference results
331
+ onnx_tensor_infos = None
332
+ if onnx_tensor_infos_for_validation is not None \
333
+ and onnx_tensor_infos_for_validation.get(Y.name, None) is not None:
334
+ onnx_tensor_infos = {
335
+ Y.name: onnx_tensor_infos_for_validation[Y.name]
336
+ }
337
+ del onnx_tensor_infos_for_validation
338
+
339
+ # Automatic correction of accuracy degradation
340
+ min_abs_err = sys.maxsize
341
+ min_abs_err_perm_1: List[int] = [idx for idx in range(len(mean.shape))]
342
+
343
+ if not disable_strict_mode:
344
+ if onnx_tensor_infos is not None and validation_data is not None:
345
+ tensor_1_candidate_for_transpositions = list(itertools.permutations(range(len(mean.shape))))
346
+ # Search for the axis with the smallest error
347
+ for tensor_1_candidate_for_transposition in tensor_1_candidate_for_transpositions:
348
+ try:
349
+ target_validation_data = validation_data
350
+ # Build TF dummy model
351
+ input = tf_keras.Input(
352
+ shape=validation_data.shape[1:],
353
+ batch_size=validation_data.shape[0] \
354
+ if isinstance(validation_data.shape[0], int) else None,
355
+ name='dummy_input',
356
+ dtype=validation_data.dtype,
357
+ )
358
+ val_model = tf_keras.Model(
359
+ inputs=[
360
+ input,
361
+ ],
362
+ outputs=[
363
+ tf.nn.batch_normalization(
364
+ x=input,
365
+ mean=\
366
+ transpose_with_flexing_deterrence(
367
+ input_tensor=mean,
368
+ perm=min_abs_err_perm_1,
369
+ output_shape=Y.shape \
370
+ if None not in Y.shape and Y.shape != [] else None,
371
+ **kwargs,
372
+ ) if not isinstance(mean, np.ndarray) else \
373
+ transpose_with_flexing_deterrence(
374
+ input_tensor=tf.convert_to_tensor(mean),
375
+ perm=min_abs_err_perm_1,
376
+ output_shape=Y.shape \
377
+ if None not in Y.shape and Y.shape != [] else None,
378
+ **kwargs,
379
+ ),
380
+ variance=\
381
+ transpose_with_flexing_deterrence(
382
+ input_tensor=var,
383
+ perm=min_abs_err_perm_1,
384
+ output_shape=Y.shape \
385
+ if None not in Y.shape and Y.shape != [] else None,
386
+ **kwargs,
387
+ ) if not isinstance(var, np.ndarray) else \
388
+ transpose_with_flexing_deterrence(
389
+ input_tensor=tf.convert_to_tensor(var),
390
+ perm=min_abs_err_perm_1,
391
+ output_shape=Y.shape \
392
+ if None not in Y.shape and Y.shape != [] else None,
393
+ **kwargs,
394
+ ),
395
+ offset=\
396
+ transpose_with_flexing_deterrence(
397
+ input_tensor=offset,
398
+ perm=min_abs_err_perm_1,
399
+ output_shape=Y.shape \
400
+ if None not in Y.shape and Y.shape != [] else None,
401
+ **kwargs,
402
+ ) if not isinstance(offset, np.ndarray) else \
403
+ transpose_with_flexing_deterrence(
404
+ input_tensor=tf.convert_to_tensor(offset),
405
+ perm=min_abs_err_perm_1,
406
+ output_shape=Y.shape \
407
+ if None not in Y.shape and Y.shape != [] else None,
408
+ **kwargs,
409
+ ),
410
+ scale=\
411
+ transpose_with_flexing_deterrence(
412
+ input_tensor=scale,
413
+ perm=min_abs_err_perm_1,
414
+ output_shape=Y.shape \
415
+ if None not in Y.shape and Y.shape != [] else None,
416
+ **kwargs,
417
+ ) if not isinstance(scale, np.ndarray) else \
418
+ transpose_with_flexing_deterrence(
419
+ input_tensor=tf.convert_to_tensor(scale),
420
+ perm=min_abs_err_perm_1,
421
+ output_shape=Y.shape \
422
+ if None not in Y.shape and Y.shape != [] else None,
423
+ **kwargs,
424
+ ),
425
+ variance_epsilon=epsilon,
426
+ )
427
+ ],
428
+ )
429
+ # TF dummy inference
430
+ tf_tensor_infos: Dict[Any] = \
431
+ dummy_tf_inference(
432
+ model=val_model,
433
+ inputs=[
434
+ input,
435
+ ],
436
+ verification_datas=[
437
+ target_validation_data,
438
+ ],
439
+ )
440
+ del input
441
+ del val_model
442
+
443
+ # Validation
444
+ onnx_tf_output_pairs = {
445
+ (oi[0], ti[0]): (oi[1], ti[1]) \
446
+ for oi, ti in zip(onnx_tensor_infos.items(), tf_tensor_infos.items())
447
+ }
448
+ """
449
+ check_results: Dict[str, List[np.ndarray, int, float|int]]
450
+ {
451
+ onnx_output_name: [
452
+ onnx_tensor,
453
+ matched_flg, <--- 0: Unmatched, 1: Matched, 2: Skipped (Deleted or Shape Unmatched)
454
+ max_abs_err,
455
+ ]
456
+ }
457
+ """
458
+ check_results = \
459
+ onnx_tf_tensor_validation(
460
+ output_pairs=onnx_tf_output_pairs,
461
+ rtol=0.0,
462
+ atol=0.0,
463
+ )
464
+ result_err = sum([val[2] for val in check_results.values()])
465
+ if result_err < min_abs_err:
466
+ min_abs_err = result_err
467
+ min_abs_err_perm_1 = list(tensor_1_candidate_for_transposition)
468
+ if min_abs_err < 1e-3:
469
+ break
470
+ except Exception as ex:
471
+ pass
472
+
473
+ tf_layers_dict[Y.name]['tf_node'] = \
474
+ tf.nn.batch_normalization(
475
+ x=input_tensor,
476
+ mean=\
477
+ transpose_with_flexing_deterrence(
478
+ input_tensor=mean,
479
+ perm=min_abs_err_perm_1,
480
+ output_shape=Y.shape \
481
+ if None not in Y.shape and Y.shape != [] else None,
482
+ **kwargs,
483
+ ) if not isinstance(mean, np.ndarray) else \
484
+ transpose_with_flexing_deterrence(
485
+ input_tensor=tf.convert_to_tensor(mean),
486
+ perm=min_abs_err_perm_1,
487
+ output_shape=Y.shape \
488
+ if None not in Y.shape and Y.shape != [] else None,
489
+ **kwargs,
490
+ ),
491
+ variance=\
492
+ transpose_with_flexing_deterrence(
493
+ input_tensor=var,
494
+ perm=min_abs_err_perm_1,
495
+ output_shape=Y.shape \
496
+ if None not in Y.shape and Y.shape != [] else None,
497
+ **kwargs,
498
+ ) if not isinstance(var, np.ndarray) else \
499
+ transpose_with_flexing_deterrence(
500
+ input_tensor=tf.convert_to_tensor(var),
501
+ perm=min_abs_err_perm_1,
502
+ output_shape=Y.shape \
503
+ if None not in Y.shape and Y.shape != [] else None,
504
+ **kwargs,
505
+ ),
506
+ offset=\
507
+ transpose_with_flexing_deterrence(
508
+ input_tensor=offset,
509
+ perm=min_abs_err_perm_1,
510
+ output_shape=Y.shape \
511
+ if None not in Y.shape and Y.shape != [] else None,
512
+ **kwargs,
513
+ ) if not isinstance(offset, np.ndarray) else \
514
+ transpose_with_flexing_deterrence(
515
+ input_tensor=tf.convert_to_tensor(offset),
516
+ perm=min_abs_err_perm_1,
517
+ output_shape=Y.shape \
518
+ if None not in Y.shape and Y.shape != [] else None,
519
+ **kwargs,
520
+ ),
521
+ scale=\
522
+ transpose_with_flexing_deterrence(
523
+ input_tensor=scale,
524
+ perm=min_abs_err_perm_1,
525
+ output_shape=Y.shape \
526
+ if None not in Y.shape and Y.shape != [] else None,
527
+ **kwargs,
528
+ ) if not isinstance(scale, np.ndarray) else \
529
+ transpose_with_flexing_deterrence(
530
+ input_tensor=tf.convert_to_tensor(scale),
531
+ perm=min_abs_err_perm_1,
532
+ output_shape=Y.shape \
533
+ if None not in Y.shape and Y.shape != [] else None,
534
+ **kwargs,
535
+ ),
536
+ variance_epsilon=epsilon,
537
+ )
538
+ tf_type = tf.nn.batch_normalization
293
539
 
294
540
  # Post-process transpose
295
541
  tf_layers_dict[Y.name]['tf_node'] = post_process_transpose(
onnx2tf/ops/Concat.py CHANGED
@@ -252,13 +252,13 @@ def make_node(
252
252
  )
253
253
  tf_type = tf.slice
254
254
 
255
- elif simple_resize2 and len(values) == 2:
256
- target_input: np.ndarray = None
255
+ elif simple_resize2 and len(values) >= 2:
256
+ target_input: np.ndarray = np.array([], dtype=np.int64)
257
257
  target_spartial_size: int = 0
258
258
  for cat_value in values:
259
259
  if hasattr(cat_value, 'numpy'):
260
- target_input = cat_value.numpy()
261
- if not hasattr(cat_value, 'numpy') and cat_value.shape is not None:
260
+ target_input = np.append(target_input, cat_value.numpy())
261
+ elif not hasattr(cat_value, 'numpy') and cat_value.shape is not None:
262
262
  target_spartial_size = cat_value.shape[0] - 2
263
263
  if target_spartial_size == len(target_input):
264
264
  target_input = np.asarray([1] + [i for i in target_input] + [1])
@@ -96,6 +96,14 @@ def make_node(
96
96
  elif mode == "CRD":
97
97
  batch, channel = input_tensor_shape[0], input_tensor_shape[-1]
98
98
  height, width = input_tensor_shape[1], input_tensor_shape[2]
99
+ if batch is None:
100
+ batch = tf.shape(input_tensor)[0]
101
+ if channel is None:
102
+ channel = tf.shape(input_tensor)[-1]
103
+ if height is None:
104
+ height = tf.shape(input_tensor)[1]
105
+ if width is None:
106
+ width = tf.shape(input_tensor)[2]
99
107
  csize = channel // (blocksize**2)
100
108
 
101
109
  reshape_node = tf.reshape(
onnx2tf/ops/Div.py CHANGED
@@ -158,8 +158,37 @@ def make_node(
158
158
  is_scalar_2_rank = tf.rank(input_tensor_2) == 0
159
159
  if hasattr(is_scalar_2_rank, 'numpy'):
160
160
  is_scalar_2 = is_scalar_2_rank.numpy()
161
+
161
162
  if (is_scalar_1 or is_scalar_2) and graph_node.i().op == 'Gemm':
162
163
  pass
164
+ elif (is_scalar_1 or is_scalar_2) and graph_node.i().op != 'Gemm':
165
+ first_tensor = None
166
+ second_tensor = None
167
+ if is_scalar_1:
168
+ first_tensor = input_tensor_2
169
+ second_tensor = input_tensor_1
170
+ elif is_scalar_2:
171
+ first_tensor = input_tensor_1
172
+ second_tensor = input_tensor_2
173
+ tmp_result = tf.math.divide(first_tensor, second_tensor)
174
+ tmp_result_shape = tmp_result.shape
175
+ if first_tensor.shape == tmp_result_shape:
176
+ pass
177
+ else:
178
+ input_tensor_1, input_tensor_2 = \
179
+ pre_explicit_broadcast(
180
+ input_tensor_1=input_tensor_1,
181
+ input_tensor_2=input_tensor_2,
182
+ )
183
+
184
+ input_tensor_1, input_tensor_2 = \
185
+ explicit_broadcast(
186
+ const_or_var_1=input_tensor_1,
187
+ const_or_var_2=input_tensor_2,
188
+ graph_node=graph_node,
189
+ tf_layers_dict= tf_layers_dict,
190
+ )
191
+
163
192
  else:
164
193
  input_tensor_1, input_tensor_2 = \
165
194
  pre_explicit_broadcast(
@@ -174,6 +203,7 @@ def make_node(
174
203
  graph_node=graph_node,
175
204
  tf_layers_dict= tf_layers_dict,
176
205
  )
206
+
177
207
  except Exception as ex:
178
208
  input_tensor_1, input_tensor_2 = \
179
209
  pre_explicit_broadcast(
onnx2tf/ops/Expand.py CHANGED
@@ -1,8 +1,12 @@
1
+ import sys
2
+ import copy
1
3
  import random
2
4
  random.seed(0)
3
5
  import numpy as np
4
6
  np.random.seed(0)
7
+ import itertools
5
8
  import tensorflow as tf
9
+ import tf_keras
6
10
  import onnx_graphsurgeon as gs
7
11
  from onnx2tf.utils.common_functions import (
8
12
  get_replacement_parameter,
@@ -13,7 +17,12 @@ from onnx2tf.utils.common_functions import (
13
17
  make_tf_node_info,
14
18
  pre_process_transpose,
15
19
  post_process_transpose,
20
+ transpose_with_flexing_deterrence,
21
+ get_tf_model_inputs,
22
+ dummy_tf_inference,
23
+ onnx_tf_tensor_validation,
16
24
  )
25
+ from typing import List, Dict, Any
17
26
 
18
27
 
19
28
  @print_node_info
@@ -57,9 +66,15 @@ def make_node(
57
66
 
58
67
  input_tensor = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
59
68
  if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
69
+ input_tensor_rank = len(input_tensor.shape)
60
70
  input_tensor_shape = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
61
71
  if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
62
72
 
73
+ onnx_tensor_infos_for_validation: Dict[str: np.ndarray] = kwargs['onnx_tensor_infos_for_validation']
74
+ test_data_nhwc: np.ndarray = kwargs['test_data_nhwc']
75
+ custom_input_op_name_np_data_path: str = kwargs['custom_input_op_name_np_data_path']
76
+ disable_strict_mode: bool = kwargs['disable_strict_mode']
77
+
63
78
  # Preserving Graph Structure (Dict)
64
79
  tf_layers_dict[graph_node_output.name] = {
65
80
  'optype': graph_node.op,
@@ -119,6 +134,198 @@ def make_node(
119
134
  tf_layers_dict[graph_node_output.name]['tf_node'] = expanded_tensor
120
135
  tf_type = 'Expand'
121
136
 
137
+ if tf_type == 'Expand':
138
+ graph_node_input_1_shape = graph_node_input_1.shape
139
+ graph_node_input_2_shape = graph_node_input_2.shape
140
+
141
+ # Get the output tensor of one previous OP of TensorFlow only once
142
+ if not disable_strict_mode:
143
+ tf_model_inputs = get_tf_model_inputs(
144
+ tf_layers_dict=tf_layers_dict,
145
+ )
146
+ val_model = None
147
+ if not isinstance(input_tensor, np.ndarray):
148
+ expand_shape = []
149
+ if not isinstance(input_tensor_shape, np.ndarray):
150
+ expand_shape = [input_tensor_shape]
151
+ val_model = tf_keras.Model(
152
+ inputs=tf_model_inputs,
153
+ outputs=[
154
+ input_tensor,
155
+ ] + expand_shape,
156
+ )
157
+ else:
158
+ pass
159
+
160
+ # TF dummy inference
161
+ # Get the output tensor of the previous layer of MatMul
162
+ # If input.1 and input.2 are both layers, tf_pre_tensor_infos is 2 cases
163
+ # If one of input.1 or input.2 is np.ndarray, tf_pre_tensor_infos is 1 case
164
+ tf_pre_tensor_infos = {}
165
+ if not disable_strict_mode:
166
+ try:
167
+ tf_pre_tensor_infos: Dict[Any] = \
168
+ dummy_tf_inference(
169
+ model=val_model,
170
+ inputs=tf_model_inputs,
171
+ test_data_nhwc=test_data_nhwc,
172
+ custom_input_op_name_np_data_path=custom_input_op_name_np_data_path,
173
+ )
174
+ except Exception as ex:
175
+ pass
176
+ del val_model
177
+
178
+ # Get np.ndarray for validation
179
+ validation_data_1 = None
180
+ validation_data_2 = None
181
+
182
+ if not disable_strict_mode:
183
+ if len(tf_pre_tensor_infos) == 1:
184
+ if not isinstance(input_tensor, np.ndarray):
185
+ validation_data_1 = list(tf_pre_tensor_infos.values())[0]
186
+ else:
187
+ validation_data_1 = copy.deepcopy(input_tensor)
188
+ elif len(tf_pre_tensor_infos) == 2:
189
+ if not isinstance(input_tensor, np.ndarray):
190
+ validation_data_1 = list(tf_pre_tensor_infos.values())[0]
191
+ else:
192
+ validation_data_1 = copy.deepcopy(input_tensor)
193
+ if not isinstance(input_tensor_shape, np.ndarray):
194
+ validation_data_2 = list(tf_pre_tensor_infos.values())[1]
195
+ else:
196
+ validation_data_2 = copy.deepcopy(input_tensor_shape)
197
+
198
+ # Get ONNX inference results
199
+ onnx_tensor_infos = None
200
+ if onnx_tensor_infos_for_validation is not None \
201
+ and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
202
+ onnx_tensor_infos = {
203
+ graph_node_output.name: onnx_tensor_infos_for_validation[graph_node_output.name]
204
+ }
205
+ del onnx_tensor_infos_for_validation
206
+
207
+ # ONNX : N,C,W
208
+ # TF : N,W,C
209
+ # TF-axes: [1]
210
+ #
211
+ # ONNX: N,C,H,W
212
+ # TF : N,H,W,C
213
+ # TF-axes: [1,2]
214
+ #
215
+ # ONNX: N,C,D,H,W
216
+ # TF : N,D,H,W,C
217
+ # TF-axes: [1,2,3]
218
+
219
+ # Automatic correction of accuracy degradation
220
+ min_abs_err = sys.maxsize
221
+ min_abs_err_perm_1: List[int] = [idx for idx in range(input_tensor_rank)]
222
+ min_abs_err_perm_2: List[int] = [idx for idx, val in enumerate(input_tensor_shape)]
223
+
224
+ if not disable_strict_mode:
225
+ if onnx_tensor_infos is not None and validation_data_1 is not None and validation_data_2 is not None:
226
+ tensor_1_candidate_for_transpositions = list(itertools.permutations(range(input_tensor_rank)))
227
+ tensor_2_candidate_for_transpositions = list(itertools.permutations(range(len(min_abs_err_perm_2))))
228
+ # Search for the axis with the smallest error
229
+ for tensor_1_candidate_for_transposition in tensor_1_candidate_for_transpositions:
230
+ try:
231
+ for tensor_2_candidate_for_transposition in tensor_2_candidate_for_transpositions:
232
+ try:
233
+ # Build TF dummy model
234
+ input_1 = tf_keras.Input(
235
+ shape=validation_data_1.shape[1:],
236
+ batch_size=validation_data_1.shape[0] \
237
+ if isinstance(validation_data_1.shape[0], int) else None,
238
+ name='dummy_input_1',
239
+ dtype=validation_data_1.dtype,
240
+ )
241
+ expand_shape = [validation_data_2[pos] for pos in tensor_2_candidate_for_transposition]
242
+ input_2 = tf_keras.Input(
243
+ shape=[len(expand_shape)],
244
+ batch_size=1,
245
+ name='dummy_input_2',
246
+ dtype=validation_data_2.dtype,
247
+ )
248
+ a=0
249
+
250
+ ones = tf.ones(input_2[0], dtype=input_tensor.dtype)
251
+ expanded_tensor = input_1 * ones
252
+ a=0
253
+
254
+ val_model = tf_keras.Model(
255
+ inputs=[
256
+ input_1,
257
+ input_2,
258
+ ],
259
+ outputs=[
260
+ expanded_tensor,
261
+ ],
262
+ )
263
+ a=0
264
+ # TF dummy inference
265
+ tf_tensor_infos: Dict[Any] = \
266
+ dummy_tf_inference(
267
+ model=val_model,
268
+ inputs=[
269
+ input_1,
270
+ input_2,
271
+ ],
272
+ verification_datas=[
273
+ validation_data_1,
274
+ tf.convert_to_tensor([expand_shape], dtype=tf.int64) if isinstance(expand_shape, list) else tf.expand_dims(expand_shape, axis=0),
275
+ ],
276
+ )
277
+ del input_1
278
+ del input_2
279
+ del val_model
280
+
281
+ # Validation
282
+ onnx_tf_output_pairs = {
283
+ (oi[0], ti[0]): (oi[1], ti[1]) \
284
+ for oi, ti in zip(onnx_tensor_infos.items(), tf_tensor_infos.items())
285
+ }
286
+ """
287
+ check_results: Dict[str, List[np.ndarray, int, float|int]]
288
+ {
289
+ onnx_output_name: [
290
+ onnx_tensor,
291
+ matched_flg, <--- 0: Unmatched, 1: Matched, 2: Skipped (Deleted or Shape Unmatched)
292
+ max_abs_err,
293
+ ]
294
+ }
295
+ """
296
+ check_results = \
297
+ onnx_tf_tensor_validation(
298
+ output_pairs=onnx_tf_output_pairs,
299
+ rtol=0.0,
300
+ atol=0.0,
301
+ )
302
+ result_err = sum([val[2] for val in check_results.values()])
303
+ if result_err < min_abs_err:
304
+ min_abs_err = result_err
305
+ min_abs_err_perm_1 = list(tensor_1_candidate_for_transposition)
306
+ min_abs_err_perm_2 = list(tensor_2_candidate_for_transposition)
307
+ if min_abs_err < 1e-3:
308
+ break
309
+ except Exception as ex:
310
+ pass
311
+ except Exception as ex:
312
+ pass
313
+
314
+ input_tensor = \
315
+ transpose_with_flexing_deterrence(
316
+ input_tensor=input_tensor,
317
+ perm=min_abs_err_perm_1,
318
+ output_shape=input_tensor_shape \
319
+ if None not in input_tensor.shape and input_tensor.shape != [] else None,
320
+ **kwargs,
321
+ )
322
+ input_tensor_shape = [input_tensor_shape[pos] for pos in min_abs_err_perm_2]
323
+ ones = tf.ones(input_tensor_shape, dtype=input_tensor.dtype)
324
+ expanded_tensor = input_tensor * ones
325
+
326
+ tf_layers_dict[graph_node_output.name]['tf_node'] = expanded_tensor
327
+ tf_type = tf.expand_dims
328
+
122
329
  # Post-process transpose
123
330
  tf_layers_dict[graph_node_output.name]['tf_node'] = post_process_transpose(
124
331
  value_before_transpose=tf_layers_dict[graph_node_output.name]['tf_node'],