onnx-diagnostic 0.8.10__py3-none-any.whl → 0.8.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +136 -140
  3. onnx_diagnostic/ci_models/export_phi4_mm.py +2 -4
  4. onnx_diagnostic/export/api.py +2 -4
  5. onnx_diagnostic/export/validate.py +2 -0
  6. onnx_diagnostic/ext_test_case.py +32 -15
  7. onnx_diagnostic/helpers/args_helper.py +1 -0
  8. onnx_diagnostic/helpers/bench_run.py +0 -1
  9. onnx_diagnostic/helpers/cache_helper.py +6 -6
  10. onnx_diagnostic/helpers/doc_helper.py +7 -4
  11. onnx_diagnostic/helpers/graph_helper.py +6 -6
  12. onnx_diagnostic/helpers/log_helper.py +37 -14
  13. onnx_diagnostic/helpers/memory_peak.py +5 -1
  14. onnx_diagnostic/helpers/mini_onnx_builder.py +9 -14
  15. onnx_diagnostic/helpers/model_builder_helper.py +1 -1
  16. onnx_diagnostic/helpers/onnx_helper.py +283 -110
  17. onnx_diagnostic/helpers/ort_session.py +0 -1
  18. onnx_diagnostic/helpers/torch_helper.py +8 -9
  19. onnx_diagnostic/investigate/__init__.py +0 -0
  20. onnx_diagnostic/investigate/input_observer.py +329 -0
  21. onnx_diagnostic/reference/evaluator.py +0 -1
  22. onnx_diagnostic/reference/ort_evaluator.py +0 -1
  23. onnx_diagnostic/reference/report_results_comparison.py +9 -3
  24. onnx_diagnostic/reference/torch_evaluator.py +5 -1
  25. onnx_diagnostic/reference/torch_ops/_op_run.py +3 -5
  26. onnx_diagnostic/reference/torch_ops/sequence_ops.py +1 -1
  27. onnx_diagnostic/tasks/feature_extraction.py +0 -1
  28. onnx_diagnostic/torch_export_patches/__init__.py +0 -1
  29. onnx_diagnostic/torch_export_patches/patch_module.py +1 -1
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
  31. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +44 -23
  32. onnx_diagnostic/torch_models/code_sample.py +5 -10
  33. onnx_diagnostic/torch_models/hghub/hub_data.py +2 -4
  34. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +6 -12
  35. onnx_diagnostic/torch_models/validate.py +1 -1
  36. onnx_diagnostic/torch_onnx/compare.py +0 -1
  37. onnx_diagnostic/torch_onnx/runtime_info.py +1 -1
  38. onnx_diagnostic/torch_onnx/sbs.py +1 -1
  39. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +2 -4
  40. onnx_diagnostic/typing.py +15 -0
  41. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/METADATA +1 -1
  42. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/RECORD +45 -43
  43. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/WHEEL +1 -1
  44. onnx_diagnostic/api.py +0 -15
  45. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/licenses/LICENSE.txt +0 -0
  46. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  import functools
2
2
  import json
3
3
  import os
4
+ import re
4
5
  import sys
5
6
  import warnings
