onnx2tf 1.29.19__py3-none-any.whl → 1.29.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
onnx2tf/onnx2tf.py CHANGED
@@ -2,6 +2,8 @@
2
2
 
3
3
  import os
4
4
  import re
5
+ import shutil
6
+ import tempfile
5
7
  __path__ = (os.path.dirname(__file__), )
6
8
  with open(os.path.join(__path__[0], '__init__.py')) as f:
7
9
  init_text = f.read()
@@ -41,6 +43,7 @@ from typing import Optional, List, Any, Dict
41
43
  from argparse import ArgumentParser
42
44
 
43
45
  import importlib
46
+ import onnx2tf.utils.common_functions as common_functions
44
47
  from onnx2tf.utils.common_functions import (
45
48
  dummy_onnx_inference,
46
49
  dummy_tf_inference,
@@ -51,6 +54,7 @@ from onnx2tf.utils.common_functions import (
51
54
  get_tf_model_outputs,
52
55
  rewrite_tflite_inout_opname,
53
56
  check_cuda_enabled,
57
+ check_has_external_data,
54
58
  )
55
59
  from onnx2tf.utils.json_auto_generator import (
56
60
  generate_auto_replacement_json,
@@ -62,6 +66,349 @@ from onnx2tf.utils.enums import (
62
66
  from onnx2tf.utils.logging import *
63
67
  from sng4onnx import generate as op_name_auto_generate
64
68
 
69
+ def _sanitize_split_input_name(name: str) -> str:
70
+ if not name:
71
+ return 'tensor'
72
+ return re.sub(r'[^0-9A-Za-z._-]+', '_', name)
73
+
74
+ def _write_memmap_array(path: str, array: np.ndarray) -> str:
75
+ mm = np.lib.format.open_memmap(
76
+ path,
77
+ mode='w+',
78
+ dtype=array.dtype,
79
+ shape=array.shape,
80
+ )
81
+ mm[...] = array
82
+ mm.flush()
83
+ return path
84
+
85
+
86
+ def _tensorproto_nbytes(tensor: onnx.TensorProto) -> int:
87
+ if tensor is None:
88
+ return 0
89
+ if tensor.HasField('raw_data'):
90
+ return len(tensor.raw_data)
91
+ try:
92
+ np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor.data_type)
93
+ except Exception:
94
+ np_dtype = None
95
+ if np_dtype is None:
96
+ return 0
97
+ elem_size = np.dtype(np_dtype).itemsize
98
+ num_elems = int(np.prod(tensor.dims)) if len(tensor.dims) > 0 else 0
99
+ if num_elems == 0:
100
+ try:
101
+ field_name = onnx.helper.tensor_dtype_to_field(tensor.data_type)
102
+ if hasattr(tensor, field_name):
103
+ num_elems = len(getattr(tensor, field_name))
104
+ except Exception:
105
+ num_elems = 0
106
+ return num_elems * elem_size
107
+
108
+ def _collect_initializer_sizes(onnx_graph: onnx.ModelProto) -> Dict[str, int]:
109
+ initializer_sizes: Dict[str, int] = {}
110
+ if onnx_graph is None:
111
+ return initializer_sizes
112
+ for initializer in onnx_graph.graph.initializer:
113
+ if not initializer.name:
114
+ continue
115
+ try:
116
+ initializer_sizes[initializer.name] = _tensorproto_nbytes(initializer)
117
+ except Exception:
118
+ initializer_sizes[initializer.name] = 0
119
+ return initializer_sizes
120
+
121
+ def _collect_node_weight_keys(
122
+ *,
123
+ graph: gs.Graph,
124
+ initializer_sizes: Dict[str, int],
125
+ ) -> tuple[List[List[str]], Dict[str, int]]:
126
+ weight_sizes = dict(initializer_sizes)
127
+ node_weight_keys: List[List[str]] = []
128
+ for node in graph.nodes:
129
+ keys: List[str] = []
130
+ for inp in node.inputs:
131
+ if isinstance(inp, gs.Constant):
132
+ if isinstance(getattr(inp, 'values', None), np.ndarray):
133
+ key = f'const:{id(inp)}'
134
+ if key not in weight_sizes:
135
+ weight_sizes[key] = int(inp.values.nbytes)
136
+ keys.append(key)
137
+ continue
138
+ name = getattr(inp, 'name', '')
139
+ if name and name in initializer_sizes:
140
+ keys.append(name)
141
+ node_weight_keys.append(keys)
142
+ return node_weight_keys, weight_sizes
143
+
144
+ def _auto_partition_ranges(
145
+ *,
146
+ node_weight_keys: List[List[str]],
147
+ weight_sizes: Dict[str, int],
148
+ max_size_bytes: int,
149
+ reachable_node_indices: Optional[set] = None,
150
+ ) -> List[tuple]:
151
+ ranges: List[tuple] = []
152
+ if max_size_bytes <= 0 or not node_weight_keys:
153
+ return ranges
154
+ current_keys: set = set()
155
+ current_bytes = 0
156
+ start_idx = 0
157
+ for idx, keys in enumerate(node_weight_keys):
158
+ new_bytes = 0
159
+ for key in keys:
160
+ if key not in current_keys:
161
+ new_bytes += weight_sizes.get(key, 0)
162
+ current_keys.add(key)
163
+ current_bytes += new_bytes
164
+ if current_bytes >= max_size_bytes and idx > start_idx:
165
+ if reachable_node_indices is not None and idx not in reachable_node_indices:
166
+ continue
167
+ ranges.append((start_idx, idx))
168
+ start_idx = idx + 1
169
+ current_keys = set()
170
+ current_bytes = 0
171
+ if start_idx <= len(node_weight_keys) - 1:
172
+ ranges.append((start_idx, len(node_weight_keys) - 1))
173
+ return ranges
174
+
175
+ def _collect_reachable_node_indices(
176
+ graph: gs.Graph,
177
+ initializer_names: Optional[set] = None,
178
+ ) -> set:
179
+ reachable_nodes: set = set()
180
+ reachable_vars: set = set()
181
+ initializer_names = initializer_names or set()
182
+ for graph_input in graph.inputs:
183
+ name = getattr(graph_input, 'name', '')
184
+ if name and name not in initializer_names:
185
+ reachable_vars.add(name)
186
+ for idx, node in enumerate(graph.nodes):
187
+ is_reachable = False
188
+ for inp in node.inputs:
189
+ if isinstance(inp, gs.Variable):
190
+ name = getattr(inp, 'name', '')
191
+ if name in reachable_vars and name not in initializer_names:
192
+ is_reachable = True
193
+ break
194
+ if is_reachable:
195
+ reachable_nodes.add(idx)
196
+ for out in node.outputs:
197
+ name = getattr(out, 'name', '')
198
+ if name:
199
+ reachable_vars.add(name)
200
+ return reachable_nodes
201
+
202
+ def _collect_constant_only_node_indices(
203
+ graph: gs.Graph,
204
+ initializer_names: Optional[set] = None,
205
+ ) -> set:
206
+ initializer_names = initializer_names or set()
207
+ const_only_nodes: set = set()
208
+ for idx, node in enumerate(graph.nodes):
209
+ has_variable_input = False
210
+ for inp in node.inputs:
211
+ if isinstance(inp, gs.Constant):
212
+ continue
213
+ name = getattr(inp, 'name', '')
214
+ if name and name not in initializer_names:
215
+ has_variable_input = True
216
+ break
217
+ if not has_variable_input:
218
+ const_only_nodes.add(idx)
219
+ return const_only_nodes
220
+
221
+ def _complete_custom_inputs_for_graph(
222
+ *,
223
+ onnx_graph: onnx.ModelProto,
224
+ custom_inputs: List[List[Any]],
225
+ output_dir: str,
226
+ file_prefix: str,
227
+ shape_hints: Optional[List[str]] = None,
228
+ require_mean_std: bool = False,
229
+ ) -> List[List[Any]]:
230
+ gs_graph = gs.import_onnx(onnx_graph)
231
+ input_names: List[str] = [inp.name for inp in gs_graph.inputs]
232
+ input_sizes: List[List[Any]] = [inp.shape for inp in gs_graph.inputs]
233
+ input_dtypes: List[Any] = [inp.dtype for inp in gs_graph.inputs]
234
+
235
+ if shape_hints is None:
236
+ new_input_sizes = []
237
+ for input_size in input_sizes:
238
+ new_input_size = []
239
+ for idx, dim in enumerate(input_size):
240
+ if idx == 0 and input_sizes and input_sizes[0][0] is not None \
241
+ and not isinstance(input_sizes[0][0], str) \
242
+ and len(input_sizes[0]) == len(input_size) \
243
+ and (dim is None or isinstance(dim, str)):
244
+ new_input_size.append(input_sizes[0][0])
245
+ elif dim is None or isinstance(dim, str):
246
+ new_input_size.append(1)
247
+ else:
248
+ new_input_size.append(dim)
249
+ new_input_sizes.append(new_input_size)
250
+ input_sizes = new_input_sizes
251
+ else:
252
+ shape_hints_dict = {}
253
+ for hint in shape_hints:
254
+ parts = hint.split(':')
255
+ if len(parts) == 2:
256
+ input_name = parts[0]
257
+ shape_values = [int(val) for val in parts[1].split(',')]
258
+ shape_hints_dict[input_name] = shape_values
259
+ for i, (input_name, original_shape) in enumerate(zip(input_names, input_sizes)):
260
+ if input_name in shape_hints_dict:
261
+ updated_shape = shape_hints_dict[input_name]
262
+ for j, (orig_dim, hint_dim) in enumerate(zip(original_shape, updated_shape)):
263
+ if orig_dim is not None and not isinstance(orig_dim, str):
264
+ updated_shape[j] = orig_dim
265
+ else:
266
+ updated_shape[j] = hint_dim
267
+ input_sizes[i] = updated_shape
268
+
269
+ custom_map = {}
270
+ for item in custom_inputs or []:
271
+ if len(item) >= 2:
272
+ custom_map[item[0]] = item
273
+
274
+ results: List[List[Any]] = []
275
+ for input_name, input_size, input_dtype in zip(input_names, input_sizes, input_dtypes):
276
+ if input_name in custom_map:
277
+ item = list(custom_map[input_name])
278
+ if require_mean_std and len(item) == 2:
279
+ item = [item[0], item[1], 0.0, 1.0]
280
+ results.append(item)
281
+ continue
282
+ dtype = input_dtype if input_dtype is not None else np.float32
283
+ file_name = f'{file_prefix}_{_sanitize_split_input_name(input_name)}.npy'
284
+ file_path = os.path.join(output_dir, file_name)
285
+ mm = np.lib.format.open_memmap(
286
+ file_path,
287
+ mode='w+',
288
+ dtype=dtype,
289
+ shape=tuple(input_size),
290
+ )
291
+ mm[...] = 1
292
+ mm.flush()
293
+ if require_mean_std:
294
+ results.append([input_name, file_path, 0.0, 1.0])
295
+ else:
296
+ results.append([input_name, file_path])
297
+ return results
298
+
299
+ def _estimate_partition_weight_bytes(
300
+ *,
301
+ ranges: List[tuple],
302
+ node_weight_keys: List[List[str]],
303
+ weight_sizes: Dict[str, int],
304
+ ) -> List[int]:
305
+ partition_sizes: List[int] = []
306
+ for start_idx, end_idx in ranges:
307
+ seen: set = set()
308
+ total_bytes = 0
309
+ for idx in range(start_idx, end_idx + 1):
310
+ for key in node_weight_keys[idx]:
311
+ if key not in seen:
312
+ total_bytes += weight_sizes.get(key, 0)
313
+ seen.add(key)
314
+ partition_sizes.append(total_bytes)
315
+ return partition_sizes
316
+
317
+ def _build_partition_io(
318
+ *,
319
+ graph: gs.Graph,
320
+ ranges: List[tuple],
321
+ const_only_nodes: Optional[set] = None,
322
+ ) -> List[Dict[str, Any]]:
323
+ if not ranges:
324
+ return []
325
+ const_only_nodes = const_only_nodes or set()
326
+ producer_by_tensor: Dict[str, int] = {}
327
+ consumers_by_tensor: Dict[str, set] = {}
328
+ graph_output_names = [o.name for o in graph.outputs if o.name]
329
+ for idx, node in enumerate(graph.nodes):
330
+ for out in node.outputs:
331
+ name = getattr(out, 'name', '')
332
+ if name:
333
+ producer_by_tensor[name] = idx
334
+ for inp in node.inputs:
335
+ if isinstance(inp, gs.Constant):
336
+ continue
337
+ name = getattr(inp, 'name', '')
338
+ if not name:
339
+ continue
340
+ consumers_by_tensor.setdefault(name, set()).add(idx)
341
+
342
+ partitions: List[Dict[str, Any]] = []
343
+ for start_idx, end_idx in ranges:
344
+ node_idx_set = set(range(start_idx, end_idx + 1))
345
+ part_inputs: set = set()
346
+ part_outputs: set = set()
347
+ for idx in node_idx_set:
348
+ node = graph.nodes[idx]
349
+ for inp in node.inputs:
350
+ if isinstance(inp, gs.Constant):
351
+ continue
352
+ name = getattr(inp, 'name', '')
353
+ if not name:
354
+ continue
355
+ producer_idx = producer_by_tensor.get(name)
356
+ if producer_idx is None or producer_idx not in node_idx_set:
357
+ if producer_idx is not None and producer_idx in const_only_nodes:
358
+ continue
359
+ part_inputs.add(name)
360
+ for out in node.outputs:
361
+ name = getattr(out, 'name', '')
362
+ if not name:
363
+ continue
364
+ consumers = consumers_by_tensor.get(name, set())
365
+ if name in graph_output_names or any(c not in node_idx_set for c in consumers):
366
+ if idx in const_only_nodes and name not in graph_output_names:
367
+ continue
368
+ part_outputs.add(name)
369
+ partitions.append({
370
+ 'inputs': sorted(part_inputs),
371
+ 'outputs': sorted(part_outputs),
372
+ 'node_count': end_idx - start_idx + 1,
373
+ 'start_idx': start_idx,
374
+ 'end_idx': end_idx,
375
+ })
376
+ return partitions
377
+
378
+ def _merge_ranges_with_missing_io(
379
+ *,
380
+ graph: gs.Graph,
381
+ ranges: List[tuple],
382
+ const_only_nodes: Optional[set] = None,
383
+ ) -> tuple[List[tuple], List[Dict[str, Any]]]:
384
+ if not ranges:
385
+ return ranges, []
386
+ ranges = list(ranges)
387
+ const_only_nodes = const_only_nodes or set()
388
+ while True:
389
+ partitions = _build_partition_io(
390
+ graph=graph,
391
+ ranges=ranges,
392
+ const_only_nodes=const_only_nodes,
393
+ ) or []
394
+ if all(part['inputs'] and part['outputs'] for part in partitions):
395
+ return ranges, partitions
396
+ if len(ranges) <= 1:
397
+ return ranges, partitions
398
+ merged = False
399
+ for idx, part in enumerate(partitions):
400
+ if not part['inputs'] or not part['outputs']:
401
+ if idx > 0:
402
+ ranges[idx - 1] = (ranges[idx - 1][0], ranges[idx][1])
403
+ del ranges[idx]
404
+ else:
405
+ ranges[idx] = (ranges[idx][0], ranges[idx + 1][1])
406
+ del ranges[idx + 1]
407
+ merged = True
408
+ break
409
+ if not merged:
410
+ return ranges, partitions
411
+
65
412
  def fuse_expanded_qdq_to_qdq(
66
413
  *,
67
414
  graph: gs.Graph,
@@ -285,6 +632,7 @@ def convert(
285
632
  quant_norm_std: Optional[str] = '[[[[0.229, 0.224, 0.225]]]]',
286
633
  quant_type: Optional[str] = 'per-channel',
287
634
  custom_input_op_name_np_data_path: Optional[List] = None,
635
+ tf_input_cache: Optional[Dict[str, np.ndarray]] = None,
288
636
  input_quant_dtype: Optional[str] = 'int8',
289
637
  output_quant_dtype: Optional[str] = 'int8',
290
638
  not_use_onnxsim: Optional[bool] = False,
@@ -292,6 +640,7 @@ def convert(
292
640
  batch_size: Optional[int] = None,
293
641
  overwrite_input_shape: Optional[List[str]] = None,
294
642
  shape_hints: Optional[List[str]] = None,
643
+ value_hints: Optional[List[str]] = None,
295
644
  no_large_tensor: Optional[bool] = False,
296
645
  output_nms_with_dynamic_tensor: Optional[bool] = False,
297
646
  switch_nms_version: Optional[str] = 'v4',
@@ -321,6 +670,8 @@ def convert(
321
670
  param_replacement_file: Optional[str] = '',
322
671
  auto_generate_json: Optional[bool] = False,
323
672
  auto_generate_json_on_error: Optional[bool] = False,
673
+ enable_auto_split_model: Optional[bool] = False,
674
+ auto_split_max_size_mb: Optional[int] = 1024,
324
675
  check_gpu_delegate_compatibility: Optional[bool] = False,
325
676
  check_onnx_tf_outputs_elementwise_close: Optional[bool] = False,
326
677
  check_onnx_tf_outputs_elementwise_close_full: Optional[bool] = False,
@@ -451,6 +802,10 @@ def convert(
451
802
  ["input2","input2.npy",[0.3],[0.07]],\n
452
803
  ]
453
804
 
805
+ tf_input_cache: Optional[Dict[str, np.ndarray]]
806
+ Cache of TF dummy inference inputs keyed by TF input tensor name.\n
807
+ Used to propagate TF outputs between auto-split partitions.\n
808
+
454
809
  input_quant_dtype: Optional[str]
455
810
  Input dtypes when doing Full INT8 Quantization.\n
456
811
  "int8"(default) or "uint8" or "float32"
@@ -497,6 +852,15 @@ def convert(
497
852
  A value of 1 or more must be specified.\n
498
853
  Numerical values other than dynamic dimensions are ignored.
499
854
 
855
+ value_hints: Optional[List[str]]
856
+ Value hints for dummy inference input tensors.\n
857
+ The format is\n
858
+ ["input_name_1:value","input_name_2:value","*:default_value"].\n
859
+ "*" applies to all inputs not explicitly specified.\n
860
+ Values are scalar only.\n
861
+ e.g.\n
862
+ ['input0:0.5','mask:0','*:1.0']\n
863
+
500
864
  no_large_tensor: Optional[bool]
501
865
  Suppresses constant bloat caused by Tile OP when optimizing models in onnxsim.\n
502
866
  See: https://github.com/daquexian/onnx-simplifier/issues/178
@@ -682,6 +1046,15 @@ def convert(
682
1046
  This is now opt-in and requires explicitly enabling the feature.\n
683
1047
  Default: False
684
1048
 
1049
+ enable_auto_split_model: Optional[bool]
1050
+ Force auto split regardless of the ONNX file size.\n
1051
+ The target size is controlled by auto_split_max_size_mb.\n
1052
+ Default: False
1053
+
1054
+ auto_split_max_size_mb: Optional[int]
1055
+ Target maximum size per partition in MB based on ONNX initializer sizes.\n
1056
+ Default: 1024
1057
+
685
1058
  check_gpu_delegate_compatibility: Optional[bool]
686
1059
  Run TFLite ModelAnalyzer on the generated Float16 tflite model\n
687
1060
  to check if the model can be supported by GPU Delegate.
@@ -748,6 +1121,8 @@ def convert(
748
1121
  if verbosity is None:
749
1122
  verbosity = 'debug'
750
1123
  set_log_level('error' if non_verbose else verbosity)
1124
+ common_functions.set_dummy_shape_hints(shape_hints)
1125
+ common_functions.set_dummy_value_hints(value_hints)
751
1126
 
752
1127
  # Either designation required
753
1128
  if not input_onnx_file_path and not onnx_graph:
@@ -771,6 +1146,23 @@ def convert(
771
1146
  f'input_onnx_file_path: {input_onnx_file_path}'
772
1147
  )
773
1148
  sys.exit(1)
1149
+ auto_split_model = bool(enable_auto_split_model)
1150
+ if auto_split_model:
1151
+ info(
1152
+ Color.GREEN('Auto split forced by --enable_auto_split_model. ') +
1153
+ f'target={auto_split_max_size_mb} MB'
1154
+ )
1155
+ if onnx_graph is None and input_onnx_file_path and os.path.exists(input_onnx_file_path):
1156
+ try:
1157
+ onnx_file_size = os.path.getsize(input_onnx_file_path)
1158
+ if not auto_split_model and onnx_file_size > 2 * 1024 * 1024 * 1024:
1159
+ info(
1160
+ Color.GREEN('ONNX file exceeds 2GB; switching to auto-split mode. ') +
1161
+ f'size={onnx_file_size / (1024 * 1024 * 1024):.2f} GB'
1162
+ )
1163
+ auto_split_model = True
1164
+ except Exception:
1165
+ pass
774
1166
 
775
1167
  # Extracting onnx filenames
776
1168
  output_file_name = ''
@@ -917,9 +1309,8 @@ def convert(
917
1309
  exported_onnx_graph = gs.export_onnx(graph, do_type_check=False, **meta_data)
918
1310
  if metadata_props is not None:
919
1311
  exported_onnx_graph.metadata_props.extend(metadata_props)
920
- estimated_graph = onnx.shape_inference.infer_shapes(exported_onnx_graph)
921
- onnx.save(estimated_graph, f=input_onnx_file_path)
922
- del estimated_graph
1312
+ onnx.save(exported_onnx_graph, f=input_onnx_file_path)
1313
+ del exported_onnx_graph
923
1314
  except:
924
1315
  if tmp_graph is not None:
925
1316
  del tmp_graph
@@ -948,6 +1339,10 @@ def convert(
948
1339
  'Failed to optimize the onnx file.'
949
1340
  )
950
1341
 
1342
+ has_external_data = False
1343
+ if input_onnx_file_path and os.path.exists(input_onnx_file_path):
1344
+ has_external_data = check_has_external_data(input_onnx_file_path)
1345
+
951
1346
  # Automatic generation of each OP name - sng4onnx
952
1347
  if not not_use_opname_auto_generate:
953
1348
  info('')
@@ -957,6 +1352,7 @@ def convert(
957
1352
  input_onnx_file_path=f'{input_onnx_file_path}',
958
1353
  onnx_graph=onnx_graph,
959
1354
  output_onnx_file_path=f'{input_onnx_file_path}',
1355
+ has_external_data=has_external_data,
960
1356
  non_verbose=True,
961
1357
  )
962
1358
  info(Color.GREEN(f'Automatic generation of each OP name complete!'))
@@ -981,6 +1377,19 @@ def convert(
981
1377
  if not onnx_graph:
982
1378
  onnx_graph = onnx.load(input_onnx_file_path)
983
1379
 
1380
+ if not auto_split_model and onnx_graph is not None:
1381
+ try:
1382
+ initializer_sizes = _collect_initializer_sizes(onnx_graph)
1383
+ total_init_bytes = sum(initializer_sizes.values())
1384
+ if total_init_bytes > 2 * 1024 * 1024 * 1024:
1385
+ info(
1386
+ Color.GREEN('ONNX graph estimated initializer size exceeds 2GB; ') +
1387
+ f'switching to auto-split mode. size={total_init_bytes / (1024 * 1024 * 1024):.2f} GB'
1388
+ )
1389
+ auto_split_model = True
1390
+ except Exception:
1391
+ pass
1392
+
984
1393
  domain: str = onnx_graph.domain
985
1394
  ir_version: int = onnx_graph.ir_version
986
1395
  meta_data = {'domain': domain, 'ir_version': ir_version}
@@ -990,6 +1399,523 @@ def convert(
990
1399
  graph = gs.import_onnx(onnx_graph)
991
1400
  fuse_expanded_qdq_to_qdq(graph=graph)
992
1401
 
1402
+ # Auto split model by estimated weight size
1403
+ if auto_split_model:
1404
+ if input_names_to_interrupt_model_conversion or output_names_to_interrupt_model_conversion:
1405
+ error(
1406
+ 'Auto split cannot be used together with input_names_to_interrupt_model_conversion '
1407
+ 'or output_names_to_interrupt_model_conversion.'
1408
+ )
1409
+ sys.exit(1)
1410
+ if auto_split_max_size_mb is None or auto_split_max_size_mb <= 0:
1411
+ error(
1412
+ f'auto_split_max_size_mb must be greater than 0. auto_split_max_size_mb: {auto_split_max_size_mb}'
1413
+ )
1414
+ sys.exit(1)
1415
+ try:
1416
+ import sne4onnx
1417
+ except Exception:
1418
+ error(
1419
+ 'Auto split requires sne4onnx. pip install sne4onnx'
1420
+ )
1421
+ sys.exit(1)
1422
+ try:
1423
+ graph.toposort()
1424
+ except Exception:
1425
+ pass
1426
+
1427
+ onnx_graph_for_split = onnx_graph
1428
+ try:
1429
+ onnx_graph_for_split = gs.export_onnx(
1430
+ graph=graph,
1431
+ do_type_check=False,
1432
+ **meta_data,
1433
+ )
1434
+ if metadata_props is not None:
1435
+ onnx_graph_for_split.metadata_props.extend(metadata_props)
1436
+ except Exception:
1437
+ onnx_graph_for_split = onnx_graph
1438
+
1439
+ initializer_sizes = _collect_initializer_sizes(onnx_graph_for_split)
1440
+ node_weight_keys, weight_sizes = _collect_node_weight_keys(
1441
+ graph=graph,
1442
+ initializer_sizes=initializer_sizes,
1443
+ )
1444
+ const_only_nodes = _collect_constant_only_node_indices(
1445
+ graph,
1446
+ initializer_names=set(initializer_sizes.keys()),
1447
+ )
1448
+ reachable_node_indices = _collect_reachable_node_indices(
1449
+ graph,
1450
+ initializer_names=set(initializer_sizes.keys()),
1451
+ )
1452
+ max_size_bytes = int(auto_split_max_size_mb) * 1024 * 1024
1453
+ ranges = _auto_partition_ranges(
1454
+ node_weight_keys=node_weight_keys,
1455
+ weight_sizes=weight_sizes,
1456
+ max_size_bytes=max_size_bytes,
1457
+ reachable_node_indices=reachable_node_indices,
1458
+ )
1459
+ if len(ranges) > 1:
1460
+ ranges, partitions = _merge_ranges_with_missing_io(
1461
+ graph=graph,
1462
+ ranges=ranges,
1463
+ const_only_nodes=const_only_nodes,
1464
+ )
1465
+ if not partitions:
1466
+ warn(
1467
+ 'Auto split failed to determine partition boundaries. Proceeding without split.'
1468
+ )
1469
+ else:
1470
+ if any([not p['inputs'] or not p['outputs'] for p in partitions]):
1471
+ warn(
1472
+ 'Auto split produced partitions with missing inputs or outputs. '
1473
+ 'Some partitions may not be inferable.'
1474
+ )
1475
+ partition_sizes = _estimate_partition_weight_bytes(
1476
+ ranges=ranges,
1477
+ node_weight_keys=node_weight_keys,
1478
+ weight_sizes=weight_sizes,
1479
+ )
1480
+ try:
1481
+ op_type_list = list(set([node.op for node in graph.nodes]))
1482
+ local_use_cuda = sum(
1483
+ [1 if op_type in CUDA_ONLY_OPS else 0 for op_type in op_type_list]
1484
+ ) > 0
1485
+ except Exception:
1486
+ local_use_cuda = False
1487
+ info('')
1488
+ info(Color.REVERSE(f'Auto model partitioning enabled'), '=' * 44)
1489
+ info(
1490
+ Color.GREEN(f'Target partition size (estimated weights): ') +
1491
+ f'{auto_split_max_size_mb} MB'
1492
+ )
1493
+ for idx, part in enumerate(partitions):
1494
+ size_mb = partition_sizes[idx] / (1024 * 1024)
1495
+ info(
1496
+ f' part {idx+1}: nodes={part["node_count"]}, '
1497
+ f'est_weights={size_mb:.2f} MB, '
1498
+ f'inputs={len(part["inputs"])}, outputs={len(part["outputs"])}'
1499
+ )
1500
+ info(
1501
+ f' inputs: {", ".join(part["inputs"]) if part["inputs"] else "(none)"}'
1502
+ )
1503
+ info(
1504
+ f' outputs: {", ".join(part["outputs"]) if part["outputs"] else "(none)"}'
1505
+ )
1506
+
1507
+ split_input_cache: Dict[str, str] = {}
1508
+ split_input_dir = tempfile.mkdtemp(prefix='onnx2tf_split_')
1509
+ split_tf_input_cache: Dict[str, np.ndarray] = {}
1510
+ split_output_layouts: Dict[str, bool] = {}
1511
+
1512
+ def _sanitize_tf_input_name(name: str) -> str:
1513
+ if name is None:
1514
+ return ''
1515
+ sanitized = name.replace(':', '__')
1516
+ if output_signaturedefs or output_integer_quantized_tflite:
1517
+ sanitized = re.sub('^/', 'wa/', sanitized)
1518
+ return f'{sanitized}:0'
1519
+
1520
+ def _normalize_onnx_output_name(name: str) -> str:
1521
+ if name is None:
1522
+ return ''
1523
+ normalized = name.replace(':', '__')
1524
+ if output_signaturedefs or output_integer_quantized_tflite:
1525
+ normalized = re.sub('^/', '', normalized)
1526
+ return normalized
1527
+
1528
+ def _common_layout_perms(rank: int) -> List[tuple]:
1529
+ if rank == 3:
1530
+ return [(0, 1, 2), (0, 2, 1)]
1531
+ if rank == 4:
1532
+ return [(0, 1, 2, 3), (0, 2, 3, 1), (0, 3, 1, 2)]
1533
+ if rank == 5:
1534
+ return [(0, 1, 2, 3, 4), (0, 2, 3, 4, 1), (0, 4, 1, 2, 3)]
1535
+ return [tuple(range(rank))]
1536
+
1537
+ def _build_onnx_tf_output_map(
1538
+ *,
1539
+ onnx_output_names: List[str],
1540
+ tf_output_tensors: List[tf.Tensor],
1541
+ onnx_output_values: Optional[Dict[str, np.ndarray]] = None,
1542
+ tf_output_values: Optional[Dict[str, np.ndarray]] = None,
1543
+ ) -> Dict[str, tf.Tensor]:
1544
+ tf_by_base = {t.name.split(':')[0]: t for t in tf_output_tensors}
1545
+ tf_by_full = {t.name: t for t in tf_output_tensors}
1546
+ mapping: Dict[str, tf.Tensor] = {}
1547
+ used_tf: set = set()
1548
+ missing: List[str] = []
1549
+ for onnx_name in onnx_output_names:
1550
+ normalized = _normalize_onnx_output_name(onnx_name)
1551
+ candidates = [
1552
+ normalized,
1553
+ f'wa/{normalized}',
1554
+ ]
1555
+ tf_tensor = None
1556
+ for cand in candidates:
1557
+ if cand in tf_by_base:
1558
+ tf_tensor = tf_by_base[cand]
1559
+ break
1560
+ full_name = f'{cand}:0'
1561
+ if full_name in tf_by_full:
1562
+ tf_tensor = tf_by_full[full_name]
1563
+ break
1564
+ if tf_tensor is None:
1565
+ if onnx_name in tf_by_base:
1566
+ tf_tensor = tf_by_base[onnx_name]
1567
+ else:
1568
+ full_name = f'{onnx_name}:0'
1569
+ if full_name in tf_by_full:
1570
+ tf_tensor = tf_by_full[full_name]
1571
+ if tf_tensor is not None:
1572
+ mapping[onnx_name] = tf_tensor
1573
+ used_tf.add(tf_tensor.name)
1574
+ else:
1575
+ missing.append(onnx_name)
1576
+
1577
+ if onnx_output_values and tf_output_values and missing:
1578
+ tf_candidates = []
1579
+ for tf_tensor in tf_output_tensors:
1580
+ tf_val = tf_output_values.get(tf_tensor.name)
1581
+ if tf_val is None:
1582
+ tf_val = tf_output_values.get(tf_tensor.name.split(':')[0])
1583
+ if tf_val is not None:
1584
+ tf_candidates.append((tf_tensor, tf_val))
1585
+ still_missing = []
1586
+ for onnx_name in list(missing):
1587
+ onnx_val = onnx_output_values.get(onnx_name)
1588
+ if onnx_val is None:
1589
+ still_missing.append(onnx_name)
1590
+ continue
1591
+ best = None
1592
+ best_err = None
1593
+ for tf_tensor, tf_val in tf_candidates:
1594
+ if tf_tensor.name in used_tf:
1595
+ continue
1596
+ if tf_val.shape != onnx_val.shape:
1597
+ continue
1598
+ err = np.max(np.abs(onnx_val - tf_val))
1599
+ if best is None or err < best_err:
1600
+ best = tf_tensor
1601
+ best_err = err
1602
+ if best is None:
1603
+ for tf_tensor, tf_val in tf_candidates:
1604
+ if tf_tensor.name in used_tf:
1605
+ continue
1606
+ if tf_val.ndim != onnx_val.ndim:
1607
+ continue
1608
+ for perm in _common_layout_perms(tf_val.ndim):
1609
+ if tf_val.transpose(perm).shape != onnx_val.shape:
1610
+ continue
1611
+ err = np.max(np.abs(onnx_val - tf_val.transpose(perm)))
1612
+ if best is None or err < best_err:
1613
+ best = tf_tensor
1614
+ best_err = err
1615
+ if best is not None and best_err is not None and best_err <= 1e-3:
1616
+ mapping[onnx_name] = best
1617
+ used_tf.add(best.name)
1618
+ else:
1619
+ still_missing.append(onnx_name)
1620
+ missing = still_missing
1621
+ if missing:
1622
+ warn(
1623
+ 'Auto split output mapping failed for: ' +
1624
+ ', '.join(missing) +
1625
+ '. Output cache/layout may be incomplete.'
1626
+ )
1627
+ return mapping
1628
+
1629
+ def _onnx_output_shape_map(onnx_model: onnx.ModelProto) -> Dict[str, List[Optional[int]]]:
1630
+ shape_map: Dict[str, List[Optional[int]]] = {}
1631
+ try:
1632
+ for out in onnx_model.graph.output:
1633
+ dims: List[Optional[int]] = []
1634
+ t = out.type.tensor_type
1635
+ if t.HasField('shape'):
1636
+ for d in t.shape.dim:
1637
+ if d.dim_value > 0:
1638
+ dims.append(int(d.dim_value))
1639
+ elif d.dim_param:
1640
+ dims.append(None)
1641
+ else:
1642
+ dims.append(None)
1643
+ if dims:
1644
+ shape_map[out.name] = dims
1645
+ except Exception:
1646
+ pass
1647
+ return shape_map
1648
+
1649
+ def _infer_keep_shape(onnx_shape: List[Optional[int]], tf_shape: List[int]) -> Optional[bool]:
1650
+ if not onnx_shape or any(d is None for d in onnx_shape):
1651
+ return None
1652
+ if list(onnx_shape) == list(tf_shape):
1653
+ return True
1654
+ rank = len(onnx_shape)
1655
+ if rank == 3:
1656
+ if list(tf_shape) == [onnx_shape[0], onnx_shape[2], onnx_shape[1]]:
1657
+ return False
1658
+ elif rank == 4:
1659
+ if list(tf_shape) == [onnx_shape[0], onnx_shape[2], onnx_shape[3], onnx_shape[1]]:
1660
+ return False
1661
+ elif rank == 5:
1662
+ if list(tf_shape) == [onnx_shape[0], onnx_shape[2], onnx_shape[3], onnx_shape[4], onnx_shape[1]]:
1663
+ return False
1664
+ return None
1665
+
1666
+ def _merge_custom_inputs(user_inputs, auto_inputs):
1667
+ merged = []
1668
+ seen = set()
1669
+ if user_inputs:
1670
+ for item in user_inputs:
1671
+ if len(item) >= 2:
1672
+ merged.append(item)
1673
+ seen.add(item[0])
1674
+ for item in auto_inputs:
1675
+ if len(item) >= 2 and item[0] not in seen:
1676
+ merged.append(item)
1677
+ seen.add(item[0])
1678
+ return merged
1679
+
1680
+ base_kwargs = {
1681
+ 'input_onnx_file_path': input_onnx_file_path if input_onnx_file_path is not None else None,
1682
+ 'onnx_graph': onnx_graph,
1683
+ 'output_folder_path': output_folder_path,
1684
+ 'output_signaturedefs': output_signaturedefs,
1685
+ 'output_h5': output_h5,
1686
+ 'output_keras_v3': output_keras_v3,
1687
+ 'output_tfv1_pb': output_tfv1_pb,
1688
+ 'output_weights': output_weights,
1689
+ 'copy_onnx_input_output_names_to_tflite': copy_onnx_input_output_names_to_tflite,
1690
+ 'output_dynamic_range_quantized_tflite': output_dynamic_range_quantized_tflite,
1691
+ 'output_integer_quantized_tflite': output_integer_quantized_tflite,
1692
+ 'quant_norm_mean': quant_norm_mean,
1693
+ 'quant_norm_std': quant_norm_std,
1694
+ 'quant_type': quant_type,
1695
+ 'custom_input_op_name_np_data_path': custom_input_op_name_np_data_path,
1696
+ 'tf_input_cache': split_tf_input_cache,
1697
+ 'input_quant_dtype': input_quant_dtype,
1698
+ 'output_quant_dtype': output_quant_dtype,
1699
+ 'not_use_onnxsim': not_use_onnxsim,
1700
+ 'not_use_opname_auto_generate': not_use_opname_auto_generate,
1701
+ 'batch_size': batch_size,
1702
+ 'overwrite_input_shape': overwrite_input_shape,
1703
+ 'shape_hints': shape_hints,
1704
+ 'value_hints': value_hints,
1705
+ 'no_large_tensor': no_large_tensor,
1706
+ 'output_nms_with_dynamic_tensor': output_nms_with_dynamic_tensor,
1707
+ 'switch_nms_version': switch_nms_version,
1708
+ 'keep_ncw_or_nchw_or_ncdhw_input_names': keep_ncw_or_nchw_or_ncdhw_input_names,
1709
+ 'keep_nwc_or_nhwc_or_ndhwc_input_names': keep_nwc_or_nhwc_or_ndhwc_input_names,
1710
+ 'keep_shape_absolutely_input_names': keep_shape_absolutely_input_names,
1711
+ 'input_names_to_interrupt_model_conversion': None,
1712
+ 'output_names_to_interrupt_model_conversion': None,
1713
+ 'disable_group_convolution': disable_group_convolution,
1714
+ 'enable_accumulation_type_float16': enable_accumulation_type_float16,
1715
+ 'enable_batchmatmul_unfold': enable_batchmatmul_unfold,
1716
+ 'enable_rnn_unroll': enable_rnn_unroll,
1717
+ 'disable_suppression_flextranspose': disable_suppression_flextranspose,
1718
+ 'disable_strict_mode': disable_strict_mode,
1719
+ 'onnxruntime_output_memmap': onnxruntime_output_memmap,
1720
+ 'onnxruntime_output_memmap_dir': onnxruntime_output_memmap_dir,
1721
+ 'number_of_dimensions_after_flextranspose_compression': number_of_dimensions_after_flextranspose_compression,
1722
+ 'disable_suppression_flexstridedslice': disable_suppression_flexstridedslice,
1723
+ 'number_of_dimensions_after_flexstridedslice_compression': number_of_dimensions_after_flexstridedslice_compression,
1724
+ 'optimization_for_gpu_delegate': optimization_for_gpu_delegate,
1725
+ 'replace_argmax_to_reducemax_and_indices_is_int64': replace_argmax_to_reducemax_and_indices_is_int64,
1726
+ 'replace_argmax_to_reducemax_and_indices_is_float32': replace_argmax_to_reducemax_and_indices_is_float32,
1727
+ 'replace_argmax_to_fused_argmax_and_indices_is_int64': replace_argmax_to_fused_argmax_and_indices_is_int64,
1728
+ 'replace_argmax_to_fused_argmax_and_indices_is_float32': replace_argmax_to_fused_argmax_and_indices_is_float32,
1729
+ 'fused_argmax_scale_ratio': fused_argmax_scale_ratio,
1730
+ 'replace_to_pseudo_operators': replace_to_pseudo_operators,
1731
+ 'param_replacement_file': param_replacement_file,
1732
+ 'auto_generate_json': auto_generate_json,
1733
+ 'auto_generate_json_on_error': auto_generate_json_on_error,
1734
+ 'enable_auto_split_model': False,
1735
+ 'auto_split_max_size_mb': auto_split_max_size_mb,
1736
+ 'check_gpu_delegate_compatibility': check_gpu_delegate_compatibility,
1737
+ 'check_onnx_tf_outputs_elementwise_close': check_onnx_tf_outputs_elementwise_close,
1738
+ 'check_onnx_tf_outputs_elementwise_close_full': check_onnx_tf_outputs_elementwise_close_full,
1739
+ 'check_onnx_tf_outputs_sample_data_normalization': check_onnx_tf_outputs_sample_data_normalization,
1740
+ 'check_onnx_tf_outputs_elementwise_close_rtol': check_onnx_tf_outputs_elementwise_close_rtol,
1741
+ 'check_onnx_tf_outputs_elementwise_close_atol': check_onnx_tf_outputs_elementwise_close_atol,
1742
+ 'mvn_epsilon': mvn_epsilon,
1743
+ 'disable_model_save': disable_model_save,
1744
+ 'non_verbose': non_verbose,
1745
+ 'verbosity': verbosity,
1746
+ }
1747
+ base_kwargs['input_names_to_interrupt_model_conversion'] = None
1748
+ base_kwargs['output_names_to_interrupt_model_conversion'] = None
1749
+
1750
+ model_ret = None
1751
+ try:
1752
+ for idx, part in enumerate(partitions):
1753
+ part_output_values: Optional[Dict[str, np.ndarray]] = None
1754
+ part_output_folder = os.path.join(
1755
+ output_folder_path,
1756
+ f'part_{idx+1:04d}',
1757
+ )
1758
+ base_name = os.path.splitext(os.path.basename(input_onnx_file_path))[0] \
1759
+ if input_onnx_file_path else 'model'
1760
+ os.makedirs(part_output_folder, exist_ok=True)
1761
+ split_onnx_path = os.path.join(
1762
+ part_output_folder,
1763
+ f'{base_name}_part_{idx+1:04d}.onnx'
1764
+ )
1765
+ part_graph = sne4onnx.extraction(
1766
+ input_op_names=part['inputs'],
1767
+ output_op_names=part['outputs'],
1768
+ onnx_graph=onnx_graph_for_split,
1769
+ output_onnx_file_path=split_onnx_path,
1770
+ has_external_data=has_external_data,
1771
+ )
1772
+ auto_custom_inputs = []
1773
+ if split_input_cache:
1774
+ for input_name in part['inputs']:
1775
+ if input_name in split_input_cache:
1776
+ auto_custom_inputs.append([
1777
+ input_name,
1778
+ split_input_cache[input_name],
1779
+ ])
1780
+ merged_custom_inputs = _merge_custom_inputs(
1781
+ custom_input_op_name_np_data_path,
1782
+ auto_custom_inputs,
1783
+ )
1784
+ # For the first partition, keep the same behavior as non-split conversion.
1785
+ # Only user-provided custom inputs are used.
1786
+ if idx == 0 and not auto_custom_inputs:
1787
+ custom_inputs_for_part = merged_custom_inputs
1788
+ else:
1789
+ require_mean_std = bool(output_integer_quantized_tflite)
1790
+ custom_inputs_for_part = _complete_custom_inputs_for_graph(
1791
+ onnx_graph=part_graph,
1792
+ custom_inputs=merged_custom_inputs,
1793
+ output_dir=split_input_dir,
1794
+ file_prefix=f'part_{idx+1:04d}',
1795
+ shape_hints=shape_hints,
1796
+ require_mean_std=require_mean_std,
1797
+ )
1798
+
1799
+ # Propagate dummy outputs to next partitions
1800
+ try:
1801
+ has_inputs = len(part_graph.graph.input) > 0
1802
+ has_outputs = len(part_graph.graph.output) > 0
1803
+ if has_inputs and has_outputs:
1804
+ part_input_datas = {}
1805
+ part_outputs = dummy_onnx_inference(
1806
+ onnx_graph=part_graph,
1807
+ output_names=part['outputs'],
1808
+ test_data_nhwc=None,
1809
+ custom_input_op_name_np_data_path=custom_inputs_for_part,
1810
+ tf_layers_dict={},
1811
+ use_cuda=local_use_cuda,
1812
+ disable_strict_mode=disable_strict_mode,
1813
+ enable_ort_output_memmap=False,
1814
+ ort_output_memmap_dir=None,
1815
+ shape_hints=shape_hints,
1816
+ input_datas_for_validation=part_input_datas,
1817
+ )
1818
+ for input_name, input_value in part_input_datas.items():
1819
+ file_name = (
1820
+ f'part_{idx+1:04d}_' +
1821
+ f'{_sanitize_split_input_name(input_name)}.npy'
1822
+ )
1823
+ file_path = os.path.join(split_input_dir, file_name)
1824
+ split_input_cache[input_name] = _write_memmap_array(
1825
+ file_path,
1826
+ input_value,
1827
+ )
1828
+ part_output_values = {
1829
+ name: value for name, value in zip(part['outputs'], part_outputs)
1830
+ }
1831
+ for output_name, output_value in zip(part['outputs'], part_outputs):
1832
+ file_name = (
1833
+ f'part_{idx+1:04d}_' +
1834
+ f'{_sanitize_split_input_name(output_name)}.npy'
1835
+ )
1836
+ file_path = os.path.join(split_input_dir, file_name)
1837
+ split_input_cache[output_name] = _write_memmap_array(
1838
+ file_path,
1839
+ output_value,
1840
+ )
1841
+ else:
1842
+ warn(
1843
+ 'Auto split input propagation skipped for this partition '
1844
+ 'because it has no inputs or outputs.'
1845
+ )
1846
+ except Exception as ex:
1847
+ warn(
1848
+ 'Auto split input propagation failed for this partition. '
1849
+ 'Subsequent partitions may use default dummy inputs.'
1850
+ )
1851
+ warn(f'{ex}')
1852
+
1853
+ part_kwargs = dict(base_kwargs)
1854
+ if split_output_layouts:
1855
+ part_keep_shape_abs = set(keep_shape_absolutely_input_names or [])
1856
+ for input_name in part['inputs']:
1857
+ if split_output_layouts.get(input_name, False):
1858
+ part_keep_shape_abs.add(input_name)
1859
+ part_kwargs['keep_shape_absolutely_input_names'] = \
1860
+ list(part_keep_shape_abs) if part_keep_shape_abs else None
1861
+ part_kwargs['input_onnx_file_path'] = split_onnx_path
1862
+ part_kwargs['onnx_graph'] = part_graph
1863
+ part_kwargs['output_folder_path'] = part_output_folder
1864
+ if custom_inputs_for_part:
1865
+ part_kwargs['custom_input_op_name_np_data_path'] = custom_inputs_for_part
1866
+ model_ret = convert(**part_kwargs)
1867
+
1868
+ if hasattr(model_ret, 'onnx_output_layouts') \
1869
+ and isinstance(model_ret.onnx_output_layouts, dict):
1870
+ for out_name in part['outputs']:
1871
+ if out_name in model_ret.onnx_output_layouts:
1872
+ split_output_layouts[out_name] = \
1873
+ bool(model_ret.onnx_output_layouts[out_name])
1874
+
1875
+ # Cache TF outputs for the next partition's TF dummy inference.
1876
+ try:
1877
+ tf_outputs = dummy_tf_inference(
1878
+ model=model_ret,
1879
+ inputs=model_ret.inputs,
1880
+ test_data_nhwc=None,
1881
+ custom_input_op_name_np_data_path=custom_input_op_name_np_data_path,
1882
+ prefilled_input_datas=split_tf_input_cache,
1883
+ shape_hints=shape_hints,
1884
+ keep_shape_absolutely_input_names=keep_shape_absolutely_input_names,
1885
+ keep_ncw_or_nchw_or_ncdhw_input_names=keep_ncw_or_nchw_or_ncdhw_input_names,
1886
+ keep_nwc_or_nhwc_or_ndhwc_input_names=keep_nwc_or_nhwc_or_ndhwc_input_names,
1887
+ )
1888
+ onnx_output_shapes = _onnx_output_shape_map(part_graph)
1889
+ if model_ret.outputs and part['outputs']:
1890
+ tf_output_map = _build_onnx_tf_output_map(
1891
+ onnx_output_names=part['outputs'],
1892
+ tf_output_tensors=model_ret.outputs,
1893
+ onnx_output_values=part_output_values,
1894
+ tf_output_values=tf_outputs,
1895
+ )
1896
+ for onnx_out, tf_tensor in tf_output_map.items():
1897
+ tf_val = tf_outputs.get(tf_tensor.name)
1898
+ if tf_val is None:
1899
+ continue
1900
+ # Store both full and base TF names to maximize cache hits.
1901
+ split_tf_input_cache[tf_tensor.name] = tf_val
1902
+ split_tf_input_cache[tf_tensor.name.split(':')[0]] = tf_val
1903
+ # Keep legacy key for compatibility with existing lookups.
1904
+ sanitized = _sanitize_tf_input_name(onnx_out)
1905
+ split_tf_input_cache[sanitized] = tf_val
1906
+ split_tf_input_cache[sanitized.split(':')[0]] = tf_val
1907
+ keep_shape = _infer_keep_shape(
1908
+ onnx_output_shapes.get(onnx_out),
1909
+ list(tf_val.shape),
1910
+ )
1911
+ if keep_shape is not None:
1912
+ split_output_layouts[onnx_out] = keep_shape
1913
+ except Exception:
1914
+ pass
1915
+ finally:
1916
+ shutil.rmtree(split_input_dir, ignore_errors=True)
1917
+ return model_ret
1918
+
993
1919
  # Cut the ONNX graph when an input name is specified that interrupts the conversion
994
1920
  if not input_names_to_interrupt_model_conversion:
995
1921
  input_names = [
@@ -1218,6 +2144,7 @@ def convert(
1218
2144
  'relu_relu6_merge_op_names': {},
1219
2145
  'mul_div_replace_op_names': {},
1220
2146
  'use_cuda': use_cuda,
2147
+ 'tf_input_cache': tf_input_cache,
1221
2148
  }
1222
2149
 
1223
2150
  tf_layers_dict = {}
@@ -1291,30 +2218,29 @@ def convert(
1291
2218
  )
1292
2219
 
1293
2220
  # download test data
1294
- all_four_dim = sum(
1295
- [
1296
- 1 for input in inputs \
1297
- if len(input.shape) == 4 \
1298
- and input.shape[0] is not None \
1299
- and input.shape[0] <= 20 \
1300
- and input.shape[-1] == 3 \
1301
- and input.shape[1] is not None \
1302
- and input.shape[2] is not None
1303
- ]
1304
- ) == len(inputs)
1305
- same_batch_dim = False
1306
- if all_four_dim:
1307
- batch_size = inputs[0].shape[0]
1308
- for input in inputs:
1309
- same_batch_dim = batch_size == input.shape[0]
1310
2221
  test_data_nhwc = None
1311
- if all_four_dim and same_batch_dim:
1312
- test_data: np.ndarray = download_test_image_data()
1313
- test_data_nhwc = test_data[:inputs[0].shape[0], ...]
1314
- if check_onnx_tf_outputs_sample_data_normalization == "norm":
1315
- pass
1316
- elif check_onnx_tf_outputs_sample_data_normalization == "denorm":
1317
- test_data_nhwc = test_data_nhwc * 255.0
2222
+ if inputs:
2223
+ all_four_dim = sum(
2224
+ [
2225
+ 1 for input in inputs \
2226
+ if len(input.shape) == 4 \
2227
+ and input.shape[0] is not None \
2228
+ and input.shape[0] <= 20 \
2229
+ and input.shape[-1] == 3 \
2230
+ and input.shape[1] is not None \
2231
+ and input.shape[2] is not None
2232
+ ]
2233
+ ) == len(inputs)
2234
+ same_batch_dim = False
2235
+ if all_four_dim:
2236
+ batch_size = inputs[0].shape[0]
2237
+ for input in inputs:
2238
+ same_batch_dim = batch_size == input.shape[0]
2239
+ if all_four_dim and same_batch_dim:
2240
+ test_data: np.ndarray = download_test_image_data()
2241
+ test_data_nhwc = test_data[:inputs[0].shape[0], ...]
2242
+ if check_onnx_tf_outputs_sample_data_normalization == "denorm":
2243
+ test_data_nhwc = test_data_nhwc * 255.0
1318
2244
 
1319
2245
  # ONNX dummy inference
1320
2246
  # Generate output for all OPs.
@@ -1400,7 +2326,10 @@ def convert(
1400
2326
  exported_onnx_graph = gs.export_onnx(graph, do_type_check=False, **meta_data)
1401
2327
  if metadata_props is not None:
1402
2328
  exported_onnx_graph.metadata_props.extend(metadata_props)
1403
- estimated_graph = onnx.shape_inference.infer_shapes(exported_onnx_graph)
2329
+ if not has_external_data:
2330
+ estimated_graph = onnx.shape_inference.infer_shapes(exported_onnx_graph)
2331
+ else:
2332
+ estimated_graph = exported_onnx_graph
1404
2333
  if input_onnx_file_path is not None:
1405
2334
  onnx.save(estimated_graph, input_onnx_file_path)
1406
2335
  if not not_use_onnxsim:
@@ -1580,6 +2509,14 @@ def convert(
1580
2509
  outputs[oidx] = tf_keras.layers.Lambda(lambda x: tf.constant(y))(x)
1581
2510
 
1582
2511
  model = tf_keras.Model(inputs=inputs, outputs=outputs)
2512
+ try:
2513
+ onnx_output_layouts = {
2514
+ name: tf_layers_dict.get(name, {}).get('nhwc', False)
2515
+ for name in onnx_graph_output_names
2516
+ }
2517
+ model.onnx_output_layouts = onnx_output_layouts
2518
+ except Exception:
2519
+ pass
1583
2520
  debug('')
1584
2521
 
1585
2522
  # The process ends normally without saving the model.
@@ -2953,6 +3890,18 @@ def main():
2953
3890
  'Only applied to dynamic dimensions in inputs. \n' +
2954
3891
  'Only used when -cotof or -coto are specified.'
2955
3892
  )
3893
+ parser.add_argument(
3894
+ '-vh',
3895
+ '--value_hints',
3896
+ type=str,
3897
+ nargs='+',
3898
+ help=\
3899
+ 'Value hints for dummy inference input tensors. \n' +
3900
+ 'The format is\n' +
3901
+ '"input_name_1:value" "input_name_2:value" "*:default_value". \n' +
3902
+ '"*" applies to all inputs not explicitly specified. \n' +
3903
+ 'Values are scalar only.'
3904
+ )
2956
3905
  parser.add_argument(
2957
3906
  '-nlt',
2958
3907
  '--no_large_tensor',
@@ -3050,6 +3999,23 @@ def main():
3050
3999
  'e.g. \n' +
3051
4000
  '--output_names_to_interrupt_model_conversion "output0" "output1" "output2"'
3052
4001
  )
4002
+ parser.add_argument(
4003
+ '-easm',
4004
+ '--enable_auto_split_model',
4005
+ action='store_true',
4006
+ help=\
4007
+ 'Force auto split regardless of the ONNX file size. \n' +
4008
+ 'Uses --auto_split_max_size_mb as the target partition size.'
4009
+ )
4010
+ parser.add_argument(
4011
+ '-asmsm',
4012
+ '--auto_split_max_size_mb',
4013
+ type=int,
4014
+ default=1024,
4015
+ help=\
4016
+ 'Target maximum size per partition in MB based on ONNX initializer sizes. \n' +
4017
+ 'Used when auto-split is triggered or forced.'
4018
+ )
3053
4019
  parser.add_argument(
3054
4020
  '-dgc',
3055
4021
  '--disable_group_convolution',
@@ -3421,6 +4387,7 @@ def main():
3421
4387
  batch_size=args.batch_size,
3422
4388
  overwrite_input_shape=args.overwrite_input_shape,
3423
4389
  shape_hints=args.shape_hints,
4390
+ value_hints=args.value_hints,
3424
4391
  no_large_tensor=args.no_large_tensor,
3425
4392
  output_nms_with_dynamic_tensor=args.output_nms_with_dynamic_tensor,
3426
4393
  switch_nms_version=args.switch_nms_version,
@@ -3450,6 +4417,8 @@ def main():
3450
4417
  param_replacement_file=args.param_replacement_file,
3451
4418
  auto_generate_json=args.auto_generate_json,
3452
4419
  auto_generate_json_on_error=args.auto_generate_json_on_error,
4420
+ enable_auto_split_model=args.enable_auto_split_model,
4421
+ auto_split_max_size_mb=args.auto_split_max_size_mb,
3453
4422
  check_gpu_delegate_compatibility=args.check_gpu_delegate_compatibility,
3454
4423
  check_onnx_tf_outputs_elementwise_close=args.check_onnx_tf_outputs_elementwise_close,
3455
4424
  check_onnx_tf_outputs_elementwise_close_full=args.check_onnx_tf_outputs_elementwise_close_full,