onnx2tf 1.29.12__py3-none-any.whl → 1.29.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2tf/__init__.py +1 -1
- onnx2tf/onnx2tf.py +107 -0
- onnx2tf/ops/AveragePool.py +49 -0
- onnx2tf/ops/Expand.py +12 -1
- onnx2tf/ops/Flatten.py +106 -24
- onnx2tf/ops/Slice.py +34 -2
- onnx2tf/utils/common_functions.py +223 -0
- {onnx2tf-1.29.12.dist-info → onnx2tf-1.29.14.dist-info}/METADATA +4 -3
- {onnx2tf-1.29.12.dist-info → onnx2tf-1.29.14.dist-info}/RECORD +11 -11
- {onnx2tf-1.29.12.dist-info → onnx2tf-1.29.14.dist-info}/WHEEL +0 -0
- {onnx2tf-1.29.12.dist-info → onnx2tf-1.29.14.dist-info}/entry_points.txt +0 -0
onnx2tf/__init__.py
CHANGED
onnx2tf/onnx2tf.py
CHANGED
|
@@ -62,6 +62,73 @@ from onnx2tf.utils.enums import (
|
|
|
62
62
|
from onnx2tf.utils.logging import *
|
|
63
63
|
from sng4onnx import generate as op_name_auto_generate
|
|
64
64
|
|
|
65
|
+
def apply_nonzero_passthrough(
|
|
66
|
+
*,
|
|
67
|
+
graph: gs.Graph,
|
|
68
|
+
onnx_tensor_infos: Optional[Dict[str, np.ndarray]],
|
|
69
|
+
onnx_input_datas_for_validation: Optional[Dict[str, np.ndarray]] = None,
|
|
70
|
+
update_graph_shape: bool = False,
|
|
71
|
+
) -> None:
|
|
72
|
+
if onnx_tensor_infos is None:
|
|
73
|
+
return
|
|
74
|
+
for graph_node in graph.nodes:
|
|
75
|
+
if graph_node.op != 'NonZero':
|
|
76
|
+
continue
|
|
77
|
+
if len(graph_node.inputs) == 0 or len(graph_node.outputs) == 0:
|
|
78
|
+
continue
|
|
79
|
+
nonzero_input = graph_node.inputs[0]
|
|
80
|
+
nonzero_output = graph_node.outputs[0]
|
|
81
|
+
passthrough_tensor = None
|
|
82
|
+
input_name = nonzero_input.name
|
|
83
|
+
|
|
84
|
+
if input_name in onnx_tensor_infos:
|
|
85
|
+
passthrough_tensor = onnx_tensor_infos[input_name]
|
|
86
|
+
elif onnx_input_datas_for_validation and input_name in onnx_input_datas_for_validation:
|
|
87
|
+
passthrough_tensor = onnx_input_datas_for_validation[input_name]
|
|
88
|
+
elif hasattr(nonzero_input, 'values'):
|
|
89
|
+
passthrough_tensor = nonzero_input.values
|
|
90
|
+
|
|
91
|
+
if passthrough_tensor is not None:
|
|
92
|
+
onnx_tensor_infos[nonzero_output.name] = passthrough_tensor
|
|
93
|
+
if update_graph_shape and hasattr(passthrough_tensor, 'shape'):
|
|
94
|
+
nonzero_output.shape = list(passthrough_tensor.shape)
|
|
95
|
+
|
|
96
|
+
def apply_nonzero_passthrough_tf(
|
|
97
|
+
*,
|
|
98
|
+
graph: gs.Graph,
|
|
99
|
+
tf_layers_dict: Dict[str, Any],
|
|
100
|
+
tf_tensor_infos: Optional[Dict[str, np.ndarray]],
|
|
101
|
+
tf_input_datas_for_validation: Optional[Dict[str, np.ndarray]] = None,
|
|
102
|
+
) -> None:
|
|
103
|
+
if tf_tensor_infos is None:
|
|
104
|
+
return
|
|
105
|
+
for graph_node in graph.nodes:
|
|
106
|
+
if graph_node.op != 'NonZero':
|
|
107
|
+
continue
|
|
108
|
+
if len(graph_node.inputs) == 0 or len(graph_node.outputs) == 0:
|
|
109
|
+
continue
|
|
110
|
+
input_name = graph_node.inputs[0].name
|
|
111
|
+
output_name = graph_node.outputs[0].name
|
|
112
|
+
input_info = tf_layers_dict.get(input_name)
|
|
113
|
+
output_info = tf_layers_dict.get(output_name)
|
|
114
|
+
if input_info is None or output_info is None:
|
|
115
|
+
continue
|
|
116
|
+
input_tf_node = input_info.get('tf_node')
|
|
117
|
+
output_tf_node = output_info.get('tf_node')
|
|
118
|
+
if input_tf_node is None or output_tf_node is None:
|
|
119
|
+
continue
|
|
120
|
+
input_tf_name = input_tf_node.name
|
|
121
|
+
output_tf_name = output_tf_node.name
|
|
122
|
+
passthrough_tensor = None
|
|
123
|
+
|
|
124
|
+
if input_tf_name in tf_tensor_infos:
|
|
125
|
+
passthrough_tensor = tf_tensor_infos[input_tf_name]
|
|
126
|
+
elif tf_input_datas_for_validation and input_tf_name in tf_input_datas_for_validation:
|
|
127
|
+
passthrough_tensor = tf_input_datas_for_validation[input_tf_name]
|
|
128
|
+
|
|
129
|
+
if passthrough_tensor is not None:
|
|
130
|
+
tf_tensor_infos[output_tf_name] = passthrough_tensor
|
|
131
|
+
|
|
65
132
|
def convert(
|
|
66
133
|
input_onnx_file_path: Optional[str] = '',
|
|
67
134
|
onnx_graph: Optional[onnx.ModelProto] = None,
|
|
@@ -1113,6 +1180,7 @@ def convert(
|
|
|
1113
1180
|
# Used to verify the output error of each OP in the TensorFlow model.
|
|
1114
1181
|
full_ops_output_names = []
|
|
1115
1182
|
onnx_tensor_infos_for_validation = None
|
|
1183
|
+
onnx_input_datas_for_validation = {}
|
|
1116
1184
|
for graph_node in graph.nodes:
|
|
1117
1185
|
full_ops_output_names_sub = []
|
|
1118
1186
|
for graph_node_output in graph_node.outputs:
|
|
@@ -1132,6 +1200,7 @@ def convert(
|
|
|
1132
1200
|
enable_ort_output_memmap=onnxruntime_output_memmap,
|
|
1133
1201
|
ort_output_memmap_dir=onnxruntime_output_memmap_dir,
|
|
1134
1202
|
shape_hints=shape_hints if (check_onnx_tf_outputs_elementwise_close or check_onnx_tf_outputs_elementwise_close_full) else None,
|
|
1203
|
+
input_datas_for_validation=onnx_input_datas_for_validation,
|
|
1135
1204
|
)
|
|
1136
1205
|
"""
|
|
1137
1206
|
onnx_tensor_infos_for_validation:
|
|
@@ -1148,12 +1217,20 @@ def convert(
|
|
|
1148
1217
|
in zip(full_ops_output_names, onnx_outputs_for_validation)
|
|
1149
1218
|
}
|
|
1150
1219
|
del onnx_outputs_for_validation
|
|
1220
|
+
|
|
1221
|
+
apply_nonzero_passthrough(
|
|
1222
|
+
graph=graph,
|
|
1223
|
+
onnx_tensor_infos=onnx_tensor_infos_for_validation,
|
|
1224
|
+
onnx_input_datas_for_validation=onnx_input_datas_for_validation,
|
|
1225
|
+
update_graph_shape=True,
|
|
1226
|
+
)
|
|
1151
1227
|
except Exception as ex:
|
|
1152
1228
|
warn(
|
|
1153
1229
|
f'The optimization process for shape estimation is skipped ' +
|
|
1154
1230
|
f'because it contains OPs that cannot be inferred by the standard onnxruntime.'
|
|
1155
1231
|
)
|
|
1156
1232
|
warn(f'{ex}')
|
|
1233
|
+
onnx_input_datas_for_validation = None
|
|
1157
1234
|
additional_parameters['onnx_tensor_infos_for_validation'] = onnx_tensor_infos_for_validation
|
|
1158
1235
|
additional_parameters['test_data_nhwc'] = test_data_nhwc
|
|
1159
1236
|
additional_parameters['custom_input_op_name_np_data_path'] = custom_input_op_name_np_data_path
|
|
@@ -2061,6 +2138,7 @@ def convert(
|
|
|
2061
2138
|
dummy_onnx_outputs = None
|
|
2062
2139
|
try:
|
|
2063
2140
|
# ONNX dummy inference
|
|
2141
|
+
onnx_input_datas_for_validation = {}
|
|
2064
2142
|
dummy_onnx_outputs: List[np.ndarray] = \
|
|
2065
2143
|
dummy_onnx_inference(
|
|
2066
2144
|
onnx_graph=onnx_graph,
|
|
@@ -2072,6 +2150,7 @@ def convert(
|
|
|
2072
2150
|
enable_ort_output_memmap=onnxruntime_output_memmap,
|
|
2073
2151
|
ort_output_memmap_dir=onnxruntime_output_memmap_dir,
|
|
2074
2152
|
shape_hints=shape_hints,
|
|
2153
|
+
input_datas_for_validation=onnx_input_datas_for_validation,
|
|
2075
2154
|
)
|
|
2076
2155
|
except Exception as ex:
|
|
2077
2156
|
warn(
|
|
@@ -2081,6 +2160,7 @@ def convert(
|
|
|
2081
2160
|
warn(f'{ex}')
|
|
2082
2161
|
else:
|
|
2083
2162
|
# TF dummy inference
|
|
2163
|
+
tf_input_datas_for_validation = {}
|
|
2084
2164
|
tf_tensor_infos: Dict[Any] = \
|
|
2085
2165
|
dummy_tf_inference(
|
|
2086
2166
|
model=model,
|
|
@@ -2088,6 +2168,7 @@ def convert(
|
|
|
2088
2168
|
test_data_nhwc=test_data_nhwc,
|
|
2089
2169
|
custom_input_op_name_np_data_path=custom_input_op_name_np_data_path,
|
|
2090
2170
|
shape_hints=shape_hints,
|
|
2171
|
+
input_datas_for_validation=tf_input_datas_for_validation,
|
|
2091
2172
|
keep_shape_absolutely_input_names=keep_shape_absolutely_input_names,
|
|
2092
2173
|
keep_ncw_or_nchw_or_ncdhw_input_names=keep_ncw_or_nchw_or_ncdhw_input_names,
|
|
2093
2174
|
keep_nwc_or_nhwc_or_ndhwc_input_names=keep_nwc_or_nhwc_or_ndhwc_input_names,
|
|
@@ -2097,6 +2178,17 @@ def convert(
|
|
|
2097
2178
|
output_name: dummy_onnx_output \
|
|
2098
2179
|
for output_name, dummy_onnx_output in zip(ops_output_names, dummy_onnx_outputs)
|
|
2099
2180
|
}
|
|
2181
|
+
apply_nonzero_passthrough(
|
|
2182
|
+
graph=graph,
|
|
2183
|
+
onnx_tensor_infos=onnx_tensor_infos,
|
|
2184
|
+
onnx_input_datas_for_validation=onnx_input_datas_for_validation,
|
|
2185
|
+
)
|
|
2186
|
+
apply_nonzero_passthrough_tf(
|
|
2187
|
+
graph=graph,
|
|
2188
|
+
tf_layers_dict=tf_layers_dict,
|
|
2189
|
+
tf_tensor_infos=tf_tensor_infos,
|
|
2190
|
+
tf_input_datas_for_validation=tf_input_datas_for_validation,
|
|
2191
|
+
)
|
|
2100
2192
|
"""
|
|
2101
2193
|
np.allclose(
|
|
2102
2194
|
dummy_onnx_outputs,
|
|
@@ -2326,6 +2418,7 @@ def convert(
|
|
|
2326
2418
|
# Initial accuracy check
|
|
2327
2419
|
try:
|
|
2328
2420
|
# ONNX dummy inference
|
|
2421
|
+
onnx_input_datas_for_validation = {}
|
|
2329
2422
|
dummy_onnx_outputs: List[np.ndarray] = \
|
|
2330
2423
|
dummy_onnx_inference(
|
|
2331
2424
|
onnx_graph=onnx_graph,
|
|
@@ -2337,9 +2430,11 @@ def convert(
|
|
|
2337
2430
|
enable_ort_output_memmap=onnxruntime_output_memmap,
|
|
2338
2431
|
ort_output_memmap_dir=onnxruntime_output_memmap_dir,
|
|
2339
2432
|
shape_hints=shape_hints,
|
|
2433
|
+
input_datas_for_validation=onnx_input_datas_for_validation,
|
|
2340
2434
|
)
|
|
2341
2435
|
|
|
2342
2436
|
# TF dummy inference
|
|
2437
|
+
tf_input_datas_for_validation = {}
|
|
2343
2438
|
tf_tensor_infos: Dict[Any] = \
|
|
2344
2439
|
dummy_tf_inference(
|
|
2345
2440
|
model=validation_model,
|
|
@@ -2347,6 +2442,7 @@ def convert(
|
|
|
2347
2442
|
test_data_nhwc=test_data_nhwc,
|
|
2348
2443
|
custom_input_op_name_np_data_path=custom_input_op_name_np_data_path,
|
|
2349
2444
|
shape_hints=shape_hints,
|
|
2445
|
+
input_datas_for_validation=tf_input_datas_for_validation,
|
|
2350
2446
|
keep_shape_absolutely_input_names=keep_shape_absolutely_input_names,
|
|
2351
2447
|
keep_ncw_or_nchw_or_ncdhw_input_names=keep_ncw_or_nchw_or_ncdhw_input_names,
|
|
2352
2448
|
keep_nwc_or_nhwc_or_ndhwc_input_names=keep_nwc_or_nhwc_or_ndhwc_input_names,
|
|
@@ -2357,6 +2453,17 @@ def convert(
|
|
|
2357
2453
|
output_name: dummy_onnx_output \
|
|
2358
2454
|
for output_name, dummy_onnx_output in zip(ops_output_names, dummy_onnx_outputs)
|
|
2359
2455
|
}
|
|
2456
|
+
apply_nonzero_passthrough(
|
|
2457
|
+
graph=graph,
|
|
2458
|
+
onnx_tensor_infos=onnx_tensor_infos,
|
|
2459
|
+
onnx_input_datas_for_validation=onnx_input_datas_for_validation,
|
|
2460
|
+
)
|
|
2461
|
+
apply_nonzero_passthrough_tf(
|
|
2462
|
+
graph=graph,
|
|
2463
|
+
tf_layers_dict=tf_layers_dict,
|
|
2464
|
+
tf_tensor_infos=tf_tensor_infos,
|
|
2465
|
+
tf_input_datas_for_validation=tf_input_datas_for_validation,
|
|
2466
|
+
)
|
|
2360
2467
|
|
|
2361
2468
|
input_names = [k.name for k in inputs]
|
|
2362
2469
|
for k, v in tf_layers_dict.items():
|
onnx2tf/ops/AveragePool.py
CHANGED
|
@@ -370,6 +370,12 @@ def make_node(
|
|
|
370
370
|
paddings=tf_pads,
|
|
371
371
|
mode='CONSTANT',
|
|
372
372
|
)
|
|
373
|
+
if input_tensor_shape is not None and len(input_tensor_shape) == spatial_size + 2:
|
|
374
|
+
# Preserve known batch/channel dims since dynamic paddings erase shape info.
|
|
375
|
+
padded_tensor = tf.ensure_shape(
|
|
376
|
+
padded_tensor,
|
|
377
|
+
[input_tensor_shape[0]] + [None] * spatial_size + [input_tensor_shape[-1]],
|
|
378
|
+
)
|
|
373
379
|
else:
|
|
374
380
|
if auto_pad == 'SAME_LOWER':
|
|
375
381
|
# switch the order of pads
|
|
@@ -468,6 +474,49 @@ def make_node(
|
|
|
468
474
|
print(error_msg)
|
|
469
475
|
raise AssertionError(error_msg)
|
|
470
476
|
|
|
477
|
+
# Dynamic shape compensation for count_include_pad=False with explicit padding.
|
|
478
|
+
# Use pooled mask to compute valid element counts per window.
|
|
479
|
+
if not is_known_shape and is_explicit_padding and not count_include_pad:
|
|
480
|
+
mask = tf.ones_like(input_tensor, dtype=pooled_tensor.dtype)
|
|
481
|
+
if tf_pads is not None:
|
|
482
|
+
if tf.is_tensor(tf_pads):
|
|
483
|
+
mask = tf.pad(
|
|
484
|
+
tensor=mask,
|
|
485
|
+
paddings=tf_pads,
|
|
486
|
+
mode='CONSTANT',
|
|
487
|
+
)
|
|
488
|
+
elif tf_pads != [0] * spatial_size * 2:
|
|
489
|
+
mask = tf.pad(
|
|
490
|
+
tensor=mask,
|
|
491
|
+
paddings=tf_pads,
|
|
492
|
+
mode='CONSTANT',
|
|
493
|
+
)
|
|
494
|
+
if len(kernel_shape) == 1:
|
|
495
|
+
mask_pooled = AveragePooling1D(
|
|
496
|
+
pool_size=kernel_shape,
|
|
497
|
+
strides=strides,
|
|
498
|
+
padding=tf_pad_mode.upper(),
|
|
499
|
+
)(mask)
|
|
500
|
+
elif len(kernel_shape) == 2:
|
|
501
|
+
mask_pooled = AveragePooling2D(
|
|
502
|
+
pool_size=kernel_shape,
|
|
503
|
+
strides=strides,
|
|
504
|
+
padding=tf_pad_mode.upper(),
|
|
505
|
+
)(mask)
|
|
506
|
+
else:
|
|
507
|
+
mask_pooled = AveragePooling3D(
|
|
508
|
+
pool_size=kernel_shape,
|
|
509
|
+
strides=strides,
|
|
510
|
+
padding=tf_pad_mode.upper(),
|
|
511
|
+
)(mask)
|
|
512
|
+
kernel_volume = float(np.prod(kernel_shape))
|
|
513
|
+
count_valid = mask_pooled * tf.cast(kernel_volume, dtype=mask_pooled.dtype)
|
|
514
|
+
multiplier = tf.math.divide_no_nan(
|
|
515
|
+
tf.cast(kernel_volume, dtype=mask_pooled.dtype),
|
|
516
|
+
count_valid,
|
|
517
|
+
)
|
|
518
|
+
pooled_tensor = pooled_tensor * multiplier
|
|
519
|
+
|
|
471
520
|
# tensorflow average pooling needs extra process to get same output with onnx
|
|
472
521
|
# https://github.com/PINTO0309/onnx2tf/issues/124
|
|
473
522
|
if average_multiplier is not None:
|
onnx2tf/ops/Expand.py
CHANGED
|
@@ -48,6 +48,7 @@ def make_node(
|
|
|
48
48
|
tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
|
|
49
49
|
before_op_output_shape_trans_2 = \
|
|
50
50
|
tf_layers_dict.get(graph_node.inputs[1].name, {}).get('before_op_output_shape_trans', True)
|
|
51
|
+
# Data layout follows input[0]; shape vector (input[1]) should align to it.
|
|
51
52
|
before_op_output_shape_trans = \
|
|
52
53
|
before_op_output_shape_trans_1 \
|
|
53
54
|
and before_op_output_shape_trans_2
|
|
@@ -58,7 +59,7 @@ def make_node(
|
|
|
58
59
|
)
|
|
59
60
|
graph_node_input_2 = get_constant_or_variable(
|
|
60
61
|
graph_node.inputs[1],
|
|
61
|
-
|
|
62
|
+
before_op_output_shape_trans_1,
|
|
62
63
|
)
|
|
63
64
|
graph_node_output: gs.Variable = graph_node.outputs[0]
|
|
64
65
|
shape = graph_node_output.shape
|
|
@@ -106,6 +107,16 @@ def make_node(
|
|
|
106
107
|
**kwargs,
|
|
107
108
|
)
|
|
108
109
|
|
|
110
|
+
# If shape is dynamic (Tensor) and input was transposed to NHWC/NWC/NDHWC,
|
|
111
|
+
# align the shape vector order to TensorFlow's layout.
|
|
112
|
+
if before_op_output_shape_trans_1 \
|
|
113
|
+
and tf.is_tensor(input_tensor_shape) \
|
|
114
|
+
and input_tensor_rank > 2:
|
|
115
|
+
shape_rank = input_tensor_shape.shape.rank
|
|
116
|
+
if shape_rank == 1 or shape_rank is None:
|
|
117
|
+
perm = [0] + list(range(2, input_tensor_rank)) + [1]
|
|
118
|
+
input_tensor_shape = tf.gather(input_tensor_shape, perm)
|
|
119
|
+
|
|
109
120
|
tf_type = None
|
|
110
121
|
if \
|
|
111
122
|
(
|
onnx2tf/ops/Flatten.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import random
|
|
2
2
|
random.seed(0)
|
|
3
3
|
import numpy as np
|
|
4
|
+
import itertools
|
|
4
5
|
np.random.seed(0)
|
|
5
6
|
import tensorflow as tf
|
|
6
7
|
import tf_keras
|
|
@@ -13,6 +14,8 @@ from onnx2tf.utils.common_functions import (
|
|
|
13
14
|
print_node_info,
|
|
14
15
|
inverted_operation_enable_disable,
|
|
15
16
|
make_tf_node_info,
|
|
17
|
+
dummy_tf_inference,
|
|
18
|
+
get_tf_model_inputs,
|
|
16
19
|
pre_process_transpose,
|
|
17
20
|
post_process_transpose,
|
|
18
21
|
transpose_with_flexing_deterrence,
|
|
@@ -84,6 +87,109 @@ def make_node(
|
|
|
84
87
|
**kwargs,
|
|
85
88
|
)
|
|
86
89
|
|
|
90
|
+
# Param replacement
|
|
91
|
+
input_tensor = replace_parameter(
|
|
92
|
+
value_before_replacement=input_tensor,
|
|
93
|
+
param_target='inputs',
|
|
94
|
+
param_name=graph_node.inputs[0].name,
|
|
95
|
+
**kwargs,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# Pre-process transpose
|
|
99
|
+
input_tensor = pre_process_transpose(
|
|
100
|
+
value_before_transpose=input_tensor,
|
|
101
|
+
param_target='inputs',
|
|
102
|
+
param_name=graph_node.inputs[0].name,
|
|
103
|
+
**kwargs,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
perm = [
|
|
107
|
+
convert_axis(
|
|
108
|
+
axis=idx,
|
|
109
|
+
tensor_rank=input_tensor_rank,
|
|
110
|
+
before_op_output_shape_trans=before_op_output_shape_trans,
|
|
111
|
+
) for idx in range(input_tensor_rank)
|
|
112
|
+
]
|
|
113
|
+
|
|
114
|
+
# Brute-force transpose to match ONNX dummy inference outputs when available.
|
|
115
|
+
onnx_tensor_infos_for_validation = kwargs.get('onnx_tensor_infos_for_validation', None)
|
|
116
|
+
test_data_nhwc: np.ndarray = kwargs.get('test_data_nhwc', None)
|
|
117
|
+
custom_input_op_name_np_data_path: str = kwargs.get('custom_input_op_name_np_data_path', None)
|
|
118
|
+
disable_strict_mode: bool = kwargs.get('disable_strict_mode', False)
|
|
119
|
+
if not disable_strict_mode \
|
|
120
|
+
and onnx_tensor_infos_for_validation is not None \
|
|
121
|
+
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
|
|
122
|
+
validation_input = None
|
|
123
|
+
if isinstance(input_tensor, np.ndarray):
|
|
124
|
+
validation_input = input_tensor
|
|
125
|
+
elif hasattr(input_tensor, 'numpy'):
|
|
126
|
+
try:
|
|
127
|
+
validation_input = input_tensor.numpy()
|
|
128
|
+
except Exception:
|
|
129
|
+
validation_input = None
|
|
130
|
+
else:
|
|
131
|
+
try:
|
|
132
|
+
tf_model_inputs = get_tf_model_inputs(tf_layers_dict=tf_layers_dict)
|
|
133
|
+
val_model = tf_keras.Model(
|
|
134
|
+
inputs=tf_model_inputs,
|
|
135
|
+
outputs=[input_tensor],
|
|
136
|
+
)
|
|
137
|
+
tf_pre_tensor_infos = dummy_tf_inference(
|
|
138
|
+
model=val_model,
|
|
139
|
+
inputs=tf_model_inputs,
|
|
140
|
+
test_data_nhwc=test_data_nhwc,
|
|
141
|
+
custom_input_op_name_np_data_path=custom_input_op_name_np_data_path,
|
|
142
|
+
)
|
|
143
|
+
if len(tf_pre_tensor_infos) >= 1:
|
|
144
|
+
validation_input = list(tf_pre_tensor_infos.values())[0]
|
|
145
|
+
del val_model
|
|
146
|
+
except Exception:
|
|
147
|
+
validation_input = None
|
|
148
|
+
if validation_input is None:
|
|
149
|
+
onnx_input_name = graph_node.inputs[0].name
|
|
150
|
+
if onnx_tensor_infos_for_validation.get(onnx_input_name, None) is not None:
|
|
151
|
+
validation_input = onnx_tensor_infos_for_validation[onnx_input_name]
|
|
152
|
+
|
|
153
|
+
onnx_output = onnx_tensor_infos_for_validation.get(graph_node_output.name, None)
|
|
154
|
+
if validation_input is not None and onnx_output is not None:
|
|
155
|
+
rank = len(validation_input.shape)
|
|
156
|
+
if rank <= 6:
|
|
157
|
+
perm_candidates = itertools.permutations(range(rank))
|
|
158
|
+
else:
|
|
159
|
+
perm_candidates = [perm]
|
|
160
|
+
|
|
161
|
+
def _flatten_np(arr, axis):
|
|
162
|
+
if axis == 0:
|
|
163
|
+
return arr.reshape(1, -1)
|
|
164
|
+
if axis >= arr.ndim:
|
|
165
|
+
return arr.reshape(-1, 1)
|
|
166
|
+
return arr.reshape(
|
|
167
|
+
int(np.prod(arr.shape[:axis])),
|
|
168
|
+
int(np.prod(arr.shape[axis:])),
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
matched_perm = None
|
|
172
|
+
matched_axis = None
|
|
173
|
+
for cand in perm_candidates:
|
|
174
|
+
try:
|
|
175
|
+
cand_arr = np.transpose(validation_input, cand)
|
|
176
|
+
for axis_candidate in range(0, rank + 1):
|
|
177
|
+
cand_flat = _flatten_np(cand_arr, axis_candidate)
|
|
178
|
+
if cand_flat.shape != onnx_output.shape:
|
|
179
|
+
continue
|
|
180
|
+
if np.allclose(cand_flat, onnx_output, rtol=0.0, atol=0.0, equal_nan=True):
|
|
181
|
+
matched_perm = list(cand)
|
|
182
|
+
matched_axis = axis_candidate
|
|
183
|
+
break
|
|
184
|
+
if matched_perm is not None:
|
|
185
|
+
break
|
|
186
|
+
except Exception:
|
|
187
|
+
continue
|
|
188
|
+
if matched_perm is not None:
|
|
189
|
+
perm = matched_perm
|
|
190
|
+
if matched_axis is not None:
|
|
191
|
+
axis = matched_axis
|
|
192
|
+
|
|
87
193
|
# Generation of TF OP
|
|
88
194
|
cal_shape = None
|
|
89
195
|
if axis == 0:
|
|
@@ -134,30 +240,6 @@ def make_node(
|
|
|
134
240
|
has_str_outputshape = True in [True for dim in output_shape if isinstance(dim, str)]
|
|
135
241
|
has_undefined_outputshape = has_none_outputshape or has_str_outputshape
|
|
136
242
|
cal_shape = cal_shape if has_undefined_outputshape else output_shape
|
|
137
|
-
|
|
138
|
-
# Param replacement
|
|
139
|
-
input_tensor = replace_parameter(
|
|
140
|
-
value_before_replacement=input_tensor,
|
|
141
|
-
param_target='inputs',
|
|
142
|
-
param_name=graph_node.inputs[0].name,
|
|
143
|
-
**kwargs,
|
|
144
|
-
)
|
|
145
|
-
|
|
146
|
-
# Pre-process transpose
|
|
147
|
-
input_tensor = pre_process_transpose(
|
|
148
|
-
value_before_transpose=input_tensor,
|
|
149
|
-
param_target='inputs',
|
|
150
|
-
param_name=graph_node.inputs[0].name,
|
|
151
|
-
**kwargs,
|
|
152
|
-
)
|
|
153
|
-
|
|
154
|
-
perm = [
|
|
155
|
-
convert_axis(
|
|
156
|
-
axis=idx,
|
|
157
|
-
tensor_rank=input_tensor_rank,
|
|
158
|
-
before_op_output_shape_trans=before_op_output_shape_trans,
|
|
159
|
-
) for idx in range(input_tensor_rank)
|
|
160
|
-
]
|
|
161
243
|
input_tensor = transpose_with_flexing_deterrence(
|
|
162
244
|
input_tensor=input_tensor,
|
|
163
245
|
perm=list(perm) if perm is not None else None,
|
onnx2tf/ops/Slice.py
CHANGED
|
@@ -434,7 +434,23 @@ def make_node(
|
|
|
434
434
|
dtype=tf.int32,
|
|
435
435
|
)
|
|
436
436
|
if hasattr(begin_mask_, '_inferred_value') and begin_mask_._inferred_value == [None]:
|
|
437
|
-
|
|
437
|
+
axes_list = None
|
|
438
|
+
if axes is not None:
|
|
439
|
+
if isinstance(axes, (list, tuple)):
|
|
440
|
+
axes_list = list(axes)
|
|
441
|
+
elif isinstance(axes, np.ndarray):
|
|
442
|
+
axes_list = axes.tolist() if axes.ndim > 0 else [int(axes)]
|
|
443
|
+
elif tf.is_tensor(axes):
|
|
444
|
+
if hasattr(axes, 'numpy'):
|
|
445
|
+
axes_list = axes.numpy().tolist()
|
|
446
|
+
elif hasattr(axes, '_inferred_value') and axes._inferred_value not in (None, [None]):
|
|
447
|
+
axes_list = list(axes._inferred_value)
|
|
448
|
+
if axes_list is not None:
|
|
449
|
+
begin_mask_ = sum(
|
|
450
|
+
1 << axis for axis in range(input_tensor_rank) if axis not in axes_list
|
|
451
|
+
)
|
|
452
|
+
else:
|
|
453
|
+
begin_mask_ = 0
|
|
438
454
|
|
|
439
455
|
##### end_mask
|
|
440
456
|
end_bit_mask = tf.constant([2**idx for idx in range(input_tensor_rank)], dtype=tf.int32)
|
|
@@ -446,7 +462,23 @@ def make_node(
|
|
|
446
462
|
dtype=tf.int32,
|
|
447
463
|
)
|
|
448
464
|
if hasattr(end_mask_, '_inferred_value') and end_mask_._inferred_value == [None]:
|
|
449
|
-
|
|
465
|
+
axes_list = None
|
|
466
|
+
if axes is not None:
|
|
467
|
+
if isinstance(axes, (list, tuple)):
|
|
468
|
+
axes_list = list(axes)
|
|
469
|
+
elif isinstance(axes, np.ndarray):
|
|
470
|
+
axes_list = axes.tolist() if axes.ndim > 0 else [int(axes)]
|
|
471
|
+
elif tf.is_tensor(axes):
|
|
472
|
+
if hasattr(axes, 'numpy'):
|
|
473
|
+
axes_list = axes.numpy().tolist()
|
|
474
|
+
elif hasattr(axes, '_inferred_value') and axes._inferred_value not in (None, [None]):
|
|
475
|
+
axes_list = list(axes._inferred_value)
|
|
476
|
+
if axes_list is not None:
|
|
477
|
+
end_mask_ = sum(
|
|
478
|
+
1 << axis for axis in range(input_tensor_rank) if axis not in axes_list
|
|
479
|
+
)
|
|
480
|
+
else:
|
|
481
|
+
end_mask_ = 0
|
|
450
482
|
|
|
451
483
|
# strided_slice
|
|
452
484
|
tf_layers_dict[graph_node_output.name]['tf_node'] = \
|
|
@@ -896,6 +896,19 @@ def explicit_broadcast(
|
|
|
896
896
|
const_or_var_2: Any
|
|
897
897
|
gs.Variable or np.ndarray
|
|
898
898
|
"""
|
|
899
|
+
def _tf_broadcastable(shape_a, shape_b):
|
|
900
|
+
if shape_a is None or shape_b is None:
|
|
901
|
+
return False
|
|
902
|
+
if len(shape_a) != len(shape_b):
|
|
903
|
+
return False
|
|
904
|
+
for dim_a, dim_b in zip(shape_a, shape_b):
|
|
905
|
+
if dim_a is None or dim_b is None:
|
|
906
|
+
continue
|
|
907
|
+
if dim_a == dim_b or dim_a == 1 or dim_b == 1:
|
|
908
|
+
continue
|
|
909
|
+
return False
|
|
910
|
+
return True
|
|
911
|
+
|
|
899
912
|
graph_node_input_name1 = None
|
|
900
913
|
graph_node_input_name2 = None
|
|
901
914
|
graph_node_input_shape1 = []
|
|
@@ -928,6 +941,29 @@ def explicit_broadcast(
|
|
|
928
941
|
if graph_node_input_shape1 is None or graph_node_input_shape2 is None:
|
|
929
942
|
return const_or_var_1, const_or_var_2
|
|
930
943
|
|
|
944
|
+
# If one operand is 1D and matches the last dimension of the other operand,
|
|
945
|
+
# align it to the last axis to avoid unintended transpose.
|
|
946
|
+
if len(const_or_var_1.shape) == 1 and len(const_or_var_2.shape) > 1:
|
|
947
|
+
dim_1 = const_or_var_1.shape[0]
|
|
948
|
+
dim_2_last = const_or_var_2.shape[-1]
|
|
949
|
+
if isinstance(dim_1, int) and isinstance(dim_2_last, int) and dim_1 == dim_2_last:
|
|
950
|
+
target_shape = [1] * (len(const_or_var_2.shape) - 1) + [dim_1]
|
|
951
|
+
if isinstance(const_or_var_1, np.ndarray):
|
|
952
|
+
const_or_var_1 = const_or_var_1.reshape(target_shape)
|
|
953
|
+
else:
|
|
954
|
+
const_or_var_1 = tf.reshape(const_or_var_1, target_shape)
|
|
955
|
+
return const_or_var_1, const_or_var_2
|
|
956
|
+
if len(const_or_var_2.shape) == 1 and len(const_or_var_1.shape) > 1:
|
|
957
|
+
dim_2 = const_or_var_2.shape[0]
|
|
958
|
+
dim_1_last = const_or_var_1.shape[-1]
|
|
959
|
+
if isinstance(dim_2, int) and isinstance(dim_1_last, int) and dim_2 == dim_1_last:
|
|
960
|
+
target_shape = [1] * (len(const_or_var_1.shape) - 1) + [dim_2]
|
|
961
|
+
if isinstance(const_or_var_2, np.ndarray):
|
|
962
|
+
const_or_var_2 = const_or_var_2.reshape(target_shape)
|
|
963
|
+
else:
|
|
964
|
+
const_or_var_2 = tf.reshape(const_or_var_2, target_shape)
|
|
965
|
+
return const_or_var_1, const_or_var_2
|
|
966
|
+
|
|
931
967
|
# If either operand have shape of all 1's, do not broadcast and return as is
|
|
932
968
|
shape_for_judging_skip_processing_1 = [
|
|
933
969
|
i if i is not None else INF_INDEX_VALUE for i in const_or_var_1.shape
|
|
@@ -2403,6 +2439,179 @@ def shape_unmatched_special_avoidance_workaround(
|
|
|
2403
2439
|
return input_tensor_1, input_tensor_2
|
|
2404
2440
|
except:
|
|
2405
2441
|
pass
|
|
2442
|
+
|
|
2443
|
+
def _normalize_shape(shape):
|
|
2444
|
+
if shape is None:
|
|
2445
|
+
return None
|
|
2446
|
+
return [dim if isinstance(dim, int) else None for dim in shape]
|
|
2447
|
+
|
|
2448
|
+
def _broadcastable(shape_a, shape_b):
|
|
2449
|
+
if shape_a is None or shape_b is None:
|
|
2450
|
+
return False
|
|
2451
|
+
if len(shape_a) != len(shape_b):
|
|
2452
|
+
return False
|
|
2453
|
+
for dim_a, dim_b in zip(shape_a[::-1], shape_b[::-1]):
|
|
2454
|
+
if dim_a is None or dim_b is None:
|
|
2455
|
+
continue
|
|
2456
|
+
if dim_a != dim_b and dim_a != 1 and dim_b != 1:
|
|
2457
|
+
return False
|
|
2458
|
+
return True
|
|
2459
|
+
|
|
2460
|
+
def _match_score(shape_a, shape_b):
|
|
2461
|
+
score = 0
|
|
2462
|
+
for dim_a, dim_b in zip(shape_a, shape_b):
|
|
2463
|
+
if dim_a is None or dim_b is None:
|
|
2464
|
+
continue
|
|
2465
|
+
if dim_a == dim_b:
|
|
2466
|
+
score += 1
|
|
2467
|
+
return score
|
|
2468
|
+
|
|
2469
|
+
def _shape_matches(shape_a, shape_b):
|
|
2470
|
+
if shape_a is None or shape_b is None:
|
|
2471
|
+
return False
|
|
2472
|
+
if len(shape_a) != len(shape_b):
|
|
2473
|
+
return False
|
|
2474
|
+
for dim_a, dim_b in zip(shape_a, shape_b):
|
|
2475
|
+
if dim_a is None or dim_b is None:
|
|
2476
|
+
continue
|
|
2477
|
+
if dim_a != dim_b:
|
|
2478
|
+
return False
|
|
2479
|
+
return True
|
|
2480
|
+
|
|
2481
|
+
# Generic layout-alignment for channel-first/last in 3D/4D/5D.
|
|
2482
|
+
# Try a small set of canonical perms and apply the best one if it makes broadcasting possible.
|
|
2483
|
+
try:
|
|
2484
|
+
if hasattr(input_tensor_1, "shape") and hasattr(input_tensor_2, "shape"):
|
|
2485
|
+
input_shape_1 = _normalize_shape(input_tensor_1.shape)
|
|
2486
|
+
input_shape_2 = _normalize_shape(input_tensor_2.shape)
|
|
2487
|
+
if input_shape_1 is not None and input_shape_2 is not None \
|
|
2488
|
+
and len(input_shape_1) == len(input_shape_2) \
|
|
2489
|
+
and len(input_shape_1) in (3, 4, 5):
|
|
2490
|
+
if not _broadcastable(input_shape_1, input_shape_2):
|
|
2491
|
+
rank = len(input_shape_1)
|
|
2492
|
+
perm_cf2cl = [0] + list(range(2, rank)) + [1]
|
|
2493
|
+
perm_cl2cf = [0, rank - 1] + list(range(1, rank - 1))
|
|
2494
|
+
perms = []
|
|
2495
|
+
if perm_cf2cl != list(range(rank)):
|
|
2496
|
+
perms.append(perm_cf2cl)
|
|
2497
|
+
if perm_cl2cf != list(range(rank)) and perm_cl2cf != perm_cf2cl:
|
|
2498
|
+
perms.append(perm_cl2cf)
|
|
2499
|
+
|
|
2500
|
+
onnx_shape_1 = _normalize_shape(
|
|
2501
|
+
graph_node_input_1.shape if hasattr(graph_node_input_1, "shape") else None
|
|
2502
|
+
)
|
|
2503
|
+
onnx_shape_2 = _normalize_shape(
|
|
2504
|
+
graph_node_input_2.shape if hasattr(graph_node_input_2, "shape") else None
|
|
2505
|
+
)
|
|
2506
|
+
|
|
2507
|
+
candidates = []
|
|
2508
|
+
for idx, (shape, other_shape) in enumerate(
|
|
2509
|
+
[(input_shape_1, input_shape_2), (input_shape_2, input_shape_1)]
|
|
2510
|
+
):
|
|
2511
|
+
for perm in perms:
|
|
2512
|
+
permuted = [shape[p] for p in perm]
|
|
2513
|
+
if _broadcastable(permuted, other_shape):
|
|
2514
|
+
score = _match_score(permuted, other_shape)
|
|
2515
|
+
# Prefer transposing the input whose ONNX shape matches current layout.
|
|
2516
|
+
if idx == 0 and _shape_matches(onnx_shape_1, shape):
|
|
2517
|
+
score += 2
|
|
2518
|
+
if idx == 1 and _shape_matches(onnx_shape_2, shape):
|
|
2519
|
+
score += 2
|
|
2520
|
+
candidates.append((score, idx, perm))
|
|
2521
|
+
|
|
2522
|
+
if candidates:
|
|
2523
|
+
candidates.sort(reverse=True)
|
|
2524
|
+
best_score, best_idx, best_perm = candidates[0]
|
|
2525
|
+
# Avoid ambiguous ties.
|
|
2526
|
+
if len(candidates) == 1 or best_score > candidates[1][0]:
|
|
2527
|
+
if best_idx == 0:
|
|
2528
|
+
input_tensor_1 = \
|
|
2529
|
+
transpose_with_flexing_deterrence(
|
|
2530
|
+
input_tensor=input_tensor_1,
|
|
2531
|
+
perm=best_perm,
|
|
2532
|
+
**kwargs,
|
|
2533
|
+
)
|
|
2534
|
+
else:
|
|
2535
|
+
input_tensor_2 = \
|
|
2536
|
+
transpose_with_flexing_deterrence(
|
|
2537
|
+
input_tensor=input_tensor_2,
|
|
2538
|
+
perm=best_perm,
|
|
2539
|
+
**kwargs,
|
|
2540
|
+
)
|
|
2541
|
+
except Exception:
|
|
2542
|
+
pass
|
|
2543
|
+
|
|
2544
|
+
# Heuristic for 3D tensors where one input is (N,1,C) and the other is (N,C,W).
|
|
2545
|
+
# Align by transposing the (N,C,W) tensor to (N,W,C).
|
|
2546
|
+
try:
|
|
2547
|
+
if hasattr(input_tensor_1, "shape") and hasattr(input_tensor_2, "shape"):
|
|
2548
|
+
s1 = list(input_tensor_1.shape)
|
|
2549
|
+
s2 = list(input_tensor_2.shape)
|
|
2550
|
+
if len(s1) == len(s2) == 3:
|
|
2551
|
+
# Normalize unknown dims to None
|
|
2552
|
+
s1 = [dim if isinstance(dim, int) else None for dim in s1]
|
|
2553
|
+
s2 = [dim if isinstance(dim, int) else None for dim in s2]
|
|
2554
|
+
if s1[1] == 1 and s1[2] is not None and s2[1] == s1[2]:
|
|
2555
|
+
input_tensor_2 = \
|
|
2556
|
+
transpose_with_flexing_deterrence(
|
|
2557
|
+
input_tensor=input_tensor_2,
|
|
2558
|
+
perm=[0, 2, 1],
|
|
2559
|
+
**kwargs,
|
|
2560
|
+
)
|
|
2561
|
+
elif s2[1] == 1 and s2[2] is not None and s1[1] == s2[2]:
|
|
2562
|
+
input_tensor_1 = \
|
|
2563
|
+
transpose_with_flexing_deterrence(
|
|
2564
|
+
input_tensor=input_tensor_1,
|
|
2565
|
+
perm=[0, 2, 1],
|
|
2566
|
+
**kwargs,
|
|
2567
|
+
)
|
|
2568
|
+
except Exception:
|
|
2569
|
+
pass
|
|
2570
|
+
|
|
2571
|
+
# Layout mismatch mitigation based on ONNX shapes:
|
|
2572
|
+
# If one input matches ONNX layout and the other matches the transposed layout,
|
|
2573
|
+
# transpose the ONNX-layout input to align with the transposed one.
|
|
2574
|
+
try:
|
|
2575
|
+
if hasattr(input_tensor_1, "shape") and hasattr(input_tensor_2, "shape"):
|
|
2576
|
+
input_shape_1 = list(input_tensor_1.shape)
|
|
2577
|
+
input_shape_2 = list(input_tensor_2.shape)
|
|
2578
|
+
if len(input_shape_1) == len(input_shape_2) and len(input_shape_1) in (3, 4, 5):
|
|
2579
|
+
onnx_shape_1 = None
|
|
2580
|
+
onnx_shape_2 = None
|
|
2581
|
+
if hasattr(graph_node_input_1, "shape") and graph_node_input_1.shape is not None:
|
|
2582
|
+
onnx_shape_1 = [
|
|
2583
|
+
dim if not isinstance(dim, str) else None for dim in graph_node_input_1.shape
|
|
2584
|
+
]
|
|
2585
|
+
if hasattr(graph_node_input_2, "shape") and graph_node_input_2.shape is not None:
|
|
2586
|
+
onnx_shape_2 = [
|
|
2587
|
+
dim if not isinstance(dim, str) else None for dim in graph_node_input_2.shape
|
|
2588
|
+
]
|
|
2589
|
+
if onnx_shape_1 is not None and onnx_shape_2 is not None:
|
|
2590
|
+
perm = [0] + list(range(2, len(input_shape_1))) + [1]
|
|
2591
|
+
permuted_onnx_shape_1 = [onnx_shape_1[p] for p in perm]
|
|
2592
|
+
permuted_onnx_shape_2 = [onnx_shape_2[p] for p in perm]
|
|
2593
|
+
|
|
2594
|
+
in1_matches_onnx = _shape_matches(input_shape_1, onnx_shape_1)
|
|
2595
|
+
in1_matches_perm = _shape_matches(input_shape_1, permuted_onnx_shape_1)
|
|
2596
|
+
in2_matches_onnx = _shape_matches(input_shape_2, onnx_shape_2)
|
|
2597
|
+
in2_matches_perm = _shape_matches(input_shape_2, permuted_onnx_shape_2)
|
|
2598
|
+
|
|
2599
|
+
if in1_matches_perm and in2_matches_onnx and not in2_matches_perm:
|
|
2600
|
+
input_tensor_2 = \
|
|
2601
|
+
transpose_with_flexing_deterrence(
|
|
2602
|
+
input_tensor=input_tensor_2,
|
|
2603
|
+
perm=perm,
|
|
2604
|
+
**kwargs,
|
|
2605
|
+
)
|
|
2606
|
+
elif in2_matches_perm and in1_matches_onnx and not in1_matches_perm:
|
|
2607
|
+
input_tensor_1 = \
|
|
2608
|
+
transpose_with_flexing_deterrence(
|
|
2609
|
+
input_tensor=input_tensor_1,
|
|
2610
|
+
perm=perm,
|
|
2611
|
+
**kwargs,
|
|
2612
|
+
)
|
|
2613
|
+
except Exception:
|
|
2614
|
+
pass
|
|
2406
2615
|
# At least one True value for same_input_shape_as_onnx
|
|
2407
2616
|
# At least one True value in nhwc_flags
|
|
2408
2617
|
# same_input_shape_as_onnx == True and nhwc_flags == False and 3D or 4D or 5D tensor is NHWC transposed
|
|
@@ -3642,6 +3851,7 @@ def dummy_onnx_inference(
|
|
|
3642
3851
|
enable_ort_output_memmap: bool = False,
|
|
3643
3852
|
ort_output_memmap_dir: Optional[str] = None,
|
|
3644
3853
|
shape_hints: Optional[List[str]] = None,
|
|
3854
|
+
input_datas_for_validation: Optional[Dict[str, np.ndarray]] = None,
|
|
3645
3855
|
) -> List[np.ndarray]:
|
|
3646
3856
|
"""Perform inference on ONNX subgraphs with an all-1 dummy tensor.
|
|
3647
3857
|
|
|
@@ -3678,6 +3888,9 @@ def dummy_onnx_inference(
|
|
|
3678
3888
|
Directory to store memmap files. If not specified, a temporary
|
|
3679
3889
|
directory is created and removed on exit.
|
|
3680
3890
|
|
|
3891
|
+
input_datas_for_validation: Optional[Dict[str, np.ndarray]]
|
|
3892
|
+
Optional dict to be filled with the input tensors used for inference.
|
|
3893
|
+
|
|
3681
3894
|
Returns
|
|
3682
3895
|
----------
|
|
3683
3896
|
outputs: List[np.ndarray]
|
|
@@ -3873,6 +4086,9 @@ def dummy_onnx_inference(
|
|
|
3873
4086
|
perm=[0,3,1,2],
|
|
3874
4087
|
).numpy().astype(input_dtype)
|
|
3875
4088
|
|
|
4089
|
+
if input_datas_for_validation is not None:
|
|
4090
|
+
input_datas_for_validation.update(input_datas)
|
|
4091
|
+
|
|
3876
4092
|
dtype_sizes = {
|
|
3877
4093
|
np.dtype('float16'): 2,
|
|
3878
4094
|
np.dtype('float32'): 4,
|
|
@@ -4014,6 +4230,7 @@ def dummy_tf_inference(
|
|
|
4014
4230
|
verification_datas: Optional[List[np.ndarray]] = None,
|
|
4015
4231
|
custom_input_op_name_np_data_path: Optional[str] = None,
|
|
4016
4232
|
shape_hints: Optional[List[str]] = None,
|
|
4233
|
+
input_datas_for_validation: Optional[Dict[str, np.ndarray]] = None,
|
|
4017
4234
|
keep_shape_absolutely_input_names: Optional[List[str]] = None,
|
|
4018
4235
|
keep_ncw_or_nchw_or_ncdhw_input_names: Optional[List[str]] = None,
|
|
4019
4236
|
keep_nwc_or_nhwc_or_ndhwc_input_names: Optional[List[str]] = None,
|
|
@@ -4036,6 +4253,8 @@ def dummy_tf_inference(
|
|
|
4036
4253
|
|
|
4037
4254
|
custom_input_op_name_np_data_path
|
|
4038
4255
|
Path to Numpy file for custom data used for dummy inference
|
|
4256
|
+
input_datas_for_validation: Optional[Dict[str, np.ndarray]]
|
|
4257
|
+
Optional dict to be filled with the input tensors used for inference.
|
|
4039
4258
|
|
|
4040
4259
|
Returns
|
|
4041
4260
|
----------
|
|
@@ -4174,6 +4393,10 @@ def dummy_tf_inference(
|
|
|
4174
4393
|
input_size,
|
|
4175
4394
|
dtype=TF_DTYPES_TO_NUMPY_DTYPES[input_dtype],
|
|
4176
4395
|
)
|
|
4396
|
+
|
|
4397
|
+
if input_datas_for_validation is not None:
|
|
4398
|
+
input_datas_for_validation.update(input_datas)
|
|
4399
|
+
|
|
4177
4400
|
outputs = model(
|
|
4178
4401
|
inputs={
|
|
4179
4402
|
input.name: input_datas[input.name] for input in inputs
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: onnx2tf
|
|
3
|
-
Version: 1.29.
|
|
3
|
+
Version: 1.29.14
|
|
4
4
|
Summary: Self-Created Tools to convert ONNX files (NCHW) to TensorFlow/TFLite/Keras format (NHWC). The purpose of this tool is to solve the massive Transpose extrapolation problem in onnx-tensorflow (onnx-tf).
|
|
5
5
|
Keywords: onnx,tensorflow,tflite,keras,deep-learning,machine-learning
|
|
6
6
|
Author: Katsuya Hyodo
|
|
@@ -13,6 +13,7 @@ Classifier: License :: OSI Approved :: MIT License
|
|
|
13
13
|
Classifier: Operating System :: POSIX :: Linux
|
|
14
14
|
Classifier: Operating System :: Unix
|
|
15
15
|
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
16
17
|
Classifier: Programming Language :: Python :: 3.11
|
|
17
18
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
18
19
|
Requires-Dist: requests==2.32.5
|
|
@@ -363,7 +364,7 @@ Video speed is adjusted approximately 50 times slower than actual speed.
|
|
|
363
364
|
docker run --rm -it \
|
|
364
365
|
-v `pwd`:/workdir \
|
|
365
366
|
-w /workdir \
|
|
366
|
-
ghcr.io/pinto0309/onnx2tf:1.29.
|
|
367
|
+
ghcr.io/pinto0309/onnx2tf:1.29.14
|
|
367
368
|
|
|
368
369
|
or
|
|
369
370
|
|
|
@@ -371,7 +372,7 @@ Video speed is adjusted approximately 50 times slower than actual speed.
|
|
|
371
372
|
docker run --rm -it \
|
|
372
373
|
-v `pwd`:/workdir \
|
|
373
374
|
-w /workdir \
|
|
374
|
-
docker.io/pinto0309/onnx2tf:1.29.
|
|
375
|
+
docker.io/pinto0309/onnx2tf:1.29.14
|
|
375
376
|
|
|
376
377
|
or
|
|
377
378
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
onnx2tf/__init__.py,sha256=
|
|
1
|
+
onnx2tf/__init__.py,sha256=8dbSscURHL1ncvasA8yz2hU36oshMnPpVm9IcPYu_Vc,67
|
|
2
2
|
onnx2tf/__main__.py,sha256=2RSCQ7d4lc6CwD-rlGn9UicPFg-P5du7ZD_yh-kuBEU,57
|
|
3
|
-
onnx2tf/onnx2tf.py,sha256=
|
|
3
|
+
onnx2tf/onnx2tf.py,sha256=O3B_ME8omswggw4xtjxxnC8_uaPHH3Ly8dwSv7w75no,157060
|
|
4
4
|
onnx2tf/ops/Abs.py,sha256=V7btmCG_ZvK_qJovUsguq0ZMJ349mhNQ4FHSgzP_Yuo,4029
|
|
5
5
|
onnx2tf/ops/Acos.py,sha256=Fo8YkFKuWq8Fi2xUrBdKcAH1yJ8r5pjSD0wgLttTNdk,4003
|
|
6
6
|
onnx2tf/ops/Acosh.py,sha256=ATQj2cT5JS_mTfXi0kXqJ1yzSZu5J0zHA5VjV3j7uKY,3588
|
|
@@ -14,7 +14,7 @@ onnx2tf/ops/Asinh.py,sha256=74ZzTEkpxZY4CGfJT2JJU-SHXYL5KZeUkWY2v7hsMMw,3588
|
|
|
14
14
|
onnx2tf/ops/Atan.py,sha256=D24XDMxEwXFtJheQAr3V3IWOUOc6Q5M0-b_83bmGGMM,3981
|
|
15
15
|
onnx2tf/ops/Atanh.py,sha256=VsUYopBWWPoo4gta1_aqvUL6NrVXuVkGid4SqDqYJ9Q,3588
|
|
16
16
|
onnx2tf/ops/Attention.py,sha256=7TMOdPztVLtNKSzeozvaRxhUFVhACci8wvhn7ONKWrQ,21006
|
|
17
|
-
onnx2tf/ops/AveragePool.py,sha256=
|
|
17
|
+
onnx2tf/ops/AveragePool.py,sha256=3pf-DKS76aU1BR8jafOBbfpzkNWop9cHQSZVQjbecdY,22144
|
|
18
18
|
onnx2tf/ops/BatchNormalization.py,sha256=_hlf2-5-j3MCJHEoE2oMNQ8YhCm7ad9h2fwPpTo3i7g,26624
|
|
19
19
|
onnx2tf/ops/Bernoulli.py,sha256=PM0xS0n1q4bnT_9PnbcKW8_Qj8dJYYBQR8kb2X-wIp4,3670
|
|
20
20
|
onnx2tf/ops/BitShift.py,sha256=a28_E9hwA8yfjvtsrSKCZCeeMPB5RBQbjB3cmaNGN6k,3861
|
|
@@ -51,9 +51,9 @@ onnx2tf/ops/Elu.py,sha256=VDd5cKc1h-8nd0bVwWR_CkgfomrBl4NMbjRtAvkoNks,4025
|
|
|
51
51
|
onnx2tf/ops/Equal.py,sha256=ni0gf7nJex8S-oG61bnHc_xn8LuMits3gM6IzGNT65w,4579
|
|
52
52
|
onnx2tf/ops/Erf.py,sha256=ayvSp8Pr9h-VYuIiMorwOC0r9aQ4i4S1Uvaho9R6PYo,4962
|
|
53
53
|
onnx2tf/ops/Exp.py,sha256=MM_Osse7UbJgld2u0fGMcjniJCs40uDztuOodVUqWMU,3583
|
|
54
|
-
onnx2tf/ops/Expand.py,sha256=
|
|
54
|
+
onnx2tf/ops/Expand.py,sha256=u_LrCaWqb-Pdz2F8yWJUFx-E_SNE888pPmHP4-HGx2M,15339
|
|
55
55
|
onnx2tf/ops/EyeLike.py,sha256=VHRlr_WpIGVpZSqfjN7zWQF6XT2KjNVJnjVccxB4P6U,5877
|
|
56
|
-
onnx2tf/ops/Flatten.py,sha256=
|
|
56
|
+
onnx2tf/ops/Flatten.py,sha256=RZZJF8RnZaUf_jCEdTgLppPa6FoeM7BLxHrIkHv1t5c,10292
|
|
57
57
|
onnx2tf/ops/Floor.py,sha256=8izJrNmw8wNmjF_YabIpLs4jm82J-gKcyAQbwV7Yqpc,3589
|
|
58
58
|
onnx2tf/ops/FusedConv.py,sha256=gslI50V3yvt4l0mmodnyHFAu0cORx1J_ZL5cE0rZ8qs,4523
|
|
59
59
|
onnx2tf/ops/GRU.py,sha256=kBHiZlhlPIV2DQCoFYFHxCTwOATeguJy1MSfj2kxqDM,30732
|
|
@@ -168,7 +168,7 @@ onnx2tf/ops/Sign.py,sha256=rJNyo_YTLO5x4yoF_Z_wpaIX4dSOL-vdmKH0SbVDwJc,3585
|
|
|
168
168
|
onnx2tf/ops/Sin.py,sha256=jrv76uQPIfB7UdLGf42MOlRUPM6fQ3GR6BvSybpptFo,3608
|
|
169
169
|
onnx2tf/ops/Sinh.py,sha256=9zXIQWcZiZmu3RnQuQpW-PEgBLOKY51SY0OBu1B5eh8,3706
|
|
170
170
|
onnx2tf/ops/Size.py,sha256=vFD5eae9Jko3tHbBtydj2d3T3tbb4r0xua7OIH40p9M,2665
|
|
171
|
-
onnx2tf/ops/Slice.py,sha256=
|
|
171
|
+
onnx2tf/ops/Slice.py,sha256=ChqpC_l-c32aZzI7o2GP7SyRz142Gwo0ctc75nkXFvE,26788
|
|
172
172
|
onnx2tf/ops/Softmax.py,sha256=CEnHcSm25v1QC4QVDg4fz1NooYY1v-Uq4GORd8dnnr8,14773
|
|
173
173
|
onnx2tf/ops/Softplus.py,sha256=R44YMo8G2Ig15jBO6T2VOI6RhpUmjD70qvSCXFylU-Q,3605
|
|
174
174
|
onnx2tf/ops/Softsign.py,sha256=2ZdKH3KVHZXDzyO7S8f-O_aqRugurbRxd1i2g_fwCos,3600
|
|
@@ -194,12 +194,12 @@ onnx2tf/ops/Where.py,sha256=MaCcY9g4mKZQqCgh4xtoylicP-xVu9f4boKiu_q9Ow8,7711
|
|
|
194
194
|
onnx2tf/ops/Xor.py,sha256=2ceqxHSI1Wtez_CIh8gFfvcu45Xboqfyp1iy3v2vuIs,4590
|
|
195
195
|
onnx2tf/ops/__init__.py,sha256=jnmUWWa-3dHzBZV9bmPzXu6eoz2dumJTzO7i8JdcgSM,25
|
|
196
196
|
onnx2tf/utils/__init__.py,sha256=E9FM9He68VIASDnYp-OrxvHFVn55GzWqw2OEkCqn1zg,27
|
|
197
|
-
onnx2tf/utils/common_functions.py,sha256=
|
|
197
|
+
onnx2tf/utils/common_functions.py,sha256=j8bRC3RK5NlNAV9vwxj38DwDaaCLR2iprRdDjBgv_RA,260619
|
|
198
198
|
onnx2tf/utils/enums.py,sha256=7c5TqetqB07VjyHoxJHfLgtqBqk9ZRyUF33fPOJR1IM,1649
|
|
199
199
|
onnx2tf/utils/iterative_json_optimizer.py,sha256=qqeIxWGxrhcCYk8-ebWnblnOkzDCwi-nseipHzHR_bk,10436
|
|
200
200
|
onnx2tf/utils/json_auto_generator.py,sha256=OC-SfKtUg7zUxaXTAg6kT0ShzIc3ByjDa3FNp173DtA,60302
|
|
201
201
|
onnx2tf/utils/logging.py,sha256=yUCmPuJ_XiUItM3sZMcaMO24JErkQy7zZwVTYWAuiKg,1982
|
|
202
|
-
onnx2tf-1.29.
|
|
203
|
-
onnx2tf-1.29.
|
|
204
|
-
onnx2tf-1.29.
|
|
205
|
-
onnx2tf-1.29.
|
|
202
|
+
onnx2tf-1.29.14.dist-info/WHEEL,sha256=e_m4S054HL0hyR3CpOk-b7Q7fDX6BuFkgL5OjAExXas,80
|
|
203
|
+
onnx2tf-1.29.14.dist-info/entry_points.txt,sha256=GuhvLu7ZlYECumbmoiFlKX0mFPtFi_Ti9L-E5yuQqKs,42
|
|
204
|
+
onnx2tf-1.29.14.dist-info/METADATA,sha256=TT-jjFuqKAE7Tyt9Crx-og515ebykFCODKhYQ-8T-x0,154244
|
|
205
|
+
onnx2tf-1.29.14.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|