onnx-diagnostic 0.8.3__py3-none-any.whl → 0.8.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +47 -10
  3. onnx_diagnostic/export/api.py +81 -50
  4. onnx_diagnostic/export/control_flow_research.py +10 -5
  5. onnx_diagnostic/export/onnx_plug.py +250 -61
  6. onnx_diagnostic/ext_test_case.py +99 -53
  7. onnx_diagnostic/helpers/dot_helper.py +37 -25
  8. onnx_diagnostic/helpers/helper.py +44 -38
  9. onnx_diagnostic/helpers/onnx_helper.py +441 -18
  10. onnx_diagnostic/helpers/ort_session.py +8 -8
  11. onnx_diagnostic/helpers/torch_helper.py +28 -2
  12. onnx_diagnostic/reference/ort_evaluator.py +6 -29
  13. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +1 -0
  14. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +10 -1
  15. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +168 -113
  16. onnx_diagnostic/torch_models/code_sample.py +2 -1
  17. onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
  18. onnx_diagnostic/torch_models/validate.py +14 -1
  19. onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
  20. onnx_diagnostic/torch_onnx/sbs.py +11 -5
  21. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +48 -4
  22. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/METADATA +1 -1
  23. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/RECORD +26 -26
  24. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/WHEEL +0 -0
  25. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/licenses/LICENSE.txt +0 -0
  26. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/top_level.txt +0 -0
@@ -3,9 +3,20 @@ import json
3
3
  import os
4
4
  import sys
5
5
  import warnings
6
- from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
6
+ from typing import (
7
+ Any,
8
+ Callable,
9
+ Dict,
10
+ Iterable,
11
+ Iterator,
12
+ List,
13
+ Optional,
14
+ Sequence,
15
+ Set,
16
+ Tuple,
17
+ Union,
18
+ )
7
19
  import numpy as np
8
- import numpy.typing as npt
9
20
  import onnx
10
21
  import onnx.helper as oh
11
22
  import onnx.numpy_helper as onh
@@ -21,6 +32,8 @@ from onnx import (
21
32
  load as onnx_load,
22
33
  )
23
34
 
35
+ TensorLike = Union[np.ndarray, "torch.Tensor"] # noqa: F821
36
+
24
37
 
25
38
  def _make_stat(init: TensorProto) -> Dict[str, float]:
26
39
  """
@@ -332,7 +345,7 @@ def onnx_dtype_name(itype: int, exc: bool = True) -> str:
332
345
  print(onnx_dtype_name(7))
333
346
  """
334
347
  for k in dir(TensorProto):
335
- if k.upper() == k and k != "EXTERNAL":
348
+ if k.upper() == k and k not in {"DESCRIPTOR", "EXTERNAL", "DEFAULT"}:
336
349
  v = getattr(TensorProto, k)
337
350
  if v == itype:
338
351
  return k
@@ -478,7 +491,7 @@ def convert_endian(tensor: TensorProto) -> None:
478
491
  tensor.raw_data = np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap().tobytes()
479
492
 
480
493
 
481
- def from_array_ml_dtypes(arr: npt.ArrayLike, name: Optional[str] = None) -> TensorProto:
494
+ def from_array_ml_dtypes(arr: TensorLike, name: Optional[str] = None) -> TensorProto:
482
495
  """
483
496
  Converts a numpy array to a tensor def assuming the dtype
484
497
  is defined in ml_dtypes.
@@ -524,7 +537,7 @@ _STORAGE_TYPE = {
524
537
  }
525
538
 
526
539
 
527
- def from_array_extended(tensor: npt.ArrayLike, name: Optional[str] = None) -> TensorProto:
540
+ def from_array_extended(tensor: TensorLike, name: Optional[str] = None) -> TensorProto:
528
541
  """
529
542
  Converts an array into a :class:`onnx.TensorProto`.
530
543
 
