onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +412 -12
  3. onnx_diagnostic/export/api.py +111 -8
  4. onnx_diagnostic/export/control_flow.py +48 -345
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +12 -7
  7. onnx_diagnostic/export/onnx_plug.py +531 -0
  8. onnx_diagnostic/ext_test_case.py +163 -48
  9. onnx_diagnostic/helpers/cache_helper.py +1 -1
  10. onnx_diagnostic/helpers/dot_helper.py +222 -0
  11. onnx_diagnostic/helpers/helper.py +108 -37
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  14. onnx_diagnostic/helpers/onnx_helper.py +531 -6
  15. onnx_diagnostic/helpers/ort_session.py +45 -19
  16. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  17. onnx_diagnostic/helpers/torch_helper.py +131 -8
  18. onnx_diagnostic/reference/ort_evaluator.py +228 -46
  19. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  20. onnx_diagnostic/tasks/summarization.py +72 -137
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
  22. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  23. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  24. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
  36. onnx_diagnostic/torch_models/code_sample.py +2 -1
  37. onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
  38. onnx_diagnostic/torch_models/validate.py +64 -2
  39. onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
  40. onnx_diagnostic/torch_onnx/sbs.py +969 -312
  41. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
  42. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
  43. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
  44. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
  45. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  46. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
@@ -3,9 +3,20 @@ import json
3
3
  import os
4
4
  import sys
5
5
  import warnings
6
- from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
6
+ from typing import (
7
+ Any,
8
+ Callable,
9
+ Dict,
10
+ Iterable,
11
+ Iterator,
12
+ List,
13
+ Optional,
14
+ Sequence,
15
+ Set,
16
+ Tuple,
17
+ Union,
18
+ )
7
19
  import numpy as np
8
- import numpy.typing as npt
9
20
  import onnx
10
21
  import onnx.helper as oh
11
22
  import onnx.numpy_helper as onh
@@ -15,11 +26,14 @@ from onnx import (
15
26
  GraphProto,
16
27
  ModelProto,
17
28
  NodeProto,
29
+ OperatorSetIdProto,
18
30
  TensorProto,
19
31
  ValueInfoProto,
20
32
  load as onnx_load,
21
33
  )
22
34
 
35
+ TensorLike = Union[np.ndarray, "torch.Tensor"] # noqa: F821
36
+
23
37
 
24
38
  def _make_stat(init: TensorProto) -> Dict[str, float]:
25
39
  """
@@ -331,7 +345,7 @@ def onnx_dtype_name(itype: int, exc: bool = True) -> str:
331
345
  print(onnx_dtype_name(7))
332
346
  """
333
347
  for k in dir(TensorProto):
334
- if k.upper() == k and k != "EXTERNAL":
348
+ if k.upper() == k and k not in {"DESCRIPTOR", "EXTERNAL", "DEFAULT"}:
335
349
  v = getattr(TensorProto, k)
336
350
  if v == itype:
337
351
  return k
@@ -477,7 +491,7 @@ def convert_endian(tensor: TensorProto) -> None:
477
491
  tensor.raw_data = np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap().tobytes()
478
492
 
479
493
 
480
- def from_array_ml_dtypes(arr: npt.ArrayLike, name: Optional[str] = None) -> TensorProto:
494
+ def from_array_ml_dtypes(arr: TensorLike, name: Optional[str] = None) -> TensorProto:
481
495
  """
482
496
  Converts a numpy array to a tensor def assuming the dtype
483
497
  is defined in ml_dtypes.
@@ -523,7 +537,7 @@ _STORAGE_TYPE = {
523
537
  }
524
538
 
525
539
 
526
- def from_array_extended(tensor: npt.ArrayLike, name: Optional[str] = None) -> TensorProto:
540
+ def from_array_extended(tensor: TensorLike, name: Optional[str] = None) -> TensorProto:
527
541
  """
528
542
  Converts an array into a :class:`onnx.TensorProto`.
529
543
 
@@ -590,7 +604,7 @@ def from_array_extended(tensor: npt.ArrayLike, name: Optional[str] = None) -> Te
590
604
  return t
591
605
 
592
606
 
593
- def to_array_extended(proto: TensorProto) -> npt.ArrayLike:
607
+ def to_array_extended(proto: TensorProto) -> TensorLike:
594
608
  """Converts :class:`onnx.TensorProto` into a numpy array."""
