onnx-diagnostic 0.8.3__py3-none-any.whl → 0.8.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +47 -10
- onnx_diagnostic/export/api.py +81 -50
- onnx_diagnostic/export/control_flow_research.py +10 -5
- onnx_diagnostic/export/onnx_plug.py +250 -61
- onnx_diagnostic/ext_test_case.py +99 -53
- onnx_diagnostic/helpers/dot_helper.py +37 -25
- onnx_diagnostic/helpers/helper.py +44 -38
- onnx_diagnostic/helpers/onnx_helper.py +441 -18
- onnx_diagnostic/helpers/ort_session.py +8 -8
- onnx_diagnostic/helpers/torch_helper.py +28 -2
- onnx_diagnostic/reference/ort_evaluator.py +6 -29
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +1 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +10 -1
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +168 -113
- onnx_diagnostic/torch_models/code_sample.py +2 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
- onnx_diagnostic/torch_models/validate.py +14 -1
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
- onnx_diagnostic/torch_onnx/sbs.py +11 -5
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +48 -4
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/RECORD +26 -26
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/top_level.txt +0 -0
|
@@ -3,9 +3,20 @@ import json
|
|
|
3
3
|
import os
|
|
4
4
|
import sys
|
|
5
5
|
import warnings
|
|
6
|
-
from typing import
|
|
6
|
+
from typing import (
|
|
7
|
+
Any,
|
|
8
|
+
Callable,
|
|
9
|
+
Dict,
|
|
10
|
+
Iterable,
|
|
11
|
+
Iterator,
|
|
12
|
+
List,
|
|
13
|
+
Optional,
|
|
14
|
+
Sequence,
|
|
15
|
+
Set,
|
|
16
|
+
Tuple,
|
|
17
|
+
Union,
|
|
18
|
+
)
|
|
7
19
|
import numpy as np
|
|
8
|
-
import numpy.typing as npt
|
|
9
20
|
import onnx
|
|
10
21
|
import onnx.helper as oh
|
|
11
22
|
import onnx.numpy_helper as onh
|
|
@@ -21,6 +32,8 @@ from onnx import (
|
|
|
21
32
|
load as onnx_load,
|
|
22
33
|
)
|
|
23
34
|
|
|
35
|
+
TensorLike = Union[np.ndarray, "torch.Tensor"] # noqa: F821
|
|
36
|
+
|
|
24
37
|
|
|
25
38
|
def _make_stat(init: TensorProto) -> Dict[str, float]:
|
|
26
39
|
"""
|
|
@@ -332,7 +345,7 @@ def onnx_dtype_name(itype: int, exc: bool = True) -> str:
|
|
|
332
345
|
print(onnx_dtype_name(7))
|
|
333
346
|
"""
|
|
334
347
|
for k in dir(TensorProto):
|
|
335
|
-
if k.upper() == k and k
|
|
348
|
+
if k.upper() == k and k not in {"DESCRIPTOR", "EXTERNAL", "DEFAULT"}:
|
|
336
349
|
v = getattr(TensorProto, k)
|
|
337
350
|
if v == itype:
|
|
338
351
|
return k
|
|
@@ -478,7 +491,7 @@ def convert_endian(tensor: TensorProto) -> None:
|
|
|
478
491
|
tensor.raw_data = np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap().tobytes()
|
|
479
492
|
|
|
480
493
|
|
|
481
|
-
def from_array_ml_dtypes(arr:
|
|
494
|
+
def from_array_ml_dtypes(arr: TensorLike, name: Optional[str] = None) -> TensorProto:
|
|
482
495
|
"""
|
|
483
496
|
Converts a numpy array to a tensor def assuming the dtype
|
|
484
497
|
is defined in ml_dtypes.
|
|
@@ -524,7 +537,7 @@ _STORAGE_TYPE = {
|
|
|
524
537
|
}
|
|
525
538
|
|
|
526
539
|
|
|
527
|
-
def from_array_extended(tensor:
|
|
540
|
+
def from_array_extended(tensor: TensorLike, name: Optional[str] = None) -> TensorProto:
|
|
528
541
|
"""
|
|
529
542
|
Converts an array into a :class:`onnx.TensorProto`.
|
|
530
543
|
|
|
@@ -591,7 +604,7 @@ def from_array_extended(tensor: npt.ArrayLike, name: Optional[str] = None) -> Te
|
|
|
591
604
|
return t
|
|
592
605
|
|
|
593
606
|
|
|
594
|
-
def to_array_extended(proto: TensorProto) ->
|
|
607
|
+
def to_array_extended(proto: TensorProto) -> TensorLike:
|
|
595
608
|
"""Converts :class:`onnx.TensorProto` into a numpy array."""
|
|
596
609
|
arr = onh.to_array(proto)
|
|
597
610
|
if proto.data_type >= onnx.TensorProto.BFLOAT16:
|
|
@@ -1198,6 +1211,43 @@ def shadowing_names(
|
|
|
1198
1211
|
return shadow, post_shadow, created
|
|
1199
1212
|
|
|
1200
1213
|
|
|
1214
|
+
def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
|
|
1215
|
+
"""
|
|
1216
|
+
Returns the hidden inputs (inputs coming from an upper context)
|
|
1217
|
+
used by a subgraph. It excludes empty names.
|
|
1218
|
+
"""
|
|
1219
|
+
hidden = set()
|
|
1220
|
+
memo = (
|
|
1221
|
+
set(i.name for i in graph.initializer)
|
|
1222
|
+
| set(i.name for i in graph.sparse_initializer)
|
|
1223
|
+
| set(i.name for i in graph.input)
|
|
1224
|
+
)
|
|
1225
|
+
for node in graph.node:
|
|
1226
|
+
for i in node.input:
|
|
1227
|
+
if i and i not in memo:
|
|
1228
|
+
hidden.add(i)
|
|
1229
|
+
for att in node.attribute:
|
|
1230
|
+
if att.type == onnx.AttributeProto.GRAPH and att.g:
|
|
1231
|
+
hid = get_hidden_inputs(att.g)
|
|
1232
|
+
less = set(h for h in hid if h not in memo)
|
|
1233
|
+
hidden |= less
|
|
1234
|
+
memo |= set(node.output)
|
|
1235
|
+
return hidden
|
|
1236
|
+
|
|
1237
|
+
|
|
1238
|
+
def get_all_node_inputs(node: onnx.NodeProto) -> Set[str]:
|
|
1239
|
+
"""
|
|
1240
|
+
Returns input and hidden inputs of a node.
|
|
1241
|
+
See :func:`get_hidden_inputs`. It excludes empty names.
|
|
1242
|
+
"""
|
|
1243
|
+
start = {i for i in node.input if i}
|
|
1244
|
+
if node.op_type in {"Scan", "Loop", "If"}:
|
|
1245
|
+
for att in node.attribute:
|
|
1246
|
+
if att.type == onnx.AttributeProto.GRAPH:
|
|
1247
|
+
start |= get_hidden_inputs(att.g)
|
|
1248
|
+
return start
|
|
1249
|
+
|
|
1250
|
+
|
|
1201
1251
|
def extract_subset_of_nodes(
|
|
1202
1252
|
model: ModelProto,
|
|
1203
1253
|
name: str,
|
|
@@ -1219,11 +1269,14 @@ def extract_subset_of_nodes(
|
|
|
1219
1269
|
if name in node.output:
|
|
1220
1270
|
node_index = i
|
|
1221
1271
|
break
|
|
1222
|
-
assert (
|
|
1223
|
-
node_index
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1272
|
+
assert node_index is not None and node_index < len(model.graph.node), (
|
|
1273
|
+
f"node_index={node_index} (n_nodes={len(model.graph.node)}) "
|
|
1274
|
+
f"is still empty or wrong for result {name!r}"
|
|
1275
|
+
)
|
|
1276
|
+
assert name in model.graph.node[node_index].output, (
|
|
1277
|
+
f"Unable to find {name!r} in {model.graph.node[node_index].output}, "
|
|
1278
|
+
f"node={pretty_onnx(model.graph.node[node_index])}"
|
|
1279
|
+
)
|
|
1227
1280
|
if cut_points is None:
|
|
1228
1281
|
cut_points = {n.name for n in model.graph.input} | {
|
|
1229
1282
|
n.name for n in model.graph.initializer
|
|
@@ -1236,21 +1289,46 @@ def extract_subset_of_nodes(
|
|
|
1236
1289
|
current_node_index = node_index
|
|
1237
1290
|
current_input_index = 0
|
|
1238
1291
|
intermediate = {name}
|
|
1292
|
+
cut_points -= {name}
|
|
1293
|
+
cached: Dict[int, List[str]] = {}
|
|
1239
1294
|
inputs = set(k for k in node.input if k)
|
|
1240
1295
|
while not (inputs <= cut_points) and current_node_index >= 0:
|
|
1241
1296
|
node = model.graph.node[current_node_index]
|
|
1242
|
-
|
|
1297
|
+
# node inputs including hidden ones
|
|
1298
|
+
if current_node_index in cached:
|
|
1299
|
+
node_inputs = cached[current_node_index]
|
|
1300
|
+
else:
|
|
1301
|
+
set_inputs = set(i for i in node.input if i)
|
|
1302
|
+
if node.op_type in {"Scan", "If", "Loop"}:
|
|
1303
|
+
# there are hidden inputs
|
|
1304
|
+
for att in node.attribute:
|
|
1305
|
+
if att.type == onnx.AttributeProto.GRAPH:
|
|
1306
|
+
set_inputs |= get_hidden_inputs(att.g)
|
|
1307
|
+
node_inputs = list(set_inputs)
|
|
1308
|
+
cached[current_node_index] = node_inputs
|
|
1309
|
+
# processing
|
|
1310
|
+
if current_input_index == 0 or not node_inputs:
|
|
1243
1311
|
needs = [o for o in node.output if o in intermediate and o not in cut_points]
|
|
1244
1312
|
if needs:
|
|
1245
1313
|
selected.add(current_node_index)
|
|
1314
|
+
if not node_inputs:
|
|
1315
|
+
current_node_index -= 1
|
|
1316
|
+
current_input_index = 0
|
|
1317
|
+
continue
|
|
1246
1318
|
else:
|
|
1247
1319
|
current_node_index -= 1
|
|
1320
|
+
current_input_index = 0
|
|
1248
1321
|
continue
|
|
1249
|
-
|
|
1322
|
+
# more intermediate results
|
|
1323
|
+
assert current_input_index < len(node_inputs), (
|
|
1324
|
+
f"current_input_index={current_input_index} but node_inputs={node_inputs}, "
|
|
1325
|
+
f"node={pretty_onnx(node)}"
|
|
1326
|
+
)
|
|
1327
|
+
res = node_inputs[current_input_index]
|
|
1250
1328
|
if res not in cut_points:
|
|
1251
1329
|
intermediate.add(res)
|
|
1252
1330
|
current_input_index += 1
|
|
1253
|
-
if current_input_index >= len(
|
|
1331
|
+
if current_input_index >= len(node_inputs):
|
|
1254
1332
|
current_node_index -= 1
|
|
1255
1333
|
current_input_index = 0
|
|
1256
1334
|
|
|
@@ -1283,17 +1361,362 @@ def make_submodel(
|
|
|
1283
1361
|
|
|
1284
1362
|
not_known: Set[str] = set()
|
|
1285
1363
|
for node in nodes[::-1]:
|
|
1286
|
-
not_known -=
|
|
1287
|
-
not_known |=
|
|
1364
|
+
not_known -= {o for o in node.output if o}
|
|
1365
|
+
not_known |= {i for i in node.input if i}
|
|
1366
|
+
if node.op_type in {"Scan", "If", "Loop"}:
|
|
1367
|
+
# there are hidden inputs
|
|
1368
|
+
for att in node.attribute:
|
|
1369
|
+
if att.type == onnx.AttributeProto.GRAPH:
|
|
1370
|
+
not_known |= get_hidden_inputs(att.g)
|
|
1288
1371
|
|
|
1289
1372
|
model = oh.make_model(
|
|
1290
1373
|
oh.make_graph(
|
|
1291
1374
|
nodes,
|
|
1292
1375
|
"submodel",
|
|
1293
|
-
[_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known)],
|
|
1294
|
-
[_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names)],
|
|
1376
|
+
[_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known) if n],
|
|
1377
|
+
[_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names) if n],
|
|
1295
1378
|
),
|
|
1296
1379
|
ir_version=ir_version,
|
|
1297
1380
|
opset_imports=opset_imports,
|
|
1298
1381
|
)
|
|
1299
1382
|
return model
|
|
1383
|
+
|
|
1384
|
+
|
|
1385
|
+
def get_tensor_shape(
|
|
1386
|
+
obj: Union[onnx.ValueInfoProto, onnx.TypeProto, onnx.TensorProto],
|
|
1387
|
+
) -> Optional[List[Optional[Union[int, str]]]]:
|
|
1388
|
+
"""
|
|
1389
|
+
Returns the shape if that makes sense for this object.
|
|
1390
|
+
"""
|
|
1391
|
+
if isinstance(obj, ValueInfoProto):
|
|
1392
|
+
return get_tensor_shape(obj.type)
|
|
1393
|
+
elif not isinstance(obj, onnx.TypeProto):
|
|
1394
|
+
raise TypeError(f"Unexpected type {type(obj)!r}.")
|
|
1395
|
+
if not obj.tensor_type.HasField("shape"):
|
|
1396
|
+
return None
|
|
1397
|
+
shape = []
|
|
1398
|
+
for d in obj.tensor_type.shape.dim:
|
|
1399
|
+
v = d.dim_value if d.dim_value > 0 else d.dim_param
|
|
1400
|
+
shape.append(v)
|
|
1401
|
+
if not shape:
|
|
1402
|
+
return shape
|
|
1403
|
+
return [None if s in (0, "") else s for s in shape]
|
|
1404
|
+
|
|
1405
|
+
|
|
1406
|
+
def _enumerate_model_node_outputs(
|
|
1407
|
+
model: ModelProto, add_node: bool = False, order: bool = False
|
|
1408
|
+
) -> Iterable[Union[str, Tuple[str, NodeProto]]]:
|
|
1409
|
+
"""
|
|
1410
|
+
Enumerates all the nodes of a model.
|
|
1411
|
+
|
|
1412
|
+
:param model: :epkg:`ONNX` graph
|
|
1413
|
+
:param add_node: if False, the function enumerates
|
|
1414
|
+
all output names from every node, otherwise, it
|
|
1415
|
+
enumerates tuple (output name, node)
|
|
1416
|
+
:param order: goes through outputs following the graph order
|
|
1417
|
+
:return: enumerator
|
|
1418
|
+
"""
|
|
1419
|
+
assert hasattr(model, "graph"), "Parameter model is not an ONNX model but {type(model)}"
|
|
1420
|
+
if order:
|
|
1421
|
+
edges = []
|
|
1422
|
+
d_order = {}
|
|
1423
|
+
node_names = {}
|
|
1424
|
+
for inp in model.graph.input:
|
|
1425
|
+
d_order[0, inp.name] = 0
|
|
1426
|
+
for node in model.graph.node:
|
|
1427
|
+
d_order[1, node.name] = 0
|
|
1428
|
+
for i in node.input:
|
|
1429
|
+
edges.append(("in", i, node.name))
|
|
1430
|
+
for o in node.output:
|
|
1431
|
+
edges.append(("out", o, node.name))
|
|
1432
|
+
node_names[o] = node
|
|
1433
|
+
d_order[0, o] = 0
|
|
1434
|
+
|
|
1435
|
+
modif = 1
|
|
1436
|
+
n_iter = 0
|
|
1437
|
+
while modif > 0 and n_iter <= len(model.graph.node):
|
|
1438
|
+
modif = 0
|
|
1439
|
+
n_iter += 1
|
|
1440
|
+
for kind, data_name, node_name in edges:
|
|
1441
|
+
if kind == "in":
|
|
1442
|
+
if (0, data_name) not in d_order:
|
|
1443
|
+
continue
|
|
1444
|
+
if d_order[0, data_name] + 1 > d_order[1, node_name]:
|
|
1445
|
+
modif += 1
|
|
1446
|
+
d_order[1, node_name] = d_order[0, data_name] + 1
|
|
1447
|
+
else:
|
|
1448
|
+
if d_order[1, node_name] + 1 > d_order[0, data_name]:
|
|
1449
|
+
modif += 1
|
|
1450
|
+
d_order[0, data_name] = d_order[1, node_name] + 1
|
|
1451
|
+
|
|
1452
|
+
orders = [(v, k) for k, v in d_order.items()]
|
|
1453
|
+
orders.sort()
|
|
1454
|
+
|
|
1455
|
+
for _, k in orders:
|
|
1456
|
+
if k[0] == 1:
|
|
1457
|
+
continue
|
|
1458
|
+
out = k[1]
|
|
1459
|
+
if out not in node_names:
|
|
1460
|
+
continue
|
|
1461
|
+
yield (out, node_names[out]) if add_node else out
|
|
1462
|
+
else:
|
|
1463
|
+
for node in model.graph.node:
|
|
1464
|
+
for out in node.output:
|
|
1465
|
+
yield (out, node) if add_node else out
|
|
1466
|
+
|
|
1467
|
+
|
|
1468
|
+
def onnx_remove_node_unused(
|
|
1469
|
+
graph: Union[onnx.GraphProto, onnx.FunctionProto], recursive=True
|
|
1470
|
+
) -> Union[onnx.GraphProto, onnx.FunctionProto]:
|
|
1471
|
+
"""
|
|
1472
|
+
Removes unused nodes of the graph. An unused node
|
|
1473
|
+
is not involved in the output computation.
|
|
1474
|
+
|
|
1475
|
+
:param onnx_model: onnx model
|
|
1476
|
+
:param recursive: looks into subgraphs
|
|
1477
|
+
:return: new Graph
|
|
1478
|
+
"""
|
|
1479
|
+
is_function = isinstance(graph, FunctionProto)
|
|
1480
|
+
|
|
1481
|
+
# mark outputs
|
|
1482
|
+
marked: Dict[str, Set[str]] = (
|
|
1483
|
+
{o: set() for o in graph.output}
|
|
1484
|
+
if is_function
|
|
1485
|
+
else {o.name: set() for o in graph.output}
|
|
1486
|
+
)
|
|
1487
|
+
nodes = list(graph.node)
|
|
1488
|
+
|
|
1489
|
+
# mark node output
|
|
1490
|
+
for node in reversed(nodes):
|
|
1491
|
+
used = False
|
|
1492
|
+
for o in node.output:
|
|
1493
|
+
if o and o in marked:
|
|
1494
|
+
for i in get_all_node_inputs(node):
|
|
1495
|
+
marked[o].add(i)
|
|
1496
|
+
used = True
|
|
1497
|
+
if used:
|
|
1498
|
+
for i in get_all_node_inputs(node):
|
|
1499
|
+
marked[i] = set()
|
|
1500
|
+
|
|
1501
|
+
# removed nodes
|
|
1502
|
+
removed = set()
|
|
1503
|
+
marked_set = set(marked)
|
|
1504
|
+
for ind, node in enumerate(nodes):
|
|
1505
|
+
if not ({o for o in node.output if o} & marked_set):
|
|
1506
|
+
removed.add(ind)
|
|
1507
|
+
|
|
1508
|
+
if not is_function:
|
|
1509
|
+
initializers = [i for i in graph.initializer if i.name in marked]
|
|
1510
|
+
sparse_initializers = [i for i in graph.sparse_initializer if i.name in marked]
|
|
1511
|
+
new_nodes = [node for i, node in enumerate(nodes) if i not in removed]
|
|
1512
|
+
|
|
1513
|
+
# Finally create the new graph.
|
|
1514
|
+
if is_function:
|
|
1515
|
+
return oh.make_function(
|
|
1516
|
+
graph.domain,
|
|
1517
|
+
graph.name,
|
|
1518
|
+
graph.input,
|
|
1519
|
+
graph.output,
|
|
1520
|
+
new_nodes,
|
|
1521
|
+
opset_imports=graph.opset_import,
|
|
1522
|
+
attributes=graph.attribute,
|
|
1523
|
+
doc_string=graph.doc_string,
|
|
1524
|
+
)
|
|
1525
|
+
new_graph = oh.make_graph(
|
|
1526
|
+
new_nodes,
|
|
1527
|
+
graph.name,
|
|
1528
|
+
graph.input,
|
|
1529
|
+
graph.output,
|
|
1530
|
+
initializers,
|
|
1531
|
+
sparse_initializers,
|
|
1532
|
+
)
|
|
1533
|
+
new_graph.value_info.extend(graph.value_info)
|
|
1534
|
+
return new_graph
|
|
1535
|
+
|
|
1536
|
+
|
|
1537
|
+
def select_model_inputs_outputs(
|
|
1538
|
+
model: ModelProto,
|
|
1539
|
+
outputs: Optional[List[str]] = None,
|
|
1540
|
+
inputs: Optional[List[str]] = None,
|
|
1541
|
+
infer_shapes: bool = True,
|
|
1542
|
+
overwrite: Optional[Dict[str, Any]] = None,
|
|
1543
|
+
remove_unused: bool = True,
|
|
1544
|
+
verbose: int = 0,
|
|
1545
|
+
):
|
|
1546
|
+
"""
|
|
1547
|
+
Takes a model and changes its outputs.
|
|
1548
|
+
|
|
1549
|
+
:param model: :epkg:`ONNX` model
|
|
1550
|
+
:param inputs: new inputs, same ones if None
|
|
1551
|
+
:param outputs: new outputs, same ones if None
|
|
1552
|
+
:param infer_shapes: infer inputs and outputs shapes
|
|
1553
|
+
:param overwrite: overwrite type and shapes for
|
|
1554
|
+
inputs or outputs, *overwrite* is a
|
|
1555
|
+
dictionary `{'name': (numpy dtype, shape)}`
|
|
1556
|
+
:param remove_unused: remove unused nodes from the graph
|
|
1557
|
+
:param verbose: display information while converting
|
|
1558
|
+
:return: modified model
|
|
1559
|
+
|
|
1560
|
+
The function removes unneeded nodes.
|
|
1561
|
+
|
|
1562
|
+
The following example shows how to change the inputs of model
|
|
1563
|
+
to bypass the first nodes. Shape inferences fails to determine
|
|
1564
|
+
the new inputs type. They need to be overwritten.
|
|
1565
|
+
`verbose=1` shows the number of deleted nodes.
|
|
1566
|
+
|
|
1567
|
+
::
|
|
1568
|
+
|
|
1569
|
+
import onnx
|
|
1570
|
+
from onnx_extended.tools.onnx_nodes import select_model_inputs_outputs
|
|
1571
|
+
|
|
1572
|
+
onx = onnx.load(path)
|
|
1573
|
+
onx2 = select_model_inputs_outputs(
|
|
1574
|
+
onx, inputs=["a", "b"],
|
|
1575
|
+
infer_shapes=True, verbose=1,
|
|
1576
|
+
overwrite={'a': (numpy.int32, None), 'b': (numpy.int64, None)})
|
|
1577
|
+
onnx.save(onx2, path2)
|
|
1578
|
+
"""
|
|
1579
|
+
if not isinstance(model, ModelProto):
|
|
1580
|
+
raise TypeError(f"Unexpected type {type(model)} for model.")
|
|
1581
|
+
if inputs is not None and not isinstance(inputs, list):
|
|
1582
|
+
inputs = [inputs]
|
|
1583
|
+
if outputs is not None and not isinstance(outputs, list):
|
|
1584
|
+
outputs = [outputs]
|
|
1585
|
+
if inputs is None:
|
|
1586
|
+
inputs = [i.name for i in model.graph.input]
|
|
1587
|
+
if outputs is None:
|
|
1588
|
+
outputs = [o.name for o in model.graph.output]
|
|
1589
|
+
|
|
1590
|
+
mark_var = {}
|
|
1591
|
+
for out in _enumerate_model_node_outputs(model):
|
|
1592
|
+
mark_var[out] = 0
|
|
1593
|
+
for inp in inputs:
|
|
1594
|
+
mark_var[inp] = 0
|
|
1595
|
+
for out in outputs:
|
|
1596
|
+
assert out in mark_var, f"Output {out!r} not found in model."
|
|
1597
|
+
mark_var[out] = 1
|
|
1598
|
+
|
|
1599
|
+
nodes = list(model.graph.node[::-1])
|
|
1600
|
+
mark_op = {}
|
|
1601
|
+
for node in list(nodes):
|
|
1602
|
+
mark_op[id(node)] = 0
|
|
1603
|
+
|
|
1604
|
+
# We mark all the nodes we need to keep.
|
|
1605
|
+
nb = 1
|
|
1606
|
+
while nb > 0:
|
|
1607
|
+
nb = 0
|
|
1608
|
+
for node in nodes:
|
|
1609
|
+
if mark_op[id(node)] == 1:
|
|
1610
|
+
continue
|
|
1611
|
+
mod = False
|
|
1612
|
+
for out in node.output:
|
|
1613
|
+
if mark_var[out] == 1:
|
|
1614
|
+
mark_op[id(node)] = 1
|
|
1615
|
+
mod = True
|
|
1616
|
+
break
|
|
1617
|
+
if not mod:
|
|
1618
|
+
continue
|
|
1619
|
+
|
|
1620
|
+
node_inputs = get_all_node_inputs(node)
|
|
1621
|
+
|
|
1622
|
+
nb += 1
|
|
1623
|
+
for inp in node_inputs:
|
|
1624
|
+
if inp in inputs:
|
|
1625
|
+
continue
|
|
1626
|
+
if mark_var.get(inp, 0) == 1:
|
|
1627
|
+
continue
|
|
1628
|
+
mark_var[inp] = 1
|
|
1629
|
+
nb += 1
|
|
1630
|
+
|
|
1631
|
+
# All nodes verifies mark_op[node.name] == 1
|
|
1632
|
+
keep_nodes = [node for node in nodes[::-1] if mark_op[id(node)] == 1]
|
|
1633
|
+
|
|
1634
|
+
known_shapes = {}
|
|
1635
|
+
if infer_shapes:
|
|
1636
|
+
shapes = onnx.shape_inference.infer_shapes(model)
|
|
1637
|
+
for shape in shapes.graph.value_info:
|
|
1638
|
+
known_shapes[shape.name] = shape.type
|
|
1639
|
+
for shape in shapes.graph.input:
|
|
1640
|
+
known_shapes[shape.name] = shape.type
|
|
1641
|
+
for shape in shapes.graph.output:
|
|
1642
|
+
known_shapes[shape.name] = shape.type
|
|
1643
|
+
else:
|
|
1644
|
+
for shape in model.graph.input:
|
|
1645
|
+
known_shapes[shape.name] = shape.type
|
|
1646
|
+
for shape in model.graph.output:
|
|
1647
|
+
known_shapes[shape.name] = shape.type
|
|
1648
|
+
|
|
1649
|
+
var_in = []
|
|
1650
|
+
existing = {i.name: i for i in model.graph.input}
|
|
1651
|
+
for name in inputs:
|
|
1652
|
+
if overwrite is not None and name in overwrite:
|
|
1653
|
+
dtype, shape = overwrite[name]
|
|
1654
|
+
proto_dtype = np_dtype_to_tensor_dtype(dtype)
|
|
1655
|
+
value_info = oh.make_tensor_value_info(name, proto_dtype, shape)
|
|
1656
|
+
elif name in known_shapes:
|
|
1657
|
+
info = known_shapes[name].tensor_type
|
|
1658
|
+
proto_dtype = info.elem_type
|
|
1659
|
+
if proto_dtype == 0:
|
|
1660
|
+
value_info = ValueInfoProto()
|
|
1661
|
+
value_info.name = name
|
|
1662
|
+
else:
|
|
1663
|
+
shape = get_tensor_shape(known_shapes[name])
|
|
1664
|
+
value_info = oh.make_tensor_value_info(name, proto_dtype, shape)
|
|
1665
|
+
elif name in existing:
|
|
1666
|
+
value_info = existing[name]
|
|
1667
|
+
else:
|
|
1668
|
+
value_info = ValueInfoProto()
|
|
1669
|
+
value_info.name = name
|
|
1670
|
+
var_in.append(value_info)
|
|
1671
|
+
|
|
1672
|
+
var_out = []
|
|
1673
|
+
existing = {i.name: i for i in model.graph.output}
|
|
1674
|
+
for name in outputs:
|
|
1675
|
+
if overwrite is not None and name in overwrite:
|
|
1676
|
+
dtype, shape = overwrite[name]
|
|
1677
|
+
proto_dtype = np_dtype_to_tensor_dtype(dtype)
|
|
1678
|
+
value_info = oh.make_tensor_value_info(name, proto_dtype, shape)
|
|
1679
|
+
elif name in known_shapes:
|
|
1680
|
+
info = known_shapes[name].tensor_type
|
|
1681
|
+
proto_dtype = info.elem_type
|
|
1682
|
+
if proto_dtype == 0:
|
|
1683
|
+
value_info = ValueInfoProto()
|
|
1684
|
+
value_info.name = name
|
|
1685
|
+
else:
|
|
1686
|
+
shape = get_tensor_shape(known_shapes[name])
|
|
1687
|
+
value_info = oh.make_tensor_value_info(name, proto_dtype, shape)
|
|
1688
|
+
elif name in existing:
|
|
1689
|
+
value_info = existing[name]
|
|
1690
|
+
else:
|
|
1691
|
+
value_info = ValueInfoProto()
|
|
1692
|
+
value_info.name = name
|
|
1693
|
+
var_out.append(value_info)
|
|
1694
|
+
|
|
1695
|
+
graph = oh.make_graph(
|
|
1696
|
+
keep_nodes,
|
|
1697
|
+
model.graph.name,
|
|
1698
|
+
var_in,
|
|
1699
|
+
var_out,
|
|
1700
|
+
model.graph.initializer,
|
|
1701
|
+
sparse_initializer=model.graph.sparse_initializer,
|
|
1702
|
+
)
|
|
1703
|
+
if remove_unused:
|
|
1704
|
+
graph = onnx_remove_node_unused(graph, recursive=False)
|
|
1705
|
+
onnx_model = oh.make_model(graph, functions=model.functions)
|
|
1706
|
+
onnx_model.ir_version = model.ir_version
|
|
1707
|
+
onnx_model.producer_name = model.producer_name
|
|
1708
|
+
onnx_model.producer_version = model.producer_version
|
|
1709
|
+
onnx_model.domain = model.domain
|
|
1710
|
+
onnx_model.model_version = model.model_version
|
|
1711
|
+
onnx_model.doc_string = model.doc_string
|
|
1712
|
+
if model.metadata_props:
|
|
1713
|
+
values = {p.key: p.value for p in model.metadata_props}
|
|
1714
|
+
oh.set_model_props(onnx_model, values)
|
|
1715
|
+
|
|
1716
|
+
del onnx_model.opset_import[:]
|
|
1717
|
+
for oimp in model.opset_import:
|
|
1718
|
+
op_set = onnx_model.opset_import.add()
|
|
1719
|
+
op_set.domain = oimp.domain
|
|
1720
|
+
op_set.version = oimp.version
|
|
1721
|
+
|
|
1722
|
+
return onnx_model
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
2
2
|
import onnx
|
|
3
3
|
import numpy as np
|
|
4
|
-
import numpy.typing as npt
|
|
5
4
|
import torch
|
|
6
5
|
from torch._C import _from_dlpack
|
|
7
6
|
import onnxruntime
|
|
@@ -16,6 +15,7 @@ from .torch_helper import torch_dtype_to_onnx_dtype
|
|
|
16
15
|
|
|
17
16
|
|
|
18
17
|
DEVICES = {-1: ORTC.OrtDevice(ORTC.OrtDevice.cpu(), ORTC.OrtDevice.default_memory(), 0)}
|
|
18
|
+
TensorLike = Union[np.ndarray, torch.Tensor]
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class _InferenceSession:
|
|
@@ -243,16 +243,16 @@ class InferenceSessionForNumpy(_InferenceSession):
|
|
|
243
243
|
)
|
|
244
244
|
|
|
245
245
|
def run(
|
|
246
|
-
self, output_names: Optional[List[str]], feeds: Dict[str,
|
|
247
|
-
) -> List[Optional[
|
|
246
|
+
self, output_names: Optional[List[str]], feeds: Dict[str, TensorLike]
|
|
247
|
+
) -> List[Optional[TensorLike]]:
|
|
248
248
|
"""Calls :meth:`onnxruntime.InferenceSession.run`."""
|
|
249
249
|
# sess.run does not support blfoat16
|
|
250
250
|
# res = self.sess.run(output_names, feeds)
|
|
251
251
|
return self._post_process_inplace(list(self.run_dlpack(output_names, feeds)))
|
|
252
252
|
|
|
253
253
|
def run_dlpack(
|
|
254
|
-
self, output_names: Optional[List[str]], feeds: Dict[str,
|
|
255
|
-
) -> Tuple[Optional[
|
|
254
|
+
self, output_names: Optional[List[str]], feeds: Dict[str, TensorLike]
|
|
255
|
+
) -> Tuple[Optional[TensorLike], ...]:
|
|
256
256
|
"""
|
|
257
257
|
Same as :meth:`onnxruntime.InferenceSession.run` except that
|
|
258
258
|
feeds is a dictionary of :class:`np.ndarray`.
|
|
@@ -289,13 +289,13 @@ class InferenceSessionForNumpy(_InferenceSession):
|
|
|
289
289
|
def _ortvalues_to_numpy_tensor(
|
|
290
290
|
self,
|
|
291
291
|
ortvalues: Union[List[ORTC.OrtValue], ORTC.OrtValueVector],
|
|
292
|
-
) -> Tuple[Optional[
|
|
292
|
+
) -> Tuple[Optional[TensorLike], ...]:
|
|
293
293
|
if len(ortvalues) == 0:
|
|
294
294
|
return tuple()
|
|
295
295
|
|
|
296
296
|
if self.nvtx:
|
|
297
297
|
self.torch.cuda.nvtx.range_push("_ortvalues_to_numpy_tensor")
|
|
298
|
-
res: List[Optional[
|
|
298
|
+
res: List[Optional[TensorLike]] = [] # noqa: F823
|
|
299
299
|
for i in range(len(ortvalues)):
|
|
300
300
|
if not ortvalues[i].has_value():
|
|
301
301
|
res.append(None)
|
|
@@ -556,7 +556,7 @@ def investigate_onnxruntime_issue(
|
|
|
556
556
|
Union[str, Callable[[onnx.ModelProto], onnxruntime.InferenceSession]]
|
|
557
557
|
] = None,
|
|
558
558
|
# if model needs to be run.
|
|
559
|
-
feeds: Optional[Union[Dict[str, torch.Tensor], Dict[str,
|
|
559
|
+
feeds: Optional[Union[Dict[str, torch.Tensor], Dict[str, TensorLike]]] = None,
|
|
560
560
|
verbose: int = 0,
|
|
561
561
|
dump_filename: Optional[str] = None,
|
|
562
562
|
infer_shapes: bool = True,
|
|
@@ -139,6 +139,15 @@ def onnx_dtype_to_torch_dtype(itype: int) -> torch.dtype:
|
|
|
139
139
|
)
|
|
140
140
|
|
|
141
141
|
|
|
142
|
+
_TYPENAME = dict(
|
|
143
|
+
FLOAT=onnx.TensorProto.FLOAT,
|
|
144
|
+
INT64=onnx.TensorProto.INT64,
|
|
145
|
+
INT32=onnx.TensorProto.INT32,
|
|
146
|
+
FLOAT16=onnx.TensorProto.FLOAT16,
|
|
147
|
+
BFLOAT16=onnx.TensorProto.BFLOAT16,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
142
151
|
def torch_dtype_to_onnx_dtype(to: torch.dtype) -> int:
|
|
143
152
|
"""
|
|
144
153
|
Converts a torch dtype into a onnx element type.
|
|
@@ -182,7 +191,13 @@ def torch_dtype_to_onnx_dtype(to: torch.dtype) -> int:
|
|
|
182
191
|
return onnx.TensorProto.COMPLEX64
|
|
183
192
|
if to == torch.complex128:
|
|
184
193
|
return onnx.TensorProto.COMPLEX128
|
|
185
|
-
|
|
194
|
+
# SymbolicTensor
|
|
195
|
+
sto = str(to)
|
|
196
|
+
if sto in _TYPENAME:
|
|
197
|
+
return _TYPENAME[sto]
|
|
198
|
+
raise NotImplementedError(
|
|
199
|
+
f"Unable to convert torch dtype {to!r} ({type(to)}) to onnx dtype."
|
|
200
|
+
)
|
|
186
201
|
|
|
187
202
|
|
|
188
203
|
def _forward_(
|
|
@@ -811,7 +826,8 @@ def torch_deepcopy(value: Any) -> Any:
|
|
|
811
826
|
if isinstance(value, tuple):
|
|
812
827
|
return tuple(torch_deepcopy(v) for v in value)
|
|
813
828
|
if isinstance(value, list):
|
|
814
|
-
|
|
829
|
+
if type(value) is list:
|
|
830
|
+
return [torch_deepcopy(v) for v in value]
|
|
815
831
|
if isinstance(value, set):
|
|
816
832
|
return {torch_deepcopy(v) for v in value}
|
|
817
833
|
if isinstance(value, dict):
|
|
@@ -1087,3 +1103,13 @@ def study_discrepancies(
|
|
|
1087
1103
|
if name:
|
|
1088
1104
|
fig.savefig(name)
|
|
1089
1105
|
return ax
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
def int_device_to_torch_device(device_id: int) -> torch.device:
|
|
1109
|
+
"""
|
|
1110
|
+
Converts a device defined as an integer (coming from :meth:`torch.Tensor.get_device`)
|
|
1111
|
+
into a ``torch.device``.
|
|
1112
|
+
"""
|
|
1113
|
+
if device_id < 0:
|
|
1114
|
+
return torch.device("cpu")
|
|
1115
|
+
return torch.device("cuda", device_id)
|