@@ -591,7 +604,7 @@ def from_array_extended(tensor: npt.ArrayLike, name: Optional[str] = None) -> Te
591
604
  return t
592
605
 
593
606
 
594
- def to_array_extended(proto: TensorProto) -> npt.ArrayLike:
607
+ def to_array_extended(proto: TensorProto) -> TensorLike:
595
608
  """Converts :class:`onnx.TensorProto` into a numpy array."""
596
609
  arr = onh.to_array(proto)
597
610
  if proto.data_type >= onnx.TensorProto.BFLOAT16:
@@ -1198,6 +1211,43 @@ def shadowing_names(
1198
1211
  return shadow, post_shadow, created
1199
1212
 
1200
1213
 
1214
+ def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
1215
+ """
1216
+ Returns the hidden inputs (inputs coming from an upper context)
1217
+ used by a subgraph. It excludes empty names.
1218
+ """
1219
+ hidden = set()
1220
+ memo = (
1221
+ set(i.name for i in graph.initializer)
1222
+ | set(i.name for i in graph.sparse_initializer)
1223
+ | set(i.name for i in graph.input)
1224
+ )
1225
+ for node in graph.node:
1226
+ for i in node.input:
1227
+ if i and i not in memo:
1228
+ hidden.add(i)
1229
+ for att in node.attribute:
1230
+ if att.type == onnx.AttributeProto.GRAPH and att.g:
1231
+ hid = get_hidden_inputs(att.g)
1232
+ less = set(h for h in hid if h not in memo)
1233
+ hidden |= less
1234
+ memo |= set(node.output)
1235
+ return hidden
1236
+
1237
+
1238
+ def get_all_node_inputs(node: onnx.NodeProto) -> Set[str]:
1239
+ """
1240
+ Returns input and hidden inputs of a node.
1241
+ See :func:`get_hidden_inputs`. It excludes empty names.
1242
+ """
1243
+ start = {i for i in node.input if i}
1244
+ if node.op_type in {"Scan", "Loop", "If"}:
1245
+ for att in node.attribute:
1246
+ if att.type == onnx.AttributeProto.GRAPH:
1247
+ start |= get_hidden_inputs(att.g)
1248
+ return start
1249
+
1250
+
1201
1251
  def extract_subset_of_nodes(
1202
1252
  model: ModelProto,
1203
1253
  name: str,
@@ -1219,11 +1269,14 @@ def extract_subset_of_nodes(
1219
1269
  if name in node.output:
1220
1270
  node_index = i
1221
1271
  break
1222
- assert (
1223
- node_index is not None
1224
- and node_index < len(model.graph.node)
1225
- and name in model.graph.node[node_index].output
1226
- ), f"node_index is still empty or wrong for result {name!r}"
1272
+ assert node_index is not None and node_index < len(model.graph.node), (
1273
+ f"node_index={node_index} (n_nodes={len(model.graph.node)}) "
1274
+ f"is still empty or wrong for result {name!r}"
1275
+ )
1276
+ assert name in model.graph.node[node_index].output, (
1277
+ f"Unable to find {name!r} in {model.graph.node[node_index].output}, "
1278
+ f"node={pretty_onnx(model.graph.node[node_index])}"
1279
+ )
1227
1280
  if cut_points is None:
1228
1281
  cut_points = {n.name for n in model.graph.input} | {
1229
1282
  n.name for n in model.graph.initializer
@@ -1236,21 +1289,46 @@ def extract_subset_of_nodes(
1236
1289
  current_node_index = node_index
1237
1290
  current_input_index = 0
1238
1291
  intermediate = {name}
1292
+ cut_points -= {name}
1293
+ cached: Dict[int, List[str]] = {}
1239
1294
  inputs = set(k for k in node.input if k)
1240
1295
  while not (inputs <= cut_points) and current_node_index >= 0:
1241
1296
  node = model.graph.node[current_node_index]
1242
- if current_input_index == 0:
1297
+ # node inputs including hidden ones
1298
+ if current_node_index in cached:
1299
+ node_inputs = cached[current_node_index]
1300
+ else:
1301
+ set_inputs = set(i for i in node.input if i)
1302
+ if node.op_type in {"Scan", "If", "Loop"}:
1303
+ # there are hidden inputs
1304
+ for att in node.attribute:
1305
+ if att.type == onnx.AttributeProto.GRAPH:
1306
+ set_inputs |= get_hidden_inputs(att.g)
1307
+ node_inputs = list(set_inputs)
1308
+ cached[current_node_index] = node_inputs
1309
+ # processing
1310
+ if current_input_index == 0 or not node_inputs:
1243
1311
  needs = [o for o in node.output if o in intermediate and o not in cut_points]
1244
1312
  if needs:
1245
1313
  selected.add(current_node_index)
1314
+ if not node_inputs:
1315
+ current_node_index -= 1
1316
+ current_input_index = 0
1317
+ continue
1246
1318
  else:
1247
1319
  current_node_index -= 1
1320
+ current_input_index = 0
1248
1321
  continue
1249
- res = node.input[current_input_index]
1322
+ # more intermediate results
1323
+ assert current_input_index < len(node_inputs), (
1324
+ f"current_input_index={current_input_index} but node_inputs={node_inputs}, "
1325
+ f"node={pretty_onnx(node)}"
1326
+ )
1327
+ res = node_inputs[current_input_index]
1250
1328
  if res not in cut_points:
1251
1329
  intermediate.add(res)
1252
1330
  current_input_index += 1
1253
- if current_input_index >= len(node.input):
1331
+ if current_input_index >= len(node_inputs):
1254
1332
  current_node_index -= 1
1255
1333
  current_input_index = 0
1256
1334
 
@@ -1283,17 +1361,362 @@ def make_submodel(
1283
1361
 
1284
1362
  not_known: Set[str] = set()
1285
1363
  for node in nodes[::-1]:
1286
- not_known -= set(node.output)
1287
- not_known |= set(node.input)
1364
+ not_known -= {o for o in node.output if o}
1365
+ not_known |= {i for i in node.input if i}
1366
+ if node.op_type in {"Scan", "If", "Loop"}:
1367
+ # there are hidden inputs
1368
+ for att in node.attribute:
1369
+ if att.type == onnx.AttributeProto.GRAPH:
1370
+ not_known |= get_hidden_inputs(att.g)
1288
1371
 
1289
1372
  model = oh.make_model(
1290
1373
  oh.make_graph(
1291
1374
  nodes,
1292
1375
  "submodel",
1293
- [_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known)],
1294
- [_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names)],
1376
+ [_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known) if n],
1377
+ [_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names) if n],
1295
1378
  ),
1296
1379
  ir_version=ir_version,
1297
1380
  opset_imports=opset_imports,
1298
1381
  )
1299
1382
  return model
1383
+
1384
+
1385
+ def get_tensor_shape(
1386
+ obj: Union[onnx.ValueInfoProto, onnx.TypeProto, onnx.TensorProto],
1387
+ ) -> Optional[List[Optional[Union[int, str]]]]:
1388
+ """
1389
+ Returns the shape if that makes sense for this object.
1390
+ """
1391
+ if isinstance(obj, ValueInfoProto):
1392
+ return get_tensor_shape(obj.type)
1393
+ elif not isinstance(obj, onnx.TypeProto):
1394
+ raise TypeError(f"Unexpected type {type(obj)!r}.")
1395
+ if not obj.tensor_type.HasField("shape"):
1396
+ return None
1397
+ shape = []
1398
+ for d in obj.tensor_type.shape.dim:
1399
+ v = d.dim_value if d.dim_value > 0 else d.dim_param
1400
+ shape.append(v)
1401
+ if not shape:
1402
+ return shape
1403
+ return [None if s in (0, "") else s for s in shape]
1404
+
1405
+
1406
+ def _enumerate_model_node_outputs(
1407
+ model: ModelProto, add_node: bool = False, order: bool = False
1408
+ ) -> Iterable[Union[str, Tuple[str, NodeProto]]]:
1409
+ """
1410
+ Enumerates all the nodes of a model.
1411
+
1412
+ :param model: :epkg:`ONNX` graph
1413
+ :param add_node: if False, the function enumerates
1414
+ all output names from every node, otherwise, it
1415
+ enumerates tuple (output name, node)
1416
+ :param order: goes through outputs following the graph order
1417
+ :return: enumerator
1418
+ """
1419
+ assert hasattr(model, "graph"), "Parameter model is not an ONNX model but {type(model)}"
1420
+ if order:
1421
+ edges = []
1422
+ d_order = {}
1423
+ node_names = {}
1424
+ for inp in model.graph.input:
1425
+ d_order[0, inp.name] = 0
1426
+ for node in model.graph.node:
1427
+ d_order[1, node.name] = 0
1428
+ for i in node.input:
1429
+ edges.append(("in", i, node.name))
1430
+ for o in node.output:
1431
+ edges.append(("out", o, node.name))
1432
+ node_names[o] = node
1433
+ d_order[0, o] = 0
1434
+
1435
+ modif = 1
1436
+ n_iter = 0
1437
+ while modif > 0 and n_iter <= len(model.graph.node):
1438
+ modif = 0
1439
+ n_iter += 1
1440
+ for kind, data_name, node_name in edges:
1441
+ if kind == "in":
1442
+ if (0, data_name) not in d_order:
1443
+ continue
1444
+ if d_order[0, data_name] + 1 > d_order[1, node_name]:
1445
+ modif += 1
1446
+ d_order[1, node_name] = d_order[0, data_name] + 1
1447
+ else:
1448
+ if d_order[1, node_name] + 1 > d_order[0, data_name]:
1449
+ modif += 1
1450
+ d_order[0, data_name] = d_order[1, node_name] + 1
1451
+
1452
+ orders = [(v, k) for k, v in d_order.items()]
1453
+ orders.sort()
1454
+
1455
+ for _, k in orders:
1456
+ if k[0] == 1:
1457
+ continue
1458
+ out = k[1]
1459
+ if out not in node_names:
1460
+ continue
1461
+ yield (out, node_names[out]) if add_node else out
1462
+ else:
1463
+ for node in model.graph.node:
1464
+ for out in node.output:
1465
+ yield (out, node) if add_node else out
1466
+
1467
+
1468
+ def onnx_remove_node_unused(
1469
+ graph: Union[onnx.GraphProto, onnx.FunctionProto], recursive=True
1470
+ ) -> Union[onnx.GraphProto, onnx.FunctionProto]:
1471
+ """
1472
+ Removes unused nodes of the graph. An unused node
1473
+ is not involved in the output computation.
1474
+
1475
+ :param onnx_model: onnx model
1476
+ :param recursive: looks into subgraphs
1477
+ :return: new Graph
1478
+ """
1479
+ is_function = isinstance(graph, FunctionProto)
1480
+
1481
+ # mark outputs
1482
+ marked: Dict[str, Set[str]] = (
1483
+ {o: set() for o in graph.output}
1484
+ if is_function
1485
+ else {o.name: set() for o in graph.output}
1486
+ )
1487
+ nodes = list(graph.node)
1488
+
1489
+ # mark node output
1490
+ for node in reversed(nodes):
1491
+ used = False
1492
+ for o in node.output:
1493
+ if o and o in marked:
1494
+ for i in get_all_node_inputs(node):
1495
+ marked[o].add(i)
1496
+ used = True
1497
+ if used:
1498
+ for i in get_all_node_inputs(node):
1499
+ marked[i] = set()
1500
+
1501
+ # removed nodes
1502
+ removed = set()
1503
+ marked_set = set(marked)
1504
+ for ind, node in enumerate(nodes):
1505
+ if not ({o for o in node.output if o} & marked_set):
1506
+ removed.add(ind)
1507
+
1508
+ if not is_function:
1509
+ initializers = [i for i in graph.initializer if i.name in marked]
1510
+ sparse_initializers = [i for i in graph.sparse_initializer if i.name in marked]
1511
+ new_nodes = [node for i, node in enumerate(nodes) if i not in removed]
1512
+
1513
+ # Finally create the new graph.
1514
+ if is_function:
1515
+ return oh.make_function(
1516
+ graph.domain,
1517
+ graph.name,
1518
+ graph.input,
1519
+ graph.output,
1520
+ new_nodes,
1521
+ opset_imports=graph.opset_import,
1522
+ attributes=graph.attribute,
1523
+ doc_string=graph.doc_string,
1524
+ )
1525
+ new_graph = oh.make_graph(
1526
+ new_nodes,
1527
+ graph.name,
1528
+ graph.input,
1529
+ graph.output,
1530
+ initializers,
1531
+ sparse_initializers,
1532
+ )
1533
+ new_graph.value_info.extend(graph.value_info)
1534
+ return new_graph
1535
+
1536
+
1537
+ def select_model_inputs_outputs(
1538
+ model: ModelProto,
1539
+ outputs: Optional[List[str]] = None,
1540
+ inputs: Optional[List[str]] = None,
1541
+ infer_shapes: bool = True,
1542
+ overwrite: Optional[Dict[str, Any]] = None,
1543
+ remove_unused: bool = True,
1544
+ verbose: int = 0,
1545
+ ):
1546
+ """
1547
+ Takes a model and changes its outputs.
1548
+
1549
+ :param model: :epkg:`ONNX` model
1550
+ :param inputs: new inputs, same ones if None
1551
+ :param outputs: new outputs, same ones if None
1552
+ :param infer_shapes: infer inputs and outputs shapes
1553
+ :param overwrite: overwrite type and shapes for
1554
+ inputs or outputs, *overwrite* is a
1555
+ dictionary `{'name': (numpy dtype, shape)}`
1556
+ :param remove_unused: remove unused nodes from the graph
1557
+ :param verbose: display information while converting
1558
+ :return: modified model
1559
+
1560
+ The function removes unneeded nodes.
1561
+
1562
+ The following example shows how to change the inputs of model
1563
+ to bypass the first nodes. Shape inferences fails to determine
1564
+ the new inputs type. They need to be overwritten.
1565
+ `verbose=1` shows the number of deleted nodes.
1566
+
1567
+ ::
1568
+
1569
+ import onnx
1570
+ from onnx_extended.tools.onnx_nodes import select_model_inputs_outputs
1571
+
1572
+ onx = onnx.load(path)
1573
+ onx2 = select_model_inputs_outputs(
1574
+ onx, inputs=["a", "b"],
1575
+ infer_shapes=True, verbose=1,
1576
+ overwrite={'a': (numpy.int32, None), 'b': (numpy.int64, None)})
1577
+ onnx.save(onx2, path2)
1578
+ """
1579
+ if not isinstance(model, ModelProto):
1580
+ raise TypeError(f"Unexpected type {type(model)} for model.")
1581
+ if inputs is not None and not isinstance(inputs, list):
1582
+ inputs = [inputs]
1583
+ if outputs is not None and not isinstance(outputs, list):
1584
+ outputs = [outputs]
1585
+ if inputs is None:
1586
+ inputs = [i.name for i in model.graph.input]
1587
+ if outputs is None:
1588
+ outputs = [o.name for o in model.graph.output]
1589
+
1590
+ mark_var = {}
1591
+ for out in _enumerate_model_node_outputs(model):
1592
+ mark_var[out] = 0
1593
+ for inp in inputs:
1594
+ mark_var[inp] = 0
1595
+ for out in outputs:
1596
+ assert out in mark_var, f"Output {out!r} not found in model."
1597
+ mark_var[out] = 1
1598
+
1599
+ nodes = list(model.graph.node[::-1])
1600
+ mark_op = {}
1601
+ for node in list(nodes):
1602
+ mark_op[id(node)] = 0
1603
+
1604
+ # We mark all the nodes we need to keep.
1605
+ nb = 1
1606
+ while nb > 0:
1607
+ nb = 0
1608
+ for node in nodes:
1609
+ if mark_op[id(node)] == 1:
1610
+ continue
1611
+ mod = False
1612
+ for out in node.output:
1613
+ if mark_var[out] == 1:
1614
+ mark_op[id(node)] = 1
1615
+ mod = True
1616
+ break
1617
+ if not mod:
1618
+ continue
1619
+
1620
+ node_inputs = get_all_node_inputs(node)
1621
+
1622
+ nb += 1
1623
+ for inp in node_inputs:
1624
+ if inp in inputs:
1625
+ continue
1626
+ if mark_var.get(inp, 0) == 1:
1627
+ continue
1628
+ mark_var[inp] = 1
1629
+ nb += 1
1630
+
1631
+ # All nodes verifies mark_op[node.name] == 1
1632
+ keep_nodes = [node for node in nodes[::-1] if mark_op[id(node)] == 1]
1633
+
1634
+ known_shapes = {}
1635
+ if infer_shapes:
1636
+ shapes = onnx.shape_inference.infer_shapes(model)
1637
+ for shape in shapes.graph.value_info:
1638
+ known_shapes[shape.name] = shape.type
1639
+ for shape in shapes.graph.input:
1640
+ known_shapes[shape.name] = shape.type
1641
+ for shape in shapes.graph.output:
1642
+ known_shapes[shape.name] = shape.type
1643
+ else:
1644
+ for shape in model.graph.input:
1645
+ known_shapes[shape.name] = shape.type
1646
+ for shape in model.graph.output:
1647
+ known_shapes[shape.name] = shape.type
1648
+
1649
+ var_in = []
1650
+ existing = {i.name: i for i in model.graph.input}
1651
+ for name in inputs:
1652
+ if overwrite is not None and name in overwrite:
1653
+ dtype, shape = overwrite[name]
1654
+ proto_dtype = np_dtype_to_tensor_dtype(dtype)
1655
+ value_info = oh.make_tensor_value_info(name, proto_dtype, shape)
1656
+ elif name in known_shapes:
1657
+ info = known_shapes[name].tensor_type
1658
+ proto_dtype = info.elem_type
1659
+ if proto_dtype == 0:
1660
+ value_info = ValueInfoProto()
1661
+ value_info.name = name
1662
+ else:
1663
+ shape = get_tensor_shape(known_shapes[name])
1664
+ value_info = oh.make_tensor_value_info(name, proto_dtype, shape)
1665
+ elif name in existing:
1666
+ value_info = existing[name]
1667
+ else:
1668
+ value_info = ValueInfoProto()
1669
+ value_info.name = name
1670
+ var_in.append(value_info)
1671
+
1672
+ var_out = []
1673
+ existing = {i.name: i for i in model.graph.output}
1674
+ for name in outputs:
1675
+ if overwrite is not None and name in overwrite:
1676
+ dtype, shape = overwrite[name]
1677
+ proto_dtype = np_dtype_to_tensor_dtype(dtype)
1678
+ value_info = oh.make_tensor_value_info(name, proto_dtype, shape)
1679
+ elif name in known_shapes:
1680
+ info = known_shapes[name].tensor_type
1681
+ proto_dtype = info.elem_type
1682
+ if proto_dtype == 0:
1683
+ value_info = ValueInfoProto()
1684
+ value_info.name = name
1685
+ else:
1686
+ shape = get_tensor_shape(known_shapes[name])
1687
+ value_info = oh.make_tensor_value_info(name, proto_dtype, shape)
1688
+ elif name in existing:
1689
+ value_info = existing[name]
1690
+ else:
1691
+ value_info = ValueInfoProto()
1692
+ value_info.name = name
1693
+ var_out.append(value_info)
1694
+
1695
+ graph = oh.make_graph(
1696
+ keep_nodes,
1697
+ model.graph.name,
1698
+ var_in,
1699
+ var_out,
1700
+ model.graph.initializer,
1701
+ sparse_initializer=model.graph.sparse_initializer,
1702
+ )
1703
+ if remove_unused:
1704
+ graph = onnx_remove_node_unused(graph, recursive=False)
1705
+ onnx_model = oh.make_model(graph, functions=model.functions)
1706
+ onnx_model.ir_version = model.ir_version
1707
+ onnx_model.producer_name = model.producer_name
1708
+ onnx_model.producer_version = model.producer_version
1709
+ onnx_model.domain = model.domain
1710
+ onnx_model.model_version = model.model_version
1711
+ onnx_model.doc_string = model.doc_string
1712
+ if model.metadata_props:
1713
+ values = {p.key: p.value for p in model.metadata_props}
1714
+ oh.set_model_props(onnx_model, values)
1715
+
1716
+ del onnx_model.opset_import[:]
1717
+ for oimp in model.opset_import:
1718
+ op_set = onnx_model.opset_import.add()
1719
+ op_set.domain = oimp.domain
1720
+ op_set.version = oimp.version
1721
+
1722
+ return onnx_model
@@ -1,7 +1,6 @@
1
1
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2
2
  import onnx
3
3
  import numpy as np
4
- import numpy.typing as npt
5
4
  import torch
6
5
  from torch._C import _from_dlpack
7
6
  import onnxruntime
@@ -16,6 +15,7 @@ from .torch_helper import torch_dtype_to_onnx_dtype
16
15
 
17
16
 
18
17
  DEVICES = {-1: ORTC.OrtDevice(ORTC.OrtDevice.cpu(), ORTC.OrtDevice.default_memory(), 0)}
18
+ TensorLike = Union[np.ndarray, torch.Tensor]
19
19
 
20
20
 
21
21
  class _InferenceSession:
@@ -243,16 +243,16 @@ class InferenceSessionForNumpy(_InferenceSession):
243
243
  )
244
244
 
245
245
  def run(
246
- self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike]
247
- ) -> List[Optional[npt.ArrayLike]]:
246
+ self, output_names: Optional[List[str]], feeds: Dict[str, TensorLike]
247
+ ) -> List[Optional[TensorLike]]:
248
248
  """Calls :meth:`onnxruntime.InferenceSession.run`."""
