onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +412 -12
- onnx_diagnostic/export/api.py +111 -8
- onnx_diagnostic/export/control_flow.py +48 -345
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +12 -7
- onnx_diagnostic/export/onnx_plug.py +531 -0
- onnx_diagnostic/ext_test_case.py +163 -48
- onnx_diagnostic/helpers/cache_helper.py +1 -1
- onnx_diagnostic/helpers/dot_helper.py +222 -0
- onnx_diagnostic/helpers/helper.py +108 -37
- onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +531 -6
- onnx_diagnostic/helpers/ort_session.py +45 -19
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +131 -8
- onnx_diagnostic/reference/ort_evaluator.py +228 -46
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
- onnx_diagnostic/torch_models/code_sample.py +2 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
- onnx_diagnostic/torch_models/validate.py +64 -2
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
- onnx_diagnostic/torch_onnx/sbs.py +969 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
|
@@ -3,9 +3,20 @@ import json
|
|
|
3
3
|
import os
|
|
4
4
|
import sys
|
|
5
5
|
import warnings
|
|
6
|
-
from typing import
|
|
6
|
+
from typing import (
|
|
7
|
+
Any,
|
|
8
|
+
Callable,
|
|
9
|
+
Dict,
|
|
10
|
+
Iterable,
|
|
11
|
+
Iterator,
|
|
12
|
+
List,
|
|
13
|
+
Optional,
|
|
14
|
+
Sequence,
|
|
15
|
+
Set,
|
|
16
|
+
Tuple,
|
|
17
|
+
Union,
|
|
18
|
+
)
|
|
7
19
|
import numpy as np
|
|
8
|
-
import numpy.typing as npt
|
|
9
20
|
import onnx
|
|
10
21
|
import onnx.helper as oh
|
|
11
22
|
import onnx.numpy_helper as onh
|
|
@@ -15,11 +26,14 @@ from onnx import (
|
|
|
15
26
|
GraphProto,
|
|
16
27
|
ModelProto,
|
|
17
28
|
NodeProto,
|
|
29
|
+
OperatorSetIdProto,
|
|
18
30
|
TensorProto,
|
|
19
31
|
ValueInfoProto,
|
|
20
32
|
load as onnx_load,
|
|
21
33
|
)
|
|
22
34
|
|
|
35
|
+
TensorLike = Union[np.ndarray, "torch.Tensor"] # noqa: F821
|
|
36
|
+
|
|
23
37
|
|
|
24
38
|
def _make_stat(init: TensorProto) -> Dict[str, float]:
|
|
25
39
|
"""
|
|
@@ -331,7 +345,7 @@ def onnx_dtype_name(itype: int, exc: bool = True) -> str:
|
|
|
331
345
|
print(onnx_dtype_name(7))
|
|
332
346
|
"""
|
|
333
347
|
for k in dir(TensorProto):
|
|
334
|
-
if k.upper() == k and k
|
|
348
|
+
if k.upper() == k and k not in {"DESCRIPTOR", "EXTERNAL", "DEFAULT"}:
|
|
335
349
|
v = getattr(TensorProto, k)
|
|
336
350
|
if v == itype:
|
|
337
351
|
return k
|
|
@@ -477,7 +491,7 @@ def convert_endian(tensor: TensorProto) -> None:
|
|
|
477
491
|
tensor.raw_data = np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap().tobytes()
|
|
478
492
|
|
|
479
493
|
|
|
480
|
-
def from_array_ml_dtypes(arr:
|
|
494
|
+
def from_array_ml_dtypes(arr: TensorLike, name: Optional[str] = None) -> TensorProto:
|
|
481
495
|
"""
|
|
482
496
|
Converts a numpy array to a tensor def assuming the dtype
|
|
483
497
|
is defined in ml_dtypes.
|
|
@@ -523,7 +537,7 @@ _STORAGE_TYPE = {
|
|
|
523
537
|
}
|
|
524
538
|
|
|
525
539
|
|
|
526
|
-
def from_array_extended(tensor:
|
|
540
|
+
def from_array_extended(tensor: TensorLike, name: Optional[str] = None) -> TensorProto:
|
|
527
541
|
"""
|
|
528
542
|
Converts an array into a :class:`onnx.TensorProto`.
|
|
529
543
|
|
|
@@ -590,7 +604,7 @@ def from_array_extended(tensor: npt.ArrayLike, name: Optional[str] = None) -> Te
|
|
|
590
604
|
return t
|
|
591
605
|
|
|
592
606
|
|
|
593
|
-
def to_array_extended(proto: TensorProto) ->
|
|
607
|
+
def to_array_extended(proto: TensorProto) -> TensorLike:
|
|
594
608
|
"""Converts :class:`onnx.TensorProto` into a numpy array."""
|
|
595
609
|
arr = onh.to_array(proto)
|
|
596
610
|
if proto.data_type >= onnx.TensorProto.BFLOAT16:
|
|
@@ -1195,3 +1209,514 @@ def shadowing_names(
|
|
|
1195
1209
|
existing |= not_empty
|
|
1196
1210
|
created |= not_empty
|
|
1197
1211
|
return shadow, post_shadow, created
|
|
1212
|
+
|
|
1213
|
+
|
|
1214
|
+
def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
|
|
1215
|
+
"""
|
|
1216
|
+
Returns the hidden inputs (inputs coming from an upper context)
|
|
1217
|
+
used by a subgraph. It excludes empty names.
|
|
1218
|
+
"""
|
|
1219
|
+
hidden = set()
|
|
1220
|
+
memo = (
|
|
1221
|
+
set(i.name for i in graph.initializer)
|
|
1222
|
+
| set(i.name for i in graph.sparse_initializer)
|
|
1223
|
+
| set(i.name for i in graph.input)
|
|
1224
|
+
)
|
|
1225
|
+
for node in graph.node:
|
|
1226
|
+
for i in node.input:
|
|
1227
|
+
if i and i not in memo:
|
|
1228
|
+
hidden.add(i)
|
|
1229
|
+
for att in node.attribute:
|
|
1230
|
+
if att.type == onnx.AttributeProto.GRAPH and att.g:
|
|
1231
|
+
hid = get_hidden_inputs(att.g)
|
|
1232
|
+
less = set(h for h in hid if h not in memo)
|
|
1233
|
+
hidden |= less
|
|
1234
|
+
memo |= set(node.output)
|
|
1235
|
+
return hidden
|
|
1236
|
+
|
|
1237
|
+
|
|
1238
|
+
def get_all_node_inputs(node: onnx.NodeProto) -> Set[str]:
|
|
1239
|
+
"""
|
|
1240
|
+
Returns input and hidden inputs of a node.
|
|
1241
|
+
See :func:`get_hidden_inputs`. It excludes empty names.
|
|
1242
|
+
"""
|
|
1243
|
+
start = {i for i in node.input if i}
|
|
1244
|
+
if node.op_type in {"Scan", "Loop", "If"}:
|
|
1245
|
+
for att in node.attribute:
|
|
1246
|
+
if att.type == onnx.AttributeProto.GRAPH:
|
|
1247
|
+
start |= get_hidden_inputs(att.g)
|
|
1248
|
+
return start
|
|
1249
|
+
|
|
1250
|
+
|
|
1251
|
+
def extract_subset_of_nodes(
|
|
1252
|
+
model: ModelProto,
|
|
1253
|
+
name: str,
|
|
1254
|
+
node_index: Optional[int] = None,
|
|
1255
|
+
cut_points: Optional[Set[str]] = None,
|
|
1256
|
+
) -> List[NodeProto]:
|
|
1257
|
+
"""
|
|
1258
|
+
Extracts the minimal subgraphs which can produce the output ``name``
|
|
1259
|
+
knowing ``cut_points``.
|
|
1260
|
+
|
|
1261
|
+
:param model: original model
|
|
1262
|
+
:param name: result name
|
|
1263
|
+
:param node_index: if the node index is known, otherwise searches for it
|
|
1264
|
+
:param cut_points: the known results or input name otherwise
|
|
1265
|
+
:return: minimal list of nodes
|
|
1266
|
+
"""
|
|
1267
|
+
if node_index is None:
|
|
1268
|
+
for i, node in enumerate(model.graph.node):
|
|
1269
|
+
if name in node.output:
|
|
1270
|
+
node_index = i
|
|
1271
|
+
break
|
|
1272
|
+
assert node_index is not None and node_index < len(model.graph.node), (
|
|
1273
|
+
f"node_index={node_index} (n_nodes={len(model.graph.node)}) "
|
|
1274
|
+
f"is still empty or wrong for result {name!r}"
|
|
1275
|
+
)
|
|
1276
|
+
assert name in model.graph.node[node_index].output, (
|
|
1277
|
+
f"Unable to find {name!r} in {model.graph.node[node_index].output}, "
|
|
1278
|
+
f"node={pretty_onnx(model.graph.node[node_index])}"
|
|
1279
|
+
)
|
|
1280
|
+
if cut_points is None:
|
|
1281
|
+
cut_points = {n.name for n in model.graph.input} | {
|
|
1282
|
+
n.name for n in model.graph.initializer
|
|
1283
|
+
}
|
|
1284
|
+
elif model.graph.initializer:
|
|
1285
|
+
cut_points = cut_points | {n.name for n in model.graph.initializer}
|
|
1286
|
+
|
|
1287
|
+
node = model.graph.node[node_index]
|
|
1288
|
+
selected = {node_index}
|
|
1289
|
+
current_node_index = node_index
|
|
1290
|
+
current_input_index = 0
|
|
1291
|
+
intermediate = {name}
|
|
1292
|
+
cut_points -= {name}
|
|
1293
|
+
cached: Dict[int, List[str]] = {}
|
|
1294
|
+
inputs = set(k for k in node.input if k)
|
|
1295
|
+
while not (inputs <= cut_points) and current_node_index >= 0:
|
|
1296
|
+
node = model.graph.node[current_node_index]
|
|
1297
|
+
# node inputs including hidden ones
|
|
1298
|
+
if current_node_index in cached:
|
|
1299
|
+
node_inputs = cached[current_node_index]
|
|
1300
|
+
else:
|
|
1301
|
+
set_inputs = set(i for i in node.input if i)
|
|
1302
|
+
if node.op_type in {"Scan", "If", "Loop"}:
|
|
1303
|
+
# there are hidden inputs
|
|
1304
|
+
for att in node.attribute:
|
|
1305
|
+
if att.type == onnx.AttributeProto.GRAPH:
|
|
1306
|
+
set_inputs |= get_hidden_inputs(att.g)
|
|
1307
|
+
node_inputs = list(set_inputs)
|
|
1308
|
+
cached[current_node_index] = node_inputs
|
|
1309
|
+
# processing
|
|
1310
|
+
if current_input_index == 0 or not node_inputs:
|
|
1311
|
+
needs = [o for o in node.output if o in intermediate and o not in cut_points]
|
|
1312
|
+
if needs:
|
|
1313
|
+
selected.add(current_node_index)
|
|
1314
|
+
if not node_inputs:
|
|
1315
|
+
current_node_index -= 1
|
|
1316
|
+
current_input_index = 0
|
|
1317
|
+
continue
|
|
1318
|
+
else:
|
|
1319
|
+
current_node_index -= 1
|
|
1320
|
+
current_input_index = 0
|
|
1321
|
+
continue
|
|
1322
|
+
# more intermediate results
|
|
1323
|
+
assert current_input_index < len(node_inputs), (
|
|
1324
|
+
f"current_input_index={current_input_index} but node_inputs={node_inputs}, "
|
|
1325
|
+
f"node={pretty_onnx(node)}"
|
|
1326
|
+
)
|
|
1327
|
+
res = node_inputs[current_input_index]
|
|
1328
|
+
if res not in cut_points:
|
|
1329
|
+
intermediate.add(res)
|
|
1330
|
+
current_input_index += 1
|
|
1331
|
+
if current_input_index >= len(node_inputs):
|
|
1332
|
+
current_node_index -= 1
|
|
1333
|
+
current_input_index = 0
|
|
1334
|
+
|
|
1335
|
+
return [model.graph.node[i] for i in sorted(selected)]
|
|
1336
|
+
|
|
1337
|
+
|
|
1338
|
+
def make_submodel(
|
|
1339
|
+
nodes: List[NodeProto],
|
|
1340
|
+
ir_version: int,
|
|
1341
|
+
opset_imports: List[OperatorSetIdProto],
|
|
1342
|
+
output_names: List[str],
|
|
1343
|
+
type_rank_fn: Callable[[str], Tuple[int, int]],
|
|
1344
|
+
) -> ModelProto:
|
|
1345
|
+
"""
|
|
1346
|
+
Creates a model with the given list of nodes.
|
|
1347
|
+
It computes the minimum list of inputs needed for this model.
|
|
1348
|
+
The function assumes the nodes are sorted.
|
|
1349
|
+
It does not handle yet subgraphs.
|
|
1350
|
+
|
|
1351
|
+
:param nodes: list of nodes
|
|
1352
|
+
:param ir_version: ir version
|
|
1353
|
+
:param opset_imports: opset import
|
|
1354
|
+
:param output_names: desired outputs
|
|
1355
|
+
:param function: function returning the type and the rank of a result
|
|
1356
|
+
:return: model proto
|
|
1357
|
+
"""
|
|
1358
|
+
|
|
1359
|
+
def _mkv_(name, itype, irank):
|
|
1360
|
+
return oh.make_tensor_value_info(name, itype, [f"{name}_d{i}" for i in range(irank)])
|
|
1361
|
+
|
|
1362
|
+
not_known: Set[str] = set()
|
|
1363
|
+
for node in nodes[::-1]:
|
|
1364
|
+
not_known -= {o for o in node.output if o}
|
|
1365
|
+
not_known |= {i for i in node.input if i}
|
|
1366
|
+
if node.op_type in {"Scan", "If", "Loop"}:
|
|
1367
|
+
# there are hidden inputs
|
|
1368
|
+
for att in node.attribute:
|
|
1369
|
+
if att.type == onnx.AttributeProto.GRAPH:
|
|
1370
|
+
not_known |= get_hidden_inputs(att.g)
|
|
1371
|
+
|
|
1372
|
+
model = oh.make_model(
|
|
1373
|
+
oh.make_graph(
|
|
1374
|
+
nodes,
|
|
1375
|
+
"submodel",
|
|
1376
|
+
[_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known) if n],
|
|
1377
|
+
[_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names) if n],
|
|
1378
|
+
),
|
|
1379
|
+
ir_version=ir_version,
|
|
1380
|
+
opset_imports=opset_imports,
|
|
1381
|
+
)
|
|
1382
|
+
return model
|
|
1383
|
+
|
|
1384
|
+
|
|
1385
|
+
def get_tensor_shape(
|
|
1386
|
+
obj: Union[onnx.ValueInfoProto, onnx.TypeProto, onnx.TensorProto],
|
|
1387
|
+
) -> Optional[List[Optional[Union[int, str]]]]:
|
|
1388
|
+
"""
|
|
1389
|
+
Returns the shape if that makes sense for this object.
|
|
1390
|
+
"""
|
|
1391
|
+
if isinstance(obj, ValueInfoProto):
|
|
1392
|
+
return get_tensor_shape(obj.type)
|
|
1393
|
+
elif not isinstance(obj, onnx.TypeProto):
|
|
1394
|
+
raise TypeError(f"Unexpected type {type(obj)!r}.")
|
|
1395
|
+
if not obj.tensor_type.HasField("shape"):
|
|
1396
|
+
return None
|
|
1397
|
+
shape = []
|
|
1398
|
+
for d in obj.tensor_type.shape.dim:
|
|
1399
|
+
v = d.dim_value if d.dim_value > 0 else d.dim_param
|
|
1400
|
+
shape.append(v)
|
|
1401
|
+
if not shape:
|
|
1402
|
+
return shape
|
|
1403
|
+
return [None if s in (0, "") else s for s in shape]
|
|
1404
|
+
|
|
1405
|
+
|
|
1406
|
+
def _enumerate_model_node_outputs(
|
|
1407
|
+
model: ModelProto, add_node: bool = False, order: bool = False
|
|
1408
|
+
) -> Iterable[Union[str, Tuple[str, NodeProto]]]:
|
|
1409
|
+
"""
|
|
1410
|
+
Enumerates all the nodes of a model.
|
|
1411
|
+
|
|
1412
|
+
:param model: :epkg:`ONNX` graph
|
|
1413
|
+
:param add_node: if False, the function enumerates
|
|
1414
|
+
all output names from every node, otherwise, it
|
|
1415
|
+
enumerates tuple (output name, node)
|
|
1416
|
+
:param order: goes through outputs following the graph order
|
|
1417
|
+
:return: enumerator
|
|
1418
|
+
"""
|
|
1419
|
+
assert hasattr(model, "graph"), "Parameter model is not an ONNX model but {type(model)}"
|
|
1420
|
+
if order:
|
|
1421
|
+
edges = []
|
|
1422
|
+
d_order = {}
|
|
1423
|
+
node_names = {}
|
|
1424
|
+
for inp in model.graph.input:
|
|
1425
|
+
d_order[0, inp.name] = 0
|
|
1426
|
+
for node in model.graph.node:
|
|
1427
|
+
d_order[1, node.name] = 0
|
|
1428
|
+
for i in node.input:
|
|
1429
|
+
edges.append(("in", i, node.name))
|
|
1430
|
+
for o in node.output:
|
|
1431
|
+
edges.append(("out", o, node.name))
|
|
1432
|
+
node_names[o] = node
|
|
1433
|
+
d_order[0, o] = 0
|
|
1434
|
+
|
|
1435
|
+
modif = 1
|
|
1436
|
+
n_iter = 0
|
|
1437
|
+
while modif > 0 and n_iter <= len(model.graph.node):
|
|
1438
|
+
modif = 0
|
|
1439
|
+
n_iter += 1
|
|
1440
|
+
for kind, data_name, node_name in edges:
|
|
1441
|
+
if kind == "in":
|
|
1442
|
+
if (0, data_name) not in d_order:
|
|
1443
|
+
continue
|
|
1444
|
+
if d_order[0, data_name] + 1 > d_order[1, node_name]:
|
|
1445
|
+
modif += 1
|
|
1446
|
+
d_order[1, node_name] = d_order[0, data_name] + 1
|
|
1447
|
+
else:
|
|
1448
|
+
if d_order[1, node_name] + 1 > d_order[0, data_name]:
|
|
1449
|
+
modif += 1
|
|
1450
|
+
d_order[0, data_name] = d_order[1, node_name] + 1
|
|
1451
|
+
|
|
1452
|
+
orders = [(v, k) for k, v in d_order.items()]
|
|
1453
|
+
orders.sort()
|
|
1454
|
+
|
|
1455
|
+
for _, k in orders:
|
|
1456
|
+
if k[0] == 1:
|
|
1457
|
+
continue
|
|
1458
|
+
out = k[1]
|
|
1459
|
+
if out not in node_names:
|
|
1460
|
+
continue
|
|
1461
|
+
yield (out, node_names[out]) if add_node else out
|
|
1462
|
+
else:
|
|
1463
|
+
for node in model.graph.node:
|
|
1464
|
+
for out in node.output:
|
|
1465
|
+
yield (out, node) if add_node else out
|
|
1466
|
+
|
|
1467
|
+
|
|
1468
|
+
def onnx_remove_node_unused(
|
|
1469
|
+
graph: Union[onnx.GraphProto, onnx.FunctionProto], recursive=True
|
|
1470
|
+
) -> Union[onnx.GraphProto, onnx.FunctionProto]:
|
|
1471
|
+
"""
|
|
1472
|
+
Removes unused nodes of the graph. An unused node
|
|
1473
|
+
is not involved in the output computation.
|
|
1474
|
+
|
|
1475
|
+
:param onnx_model: onnx model
|
|
1476
|
+
:param recursive: looks into subgraphs
|
|
1477
|
+
:return: new Graph
|
|
1478
|
+
"""
|
|
1479
|
+
is_function = isinstance(graph, FunctionProto)
|
|
1480
|
+
|
|
1481
|
+
# mark outputs
|
|
1482
|
+
marked: Dict[str, Set[str]] = (
|
|
1483
|
+
{o: set() for o in graph.output}
|
|
1484
|
+
if is_function
|
|
1485
|
+
else {o.name: set() for o in graph.output}
|
|
1486
|
+
)
|
|
1487
|
+
nodes = list(graph.node)
|
|
1488
|
+
|
|
1489
|
+
# mark node output
|
|
1490
|
+
for node in reversed(nodes):
|
|
1491
|
+
used = False
|
|
1492
|
+
for o in node.output:
|
|
1493
|
+
if o and o in marked:
|
|
1494
|
+
for i in get_all_node_inputs(node):
|
|
1495
|
+
marked[o].add(i)
|
|
1496
|
+
used = True
|
|
1497
|
+
if used:
|
|
1498
|
+
for i in get_all_node_inputs(node):
|
|
1499
|
+
marked[i] = set()
|
|
1500
|
+
|
|
1501
|
+
# removed nodes
|
|
1502
|
+
removed = set()
|
|
1503
|
+
marked_set = set(marked)
|
|
1504
|
+
for ind, node in enumerate(nodes):
|
|
1505
|
+
if not ({o for o in node.output if o} & marked_set):
|
|
1506
|
+
removed.add(ind)
|
|
1507
|
+
|
|
1508
|
+
if not is_function:
|
|
1509
|
+
initializers = [i for i in graph.initializer if i.name in marked]
|
|
1510
|
+
sparse_initializers = [i for i in graph.sparse_initializer if i.name in marked]
|
|
1511
|
+
new_nodes = [node for i, node in enumerate(nodes) if i not in removed]
|
|
1512
|
+
|
|
1513
|
+
# Finally create the new graph.
|
|
1514
|
+
if is_function:
|
|
1515
|
+
return oh.make_function(
|
|
1516
|
+
graph.domain,
|
|
1517
|
+
graph.name,
|
|
1518
|
+
graph.input,
|
|
1519
|
+
graph.output,
|
|
1520
|
+
new_nodes,
|
|
1521
|
+
opset_imports=graph.opset_import,
|
|
1522
|
+
attributes=graph.attribute,
|
|
1523
|
+
doc_string=graph.doc_string,
|
|
1524
|
+
)
|
|
1525
|
+
new_graph = oh.make_graph(
|
|
1526
|
+
new_nodes,
|
|
1527
|
+
graph.name,
|
|
1528
|
+
graph.input,
|
|
1529
|
+
graph.output,
|
|
1530
|
+
initializers,
|
|
1531
|
+
sparse_initializers,
|
|
1532
|
+
)
|
|
1533
|
+
new_graph.value_info.extend(graph.value_info)
|
|
1534
|
+
return new_graph
|
|
1535
|
+
|
|
1536
|
+
|
|
1537
|
+
def select_model_inputs_outputs(
|
|
1538
|
+
model: ModelProto,
|
|
1539
|
+
outputs: Optional[List[str]] = None,
|
|
1540
|
+
inputs: Optional[List[str]] = None,
|
|
1541
|
+
infer_shapes: bool = True,
|
|
1542
|
+
overwrite: Optional[Dict[str, Any]] = None,
|
|
1543
|
+
remove_unused: bool = True,
|
|
1544
|
+
verbose: int = 0,
|
|
1545
|
+
):
|
|
1546
|
+
"""
|
|
1547
|
+
Takes a model and changes its outputs.
|
|
1548
|
+
|
|
1549
|
+
:param model: :epkg:`ONNX` model
|
|
1550
|
+
:param inputs: new inputs, same ones if None
|
|
1551
|
+
:param outputs: new outputs, same ones if None
|
|
1552
|
+
:param infer_shapes: infer inputs and outputs shapes
|
|
1553
|
+
:param overwrite: overwrite type and shapes for
|
|
1554
|
+
inputs or outputs, *overwrite* is a
|
|
1555
|
+
dictionary `{'name': (numpy dtype, shape)}`
|
|
1556
|
+
:param remove_unused: remove unused nodes from the graph
|
|
1557
|
+
:param verbose: display information while converting
|
|
1558
|
+
:return: modified model
|
|
1559
|
+
|
|
1560
|
+
The function removes unneeded nodes.
|
|
1561
|
+
|
|
1562
|
+
The following example shows how to change the inputs of model
|
|
1563
|
+
to bypass the first nodes. Shape inferences fails to determine
|
|
1564
|
+
the new inputs type. They need to be overwritten.
|
|
1565
|
+
`verbose=1` shows the number of deleted nodes.
|
|
1566
|
+
|
|
1567
|
+
::
|
|
1568
|
+
|
|
1569
|
+
import onnx
|
|
1570
|
+
from onnx_extended.tools.onnx_nodes import select_model_inputs_outputs
|
|
1571
|
+
|
|
1572
|
+
onx = onnx.load(path)
|
|
1573
|
+
onx2 = select_model_inputs_outputs(
|
|
1574
|
+
onx, inputs=["a", "b"],
|
|
1575
|
+
infer_shapes=True, verbose=1,
|
|
1576
|
+
overwrite={'a': (numpy.int32, None), 'b': (numpy.int64, None)})
|
|
1577
|
+
onnx.save(onx2, path2)
|
|
1578
|
+
"""
|
|
1579
|
+
if not isinstance(model, ModelProto):
|
|
1580
|
+
raise TypeError(f"Unexpected type {type(model)} for model.")
|
|
1581
|
+
if inputs is not None and not isinstance(inputs, list):
|
|
1582
|
+
inputs = [inputs]
|
|
1583
|
+
if outputs is not None and not isinstance(outputs, list):
|
|
1584
|
+
outputs = [outputs]
|
|
1585
|
+
if inputs is None:
|
|
1586
|
+
inputs = [i.name for i in model.graph.input]
|
|
1587
|
+
if outputs is None:
|
|
1588
|
+
outputs = [o.name for o in model.graph.output]
|
|
1589
|
+
|
|
1590
|
+
mark_var = {}
|
|
1591
|
+
for out in _enumerate_model_node_outputs(model):
|
|
1592
|
+
mark_var[out] = 0
|
|
1593
|
+
for inp in inputs:
|
|
1594
|
+
mark_var[inp] = 0
|
|
1595
|
+
for out in outputs:
|
|
1596
|
+
assert out in mark_var, f"Output {out!r} not found in model."
|
|
1597
|
+
mark_var[out] = 1
|
|
1598
|
+
|
|
1599
|
+
nodes = list(model.graph.node[::-1])
|
|
1600
|
+
mark_op = {}
|
|
1601
|
+
for node in list(nodes):
|
|
1602
|
+
mark_op[id(node)] = 0
|
|
1603
|
+
|
|
1604
|
+
# We mark all the nodes we need to keep.
|
|
1605
|
+
nb = 1
|
|
1606
|
+
while nb > 0:
|
|
1607
|
+
nb = 0
|
|
1608
|
+
for node in nodes:
|
|
1609
|
+
if mark_op[id(node)] == 1:
|
|
1610
|
+
continue
|
|
1611
|
+
mod = False
|
|
1612
|
+
for out in node.output:
|
|
1613
|
+
if mark_var[out] == 1:
|
|
1614
|
+
mark_op[id(node)] = 1
|
|
1615
|
+
mod = True
|
|
1616
|
+
break
|
|
1617
|
+
if not mod:
|
|
1618
|
+
continue
|
|
1619
|
+
|
|
1620
|
+
node_inputs = get_all_node_inputs(node)
|
|
1621
|
+
|
|
1622
|
+
nb += 1
|
|
1623
|
+
for inp in node_inputs:
|
|
1624
|
+
if inp in inputs:
|
|
1625
|
+
continue
|
|
1626
|
+
if mark_var.get(inp, 0) == 1:
|
|
1627
|
+
continue
|
|
1628
|
+
mark_var[inp] = 1
|
|
1629
|
+
nb += 1
|
|
1630
|
+
|
|
1631
|
+
# All nodes verifies mark_op[node.name] == 1
|
|
1632
|
+
keep_nodes = [node for node in nodes[::-1] if mark_op[id(node)] == 1]
|
|
1633
|
+
|
|
1634
|
+
known_shapes = {}
|
|
1635
|
+
if infer_shapes:
|
|
1636
|
+
shapes = onnx.shape_inference.infer_shapes(model)
|
|
1637
|
+
for shape in shapes.graph.value_info:
|
|
1638
|
+
known_shapes[shape.name] = shape.type
|
|
1639
|
+
for shape in shapes.graph.input:
|
|
1640
|
+
known_shapes[shape.name] = shape.type
|
|
1641
|
+
for shape in shapes.graph.output:
|
|
1642
|
+
known_shapes[shape.name] = shape.type
|
|
1643
|
+
else:
|
|
1644
|
+
for shape in model.graph.input:
|
|
1645
|
+
known_shapes[shape.name] = shape.type
|
|
1646
|
+
for shape in model.graph.output:
|
|
1647
|
+
known_shapes[shape.name] = shape.type
|
|
1648
|
+
|
|
1649
|
+
var_in = []
|
|
1650
|
+
existing = {i.name: i for i in model.graph.input}
|
|
1651
|
+
for name in inputs:
|
|
1652
|
+
if overwrite is not None and name in overwrite:
|
|
1653
|
+
dtype, shape = overwrite[name]
|
|
1654
|
+
proto_dtype = np_dtype_to_tensor_dtype(dtype)
|
|
1655
|
+
value_info = oh.make_tensor_value_info(name, proto_dtype, shape)
|
|
1656
|
+
elif name in known_shapes:
|
|
1657
|
+
info = known_shapes[name].tensor_type
|
|
1658
|
+
proto_dtype = info.elem_type
|
|
1659
|
+
if proto_dtype == 0:
|
|
1660
|
+
value_info = ValueInfoProto()
|
|
1661
|
+
value_info.name = name
|
|
1662
|
+
else:
|
|
1663
|
+
shape = get_tensor_shape(known_shapes[name])
|
|
1664
|
+
value_info = oh.make_tensor_value_info(name, proto_dtype, shape)
|
|
1665
|
+
elif name in existing:
|
|
1666
|
+
value_info = existing[name]
|
|
1667
|
+
else:
|
|
1668
|
+
value_info = ValueInfoProto()
|
|
1669
|
+
value_info.name = name
|
|
1670
|
+
var_in.append(value_info)
|
|
1671
|
+
|
|
1672
|
+
var_out = []
|
|
1673
|
+
existing = {i.name: i for i in model.graph.output}
|
|
1674
|
+
for name in outputs:
|
|
1675
|
+
if overwrite is not None and name in overwrite:
|
|
1676
|
+
dtype, shape = overwrite[name]
|
|
1677
|
+
proto_dtype = np_dtype_to_tensor_dtype(dtype)
|
|
1678
|
+
value_info = oh.make_tensor_value_info(name, proto_dtype, shape)
|
|
1679
|
+
elif name in known_shapes:
|
|
1680
|
+
info = known_shapes[name].tensor_type
|
|
1681
|
+
proto_dtype = info.elem_type
|
|
1682
|
+
if proto_dtype == 0:
|
|
1683
|
+
value_info = ValueInfoProto()
|
|
1684
|
+
value_info.name = name
|
|
1685
|
+
else:
|
|
1686
|
+
shape = get_tensor_shape(known_shapes[name])
|
|
1687
|
+
value_info = oh.make_tensor_value_info(name, proto_dtype, shape)
|
|
1688
|
+
elif name in existing:
|
|
1689
|
+
value_info = existing[name]
|
|
1690
|
+
else:
|
|
1691
|
+
value_info = ValueInfoProto()
|
|
1692
|
+
value_info.name = name
|
|
1693
|
+
var_out.append(value_info)
|
|
1694
|
+
|
|
1695
|
+
graph = oh.make_graph(
|
|
1696
|
+
keep_nodes,
|
|
1697
|
+
model.graph.name,
|
|
1698
|
+
var_in,
|
|
1699
|
+
var_out,
|
|
1700
|
+
model.graph.initializer,
|
|
1701
|
+
sparse_initializer=model.graph.sparse_initializer,
|
|
1702
|
+
)
|
|
1703
|
+
if remove_unused:
|
|
1704
|
+
graph = onnx_remove_node_unused(graph, recursive=False)
|
|
1705
|
+
onnx_model = oh.make_model(graph, functions=model.functions)
|
|
1706
|
+
onnx_model.ir_version = model.ir_version
|
|
1707
|
+
onnx_model.producer_name = model.producer_name
|
|
1708
|
+
onnx_model.producer_version = model.producer_version
|
|
1709
|
+
onnx_model.domain = model.domain
|
|
1710
|
+
onnx_model.model_version = model.model_version
|
|
1711
|
+
onnx_model.doc_string = model.doc_string
|
|
1712
|
+
if model.metadata_props:
|
|
1713
|
+
values = {p.key: p.value for p in model.metadata_props}
|
|
1714
|
+
oh.set_model_props(onnx_model, values)
|
|
1715
|
+
|
|
1716
|
+
del onnx_model.opset_import[:]
|
|
1717
|
+
for oimp in model.opset_import:
|
|
1718
|
+
op_set = onnx_model.opset_import.add()
|
|
1719
|
+
op_set.domain = oimp.domain
|
|
1720
|
+
op_set.version = oimp.version
|
|
1721
|
+
|
|
1722
|
+
return onnx_model
|