onnx-diagnostic 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +18 -0
  3. onnx_diagnostic/api.py +15 -0
  4. onnx_diagnostic/ext_test_case.py +3 -1
  5. onnx_diagnostic/helpers/args_helper.py +1 -1
  6. onnx_diagnostic/helpers/doc_helper.py +143 -0
  7. onnx_diagnostic/helpers/helper.py +6 -5
  8. onnx_diagnostic/helpers/model_builder_helper.py +24 -8
  9. onnx_diagnostic/helpers/rt_helper.py +5 -1
  10. onnx_diagnostic/helpers/torch_helper.py +2 -0
  11. onnx_diagnostic/reference/__init__.py +1 -0
  12. onnx_diagnostic/reference/torch_evaluator.py +648 -0
  13. onnx_diagnostic/reference/torch_ops/__init__.py +55 -0
  14. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  15. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  16. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  17. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  18. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  19. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  20. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  21. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  22. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  23. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  24. onnx_diagnostic/reference/torch_ops/unary_ops.py +86 -0
  25. onnx_diagnostic/tasks/__init__.py +22 -1
  26. onnx_diagnostic/tasks/image_classification.py +2 -2
  27. onnx_diagnostic/tasks/text_generation.py +3 -3
  28. onnx_diagnostic/torch_export_patches/eval/__init__.py +106 -37
  29. onnx_diagnostic/torch_export_patches/eval/model_cases.py +12 -25
  30. onnx_diagnostic/torch_export_patches/patch_module_helper.py +130 -16
  31. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +88 -0
  32. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +142 -0
  33. onnx_diagnostic/torch_models/test_helper.py +133 -16
  34. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  35. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/METADATA +1 -1
  36. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/RECORD +39 -23
  37. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/WHEEL +1 -1
  38. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/licenses/LICENSE.txt +0 -0
  39. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,7 @@ import os
4
4
  import sys
5
5
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
6
  import time
7
+ import numpy as np
7
8
  import onnx
8
9
  import onnxscript
9
10
  import onnxscript.rewriter.ort_fusions as ort_fusions
@@ -17,6 +18,7 @@ from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
17
18
  from ..tasks import random_input_kwargs
18
19
  from ..torch_export_patches import torch_export_patches
19
20
  from ..torch_export_patches.patch_inputs import use_dyn_not_str
21
+ from ..reference import TorchOnnxEvaluator
20
22
  from .hghub import get_untrained_model_with_inputs
21
23
 
22
24
 
@@ -192,11 +194,16 @@ def _quiet_or_not_quiet(
192
194
  summary: Dict[str, Any],
193
195
  data: Optional[Dict[str, Any]],
194
196
  fct: Callable,
197
+ repeat: int = 1,
198
+ warmup: int = 0,
195
199
  ) -> Any:
196
200
  begin = time.perf_counter()
197
201
  if quiet:
198
202
  try:
199
- return fct()
203
+ res = fct()
204
+ summary[f"time_{suffix}"] = time.perf_counter() - begin
205
+ if warmup + repeat == 1:
206
+ return res
200
207
  except Exception as e:
201
208
  summary[f"ERR_{suffix}"] = str(e)
202
209
  summary[f"time_{suffix}"] = time.perf_counter() - begin