249
249
  # sess.run does not support blfoat16
250
250
  # res = self.sess.run(output_names, feeds)
251
251
  return self._post_process_inplace(list(self.run_dlpack(output_names, feeds)))
252
252
 
253
253
  def run_dlpack(
254
- self, output_names: Optional[List[str]], feeds: Dict[str, npt.ArrayLike]
255
- ) -> Tuple[Optional[npt.ArrayLike], ...]:
254
+ self, output_names: Optional[List[str]], feeds: Dict[str, TensorLike]
255
+ ) -> Tuple[Optional[TensorLike], ...]:
256
256
  """
257
257
  Same as :meth:`onnxruntime.InferenceSession.run` except that
258
258
  feeds is a dictionary of :class:`np.ndarray`.
@@ -289,13 +289,13 @@ class InferenceSessionForNumpy(_InferenceSession):
289
289
  def _ortvalues_to_numpy_tensor(
290
290
  self,
291
291
  ortvalues: Union[List[ORTC.OrtValue], ORTC.OrtValueVector],
292
- ) -> Tuple[Optional[npt.ArrayLike], ...]:
292
+ ) -> Tuple[Optional[TensorLike], ...]:
293
293
  if len(ortvalues) == 0:
294
294
  return tuple()
295
295
 
296
296
  if self.nvtx:
297
297
  self.torch.cuda.nvtx.range_push("_ortvalues_to_numpy_tensor")
298
- res: List[Optional[npt.ArrayLike]] = [] # noqa: F823
298
+ res: List[Optional[TensorLike]] = [] # noqa: F823
299
299
  for i in range(len(ortvalues)):
300
300
  if not ortvalues[i].has_value():
301
301
  res.append(None)
@@ -556,7 +556,7 @@ def investigate_onnxruntime_issue(
556
556
  Union[str, Callable[[onnx.ModelProto], onnxruntime.InferenceSession]]
557
557
  ] = None,
558
558
  # if model needs to be run.
559
- feeds: Optional[Union[Dict[str, torch.Tensor], Dict[str, npt.ArrayLike]]] = None,
559
+ feeds: Optional[Union[Dict[str, torch.Tensor], Dict[str, TensorLike]]] = None,
560
560
  verbose: int = 0,
561
561
  dump_filename: Optional[str] = None,
562
562
  infer_shapes: bool = True,
@@ -139,6 +139,15 @@ def onnx_dtype_to_torch_dtype(itype: int) -> torch.dtype:
139
139
  )
140
140
 
141
141
 
142
+ _TYPENAME = dict(
143
+ FLOAT=onnx.TensorProto.FLOAT,
144
+ INT64=onnx.TensorProto.INT64,
145
+ INT32=onnx.TensorProto.INT32,
146
+ FLOAT16=onnx.TensorProto.FLOAT16,
147
+ BFLOAT16=onnx.TensorProto.BFLOAT16,
148
+ )
149
+
150
+
142
151
  def torch_dtype_to_onnx_dtype(to: torch.dtype) -> int:
143
152
  """
