onnx2tf 1.29.19__py3-none-any.whl → 1.29.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2tf/__init__.py +1 -1
- onnx2tf/onnx2tf.py +967 -27
- onnx2tf/ops/GatherElements.py +25 -7
- onnx2tf/ops/GatherND.py +28 -1
- onnx2tf/ops/ScatterElements.py +25 -7
- onnx2tf/ops/ScatterND.py +45 -6
- onnx2tf/ops/TensorScatter.py +20 -6
- onnx2tf/utils/common_functions.py +99 -2
- {onnx2tf-1.29.19.dist-info → onnx2tf-1.29.20.dist-info}/METADATA +25 -3
- {onnx2tf-1.29.19.dist-info → onnx2tf-1.29.20.dist-info}/RECORD +12 -12
- {onnx2tf-1.29.19.dist-info → onnx2tf-1.29.20.dist-info}/WHEEL +0 -0
- {onnx2tf-1.29.19.dist-info → onnx2tf-1.29.20.dist-info}/entry_points.txt +0 -0
onnx2tf/onnx2tf.py
CHANGED
|
@@ -2,6 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import re
|
|
5
|
+
import shutil
|
|
6
|
+
import tempfile
|
|
5
7
|
__path__ = (os.path.dirname(__file__), )
|
|
6
8
|
with open(os.path.join(__path__[0], '__init__.py')) as f:
|
|
7
9
|
init_text = f.read()
|
|
@@ -51,6 +53,7 @@ from onnx2tf.utils.common_functions import (
|
|
|
51
53
|
get_tf_model_outputs,
|
|
52
54
|
rewrite_tflite_inout_opname,
|
|
53
55
|
check_cuda_enabled,
|
|
56
|
+
check_has_external_data,
|
|
54
57
|
)
|
|
55
58
|
from onnx2tf.utils.json_auto_generator import (
|
|
56
59
|
generate_auto_replacement_json,
|
|
@@ -62,6 +65,349 @@ from onnx2tf.utils.enums import (
|
|
|
62
65
|
from onnx2tf.utils.logging import *
|
|
63
66
|
from sng4onnx import generate as op_name_auto_generate
|
|
64
67
|
|
|
68
|
+
def _sanitize_split_input_name(name: str) -> str:
|
|
69
|
+
if not name:
|
|
70
|
+
return 'tensor'
|
|
71
|
+
return re.sub(r'[^0-9A-Za-z._-]+', '_', name)
|
|
72
|
+
|
|
73
|
+
def _write_memmap_array(path: str, array: np.ndarray) -> str:
|
|
74
|
+
mm = np.lib.format.open_memmap(
|
|
75
|
+
path,
|
|
76
|
+
mode='w+',
|
|
77
|
+
dtype=array.dtype,
|
|
78
|
+
shape=array.shape,
|
|
79
|
+
)
|
|
80
|
+
mm[...] = array
|
|
81
|
+
mm.flush()
|
|
82
|
+
return path
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _tensorproto_nbytes(tensor: onnx.TensorProto) -> int:
|
|
86
|
+
if tensor is None:
|
|
87
|
+
return 0
|
|
88
|
+
if tensor.HasField('raw_data'):
|
|
89
|
+
return len(tensor.raw_data)
|
|
90
|
+
try:
|
|
91
|
+
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor.data_type)
|
|
92
|
+
except Exception:
|
|
93
|
+
np_dtype = None
|
|
94
|
+
if np_dtype is None:
|
|
95
|
+
return 0
|
|
96
|
+
elem_size = np.dtype(np_dtype).itemsize
|
|
97
|
+
num_elems = int(np.prod(tensor.dims)) if len(tensor.dims) > 0 else 0
|
|
98
|
+
if num_elems == 0:
|
|
99
|
+
try:
|
|
100
|
+
field_name = onnx.helper.tensor_dtype_to_field(tensor.data_type)
|
|
101
|
+
if hasattr(tensor, field_name):
|
|
102
|
+
num_elems = len(getattr(tensor, field_name))
|
|
103
|
+
except Exception:
|
|
104
|
+
num_elems = 0
|
|
105
|
+
return num_elems * elem_size
|
|
106
|
+
|
|
107
|
+
def _collect_initializer_sizes(onnx_graph: onnx.ModelProto) -> Dict[str, int]:
|
|
108
|
+
initializer_sizes: Dict[str, int] = {}
|
|
109
|
+
if onnx_graph is None:
|
|
110
|
+
return initializer_sizes
|
|
111
|
+
for initializer in onnx_graph.graph.initializer:
|
|
112
|
+
if not initializer.name:
|
|
113
|
+
continue
|
|
114
|
+
try:
|
|
115
|
+
initializer_sizes[initializer.name] = _tensorproto_nbytes(initializer)
|
|
116
|
+
except Exception:
|
|
117
|
+
initializer_sizes[initializer.name] = 0
|
|
118
|
+
return initializer_sizes
|
|
119
|
+
|
|
120
|
+
def _collect_node_weight_keys(
|
|
121
|
+
*,
|
|
122
|
+
graph: gs.Graph,
|
|
123
|
+
initializer_sizes: Dict[str, int],
|
|
124
|
+
) -> tuple[List[List[str]], Dict[str, int]]:
|
|
125
|
+
weight_sizes = dict(initializer_sizes)
|
|
126
|
+
node_weight_keys: List[List[str]] = []
|
|
127
|
+
for node in graph.nodes:
|
|
128
|
+
keys: List[str] = []
|
|
129
|
+
for inp in node.inputs:
|
|
130
|
+
if isinstance(inp, gs.Constant):
|
|
131
|
+
if isinstance(getattr(inp, 'values', None), np.ndarray):
|
|
132
|
+
key = f'const:{id(inp)}'
|
|
133
|
+
if key not in weight_sizes:
|
|
134
|
+
weight_sizes[key] = int(inp.values.nbytes)
|
|
135
|
+
keys.append(key)
|
|
136
|
+
continue
|
|
137
|
+
name = getattr(inp, 'name', '')
|
|
138
|
+
if name and name in initializer_sizes:
|
|
139
|
+
keys.append(name)
|
|
140
|
+
node_weight_keys.append(keys)
|
|
141
|
+
return node_weight_keys, weight_sizes
|
|
142
|
+
|
|
143
|
+
def _auto_partition_ranges(
|
|
144
|
+
*,
|
|
145
|
+
node_weight_keys: List[List[str]],
|
|
146
|
+
weight_sizes: Dict[str, int],
|
|
147
|
+
max_size_bytes: int,
|
|
148
|
+
reachable_node_indices: Optional[set] = None,
|
|
149
|
+
) -> List[tuple]:
|
|
150
|
+
ranges: List[tuple] = []
|
|
151
|
+
if max_size_bytes <= 0 or not node_weight_keys:
|
|
152
|
+
return ranges
|
|
153
|
+
current_keys: set = set()
|
|
154
|
+
current_bytes = 0
|
|
155
|
+
start_idx = 0
|
|
156
|
+
for idx, keys in enumerate(node_weight_keys):
|
|
157
|
+
new_bytes = 0
|
|
158
|
+
for key in keys:
|
|
159
|
+
if key not in current_keys:
|
|
160
|
+
new_bytes += weight_sizes.get(key, 0)
|
|
161
|
+
current_keys.add(key)
|
|
162
|
+
current_bytes += new_bytes
|
|
163
|
+
if current_bytes >= max_size_bytes and idx > start_idx:
|
|
164
|
+
if reachable_node_indices is not None and idx not in reachable_node_indices:
|
|
165
|
+
continue
|
|
166
|
+
ranges.append((start_idx, idx))
|
|
167
|
+
start_idx = idx + 1
|
|
168
|
+
current_keys = set()
|
|
169
|
+
current_bytes = 0
|
|
170
|
+
if start_idx <= len(node_weight_keys) - 1:
|
|
171
|
+
ranges.append((start_idx, len(node_weight_keys) - 1))
|
|
172
|
+
return ranges
|
|
173
|
+
|
|
174
|
+
def _collect_reachable_node_indices(
|
|
175
|
+
graph: gs.Graph,
|
|
176
|
+
initializer_names: Optional[set] = None,
|
|
177
|
+
) -> set:
|
|
178
|
+
reachable_nodes: set = set()
|
|
179
|
+
reachable_vars: set = set()
|
|
180
|
+
initializer_names = initializer_names or set()
|
|
181
|
+
for graph_input in graph.inputs:
|
|
182
|
+
name = getattr(graph_input, 'name', '')
|
|
183
|
+
if name and name not in initializer_names:
|
|
184
|
+
reachable_vars.add(name)
|
|
185
|
+
for idx, node in enumerate(graph.nodes):
|
|
186
|
+
is_reachable = False
|
|
187
|
+
for inp in node.inputs:
|
|
188
|
+
if isinstance(inp, gs.Variable):
|
|
189
|
+
name = getattr(inp, 'name', '')
|
|
190
|
+
if name in reachable_vars and name not in initializer_names:
|
|
191
|
+
is_reachable = True
|
|
192
|
+
break
|
|
193
|
+
if is_reachable:
|
|
194
|
+
reachable_nodes.add(idx)
|
|
195
|
+
for out in node.outputs:
|
|
196
|
+
name = getattr(out, 'name', '')
|
|
197
|
+
if name:
|
|
198
|
+
reachable_vars.add(name)
|
|
199
|
+
return reachable_nodes
|
|
200
|
+
|
|
201
|
+
def _collect_constant_only_node_indices(
|
|
202
|
+
graph: gs.Graph,
|
|
203
|
+
initializer_names: Optional[set] = None,
|
|
204
|
+
) -> set:
|
|
205
|
+
initializer_names = initializer_names or set()
|
|
206
|
+
const_only_nodes: set = set()
|
|
207
|
+
for idx, node in enumerate(graph.nodes):
|
|
208
|
+
has_variable_input = False
|
|
209
|
+
for inp in node.inputs:
|
|
210
|
+
if isinstance(inp, gs.Constant):
|
|
211
|
+
continue
|
|
212
|
+
name = getattr(inp, 'name', '')
|
|
213
|
+
if name and name not in initializer_names:
|
|
214
|
+
has_variable_input = True
|
|
215
|
+
break
|
|
216
|
+
if not has_variable_input:
|
|
217
|
+
const_only_nodes.add(idx)
|
|
218
|
+
return const_only_nodes
|
|
219
|
+
|
|
220
|
+
def _complete_custom_inputs_for_graph(
|
|
221
|
+
*,
|
|
222
|
+
onnx_graph: onnx.ModelProto,
|
|
223
|
+
custom_inputs: List[List[Any]],
|
|
224
|
+
output_dir: str,
|
|
225
|
+
file_prefix: str,
|
|
226
|
+
shape_hints: Optional[List[str]] = None,
|
|
227
|
+
require_mean_std: bool = False,
|
|
228
|
+
) -> List[List[Any]]:
|
|
229
|
+
gs_graph = gs.import_onnx(onnx_graph)
|
|
230
|
+
input_names: List[str] = [inp.name for inp in gs_graph.inputs]
|
|
231
|
+
input_sizes: List[List[Any]] = [inp.shape for inp in gs_graph.inputs]
|
|
232
|
+
input_dtypes: List[Any] = [inp.dtype for inp in gs_graph.inputs]
|
|
233
|
+
|
|
234
|
+
if shape_hints is None:
|
|
235
|
+
new_input_sizes = []
|
|
236
|
+
for input_size in input_sizes:
|
|
237
|
+
new_input_size = []
|
|
238
|
+
for idx, dim in enumerate(input_size):
|
|
239
|
+
if idx == 0 and input_sizes and input_sizes[0][0] is not None \
|
|
240
|
+
and not isinstance(input_sizes[0][0], str) \
|
|
241
|
+
and len(input_sizes[0]) == len(input_size) \
|
|
242
|
+
and (dim is None or isinstance(dim, str)):
|
|
243
|
+
new_input_size.append(input_sizes[0][0])
|
|
244
|
+
elif dim is None or isinstance(dim, str):
|
|
245
|
+
new_input_size.append(1)
|
|
246
|
+
else:
|
|
247
|
+
new_input_size.append(dim)
|
|
248
|
+
new_input_sizes.append(new_input_size)
|
|
249
|
+
input_sizes = new_input_sizes
|
|
250
|
+
else:
|
|
251
|
+
shape_hints_dict = {}
|
|
252
|
+
for hint in shape_hints:
|
|
253
|
+
parts = hint.split(':')
|
|
254
|
+
if len(parts) == 2:
|
|
255
|
+
input_name = parts[0]
|
|
256
|
+
shape_values = [int(val) for val in parts[1].split(',')]
|
|
257
|
+
shape_hints_dict[input_name] = shape_values
|
|
258
|
+
for i, (input_name, original_shape) in enumerate(zip(input_names, input_sizes)):
|
|
259
|
+
if input_name in shape_hints_dict:
|
|
260
|
+
updated_shape = shape_hints_dict[input_name]
|
|
261
|
+
for j, (orig_dim, hint_dim) in enumerate(zip(original_shape, updated_shape)):
|
|
262
|
+
if orig_dim is not None and not isinstance(orig_dim, str):
|
|
263
|
+
updated_shape[j] = orig_dim
|
|
264
|
+
else:
|
|
265
|
+
updated_shape[j] = hint_dim
|
|
266
|
+
input_sizes[i] = updated_shape
|
|
267
|
+
|
|
268
|
+
custom_map = {}
|
|
269
|
+
for item in custom_inputs or []:
|
|
270
|
+
if len(item) >= 2:
|
|
271
|
+
custom_map[item[0]] = item
|
|
272
|
+
|
|
273
|
+
results: List[List[Any]] = []
|
|
274
|
+
for input_name, input_size, input_dtype in zip(input_names, input_sizes, input_dtypes):
|
|
275
|
+
if input_name in custom_map:
|
|
276
|
+
item = list(custom_map[input_name])
|
|
277
|
+
if require_mean_std and len(item) == 2:
|
|
278
|
+
item = [item[0], item[1], 0.0, 1.0]
|
|
279
|
+
results.append(item)
|
|
280
|
+
continue
|
|
281
|
+
dtype = input_dtype if input_dtype is not None else np.float32
|
|
282
|
+
file_name = f'{file_prefix}_{_sanitize_split_input_name(input_name)}.npy'
|
|
283
|
+
file_path = os.path.join(output_dir, file_name)
|
|
284
|
+
mm = np.lib.format.open_memmap(
|
|
285
|
+
file_path,
|
|
286
|
+
mode='w+',
|
|
287
|
+
dtype=dtype,
|
|
288
|
+
shape=tuple(input_size),
|
|
289
|
+
)
|
|
290
|
+
mm[...] = 1
|
|
291
|
+
mm.flush()
|
|
292
|
+
if require_mean_std:
|
|
293
|
+
results.append([input_name, file_path, 0.0, 1.0])
|
|
294
|
+
else:
|
|
295
|
+
results.append([input_name, file_path])
|
|
296
|
+
return results
|
|
297
|
+
|
|
298
|
+
def _estimate_partition_weight_bytes(
|
|
299
|
+
*,
|
|
300
|
+
ranges: List[tuple],
|
|
301
|
+
node_weight_keys: List[List[str]],
|
|
302
|
+
weight_sizes: Dict[str, int],
|
|
303
|
+
) -> List[int]:
|
|
304
|
+
partition_sizes: List[int] = []
|
|
305
|
+
for start_idx, end_idx in ranges:
|
|
306
|
+
seen: set = set()
|
|
307
|
+
total_bytes = 0
|
|
308
|
+
for idx in range(start_idx, end_idx + 1):
|
|
309
|
+
for key in node_weight_keys[idx]:
|
|
310
|
+
if key not in seen:
|
|
311
|
+
total_bytes += weight_sizes.get(key, 0)
|
|
312
|
+
seen.add(key)
|
|
313
|
+
partition_sizes.append(total_bytes)
|
|
314
|
+
return partition_sizes
|
|
315
|
+
|
|
316
|
+
def _build_partition_io(
|
|
317
|
+
*,
|
|
318
|
+
graph: gs.Graph,
|
|
319
|
+
ranges: List[tuple],
|
|
320
|
+
const_only_nodes: Optional[set] = None,
|
|
321
|
+
) -> List[Dict[str, Any]]:
|
|
322
|
+
if not ranges:
|
|
323
|
+
return []
|
|
324
|
+
const_only_nodes = const_only_nodes or set()
|
|
325
|
+
producer_by_tensor: Dict[str, int] = {}
|
|
326
|
+
consumers_by_tensor: Dict[str, set] = {}
|
|
327
|
+
graph_output_names = [o.name for o in graph.outputs if o.name]
|
|
328
|
+
for idx, node in enumerate(graph.nodes):
|
|
329
|
+
for out in node.outputs:
|
|
330
|
+
name = getattr(out, 'name', '')
|
|
331
|
+
if name:
|
|
332
|
+
producer_by_tensor[name] = idx
|
|
333
|
+
for inp in node.inputs:
|
|
334
|
+
if isinstance(inp, gs.Constant):
|
|
335
|
+
continue
|
|
336
|
+
name = getattr(inp, 'name', '')
|
|
337
|
+
if not name:
|
|
338
|
+
continue
|
|
339
|
+
consumers_by_tensor.setdefault(name, set()).add(idx)
|
|
340
|
+
|
|
341
|
+
partitions: List[Dict[str, Any]] = []
|
|
342
|
+
for start_idx, end_idx in ranges:
|
|
343
|
+
node_idx_set = set(range(start_idx, end_idx + 1))
|
|
344
|
+
part_inputs: set = set()
|
|
345
|
+
part_outputs: set = set()
|
|
346
|
+
for idx in node_idx_set:
|
|
347
|
+
node = graph.nodes[idx]
|
|
348
|
+
for inp in node.inputs:
|
|
349
|
+
if isinstance(inp, gs.Constant):
|
|
350
|
+
continue
|
|
351
|
+
name = getattr(inp, 'name', '')
|
|
352
|
+
if not name:
|
|
353
|
+
continue
|
|
354
|
+
producer_idx = producer_by_tensor.get(name)
|
|
355
|
+
if producer_idx is None or producer_idx not in node_idx_set:
|
|
356
|
+
if producer_idx is not None and producer_idx in const_only_nodes:
|
|
357
|
+
continue
|
|
358
|
+
part_inputs.add(name)
|
|
359
|
+
for out in node.outputs:
|
|
360
|
+
name = getattr(out, 'name', '')
|
|
361
|
+
if not name:
|
|
362
|
+
continue
|
|
363
|
+
consumers = consumers_by_tensor.get(name, set())
|
|
364
|
+
if name in graph_output_names or any(c not in node_idx_set for c in consumers):
|
|
365
|
+
if idx in const_only_nodes and name not in graph_output_names:
|
|
366
|
+
continue
|
|
367
|
+
part_outputs.add(name)
|
|
368
|
+
partitions.append({
|
|
369
|
+
'inputs': sorted(part_inputs),
|
|
370
|
+
'outputs': sorted(part_outputs),
|
|
371
|
+
'node_count': end_idx - start_idx + 1,
|
|
372
|
+
'start_idx': start_idx,
|
|
373
|
+
'end_idx': end_idx,
|
|
374
|
+
})
|
|
375
|
+
return partitions
|
|
376
|
+
|
|
377
|
+
def _merge_ranges_with_missing_io(
|
|
378
|
+
*,
|
|
379
|
+
graph: gs.Graph,
|
|
380
|
+
ranges: List[tuple],
|
|
381
|
+
const_only_nodes: Optional[set] = None,
|
|
382
|
+
) -> tuple[List[tuple], List[Dict[str, Any]]]:
|
|
383
|
+
if not ranges:
|
|
384
|
+
return ranges, []
|
|
385
|
+
ranges = list(ranges)
|
|
386
|
+
const_only_nodes = const_only_nodes or set()
|
|
387
|
+
while True:
|
|
388
|
+
partitions = _build_partition_io(
|
|
389
|
+
graph=graph,
|
|
390
|
+
ranges=ranges,
|
|
391
|
+
const_only_nodes=const_only_nodes,
|
|
392
|
+
) or []
|
|
393
|
+
if all(part['inputs'] and part['outputs'] for part in partitions):
|
|
394
|
+
return ranges, partitions
|
|
395
|
+
if len(ranges) <= 1:
|
|
396
|
+
return ranges, partitions
|
|
397
|
+
merged = False
|
|
398
|
+
for idx, part in enumerate(partitions):
|
|
399
|
+
if not part['inputs'] or not part['outputs']:
|
|
400
|
+
if idx > 0:
|
|
401
|
+
ranges[idx - 1] = (ranges[idx - 1][0], ranges[idx][1])
|
|
402
|
+
del ranges[idx]
|
|
403
|
+
else:
|
|
404
|
+
ranges[idx] = (ranges[idx][0], ranges[idx + 1][1])
|
|
405
|
+
del ranges[idx + 1]
|
|
406
|
+
merged = True
|
|
407
|
+
break
|
|
408
|
+
if not merged:
|
|
409
|
+
return ranges, partitions
|
|
410
|
+
|
|
65
411
|
def fuse_expanded_qdq_to_qdq(
|
|
66
412
|
*,
|
|
67
413
|
graph: gs.Graph,
|
|
@@ -285,6 +631,7 @@ def convert(
|
|
|
285
631
|
quant_norm_std: Optional[str] = '[[[[0.229, 0.224, 0.225]]]]',
|
|
286
632
|
quant_type: Optional[str] = 'per-channel',
|
|
287
633
|
custom_input_op_name_np_data_path: Optional[List] = None,
|
|
634
|
+
tf_input_cache: Optional[Dict[str, np.ndarray]] = None,
|
|
288
635
|
input_quant_dtype: Optional[str] = 'int8',
|
|
289
636
|
output_quant_dtype: Optional[str] = 'int8',
|
|
290
637
|
not_use_onnxsim: Optional[bool] = False,
|
|
@@ -321,6 +668,8 @@ def convert(
|
|
|
321
668
|
param_replacement_file: Optional[str] = '',
|
|
322
669
|
auto_generate_json: Optional[bool] = False,
|
|
323
670
|
auto_generate_json_on_error: Optional[bool] = False,
|
|
671
|
+
enable_auto_split_model: Optional[bool] = False,
|
|
672
|
+
auto_split_max_size_mb: Optional[int] = 1024,
|
|
324
673
|
check_gpu_delegate_compatibility: Optional[bool] = False,
|
|
325
674
|
check_onnx_tf_outputs_elementwise_close: Optional[bool] = False,
|
|
326
675
|
check_onnx_tf_outputs_elementwise_close_full: Optional[bool] = False,
|
|
@@ -451,6 +800,10 @@ def convert(
|
|
|
451
800
|
["input2","input2.npy",[0.3],[0.07]],\n
|
|
452
801
|
]
|
|
453
802
|
|
|
803
|
+
tf_input_cache: Optional[Dict[str, np.ndarray]]
|
|
804
|
+
Cache of TF dummy inference inputs keyed by TF input tensor name.\n
|
|
805
|
+
Used to propagate TF outputs between auto-split partitions.\n
|
|
806
|
+
|
|
454
807
|
input_quant_dtype: Optional[str]
|
|
455
808
|
Input dtypes when doing Full INT8 Quantization.\n
|
|
456
809
|
"int8"(default) or "uint8" or "float32"
|
|
@@ -682,6 +1035,15 @@ def convert(
|
|
|
682
1035
|
This is now opt-in and requires explicitly enabling the feature.\n
|
|
683
1036
|
Default: False
|
|
684
1037
|
|
|
1038
|
+
enable_auto_split_model: Optional[bool]
|
|
1039
|
+
Force auto split regardless of the ONNX file size.\n
|
|
1040
|
+
The target size is controlled by auto_split_max_size_mb.\n
|
|
1041
|
+
Default: False
|
|
1042
|
+
|
|
1043
|
+
auto_split_max_size_mb: Optional[int]
|
|
1044
|
+
Target maximum size per partition in MB based on ONNX initializer sizes.\n
|
|
1045
|
+
Default: 1024
|
|
1046
|
+
|
|
685
1047
|
check_gpu_delegate_compatibility: Optional[bool]
|
|
686
1048
|
Run TFLite ModelAnalyzer on the generated Float16 tflite model\n
|
|
687
1049
|
to check if the model can be supported by GPU Delegate.
|
|
@@ -771,6 +1133,23 @@ def convert(
|
|
|
771
1133
|
f'input_onnx_file_path: {input_onnx_file_path}'
|
|
772
1134
|
)
|
|
773
1135
|
sys.exit(1)
|
|
1136
|
+
auto_split_model = bool(enable_auto_split_model)
|
|
1137
|
+
if auto_split_model:
|
|
1138
|
+
info(
|
|
1139
|
+
Color.GREEN('Auto split forced by --enable_auto_split_model. ') +
|
|
1140
|
+
f'target={auto_split_max_size_mb} MB'
|
|
1141
|
+
)
|
|
1142
|
+
if onnx_graph is None and input_onnx_file_path and os.path.exists(input_onnx_file_path):
|
|
1143
|
+
try:
|
|
1144
|
+
onnx_file_size = os.path.getsize(input_onnx_file_path)
|
|
1145
|
+
if not auto_split_model and onnx_file_size > 2 * 1024 * 1024 * 1024:
|
|
1146
|
+
info(
|
|
1147
|
+
Color.GREEN('ONNX file exceeds 2GB; switching to auto-split mode. ') +
|
|
1148
|
+
f'size={onnx_file_size / (1024 * 1024 * 1024):.2f} GB'
|
|
1149
|
+
)
|
|
1150
|
+
auto_split_model = True
|
|
1151
|
+
except Exception:
|
|
1152
|
+
pass
|
|
774
1153
|
|
|
775
1154
|
# Extracting onnx filenames
|
|
776
1155
|
output_file_name = ''
|
|
@@ -917,9 +1296,8 @@ def convert(
|
|
|
917
1296
|
exported_onnx_graph = gs.export_onnx(graph, do_type_check=False, **meta_data)
|
|
918
1297
|
if metadata_props is not None:
|
|
919
1298
|
exported_onnx_graph.metadata_props.extend(metadata_props)
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
del estimated_graph
|
|
1299
|
+
onnx.save(exported_onnx_graph, f=input_onnx_file_path)
|
|
1300
|
+
del exported_onnx_graph
|
|
923
1301
|
except:
|
|
924
1302
|
if tmp_graph is not None:
|
|
925
1303
|
del tmp_graph
|
|
@@ -957,6 +1335,7 @@ def convert(
|
|
|
957
1335
|
input_onnx_file_path=f'{input_onnx_file_path}',
|
|
958
1336
|
onnx_graph=onnx_graph,
|
|
959
1337
|
output_onnx_file_path=f'{input_onnx_file_path}',
|
|
1338
|
+
has_external_data=has_external_data,
|
|
960
1339
|
non_verbose=True,
|
|
961
1340
|
)
|
|
962
1341
|
info(Color.GREEN(f'Automatic generation of each OP name complete!'))
|
|
@@ -978,9 +1357,24 @@ def convert(
|
|
|
978
1357
|
|
|
979
1358
|
# Loading Graphs
|
|
980
1359
|
# onnx_graph If specified, onnx_graph is processed first
|
|
1360
|
+
has_external_data = False
|
|
981
1361
|
if not onnx_graph:
|
|
1362
|
+
has_external_data = check_has_external_data(input_onnx_file_path)
|
|
982
1363
|
onnx_graph = onnx.load(input_onnx_file_path)
|
|
983
1364
|
|
|
1365
|
+
if not auto_split_model and onnx_graph is not None:
|
|
1366
|
+
try:
|
|
1367
|
+
initializer_sizes = _collect_initializer_sizes(onnx_graph)
|
|
1368
|
+
total_init_bytes = sum(initializer_sizes.values())
|
|
1369
|
+
if total_init_bytes > 2 * 1024 * 1024 * 1024:
|
|
1370
|
+
info(
|
|
1371
|
+
Color.GREEN('ONNX graph estimated initializer size exceeds 2GB; ') +
|
|
1372
|
+
f'switching to auto-split mode. size={total_init_bytes / (1024 * 1024 * 1024):.2f} GB'
|
|
1373
|
+
)
|
|
1374
|
+
auto_split_model = True
|
|
1375
|
+
except Exception:
|
|
1376
|
+
pass
|
|
1377
|
+
|
|
984
1378
|
domain: str = onnx_graph.domain
|
|
985
1379
|
ir_version: int = onnx_graph.ir_version
|
|
986
1380
|
meta_data = {'domain': domain, 'ir_version': ir_version}
|
|
@@ -990,6 +1384,522 @@ def convert(
|
|
|
990
1384
|
graph = gs.import_onnx(onnx_graph)
|
|
991
1385
|
fuse_expanded_qdq_to_qdq(graph=graph)
|
|
992
1386
|
|
|
1387
|
+
# Auto split model by estimated weight size
|
|
1388
|
+
if auto_split_model:
|
|
1389
|
+
if input_names_to_interrupt_model_conversion or output_names_to_interrupt_model_conversion:
|
|
1390
|
+
error(
|
|
1391
|
+
'Auto split cannot be used together with input_names_to_interrupt_model_conversion '
|
|
1392
|
+
'or output_names_to_interrupt_model_conversion.'
|
|
1393
|
+
)
|
|
1394
|
+
sys.exit(1)
|
|
1395
|
+
if auto_split_max_size_mb is None or auto_split_max_size_mb <= 0:
|
|
1396
|
+
error(
|
|
1397
|
+
f'auto_split_max_size_mb must be greater than 0. auto_split_max_size_mb: {auto_split_max_size_mb}'
|
|
1398
|
+
)
|
|
1399
|
+
sys.exit(1)
|
|
1400
|
+
try:
|
|
1401
|
+
import sne4onnx
|
|
1402
|
+
except Exception:
|
|
1403
|
+
error(
|
|
1404
|
+
'Auto split requires sne4onnx. pip install sne4onnx'
|
|
1405
|
+
)
|
|
1406
|
+
sys.exit(1)
|
|
1407
|
+
try:
|
|
1408
|
+
graph.toposort()
|
|
1409
|
+
except Exception:
|
|
1410
|
+
pass
|
|
1411
|
+
|
|
1412
|
+
onnx_graph_for_split = onnx_graph
|
|
1413
|
+
try:
|
|
1414
|
+
onnx_graph_for_split = gs.export_onnx(
|
|
1415
|
+
graph=graph,
|
|
1416
|
+
do_type_check=False,
|
|
1417
|
+
**meta_data,
|
|
1418
|
+
)
|
|
1419
|
+
if metadata_props is not None:
|
|
1420
|
+
onnx_graph_for_split.metadata_props.extend(metadata_props)
|
|
1421
|
+
except Exception:
|
|
1422
|
+
onnx_graph_for_split = onnx_graph
|
|
1423
|
+
|
|
1424
|
+
initializer_sizes = _collect_initializer_sizes(onnx_graph_for_split)
|
|
1425
|
+
node_weight_keys, weight_sizes = _collect_node_weight_keys(
|
|
1426
|
+
graph=graph,
|
|
1427
|
+
initializer_sizes=initializer_sizes,
|
|
1428
|
+
)
|
|
1429
|
+
const_only_nodes = _collect_constant_only_node_indices(
|
|
1430
|
+
graph,
|
|
1431
|
+
initializer_names=set(initializer_sizes.keys()),
|
|
1432
|
+
)
|
|
1433
|
+
reachable_node_indices = _collect_reachable_node_indices(
|
|
1434
|
+
graph,
|
|
1435
|
+
initializer_names=set(initializer_sizes.keys()),
|
|
1436
|
+
)
|
|
1437
|
+
max_size_bytes = int(auto_split_max_size_mb) * 1024 * 1024
|
|
1438
|
+
ranges = _auto_partition_ranges(
|
|
1439
|
+
node_weight_keys=node_weight_keys,
|
|
1440
|
+
weight_sizes=weight_sizes,
|
|
1441
|
+
max_size_bytes=max_size_bytes,
|
|
1442
|
+
reachable_node_indices=reachable_node_indices,
|
|
1443
|
+
)
|
|
1444
|
+
if len(ranges) > 1:
|
|
1445
|
+
ranges, partitions = _merge_ranges_with_missing_io(
|
|
1446
|
+
graph=graph,
|
|
1447
|
+
ranges=ranges,
|
|
1448
|
+
const_only_nodes=const_only_nodes,
|
|
1449
|
+
)
|
|
1450
|
+
if not partitions:
|
|
1451
|
+
warn(
|
|
1452
|
+
'Auto split failed to determine partition boundaries. Proceeding without split.'
|
|
1453
|
+
)
|
|
1454
|
+
else:
|
|
1455
|
+
if any([not p['inputs'] or not p['outputs'] for p in partitions]):
|
|
1456
|
+
warn(
|
|
1457
|
+
'Auto split produced partitions with missing inputs or outputs. '
|
|
1458
|
+
'Some partitions may not be inferable.'
|
|
1459
|
+
)
|
|
1460
|
+
partition_sizes = _estimate_partition_weight_bytes(
|
|
1461
|
+
ranges=ranges,
|
|
1462
|
+
node_weight_keys=node_weight_keys,
|
|
1463
|
+
weight_sizes=weight_sizes,
|
|
1464
|
+
)
|
|
1465
|
+
try:
|
|
1466
|
+
op_type_list = list(set([node.op for node in graph.nodes]))
|
|
1467
|
+
local_use_cuda = sum(
|
|
1468
|
+
[1 if op_type in CUDA_ONLY_OPS else 0 for op_type in op_type_list]
|
|
1469
|
+
) > 0
|
|
1470
|
+
except Exception:
|
|
1471
|
+
local_use_cuda = False
|
|
1472
|
+
info('')
|
|
1473
|
+
info(Color.REVERSE(f'Auto model partitioning enabled'), '=' * 44)
|
|
1474
|
+
info(
|
|
1475
|
+
Color.GREEN(f'Target partition size (estimated weights): ') +
|
|
1476
|
+
f'{auto_split_max_size_mb} MB'
|
|
1477
|
+
)
|
|
1478
|
+
for idx, part in enumerate(partitions):
|
|
1479
|
+
size_mb = partition_sizes[idx] / (1024 * 1024)
|
|
1480
|
+
info(
|
|
1481
|
+
f' part {idx+1}: nodes={part["node_count"]}, '
|
|
1482
|
+
f'est_weights={size_mb:.2f} MB, '
|
|
1483
|
+
f'inputs={len(part["inputs"])}, outputs={len(part["outputs"])}'
|
|
1484
|
+
)
|
|
1485
|
+
info(
|
|
1486
|
+
f' inputs: {", ".join(part["inputs"]) if part["inputs"] else "(none)"}'
|
|
1487
|
+
)
|
|
1488
|
+
info(
|
|
1489
|
+
f' outputs: {", ".join(part["outputs"]) if part["outputs"] else "(none)"}'
|
|
1490
|
+
)
|
|
1491
|
+
|
|
1492
|
+
split_input_cache: Dict[str, str] = {}
|
|
1493
|
+
split_input_dir = tempfile.mkdtemp(prefix='onnx2tf_split_')
|
|
1494
|
+
split_tf_input_cache: Dict[str, np.ndarray] = {}
|
|
1495
|
+
split_output_layouts: Dict[str, bool] = {}
|
|
1496
|
+
|
|
1497
|
+
def _sanitize_tf_input_name(name: str) -> str:
|
|
1498
|
+
if name is None:
|
|
1499
|
+
return ''
|
|
1500
|
+
sanitized = name.replace(':', '__')
|
|
1501
|
+
if output_signaturedefs or output_integer_quantized_tflite:
|
|
1502
|
+
sanitized = re.sub('^/', 'wa/', sanitized)
|
|
1503
|
+
return f'{sanitized}:0'
|
|
1504
|
+
|
|
1505
|
+
def _normalize_onnx_output_name(name: str) -> str:
|
|
1506
|
+
if name is None:
|
|
1507
|
+
return ''
|
|
1508
|
+
normalized = name.replace(':', '__')
|
|
1509
|
+
if output_signaturedefs or output_integer_quantized_tflite:
|
|
1510
|
+
normalized = re.sub('^/', '', normalized)
|
|
1511
|
+
return normalized
|
|
1512
|
+
|
|
1513
|
+
def _common_layout_perms(rank: int) -> List[tuple]:
|
|
1514
|
+
if rank == 3:
|
|
1515
|
+
return [(0, 1, 2), (0, 2, 1)]
|
|
1516
|
+
if rank == 4:
|
|
1517
|
+
return [(0, 1, 2, 3), (0, 2, 3, 1), (0, 3, 1, 2)]
|
|
1518
|
+
if rank == 5:
|
|
1519
|
+
return [(0, 1, 2, 3, 4), (0, 2, 3, 4, 1), (0, 4, 1, 2, 3)]
|
|
1520
|
+
return [tuple(range(rank))]
|
|
1521
|
+
|
|
1522
|
+
def _build_onnx_tf_output_map(
|
|
1523
|
+
*,
|
|
1524
|
+
onnx_output_names: List[str],
|
|
1525
|
+
tf_output_tensors: List[tf.Tensor],
|
|
1526
|
+
onnx_output_values: Optional[Dict[str, np.ndarray]] = None,
|
|
1527
|
+
tf_output_values: Optional[Dict[str, np.ndarray]] = None,
|
|
1528
|
+
) -> Dict[str, tf.Tensor]:
|
|
1529
|
+
tf_by_base = {t.name.split(':')[0]: t for t in tf_output_tensors}
|
|
1530
|
+
tf_by_full = {t.name: t for t in tf_output_tensors}
|
|
1531
|
+
mapping: Dict[str, tf.Tensor] = {}
|
|
1532
|
+
used_tf: set = set()
|
|
1533
|
+
missing: List[str] = []
|
|
1534
|
+
for onnx_name in onnx_output_names:
|
|
1535
|
+
normalized = _normalize_onnx_output_name(onnx_name)
|
|
1536
|
+
candidates = [
|
|
1537
|
+
normalized,
|
|
1538
|
+
f'wa/{normalized}',
|
|
1539
|
+
]
|
|
1540
|
+
tf_tensor = None
|
|
1541
|
+
for cand in candidates:
|
|
1542
|
+
if cand in tf_by_base:
|
|
1543
|
+
tf_tensor = tf_by_base[cand]
|
|
1544
|
+
break
|
|
1545
|
+
full_name = f'{cand}:0'
|
|
1546
|
+
if full_name in tf_by_full:
|
|
1547
|
+
tf_tensor = tf_by_full[full_name]
|
|
1548
|
+
break
|
|
1549
|
+
if tf_tensor is None:
|
|
1550
|
+
if onnx_name in tf_by_base:
|
|
1551
|
+
tf_tensor = tf_by_base[onnx_name]
|
|
1552
|
+
else:
|
|
1553
|
+
full_name = f'{onnx_name}:0'
|
|
1554
|
+
if full_name in tf_by_full:
|
|
1555
|
+
tf_tensor = tf_by_full[full_name]
|
|
1556
|
+
if tf_tensor is not None:
|
|
1557
|
+
mapping[onnx_name] = tf_tensor
|
|
1558
|
+
used_tf.add(tf_tensor.name)
|
|
1559
|
+
else:
|
|
1560
|
+
missing.append(onnx_name)
|
|
1561
|
+
|
|
1562
|
+
if onnx_output_values and tf_output_values and missing:
|
|
1563
|
+
tf_candidates = []
|
|
1564
|
+
for tf_tensor in tf_output_tensors:
|
|
1565
|
+
tf_val = tf_output_values.get(tf_tensor.name)
|
|
1566
|
+
if tf_val is None:
|
|
1567
|
+
tf_val = tf_output_values.get(tf_tensor.name.split(':')[0])
|
|
1568
|
+
if tf_val is not None:
|
|
1569
|
+
tf_candidates.append((tf_tensor, tf_val))
|
|
1570
|
+
still_missing = []
|
|
1571
|
+
for onnx_name in list(missing):
|
|
1572
|
+
onnx_val = onnx_output_values.get(onnx_name)
|
|
1573
|
+
if onnx_val is None:
|
|
1574
|
+
still_missing.append(onnx_name)
|
|
1575
|
+
continue
|
|
1576
|
+
best = None
|
|
1577
|
+
best_err = None
|
|
1578
|
+
for tf_tensor, tf_val in tf_candidates:
|
|
1579
|
+
if tf_tensor.name in used_tf:
|
|
1580
|
+
continue
|
|
1581
|
+
if tf_val.shape != onnx_val.shape:
|
|
1582
|
+
continue
|
|
1583
|
+
err = np.max(np.abs(onnx_val - tf_val))
|
|
1584
|
+
if best is None or err < best_err:
|
|
1585
|
+
best = tf_tensor
|
|
1586
|
+
best_err = err
|
|
1587
|
+
if best is None:
|
|
1588
|
+
for tf_tensor, tf_val in tf_candidates:
|
|
1589
|
+
if tf_tensor.name in used_tf:
|
|
1590
|
+
continue
|
|
1591
|
+
if tf_val.ndim != onnx_val.ndim:
|
|
1592
|
+
continue
|
|
1593
|
+
for perm in _common_layout_perms(tf_val.ndim):
|
|
1594
|
+
if tf_val.transpose(perm).shape != onnx_val.shape:
|
|
1595
|
+
continue
|
|
1596
|
+
err = np.max(np.abs(onnx_val - tf_val.transpose(perm)))
|
|
1597
|
+
if best is None or err < best_err:
|
|
1598
|
+
best = tf_tensor
|
|
1599
|
+
best_err = err
|
|
1600
|
+
if best is not None and best_err is not None and best_err <= 1e-3:
|
|
1601
|
+
mapping[onnx_name] = best
|
|
1602
|
+
used_tf.add(best.name)
|
|
1603
|
+
else:
|
|
1604
|
+
still_missing.append(onnx_name)
|
|
1605
|
+
missing = still_missing
|
|
1606
|
+
if missing:
|
|
1607
|
+
warn(
|
|
1608
|
+
'Auto split output mapping failed for: ' +
|
|
1609
|
+
', '.join(missing) +
|
|
1610
|
+
'. Output cache/layout may be incomplete.'
|
|
1611
|
+
)
|
|
1612
|
+
return mapping
|
|
1613
|
+
|
|
1614
|
+
def _onnx_output_shape_map(onnx_model: onnx.ModelProto) -> Dict[str, List[Optional[int]]]:
|
|
1615
|
+
shape_map: Dict[str, List[Optional[int]]] = {}
|
|
1616
|
+
try:
|
|
1617
|
+
for out in onnx_model.graph.output:
|
|
1618
|
+
dims: List[Optional[int]] = []
|
|
1619
|
+
t = out.type.tensor_type
|
|
1620
|
+
if t.HasField('shape'):
|
|
1621
|
+
for d in t.shape.dim:
|
|
1622
|
+
if d.dim_value > 0:
|
|
1623
|
+
dims.append(int(d.dim_value))
|
|
1624
|
+
elif d.dim_param:
|
|
1625
|
+
dims.append(None)
|
|
1626
|
+
else:
|
|
1627
|
+
dims.append(None)
|
|
1628
|
+
if dims:
|
|
1629
|
+
shape_map[out.name] = dims
|
|
1630
|
+
except Exception:
|
|
1631
|
+
pass
|
|
1632
|
+
return shape_map
|
|
1633
|
+
|
|
1634
|
+
def _infer_keep_shape(onnx_shape: List[Optional[int]], tf_shape: List[int]) -> Optional[bool]:
|
|
1635
|
+
if not onnx_shape or any(d is None for d in onnx_shape):
|
|
1636
|
+
return None
|
|
1637
|
+
if list(onnx_shape) == list(tf_shape):
|
|
1638
|
+
return True
|
|
1639
|
+
rank = len(onnx_shape)
|
|
1640
|
+
if rank == 3:
|
|
1641
|
+
if list(tf_shape) == [onnx_shape[0], onnx_shape[2], onnx_shape[1]]:
|
|
1642
|
+
return False
|
|
1643
|
+
elif rank == 4:
|
|
1644
|
+
if list(tf_shape) == [onnx_shape[0], onnx_shape[2], onnx_shape[3], onnx_shape[1]]:
|
|
1645
|
+
return False
|
|
1646
|
+
elif rank == 5:
|
|
1647
|
+
if list(tf_shape) == [onnx_shape[0], onnx_shape[2], onnx_shape[3], onnx_shape[4], onnx_shape[1]]:
|
|
1648
|
+
return False
|
|
1649
|
+
return None
|
|
1650
|
+
|
|
1651
|
+
def _merge_custom_inputs(user_inputs, auto_inputs):
|
|
1652
|
+
merged = []
|
|
1653
|
+
seen = set()
|
|
1654
|
+
if user_inputs:
|
|
1655
|
+
for item in user_inputs:
|
|
1656
|
+
if len(item) >= 2:
|
|
1657
|
+
merged.append(item)
|
|
1658
|
+
seen.add(item[0])
|
|
1659
|
+
for item in auto_inputs:
|
|
1660
|
+
if len(item) >= 2 and item[0] not in seen:
|
|
1661
|
+
merged.append(item)
|
|
1662
|
+
seen.add(item[0])
|
|
1663
|
+
return merged
|
|
1664
|
+
|
|
1665
|
+
base_kwargs = {
|
|
1666
|
+
'input_onnx_file_path': input_onnx_file_path if input_onnx_file_path is not None else None,
|
|
1667
|
+
'onnx_graph': onnx_graph,
|
|
1668
|
+
'output_folder_path': output_folder_path,
|
|
1669
|
+
'output_signaturedefs': output_signaturedefs,
|
|
1670
|
+
'output_h5': output_h5,
|
|
1671
|
+
'output_keras_v3': output_keras_v3,
|
|
1672
|
+
'output_tfv1_pb': output_tfv1_pb,
|
|
1673
|
+
'output_weights': output_weights,
|
|
1674
|
+
'copy_onnx_input_output_names_to_tflite': copy_onnx_input_output_names_to_tflite,
|
|
1675
|
+
'output_dynamic_range_quantized_tflite': output_dynamic_range_quantized_tflite,
|
|
1676
|
+
'output_integer_quantized_tflite': output_integer_quantized_tflite,
|
|
1677
|
+
'quant_norm_mean': quant_norm_mean,
|
|
1678
|
+
'quant_norm_std': quant_norm_std,
|
|
1679
|
+
'quant_type': quant_type,
|
|
1680
|
+
'custom_input_op_name_np_data_path': custom_input_op_name_np_data_path,
|
|
1681
|
+
'tf_input_cache': split_tf_input_cache,
|
|
1682
|
+
'input_quant_dtype': input_quant_dtype,
|
|
1683
|
+
'output_quant_dtype': output_quant_dtype,
|
|
1684
|
+
'not_use_onnxsim': not_use_onnxsim,
|
|
1685
|
+
'not_use_opname_auto_generate': not_use_opname_auto_generate,
|
|
1686
|
+
'batch_size': batch_size,
|
|
1687
|
+
'overwrite_input_shape': overwrite_input_shape,
|
|
1688
|
+
'shape_hints': shape_hints,
|
|
1689
|
+
'no_large_tensor': no_large_tensor,
|
|
1690
|
+
'output_nms_with_dynamic_tensor': output_nms_with_dynamic_tensor,
|
|
1691
|
+
'switch_nms_version': switch_nms_version,
|
|
1692
|
+
'keep_ncw_or_nchw_or_ncdhw_input_names': keep_ncw_or_nchw_or_ncdhw_input_names,
|
|
1693
|
+
'keep_nwc_or_nhwc_or_ndhwc_input_names': keep_nwc_or_nhwc_or_ndhwc_input_names,
|
|
1694
|
+
'keep_shape_absolutely_input_names': keep_shape_absolutely_input_names,
|
|
1695
|
+
'input_names_to_interrupt_model_conversion': None,
|
|
1696
|
+
'output_names_to_interrupt_model_conversion': None,
|
|
1697
|
+
'disable_group_convolution': disable_group_convolution,
|
|
1698
|
+
'enable_accumulation_type_float16': enable_accumulation_type_float16,
|
|
1699
|
+
'enable_batchmatmul_unfold': enable_batchmatmul_unfold,
|
|
1700
|
+
'enable_rnn_unroll': enable_rnn_unroll,
|
|
1701
|
+
'disable_suppression_flextranspose': disable_suppression_flextranspose,
|
|
1702
|
+
'disable_strict_mode': disable_strict_mode,
|
|
1703
|
+
'onnxruntime_output_memmap': onnxruntime_output_memmap,
|
|
1704
|
+
'onnxruntime_output_memmap_dir': onnxruntime_output_memmap_dir,
|
|
1705
|
+
'number_of_dimensions_after_flextranspose_compression': number_of_dimensions_after_flextranspose_compression,
|
|
1706
|
+
'disable_suppression_flexstridedslice': disable_suppression_flexstridedslice,
|
|
1707
|
+
'number_of_dimensions_after_flexstridedslice_compression': number_of_dimensions_after_flexstridedslice_compression,
|
|
1708
|
+
'optimization_for_gpu_delegate': optimization_for_gpu_delegate,
|
|
1709
|
+
'replace_argmax_to_reducemax_and_indices_is_int64': replace_argmax_to_reducemax_and_indices_is_int64,
|
|
1710
|
+
'replace_argmax_to_reducemax_and_indices_is_float32': replace_argmax_to_reducemax_and_indices_is_float32,
|
|
1711
|
+
'replace_argmax_to_fused_argmax_and_indices_is_int64': replace_argmax_to_fused_argmax_and_indices_is_int64,
|
|
1712
|
+
'replace_argmax_to_fused_argmax_and_indices_is_float32': replace_argmax_to_fused_argmax_and_indices_is_float32,
|
|
1713
|
+
'fused_argmax_scale_ratio': fused_argmax_scale_ratio,
|
|
1714
|
+
'replace_to_pseudo_operators': replace_to_pseudo_operators,
|
|
1715
|
+
'param_replacement_file': param_replacement_file,
|
|
1716
|
+
'auto_generate_json': auto_generate_json,
|
|
1717
|
+
'auto_generate_json_on_error': auto_generate_json_on_error,
|
|
1718
|
+
'enable_auto_split_model': False,
|
|
1719
|
+
'auto_split_max_size_mb': auto_split_max_size_mb,
|
|
1720
|
+
'check_gpu_delegate_compatibility': check_gpu_delegate_compatibility,
|
|
1721
|
+
'check_onnx_tf_outputs_elementwise_close': check_onnx_tf_outputs_elementwise_close,
|
|
1722
|
+
'check_onnx_tf_outputs_elementwise_close_full': check_onnx_tf_outputs_elementwise_close_full,
|
|
1723
|
+
'check_onnx_tf_outputs_sample_data_normalization': check_onnx_tf_outputs_sample_data_normalization,
|
|
1724
|
+
'check_onnx_tf_outputs_elementwise_close_rtol': check_onnx_tf_outputs_elementwise_close_rtol,
|
|
1725
|
+
'check_onnx_tf_outputs_elementwise_close_atol': check_onnx_tf_outputs_elementwise_close_atol,
|
|
1726
|
+
'mvn_epsilon': mvn_epsilon,
|
|
1727
|
+
'disable_model_save': disable_model_save,
|
|
1728
|
+
'non_verbose': non_verbose,
|
|
1729
|
+
'verbosity': verbosity,
|
|
1730
|
+
}
|
|
1731
|
+
base_kwargs['input_names_to_interrupt_model_conversion'] = None
|
|
1732
|
+
base_kwargs['output_names_to_interrupt_model_conversion'] = None
|
|
1733
|
+
|
|
1734
|
+
model_ret = None
|
|
1735
|
+
try:
|
|
1736
|
+
for idx, part in enumerate(partitions):
|
|
1737
|
+
part_output_values: Optional[Dict[str, np.ndarray]] = None
|
|
1738
|
+
part_output_folder = os.path.join(
|
|
1739
|
+
output_folder_path,
|
|
1740
|
+
f'part_{idx+1:04d}',
|
|
1741
|
+
)
|
|
1742
|
+
base_name = os.path.splitext(os.path.basename(input_onnx_file_path))[0] \
|
|
1743
|
+
if input_onnx_file_path else 'model'
|
|
1744
|
+
os.makedirs(part_output_folder, exist_ok=True)
|
|
1745
|
+
split_onnx_path = os.path.join(
|
|
1746
|
+
part_output_folder,
|
|
1747
|
+
f'{base_name}_part_{idx+1:04d}.onnx'
|
|
1748
|
+
)
|
|
1749
|
+
part_graph = sne4onnx.extraction(
|
|
1750
|
+
input_op_names=part['inputs'],
|
|
1751
|
+
output_op_names=part['outputs'],
|
|
1752
|
+
onnx_graph=onnx_graph_for_split,
|
|
1753
|
+
output_onnx_file_path=split_onnx_path,
|
|
1754
|
+
has_external_data=has_external_data,
|
|
1755
|
+
)
|
|
1756
|
+
auto_custom_inputs = []
|
|
1757
|
+
if split_input_cache:
|
|
1758
|
+
for input_name in part['inputs']:
|
|
1759
|
+
if input_name in split_input_cache:
|
|
1760
|
+
auto_custom_inputs.append([
|
|
1761
|
+
input_name,
|
|
1762
|
+
split_input_cache[input_name],
|
|
1763
|
+
])
|
|
1764
|
+
merged_custom_inputs = _merge_custom_inputs(
|
|
1765
|
+
custom_input_op_name_np_data_path,
|
|
1766
|
+
auto_custom_inputs,
|
|
1767
|
+
)
|
|
1768
|
+
# For the first partition, keep the same behavior as non-split conversion.
|
|
1769
|
+
# Only user-provided custom inputs are used.
|
|
1770
|
+
if idx == 0 and not auto_custom_inputs:
|
|
1771
|
+
custom_inputs_for_part = merged_custom_inputs
|
|
1772
|
+
else:
|
|
1773
|
+
require_mean_std = bool(output_integer_quantized_tflite)
|
|
1774
|
+
custom_inputs_for_part = _complete_custom_inputs_for_graph(
|
|
1775
|
+
onnx_graph=part_graph,
|
|
1776
|
+
custom_inputs=merged_custom_inputs,
|
|
1777
|
+
output_dir=split_input_dir,
|
|
1778
|
+
file_prefix=f'part_{idx+1:04d}',
|
|
1779
|
+
shape_hints=shape_hints,
|
|
1780
|
+
require_mean_std=require_mean_std,
|
|
1781
|
+
)
|
|
1782
|
+
|
|
1783
|
+
# Propagate dummy outputs to next partitions
|
|
1784
|
+
try:
|
|
1785
|
+
has_inputs = len(part_graph.graph.input) > 0
|
|
1786
|
+
has_outputs = len(part_graph.graph.output) > 0
|
|
1787
|
+
if has_inputs and has_outputs:
|
|
1788
|
+
part_input_datas = {}
|
|
1789
|
+
part_outputs = dummy_onnx_inference(
|
|
1790
|
+
onnx_graph=part_graph,
|
|
1791
|
+
output_names=part['outputs'],
|
|
1792
|
+
test_data_nhwc=None,
|
|
1793
|
+
custom_input_op_name_np_data_path=custom_inputs_for_part,
|
|
1794
|
+
tf_layers_dict={},
|
|
1795
|
+
use_cuda=local_use_cuda,
|
|
1796
|
+
disable_strict_mode=disable_strict_mode,
|
|
1797
|
+
enable_ort_output_memmap=False,
|
|
1798
|
+
ort_output_memmap_dir=None,
|
|
1799
|
+
shape_hints=shape_hints,
|
|
1800
|
+
input_datas_for_validation=part_input_datas,
|
|
1801
|
+
)
|
|
1802
|
+
for input_name, input_value in part_input_datas.items():
|
|
1803
|
+
file_name = (
|
|
1804
|
+
f'part_{idx+1:04d}_' +
|
|
1805
|
+
f'{_sanitize_split_input_name(input_name)}.npy'
|
|
1806
|
+
)
|
|
1807
|
+
file_path = os.path.join(split_input_dir, file_name)
|
|
1808
|
+
split_input_cache[input_name] = _write_memmap_array(
|
|
1809
|
+
file_path,
|
|
1810
|
+
input_value,
|
|
1811
|
+
)
|
|
1812
|
+
part_output_values = {
|
|
1813
|
+
name: value for name, value in zip(part['outputs'], part_outputs)
|
|
1814
|
+
}
|
|
1815
|
+
for output_name, output_value in zip(part['outputs'], part_outputs):
|
|
1816
|
+
file_name = (
|
|
1817
|
+
f'part_{idx+1:04d}_' +
|
|
1818
|
+
f'{_sanitize_split_input_name(output_name)}.npy'
|
|
1819
|
+
)
|
|
1820
|
+
file_path = os.path.join(split_input_dir, file_name)
|
|
1821
|
+
split_input_cache[output_name] = _write_memmap_array(
|
|
1822
|
+
file_path,
|
|
1823
|
+
output_value,
|
|
1824
|
+
)
|
|
1825
|
+
else:
|
|
1826
|
+
warn(
|
|
1827
|
+
'Auto split input propagation skipped for this partition '
|
|
1828
|
+
'because it has no inputs or outputs.'
|
|
1829
|
+
)
|
|
1830
|
+
except Exception as ex:
|
|
1831
|
+
warn(
|
|
1832
|
+
'Auto split input propagation failed for this partition. '
|
|
1833
|
+
'Subsequent partitions may use default dummy inputs.'
|
|
1834
|
+
)
|
|
1835
|
+
warn(f'{ex}')
|
|
1836
|
+
|
|
1837
|
+
part_kwargs = dict(base_kwargs)
|
|
1838
|
+
if split_output_layouts:
|
|
1839
|
+
part_keep_shape_abs = set(keep_shape_absolutely_input_names or [])
|
|
1840
|
+
for input_name in part['inputs']:
|
|
1841
|
+
if split_output_layouts.get(input_name, False):
|
|
1842
|
+
part_keep_shape_abs.add(input_name)
|
|
1843
|
+
part_kwargs['keep_shape_absolutely_input_names'] = \
|
|
1844
|
+
list(part_keep_shape_abs) if part_keep_shape_abs else None
|
|
1845
|
+
part_kwargs['input_onnx_file_path'] = split_onnx_path
|
|
1846
|
+
part_kwargs['onnx_graph'] = part_graph
|
|
1847
|
+
part_kwargs['output_folder_path'] = part_output_folder
|
|
1848
|
+
if custom_inputs_for_part:
|
|
1849
|
+
part_kwargs['custom_input_op_name_np_data_path'] = custom_inputs_for_part
|
|
1850
|
+
model_ret = convert(**part_kwargs)
|
|
1851
|
+
|
|
1852
|
+
if hasattr(model_ret, 'onnx_output_layouts') \
|
|
1853
|
+
and isinstance(model_ret.onnx_output_layouts, dict):
|
|
1854
|
+
for out_name in part['outputs']:
|
|
1855
|
+
if out_name in model_ret.onnx_output_layouts:
|
|
1856
|
+
split_output_layouts[out_name] = \
|
|
1857
|
+
bool(model_ret.onnx_output_layouts[out_name])
|
|
1858
|
+
|
|
1859
|
+
# Cache TF outputs for the next partition's TF dummy inference.
|
|
1860
|
+
try:
|
|
1861
|
+
tf_outputs = dummy_tf_inference(
|
|
1862
|
+
model=model_ret,
|
|
1863
|
+
inputs=model_ret.inputs,
|
|
1864
|
+
test_data_nhwc=None,
|
|
1865
|
+
custom_input_op_name_np_data_path=custom_input_op_name_np_data_path,
|
|
1866
|
+
prefilled_input_datas=split_tf_input_cache,
|
|
1867
|
+
shape_hints=shape_hints,
|
|
1868
|
+
keep_shape_absolutely_input_names=keep_shape_absolutely_input_names,
|
|
1869
|
+
keep_ncw_or_nchw_or_ncdhw_input_names=keep_ncw_or_nchw_or_ncdhw_input_names,
|
|
1870
|
+
keep_nwc_or_nhwc_or_ndhwc_input_names=keep_nwc_or_nhwc_or_ndhwc_input_names,
|
|
1871
|
+
)
|
|
1872
|
+
onnx_output_shapes = _onnx_output_shape_map(part_graph)
|
|
1873
|
+
if model_ret.outputs and part['outputs']:
|
|
1874
|
+
tf_output_map = _build_onnx_tf_output_map(
|
|
1875
|
+
onnx_output_names=part['outputs'],
|
|
1876
|
+
tf_output_tensors=model_ret.outputs,
|
|
1877
|
+
onnx_output_values=part_output_values,
|
|
1878
|
+
tf_output_values=tf_outputs,
|
|
1879
|
+
)
|
|
1880
|
+
for onnx_out, tf_tensor in tf_output_map.items():
|
|
1881
|
+
tf_val = tf_outputs.get(tf_tensor.name)
|
|
1882
|
+
if tf_val is None:
|
|
1883
|
+
continue
|
|
1884
|
+
# Store both full and base TF names to maximize cache hits.
|
|
1885
|
+
split_tf_input_cache[tf_tensor.name] = tf_val
|
|
1886
|
+
split_tf_input_cache[tf_tensor.name.split(':')[0]] = tf_val
|
|
1887
|
+
# Keep legacy key for compatibility with existing lookups.
|
|
1888
|
+
sanitized = _sanitize_tf_input_name(onnx_out)
|
|
1889
|
+
split_tf_input_cache[sanitized] = tf_val
|
|
1890
|
+
split_tf_input_cache[sanitized.split(':')[0]] = tf_val
|
|
1891
|
+
keep_shape = _infer_keep_shape(
|
|
1892
|
+
onnx_output_shapes.get(onnx_out),
|
|
1893
|
+
list(tf_val.shape),
|
|
1894
|
+
)
|
|
1895
|
+
if keep_shape is not None:
|
|
1896
|
+
split_output_layouts[onnx_out] = keep_shape
|
|
1897
|
+
except Exception:
|
|
1898
|
+
pass
|
|
1899
|
+
finally:
|
|
1900
|
+
shutil.rmtree(split_input_dir, ignore_errors=True)
|
|
1901
|
+
return model_ret
|
|
1902
|
+
|
|
993
1903
|
# Cut the ONNX graph when an input name is specified that interrupts the conversion
|
|
994
1904
|
if not input_names_to_interrupt_model_conversion:
|
|
995
1905
|
input_names = [
|
|
@@ -1218,6 +2128,7 @@ def convert(
|
|
|
1218
2128
|
'relu_relu6_merge_op_names': {},
|
|
1219
2129
|
'mul_div_replace_op_names': {},
|
|
1220
2130
|
'use_cuda': use_cuda,
|
|
2131
|
+
'tf_input_cache': tf_input_cache,
|
|
1221
2132
|
}
|
|
1222
2133
|
|
|
1223
2134
|
tf_layers_dict = {}
|
|
@@ -1291,30 +2202,29 @@ def convert(
|
|
|
1291
2202
|
)
|
|
1292
2203
|
|
|
1293
2204
|
# download test data
|
|
1294
|
-
all_four_dim = sum(
|
|
1295
|
-
[
|
|
1296
|
-
1 for input in inputs \
|
|
1297
|
-
if len(input.shape) == 4 \
|
|
1298
|
-
and input.shape[0] is not None \
|
|
1299
|
-
and input.shape[0] <= 20 \
|
|
1300
|
-
and input.shape[-1] == 3 \
|
|
1301
|
-
and input.shape[1] is not None \
|
|
1302
|
-
and input.shape[2] is not None
|
|
1303
|
-
]
|
|
1304
|
-
) == len(inputs)
|
|
1305
|
-
same_batch_dim = False
|
|
1306
|
-
if all_four_dim:
|
|
1307
|
-
batch_size = inputs[0].shape[0]
|
|
1308
|
-
for input in inputs:
|
|
1309
|
-
same_batch_dim = batch_size == input.shape[0]
|
|
1310
2205
|
test_data_nhwc = None
|
|
1311
|
-
if
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
2206
|
+
if inputs:
|
|
2207
|
+
all_four_dim = sum(
|
|
2208
|
+
[
|
|
2209
|
+
1 for input in inputs \
|
|
2210
|
+
if len(input.shape) == 4 \
|
|
2211
|
+
and input.shape[0] is not None \
|
|
2212
|
+
and input.shape[0] <= 20 \
|
|
2213
|
+
and input.shape[-1] == 3 \
|
|
2214
|
+
and input.shape[1] is not None \
|
|
2215
|
+
and input.shape[2] is not None
|
|
2216
|
+
]
|
|
2217
|
+
) == len(inputs)
|
|
2218
|
+
same_batch_dim = False
|
|
2219
|
+
if all_four_dim:
|
|
2220
|
+
batch_size = inputs[0].shape[0]
|
|
2221
|
+
for input in inputs:
|
|
2222
|
+
same_batch_dim = batch_size == input.shape[0]
|
|
2223
|
+
if all_four_dim and same_batch_dim:
|
|
2224
|
+
test_data: np.ndarray = download_test_image_data()
|
|
2225
|
+
test_data_nhwc = test_data[:inputs[0].shape[0], ...]
|
|
2226
|
+
if check_onnx_tf_outputs_sample_data_normalization == "denorm":
|
|
2227
|
+
test_data_nhwc = test_data_nhwc * 255.0
|
|
1318
2228
|
|
|
1319
2229
|
# ONNX dummy inference
|
|
1320
2230
|
# Generate output for all OPs.
|
|
@@ -1400,7 +2310,10 @@ def convert(
|
|
|
1400
2310
|
exported_onnx_graph = gs.export_onnx(graph, do_type_check=False, **meta_data)
|
|
1401
2311
|
if metadata_props is not None:
|
|
1402
2312
|
exported_onnx_graph.metadata_props.extend(metadata_props)
|
|
1403
|
-
|
|
2313
|
+
if not has_external_data:
|
|
2314
|
+
estimated_graph = onnx.shape_inference.infer_shapes(exported_onnx_graph)
|
|
2315
|
+
else:
|
|
2316
|
+
estimated_graph = exported_onnx_graph
|
|
1404
2317
|
if input_onnx_file_path is not None:
|
|
1405
2318
|
onnx.save(estimated_graph, input_onnx_file_path)
|
|
1406
2319
|
if not not_use_onnxsim:
|
|
@@ -1580,6 +2493,14 @@ def convert(
|
|
|
1580
2493
|
outputs[oidx] = tf_keras.layers.Lambda(lambda x: tf.constant(y))(x)
|
|
1581
2494
|
|
|
1582
2495
|
model = tf_keras.Model(inputs=inputs, outputs=outputs)
|
|
2496
|
+
try:
|
|
2497
|
+
onnx_output_layouts = {
|
|
2498
|
+
name: tf_layers_dict.get(name, {}).get('nhwc', False)
|
|
2499
|
+
for name in onnx_graph_output_names
|
|
2500
|
+
}
|
|
2501
|
+
model.onnx_output_layouts = onnx_output_layouts
|
|
2502
|
+
except Exception:
|
|
2503
|
+
pass
|
|
1583
2504
|
debug('')
|
|
1584
2505
|
|
|
1585
2506
|
# The process ends normally without saving the model.
|
|
@@ -3050,6 +3971,23 @@ def main():
|
|
|
3050
3971
|
'e.g. \n' +
|
|
3051
3972
|
'--output_names_to_interrupt_model_conversion "output0" "output1" "output2"'
|
|
3052
3973
|
)
|
|
3974
|
+
parser.add_argument(
|
|
3975
|
+
'-easm',
|
|
3976
|
+
'--enable_auto_split_model',
|
|
3977
|
+
action='store_true',
|
|
3978
|
+
help=\
|
|
3979
|
+
'Force auto split regardless of the ONNX file size. \n' +
|
|
3980
|
+
'Uses --auto_split_max_size_mb as the target partition size.'
|
|
3981
|
+
)
|
|
3982
|
+
parser.add_argument(
|
|
3983
|
+
'-asmsm',
|
|
3984
|
+
'--auto_split_max_size_mb',
|
|
3985
|
+
type=int,
|
|
3986
|
+
default=1024,
|
|
3987
|
+
help=\
|
|
3988
|
+
'Target maximum size per partition in MB based on ONNX initializer sizes. \n' +
|
|
3989
|
+
'Used when auto-split is triggered or forced.'
|
|
3990
|
+
)
|
|
3053
3991
|
parser.add_argument(
|
|
3054
3992
|
'-dgc',
|
|
3055
3993
|
'--disable_group_convolution',
|
|
@@ -3450,6 +4388,8 @@ def main():
|
|
|
3450
4388
|
param_replacement_file=args.param_replacement_file,
|
|
3451
4389
|
auto_generate_json=args.auto_generate_json,
|
|
3452
4390
|
auto_generate_json_on_error=args.auto_generate_json_on_error,
|
|
4391
|
+
enable_auto_split_model=args.enable_auto_split_model,
|
|
4392
|
+
auto_split_max_size_mb=args.auto_split_max_size_mb,
|
|
3453
4393
|
check_gpu_delegate_compatibility=args.check_gpu_delegate_compatibility,
|
|
3454
4394
|
check_onnx_tf_outputs_elementwise_close=args.check_onnx_tf_outputs_elementwise_close,
|
|
3455
4395
|
check_onnx_tf_outputs_elementwise_close_full=args.check_onnx_tf_outputs_elementwise_close_full,
|