onnx2tf 1.29.18__py3-none-any.whl → 1.29.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
onnx2tf/onnx2tf.py CHANGED
@@ -2,6 +2,8 @@
2
2
 
3
3
  import os
4
4
  import re
5
+ import shutil
6
+ import tempfile
5
7
  __path__ = (os.path.dirname(__file__), )
6
8
  with open(os.path.join(__path__[0], '__init__.py')) as f:
7
9
  init_text = f.read()
@@ -51,6 +53,7 @@ from onnx2tf.utils.common_functions import (
51
53
  get_tf_model_outputs,
52
54
  rewrite_tflite_inout_opname,
53
55
  check_cuda_enabled,
56
+ check_has_external_data,
54
57
  )
55
58
  from onnx2tf.utils.json_auto_generator import (
56
59
  generate_auto_replacement_json,
@@ -62,6 +65,349 @@ from onnx2tf.utils.enums import (
62
65
  from onnx2tf.utils.logging import *
63
66
  from sng4onnx import generate as op_name_auto_generate
64
67
 
68
+ def _sanitize_split_input_name(name: str) -> str:
69
+ if not name:
70
+ return 'tensor'
71
+ return re.sub(r'[^0-9A-Za-z._-]+', '_', name)
72
+
73
+ def _write_memmap_array(path: str, array: np.ndarray) -> str:
74
+ mm = np.lib.format.open_memmap(
75
+ path,
76
+ mode='w+',
77
+ dtype=array.dtype,
78
+ shape=array.shape,
79
+ )
80
+ mm[...] = array
81
+ mm.flush()
82
+ return path
83
+
84
+
85
+ def _tensorproto_nbytes(tensor: onnx.TensorProto) -> int:
86
+ if tensor is None:
87
+ return 0
88
+ if tensor.HasField('raw_data'):
89
+ return len(tensor.raw_data)
90
+ try:
91
+ np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor.data_type)
92
+ except Exception:
93
+ np_dtype = None
94
+ if np_dtype is None:
95
+ return 0
96
+ elem_size = np.dtype(np_dtype).itemsize
97
+ num_elems = int(np.prod(tensor.dims)) if len(tensor.dims) > 0 else 0
98
+ if num_elems == 0:
99
+ try:
100
+ field_name = onnx.helper.tensor_dtype_to_field(tensor.data_type)
101
+ if hasattr(tensor, field_name):
102
+ num_elems = len(getattr(tensor, field_name))
103
+ except Exception:
104
+ num_elems = 0
105
+ return num_elems * elem_size
106
+
107
+ def _collect_initializer_sizes(onnx_graph: onnx.ModelProto) -> Dict[str, int]:
108
+ initializer_sizes: Dict[str, int] = {}
109
+ if onnx_graph is None:
110
+ return initializer_sizes
111
+ for initializer in onnx_graph.graph.initializer:
112
+ if not initializer.name:
113
+ continue
114
+ try:
115
+ initializer_sizes[initializer.name] = _tensorproto_nbytes(initializer)
116
+ except Exception:
117
+ initializer_sizes[initializer.name] = 0
118
+ return initializer_sizes
119
+
120
+ def _collect_node_weight_keys(
121
+ *,
122
+ graph: gs.Graph,
123
+ initializer_sizes: Dict[str, int],
124
+ ) -> tuple[List[List[str]], Dict[str, int]]:
125
+ weight_sizes = dict(initializer_sizes)
126
+ node_weight_keys: List[List[str]] = []
127
+ for node in graph.nodes:
128
+ keys: List[str] = []
129
+ for inp in node.inputs:
130
+ if isinstance(inp, gs.Constant):
131
+ if isinstance(getattr(inp, 'values', None), np.ndarray):
132
+ key = f'const:{id(inp)}'
133
+ if key not in weight_sizes:
134
+ weight_sizes[key] = int(inp.values.nbytes)
135
+ keys.append(key)
136
+ continue
137
+ name = getattr(inp, 'name', '')
138
+ if name and name in initializer_sizes:
139
+ keys.append(name)
140
+ node_weight_keys.append(keys)
141
+ return node_weight_keys, weight_sizes
142
+
143
+ def _auto_partition_ranges(
144
+ *,
145
+ node_weight_keys: List[List[str]],
146
+ weight_sizes: Dict[str, int],
147
+ max_size_bytes: int,
148
+ reachable_node_indices: Optional[set] = None,
149
+ ) -> List[tuple]:
150
+ ranges: List[tuple] = []
151
+ if max_size_bytes <= 0 or not node_weight_keys:
152
+ return ranges
153
+ current_keys: set = set()
154
+ current_bytes = 0
155
+ start_idx = 0
156
+ for idx, keys in enumerate(node_weight_keys):
157
+ new_bytes = 0
158
+ for key in keys:
159
+ if key not in current_keys:
160
+ new_bytes += weight_sizes.get(key, 0)
161
+ current_keys.add(key)
162
+ current_bytes += new_bytes
163
+ if current_bytes >= max_size_bytes and idx > start_idx:
164
+ if reachable_node_indices is not None and idx not in reachable_node_indices:
165
+ continue
166
+ ranges.append((start_idx, idx))
167
+ start_idx = idx + 1
168
+ current_keys = set()
169
+ current_bytes = 0
170
+ if start_idx <= len(node_weight_keys) - 1:
171
+ ranges.append((start_idx, len(node_weight_keys) - 1))
172
+ return ranges
173
+
174
+ def _collect_reachable_node_indices(
175
+ graph: gs.Graph,
176
+ initializer_names: Optional[set] = None,
177
+ ) -> set:
178
+ reachable_nodes: set = set()
179
+ reachable_vars: set = set()
180
+ initializer_names = initializer_names or set()
181
+ for graph_input in graph.inputs:
182
+ name = getattr(graph_input, 'name', '')
183
+ if name and name not in initializer_names:
184
+ reachable_vars.add(name)
185
+ for idx, node in enumerate(graph.nodes):
186
+ is_reachable = False
187
+ for inp in node.inputs:
188
+ if isinstance(inp, gs.Variable):
189
+ name = getattr(inp, 'name', '')
190
+ if name in reachable_vars and name not in initializer_names:
191
+ is_reachable = True
192
+ break
193
+ if is_reachable:
194
+ reachable_nodes.add(idx)
195
+ for out in node.outputs:
196
+ name = getattr(out, 'name', '')
197
+ if name:
198
+ reachable_vars.add(name)
199
+ return reachable_nodes
200
+
201
+ def _collect_constant_only_node_indices(
202
+ graph: gs.Graph,
203
+ initializer_names: Optional[set] = None,
204
+ ) -> set:
205
+ initializer_names = initializer_names or set()
206
+ const_only_nodes: set = set()
207
+ for idx, node in enumerate(graph.nodes):
208
+ has_variable_input = False
209
+ for inp in node.inputs:
210
+ if isinstance(inp, gs.Constant):
211
+ continue
212
+ name = getattr(inp, 'name', '')
213
+ if name and name not in initializer_names:
214
+ has_variable_input = True
215
+ break
216
+ if not has_variable_input:
217
+ const_only_nodes.add(idx)
218
+ return const_only_nodes
219
+
220
+ def _complete_custom_inputs_for_graph(
221
+ *,
222
+ onnx_graph: onnx.ModelProto,
223
+ custom_inputs: List[List[Any]],
224
+ output_dir: str,
225
+ file_prefix: str,
226
+ shape_hints: Optional[List[str]] = None,
227
+ require_mean_std: bool = False,
228
+ ) -> List[List[Any]]:
229
+ gs_graph = gs.import_onnx(onnx_graph)
230
+ input_names: List[str] = [inp.name for inp in gs_graph.inputs]
231
+ input_sizes: List[List[Any]] = [inp.shape for inp in gs_graph.inputs]
232
+ input_dtypes: List[Any] = [inp.dtype for inp in gs_graph.inputs]
233
+
234
+ if shape_hints is None:
235
+ new_input_sizes = []
236
+ for input_size in input_sizes:
237
+ new_input_size = []
238
+ for idx, dim in enumerate(input_size):
239
+ if idx == 0 and input_sizes and input_sizes[0][0] is not None \
240
+ and not isinstance(input_sizes[0][0], str) \
241
+ and len(input_sizes[0]) == len(input_size) \
242
+ and (dim is None or isinstance(dim, str)):
243
+ new_input_size.append(input_sizes[0][0])
244
+ elif dim is None or isinstance(dim, str):
245
+ new_input_size.append(1)
246
+ else:
247
+ new_input_size.append(dim)
248
+ new_input_sizes.append(new_input_size)
249
+ input_sizes = new_input_sizes
250
+ else:
251
+ shape_hints_dict = {}
252
+ for hint in shape_hints:
253
+ parts = hint.split(':')
254
+ if len(parts) == 2:
255
+ input_name = parts[0]
256
+ shape_values = [int(val) for val in parts[1].split(',')]
257
+ shape_hints_dict[input_name] = shape_values
258
+ for i, (input_name, original_shape) in enumerate(zip(input_names, input_sizes)):
259
+ if input_name in shape_hints_dict:
260
+ updated_shape = shape_hints_dict[input_name]
261
+ for j, (orig_dim, hint_dim) in enumerate(zip(original_shape, updated_shape)):
262
+ if orig_dim is not None and not isinstance(orig_dim, str):
263
+ updated_shape[j] = orig_dim
264
+ else:
265
+ updated_shape[j] = hint_dim
266
+ input_sizes[i] = updated_shape
267
+
268
+ custom_map = {}
269
+ for item in custom_inputs or []:
270
+ if len(item) >= 2:
271
+ custom_map[item[0]] = item
272
+
273
+ results: List[List[Any]] = []
274
+ for input_name, input_size, input_dtype in zip(input_names, input_sizes, input_dtypes):
275
+ if input_name in custom_map:
276
+ item = list(custom_map[input_name])
277
+ if require_mean_std and len(item) == 2:
278
+ item = [item[0], item[1], 0.0, 1.0]
279
+ results.append(item)
280
+ continue
281
+ dtype = input_dtype if input_dtype is not None else np.float32
282
+ file_name = f'{file_prefix}_{_sanitize_split_input_name(input_name)}.npy'
283
+ file_path = os.path.join(output_dir, file_name)
284
+ mm = np.lib.format.open_memmap(
285
+ file_path,
286
+ mode='w+',
287
+ dtype=dtype,
288
+ shape=tuple(input_size),
289
+ )
290
+ mm[...] = 1
291
+ mm.flush()
292
+ if require_mean_std:
293
+ results.append([input_name, file_path, 0.0, 1.0])
294
+ else:
295
+ results.append([input_name, file_path])
296
+ return results
297
+
298
+ def _estimate_partition_weight_bytes(
299
+ *,
300
+ ranges: List[tuple],
301
+ node_weight_keys: List[List[str]],
302
+ weight_sizes: Dict[str, int],
303
+ ) -> List[int]:
304
+ partition_sizes: List[int] = []
305
+ for start_idx, end_idx in ranges:
306
+ seen: set = set()
307
+ total_bytes = 0
308
+ for idx in range(start_idx, end_idx + 1):
309
+ for key in node_weight_keys[idx]:
310
+ if key not in seen:
311
+ total_bytes += weight_sizes.get(key, 0)
312
+ seen.add(key)
313
+ partition_sizes.append(total_bytes)
314
+ return partition_sizes
315
+
316
+ def _build_partition_io(
317
+ *,
318
+ graph: gs.Graph,
319
+ ranges: List[tuple],
320
+ const_only_nodes: Optional[set] = None,
321
+ ) -> List[Dict[str, Any]]:
322
+ if not ranges:
323
+ return []
324
+ const_only_nodes = const_only_nodes or set()
325
+ producer_by_tensor: Dict[str, int] = {}
326
+ consumers_by_tensor: Dict[str, set] = {}
327
+ graph_output_names = [o.name for o in graph.outputs if o.name]
328
+ for idx, node in enumerate(graph.nodes):
329
+ for out in node.outputs:
330
+ name = getattr(out, 'name', '')
331
+ if name:
332
+ producer_by_tensor[name] = idx
333
+ for inp in node.inputs:
334
+ if isinstance(inp, gs.Constant):
335
+ continue
336
+ name = getattr(inp, 'name', '')
337
+ if not name:
338
+ continue
339
+ consumers_by_tensor.setdefault(name, set()).add(idx)
340
+
341
+ partitions: List[Dict[str, Any]] = []
342
+ for start_idx, end_idx in ranges:
343
+ node_idx_set = set(range(start_idx, end_idx + 1))
344
+ part_inputs: set = set()
345
+ part_outputs: set = set()
346
+ for idx in node_idx_set:
347
+ node = graph.nodes[idx]
348
+ for inp in node.inputs:
349
+ if isinstance(inp, gs.Constant):
350
+ continue
351
+ name = getattr(inp, 'name', '')
352
+ if not name:
353
+ continue
354
+ producer_idx = producer_by_tensor.get(name)
355
+ if producer_idx is None or producer_idx not in node_idx_set:
356
+ if producer_idx is not None and producer_idx in const_only_nodes:
357
+ continue
358
+ part_inputs.add(name)
359
+ for out in node.outputs:
360
+ name = getattr(out, 'name', '')
361
+ if not name:
362
+ continue
363
+ consumers = consumers_by_tensor.get(name, set())
364
+ if name in graph_output_names or any(c not in node_idx_set for c in consumers):
365
+ if idx in const_only_nodes and name not in graph_output_names:
366
+ continue
367
+ part_outputs.add(name)
368
+ partitions.append({
369
+ 'inputs': sorted(part_inputs),
370
+ 'outputs': sorted(part_outputs),
371
+ 'node_count': end_idx - start_idx + 1,
372
+ 'start_idx': start_idx,
373
+ 'end_idx': end_idx,
374
+ })
375
+ return partitions
376
+
377
+ def _merge_ranges_with_missing_io(
378
+ *,
379
+ graph: gs.Graph,
380
+ ranges: List[tuple],
381
+ const_only_nodes: Optional[set] = None,
382
+ ) -> tuple[List[tuple], List[Dict[str, Any]]]:
383
+ if not ranges:
384
+ return ranges, []
385
+ ranges = list(ranges)
386
+ const_only_nodes = const_only_nodes or set()
387
+ while True:
388
+ partitions = _build_partition_io(
389
+ graph=graph,
390
+ ranges=ranges,
391
+ const_only_nodes=const_only_nodes,
392
+ ) or []
393
+ if all(part['inputs'] and part['outputs'] for part in partitions):
394
+ return ranges, partitions
395
+ if len(ranges) <= 1:
396
+ return ranges, partitions
397
+ merged = False
398
+ for idx, part in enumerate(partitions):
399
+ if not part['inputs'] or not part['outputs']:
400
+ if idx > 0:
401
+ ranges[idx - 1] = (ranges[idx - 1][0], ranges[idx][1])
402
+ del ranges[idx]
403
+ else:
404
+ ranges[idx] = (ranges[idx][0], ranges[idx + 1][1])
405
+ del ranges[idx + 1]
406
+ merged = True
407
+ break
408
+ if not merged:
409
+ return ranges, partitions
410
+
65
411
  def fuse_expanded_qdq_to_qdq(
66
412
  *,
67
413
  graph: gs.Graph,
@@ -285,6 +631,7 @@ def convert(
285
631
  quant_norm_std: Optional[str] = '[[[[0.229, 0.224, 0.225]]]]',
286
632
  quant_type: Optional[str] = 'per-channel',
287
633
  custom_input_op_name_np_data_path: Optional[List] = None,
634
+ tf_input_cache: Optional[Dict[str, np.ndarray]] = None,
288
635
  input_quant_dtype: Optional[str] = 'int8',
289
636
  output_quant_dtype: Optional[str] = 'int8',
290
637
  not_use_onnxsim: Optional[bool] = False,
@@ -321,6 +668,8 @@ def convert(
321
668
  param_replacement_file: Optional[str] = '',
322
669
  auto_generate_json: Optional[bool] = False,
323
670
  auto_generate_json_on_error: Optional[bool] = False,
671
+ enable_auto_split_model: Optional[bool] = False,
672
+ auto_split_max_size_mb: Optional[int] = 1024,
324
673
  check_gpu_delegate_compatibility: Optional[bool] = False,
325
674
  check_onnx_tf_outputs_elementwise_close: Optional[bool] = False,
326
675
  check_onnx_tf_outputs_elementwise_close_full: Optional[bool] = False,
@@ -451,6 +800,10 @@ def convert(
451
800
  ["input2","input2.npy",[0.3],[0.07]],\n
452
801
  ]
453
802
 
803
+ tf_input_cache: Optional[Dict[str, np.ndarray]]
804
+ Cache of TF dummy inference inputs keyed by TF input tensor name.\n
805
+ Used to propagate TF outputs between auto-split partitions.\n
806
+
454
807
  input_quant_dtype: Optional[str]
455
808
  Input dtypes when doing Full INT8 Quantization.\n
456
809
  "int8"(default) or "uint8" or "float32"
@@ -682,6 +1035,15 @@ def convert(
682
1035
  This is now opt-in and requires explicitly enabling the feature.\n
683
1036
  Default: False
684
1037
 
1038
+ enable_auto_split_model: Optional[bool]
1039
+ Force auto split regardless of the ONNX file size.\n
1040
+ The target size is controlled by auto_split_max_size_mb.\n
1041
+ Default: False
1042
+
1043
+ auto_split_max_size_mb: Optional[int]
1044
+ Target maximum size per partition in MB based on ONNX initializer sizes.\n
1045
+ Default: 1024
1046
+
685
1047
  check_gpu_delegate_compatibility: Optional[bool]
686
1048
  Run TFLite ModelAnalyzer on the generated Float16 tflite model\n
687
1049
  to check if the model can be supported by GPU Delegate.
@@ -771,6 +1133,23 @@ def convert(
771
1133
  f'input_onnx_file_path: {input_onnx_file_path}'
772
1134
  )
773
1135
  sys.exit(1)
1136
+ auto_split_model = bool(enable_auto_split_model)
1137
+ if auto_split_model:
1138
+ info(
1139
+ Color.GREEN('Auto split forced by --enable_auto_split_model. ') +
1140
+ f'target={auto_split_max_size_mb} MB'
1141
+ )
1142
+ if onnx_graph is None and input_onnx_file_path and os.path.exists(input_onnx_file_path):
1143
+ try:
1144
+ onnx_file_size = os.path.getsize(input_onnx_file_path)
1145
+ if not auto_split_model and onnx_file_size > 2 * 1024 * 1024 * 1024:
1146
+ info(
1147
+ Color.GREEN('ONNX file exceeds 2GB; switching to auto-split mode. ') +
1148
+ f'size={onnx_file_size / (1024 * 1024 * 1024):.2f} GB'
1149
+ )
1150
+ auto_split_model = True
1151
+ except Exception:
1152
+ pass
774
1153
 
775
1154
  # Extracting onnx filenames
776
1155
  output_file_name = ''
@@ -917,9 +1296,8 @@ def convert(
917
1296
  exported_onnx_graph = gs.export_onnx(graph, do_type_check=False, **meta_data)
918
1297
  if metadata_props is not None:
919
1298
  exported_onnx_graph.metadata_props.extend(metadata_props)
920
- estimated_graph = onnx.shape_inference.infer_shapes(exported_onnx_graph)
921
- onnx.save(estimated_graph, f=input_onnx_file_path)
922
- del estimated_graph
1299
+ onnx.save(exported_onnx_graph, f=input_onnx_file_path)
1300
+ del exported_onnx_graph
923
1301
  except:
924
1302
  if tmp_graph is not None:
925
1303
  del tmp_graph
@@ -957,6 +1335,7 @@ def convert(
957
1335
  input_onnx_file_path=f'{input_onnx_file_path}',
958
1336
  onnx_graph=onnx_graph,
959
1337
  output_onnx_file_path=f'{input_onnx_file_path}',
1338
+ has_external_data=has_external_data,
960
1339
  non_verbose=True,
961
1340
  )
962
1341
  info(Color.GREEN(f'Automatic generation of each OP name complete!'))
@@ -978,9 +1357,24 @@ def convert(
978
1357
 
979
1358
  # Loading Graphs
980
1359
  # onnx_graph If specified, onnx_graph is processed first
1360
+ has_external_data = False
981
1361
  if not onnx_graph:
1362
+ has_external_data = check_has_external_data(input_onnx_file_path)
982
1363
  onnx_graph = onnx.load(input_onnx_file_path)
983
1364
 
1365
+ if not auto_split_model and onnx_graph is not None:
1366
+ try:
1367
+ initializer_sizes = _collect_initializer_sizes(onnx_graph)
1368
+ total_init_bytes = sum(initializer_sizes.values())
1369
+ if total_init_bytes > 2 * 1024 * 1024 * 1024:
1370
+ info(
1371
+ Color.GREEN('ONNX graph estimated initializer size exceeds 2GB; ') +
1372
+ f'switching to auto-split mode. size={total_init_bytes / (1024 * 1024 * 1024):.2f} GB'
1373
+ )
1374
+ auto_split_model = True
1375
+ except Exception:
1376
+ pass
1377
+
984
1378
  domain: str = onnx_graph.domain
985
1379
  ir_version: int = onnx_graph.ir_version
986
1380
  meta_data = {'domain': domain, 'ir_version': ir_version}
@@ -990,6 +1384,522 @@ def convert(
990
1384
  graph = gs.import_onnx(onnx_graph)
991
1385
  fuse_expanded_qdq_to_qdq(graph=graph)
992
1386
 
1387
+ # Auto split model by estimated weight size
1388
+ if auto_split_model:
1389
+ if input_names_to_interrupt_model_conversion or output_names_to_interrupt_model_conversion:
1390
+ error(
1391
+ 'Auto split cannot be used together with input_names_to_interrupt_model_conversion '
1392
+ 'or output_names_to_interrupt_model_conversion.'
1393
+ )
1394
+ sys.exit(1)
1395
+ if auto_split_max_size_mb is None or auto_split_max_size_mb <= 0:
1396
+ error(
1397
+ f'auto_split_max_size_mb must be greater than 0. auto_split_max_size_mb: {auto_split_max_size_mb}'
1398
+ )
1399
+ sys.exit(1)
1400
+ try:
1401
+ import sne4onnx
1402
+ except Exception:
1403
+ error(
1404
+ 'Auto split requires sne4onnx. pip install sne4onnx'
1405
+ )
1406
+ sys.exit(1)
1407
+ try:
1408
+ graph.toposort()
1409
+ except Exception:
1410
+ pass
1411
+
1412
+ onnx_graph_for_split = onnx_graph
1413
+ try:
1414
+ onnx_graph_for_split = gs.export_onnx(
1415
+ graph=graph,
1416
+ do_type_check=False,
1417
+ **meta_data,
1418
+ )
1419
+ if metadata_props is not None:
1420
+ onnx_graph_for_split.metadata_props.extend(metadata_props)
1421
+ except Exception:
1422
+ onnx_graph_for_split = onnx_graph
1423
+
1424
+ initializer_sizes = _collect_initializer_sizes(onnx_graph_for_split)
1425
+ node_weight_keys, weight_sizes = _collect_node_weight_keys(
1426
+ graph=graph,
1427
+ initializer_sizes=initializer_sizes,
1428
+ )
1429
+ const_only_nodes = _collect_constant_only_node_indices(
1430
+ graph,
1431
+ initializer_names=set(initializer_sizes.keys()),
1432
+ )
1433
+ reachable_node_indices = _collect_reachable_node_indices(
1434
+ graph,
1435
+ initializer_names=set(initializer_sizes.keys()),
1436
+ )
1437
+ max_size_bytes = int(auto_split_max_size_mb) * 1024 * 1024
1438
+ ranges = _auto_partition_ranges(
1439
+ node_weight_keys=node_weight_keys,
1440
+ weight_sizes=weight_sizes,
1441
+ max_size_bytes=max_size_bytes,
1442
+ reachable_node_indices=reachable_node_indices,
1443
+ )
1444
+ if len(ranges) > 1:
1445
+ ranges, partitions = _merge_ranges_with_missing_io(
1446
+ graph=graph,
1447
+ ranges=ranges,
1448
+ const_only_nodes=const_only_nodes,
1449
+ )
1450
+ if not partitions:
1451
+ warn(
1452
+ 'Auto split failed to determine partition boundaries. Proceeding without split.'
1453
+ )
1454
+ else:
1455
+ if any([not p['inputs'] or not p['outputs'] for p in partitions]):
1456
+ warn(
1457
+ 'Auto split produced partitions with missing inputs or outputs. '
1458
+ 'Some partitions may not be inferable.'
1459
+ )
1460
+ partition_sizes = _estimate_partition_weight_bytes(
1461
+ ranges=ranges,
1462
+ node_weight_keys=node_weight_keys,
1463
+ weight_sizes=weight_sizes,
1464
+ )
1465
+ try:
1466
+ op_type_list = list(set([node.op for node in graph.nodes]))
1467
+ local_use_cuda = sum(
1468
+ [1 if op_type in CUDA_ONLY_OPS else 0 for op_type in op_type_list]
1469
+ ) > 0
1470
+ except Exception:
1471
+ local_use_cuda = False
1472
+ info('')
1473
+ info(Color.REVERSE(f'Auto model partitioning enabled'), '=' * 44)
1474
+ info(
1475
+ Color.GREEN(f'Target partition size (estimated weights): ') +
1476
+ f'{auto_split_max_size_mb} MB'
1477
+ )
1478
+ for idx, part in enumerate(partitions):
1479
+ size_mb = partition_sizes[idx] / (1024 * 1024)
1480
+ info(
1481
+ f' part {idx+1}: nodes={part["node_count"]}, '
1482
+ f'est_weights={size_mb:.2f} MB, '
1483
+ f'inputs={len(part["inputs"])}, outputs={len(part["outputs"])}'
1484
+ )
1485
+ info(
1486
+ f' inputs: {", ".join(part["inputs"]) if part["inputs"] else "(none)"}'
1487
+ )
1488
+ info(
1489
+ f' outputs: {", ".join(part["outputs"]) if part["outputs"] else "(none)"}'
1490
+ )
1491
+
1492
+ split_input_cache: Dict[str, str] = {}
1493
+ split_input_dir = tempfile.mkdtemp(prefix='onnx2tf_split_')
1494
+ split_tf_input_cache: Dict[str, np.ndarray] = {}
1495
+ split_output_layouts: Dict[str, bool] = {}
1496
+
1497
+ def _sanitize_tf_input_name(name: str) -> str:
1498
+ if name is None:
1499
+ return ''
1500
+ sanitized = name.replace(':', '__')
1501
+ if output_signaturedefs or output_integer_quantized_tflite:
1502
+ sanitized = re.sub('^/', 'wa/', sanitized)
1503
+ return f'{sanitized}:0'
1504
+
1505
+ def _normalize_onnx_output_name(name: str) -> str:
1506
+ if name is None:
1507
+ return ''
1508
+ normalized = name.replace(':', '__')
1509
+ if output_signaturedefs or output_integer_quantized_tflite:
1510
+ normalized = re.sub('^/', '', normalized)
1511
+ return normalized
1512
+
1513
+ def _common_layout_perms(rank: int) -> List[tuple]:
1514
+ if rank == 3:
1515
+ return [(0, 1, 2), (0, 2, 1)]
1516
+ if rank == 4:
1517
+ return [(0, 1, 2, 3), (0, 2, 3, 1), (0, 3, 1, 2)]
1518
+ if rank == 5:
1519
+ return [(0, 1, 2, 3, 4), (0, 2, 3, 4, 1), (0, 4, 1, 2, 3)]
1520
+ return [tuple(range(rank))]
1521
+
1522
+ def _build_onnx_tf_output_map(
1523
+ *,
1524
+ onnx_output_names: List[str],
1525
+ tf_output_tensors: List[tf.Tensor],
1526
+ onnx_output_values: Optional[Dict[str, np.ndarray]] = None,
1527
+ tf_output_values: Optional[Dict[str, np.ndarray]] = None,
1528
+ ) -> Dict[str, tf.Tensor]:
1529
+ tf_by_base = {t.name.split(':')[0]: t for t in tf_output_tensors}
1530
+ tf_by_full = {t.name: t for t in tf_output_tensors}
1531
+ mapping: Dict[str, tf.Tensor] = {}
1532
+ used_tf: set = set()
1533
+ missing: List[str] = []
1534
+ for onnx_name in onnx_output_names:
1535
+ normalized = _normalize_onnx_output_name(onnx_name)
1536
+ candidates = [
1537
+ normalized,
1538
+ f'wa/{normalized}',
1539
+ ]
1540
+ tf_tensor = None
1541
+ for cand in candidates:
1542
+ if cand in tf_by_base:
1543
+ tf_tensor = tf_by_base[cand]
1544
+ break
1545
+ full_name = f'{cand}:0'
1546
+ if full_name in tf_by_full:
1547
+ tf_tensor = tf_by_full[full_name]
1548
+ break
1549
+ if tf_tensor is None:
1550
+ if onnx_name in tf_by_base:
1551
+ tf_tensor = tf_by_base[onnx_name]
1552
+ else:
1553
+ full_name = f'{onnx_name}:0'
1554
+ if full_name in tf_by_full:
1555
+ tf_tensor = tf_by_full[full_name]
1556
+ if tf_tensor is not None:
1557
+ mapping[onnx_name] = tf_tensor
1558
+ used_tf.add(tf_tensor.name)
1559
+ else:
1560
+ missing.append(onnx_name)
1561
+
1562
+ if onnx_output_values and tf_output_values and missing:
1563
+ tf_candidates = []
1564
+ for tf_tensor in tf_output_tensors:
1565
+ tf_val = tf_output_values.get(tf_tensor.name)
1566
+ if tf_val is None:
1567
+ tf_val = tf_output_values.get(tf_tensor.name.split(':')[0])
1568
+ if tf_val is not None:
1569
+ tf_candidates.append((tf_tensor, tf_val))
1570
+ still_missing = []
1571
+ for onnx_name in list(missing):
1572
+ onnx_val = onnx_output_values.get(onnx_name)
1573
+ if onnx_val is None:
1574
+ still_missing.append(onnx_name)
1575
+ continue
1576
+ best = None
1577
+ best_err = None
1578
+ for tf_tensor, tf_val in tf_candidates:
1579
+ if tf_tensor.name in used_tf:
1580
+ continue
1581
+ if tf_val.shape != onnx_val.shape:
1582
+ continue
1583
+ err = np.max(np.abs(onnx_val - tf_val))
1584
+ if best is None or err < best_err:
1585
+ best = tf_tensor
1586
+ best_err = err
1587
+ if best is None:
1588
+ for tf_tensor, tf_val in tf_candidates:
1589
+ if tf_tensor.name in used_tf:
1590
+ continue
1591
+ if tf_val.ndim != onnx_val.ndim:
1592
+ continue
1593
+ for perm in _common_layout_perms(tf_val.ndim):
1594
+ if tf_val.transpose(perm).shape != onnx_val.shape:
1595
+ continue
1596
+ err = np.max(np.abs(onnx_val - tf_val.transpose(perm)))
1597
+ if best is None or err < best_err:
1598
+ best = tf_tensor
1599
+ best_err = err
1600
+ if best is not None and best_err is not None and best_err <= 1e-3:
1601
+ mapping[onnx_name] = best
1602
+ used_tf.add(best.name)
1603
+ else:
1604
+ still_missing.append(onnx_name)
1605
+ missing = still_missing
1606
+ if missing:
1607
+ warn(
1608
+ 'Auto split output mapping failed for: ' +
1609
+ ', '.join(missing) +
1610
+ '. Output cache/layout may be incomplete.'
1611
+ )
1612
+ return mapping
1613
+
1614
+ def _onnx_output_shape_map(onnx_model: onnx.ModelProto) -> Dict[str, List[Optional[int]]]:
1615
+ shape_map: Dict[str, List[Optional[int]]] = {}
1616
+ try:
1617
+ for out in onnx_model.graph.output:
1618
+ dims: List[Optional[int]] = []
1619
+ t = out.type.tensor_type
1620
+ if t.HasField('shape'):
1621
+ for d in t.shape.dim:
1622
+ if d.dim_value > 0:
1623
+ dims.append(int(d.dim_value))
1624
+ elif d.dim_param:
1625
+ dims.append(None)
1626
+ else:
1627
+ dims.append(None)
1628
+ if dims:
1629
+ shape_map[out.name] = dims
1630
+ except Exception:
1631
+ pass
1632
+ return shape_map
1633
+
1634
+ def _infer_keep_shape(onnx_shape: List[Optional[int]], tf_shape: List[int]) -> Optional[bool]:
1635
+ if not onnx_shape or any(d is None for d in onnx_shape):
1636
+ return None
1637
+ if list(onnx_shape) == list(tf_shape):
1638
+ return True
1639
+ rank = len(onnx_shape)
1640
+ if rank == 3:
1641
+ if list(tf_shape) == [onnx_shape[0], onnx_shape[2], onnx_shape[1]]:
1642
+ return False
1643
+ elif rank == 4:
1644
+ if list(tf_shape) == [onnx_shape[0], onnx_shape[2], onnx_shape[3], onnx_shape[1]]:
1645
+ return False
1646
+ elif rank == 5:
1647
+ if list(tf_shape) == [onnx_shape[0], onnx_shape[2], onnx_shape[3], onnx_shape[4], onnx_shape[1]]:
1648
+ return False
1649
+ return None
1650
+
1651
+ def _merge_custom_inputs(user_inputs, auto_inputs):
1652
+ merged = []
1653
+ seen = set()
1654
+ if user_inputs:
1655
+ for item in user_inputs:
1656
+ if len(item) >= 2:
1657
+ merged.append(item)
1658
+ seen.add(item[0])
1659
+ for item in auto_inputs:
1660
+ if len(item) >= 2 and item[0] not in seen:
1661
+ merged.append(item)
1662
+ seen.add(item[0])
1663
+ return merged
1664
+
1665
+ base_kwargs = {
1666
+ 'input_onnx_file_path': input_onnx_file_path if input_onnx_file_path is not None else None,
1667
+ 'onnx_graph': onnx_graph,
1668
+ 'output_folder_path': output_folder_path,
1669
+ 'output_signaturedefs': output_signaturedefs,
1670
+ 'output_h5': output_h5,
1671
+ 'output_keras_v3': output_keras_v3,
1672
+ 'output_tfv1_pb': output_tfv1_pb,
1673
+ 'output_weights': output_weights,
1674
+ 'copy_onnx_input_output_names_to_tflite': copy_onnx_input_output_names_to_tflite,
1675
+ 'output_dynamic_range_quantized_tflite': output_dynamic_range_quantized_tflite,
1676
+ 'output_integer_quantized_tflite': output_integer_quantized_tflite,
1677
+ 'quant_norm_mean': quant_norm_mean,
1678
+ 'quant_norm_std': quant_norm_std,
1679
+ 'quant_type': quant_type,
1680
+ 'custom_input_op_name_np_data_path': custom_input_op_name_np_data_path,
1681
+ 'tf_input_cache': split_tf_input_cache,
1682
+ 'input_quant_dtype': input_quant_dtype,
1683
+ 'output_quant_dtype': output_quant_dtype,
1684
+ 'not_use_onnxsim': not_use_onnxsim,
1685
+ 'not_use_opname_auto_generate': not_use_opname_auto_generate,
1686
+ 'batch_size': batch_size,
1687
+ 'overwrite_input_shape': overwrite_input_shape,
1688
+ 'shape_hints': shape_hints,
1689
+ 'no_large_tensor': no_large_tensor,
1690
+ 'output_nms_with_dynamic_tensor': output_nms_with_dynamic_tensor,
1691
+ 'switch_nms_version': switch_nms_version,
1692
+ 'keep_ncw_or_nchw_or_ncdhw_input_names': keep_ncw_or_nchw_or_ncdhw_input_names,
1693
+ 'keep_nwc_or_nhwc_or_ndhwc_input_names': keep_nwc_or_nhwc_or_ndhwc_input_names,
1694
+ 'keep_shape_absolutely_input_names': keep_shape_absolutely_input_names,
1695
+ 'input_names_to_interrupt_model_conversion': None,
1696
+ 'output_names_to_interrupt_model_conversion': None,
1697
+ 'disable_group_convolution': disable_group_convolution,
1698
+ 'enable_accumulation_type_float16': enable_accumulation_type_float16,
1699
+ 'enable_batchmatmul_unfold': enable_batchmatmul_unfold,
1700
+ 'enable_rnn_unroll': enable_rnn_unroll,
1701
+ 'disable_suppression_flextranspose': disable_suppression_flextranspose,
1702
+ 'disable_strict_mode': disable_strict_mode,
1703
+ 'onnxruntime_output_memmap': onnxruntime_output_memmap,
1704
+ 'onnxruntime_output_memmap_dir': onnxruntime_output_memmap_dir,
1705
+ 'number_of_dimensions_after_flextranspose_compression': number_of_dimensions_after_flextranspose_compression,
1706
+ 'disable_suppression_flexstridedslice': disable_suppression_flexstridedslice,
1707
+ 'number_of_dimensions_after_flexstridedslice_compression': number_of_dimensions_after_flexstridedslice_compression,
1708
+ 'optimization_for_gpu_delegate': optimization_for_gpu_delegate,
1709
+ 'replace_argmax_to_reducemax_and_indices_is_int64': replace_argmax_to_reducemax_and_indices_is_int64,
1710
+ 'replace_argmax_to_reducemax_and_indices_is_float32': replace_argmax_to_reducemax_and_indices_is_float32,
1711
+ 'replace_argmax_to_fused_argmax_and_indices_is_int64': replace_argmax_to_fused_argmax_and_indices_is_int64,
1712
+ 'replace_argmax_to_fused_argmax_and_indices_is_float32': replace_argmax_to_fused_argmax_and_indices_is_float32,
1713
+ 'fused_argmax_scale_ratio': fused_argmax_scale_ratio,
1714
+ 'replace_to_pseudo_operators': replace_to_pseudo_operators,
1715
+ 'param_replacement_file': param_replacement_file,
1716
+ 'auto_generate_json': auto_generate_json,
1717
+ 'auto_generate_json_on_error': auto_generate_json_on_error,
1718
+ 'enable_auto_split_model': False,
1719
+ 'auto_split_max_size_mb': auto_split_max_size_mb,
1720
+ 'check_gpu_delegate_compatibility': check_gpu_delegate_compatibility,
1721
+ 'check_onnx_tf_outputs_elementwise_close': check_onnx_tf_outputs_elementwise_close,
1722
+ 'check_onnx_tf_outputs_elementwise_close_full': check_onnx_tf_outputs_elementwise_close_full,
1723
+ 'check_onnx_tf_outputs_sample_data_normalization': check_onnx_tf_outputs_sample_data_normalization,
1724
+ 'check_onnx_tf_outputs_elementwise_close_rtol': check_onnx_tf_outputs_elementwise_close_rtol,
1725
+ 'check_onnx_tf_outputs_elementwise_close_atol': check_onnx_tf_outputs_elementwise_close_atol,
1726
+ 'mvn_epsilon': mvn_epsilon,
1727
+ 'disable_model_save': disable_model_save,
1728
+ 'non_verbose': non_verbose,
1729
+ 'verbosity': verbosity,
1730
+ }
1731
+ base_kwargs['input_names_to_interrupt_model_conversion'] = None
1732
+ base_kwargs['output_names_to_interrupt_model_conversion'] = None
1733
+
1734
+ model_ret = None
1735
+ try:
1736
+ for idx, part in enumerate(partitions):
1737
+ part_output_values: Optional[Dict[str, np.ndarray]] = None
1738
+ part_output_folder = os.path.join(
1739
+ output_folder_path,
1740
+ f'part_{idx+1:04d}',
1741
+ )
1742
+ base_name = os.path.splitext(os.path.basename(input_onnx_file_path))[0] \
1743
+ if input_onnx_file_path else 'model'
1744
+ os.makedirs(part_output_folder, exist_ok=True)
1745
+ split_onnx_path = os.path.join(
1746
+ part_output_folder,
1747
+ f'{base_name}_part_{idx+1:04d}.onnx'
1748
+ )
1749
+ part_graph = sne4onnx.extraction(
1750
+ input_op_names=part['inputs'],
1751
+ output_op_names=part['outputs'],
1752
+ onnx_graph=onnx_graph_for_split,
1753
+ output_onnx_file_path=split_onnx_path,
1754
+ has_external_data=has_external_data,
1755
+ )
1756
+ auto_custom_inputs = []
1757
+ if split_input_cache:
1758
+ for input_name in part['inputs']:
1759
+ if input_name in split_input_cache:
1760
+ auto_custom_inputs.append([
1761
+ input_name,
1762
+ split_input_cache[input_name],
1763
+ ])
1764
+ merged_custom_inputs = _merge_custom_inputs(
1765
+ custom_input_op_name_np_data_path,
1766
+ auto_custom_inputs,
1767
+ )
1768
+ # For the first partition, keep the same behavior as non-split conversion.
1769
+ # Only user-provided custom inputs are used.
1770
+ if idx == 0 and not auto_custom_inputs:
1771
+ custom_inputs_for_part = merged_custom_inputs
1772
+ else:
1773
+ require_mean_std = bool(output_integer_quantized_tflite)
1774
+ custom_inputs_for_part = _complete_custom_inputs_for_graph(
1775
+ onnx_graph=part_graph,
1776
+ custom_inputs=merged_custom_inputs,
1777
+ output_dir=split_input_dir,
1778
+ file_prefix=f'part_{idx+1:04d}',
1779
+ shape_hints=shape_hints,
1780
+ require_mean_std=require_mean_std,
1781
+ )
1782
+
1783
+ # Propagate dummy outputs to next partitions
1784
+ try:
1785
+ has_inputs = len(part_graph.graph.input) > 0
1786
+ has_outputs = len(part_graph.graph.output) > 0
1787
+ if has_inputs and has_outputs:
1788
+ part_input_datas = {}
1789
+ part_outputs = dummy_onnx_inference(
1790
+ onnx_graph=part_graph,
1791
+ output_names=part['outputs'],
1792
+ test_data_nhwc=None,
1793
+ custom_input_op_name_np_data_path=custom_inputs_for_part,
1794
+ tf_layers_dict={},
1795
+ use_cuda=local_use_cuda,
1796
+ disable_strict_mode=disable_strict_mode,
1797
+ enable_ort_output_memmap=False,
1798
+ ort_output_memmap_dir=None,
1799
+ shape_hints=shape_hints,
1800
+ input_datas_for_validation=part_input_datas,
1801
+ )
1802
+ for input_name, input_value in part_input_datas.items():
1803
+ file_name = (
1804
+ f'part_{idx+1:04d}_' +
1805
+ f'{_sanitize_split_input_name(input_name)}.npy'
1806
+ )
1807
+ file_path = os.path.join(split_input_dir, file_name)
1808
+ split_input_cache[input_name] = _write_memmap_array(
1809
+ file_path,
1810
+ input_value,
1811
+ )
1812
+ part_output_values = {
1813
+ name: value for name, value in zip(part['outputs'], part_outputs)
1814
+ }
1815
+ for output_name, output_value in zip(part['outputs'], part_outputs):
1816
+ file_name = (
1817
+ f'part_{idx+1:04d}_' +
1818
+ f'{_sanitize_split_input_name(output_name)}.npy'
1819
+ )
1820
+ file_path = os.path.join(split_input_dir, file_name)
1821
+ split_input_cache[output_name] = _write_memmap_array(
1822
+ file_path,
1823
+ output_value,
1824
+ )
1825
+ else:
1826
+ warn(
1827
+ 'Auto split input propagation skipped for this partition '
1828
+ 'because it has no inputs or outputs.'
1829
+ )
1830
+ except Exception as ex:
1831
+ warn(
1832
+ 'Auto split input propagation failed for this partition. '
1833
+ 'Subsequent partitions may use default dummy inputs.'
1834
+ )
1835
+ warn(f'{ex}')
1836
+
1837
+ part_kwargs = dict(base_kwargs)
1838
+ if split_output_layouts:
1839
+ part_keep_shape_abs = set(keep_shape_absolutely_input_names or [])
1840
+ for input_name in part['inputs']:
1841
+ if split_output_layouts.get(input_name, False):
1842
+ part_keep_shape_abs.add(input_name)
1843
+ part_kwargs['keep_shape_absolutely_input_names'] = \
1844
+ list(part_keep_shape_abs) if part_keep_shape_abs else None
1845
+ part_kwargs['input_onnx_file_path'] = split_onnx_path
1846
+ part_kwargs['onnx_graph'] = part_graph
1847
+ part_kwargs['output_folder_path'] = part_output_folder
1848
+ if custom_inputs_for_part:
1849
+ part_kwargs['custom_input_op_name_np_data_path'] = custom_inputs_for_part
1850
+ model_ret = convert(**part_kwargs)
1851
+
1852
+ if hasattr(model_ret, 'onnx_output_layouts') \
1853
+ and isinstance(model_ret.onnx_output_layouts, dict):
1854
+ for out_name in part['outputs']:
1855
+ if out_name in model_ret.onnx_output_layouts:
1856
+ split_output_layouts[out_name] = \
1857
+ bool(model_ret.onnx_output_layouts[out_name])
1858
+
1859
+ # Cache TF outputs for the next partition's TF dummy inference.
1860
+ try:
1861
+ tf_outputs = dummy_tf_inference(
1862
+ model=model_ret,
1863
+ inputs=model_ret.inputs,
1864
+ test_data_nhwc=None,
1865
+ custom_input_op_name_np_data_path=custom_input_op_name_np_data_path,
1866
+ prefilled_input_datas=split_tf_input_cache,
1867
+ shape_hints=shape_hints,
1868
+ keep_shape_absolutely_input_names=keep_shape_absolutely_input_names,
1869
+ keep_ncw_or_nchw_or_ncdhw_input_names=keep_ncw_or_nchw_or_ncdhw_input_names,
1870
+ keep_nwc_or_nhwc_or_ndhwc_input_names=keep_nwc_or_nhwc_or_ndhwc_input_names,
1871
+ )
1872
+ onnx_output_shapes = _onnx_output_shape_map(part_graph)
1873
+ if model_ret.outputs and part['outputs']:
1874
+ tf_output_map = _build_onnx_tf_output_map(
1875
+ onnx_output_names=part['outputs'],
1876
+ tf_output_tensors=model_ret.outputs,
1877
+ onnx_output_values=part_output_values,
1878
+ tf_output_values=tf_outputs,
1879
+ )
1880
+ for onnx_out, tf_tensor in tf_output_map.items():
1881
+ tf_val = tf_outputs.get(tf_tensor.name)
1882
+ if tf_val is None:
1883
+ continue
1884
+ # Store both full and base TF names to maximize cache hits.
1885
+ split_tf_input_cache[tf_tensor.name] = tf_val
1886
+ split_tf_input_cache[tf_tensor.name.split(':')[0]] = tf_val
1887
+ # Keep legacy key for compatibility with existing lookups.
1888
+ sanitized = _sanitize_tf_input_name(onnx_out)
1889
+ split_tf_input_cache[sanitized] = tf_val
1890
+ split_tf_input_cache[sanitized.split(':')[0]] = tf_val
1891
+ keep_shape = _infer_keep_shape(
1892
+ onnx_output_shapes.get(onnx_out),
1893
+ list(tf_val.shape),
1894
+ )
1895
+ if keep_shape is not None:
1896
+ split_output_layouts[onnx_out] = keep_shape
1897
+ except Exception:
1898
+ pass
1899
+ finally:
1900
+ shutil.rmtree(split_input_dir, ignore_errors=True)
1901
+ return model_ret
1902
+
993
1903
  # Cut the ONNX graph when an input name is specified that interrupts the conversion
994
1904
  if not input_names_to_interrupt_model_conversion:
995
1905
  input_names = [
@@ -1218,6 +2128,7 @@ def convert(
1218
2128
  'relu_relu6_merge_op_names': {},
1219
2129
  'mul_div_replace_op_names': {},
1220
2130
  'use_cuda': use_cuda,
2131
+ 'tf_input_cache': tf_input_cache,
1221
2132
  }
1222
2133
 
1223
2134
  tf_layers_dict = {}
@@ -1291,30 +2202,29 @@ def convert(
1291
2202
  )
1292
2203
 
1293
2204
  # download test data
1294
- all_four_dim = sum(
1295
- [
1296
- 1 for input in inputs \
1297
- if len(input.shape) == 4 \
1298
- and input.shape[0] is not None \
1299
- and input.shape[0] <= 20 \
1300
- and input.shape[-1] == 3 \
1301
- and input.shape[1] is not None \
1302
- and input.shape[2] is not None
1303
- ]
1304
- ) == len(inputs)
1305
- same_batch_dim = False
1306
- if all_four_dim:
1307
- batch_size = inputs[0].shape[0]
1308
- for input in inputs:
1309
- same_batch_dim = batch_size == input.shape[0]
1310
2205
  test_data_nhwc = None
1311
- if all_four_dim and same_batch_dim:
1312
- test_data: np.ndarray = download_test_image_data()
1313
- test_data_nhwc = test_data[:inputs[0].shape[0], ...]
1314
- if check_onnx_tf_outputs_sample_data_normalization == "norm":
1315
- pass
1316
- elif check_onnx_tf_outputs_sample_data_normalization == "denorm":
1317
- test_data_nhwc = test_data_nhwc * 255.0
2206
+ if inputs:
2207
+ all_four_dim = sum(
2208
+ [
2209
+ 1 for input in inputs \
2210
+ if len(input.shape) == 4 \
2211
+ and input.shape[0] is not None \
2212
+ and input.shape[0] <= 20 \
2213
+ and input.shape[-1] == 3 \
2214
+ and input.shape[1] is not None \
2215
+ and input.shape[2] is not None
2216
+ ]
2217
+ ) == len(inputs)
2218
+ same_batch_dim = False
2219
+ if all_four_dim:
2220
+ batch_size = inputs[0].shape[0]
2221
+ for input in inputs:
2222
+ same_batch_dim = batch_size == input.shape[0]
2223
+ if all_four_dim and same_batch_dim:
2224
+ test_data: np.ndarray = download_test_image_data()
2225
+ test_data_nhwc = test_data[:inputs[0].shape[0], ...]
2226
+ if check_onnx_tf_outputs_sample_data_normalization == "denorm":
2227
+ test_data_nhwc = test_data_nhwc * 255.0
1318
2228
 
1319
2229
  # ONNX dummy inference
1320
2230
  # Generate output for all OPs.
@@ -1400,7 +2310,10 @@ def convert(
1400
2310
  exported_onnx_graph = gs.export_onnx(graph, do_type_check=False, **meta_data)
1401
2311
  if metadata_props is not None:
1402
2312
  exported_onnx_graph.metadata_props.extend(metadata_props)
1403
- estimated_graph = onnx.shape_inference.infer_shapes(exported_onnx_graph)
2313
+ if not has_external_data:
2314
+ estimated_graph = onnx.shape_inference.infer_shapes(exported_onnx_graph)
2315
+ else:
2316
+ estimated_graph = exported_onnx_graph
1404
2317
  if input_onnx_file_path is not None:
1405
2318
  onnx.save(estimated_graph, input_onnx_file_path)
1406
2319
  if not not_use_onnxsim:
@@ -1580,6 +2493,14 @@ def convert(
1580
2493
  outputs[oidx] = tf_keras.layers.Lambda(lambda x: tf.constant(y))(x)
1581
2494
 
1582
2495
  model = tf_keras.Model(inputs=inputs, outputs=outputs)
2496
+ try:
2497
+ onnx_output_layouts = {
2498
+ name: tf_layers_dict.get(name, {}).get('nhwc', False)
2499
+ for name in onnx_graph_output_names
2500
+ }
2501
+ model.onnx_output_layouts = onnx_output_layouts
2502
+ except Exception:
2503
+ pass
1583
2504
  debug('')
1584
2505
 
1585
2506
  # The process ends normally without saving the model.
@@ -3050,6 +3971,23 @@ def main():
3050
3971
  'e.g. \n' +
3051
3972
  '--output_names_to_interrupt_model_conversion "output0" "output1" "output2"'
3052
3973
  )
3974
+ parser.add_argument(
3975
+ '-easm',
3976
+ '--enable_auto_split_model',
3977
+ action='store_true',
3978
+ help=\
3979
+ 'Force auto split regardless of the ONNX file size. \n' +
3980
+ 'Uses --auto_split_max_size_mb as the target partition size.'
3981
+ )
3982
+ parser.add_argument(
3983
+ '-asmsm',
3984
+ '--auto_split_max_size_mb',
3985
+ type=int,
3986
+ default=1024,
3987
+ help=\
3988
+ 'Target maximum size per partition in MB based on ONNX initializer sizes. \n' +
3989
+ 'Used when auto-split is triggered or forced.'
3990
+ )
3053
3991
  parser.add_argument(
3054
3992
  '-dgc',
3055
3993
  '--disable_group_convolution',
@@ -3450,6 +4388,8 @@ def main():
3450
4388
  param_replacement_file=args.param_replacement_file,
3451
4389
  auto_generate_json=args.auto_generate_json,
3452
4390
  auto_generate_json_on_error=args.auto_generate_json_on_error,
4391
+ enable_auto_split_model=args.enable_auto_split_model,
4392
+ auto_split_max_size_mb=args.auto_split_max_size_mb,
3453
4393
  check_gpu_delegate_compatibility=args.check_gpu_delegate_compatibility,
3454
4394
  check_onnx_tf_outputs_elementwise_close=args.check_onnx_tf_outputs_elementwise_close,
3455
4395
  check_onnx_tf_outputs_elementwise_close_full=args.check_onnx_tf_outputs_elementwise_close_full,