144
153
  Converts a torch dtype into a onnx element type.
@@ -182,7 +191,13 @@ def torch_dtype_to_onnx_dtype(to: torch.dtype) -> int:
182
191
  return onnx.TensorProto.COMPLEX64
183
192
  if to == torch.complex128:
184
193
  return onnx.TensorProto.COMPLEX128
185
- raise NotImplementedError(f"Unable to convert torch dtype {to!r} to onnx dtype.")
194
+ # SymbolicTensor
195
+ sto = str(to)
196
+ if sto in _TYPENAME:
197
+ return _TYPENAME[sto]
198
+ raise NotImplementedError(
199
+ f"Unable to convert torch dtype {to!r} ({type(to)}) to onnx dtype."
200
+ )
186
201
 
187
202
 
188
203
  def _forward_(
@@ -811,7 +826,8 @@ def torch_deepcopy(value: Any) -> Any:
811
826
  if isinstance(value, tuple):
812
827
  return tuple(torch_deepcopy(v) for v in value)
813
828
  if isinstance(value, list):
814
- return [torch_deepcopy(v) for v in value]
829
+ if type(value) is list:
830
+ return [torch_deepcopy(v) for v in value]
815
831
  if isinstance(value, set):
816
832
  return {torch_deepcopy(v) for v in value}
817
833
  if isinstance(value, dict):
@@ -1087,3 +1103,13 @@ def study_discrepancies(
1087
1103
  if name:
1088
1104
  fig.savefig(name)
1089
1105
  return ax
1106
+
1107
+
1108
+ def int_device_to_torch_device(device_id: int) -> torch.device:
1109
+ """
1110
+ Converts a device defined as an integer (coming from :meth:`torch.Tensor.get_device`)
1111
+ into a ``torch.device``.
1112
+ """
1113
+ if device_id < 0:
1114
+ return torch.device("cpu")
1115
+ return torch.device("cuda", device_id)