6
7
  from typing import (
@@ -32,11 +33,10 @@ from onnx import (
32
33
  ValueInfoProto,
33
34
  load as onnx_load,
34
35
  )
36
+ from ..typing import InferenceSessionLike, TensorLike
35
37
 
36
- TensorLike = Union[np.ndarray, "torch.Tensor"] # noqa: F821
37
38
 
38
-
39
- def _make_stat(init: TensorProto) -> Dict[str, float]:
39
+ def _make_stat(init: TensorProto) -> Dict[str, Any]:
40
40
  """
41
41
  Produces statistics.
42
42
 
@@ -160,11 +160,11 @@ def _validate_graph(
160
160
  verbose: int = 0,
161
161
  watch: Optional[Set[str]] = None,
162
162
  path: Optional[Sequence[str]] = None,
163
- ):
164
- found = []
163
+ ) -> List[Union[NodeProto, TensorProto, ValueInfoProto]]:
164
+ found: List[Union[NodeProto, TensorProto, ValueInfoProto]] = []
165
165
  path = path or ["root"]
166
- set_init = set(i.name for i in g.initializer)
167
- set_input = set(i.name for i in g.input)
166
+ set_init = {i.name for i in g.initializer}
167
+ set_input = {i.name for i in g.input}
168
168
  existing |= set_init | set_input
169
169
  if watch and set_init & watch:
170
170
  if verbose:
@@ -215,18 +215,15 @@ def _validate_graph(
215
215
  f"in {'/'.join(path)}/{node.op_type}[{node.name}]"
216
216
  )
217
217
  found.append(node)
218
- out = set(o.name for o in g.output)
218
+ out = {o.name for o in g.output}
219
219
  ins = out & existing
220
- if ins != out:
221
- raise AssertionError(
222
- f"One output is missing, out={node.input}, existing={ins}, path={path}"
223
- )
220
+ assert ins == out, f"One output is missing, out={node.input}, existing={ins}, path={path}"
224
221
  return found
225
222
 
226
223
 
227
224
  def _validate_function(g: FunctionProto, verbose: int = 0, watch: Optional[Set[str]] = None):
228
- existing = set(g.input)
229
- found = []
225
+ existing: Set[str] = set(g.input)
226
+ found: List[Union[NodeProto, TensorProto, ValueInfoProto]] = []
230
227
  for node in g.node:
231
228
  ins = set(node.input) & existing
232
229
  if ins != set(node.input):
@@ -240,7 +237,7 @@ def _validate_function(g: FunctionProto, verbose: int = 0, watch: Optional[Set[s
240
237
  for att in node.attribute:
241
238
  if att.type == AttributeProto.GRAPH:
242
239
  found.extend(
243
- _validate_graph(g, existing.copy(), path=[g.name], verbose=verbose)
240
+ _validate_graph(att.g, existing.copy(), path=[g.name], verbose=verbose)
244
241
  )
245
242
  existing |= set(node.output)
246
243
  if watch and set(node.output) & watch:
@@ -285,7 +282,7 @@ def check_model_ort(
285
282
  onx: ModelProto,
286
283
  providers: Optional[Union[str, List[Any]]] = None,
287
284
  dump_file: Optional[str] = None,
288
- ) -> "onnxruntime.InferenceSession": # noqa: F821
285
+ ) -> InferenceSessionLike:
289
286
  """
290
287
  Loads a model with onnxruntime.
291
288
 
@@ -308,10 +305,9 @@ def check_model_ort(
308
305
 
309
306
  if isinstance(onx, str):
310
307
  try:
308
+ # pyrefly: ignore[bad-return]
311
309
  return InferenceSession(onx, providers=providers)
312
310
  except Exception as e:
313
- import onnx
314
-
315
311
  if dump_file:
316
312
  onnx.save(onx, dump_file)
317
313
 
@@ -319,8 +315,8 @@ def check_model_ort(
319
315
  f"onnxruntime cannot load the model "
320
316
  f"due to {e}\n{pretty_onnx(onnx.load(onx))}"
321
317
  )
322
- return
323
318
  try:
319
+ # pyrefly: ignore[bad-return]
324
320
  return InferenceSession(onx.SerializeToString(), providers=providers)
325
321
  except Exception as e:
326
322
  if dump_file:
@@ -358,7 +354,17 @@ def onnx_dtype_name(itype: int, exc: bool = True) -> str:
358
354
 
359
355
 
360
356
  def pretty_onnx(
361
- onx: Union[FunctionProto, GraphProto, ModelProto, ValueInfoProto, str],
357
+ onx: Union[
358
+ AttributeProto,
359
+ FunctionProto,
360
+ GraphProto,
361
+ ModelProto,
362
+ NodeProto,
363
+ onnx.SparseTensorProto,
364
+ TensorProto,
365
+ ValueInfoProto,
366
+ str,
367
+ ],
362
368
  with_attributes: bool = False,
363
369
  highlight: Optional[Set[str]] = None,
364
370
  shape_inference: bool = False,
@@ -377,6 +383,9 @@ def pretty_onnx(
377
383
  assert onx is not None, "onx cannot be None"
378
384
 
379
385
  if shape_inference:
386
+ assert isinstance(
387
+ onx, ModelProto
388
+ ), f"shape inference only works for ModelProto, not {type(onx)}"
380
389
  onx = onnx.shape_inference.infer_shapes(onx)
381
390
 
382
391
  if isinstance(onx, ValueInfoProto):
@@ -447,6 +456,8 @@ def pretty_onnx(
447
456
  shape = "x".join(map(str, onx.dims))
448
457
  return f"TensorProto:{onx.data_type}:{shape}:{onx.name}"
449
458
 
459
+ assert not isinstance(onx, onnx.SparseTensorProto), "SparseTensorProto is not handled yet."
460
+
450
461
  try:
451
462
  from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
452
463
 
@@ -538,12 +549,6 @@ def from_array_ml_dtypes(arr: TensorLike, name: Optional[str] = None) -> TensorP
538
549
  return tensor
539
550
 
540
551
 
541
- _STORAGE_TYPE = {
542
- TensorProto.FLOAT16: np.int16,
543
- TensorProto.BFLOAT16: np.int16,
544
- }
545
-
546
-
547
552
  def from_array_extended(tensor: TensorLike, name: Optional[str] = None) -> TensorProto:
548
553
  """
549
554
  Converts an array into a :class:`onnx.TensorProto`.
@@ -561,54 +566,9 @@ def from_array_extended(tensor: TensorLike, name: Optional[str] = None) -> Tenso
561
566
  ), f"Unable to convert type {type(tensor)} into TensorProto."
562
567
  return proto_from_tensor(tensor, name=name)
563
568
 
564
- try:
565
- from onnx.reference.ops.op_cast import (
566
- bfloat16,
567
- float8e4m3fn,
568
- float8e4m3fnuz,
569
- float8e5m2,
570
- float8e5m2fnuz,
571
- )
572
- except ImportError:
573
- bfloat16 = None
574
-
575
- if bfloat16 is None:
576
- return onh.from_array(tensor, name)
577
-
578
- dt = tensor.dtype
579
- if dt == float8e4m3fn and dt.descr[0][0] == "e4m3fn":
580
- to = TensorProto.FLOAT8E4M3FN
581
- dt_to = np.uint8
582
- elif dt == float8e4m3fnuz and dt.descr[0][0] == "e4m3fnuz":
583
- to = TensorProto.FLOAT8E4M3FNUZ
584
- dt_to = np.uint8
585
- elif dt == float8e5m2 and dt.descr[0][0] == "e5m2":
586
- to = TensorProto.FLOAT8E5M2
587
- dt_to = np.uint8
588
- elif dt == float8e5m2fnuz and dt.descr[0][0] == "e5m2fnuz":
589
- to = TensorProto.FLOAT8E5M2FNUZ
590
- dt_to = np.uint8
591
- elif dt == bfloat16 and dt.descr[0][0] == "bfloat16":
592
- to = TensorProto.BFLOAT16
593
- dt_to = np.uint16
594
- else:
595
- try:
596
- import ml_dtypes
597
- except ImportError:
598
- ml_dtypes = None
599
- if ml_dtypes is not None and (
600
- tensor.dtype == ml_dtypes.bfloat16
601
- or tensor.dtype == ml_dtypes.float8_e4m3fn
602
- or tensor.dtype == ml_dtypes.float8_e4m3fnuz
603
- or tensor.dtype == ml_dtypes.float8_e5m2
604
- or tensor.dtype == ml_dtypes.float8_e5m2fnuz
605
- ):
606
- return from_array_ml_dtypes(tensor, name)
607
- return onh.from_array(tensor, name)
608
-
609
- t = onh.from_array(tensor.astype(dt_to), name)
610
- t.data_type = to
611
- return t
569
+ assert isinstance(tensor, np.ndarray) # type checking
570
+ # pyrefly: ignore[bad-argument-type]
571
+ return onh.from_array(tensor, name)
612
572
 
613
573
 
614
574
  def to_array_extended(proto: TensorProto) -> TensorLike:
@@ -666,6 +626,7 @@ def onnx_dtype_to_np_dtype(itype: int) -> Any:
666
626
  )
667
627
 
668
628
 
629
+ # pyrefly: ignore[unknown-name]
669
630
  def dtype_to_tensor_dtype(dt: Union[np.dtype, "torch.dtype"]) -> int: # noqa: F821
670
631
  """
671
632
  Converts a torch dtype or numpy dtype into a onnx element type.
@@ -679,6 +640,7 @@ def dtype_to_tensor_dtype(dt: Union[np.dtype, "torch.dtype"]) -> int: # noqa: F
679
640
  pass
680
641
  from .torch_helper import torch_dtype_to_onnx_dtype
681
642
 
643
+ # pyrefly: ignore[bad-argument-type]
682
644
  return torch_dtype_to_onnx_dtype(dt)
683
645
 
684
646
 
@@ -779,6 +741,7 @@ def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype:
779
741
  f"ml_dtypes can be used."
780
742
  ) from e
781
743
 
744
+ # pyrefly: ignore[bad-assignment]
782
745
  mapping: Dict[int, np.dtype] = {
783
746
  TensorProto.BFLOAT16: ml_dtypes.bfloat16,
784
747
  TensorProto.FLOAT8E4M3FN: ml_dtypes.float8_e4m3fn,
@@ -798,7 +761,7 @@ def iterator_initializer_constant(
798
761
  model: Union[FunctionProto, GraphProto, ModelProto],
799
762
  use_numpy: bool = True,
800
763
  prefix: str = "",
801
- ) -> Iterator[Tuple[str, Union["torch.Tensor", np.ndarray]]]: # noqa: F821
764
+ ) -> Iterator[Tuple[str, TensorLike]]: # noqa: F821
802
765
  """
803
766
  Iterates on iniatialiers and constant in an onnx model.
804
767
 
@@ -814,9 +777,12 @@ def iterator_initializer_constant(
814
777
  if prefix:
815
778
  prefix += "."
816
779
  for init in graph.initializer:
817
- yield f"{prefix}{init.name}", (
818
- to_array_extended(init) if use_numpy else to_tensor(init)
819
- )
780
+ s = f"{prefix}{init.name}"
781
+ if use_numpy:
782
+ yield s, to_array_extended(init)
783
+ else:
784
+ # pyrefly: ignore[unbound-name]
785
+ yield s, to_tensor(init)
820
786
  nodes = graph.node
821
787
  name = graph.name
822
788
  if isinstance(model, ModelProto):
@@ -831,13 +797,15 @@ def iterator_initializer_constant(
831
797
  if node.op_type == "Constant" and node.domain == "":
832
798
  from ..reference import ExtendedReferenceEvaluator as Inference
833
799
 
834
- if not use_numpy:
835
- import torch
836
800
  sess = Inference(node)
837
801
  value = sess.run(None, {})[0]
838
- yield f"{prefix}{node.output[0]}", (
839
- value if use_numpy else torch.from_numpy(value)
840
- )
802
+
803
+ if not use_numpy:
804
+ import torch
805
+
806
+ yield f"{prefix}{node.output[0]}", (torch.from_numpy(value))
807
+ else:
808
+ yield f"{prefix}{node.output[0]}", (value)
841
809
 
842
810
  if node.op_type in {"Loop", "Body", "Scan"}:
843
811
  for att in node.attribute:
@@ -870,7 +838,9 @@ def tensor_statistics(tensor: Union[np.ndarray, TensorProto]) -> Dict[str, Union
870
838
  from .helper import size_type
871
839
 
872
840
  if isinstance(tensor, TensorProto):
841
+ # pyrefly: ignore[bad-assignment]
873
842
  tensor = to_array_extended(tensor)
843
+ assert isinstance(tensor, np.ndarray) # type checking
874
844
  itype = np_dtype_to_tensor_dtype(tensor.dtype)
875
845
  stat = dict(
876
846
  mean=float(tensor.mean()),
@@ -948,7 +918,7 @@ class NodeCoordinates:
948
918
 
949
919
  def __init__(
950
920
  self,
951
- node: Union[onnx.TensorProto, NodeProto, str],
921
+ node: Union[TensorProto, NodeProto, onnx.SparseTensorProto, ValueInfoProto, str],
952
922
  path: Tuple[Tuple[int, str, str], ...],
953
923
  ):
954
924
  assert isinstance(path, tuple), f"Unexpected type {type(path)} for path"
@@ -968,9 +938,7 @@ class NodeCoordinates:
968
938
 
969
939
 
970
940
  class ResultFound:
971
- """
972
- Class returned by :func:`enumerate_results`.
973
- """
941
+ """Class returned by :func:`enumerate_results`."""
974
942
 
975
943
  __slots__ = ("consumer", "name", "producer")
976
944
 
@@ -1060,9 +1028,9 @@ def enumerate_results(
1060
1028
  print(f"[enumerate_results] {indent}-- {r}")
1061
1029
  yield r
1062
1030
  for i in proto.sparse_initializer:
1063
- if i.name in name:
1031
+ if i.values.name in name:
1064
1032
  r = ResultFound(
1065
- i.name,
1033
+ i.values.name,
1066
1034
  NodeCoordinates(i, tuple([*coordinates, (-1, "INIT", "")])), # noqa: C409
1067
1035
  None,
1068
1036
  )
@@ -1165,9 +1133,9 @@ def shadowing_names(
1165
1133
  return shadowing_names(
1166
1134
  proto.node,
1167
1135
  verbose=verbose,
1168
- existing=set(i.name for i in proto.initializer)
1169
- | set(i.name for i in proto.sparse_initializer)
1170
- | set(i.name for i in proto.input if i.name),
1136
+ existing={i.name for i in proto.initializer}
1137
+ | {i.values.name for i in proto.sparse_initializer}
1138
+ | {i.name for i in proto.input if i.name},
1171
1139
  shadow_context=set(),
1172
1140
  post_shadow_context=set(),
1173
1141
  )
@@ -1201,9 +1169,9 @@ def shadowing_names(
1201
1169
  for att in node.attribute:
1202
1170
  if att.type == AttributeProto.GRAPH:
1203
1171
  g = att.g
1204
- shadow |= set(i.name for i in g.input) & shadow_context
1205
- shadow |= set(i.name for i in g.initializer) & shadow_context
1206
- shadow |= set(i.name for i in g.sparse_initializer) & shadow_context
1172
+ shadow |= {i.name for i in g.input} & shadow_context
1173
+ shadow |= {i.name for i in g.initializer} & shadow_context
1174
+ shadow |= {i.values.name for i in g.sparse_initializer} & shadow_context
1207
1175
  s, _ps, c = shadowing_names(
1208
1176
  g.node, verbose=verbose, existing=existing, shadow_context=existing
1209
1177
  )
@@ -1225,9 +1193,9 @@ def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
1225
1193
  """
1226
1194
  hidden = set()
1227
1195
  memo = (
1228
- set(i.name for i in graph.initializer)
1229
- | set(i.name for i in graph.sparse_initializer)
1230
- | set(i.name for i in graph.input)
1196
+ {i.name for i in graph.initializer}
1197
+ | {i.values.name for i in graph.sparse_initializer}
1198
+ | {i.name for i in graph.input}
1231
1199
  )
1232
1200
  for node in graph.node:
1233
1201
  for i in node.input:
@@ -1353,7 +1321,6 @@ def make_submodel(
1353
1321
  Creates a model with the given list of nodes.
1354
1322
  It computes the minimum list of inputs needed for this model.
1355
1323
  The function assumes the nodes are sorted.
1356
- It does not handle yet subgraphs.
1357
1324
 
1358
1325
  :param nodes: list of nodes
1359
1326
  :param ir_version: ir version
@@ -1376,25 +1343,61 @@ def make_submodel(
1376
1343
  if att.type == onnx.AttributeProto.GRAPH:
1377
1344
  not_known |= get_hidden_inputs(att.g)
1378
1345
 
1379
- model = oh.make_model(
1346
+ return oh.make_model(
1380
1347
  oh.make_graph(
1381
1348
  nodes,
1382
1349
  "submodel",
1383
1350
  [_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known) if n],
1384
- [_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names) if n],
1351
+ [_mkv_(n, *type_rank_fn(n)) for n in output_names if n],
1385
1352
  ),
1386
1353
  ir_version=ir_version,
1387
1354
  opset_imports=opset_imports,
1388
1355
  )
1389
- return model
1356
+
1357
+
1358
+ def make_subfunction(
1359
+ name: str,
1360
+ nodes: List[NodeProto],
1361
+ opset_imports: Sequence[OperatorSetIdProto],
1362
+ output_names: List[str],
1363
+ domain: str = "local_function",
1364
+ ) -> FunctionProto:
1365
+ """
1366
+ Creates a function with the given list of nodes.
1367
+ It computes the minimum list of inputs needed for this model.
1368
+ The function assumes the nodes are sorted.
1369
+
1370
+ :param name: function name
1371
+ :param nodes: list of nodes
1372
+ :param opset_imports: opset import
1373
+ :param output_names: desired outputs
1374
+ :param domain: function domain
1375
+ :return: model proto
1376
+ """
1377
+ not_known: Set[str] = set()
1378
+ for node in nodes[::-1]:
1379
+ not_known -= {o for o in node.output if o}
1380
+ not_known |= {i for i in node.input if i}
1381
+ if node.op_type in {"Scan", "If", "Loop"}:
1382
+ # there are hidden inputs
1383
+ for att in node.attribute:
1384
+ if att.type == onnx.AttributeProto.GRAPH:
1385
+ not_known |= get_hidden_inputs(att.g)
1386
+
1387
+ return oh.make_function(
1388
+ domain,
1389
+ name,
1390
+ nodes=nodes,
1391
+ inputs=sorted(not_known),
1392
+ outputs=output_names,
1393
+ opset_imports=opset_imports,
1394
+ )
1390
1395
 
1391
1396
 
1392
1397
  def get_tensor_shape(
1393
1398
  obj: Union[onnx.ValueInfoProto, onnx.TypeProto, onnx.TensorProto],
1394
1399
  ) -> Optional[List[Optional[Union[int, str]]]]:
1395
- """
1396
- Returns the shape if that makes sense for this object.
1397
- """
1400
+ """Returns the shape if that makes sense for this object."""
1398
1401
  if isinstance(obj, ValueInfoProto):
1399
1402
  return get_tensor_shape(obj.type)
1400
1403
  elif not isinstance(obj, onnx.TypeProto):
@@ -1512,9 +1515,6 @@ def onnx_remove_node_unused(
1512
1515
  if not ({o for o in node.output if o} & marked_set):
1513
1516
  removed.add(ind)
1514
1517
 
1515
- if not is_function:
1516
- initializers = [i for i in graph.initializer if i.name in marked]
1517
- sparse_initializers = [i for i in graph.sparse_initializer if i.name in marked]
1518
1518
  new_nodes = [node for i, node in enumerate(nodes) if i not in removed]
1519
1519
 
1520
1520
  # Finally create the new graph.
@@ -1529,13 +1529,16 @@ def onnx_remove_node_unused(
1529
1529
  attributes=graph.attribute,
1530
1530
  doc_string=graph.doc_string,
1531
1531
  )
1532
+
1533
+ initializers = [i for i in graph.initializer if i.name in marked]
1534
+ sparse_initializers = [i for i in graph.sparse_initializer if i.values.name in marked]
1532
1535
  new_graph = oh.make_graph(
1533
1536
  new_nodes,
1534
1537
  graph.name,
1535
1538
  graph.input,
1536
1539
  graph.output,
1537
1540
  initializers,
1538
- sparse_initializers,
1541
+ sparse_initializer=sparse_initializers,
1539
1542
  )
1540
1543
  new_graph.value_info.extend(graph.value_info)
1541
1544
  return new_graph
@@ -1549,7 +1552,7 @@ def select_model_inputs_outputs(
1549
1552
  overwrite: Optional[Dict[str, Any]] = None,
1550
1553
  remove_unused: bool = True,
1551
1554
  verbose: int = 0,
1552
- ):
1555
+ ) -> ModelProto:
1553
1556
  """
1554
1557
  Takes a model and changes its outputs.
1555
1558
 
@@ -1709,6 +1712,7 @@ def select_model_inputs_outputs(
1709
1712
  )
1710
1713
  if remove_unused:
1711
1714
  graph = onnx_remove_node_unused(graph, recursive=False)
1715
+ assert isinstance(graph, GraphProto) # type checking
1712
1716
  onnx_model = oh.make_model(graph, functions=model.functions)
1713
1717
  onnx_model.ir_version = model.ir_version
1714
1718
  onnx_model.producer_name = model.producer_name
@@ -1727,3 +1731,172 @@ def select_model_inputs_outputs(
1727
1731
  op_set.version = oimp.version
1728
1732
 
1729
1733
  return onnx_model
1734
+
1735
+
1736
+ def _find_used_names(node_list, node_indices):
1737
+ # find all the outputs the subset of nodes produces
1738
+ possible_outputs = set()
1739
+ for i_node in node_indices:
1740
+ if not node_list[i_node]:
1741
+ continue
1742
+ possible_outputs |= {o for o in node_list[i_node].output if o}
1743
+ # find all requires input from the other nodes
1744
+ set_indices = set(node_indices)
1745
+ not_known: Set[str] = set()
1746
+ ranges = list(range(len(node_list)))
1747
+ for i_node in ranges[::-1]:
1748
+ if i_node in set_indices:
1749
+ continue
1750
+ node = node_list[i_node]
1751
+ if not node:
1752
+ continue
1753
+ not_known -= {o for o in node.output if o}
1754
+ not_known |= {i for i in node.input if i}
1755
+ if node.op_type in {"Scan", "If", "Loop"}:
1756
+ # there are hidden inputs
1757
+ for att in node.attribute:
1758
+ if att.type == onnx.AttributeProto.GRAPH:
1759
+ not_known |= get_hidden_inputs(att.g)
1760
+ # output
1761
+ selection = possible_outputs & not_known
1762
+ assert selection, (
1763
+ f"No output is needed, possible_outputs={sorted(possible_outputs)}, "
1764
+ f"not_known={sorted(not_known)}"
1765
+ )
1766
+ return sorted(selection)
1767
+
1768
+
1769
+ def check_for_non_recursivity(
1770
+ node_list: List[Optional[NodeProto]], inputs: Sequence[str], outputs: Sequence[str]
1771
+ ):
1772
+ """
1773
+ We finally need to check that any of this output is not required
1774
+ by one input from the function itself, that would mean one node
1775
+ needs an output of the function and is also required by the function:
1776
+ it is probably missing from the initial set.
1777
+
1778
+
1779
+
1780
+ :param node_list: list of nodes
1781
+ :param inputs: input names to consider
1782
+ :param outputs: output names which cannot be involved in input names
1783
+ """
1784
+ set_inputs = set(inputs)
1785
+ set_outputs = set(outputs)
1786
+ for node in node_list[::-1]:
1787
+ if not node:
1788
+ continue
1789
+ si = set(node.output)
1790
+ if si & set_inputs:
1791
+ set_inputs |= set(node.input)
1792
+ if node.op_type in {"Scan", "If", "Loop"}:
1793
+ # there are hidden inputs
1794
+ for att in node.attribute:
1795
+ if att.type == onnx.AttributeProto.GRAPH:
1796
+ set_inputs |= get_hidden_inputs(att.g)
1797
+ if set_outputs & set_inputs:
1798
+ raise ValueError(
1799
+ f"Results {set_outputs & set_inputs} are needed for inputs {inputs} "
1800
+ f"but also requires {outputs} which is not allowed."
1801
+ )
1802
+
1803
+
1804
+ def make_model_with_local_functions(
1805
+ model: ModelProto,
1806
+ regex: str = ".*[.]layers[.][0-9]+[.]forward$",
1807
+ domain: str = "local_function",
1808
+ metadata_key_prefix: Union[str, Tuple[str, ...]] = ("namespace", "source["),
1809
+ verbose: int = 0,
1810
+ ) -> ModelProto:
1811
+ """
1812
+ Selects nodes based on a regular expression, using metadata
1813
+ ``'namespace'``. It is going to look into every value
1814
+ matching the regular expression and partition the nodes based
1815
+ on the unique values the regular expression finds.
1816
+ Every set of nodes it replaced by a call to a local function.
1817
+
1818
+ :param model: model proto
1819
+ :param regex: regular expression
1820
+ :param domain: function domain
1821
+ :param metadata_keys: list of metadata keys to consider,
1822
+ every value is split into multiple ones.
1823
+ :param verbose: verbosity
1824
+ :return: model proto
1825
+ """
1826
+ prefix = (
1827
+ metadata_key_prefix
1828
+ if isinstance(metadata_key_prefix, tuple)
1829
+ else (metadata_key_prefix,)
1830
+ )
1831
+ reg = re.compile(regex)
1832
+ unique_values = set()
1833
+ unique: Dict[str, List[int]] = {}
1834
+ for i, node in enumerate(model.graph.node):
1835
+ selected = False
1836
+ for data in node.metadata_props:
1837
+ if data.key.startswith(prefix):
1838
+ values = re.split("[,:]", data.value)
1839
+ for v in values:
1840
+ if not v:
1841
+ continue
1842
+ if reg.match(v):
1843
+ if v not in unique:
1844
+ unique[v] = []
1845
+ unique[v].append(i)
1846
+ selected = True
1847
+ break
1848
+ unique_values.add(v)
1849
+ if selected:
1850
+ break
1851
+ # sets of nodes.
1852
+ if not unique:
1853
+ if verbose:
1854
+ print(f"[make_model_with_local_functions] no match in {sorted(unique_values)}")
1855
+ return model
1856
+
1857
+ if verbose:
1858
+ print(f"[make_model_with_local_functions] matched {len(unique)} partitions")
1859
+ functions = []
1860
+ new_nodes: List[Optional[NodeProto]] = list(model.graph.node)
1861
+ for key, node_indices in unique.items():
1862
+ function_name = key.strip().replace(".", "_")
1863
+ if verbose:
1864
+ print(
1865
+ f"[make_model_with_local_functions] move {len(node_indices)} "
1866
+ f"nodes in partition {function_name!r}"
1867
+ )
1868
+ outputs = _find_used_names(new_nodes, node_indices)
1869
+ function_nodes = [new_nodes[i] for i in node_indices]
1870
+ lf = make_subfunction(
1871
+ function_name,
1872
+ [n for n in function_nodes if n],
1873
+ model.opset_import,
1874
+ outputs,
1875
+ domain=domain,
1876
+ )
1877
+ check_for_non_recursivity(new_nodes, lf.input, lf.output)
1878
+ functions.append(lf)
1879
+ maxi = max(node_indices)
1880
+ for i in node_indices:
1881
+ new_nodes[i] = None
1882
+ new_nodes[maxi] = oh.make_node(lf.name, lf.input, lf.output, domain=lf.domain)
1883
+
1884
+ return oh.make_model(
1885
+ oh.make_graph(
1886
+ [n for n in new_nodes if n],
1887
+ model.graph.name,
1888
+ model.graph.input,
1889
+ model.graph.output,
1890
+ model.graph.initializer,
1891
+ doc_string=model.graph.doc_string,
1892
+ value_info=model.graph.value_info,
1893
+ sparse_initializer=model.graph.sparse_initializer,
1894
+ ),
1895
+ ir_version=model.ir_version,
1896
+ opset_imports=(
1897
+ model.opset_import
1898
+ if domain in {d.domain for d in model.opset_import}
1899
+ else [*model.opset_import, oh.make_opsetid(domain, 1)]
1900
+ ),
1901
+ functions=[*model.functions, *functions],
1902
+ )
@@ -14,7 +14,6 @@ from .onnx_helper import (
14
14
  )
15
15
  from .torch_helper import torch_dtype_to_onnx_dtype
16
16
 
17
-
18
17
  DEVICES = {-1: ORTC.OrtDevice(ORTC.OrtDevice.cpu(), ORTC.OrtDevice.default_memory(), 0)}
19
18
  TensorLike = Union[np.ndarray, torch.Tensor]
20
19
 
@@ -19,12 +19,7 @@ from .cache_helper import (
19
19
  CacheKeyValue,
20
20
  )
21
21
  from .mini_onnx_builder import create_onnx_model_from_input_tensors
22
- from .onnx_helper import (
23
- to_array_extended,
24
- tensor_dtype_to_np_dtype,
25
- _STORAGE_TYPE,
26
- onnx_dtype_name,
27
- )
22
+ from .onnx_helper import to_array_extended, tensor_dtype_to_np_dtype, onnx_dtype_name
28
23
 
29
24
 
30
25
  def proto_from_tensor(
@@ -84,13 +79,17 @@ def proto_from_tensor(
84
79
  byte_data = (ctypes.c_ubyte * numel * element_size).from_address(np_arr.data_ptr())
85
80
  tensor.raw_data = bytes(byte_data)
86
81
  if sys.byteorder == "big":
87
- np_dtype = _STORAGE_TYPE[tensor.data_type] # type: ignore
88
- np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True) # type: ignore
82
+ storage_type = {
83
+ onnx.TensorProto.FLOAT16: np.int16,
84
+ onnx.TensorProto.BFLOAT16: np.int16,
85
+ }
86
+ np_dtype = storage_type[tensor.data_type] # type: ignore
87
+ np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap(inplace=True) # type: ignore
89
88
  else:
90
89
  tensor.raw_data = np_arr.tobytes()
91
90
  if sys.byteorder == "big":
92
91
  np_dtype = tensor_dtype_to_np_dtype(tensor.data_type)
93
- np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True)
92
+ np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap(inplace=True)
94
93
  return tensor
95
94
 
96
95
 
File without changes