onnx2tf 1.29.19__py3-none-any.whl → 1.29.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -57,9 +57,10 @@ def make_node(
57
57
  graph_node.inputs[0],
58
58
  before_op_output_shape_trans,
59
59
  )
60
+ # Indices must not be layout-transposed.
60
61
  graph_node_input_2 = get_constant_or_variable(
61
62
  graph_node.inputs[1],
62
- before_op_output_shape_trans,
63
+ False,
63
64
  )
64
65
  graph_node_output: gs.Variable = graph_node.outputs[0]
65
66
  shape = graph_node_output.shape
@@ -77,12 +78,29 @@ def make_node(
77
78
  param_name=graph_node.inputs[0].name,
78
79
  **kwargs,
79
80
  )
80
- indices_tensor = pre_process_transpose(
81
- value_before_transpose=indices_tensor,
82
- param_target='inputs',
83
- param_name=graph_node.inputs[1].name,
84
- **kwargs,
85
- )
81
+ # If input is transposed by replacement params, align indices tensor shape.
82
+ op_rep_params = kwargs.get('op_rep_params', [])
83
+ params_perm = None
84
+ indices_perm = None
85
+ for op_rep_param in op_rep_params:
86
+ if op_rep_param['param_target'] == 'inputs' \
87
+ and op_rep_param['param_name'] == graph_node.inputs[0].name:
88
+ params_perm = op_rep_param.get('pre_process_transpose_perm', None)
89
+ if op_rep_param['param_target'] == 'inputs' \
90
+ and op_rep_param['param_name'] == graph_node.inputs[1].name:
91
+ indices_perm = op_rep_param.get('pre_process_transpose_perm', None)
92
+ target_perm = indices_perm if indices_perm is not None else params_perm
93
+ if target_perm is not None:
94
+ try:
95
+ rank = len(indices_tensor.shape) if hasattr(indices_tensor, "shape") else None
96
+ if rank is None or rank == len(target_perm):
97
+ indices_tensor = transpose_with_flexing_deterrence(
98
+ input_tensor=indices_tensor,
99
+ perm=target_perm,
100
+ **kwargs,
101
+ )
102
+ except Exception:
103
+ pass
86
104
 
87
105
  tensor_rank = len(input_tensor.shape)
88
106
 
onnx2tf/ops/GatherND.py CHANGED
@@ -51,9 +51,10 @@ def make_node(
51
51
  graph_node.inputs[0],
52
52
  before_op_output_shape_trans,
53
53
  )
54
+ # Indices must not be layout-transposed.
54
55
  graph_node_input_2 = get_constant_or_variable(
55
56
  graph_node.inputs[1],
56
- before_op_output_shape_trans,
57
+ False,
57
58
  )
58
59
  graph_node_output: gs.Variable = graph_node.outputs[0]
59
60
  shape = graph_node_output.shape