@@ -204,11 +211,45 @@ def _quiet_or_not_quiet(
204
211
  return {f"ERR_{suffix}": e}
205
212
  data[f"ERR_{suffix}"] = e
206
213
  return None
207
- res = fct()
214
+ else:
215
+ res = fct()
208
216
  summary[f"time_{suffix}"] = time.perf_counter() - begin
217
+ if warmup + repeat > 1:
218
+ if suffix == "run":
219
+ res = torch_deepcopy(res)
220
+ summary[f"{suffix}_output"] = string_type(res, with_shape=True, with_min_max=True)
221
+ summary[f"{suffix}_warmup"] = warmup
222
+ summary[f"{suffix}_repeat"] = repeat
223
+ for _w in range(max(0, warmup - 1)):
224
+ t = fct()
225
+ summary[f"io_{suffix}_{_w+1}"] = string_type(t, with_shape=True, with_min_max=True)
226
+ summary[f"time_{suffix}_warmup"] = time.perf_counter() - begin
227
+ times = []
228
+ for _r in range(repeat):
229
+ begin = time.perf_counter()
230
+ t = fct()
231
+ times.append(time.perf_counter() - begin)
232
+ a = np.array(times)
233
+ summary[f"time_{suffix}_latency"] = a.mean()
234
+ summary[f"time_{suffix}_latency_std"] = a.std()
235
+ summary[f"time_{suffix}_latency_min"] = a.min()
236
+ summary[f"time_{suffix}_latency_min"] = a.max()
209
237
  return res
210
238
 
211
239
 
240
+ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
241
+ """Shrinks the configuration before it gets added to the information to log."""
242
+ new_cfg = {}
243
+ for k, v in cfg.items():
244
+
245
+ new_cfg[k] = (
246
+ v
247
+ if (not isinstance(v, (list, tuple, set, dict)) or len(v) < 50)
248
+ else (v.__class__("...") if isinstance(v, (list, tuple)) else "...")
249
+ )
250
+ return new_cfg
251
+
252
+
212
253
  def validate_model(
213
254
  model_id: str,
214
255
  task: Optional[str] = None,
@@ -231,9 +272,14 @@ def validate_model(
231
272
  model_options: Optional[Dict[str, Any]] = None,
232
273
  subfolder: Optional[str] = None,
233
274
  opset: Optional[int] = None,
275
+ runtime: str = "onnxruntime",
276
+ repeat: int = 1,
277
+ warmup: int = 0,
234
278
  ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
235
279
  """
236
280
  Validates a model.
281
+ The function can also be called through the command line
282
+ :ref:`l-cmd-validate`.
237
283
 
238
284
  :param model_id: model id to validate
239
285
  :param task: task used to generate the necessary inputs,
@@ -241,7 +287,8 @@ def validate_model(
241
287
  if it can be determined
242
288
  :param do_run: checks the model works with the defined inputs
243
289
  :param exporter: exporter the model using this exporter,
244
- available list: ``export-strict``, ``export-nostrict``, ``onnx``
290
+ available list: ``export-strict``, ``export-nostrict``, ...
291
+ see below
245
292
  :param do_same: checks the discrepancies of the exported model
246
293
  :param verbose: verbosity level
247
294
  :param dtype: uses this dtype to check the model
@@ -267,6 +314,10 @@ def validate_model(
267
314
  ``num_hidden_layers`` or ``attn_implementation``
268
315
  :param subfolder: version or subfolders to uses when retrieving a model id
269
316
  :param opset: onnx opset to use for the conversion
317
+ :param runtime: onnx runtime to use to check about discrepancies,
318
+ only if `do_run` is true
319
+ :param repeat: number of time to measure the model
320
+ :param warmup: warmup the model first
270
321
  :return: two dictionaries, one with some metrics,
271
322
  another one with whatever the function produces
272
323
 
@@ -274,6 +325,20 @@ def validate_model(
274
325
  information:
275
326
 
276
327
  * ``PRINT_CONFIG``: prints the model configuration
328
+
329
+ The following exporters are available:
330
+
331
+ * ``export-nostrict``: run :func:`torch.export.export` (..., strict=False)
332
+ * ``onnx-dynamo``: run :func:`torch.onnx.export` (..., dynamo=True),
333
+ models can be optimized with ``optimization`` in ``("ir", "os_ort")``
334
+ * ``modelbuilder``: use :epkg:`ModelBuilder` to builds the onnx model
335
+ * ``custom``: custom exporter (see :epkg:`experimental-experiment`),
336
+ models can be optimized with ``optimization`` in
337
+ ``("default", "default+onnxruntime", "default+os_ort", "default+onnxruntime+os_ort")``
338
+
339
+ The default runtime, :epkg:`onnxruntime` is used to validate a model and check the
340
+ exported model returns the same outputs as the original one, otherwise,
341
+ :class:`onnx_diagnostic.reference.TorchOnnxEvaluator` is used.
277
342
  """
278
343
  assert (
279
344
  not rewrite or patch
@@ -295,6 +360,7 @@ def validate_model(
295
360
  version_ortfusiontype=ortfusiontype or "",
296
361
  version_stop_if_static=str(stop_if_static),
297
362
  version_exporter=exporter or "",
363
+ version_runtime=runtime,
298
364
  )
299
365
  )
300
366
  if opset:
@@ -436,7 +502,9 @@ def validate_model(
436
502
  if summary["model_module"] in sys.modules:
437
503
  summary["model_file"] = str(sys.modules[summary["model_module"]].__file__) # type: ignore[index]
438
504
  summary["model_config_class"] = data["configuration"].__class__.__name__
439
- summary["model_config"] = str(data["configuration"].to_dict()).replace(" ", "")
505
+ summary["model_config"] = str(shrink_config(data["configuration"].to_dict())).replace(
506
+ " ", ""
507
+ )
440
508
  summary["model_id"] = model_id
441
509
 
442
510
  if verbose:
@@ -460,7 +528,13 @@ def validate_model(
460
528
  model = data["model"]
461
529
 
462
530
  expected = _quiet_or_not_quiet(
463
- quiet, "run", summary, data, (lambda m=model, inp=inputs: m(**inp))
531
+ quiet,
532
+ "run",
533
+ summary,
534
+ data,
535
+ (lambda m=model, inp=inputs: m(**torch_deepcopy(inp))),
536
+ repeat=repeat,
537
+ warmup=warmup,
464
538
  )
465
539
  if "ERR_run" in summary:
466
540
  return summary, data
@@ -522,7 +596,7 @@ def validate_model(
522
596
 
523
597
  disc = max_diff(data["expected"], expected)
524
598
  for k, v in disc.items():
525
- summary[f"disc_patched_{k}"] = v
599
+ summary[f"disc_patched_{k}"] = str(v)
526
600
  if verbose:
527
601
  print("[validate_model] done (patched run)")
528
602
  print(f"[validate_model] patched discrepancies={string_diff(disc)}")
@@ -618,7 +692,14 @@ def validate_model(
618
692
  return summary, data
619
693
 
620
694
  if do_run:
621
- summary_valid, data = validate_onnx_model(data=data, quiet=quiet, verbose=verbose)
695
+ summary_valid, data = validate_onnx_model(
696
+ data=data,
697
+ quiet=quiet,
698
+ verbose=verbose,
699
+ runtime=runtime,
700
+ repeat=repeat,
701
+ warmup=warmup,
702
+ )
622
703
  summary.update(summary_valid)
623
704
 
624
705
  if ortfusiontype and "onnx_filename" in data:
@@ -671,7 +752,13 @@ def validate_model(
671
752
 
672
753
  if do_run:
673
754
  summary_valid, data = validate_onnx_model(
674
- data=data, quiet=quiet, verbose=verbose, flavour=flavour
755
+ data=data,
756
+ quiet=quiet,
757
+ verbose=verbose,
758
+ flavour=flavour,
759
+ runtime=runtime,
760
+ repeat=repeat,
761
+ warmup=warmup,
675
762
  )
676
763
  summary.update(summary_valid)
677
764
 
@@ -883,6 +970,9 @@ def validate_onnx_model(
883
970
  quiet: bool = False,
884
971
  verbose: int = 0,
885
972
  flavour: Optional[str] = None,
973
+ runtime: str = "onnxruntime",
974
+ repeat: int = 1,
975
+ warmup: int = 0,
886
976
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
887
977
  """
888
978
  Verifies that an onnx model produces the same
@@ -895,6 +985,9 @@ def validate_onnx_model(
895
985
  :param quiet: catch exception or not
896
986
  :param verbose: verbosity
897
987
  :param flavour: use a different version of the inputs
988
+ :param runtime: onnx runtime to use, onnxruntime or torch
989
+ :param repeat: run that number of times the model
990
+ :param warmup: warmup the model
898
991
  :return: two dictionaries, one with some metrics,
899
992
  another one with whatever the function produces
900
993
  """
@@ -936,18 +1029,28 @@ def validate_onnx_model(
936
1029
  f"{providers}..., flavour={flavour!r}"
937
1030
  )
938
1031
 
1032
+ cls_runtime = (
1033
+ (
1034
+ lambda model, providers: onnxruntime.InferenceSession(
1035
+ (model.SerializeToString() if isinstance(model, onnx.ModelProto) else model),
1036
+ providers=providers,
1037
+ )
1038
+ )
1039
+ if runtime == "onnxruntime"
1040
+ else (
1041
+ lambda model, providers: TorchOnnxEvaluator(
1042
+ model, providers=providers, verbose=max(verbose - 1, 0)
1043
+ )
1044
+ )
1045
+ )
939
1046
  sess = _quiet_or_not_quiet(
940
1047
  quiet,
941
- _mk("time_onnx_ort_create"),
1048
+ _mk("onnx_ort_create"),
942
1049
  summary,
943
1050
  data,
944
- (
945
- lambda source=source, providers=providers: onnxruntime.InferenceSession(
946
- source, providers=providers
947
- )
948
- ),
1051
+ (lambda source=source, providers=providers: cls_runtime(source, providers)),
949
1052
  )
950
- if f"ERR_{_mk('time_onnx_ort_create')}" in summary:
1053
+ if f"ERR_{_mk('onnx_ort_create')}" in summary:
951
1054
  return summary, data
952
1055
 
953
1056
  data[_mk("onnx_ort_sess")] = sess
@@ -975,6 +1078,8 @@ def validate_onnx_model(
975
1078
  summary,
976
1079
  data,
977
1080
  (lambda sess=sess, feeds=feeds: sess.run(None, feeds)),
1081
+ repeat=repeat,
1082
+ warmup=warmup,
978
1083
  )
979
1084
  if f"ERR_{_mk('time_onnx_ort_run')}" in summary:
980
1085
  return summary, data
@@ -1051,7 +1156,7 @@ def call_torch_export_onnx(
1051
1156
  dynamo=False,
1052
1157
  dynamic_axes={
1053
1158
  k: v
1054
- for k, v in CoupleInputsDynamicShapes(args, kwargs, ds)
1159
+ for k, v in CoupleInputsDynamicShapes(args, kwargs, ds) # type: ignore[arg-type]
1055
1160
  .replace_by_string()
1056
1161
  .items()
1057
1162
  if isinstance(v, dict)
@@ -1229,6 +1334,13 @@ def call_torch_export_custom(
1229
1334
  "custom-nostrict",
1230
1335
  "custom-nostrict-default",
1231
1336
  "custom-nostrict-all",
1337
+ "custom-inline",
1338
+ "custom-strict-inline",
1339
+ "custom-strict-default-inline",
1340
+ "custom-strict-all-inline",
1341
+ "custom-nostrict-inline",
1342
+ "custom-nostrict-default-inline",
1343
+ "custom-nostrict-all-inline",
1232
1344
  }
1233
1345
  assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
1234
1346
  assert "model" in data, f"model is missing from data: {sorted(data)}"
@@ -1269,6 +1381,10 @@ def call_torch_export_custom(
1269
1381
  ),
1270
1382
  save_ep=(os.path.join(dump_folder, f"{exporter}.ep") if dump_folder else None),
1271
1383
  )
1384
+ inline = "-inline" in exporter
1385
+ if inline:
1386
+ export_options.aten_as_function = set()
1387
+
1272
1388
  options = OptimizationOptions(patterns=optimization) if optimization else None
1273
1389
  model = data["model"]
1274
1390
  kws = dict(
@@ -1279,6 +1395,7 @@ def call_torch_export_custom(
1279
1395
  large_model=True,
1280
1396
  return_optimize_report=True,
1281
1397
  verbose=max(verbose - 2, 0),
1398
+ inline=inline,
1282
1399
  )
1283
1400
  if opset:
1284
1401
  kws["target_opset"] = opset
@@ -0,0 +1,289 @@
1
+ import enum
2
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
3
+ import onnx
4
+ import torch
5
+ from ..api import TensorLike
6
+ from ..helpers import string_type
7
+
8
+
9
+ class RuntimeValueKind(enum.IntEnum):
10
+ "Kind of result."
11
+
12
+ RESULT = 1
13
+ INITIALIZER = 3
14
+ INPUT = 5
15
+ OUTPUT = 9
16
+
17
+ def to_str(self) -> str:
18
+ for k, v in self.__class__.__dict__.items():
19
+ if v == int(self):
20
+ return k
21
+ raise RuntimeError(f"Unable to display {self!r}")
22
+
23
+
24
+ class RuntimeDevice(enum.IntEnum):
25
+ "Device definition"
26
+
27
+ UNKNOWN = 0
28
+ NEW = 1
29
+ CPU = 2
30
+ CUDA = 4
31
+
32
+ def to_str(self) -> str:
33
+ for k, v in self.__class__.__dict__.items():
34
+ if v == int(self):
35
+ return k
36
+ raise RuntimeError(f"Unable to display {self!r}")
37
+
38
+
39
+ class RuntimeValue:
40
+ """Describes a value used during the execution of a model."""
41
+
42
+ def __init__(
43
+ self,
44
+ name: str,
45
+ dtype: Optional[Any] = None,
46
+ shape: Optional[Tuple[Union[str, int], ...]] = None,
47
+ value: Optional[Any] = None,
48
+ first_used: Optional[int] = None,
49
+ last_used: Optional[int] = None,
50
+ created: Optional[int] = None,
51
+ is_shape: Optional[bool] = None,
52
+ kind: Optional[RuntimeValueKind] = None,
53
+ device: Optional[RuntimeDevice] = None,
54
+ ):
55
+ self.name = name
56
+ self.dtype = dtype
57
+ self.shape = shape
58
+ self.value = value
59
+ self.first_used = first_used
60
+ self.last_used = last_used
61
+ self.created = created
62
+ self.is_shape = is_shape
63
+ self.kind = kind
64
+ self.device = device
65
+
66
+ def __repr__(self) -> str:
67
+ "usual"
68
+ ad = {}
69
+ for att in [
70
+ "name",
71
+ "dtype",
72
+ "shape",
73
+ "first_used",
74
+ "last_used",
75
+ "is_shape",
76
+ "kind",
77
+ "created",
78
+ "device",
79
+ ]:
80
+ v = getattr(self, att)
81
+ if v is not None:
82
+ ad[att] = v
83
+ if self.value is not None:
84
+ ad["value"] = (
85
+ self.value.string_type()
86
+ if hasattr(self.value, "string_type")
87
+ else string_type(self.value, with_shape=True)
88
+ )
89
+ msg = ", ".join(
90
+ f"{name}={t.to_str()}" if hasattr(t, "to_str") else f"{name}={t}"
91
+ for name, t in ad.items()
92
+ )
93
+ return f"{self.__class__.__name__}({msg})"
94
+
95
+ @property
96
+ def has_value(self) -> bool:
97
+ "Tells if value is specified."
98
+ return self.value is not None
99
+
100
+ def string_type(self) -> str:
101
+ "Returns a string describing the value."
102
+ rows = []
103
+ if self.shape is not None:
104
+ rows.append(f"shape={self.shape}")
105
+ if self.is_shape is not None:
106
+ rows.append(f"is_shape={self.is_shape}")
107
+ if self.device is not None:
108
+ rows.append(f"device={self.device}")
109
+ text = f", {', '.join(rows)}" if rows else ""
110
+ if self.value is None:
111
+ return (
112
+ f"RuntimeValue(name={self.name!r}{text}"
113
+ f", dtype={self.dtype}, kind={self.kind})"
114
+ )
115
+ return (
116
+ f"RuntimeValue(name={self.name!r}, "
117
+ f"kind={self.kind}{text}, value={self.value.string_type()})"
118
+ )
119
+
120
+ def set_value(self, value: Union[torch.Tensor, TensorLike]):
121
+ """Sets the value."""
122
+ assert value is not None, "Use clean_value to set a value to None"
123
+ self.value = value
124
+ is_sequence = hasattr(value, "is_sequence") and value.is_sequence()
125
+ if self.dtype:
126
+ assert value is None or self.dtype == value.dtype, (
127
+ f"Unexpected dtype={value.dtype}, previous dtype was {self.dtype}, "
128
+ f"is_sequence={is_sequence}"
129
+ )
130
+ else:
131
+ self.dtype = value.dtype
132
+ self.shape = None if is_sequence else tuple(map(int, value.shape))
133
+
134
+ def clean_value(self):
135
+ """Sets value to None."""
136
+ self.value = None
137
+
138
+ @property
139
+ def is_output(self) -> bool:
140
+ "Tells if it is an output."
141
+ return self.kind == RuntimeValueKind.OUTPUT
142
+
143
+ @property
144
+ def is_input(self) -> bool:
145
+ "Tells if it is an input."
146
+ return self.kind == RuntimeValueKind.INPUT
147
+
148
+ @property
149
+ def is_initializer(self) -> bool:
150
+ "Tells if it is an initializer."
151
+ return self.kind == RuntimeValueKind.INITIALIZER
152
+
153
+
154
+ def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
155
+ """
156
+ Returns the hidden inputs (inputs coming from an upper context)
157
+ used by a subgraph.
158
+ """
159
+ hidden = set()
160
+ memo = (
161
+ set(i.name for i in graph.initializer)
162
+ | set(i.name for i in graph.sparse_initializer)
163
+ | set(i.name for i in graph.input)
164
+ )
165
+ for node in graph.node:
166
+ for i in node.input:
167
+ if i not in memo:
168
+ hidden.add(i)
169
+ for att in node.attribute:
170
+ if att.type == onnx.AttributeProto.GRAPH and att.g:
171
+ hid = get_hidden_inputs(att.g)
172
+ less = set(h for h in hid if h not in memo)
173
+ hidden |= less
174
+ memo |= set(node.output)
175
+ return hidden
176
+
177
+
178
+ def set_is_shape(
179
+ node: onnx.NodeProto, values: Dict[str, RuntimeValue], drop: Optional[Set[str]] = None
180
+ ) -> List[str]:
181
+ """
182
+ Sets attribute ``is_shape`` for outputs of a node.
183
+
184
+ :param node: node to process
185
+ :param values: stored results, values in this dictionary are updated
186
+ :param drop: variables not to consider because the come from the graph
187
+ holding this subgraph
188
+ :return: list of modified results
189
+ """
190
+ if not node.input:
191
+ # Constant
192
+ return []
193
+ drop = drop or set()
194
+ if node.op_type in ("Shape", "Size") and node.domain == "":
195
+ values[node.output[0]].is_shape = True
196
+ return [node.output[0]]
197
+ is_shapes = [values[i].is_shape for i in node.input if i not in drop]
198
+ if any(is_shapes):
199
+ if is_shapes[0] and len(node.output) == 1:
200
+ values[node.output[0]].is_shape = True
201
+ return [node.output[0]]
202
+ else:
203
+ for o in node.output:
204
+ values[o].is_shape = False
205
+ return list(node.output)
206
+ return []
207
+
208
+
209
+ def first_used_last_used(
210
+ proto: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto],
211
+ constant_as_initializer: bool = False,
212
+ ) -> Dict[str, RuntimeValue]:
213
+ """
214
+ Builds first used, last used information for every result
215
+ in the model.
216
+
217
+ :param proto: model, graph or function
218
+ :param constant_as_initializer: outputs of node Constant is tagged as INITIALIZER
219
+ :return: dictionary of RuntimeValue
220
+ """
221
+ values = {}
222
+ if isinstance(proto, onnx.ModelProto):
223
+ initializer = proto.graph.initializer
224
+ sparse_initializer = proto.graph.sparse_initializer
225
+ _input = proto.graph.input
226
+ output = proto.graph.output
227
+ _node = proto.graph.node
228
+ allow_unknown = False
229
+ elif isinstance(proto, onnx.GraphProto):
230
+ initializer = proto.initializer
231
+ sparse_initializer = proto.sparse_initializer
232
+ _input = proto.input
233
+ output = proto.output
234
+ _node = proto.node
235
+ allow_unknown = True
236
+ else:
237
+ initializer = []
238
+ sparse_initializer = []
239
+ _input = proto.input
240
+ output = proto.output
241
+ _node = proto.node
242
+ allow_unknown = False
243
+
244
+ for init in initializer:
245
+ values[init.name] = RuntimeValue(
246
+ init.name, kind=RuntimeValueKind.INITIALIZER, created=-1
247
+ )
248
+ for init in sparse_initializer:
249
+ values[init.name] = RuntimeValue(
250
+ init.name, created=-1, kind=RuntimeValueKind.INITIALIZER
251
+ )
252
+ for inp in _input:
253
+ n = inp if isinstance(inp, str) else inp.name
254
+ values[n] = RuntimeValue(n, created=-1, kind=RuntimeValueKind.INPUT)
255
+ drop = set()
256
+ for it, node in enumerate(_node):
257
+ for i in node.input:
258
+ if i not in values:
259
+ assert allow_unknown, f"Input {i!r} is unknown."
260
+ # This input comes from a context and the model is a GraphProto
261
+ drop.add(i)
262
+ continue
263
+ if values[i].first_used is None:
264
+ values[i].first_used = it
265
+ values[i].last_used = it
266
+ for att in node.attribute:
267
+ if att.type == onnx.AttributeProto.GRAPH:
268
+ for n in get_hidden_inputs(att.g):
269
+ if values[n].first_used is None:
270
+ values[n].first_used = it
271
+ values[n].last_used = it
272
+ is_constant = node.op_type == "Constant" and node.domain == ""
273
+ for o in node.output:
274
+ values[o] = RuntimeValue(
275
+ o,
276
+ created=it,
277
+ kind=(
278
+ RuntimeValueKind.INITIALIZER
279
+ if is_constant and constant_as_initializer
280
+ else RuntimeValueKind.RESULT
281
+ ),
282
+ )
283
+ set_is_shape(node, values, drop=drop)
284
+
285
+ for out in output:
286
+ n = out if isinstance(out, str) else out.name
287
+ values[n].kind = RuntimeValueKind.OUTPUT
288
+ values[n].last_used = len(_node)
289
+ return values
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-diagnostic
3
- Version: 0.6.0
3
+ Version: 0.6.2
4
4
  Summary: Investigate ONNX models
5
5
  Home-page: https://github.com/sdpython/onnx-diagnostic
6
6
  Author: Xavier Dupré