onnx2tf 1.29.19__py3-none-any.whl → 1.29.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2tf/__init__.py +1 -1
- onnx2tf/onnx2tf.py +996 -27
- onnx2tf/ops/GatherElements.py +25 -7
- onnx2tf/ops/GatherND.py +28 -1
- onnx2tf/ops/ScatterElements.py +25 -7
- onnx2tf/ops/ScatterND.py +45 -6
- onnx2tf/ops/TensorScatter.py +20 -6
- onnx2tf/utils/common_functions.py +198 -4
- {onnx2tf-1.29.19.dist-info → onnx2tf-1.29.21.dist-info}/METADATA +51 -8
- {onnx2tf-1.29.19.dist-info → onnx2tf-1.29.21.dist-info}/RECORD +12 -12
- {onnx2tf-1.29.19.dist-info → onnx2tf-1.29.21.dist-info}/WHEEL +0 -0
- {onnx2tf-1.29.19.dist-info → onnx2tf-1.29.21.dist-info}/entry_points.txt +0 -0
onnx2tf/ops/GatherElements.py
CHANGED
|
@@ -57,9 +57,10 @@ def make_node(
|
|
|
57
57
|
graph_node.inputs[0],
|
|
58
58
|
before_op_output_shape_trans,
|
|
59
59
|
)
|
|
60
|
+
# Indices must not be layout-transposed.
|
|
60
61
|
graph_node_input_2 = get_constant_or_variable(
|
|
61
62
|
graph_node.inputs[1],
|
|
62
|
-
|
|
63
|
+
False,
|
|
63
64
|
)
|
|
64
65
|
graph_node_output: gs.Variable = graph_node.outputs[0]
|
|
65
66
|
shape = graph_node_output.shape
|
|
@@ -77,12 +78,29 @@ def make_node(
|
|
|
77
78
|
param_name=graph_node.inputs[0].name,
|
|
78
79
|
**kwargs,
|
|
79
80
|
)
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
81
|
+
# If input is transposed by replacement params, align indices tensor shape.
|
|
82
|
+
op_rep_params = kwargs.get('op_rep_params', [])
|
|
83
|
+
params_perm = None
|
|
84
|
+
indices_perm = None
|
|
85
|
+
for op_rep_param in op_rep_params:
|
|
86
|
+
if op_rep_param['param_target'] == 'inputs' \
|
|
87
|
+
and op_rep_param['param_name'] == graph_node.inputs[0].name:
|
|
88
|
+
params_perm = op_rep_param.get('pre_process_transpose_perm', None)
|
|
89
|
+
if op_rep_param['param_target'] == 'inputs' \
|
|
90
|
+
and op_rep_param['param_name'] == graph_node.inputs[1].name:
|
|
91
|
+
indices_perm = op_rep_param.get('pre_process_transpose_perm', None)
|
|
92
|
+
target_perm = indices_perm if indices_perm is not None else params_perm
|
|
93
|
+
if target_perm is not None:
|
|
94
|
+
try:
|
|
95
|
+
rank = len(indices_tensor.shape) if hasattr(indices_tensor, "shape") else None
|
|
96
|
+
if rank is None or rank == len(target_perm):
|
|
97
|
+
indices_tensor = transpose_with_flexing_deterrence(
|
|
98
|
+
input_tensor=indices_tensor,
|
|
99
|
+
perm=target_perm,
|
|
100
|
+
**kwargs,
|
|
101
|
+
)
|
|
102
|
+
except Exception:
|
|
103
|
+
pass
|
|
86
104
|
|
|
87
105
|
tensor_rank = len(input_tensor.shape)
|
|
88
106
|
|
onnx2tf/ops/GatherND.py
CHANGED
|
@@ -51,9 +51,10 @@ def make_node(
|
|
|
51
51
|
graph_node.inputs[0],
|
|
52
52
|
before_op_output_shape_trans,
|
|
53
53
|
)
|
|
54
|
+
# Indices must not be layout-transposed.
|
|
54
55
|
graph_node_input_2 = get_constant_or_variable(
|
|
55
56
|
graph_node.inputs[1],
|
|
56
|
-
|
|
57
|
+
False,
|
|
57
58
|
)
|
|
58
59
|
graph_node_output: gs.Variable = graph_node.outputs[0]
|
|
59
60
|
shape = graph_node_output.shape
|
|
@@ -89,6 +90,32 @@ def make_node(
|
|
|
89
90
|
|
|
90
91
|
replace_gathernd_to_pseudo_gathernd = "gathernd" in kwargs['replace_to_pseudo_operators']
|
|
91
92
|
|
|
93
|
+
# If params is transposed, adjust indices to match the transposed layout.
|
|
94
|
+
op_rep_params = kwargs.get('op_rep_params', [])
|
|
95
|
+
params_perm = None
|
|
96
|
+
indices_perm_specified = False
|
|
97
|
+
for op_rep_param in op_rep_params:
|
|
98
|
+
if op_rep_param['param_target'] == 'inputs' and op_rep_param['param_name'] == graph_node.inputs[0].name:
|
|
99
|
+
params_perm = op_rep_param.get('pre_process_transpose_perm', None)
|
|
100
|
+
if op_rep_param['param_target'] == 'inputs' and op_rep_param['param_name'] == graph_node.inputs[1].name:
|
|
101
|
+
if op_rep_param.get('pre_process_transpose_perm', None) is not None:
|
|
102
|
+
indices_perm_specified = True
|
|
103
|
+
if params_perm is not None and not indices_perm_specified:
|
|
104
|
+
# Only handle standard layout swaps that keep batch dims at the front.
|
|
105
|
+
if batch_dims <= len(params_perm) \
|
|
106
|
+
and list(params_perm[:batch_dims]) == list(range(batch_dims)):
|
|
107
|
+
perm_tail = [p - batch_dims for p in params_perm if p >= batch_dims]
|
|
108
|
+
try:
|
|
109
|
+
if isinstance(indices_tensor, np.ndarray):
|
|
110
|
+
if indices_tensor.shape and indices_tensor.shape[-1] == len(perm_tail):
|
|
111
|
+
indices_tensor = indices_tensor[..., perm_tail]
|
|
112
|
+
else:
|
|
113
|
+
idx_last = indices_tensor.shape[-1] if indices_tensor.shape is not None else None
|
|
114
|
+
if idx_last is None or idx_last == len(perm_tail):
|
|
115
|
+
indices_tensor = tf.gather(indices_tensor, perm_tail, axis=-1)
|
|
116
|
+
except Exception:
|
|
117
|
+
pass
|
|
118
|
+
|
|
92
119
|
# Preserving Graph Structure (Dict)
|
|
93
120
|
tf_layers_dict[graph_node_output.name] = {
|
|
94
121
|
'optype': graph_node.op,
|
onnx2tf/ops/ScatterElements.py
CHANGED
|
@@ -55,9 +55,10 @@ def make_node(
|
|
|
55
55
|
graph_node.inputs[0],
|
|
56
56
|
before_op_output_shape_trans,
|
|
57
57
|
)
|
|
58
|
+
# Indices must not be layout-transposed.
|
|
58
59
|
graph_node_input_2 = get_constant_or_variable(
|
|
59
60
|
graph_node.inputs[1],
|
|
60
|
-
|
|
61
|
+
False,
|
|
61
62
|
)
|
|
62
63
|
graph_node_input_3 = get_constant_or_variable(
|
|
63
64
|
graph_node.inputs[2],
|
|
@@ -81,12 +82,29 @@ def make_node(
|
|
|
81
82
|
indices_tensor = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
|
|
82
83
|
if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
|
|
83
84
|
# Pre-process transpose
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
85
|
+
# If input is transposed by replacement params, align indices tensor shape.
|
|
86
|
+
op_rep_params = kwargs.get('op_rep_params', [])
|
|
87
|
+
params_perm = None
|
|
88
|
+
indices_perm = None
|
|
89
|
+
for op_rep_param in op_rep_params:
|
|
90
|
+
if op_rep_param['param_target'] == 'inputs' \
|
|
91
|
+
and op_rep_param['param_name'] == graph_node.inputs[0].name:
|
|
92
|
+
params_perm = op_rep_param.get('pre_process_transpose_perm', None)
|
|
93
|
+
if op_rep_param['param_target'] == 'inputs' \
|
|
94
|
+
and op_rep_param['param_name'] == graph_node.inputs[1].name:
|
|
95
|
+
indices_perm = op_rep_param.get('pre_process_transpose_perm', None)
|
|
96
|
+
target_perm = indices_perm if indices_perm is not None else params_perm
|
|
97
|
+
if target_perm is not None:
|
|
98
|
+
try:
|
|
99
|
+
rank = len(indices_tensor.shape) if hasattr(indices_tensor, "shape") else None
|
|
100
|
+
if rank is None or rank == len(target_perm):
|
|
101
|
+
indices_tensor = transpose_with_flexing_deterrence(
|
|
102
|
+
input_tensor=indices_tensor,
|
|
103
|
+
perm=target_perm,
|
|
104
|
+
**kwargs,
|
|
105
|
+
)
|
|
106
|
+
except Exception:
|
|
107
|
+
pass
|
|
90
108
|
updates_tensor = tf_layers_dict[graph_node_input_3.name]['tf_node'] \
|
|
91
109
|
if isinstance(graph_node_input_3, gs.Variable) else graph_node_input_3
|
|
92
110
|
# Pre-process transpose
|
onnx2tf/ops/ScatterND.py
CHANGED
|
@@ -13,6 +13,7 @@ from onnx2tf.utils.common_functions import (
|
|
|
13
13
|
get_replacement_parameter,
|
|
14
14
|
pre_process_transpose,
|
|
15
15
|
post_process_transpose,
|
|
16
|
+
transpose_with_flexing_deterrence,
|
|
16
17
|
)
|
|
17
18
|
|
|
18
19
|
|
|
@@ -79,6 +80,32 @@ def make_node(
|
|
|
79
80
|
and 'nhwc' in tf_layers_dict[graph_node_input_1.name].keys() else False
|
|
80
81
|
}
|
|
81
82
|
|
|
83
|
+
op_rep_params = kwargs.get('op_rep_params', [])
|
|
84
|
+
params_perm = None
|
|
85
|
+
indices_perm = None
|
|
86
|
+
for op_rep_param in op_rep_params:
|
|
87
|
+
if op_rep_param['param_target'] == 'inputs' \
|
|
88
|
+
and op_rep_param['param_name'] == graph_node.inputs[0].name:
|
|
89
|
+
params_perm = op_rep_param.get('pre_process_transpose_perm', None)
|
|
90
|
+
if op_rep_param['param_target'] == 'inputs' \
|
|
91
|
+
and op_rep_param['param_name'] == graph_node.inputs[1].name:
|
|
92
|
+
indices_perm = op_rep_param.get('pre_process_transpose_perm', None)
|
|
93
|
+
|
|
94
|
+
def reorder_indices_last_dim(target_indices, perm):
|
|
95
|
+
if perm is None:
|
|
96
|
+
return target_indices
|
|
97
|
+
try:
|
|
98
|
+
if isinstance(target_indices, np.ndarray):
|
|
99
|
+
if target_indices.shape and target_indices.shape[-1] == len(perm):
|
|
100
|
+
return target_indices[..., perm]
|
|
101
|
+
else:
|
|
102
|
+
idx_last = target_indices.shape[-1] if target_indices.shape is not None else None
|
|
103
|
+
if idx_last is None or idx_last == len(perm):
|
|
104
|
+
return tf.gather(target_indices, perm, axis=-1)
|
|
105
|
+
except Exception:
|
|
106
|
+
pass
|
|
107
|
+
return target_indices
|
|
108
|
+
|
|
82
109
|
# Pre-process transpose
|
|
83
110
|
input_tensor = pre_process_transpose(
|
|
84
111
|
value_before_transpose=input_tensor,
|
|
@@ -86,18 +113,26 @@ def make_node(
|
|
|
86
113
|
param_name=graph_node.inputs[0].name,
|
|
87
114
|
**kwargs,
|
|
88
115
|
)
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
116
|
+
# Indices must not be layout-transposed; apply explicit perm only if specified.
|
|
117
|
+
if indices_perm is not None:
|
|
118
|
+
try:
|
|
119
|
+
rank = len(indices_tensor.shape) if hasattr(indices_tensor, "shape") else None
|
|
120
|
+
if rank is None or rank == len(indices_perm):
|
|
121
|
+
indices_tensor = transpose_with_flexing_deterrence(
|
|
122
|
+
input_tensor=indices_tensor,
|
|
123
|
+
perm=indices_perm,
|
|
124
|
+
**kwargs,
|
|
125
|
+
)
|
|
126
|
+
except Exception:
|
|
127
|
+
pass
|
|
95
128
|
updates_tensor = pre_process_transpose(
|
|
96
129
|
value_before_transpose=updates_tensor,
|
|
97
130
|
param_target='inputs',
|
|
98
131
|
param_name=graph_node.inputs[2].name,
|
|
99
132
|
**kwargs,
|
|
100
133
|
)
|
|
134
|
+
if params_perm is not None and indices_perm is None:
|
|
135
|
+
indices_tensor = reorder_indices_last_dim(indices_tensor, params_perm)
|
|
101
136
|
|
|
102
137
|
# When NHWC is fixed, return to NCHW format before processing.
|
|
103
138
|
data_nhwc = tf_layers_dict[graph_node_input_1.name]['nhwc'] \
|
|
@@ -119,6 +154,8 @@ def make_node(
|
|
|
119
154
|
and len(input_tensor.shape) >= 3:
|
|
120
155
|
perm = [0, len(input_tensor.shape)-1] + [i for i in range(1, len(input_tensor.shape)-1)]
|
|
121
156
|
input_tensor = tf.transpose(a=input_tensor, perm=perm)
|
|
157
|
+
if indices_perm is None:
|
|
158
|
+
indices_tensor = reorder_indices_last_dim(indices_tensor, perm)
|
|
122
159
|
nchw = True
|
|
123
160
|
elif not data_nhwc \
|
|
124
161
|
and len(input_tensor.shape) >= 3 \
|
|
@@ -126,6 +163,8 @@ def make_node(
|
|
|
126
163
|
and input_tensor.shape != graph_node.inputs[0].shape:
|
|
127
164
|
perm = [0, len(input_tensor.shape)-1] + [i for i in range(1, len(input_tensor.shape)-1)]
|
|
128
165
|
input_tensor = tf.transpose(a=input_tensor, perm=perm)
|
|
166
|
+
if indices_perm is None:
|
|
167
|
+
indices_tensor = reorder_indices_last_dim(indices_tensor, perm)
|
|
129
168
|
nchw = True
|
|
130
169
|
## indices
|
|
131
170
|
if indices_nhwc \
|
onnx2tf/ops/TensorScatter.py
CHANGED
|
@@ -14,6 +14,7 @@ from onnx2tf.utils.common_functions import (
|
|
|
14
14
|
get_replacement_parameter,
|
|
15
15
|
pre_process_transpose,
|
|
16
16
|
post_process_transpose,
|
|
17
|
+
transpose_with_flexing_deterrence,
|
|
17
18
|
)
|
|
18
19
|
from onnx2tf.utils.enums import NUMPY_DTYPES_TO_TF_DTYPES
|
|
19
20
|
from onnx2tf.utils.logging import *
|
|
@@ -112,12 +113,25 @@ def make_node(
|
|
|
112
113
|
**kwargs,
|
|
113
114
|
)
|
|
114
115
|
if write_indices is not None:
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
116
|
+
# Indices must not be layout-transposed; apply explicit perm only if specified.
|
|
117
|
+
op_rep_params = kwargs.get('op_rep_params', [])
|
|
118
|
+
indices_perm = None
|
|
119
|
+
for op_rep_param in op_rep_params:
|
|
120
|
+
if op_rep_param['param_target'] == 'inputs' \
|
|
121
|
+
and op_rep_param['param_name'] == graph_node.inputs[2].name:
|
|
122
|
+
indices_perm = op_rep_param.get('pre_process_transpose_perm', None)
|
|
123
|
+
break
|
|
124
|
+
if indices_perm is not None:
|
|
125
|
+
try:
|
|
126
|
+
rank = len(write_indices.shape) if hasattr(write_indices, "shape") else None
|
|
127
|
+
if rank is None or rank == len(indices_perm):
|
|
128
|
+
write_indices = transpose_with_flexing_deterrence(
|
|
129
|
+
input_tensor=write_indices,
|
|
130
|
+
perm=indices_perm,
|
|
131
|
+
**kwargs,
|
|
132
|
+
)
|
|
133
|
+
except Exception:
|
|
134
|
+
pass
|
|
121
135
|
|
|
122
136
|
# Generation of TF OP
|
|
123
137
|
past_cache = _as_tensor(past_cache)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import math
|
|
2
|
+
import ast
|
|
2
3
|
import os
|
|
3
4
|
import io
|
|
4
5
|
import re
|
|
@@ -26,6 +27,7 @@ from tensorflow.python.keras.layers import Lambda
|
|
|
26
27
|
from tensorflow.python.keras.utils import conv_utils
|
|
27
28
|
import onnx
|
|
28
29
|
from onnx.serialization import ProtoSerializer
|
|
30
|
+
from onnx.external_data_helper import uses_external_data
|
|
29
31
|
import onnx_graphsurgeon as gs
|
|
30
32
|
try:
|
|
31
33
|
import onnxruntime as ort
|
|
@@ -44,6 +46,61 @@ from onnx2tf.utils.enums import (
|
|
|
44
46
|
INF_INDEX_VALUE: int = 4294967296
|
|
45
47
|
ONNX_INF_INDEX_VALUE = sys.maxsize # 9223372036854775807
|
|
46
48
|
|
|
49
|
+
_DEFAULT_DUMMY_SHAPE_HINTS: Optional[List[str]] = None
|
|
50
|
+
_DEFAULT_DUMMY_VALUE_HINTS: Optional[List[str]] = None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def set_dummy_shape_hints(shape_hints: Optional[List[str]]) -> None:
|
|
54
|
+
global _DEFAULT_DUMMY_SHAPE_HINTS
|
|
55
|
+
_DEFAULT_DUMMY_SHAPE_HINTS = shape_hints
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def set_dummy_value_hints(value_hints: Optional[List[str]]) -> None:
|
|
59
|
+
global _DEFAULT_DUMMY_VALUE_HINTS
|
|
60
|
+
_DEFAULT_DUMMY_VALUE_HINTS = value_hints
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _parse_value_hint_scalar(value: str) -> Optional[Any]:
|
|
64
|
+
try:
|
|
65
|
+
parsed = ast.literal_eval(value)
|
|
66
|
+
except Exception:
|
|
67
|
+
try:
|
|
68
|
+
parsed = float(value)
|
|
69
|
+
except Exception:
|
|
70
|
+
return None
|
|
71
|
+
if isinstance(parsed, (list, tuple, dict, set, np.ndarray)):
|
|
72
|
+
return None
|
|
73
|
+
if isinstance(parsed, (int, float, bool, np.number)):
|
|
74
|
+
return parsed
|
|
75
|
+
return None
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _parse_value_hints(
|
|
79
|
+
value_hints: Optional[List[str]]
|
|
80
|
+
) -> Tuple[Dict[str, Any], Optional[Any], bool]:
|
|
81
|
+
if not value_hints:
|
|
82
|
+
return {}, None, False
|
|
83
|
+
hints: Dict[str, Any] = {}
|
|
84
|
+
default_value: Optional[Any] = None
|
|
85
|
+
for hint in value_hints:
|
|
86
|
+
if not isinstance(hint, str):
|
|
87
|
+
continue
|
|
88
|
+
parts = hint.split(':', 1)
|
|
89
|
+
if len(parts) != 2:
|
|
90
|
+
continue
|
|
91
|
+
input_name, value_str = parts[0], parts[1]
|
|
92
|
+
parsed_value = _parse_value_hint_scalar(value_str)
|
|
93
|
+
if parsed_value is None:
|
|
94
|
+
warn(f'Invalid --value_hints entry ignored: {hint}')
|
|
95
|
+
continue
|
|
96
|
+
if input_name == '*':
|
|
97
|
+
default_value = parsed_value
|
|
98
|
+
else:
|
|
99
|
+
hints[input_name] = parsed_value
|
|
100
|
+
return hints, default_value, default_value is not None
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
|
|
47
104
|
|
|
48
105
|
def get_replacement_parameter(func):
|
|
49
106
|
@wraps(func)
|
|
@@ -3851,6 +3908,7 @@ def dummy_onnx_inference(
|
|
|
3851
3908
|
enable_ort_output_memmap: bool = False,
|
|
3852
3909
|
ort_output_memmap_dir: Optional[str] = None,
|
|
3853
3910
|
shape_hints: Optional[List[str]] = None,
|
|
3911
|
+
value_hints: Optional[List[str]] = None,
|
|
3854
3912
|
input_datas_for_validation: Optional[Dict[str, np.ndarray]] = None,
|
|
3855
3913
|
) -> List[np.ndarray]:
|
|
3856
3914
|
"""Perform inference on ONNX subgraphs with an all-1 dummy tensor.
|
|
@@ -3888,6 +3946,10 @@ def dummy_onnx_inference(
|
|
|
3888
3946
|
Directory to store memmap files. If not specified, a temporary
|
|
3889
3947
|
directory is created and removed on exit.
|
|
3890
3948
|
|
|
3949
|
+
value_hints: Optional[List[str]]
|
|
3950
|
+
Value hints for dummy inference input tensors.
|
|
3951
|
+
Format: ["input_name:value", "*:default_value"].
|
|
3952
|
+
|
|
3891
3953
|
input_datas_for_validation: Optional[Dict[str, np.ndarray]]
|
|
3892
3954
|
Optional dict to be filled with the input tensors used for inference.
|
|
3893
3955
|
|
|
@@ -3896,6 +3958,11 @@ def dummy_onnx_inference(
|
|
|
3896
3958
|
outputs: List[np.ndarray]
|
|
3897
3959
|
Results of inference using dummy tensor
|
|
3898
3960
|
"""
|
|
3961
|
+
if shape_hints is None:
|
|
3962
|
+
shape_hints = _DEFAULT_DUMMY_SHAPE_HINTS
|
|
3963
|
+
if value_hints is None:
|
|
3964
|
+
value_hints = _DEFAULT_DUMMY_VALUE_HINTS
|
|
3965
|
+
|
|
3899
3966
|
# Separate onnx at specified output_names position
|
|
3900
3967
|
domain: str = onnx_graph.domain
|
|
3901
3968
|
ir_version: int = onnx_graph.ir_version
|
|
@@ -4046,7 +4113,11 @@ def dummy_onnx_inference(
|
|
|
4046
4113
|
input_sizes[i] = updated_shape
|
|
4047
4114
|
|
|
4048
4115
|
input_dtypes: List[Any] = [inp.dtype for inp in onnx_inputs]
|
|
4116
|
+
input_size_map = {
|
|
4117
|
+
name: tuple(size) for name, size in zip(input_names, input_sizes)
|
|
4118
|
+
}
|
|
4049
4119
|
input_datas = {}
|
|
4120
|
+
value_hints_dict, default_value, has_default = _parse_value_hints(value_hints)
|
|
4050
4121
|
|
|
4051
4122
|
# -cid
|
|
4052
4123
|
if custom_input_op_name_np_data_path:
|
|
@@ -4059,7 +4130,16 @@ def dummy_onnx_inference(
|
|
|
4059
4130
|
if input_op_info is not None:
|
|
4060
4131
|
ncw_nchw_ncdhw_perm: List = input_op_info.get('ncw_nchw_ncdhw_perm', None)
|
|
4061
4132
|
if ncw_nchw_ncdhw_perm is not None:
|
|
4062
|
-
|
|
4133
|
+
expected_shape = input_size_map.get(
|
|
4134
|
+
input_op_name,
|
|
4135
|
+
tuple(custom_input_data.shape),
|
|
4136
|
+
)
|
|
4137
|
+
if tuple(custom_input_data.shape) != expected_shape:
|
|
4138
|
+
permuted_shape = tuple(
|
|
4139
|
+
custom_input_data.shape[i] for i in ncw_nchw_ncdhw_perm
|
|
4140
|
+
)
|
|
4141
|
+
if permuted_shape == expected_shape:
|
|
4142
|
+
custom_input_data = custom_input_data.transpose(ncw_nchw_ncdhw_perm)
|
|
4063
4143
|
onnx_batch_size = input_op_info['shape'][0]
|
|
4064
4144
|
cdata_batch_size = custom_input_data.shape[0]
|
|
4065
4145
|
if isinstance(onnx_batch_size, int) and onnx_batch_size != cdata_batch_size and cdata_batch_size > 1:
|
|
@@ -4071,7 +4151,17 @@ def dummy_onnx_inference(
|
|
|
4071
4151
|
|
|
4072
4152
|
else:
|
|
4073
4153
|
for input_name, input_size, input_dtype in zip(input_names, input_sizes, input_dtypes):
|
|
4074
|
-
|
|
4154
|
+
hint_value = value_hints_dict.get(
|
|
4155
|
+
input_name,
|
|
4156
|
+
default_value if has_default else None,
|
|
4157
|
+
)
|
|
4158
|
+
if hint_value is not None:
|
|
4159
|
+
input_datas[input_name] = np.full(
|
|
4160
|
+
input_size,
|
|
4161
|
+
hint_value,
|
|
4162
|
+
dtype=input_dtype,
|
|
4163
|
+
)
|
|
4164
|
+
elif test_data_nhwc is None:
|
|
4075
4165
|
input_datas[input_name] = np.ones(
|
|
4076
4166
|
input_size,
|
|
4077
4167
|
dtype=input_dtype,
|
|
@@ -4230,7 +4320,9 @@ def dummy_tf_inference(
|
|
|
4230
4320
|
verification_datas: Optional[List[np.ndarray]] = None,
|
|
4231
4321
|
custom_input_op_name_np_data_path: Optional[str] = None,
|
|
4232
4322
|
shape_hints: Optional[List[str]] = None,
|
|
4323
|
+
value_hints: Optional[List[str]] = None,
|
|
4233
4324
|
input_datas_for_validation: Optional[Dict[str, np.ndarray]] = None,
|
|
4325
|
+
prefilled_input_datas: Optional[Dict[str, np.ndarray]] = None,
|
|
4234
4326
|
keep_shape_absolutely_input_names: Optional[List[str]] = None,
|
|
4235
4327
|
keep_ncw_or_nchw_or_ncdhw_input_names: Optional[List[str]] = None,
|
|
4236
4328
|
keep_nwc_or_nhwc_or_ndhwc_input_names: Optional[List[str]] = None,
|
|
@@ -4253,6 +4345,11 @@ def dummy_tf_inference(
|
|
|
4253
4345
|
|
|
4254
4346
|
custom_input_op_name_np_data_path
|
|
4255
4347
|
Path to Numpy file for custom data used for dummy inference
|
|
4348
|
+
|
|
4349
|
+
value_hints: Optional[List[str]]
|
|
4350
|
+
Value hints for dummy inference input tensors.
|
|
4351
|
+
Format: ["input_name:value", "*:default_value"].
|
|
4352
|
+
|
|
4256
4353
|
input_datas_for_validation: Optional[Dict[str, np.ndarray]]
|
|
4257
4354
|
Optional dict to be filled with the input tensors used for inference.
|
|
4258
4355
|
|
|
@@ -4262,8 +4359,15 @@ def dummy_tf_inference(
|
|
|
4262
4359
|
Results of inference using dummy tensor.
|
|
4263
4360
|
Dict of tensorflow node and corresponding ndarray output.
|
|
4264
4361
|
"""
|
|
4362
|
+
if shape_hints is None:
|
|
4363
|
+
shape_hints = _DEFAULT_DUMMY_SHAPE_HINTS
|
|
4364
|
+
if value_hints is None:
|
|
4365
|
+
value_hints = _DEFAULT_DUMMY_VALUE_HINTS
|
|
4366
|
+
|
|
4265
4367
|
input_names: List[str] = [inp.name for inp in inputs]
|
|
4266
4368
|
input_sizes: List[int] = [inp.shape for inp in inputs]
|
|
4369
|
+
input_size_map = {name: size for name, size in zip(input_names, input_sizes)}
|
|
4370
|
+
input_index_map = {name: i for i, name in enumerate(input_names)}
|
|
4267
4371
|
|
|
4268
4372
|
if shape_hints is None:
|
|
4269
4373
|
new_input_sizes = []
|
|
@@ -4335,13 +4439,24 @@ def dummy_tf_inference(
|
|
|
4335
4439
|
|
|
4336
4440
|
input_dtypes: List[Any] = [inp.dtype for inp in inputs]
|
|
4337
4441
|
input_datas = {}
|
|
4442
|
+
value_hints_dict, default_value, has_default = _parse_value_hints(value_hints)
|
|
4338
4443
|
|
|
4339
4444
|
# -cid
|
|
4340
4445
|
if custom_input_op_name_np_data_path:
|
|
4341
|
-
for
|
|
4446
|
+
for param in custom_input_op_name_np_data_path:
|
|
4447
|
+
if len(param) < 2:
|
|
4448
|
+
continue
|
|
4449
|
+
input_name = str(param[0])
|
|
4342
4450
|
numpy_file_path = str(param[1])
|
|
4451
|
+
if input_name not in input_index_map:
|
|
4452
|
+
continue
|
|
4453
|
+
idx = input_index_map[input_name]
|
|
4454
|
+
tf_input_name = input_names[idx]
|
|
4455
|
+
if prefilled_input_datas and tf_input_name in prefilled_input_datas:
|
|
4456
|
+
continue
|
|
4343
4457
|
custom_input_data = np.load(numpy_file_path)
|
|
4344
4458
|
input_size = input_sizes[idx]
|
|
4459
|
+
input_dtype = input_dtypes[idx] if idx < len(input_dtypes) else np.float32
|
|
4345
4460
|
|
|
4346
4461
|
tf_batch_size = input_size[0]
|
|
4347
4462
|
cdata_batch_size = custom_input_data.shape[0]
|
|
@@ -4351,6 +4466,24 @@ def dummy_tf_inference(
|
|
|
4351
4466
|
custom_input_data = custom_input_data[0:1, ...]
|
|
4352
4467
|
|
|
4353
4468
|
if list(custom_input_data.shape) != input_size:
|
|
4469
|
+
auto_split_input = (
|
|
4470
|
+
'onnx2tf_split_' in numpy_file_path
|
|
4471
|
+
or os.path.basename(numpy_file_path).startswith('part_')
|
|
4472
|
+
)
|
|
4473
|
+
if auto_split_input:
|
|
4474
|
+
warn(
|
|
4475
|
+
'Auto-split custom input shape does not match TF input shape. '
|
|
4476
|
+
f'input_name={input_name} '
|
|
4477
|
+
f'tf_shape={input_size} '
|
|
4478
|
+
f'numpy_shape={list(custom_input_data.shape)} '
|
|
4479
|
+
f'path={numpy_file_path} '
|
|
4480
|
+
'Fallback to dummy input for this tensor.'
|
|
4481
|
+
)
|
|
4482
|
+
input_datas[input_names[idx]] = np.ones(
|
|
4483
|
+
input_size,
|
|
4484
|
+
dtype=TF_DTYPES_TO_NUMPY_DTYPES[input_dtype],
|
|
4485
|
+
)
|
|
4486
|
+
continue
|
|
4354
4487
|
error_msg = f'' + \
|
|
4355
4488
|
Color.RED(f'ERROR:') + ' ' + \
|
|
4356
4489
|
f"The format of custom input data is different from Tensorflow's format. " + \
|
|
@@ -4363,7 +4496,17 @@ def dummy_tf_inference(
|
|
|
4363
4496
|
else:
|
|
4364
4497
|
if verification_datas is None:
|
|
4365
4498
|
for input_name, input_size, input_dtype in zip(input_names, input_sizes, input_dtypes):
|
|
4366
|
-
|
|
4499
|
+
hint_value = value_hints_dict.get(
|
|
4500
|
+
input_name,
|
|
4501
|
+
default_value if has_default else None,
|
|
4502
|
+
)
|
|
4503
|
+
if hint_value is not None:
|
|
4504
|
+
input_datas[input_name] = np.full(
|
|
4505
|
+
input_size,
|
|
4506
|
+
hint_value,
|
|
4507
|
+
dtype=TF_DTYPES_TO_NUMPY_DTYPES[input_dtype],
|
|
4508
|
+
)
|
|
4509
|
+
elif test_data_nhwc is None:
|
|
4367
4510
|
input_datas[input_name] = np.ones(
|
|
4368
4511
|
input_size,
|
|
4369
4512
|
dtype=TF_DTYPES_TO_NUMPY_DTYPES[input_dtype],
|
|
@@ -4397,6 +4540,33 @@ def dummy_tf_inference(
|
|
|
4397
4540
|
if input_datas_for_validation is not None:
|
|
4398
4541
|
input_datas_for_validation.update(input_datas)
|
|
4399
4542
|
|
|
4543
|
+
if prefilled_input_datas:
|
|
4544
|
+
for input_name, input_data in prefilled_input_datas.items():
|
|
4545
|
+
expected = None
|
|
4546
|
+
if input_name in input_datas:
|
|
4547
|
+
expected = input_datas[input_name].shape
|
|
4548
|
+
elif input_name in input_size_map:
|
|
4549
|
+
expected = input_size_map[input_name]
|
|
4550
|
+
else:
|
|
4551
|
+
continue
|
|
4552
|
+
data = input_data
|
|
4553
|
+
try:
|
|
4554
|
+
if expected is not None and tuple(data.shape) != tuple(expected):
|
|
4555
|
+
if data.size == np.prod(expected):
|
|
4556
|
+
data = data.reshape(expected)
|
|
4557
|
+
else:
|
|
4558
|
+
continue
|
|
4559
|
+
target_dtype = None
|
|
4560
|
+
if input_name in input_datas:
|
|
4561
|
+
target_dtype = input_datas[input_name].dtype
|
|
4562
|
+
elif input_name in input_index_map:
|
|
4563
|
+
target_dtype = input_dtypes[input_index_map[input_name]]
|
|
4564
|
+
if target_dtype is not None and data.dtype != target_dtype:
|
|
4565
|
+
data = data.astype(target_dtype)
|
|
4566
|
+
input_datas[input_name] = data
|
|
4567
|
+
except Exception:
|
|
4568
|
+
continue
|
|
4569
|
+
|
|
4400
4570
|
outputs = model(
|
|
4401
4571
|
inputs={
|
|
4402
4572
|
input.name: input_datas[input.name] for input in inputs
|
|
@@ -6057,6 +6227,8 @@ def acquisition_of_validation_data(
|
|
|
6057
6227
|
kwargs['test_data_nhwc']
|
|
6058
6228
|
custom_input_op_name_np_data_path: str = \
|
|
6059
6229
|
kwargs['custom_input_op_name_np_data_path']
|
|
6230
|
+
tf_input_cache: Optional[Dict[str, np.ndarray]] = \
|
|
6231
|
+
kwargs.get('tf_input_cache', None)
|
|
6060
6232
|
|
|
6061
6233
|
# Get the output tensor of one previous OP of TensorFlow only once
|
|
6062
6234
|
tf_model_inputs = get_tf_model_inputs(
|
|
@@ -6108,6 +6280,7 @@ def acquisition_of_validation_data(
|
|
|
6108
6280
|
inputs=tf_model_inputs,
|
|
6109
6281
|
test_data_nhwc=test_data_nhwc,
|
|
6110
6282
|
custom_input_op_name_np_data_path=custom_input_op_name_np_data_path,
|
|
6283
|
+
prefilled_input_datas=tf_input_cache,
|
|
6111
6284
|
)
|
|
6112
6285
|
except Exception as ex:
|
|
6113
6286
|
pass
|
|
@@ -6526,3 +6699,24 @@ def define_reduceXXX(
|
|
|
6526
6699
|
keepdims=target_keepdims,
|
|
6527
6700
|
)
|
|
6528
6701
|
return reduced_tensor
|
|
6702
|
+
|
|
6703
|
+
def check_has_external_data(input_onnx_file_path: str) -> bool:
|
|
6704
|
+
model = onnx.load(input_onnx_file_path, load_external_data=False)
|
|
6705
|
+
def iter_tensors_in_graph(g):
|
|
6706
|
+
for t in g.initializer:
|
|
6707
|
+
yield t
|
|
6708
|
+
for t in g.sparse_initializer:
|
|
6709
|
+
yield t
|
|
6710
|
+
for n in g.node:
|
|
6711
|
+
for a in n.attribute:
|
|
6712
|
+
if a.type == onnx.AttributeProto.TENSOR:
|
|
6713
|
+
yield a.t
|
|
6714
|
+
elif a.type == onnx.AttributeProto.TENSORS:
|
|
6715
|
+
for t in a.tensors:
|
|
6716
|
+
yield t
|
|
6717
|
+
elif a.type == onnx.AttributeProto.GRAPH:
|
|
6718
|
+
yield from iter_tensors_in_graph(a.g)
|
|
6719
|
+
elif a.type == onnx.AttributeProto.GRAPHS:
|
|
6720
|
+
for sg in a.graphs:
|
|
6721
|
+
yield from iter_tensors_in_graph(sg)
|
|
6722
|
+
return any(uses_external_data(t) for t in iter_tensors_in_graph(model.graph))
|