@@ -89,6 +90,32 @@ def make_node(
89
90
 
90
91
  replace_gathernd_to_pseudo_gathernd = "gathernd" in kwargs['replace_to_pseudo_operators']
91
92
 
93
+ # If params is transposed, adjust indices to match the transposed layout.
94
+ op_rep_params = kwargs.get('op_rep_params', [])
95
+ params_perm = None
96
+ indices_perm_specified = False
97
+ for op_rep_param in op_rep_params:
98
+ if op_rep_param['param_target'] == 'inputs' and op_rep_param['param_name'] == graph_node.inputs[0].name:
99
+ params_perm = op_rep_param.get('pre_process_transpose_perm', None)
100
+ if op_rep_param['param_target'] == 'inputs' and op_rep_param['param_name'] == graph_node.inputs[1].name:
101
+ if op_rep_param.get('pre_process_transpose_perm', None) is not None:
102
+ indices_perm_specified = True
103
+ if params_perm is not None and not indices_perm_specified:
104
+ # Only handle standard layout swaps that keep batch dims at the front.
105
+ if batch_dims <= len(params_perm) \
106
+ and list(params_perm[:batch_dims]) == list(range(batch_dims)):
107
+ perm_tail = [p - batch_dims for p in params_perm if p >= batch_dims]
108
+ try:
109
+ if isinstance(indices_tensor, np.ndarray):
110
+ if indices_tensor.shape and indices_tensor.shape[-1] == len(perm_tail):
111
+ indices_tensor = indices_tensor[..., perm_tail]
112
+ else:
113
+ idx_last = indices_tensor.shape[-1] if indices_tensor.shape is not None else None
114
+ if idx_last is None or idx_last == len(perm_tail):
115
+ indices_tensor = tf.gather(indices_tensor, perm_tail, axis=-1)
116
+ except Exception:
117
+ pass
118
+
92
119
  # Preserving Graph Structure (Dict)
93
120
  tf_layers_dict[graph_node_output.name] = {
94
121
  'optype': graph_node.op,
@@ -55,9 +55,10 @@ def make_node(
55
55
  graph_node.inputs[0],
56
56
  before_op_output_shape_trans,
57
57
  )
58
+ # Indices must not be layout-transposed.
58
59
  graph_node_input_2 = get_constant_or_variable(
59
60
  graph_node.inputs[1],
60
- before_op_output_shape_trans,
61
+ False,
61
62
  )
62
63
  graph_node_input_3 = get_constant_or_variable(
63
64
  graph_node.inputs[2],
@@ -81,12 +82,29 @@ def make_node(
81
82
  indices_tensor = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
82
83
  if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
83
84
  # Pre-process transpose
84
- indices_tensor = pre_process_transpose(
85
- value_before_transpose=indices_tensor,
86
- param_target='inputs',
87
- param_name=graph_node.inputs[1].name,
88
- **kwargs,
89
- )
85
+ # If input is transposed by replacement params, align indices tensor shape.
86
+ op_rep_params = kwargs.get('op_rep_params', [])
87
+ params_perm = None
88
+ indices_perm = None
89
+ for op_rep_param in op_rep_params:
90
+ if op_rep_param['param_target'] == 'inputs' \
91
+ and op_rep_param['param_name'] == graph_node.inputs[0].name:
92
+ params_perm = op_rep_param.get('pre_process_transpose_perm', None)
93
+ if op_rep_param['param_target'] == 'inputs' \
94
+ and op_rep_param['param_name'] == graph_node.inputs[1].name:
95
+ indices_perm = op_rep_param.get('pre_process_transpose_perm', None)
96
+ target_perm = indices_perm if indices_perm is not None else params_perm
97
+ if target_perm is not None:
98
+ try:
99
+ rank = len(indices_tensor.shape) if hasattr(indices_tensor, "shape") else None
100
+ if rank is None or rank == len(target_perm):
101
+ indices_tensor = transpose_with_flexing_deterrence(
102
+ input_tensor=indices_tensor,
103
+ perm=target_perm,
104
+ **kwargs,
105
+ )
106
+ except Exception:
107
+ pass
90
108
  updates_tensor = tf_layers_dict[graph_node_input_3.name]['tf_node'] \
91
109
  if isinstance(graph_node_input_3, gs.Variable) else graph_node_input_3
92
110
  # Pre-process transpose
onnx2tf/ops/ScatterND.py CHANGED
@@ -13,6 +13,7 @@ from onnx2tf.utils.common_functions import (
13
13
  get_replacement_parameter,
14
14
  pre_process_transpose,
15
15
  post_process_transpose,
16
+ transpose_with_flexing_deterrence,
16
17
  )
17
18
 
18
19
 
@@ -79,6 +80,32 @@ def make_node(
79
80
  and 'nhwc' in tf_layers_dict[graph_node_input_1.name].keys() else False
80
81
  }
81
82
 
83
+ op_rep_params = kwargs.get('op_rep_params', [])
84
+ params_perm = None
85
+ indices_perm = None
86
+ for op_rep_param in op_rep_params:
87
+ if op_rep_param['param_target'] == 'inputs' \
88
+ and op_rep_param['param_name'] == graph_node.inputs[0].name:
89
+ params_perm = op_rep_param.get('pre_process_transpose_perm', None)
90
+ if op_rep_param['param_target'] == 'inputs' \
91
+ and op_rep_param['param_name'] == graph_node.inputs[1].name:
92
+ indices_perm = op_rep_param.get('pre_process_transpose_perm', None)
93
+
94
+ def reorder_indices_last_dim(target_indices, perm):
95
+ if perm is None:
96
+ return target_indices
97
+ try:
98
+ if isinstance(target_indices, np.ndarray):
99
+ if target_indices.shape and target_indices.shape[-1] == len(perm):
100
+ return target_indices[..., perm]
101
+ else:
102
+ idx_last = target_indices.shape[-1] if target_indices.shape is not None else None
103
+ if idx_last is None or idx_last == len(perm):
104
+ return tf.gather(target_indices, perm, axis=-1)
105
+ except Exception:
106
+ pass
107
+ return target_indices
108
+
82
109
  # Pre-process transpose
83
110
  input_tensor = pre_process_transpose(
84
111
  value_before_transpose=input_tensor,
@@ -86,18 +113,26 @@ def make_node(
86
113
  param_name=graph_node.inputs[0].name,
87
114
  **kwargs,
88
115
  )
89
- indices_tensor = pre_process_transpose(
90
- value_before_transpose=indices_tensor,
91
- param_target='inputs',
92
- param_name=graph_node.inputs[1].name,
93
- **kwargs,
94
- )
116
+ # Indices must not be layout-transposed; apply explicit perm only if specified.
117
+ if indices_perm is not None:
118
+ try:
119
+ rank = len(indices_tensor.shape) if hasattr(indices_tensor, "shape") else None
120
+ if rank is None or rank == len(indices_perm):
121
+ indices_tensor = transpose_with_flexing_deterrence(
122
+ input_tensor=indices_tensor,
123
+ perm=indices_perm,
124
+ **kwargs,
125
+ )
126
+ except Exception:
127
+ pass
95
128
  updates_tensor = pre_process_transpose(
96
129
  value_before_transpose=updates_tensor,
97
130
  param_target='inputs',
98
131
  param_name=graph_node.inputs[2].name,
99
132
  **kwargs,
100
133
  )
134
+ if params_perm is not None and indices_perm is None:
135
+ indices_tensor = reorder_indices_last_dim(indices_tensor, params_perm)
101
136
 
102
137
  # When NHWC is fixed, return to NCHW format before processing.
103
138
  data_nhwc = tf_layers_dict[graph_node_input_1.name]['nhwc'] \
@@ -119,6 +154,8 @@ def make_node(
119
154
  and len(input_tensor.shape) >= 3:
120
155
  perm = [0, len(input_tensor.shape)-1] + [i for i in range(1, len(input_tensor.shape)-1)]
121
156
  input_tensor = tf.transpose(a=input_tensor, perm=perm)
157
+ if indices_perm is None:
158
+ indices_tensor = reorder_indices_last_dim(indices_tensor, perm)
122
159
  nchw = True
123
160
  elif not data_nhwc \
124
161
  and len(input_tensor.shape) >= 3 \
@@ -126,6 +163,8 @@ def make_node(
126
163
  and input_tensor.shape != graph_node.inputs[0].shape:
127
164
  perm = [0, len(input_tensor.shape)-1] + [i for i in range(1, len(input_tensor.shape)-1)]
128
165
  input_tensor = tf.transpose(a=input_tensor, perm=perm)
166
+ if indices_perm is None:
167
+ indices_tensor = reorder_indices_last_dim(indices_tensor, perm)
129
168
  nchw = True
130
169
  ## indices
131
170
  if indices_nhwc \
@@ -14,6 +14,7 @@ from onnx2tf.utils.common_functions import (
14
14
  get_replacement_parameter,
15
15
  pre_process_transpose,
16
16
  post_process_transpose,
17
+ transpose_with_flexing_deterrence,
17
18
  )
18
19
  from onnx2tf.utils.enums import NUMPY_DTYPES_TO_TF_DTYPES
19
20
  from onnx2tf.utils.logging import *
@@ -112,12 +113,25 @@ def make_node(
112
113
  **kwargs,
113
114
  )
114
115
  if write_indices is not None:
115
- write_indices = pre_process_transpose(
116
- value_before_transpose=write_indices,
117
- param_target='inputs',
118
- param_name=graph_node.inputs[2].name,
119
- **kwargs,
120
- )
116
+ # Indices must not be layout-transposed; apply explicit perm only if specified.
117
+ op_rep_params = kwargs.get('op_rep_params', [])
118
+ indices_perm = None
119
+ for op_rep_param in op_rep_params:
120
+ if op_rep_param['param_target'] == 'inputs' \
121
+ and op_rep_param['param_name'] == graph_node.inputs[2].name:
122
+ indices_perm = op_rep_param.get('pre_process_transpose_perm', None)
123
+ break
124
+ if indices_perm is not None:
125
+ try:
126
+ rank = len(write_indices.shape) if hasattr(write_indices, "shape") else None
127
+ if rank is None or rank == len(indices_perm):
128
+ write_indices = transpose_with_flexing_deterrence(
129
+ input_tensor=write_indices,
130
+ perm=indices_perm,
131
+ **kwargs,
132
+ )
133
+ except Exception:
134
+ pass
121
135
 
122
136
  # Generation of TF OP
123
137
  past_cache = _as_tensor(past_cache)
@@ -1,4 +1,5 @@
1
1
  import math
2
+ import ast
2
3
  import os
3
4
  import io
4
5
  import re
@@ -26,6 +27,7 @@ from tensorflow.python.keras.layers import Lambda
26
27
  from tensorflow.python.keras.utils import conv_utils
27
28
  import onnx
28
29
  from onnx.serialization import ProtoSerializer
30
+ from onnx.external_data_helper import uses_external_data
29
31
  import onnx_graphsurgeon as gs
30
32
  try:
31
33
  import onnxruntime as ort
@@ -44,6 +46,61 @@ from onnx2tf.utils.enums import (
44
46
  INF_INDEX_VALUE: int = 4294967296
45
47
  ONNX_INF_INDEX_VALUE = sys.maxsize # 9223372036854775807
46
48
 
49
+ _DEFAULT_DUMMY_SHAPE_HINTS: Optional[List[str]] = None
50
+ _DEFAULT_DUMMY_VALUE_HINTS: Optional[List[str]] = None
51
+
52
+
53
+ def set_dummy_shape_hints(shape_hints: Optional[List[str]]) -> None:
54
+ global _DEFAULT_DUMMY_SHAPE_HINTS
55
+ _DEFAULT_DUMMY_SHAPE_HINTS = shape_hints
56
+
57
+
58
+ def set_dummy_value_hints(value_hints: Optional[List[str]]) -> None:
59
+ global _DEFAULT_DUMMY_VALUE_HINTS
60
+ _DEFAULT_DUMMY_VALUE_HINTS = value_hints
61
+
62
+
63
+ def _parse_value_hint_scalar(value: str) -> Optional[Any]:
64
+ try:
65
+ parsed = ast.literal_eval(value)
66
+ except Exception:
67
+ try:
68
+ parsed = float(value)
69
+ except Exception:
70
+ return None
71
+ if isinstance(parsed, (list, tuple, dict, set, np.ndarray)):
72
+ return None
73
+ if isinstance(parsed, (int, float, bool, np.number)):
74
+ return parsed
75
+ return None
76
+
77
+
78
+ def _parse_value_hints(
79
+ value_hints: Optional[List[str]]
80
+ ) -> Tuple[Dict[str, Any], Optional[Any], bool]:
81
+ if not value_hints:
82
+ return {}, None, False
83
+ hints: Dict[str, Any] = {}
84
+ default_value: Optional[Any] = None
85
+ for hint in value_hints:
86
+ if not isinstance(hint, str):
87
+ continue
88
+ parts = hint.split(':', 1)
89
+ if len(parts) != 2:
90
+ continue
91
+ input_name, value_str = parts[0], parts[1]
92
+ parsed_value = _parse_value_hint_scalar(value_str)
93
+ if parsed_value is None:
94
+ warn(f'Invalid --value_hints entry ignored: {hint}')
95
+ continue
96
+ if input_name == '*':
97
+ default_value = parsed_value
98
+ else:
99
+ hints[input_name] = parsed_value
100
+ return hints, default_value, default_value is not None
101
+
102
+
103
+
47
104
 
48
105
  def get_replacement_parameter(func):
49
106
  @wraps(func)
@@ -3851,6 +3908,7 @@ def dummy_onnx_inference(
3851
3908
  enable_ort_output_memmap: bool = False,
3852
3909
  ort_output_memmap_dir: Optional[str] = None,
3853
3910
  shape_hints: Optional[List[str]] = None,
3911
+ value_hints: Optional[List[str]] = None,
3854
3912
  input_datas_for_validation: Optional[Dict[str, np.ndarray]] = None,
3855
3913
  ) -> List[np.ndarray]:
3856
3914
  """Perform inference on ONNX subgraphs with an all-1 dummy tensor.
@@ -3888,6 +3946,10 @@ def dummy_onnx_inference(
3888
3946
  Directory to store memmap files. If not specified, a temporary
3889
3947
  directory is created and removed on exit.
3890
3948
 
3949
+ value_hints: Optional[List[str]]
3950
+ Value hints for dummy inference input tensors.
3951
+ Format: ["input_name:value", "*:default_value"].
3952
+
3891
3953
  input_datas_for_validation: Optional[Dict[str, np.ndarray]]
3892
3954
  Optional dict to be filled with the input tensors used for inference.
3893
3955
 
@@ -3896,6 +3958,11 @@ def dummy_onnx_inference(
3896
3958
  outputs: List[np.ndarray]
3897
3959
  Results of inference using dummy tensor
3898
3960
  """
3961
+ if shape_hints is None:
3962
+ shape_hints = _DEFAULT_DUMMY_SHAPE_HINTS
3963
+ if value_hints is None:
3964
+ value_hints = _DEFAULT_DUMMY_VALUE_HINTS
3965
+
3899
3966
  # Separate onnx at specified output_names position
3900
3967
  domain: str = onnx_graph.domain
3901
3968
  ir_version: int = onnx_graph.ir_version
@@ -4046,7 +4113,11 @@ def dummy_onnx_inference(
4046
4113
  input_sizes[i] = updated_shape
4047
4114
 
4048
4115
  input_dtypes: List[Any] = [inp.dtype for inp in onnx_inputs]
4116
+ input_size_map = {
4117
+ name: tuple(size) for name, size in zip(input_names, input_sizes)
4118
+ }
4049
4119
  input_datas = {}
4120
+ value_hints_dict, default_value, has_default = _parse_value_hints(value_hints)
4050
4121
 
4051
4122
  # -cid
4052
4123
  if custom_input_op_name_np_data_path:
@@ -4059,7 +4130,16 @@ def dummy_onnx_inference(
4059
4130
  if input_op_info is not None:
4060
4131
  ncw_nchw_ncdhw_perm: List = input_op_info.get('ncw_nchw_ncdhw_perm', None)
4061
4132
  if ncw_nchw_ncdhw_perm is not None:
4062
- custom_input_data = custom_input_data.transpose(ncw_nchw_ncdhw_perm)
4133
+ expected_shape = input_size_map.get(
4134
+ input_op_name,
4135
+ tuple(custom_input_data.shape),
4136
+ )
4137
+ if tuple(custom_input_data.shape) != expected_shape:
4138
+ permuted_shape = tuple(
4139
+ custom_input_data.shape[i] for i in ncw_nchw_ncdhw_perm
4140
+ )
4141
+ if permuted_shape == expected_shape:
4142
+ custom_input_data = custom_input_data.transpose(ncw_nchw_ncdhw_perm)
4063
4143
  onnx_batch_size = input_op_info['shape'][0]
4064
4144
  cdata_batch_size = custom_input_data.shape[0]
4065
4145
  if isinstance(onnx_batch_size, int) and onnx_batch_size != cdata_batch_size and cdata_batch_size > 1:
@@ -4071,7 +4151,17 @@ def dummy_onnx_inference(
4071
4151
 
4072
4152
  else:
4073
4153
  for input_name, input_size, input_dtype in zip(input_names, input_sizes, input_dtypes):
4074
- if test_data_nhwc is None:
4154
+ hint_value = value_hints_dict.get(
4155
+ input_name,
4156
+ default_value if has_default else None,
4157
+ )
4158
+ if hint_value is not None:
4159
+ input_datas[input_name] = np.full(
4160
+ input_size,
4161
+ hint_value,
4162
+ dtype=input_dtype,
4163
+ )
4164
+ elif test_data_nhwc is None:
4075
4165
  input_datas[input_name] = np.ones(
4076
4166
  input_size,
4077
4167
  dtype=input_dtype,
@@ -4230,7 +4320,9 @@ def dummy_tf_inference(
4230
4320
  verification_datas: Optional[List[np.ndarray]] = None,
4231
4321
  custom_input_op_name_np_data_path: Optional[str] = None,
4232
4322
  shape_hints: Optional[List[str]] = None,
4323
+ value_hints: Optional[List[str]] = None,
4233
4324
  input_datas_for_validation: Optional[Dict[str, np.ndarray]] = None,
4325
+ prefilled_input_datas: Optional[Dict[str, np.ndarray]] = None,
4234
4326
  keep_shape_absolutely_input_names: Optional[List[str]] = None,
4235
4327
  keep_ncw_or_nchw_or_ncdhw_input_names: Optional[List[str]] = None,
4236
4328
  keep_nwc_or_nhwc_or_ndhwc_input_names: Optional[List[str]] = None,
@@ -4253,6 +4345,11 @@ def dummy_tf_inference(
4253
4345
 
4254
4346
  custom_input_op_name_np_data_path
4255
4347
  Path to Numpy file for custom data used for dummy inference
4348
+
4349
+ value_hints: Optional[List[str]]
4350
+ Value hints for dummy inference input tensors.
4351
+ Format: ["input_name:value", "*:default_value"].
4352
+
4256
4353
  input_datas_for_validation: Optional[Dict[str, np.ndarray]]
4257
4354
  Optional dict to be filled with the input tensors used for inference.
4258
4355
 
@@ -4262,8 +4359,15 @@ def dummy_tf_inference(
4262
4359
  Results of inference using dummy tensor.
4263
4360
  Dict of tensorflow node and corresponding ndarray output.
4264
4361
  """
4362
+ if shape_hints is None:
4363
+ shape_hints = _DEFAULT_DUMMY_SHAPE_HINTS
4364
+ if value_hints is None:
4365
+ value_hints = _DEFAULT_DUMMY_VALUE_HINTS
4366
+
4265
4367
  input_names: List[str] = [inp.name for inp in inputs]
4266
4368
  input_sizes: List[int] = [inp.shape for inp in inputs]
4369
+ input_size_map = {name: size for name, size in zip(input_names, input_sizes)}
4370
+ input_index_map = {name: i for i, name in enumerate(input_names)}
4267
4371
 
4268
4372
  if shape_hints is None:
4269
4373
  new_input_sizes = []
@@ -4335,13 +4439,24 @@ def dummy_tf_inference(
4335
4439
 
4336
4440
  input_dtypes: List[Any] = [inp.dtype for inp in inputs]
4337
4441
  input_datas = {}
4442
+ value_hints_dict, default_value, has_default = _parse_value_hints(value_hints)
4338
4443
 
4339
4444
  # -cid
4340
4445
  if custom_input_op_name_np_data_path:
4341
- for idx, param in enumerate(custom_input_op_name_np_data_path):
4446
+ for param in custom_input_op_name_np_data_path:
4447
+ if len(param) < 2:
4448
+ continue
4449
+ input_name = str(param[0])
4342
4450
  numpy_file_path = str(param[1])
4451
+ if input_name not in input_index_map:
4452
+ continue
4453
+ idx = input_index_map[input_name]
4454
+ tf_input_name = input_names[idx]
4455
+ if prefilled_input_datas and tf_input_name in prefilled_input_datas:
4456
+ continue
4343
4457
  custom_input_data = np.load(numpy_file_path)
4344
4458
  input_size = input_sizes[idx]
4459
+ input_dtype = input_dtypes[idx] if idx < len(input_dtypes) else np.float32
4345
4460
 
4346
4461
  tf_batch_size = input_size[0]
4347
4462
  cdata_batch_size = custom_input_data.shape[0]
@@ -4351,6 +4466,24 @@ def dummy_tf_inference(
4351
4466
  custom_input_data = custom_input_data[0:1, ...]
4352
4467
 
4353
4468
  if list(custom_input_data.shape) != input_size:
4469
+ auto_split_input = (
4470
+ 'onnx2tf_split_' in numpy_file_path
4471
+ or os.path.basename(numpy_file_path).startswith('part_')
4472
+ )
4473
+ if auto_split_input:
4474
+ warn(
4475
+ 'Auto-split custom input shape does not match TF input shape. '
4476
+ f'input_name={input_name} '
4477
+ f'tf_shape={input_size} '
4478
+ f'numpy_shape={list(custom_input_data.shape)} '
4479
+ f'path={numpy_file_path} '
4480
+ 'Fallback to dummy input for this tensor.'
4481
+ )
4482
+ input_datas[input_names[idx]] = np.ones(
4483
+ input_size,
4484
+ dtype=TF_DTYPES_TO_NUMPY_DTYPES[input_dtype],
4485
+ )
4486
+ continue
4354
4487
  error_msg = f'' + \
4355
4488
  Color.RED(f'ERROR:') + ' ' + \
4356
4489
  f"The format of custom input data is different from Tensorflow's format. " + \
@@ -4363,7 +4496,17 @@ def dummy_tf_inference(
4363
4496
  else:
4364
4497
  if verification_datas is None:
4365
4498
  for input_name, input_size, input_dtype in zip(input_names, input_sizes, input_dtypes):
4366
- if test_data_nhwc is None:
4499
+ hint_value = value_hints_dict.get(
4500
+ input_name,
4501
+ default_value if has_default else None,
4502
+ )
4503
+ if hint_value is not None:
4504
+ input_datas[input_name] = np.full(
4505
+ input_size,
4506
+ hint_value,
4507
+ dtype=TF_DTYPES_TO_NUMPY_DTYPES[input_dtype],
4508
+ )
4509
+ elif test_data_nhwc is None:
4367
4510
  input_datas[input_name] = np.ones(
4368
4511
  input_size,
4369
4512
  dtype=TF_DTYPES_TO_NUMPY_DTYPES[input_dtype],
@@ -4397,6 +4540,33 @@ def dummy_tf_inference(
4397
4540
  if input_datas_for_validation is not None:
4398
4541
  input_datas_for_validation.update(input_datas)
4399
4542
 
4543
+ if prefilled_input_datas:
4544
+ for input_name, input_data in prefilled_input_datas.items():
4545
+ expected = None
4546
+ if input_name in input_datas:
4547
+ expected = input_datas[input_name].shape
4548
+ elif input_name in input_size_map:
4549
+ expected = input_size_map[input_name]
4550
+ else:
4551
+ continue
4552
+ data = input_data
4553
+ try:
4554
+ if expected is not None and tuple(data.shape) != tuple(expected):
4555
+ if data.size == np.prod(expected):
4556
+ data = data.reshape(expected)
4557
+ else:
4558
+ continue
4559
+ target_dtype = None
4560
+ if input_name in input_datas:
4561
+ target_dtype = input_datas[input_name].dtype
4562
+ elif input_name in input_index_map:
4563
+ target_dtype = input_dtypes[input_index_map[input_name]]
4564
+ if target_dtype is not None and data.dtype != target_dtype:
4565
+ data = data.astype(target_dtype)
4566
+ input_datas[input_name] = data
4567
+ except Exception:
4568
+ continue
4569
+
4400
4570
  outputs = model(
4401
4571
  inputs={
4402
4572
  input.name: input_datas[input.name] for input in inputs
@@ -6057,6 +6227,8 @@ def acquisition_of_validation_data(
6057
6227
  kwargs['test_data_nhwc']
6058
6228
  custom_input_op_name_np_data_path: str = \
6059
6229
  kwargs['custom_input_op_name_np_data_path']
6230
+ tf_input_cache: Optional[Dict[str, np.ndarray]] = \
6231
+ kwargs.get('tf_input_cache', None)
6060
6232
 
6061
6233
  # Get the output tensor of one previous OP of TensorFlow only once
6062
6234
  tf_model_inputs = get_tf_model_inputs(
@@ -6108,6 +6280,7 @@ def acquisition_of_validation_data(
6108
6280
  inputs=tf_model_inputs,
6109
6281
  test_data_nhwc=test_data_nhwc,
6110
6282
  custom_input_op_name_np_data_path=custom_input_op_name_np_data_path,
6283
+ prefilled_input_datas=tf_input_cache,
6111
6284
  )
6112
6285
  except Exception as ex:
6113
6286
  pass
@@ -6526,3 +6699,24 @@ def define_reduceXXX(
6526
6699
  keepdims=target_keepdims,
6527
6700
  )
6528
6701
  return reduced_tensor
6702
+
6703
+ def check_has_external_data(input_onnx_file_path: str) -> bool:
6704
+ model = onnx.load(input_onnx_file_path, load_external_data=False)
6705
+ def iter_tensors_in_graph(g):
6706
+ for t in g.initializer:
6707
+ yield t
6708
+ for t in g.sparse_initializer:
6709
+ yield t
6710
+ for n in g.node:
6711
+ for a in n.attribute:
6712
+ if a.type == onnx.AttributeProto.TENSOR:
6713
+ yield a.t
6714
+ elif a.type == onnx.AttributeProto.TENSORS:
6715
+ for t in a.tensors:
6716
+ yield t
6717
+ elif a.type == onnx.AttributeProto.GRAPH:
6718
+ yield from iter_tensors_in_graph(a.g)
6719
+ elif a.type == onnx.AttributeProto.GRAPHS:
6720
+ for sg in a.graphs:
6721
+ yield from iter_tensors_in_graph(sg)
6722
+ return any(uses_external_data(t) for t in iter_tensors_in_graph(model.graph))