595
609
  arr = onh.to_array(proto)
596
610
  if proto.data_type >= onnx.TensorProto.BFLOAT16:
@@ -1195,3 +1209,514 @@ def shadowing_names(
1195
1209
  existing |= not_empty
1196
1210
  created |= not_empty
1197
1211
  return shadow, post_shadow, created
1212
+
1213
+
1214
+ def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
1215
+ """
1216
+ Returns the hidden inputs (inputs coming from an upper context)
1217
+ used by a subgraph. It excludes empty names.
1218
+ """
1219
+ hidden = set()
1220
+ memo = (
1221
+ set(i.name for i in graph.initializer)
1222
+ | set(i.name for i in graph.sparse_initializer)
1223
+ | set(i.name for i in graph.input)
1224
+ )
1225
+ for node in graph.node:
1226
+ for i in node.input:
1227
+ if i and i not in memo:
1228
+ hidden.add(i)
1229
+ for att in node.attribute:
1230
+ if att.type == onnx.AttributeProto.GRAPH and att.g:
1231
+ hid = get_hidden_inputs(att.g)
1232
+ less = set(h for h in hid if h not in memo)
1233
+ hidden |= less
1234
+ memo |= set(node.output)
1235
+ return hidden
1236
+
1237
+
1238
+ def get_all_node_inputs(node: onnx.NodeProto) -> Set[str]:
1239
+ """
1240
+ Returns input and hidden inputs of a node.
1241
+ See :func:`get_hidden_inputs`. It excludes empty names.
1242
+ """
1243
+ start = {i for i in node.input if i}
1244
+ if node.op_type in {"Scan", "Loop", "If"}:
1245
+ for att in node.attribute:
1246
+ if att.type == onnx.AttributeProto.GRAPH:
1247
+ start |= get_hidden_inputs(att.g)
1248
+ return start
1249
+
1250
+
1251
+ def extract_subset_of_nodes(
1252
+ model: ModelProto,
1253
+ name: str,
1254
+ node_index: Optional[int] = None,
1255
+ cut_points: Optional[Set[str]] = None,
1256
+ ) -> List[NodeProto]:
1257
+ """
1258
+ Extracts the minimal subgraphs which can produce the output ``name``
1259
+ knowing ``cut_points``.
1260
+
1261
+ :param model: original model
1262
+ :param name: result name
1263
+ :param node_index: if the node index is known, otherwise searches for it
1264
+ :param cut_points: the known results or input name otherwise
1265
+ :return: minimal list of nodes
1266
+ """
1267
+ if node_index is None:
1268
+ for i, node in enumerate(model.graph.node):
1269
+ if name in node.output:
1270
+ node_index = i
1271
+ break
1272
+ assert node_index is not None and node_index < len(model.graph.node), (
1273
+ f"node_index={node_index} (n_nodes={len(model.graph.node)}) "
1274
+ f"is still empty or wrong for result {name!r}"
1275
+ )
1276
+ assert name in model.graph.node[node_index].output, (
1277
+ f"Unable to find {name!r} in {model.graph.node[node_index].output}, "
1278
+ f"node={pretty_onnx(model.graph.node[node_index])}"
1279
+ )
1280
+ if cut_points is None:
1281
+ cut_points = {n.name for n in model.graph.input} | {
1282
+ n.name for n in model.graph.initializer
1283
+ }
1284
+ elif model.graph.initializer:
1285
+ cut_points = cut_points | {n.name for n in model.graph.initializer}
1286
+
1287
+ node = model.graph.node[node_index]
1288
+ selected = {node_index}
1289
+ current_node_index = node_index
1290
+ current_input_index = 0
1291
+ intermediate = {name}
1292
+ cut_points -= {name}
1293
+ cached: Dict[int, List[str]] = {}
1294
+ inputs = set(k for k in node.input if k)
1295
+ while not (inputs <= cut_points) and current_node_index >= 0:
1296
+ node = model.graph.node[current_node_index]
1297
+ # node inputs including hidden ones
1298
+ if current_node_index in cached:
1299
+ node_inputs = cached[current_node_index]
1300
+ else:
1301
+ set_inputs = set(i for i in node.input if i)
1302
+ if node.op_type in {"Scan", "If", "Loop"}:
1303
+ # there are hidden inputs
1304
+ for att in node.attribute:
1305
+ if att.type == onnx.AttributeProto.GRAPH:
1306
+ set_inputs |= get_hidden_inputs(att.g)
1307
+ node_inputs = list(set_inputs)
1308
+ cached[current_node_index] = node_inputs
1309
+ # processing
1310
+ if current_input_index == 0 or not node_inputs:
1311
+ needs = [o for o in node.output if o in intermediate and o not in cut_points]
1312
+ if needs:
1313
+ selected.add(current_node_index)
1314
+ if not node_inputs:
1315
+ current_node_index -= 1
1316
+ current_input_index = 0
1317
+ continue
1318
+ else:
1319
+ current_node_index -= 1
1320
+ current_input_index = 0
1321
+ continue
1322
+ # more intermediate results
1323
+ assert current_input_index < len(node_inputs), (
1324
+ f"current_input_index={current_input_index} but node_inputs={node_inputs}, "
1325
+ f"node={pretty_onnx(node)}"
1326
+ )
1327
+ res = node_inputs[current_input_index]
1328
+ if res not in cut_points:
1329
+ intermediate.add(res)
1330
+ current_input_index += 1
1331
+ if current_input_index >= len(node_inputs):
1332
+ current_node_index -= 1
1333
+ current_input_index = 0
1334
+
1335
+ return [model.graph.node[i] for i in sorted(selected)]
1336
+
1337
+
1338
+ def make_submodel(
1339
+ nodes: List[NodeProto],
1340
+ ir_version: int,
1341
+ opset_imports: List[OperatorSetIdProto],
1342
+ output_names: List[str],
1343
+ type_rank_fn: Callable[[str], Tuple[int, int]],
1344
+ ) -> ModelProto:
1345
+ """
1346
+ Creates a model with the given list of nodes.
1347
+ It computes the minimum list of inputs needed for this model.
1348
+ The function assumes the nodes are sorted.
1349
+ It does not handle yet subgraphs.
1350
+
1351
+ :param nodes: list of nodes
1352
+ :param ir_version: ir version
1353
+ :param opset_imports: opset import
1354
+ :param output_names: desired outputs
1355
+ :param function: function returning the type and the rank of a result
1356
+ :return: model proto
1357
+ """
1358
+
1359
+ def _mkv_(name, itype, irank):
1360
+ return oh.make_tensor_value_info(name, itype, [f"{name}_d{i}" for i in range(irank)])
1361
+
1362
+ not_known: Set[str] = set()
1363
+ for node in nodes[::-1]:
1364
+ not_known -= {o for o in node.output if o}
1365
+ not_known |= {i for i in node.input if i}
1366
+ if node.op_type in {"Scan", "If", "Loop"}:
1367
+ # there are hidden inputs
1368
+ for att in node.attribute:
1369
+ if att.type == onnx.AttributeProto.GRAPH:
1370
+ not_known |= get_hidden_inputs(att.g)
1371
+
1372
+ model = oh.make_model(
1373
+ oh.make_graph(
1374
+ nodes,
1375
+ "submodel",
1376
+ [_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known) if n],
1377
+ [_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names) if n],
1378
+ ),
1379
+ ir_version=ir_version,
1380
+ opset_imports=opset_imports,
1381
+ )
1382
+ return model
1383
+
1384
+
1385
+ def get_tensor_shape(
1386
+ obj: Union[onnx.ValueInfoProto, onnx.TypeProto, onnx.TensorProto],
1387
+ ) -> Optional[List[Optional[Union[int, str]]]]:
1388
+ """
1389
+ Returns the shape if that makes sense for this object.
1390
+ """
1391
+ if isinstance(obj, ValueInfoProto):
1392
+ return get_tensor_shape(obj.type)
1393
+ elif not isinstance(obj, onnx.TypeProto):
1394
+ raise TypeError(f"Unexpected type {type(obj)!r}.")
1395
+ if not obj.tensor_type.HasField("shape"):
1396
+ return None
1397
+ shape = []
1398
+ for d in obj.tensor_type.shape.dim:
1399
+ v = d.dim_value if d.dim_value > 0 else d.dim_param
1400
+ shape.append(v)
1401
+ if not shape:
1402
+ return shape
1403
+ return [None if s in (0, "") else s for s in shape]
1404
+
1405
+
1406
+ def _enumerate_model_node_outputs(
1407
+ model: ModelProto, add_node: bool = False, order: bool = False
1408
+ ) -> Iterable[Union[str, Tuple[str, NodeProto]]]:
1409
+ """
1410
+ Enumerates all the nodes of a model.
1411
+
1412
+ :param model: :epkg:`ONNX` graph
1413
+ :param add_node: if False, the function enumerates
1414
+ all output names from every node, otherwise, it
1415
+ enumerates tuple (output name, node)
1416
+ :param order: goes through outputs following the graph order
1417
+ :return: enumerator
1418
+ """
1419
+ assert hasattr(model, "graph"), "Parameter model is not an ONNX model but {type(model)}"
1420
+ if order:
1421
+ edges = []
1422
+ d_order = {}
1423
+ node_names = {}
1424
+ for inp in model.graph.input:
1425
+ d_order[0, inp.name] = 0
1426
+ for node in model.graph.node:
1427
+ d_order[1, node.name] = 0
1428
+ for i in node.input:
1429
+ edges.append(("in", i, node.name))
1430
+ for o in node.output:
1431
+ edges.append(("out", o, node.name))
1432
+ node_names[o] = node
1433
+ d_order[0, o] = 0
1434
+
1435
+ modif = 1
1436
+ n_iter = 0
1437
+ while modif > 0 and n_iter <= len(model.graph.node):
1438
+ modif = 0
1439
+ n_iter += 1
1440
+ for kind, data_name, node_name in edges:
1441
+ if kind == "in":
1442
+ if (0, data_name) not in d_order:
1443
+ continue
1444
+ if d_order[0, data_name] + 1 > d_order[1, node_name]:
1445
+ modif += 1
1446
+ d_order[1, node_name] = d_order[0, data_name] + 1
1447
+ else:
1448
+ if d_order[1, node_name] + 1 > d_order[0, data_name]:
1449
+ modif += 1
1450
+ d_order[0, data_name] = d_order[1, node_name] + 1
1451
+
1452
+ orders = [(v, k) for k, v in d_order.items()]
1453
+ orders.sort()
1454
+
1455
+ for _, k in orders:
1456
+ if k[0] == 1:
1457
+ continue
1458
+ out = k[1]
1459
+ if out not in node_names:
1460
+ continue
1461
+ yield (out, node_names[out]) if add_node else out
1462
+ else:
1463
+ for node in model.graph.node:
1464
+ for out in node.output:
1465
+ yield (out, node) if add_node else out
1466
+
1467
+
1468
+ def onnx_remove_node_unused(
1469
+ graph: Union[onnx.GraphProto, onnx.FunctionProto], recursive=True
1470
+ ) -> Union[onnx.GraphProto, onnx.FunctionProto]:
1471
+ """
1472
+ Removes unused nodes of the graph. An unused node
1473
+ is not involved in the output computation.
1474
+
1475
+ :param onnx_model: onnx model
1476
+ :param recursive: looks into subgraphs
1477
+ :return: new Graph
1478
+ """
1479
+ is_function = isinstance(graph, FunctionProto)
1480
+
1481
+ # mark outputs
1482
+ marked: Dict[str, Set[str]] = (
1483
+ {o: set() for o in graph.output}
1484
+ if is_function
1485
+ else {o.name: set() for o in graph.output}
1486
+ )
1487
+ nodes = list(graph.node)
1488
+
1489
+ # mark node output
1490
+ for node in reversed(nodes):
1491
+ used = False
1492
+ for o in node.output:
1493
+ if o and o in marked:
1494
+ for i in get_all_node_inputs(node):
1495
+ marked[o].add(i)
1496
+ used = True
1497
+ if used:
1498
+ for i in get_all_node_inputs(node):
1499
+ marked[i] = set()
1500
+
1501
+ # removed nodes
1502
+ removed = set()
1503
+ marked_set = set(marked)
1504
+ for ind, node in enumerate(nodes):
1505
+ if not ({o for o in node.output if o} & marked_set):
1506
+ removed.add(ind)
1507
+
1508
+ if not is_function:
1509
+ initializers = [i for i in graph.initializer if i.name in marked]
1510
+ sparse_initializers = [i for i in graph.sparse_initializer if i.name in marked]
1511
+ new_nodes = [node for i, node in enumerate(nodes) if i not in removed]
1512
+
1513
+ # Finally create the new graph.
1514
+ if is_function:
1515
+ return oh.make_function(
1516
+ graph.domain,
1517
+ graph.name,
1518
+ graph.input,
1519
+ graph.output,
1520
+ new_nodes,
1521
+ opset_imports=graph.opset_import,
1522
+ attributes=graph.attribute,
1523
+ doc_string=graph.doc_string,
1524
+ )
1525
+ new_graph = oh.make_graph(
1526
+ new_nodes,
1527
+ graph.name,
1528
+ graph.input,
1529
+ graph.output,
1530
+ initializers,
1531
+ sparse_initializers,
1532
+ )
1533
+ new_graph.value_info.extend(graph.value_info)
1534
+ return new_graph
1535
+
1536
+
1537
+ def select_model_inputs_outputs(
1538
+ model: ModelProto,
1539
+ outputs: Optional[List[str]] = None,
1540
+ inputs: Optional[List[str]] = None,
1541
+ infer_shapes: bool = True,
1542
+ overwrite: Optional[Dict[str, Any]] = None,
1543
+ remove_unused: bool = True,
1544
+ verbose: int = 0,
1545
+ ):
1546
+ """
1547
+ Takes a model and changes its outputs.
1548
+
1549
+ :param model: :epkg:`ONNX` model
1550
+ :param inputs: new inputs, same ones if None
1551
+ :param outputs: new outputs, same ones if None
1552
+ :param infer_shapes: infer inputs and outputs shapes
1553
+ :param overwrite: overwrite type and shapes for
1554
+ inputs or outputs, *overwrite* is a
1555
+ dictionary `{'name': (numpy dtype, shape)}`
1556
+ :param remove_unused: remove unused nodes from the graph
1557
+ :param verbose: display information while converting
1558
+ :return: modified model
1559
+
1560
+ The function removes unneeded nodes.
1561
+
1562
+ The following example shows how to change the inputs of model
1563
+ to bypass the first nodes. Shape inferences fails to determine
1564
+ the new inputs type. They need to be overwritten.
1565
+ `verbose=1` shows the number of deleted nodes.
1566
+
1567
+ ::
1568
+
1569
+ import onnx
1570
+ from onnx_extended.tools.onnx_nodes import select_model_inputs_outputs
1571
+
1572
+ onx = onnx.load(path)
1573
+ onx2 = select_model_inputs_outputs(
1574
+ onx, inputs=["a", "b"],
1575
+ infer_shapes=True, verbose=1,
1576
+ overwrite={'a': (numpy.int32, None), 'b': (numpy.int64, None)})
1577
+ onnx.save(onx2, path2)
1578
+ """
1579
+ if not isinstance(model, ModelProto):
1580
+ raise TypeError(f"Unexpected type {type(model)} for model.")
1581
+ if inputs is not None and not isinstance(inputs, list):
1582
+ inputs = [inputs]
1583
+ if outputs is not None and not isinstance(outputs, list):
1584
+ outputs = [outputs]
1585
+ if inputs is None:
1586
+ inputs = [i.name for i in model.graph.input]
1587
+ if outputs is None:
1588
+ outputs = [o.name for o in model.graph.output]
1589
+
1590
+ mark_var = {}
1591
+ for out in _enumerate_model_node_outputs(model):
1592
+ mark_var[out] = 0
1593
+ for inp in inputs:
1594
+ mark_var[inp] = 0
1595
+ for out in outputs:
1596
+ assert out in mark_var, f"Output {out!r} not found in model."
1597
+ mark_var[out] = 1
1598
+
1599
+ nodes = list(model.graph.node[::-1])
1600
+ mark_op = {}
1601
+ for node in list(nodes):
1602
+ mark_op[id(node)] = 0
1603
+
1604
+ # We mark all the nodes we need to keep.
1605
+ nb = 1
1606
+ while nb > 0:
1607
+ nb = 0
1608
+ for node in nodes:
1609
+ if mark_op[id(node)] == 1:
1610
+ continue
1611
+ mod = False
1612
+ for out in node.output:
1613
+ if mark_var[out] == 1:
1614
+ mark_op[id(node)] = 1
1615
+ mod = True
1616
+ break
1617
+ if not mod:
1618
+ continue
1619
+
1620
+ node_inputs = get_all_node_inputs(node)
1621
+
1622
+ nb += 1
1623
+ for inp in node_inputs:
1624
+ if inp in inputs:
1625
+ continue
1626
+ if mark_var.get(inp, 0) == 1:
1627
+ continue
1628
+ mark_var[inp] = 1
1629
+ nb += 1
1630
+
1631
+ # All nodes verifies mark_op[node.name] == 1
1632
+ keep_nodes = [node for node in nodes[::-1] if mark_op[id(node)] == 1]
1633
+
1634
+ known_shapes = {}
1635
+ if infer_shapes:
1636
+ shapes = onnx.shape_inference.infer_shapes(model)
1637
+ for shape in shapes.graph.value_info:
1638
+ known_shapes[shape.name] = shape.type
1639
+ for shape in shapes.graph.input:
1640
+ known_shapes[shape.name] = shape.type
1641
+ for shape in shapes.graph.output:
1642
+ known_shapes[shape.name] = shape.type
1643
+ else:
1644
+ for shape in model.graph.input:
1645
+ known_shapes[shape.name] = shape.type
1646
+ for shape in model.graph.output:
1647
+ known_shapes[shape.name] = shape.type
1648
+
1649
+ var_in = []
1650
+ existing = {i.name: i for i in model.graph.input}
1651
+ for name in inputs:
1652
+ if overwrite is not None and name in overwrite:
1653
+ dtype, shape = overwrite[name]
1654
+ proto_dtype = np_dtype_to_tensor_dtype(dtype)
1655
+ value_info = oh.make_tensor_value_info(name, proto_dtype, shape)
1656
+ elif name in known_shapes:
1657
+ info = known_shapes[name].tensor_type
1658
+ proto_dtype = info.elem_type
1659
+ if proto_dtype == 0:
1660
+ value_info = ValueInfoProto()
1661
+ value_info.name = name
1662
+ else:
1663
+ shape = get_tensor_shape(known_shapes[name])
1664
+ value_info = oh.make_tensor_value_info(name, proto_dtype, shape)
1665
+ elif name in existing:
1666
+ value_info = existing[name]
1667
+ else:
1668
+ value_info = ValueInfoProto()
1669
+ value_info.name = name
1670
+ var_in.append(value_info)
1671
+
1672
+ var_out = []
1673
+ existing = {i.name: i for i in model.graph.output}
1674
+ for name in outputs:
1675
+ if overwrite is not None and name in overwrite:
1676
+ dtype, shape = overwrite[name]
1677
+ proto_dtype = np_dtype_to_tensor_dtype(dtype)
1678
+ value_info = oh.make_tensor_value_info(name, proto_dtype, shape)
1679
+ elif name in known_shapes:
1680
+ info = known_shapes[name].tensor_type
1681
+ proto_dtype = info.elem_type
1682
+ if proto_dtype == 0:
1683
+ value_info = ValueInfoProto()
1684
+ value_info.name = name
1685
+ else:
1686
+ shape = get_tensor_shape(known_shapes[name])
1687
+ value_info = oh.make_tensor_value_info(name, proto_dtype, shape)
1688
+ elif name in existing:
1689
+ value_info = existing[name]
1690
+ else:
1691
+ value_info = ValueInfoProto()
1692
+ value_info.name = name
1693
+ var_out.append(value_info)
1694
+
1695
+ graph = oh.make_graph(
1696
+ keep_nodes,
1697
+ model.graph.name,
1698
+ var_in,
1699
+ var_out,
1700
+ model.graph.initializer,
1701
+ sparse_initializer=model.graph.sparse_initializer,
1702
+ )
1703
+ if remove_unused:
1704
+ graph = onnx_remove_node_unused(graph, recursive=False)
1705
+ onnx_model = oh.make_model(graph, functions=model.functions)
1706
+ onnx_model.ir_version = model.ir_version
1707
+ onnx_model.producer_name = model.producer_name
1708
+ onnx_model.producer_version = model.producer_version
1709
+ onnx_model.domain = model.domain
1710
+ onnx_model.model_version = model.model_version
1711
+ onnx_model.doc_string = model.doc_string
1712
+ if model.metadata_props:
1713
+ values = {p.key: p.value for p in model.metadata_props}
1714
+ oh.set_model_props(onnx_model, values)
1715
+
1716
+ del onnx_model.opset_import[:]
1717
+ for oimp in model.opset_import:
1718
+ op_set = onnx_model.opset_import.add()
1719
+ op_set.domain = oimp.domain
1720
+ op_set.version = oimp.version
1721
+
1722
+ return onnx_model