onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +387 -12
  3. onnx_diagnostic/export/api.py +118 -5
  4. onnx_diagnostic/export/control_flow.py +214 -0
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +135 -0
  7. onnx_diagnostic/export/onnx_plug.py +396 -0
  8. onnx_diagnostic/ext_test_case.py +118 -25
  9. onnx_diagnostic/helpers/cache_helper.py +218 -204
  10. onnx_diagnostic/helpers/dot_helper.py +210 -0
  11. onnx_diagnostic/helpers/helper.py +92 -26
  12. onnx_diagnostic/helpers/log_helper.py +26 -4
  13. onnx_diagnostic/helpers/mini_onnx_builder.py +57 -3
  14. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  15. onnx_diagnostic/helpers/onnx_helper.py +115 -16
  16. onnx_diagnostic/helpers/ort_session.py +37 -11
  17. onnx_diagnostic/helpers/rt_helper.py +547 -0
  18. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  19. onnx_diagnostic/helpers/torch_helper.py +108 -6
  20. onnx_diagnostic/reference/ort_evaluator.py +233 -28
  21. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  22. onnx_diagnostic/tasks/image_text_to_text.py +5 -1
  23. onnx_diagnostic/tasks/summarization.py +72 -137
  24. onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
  25. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
  26. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  34. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  35. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  36. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
  37. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  38. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  39. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  40. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  41. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +65 -2107
  42. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
  43. onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
  44. onnx_diagnostic/torch_models/validate.py +50 -1
  45. onnx_diagnostic/torch_onnx/sbs.py +963 -312
  46. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
  47. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
  48. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +51 -30
  49. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
  50. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  51. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,396 @@
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
+ import onnx
5
+ import torch
6
+ from ..helpers import max_diff
7
+ from ..helpers.torch_helper import torch_dtype_to_onnx_dtype
8
+ from ..reference import OnnxruntimeEvaluator
9
+
10
+ TUPLE_TENSORS = Tuple[torch.Tensor, ...]
11
+
12
+
13
+ def is_exporting() -> bool:
14
+ """
15
+ Returns :func:`torch.compiler.is_exporting` or
16
+ :func:`torch.compiler.is_compiling`.
17
+ Changes ``_TEST_EXPORT`` to make it trigger.
18
+ """
19
+ return torch.compiler.is_exporting() or torch.compiler.is_compiling()
20
+
21
+
22
+ @dataclass
23
+ class VerifyResult:
24
+ """
25
+ Outputs of method :meth:`verify
26
+ <onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx.verify>`.
27
+ """
28
+
29
+ eager_outputs: TUPLE_TENSORS
30
+ onnx_outputs: TUPLE_TENSORS
31
+ diffs: Tuple[Dict[str, float], ...]
32
+
33
+
34
+ class EagerDirectReplacementWithOnnx:
35
+ """
36
+ Replaces a piece of code by another one written in ONNX
37
+ at export time. The function inserts a custom operator
38
+ and links it to the eager_fn
39
+
40
+ :param eager_fn: the code it replaces, it must be given in order to be able
41
+ to execute the torch.fx.Graph the exporter produces
42
+ :param shape_fn: the function produces dummy outputs with the shapes
43
+ the exporter can use for the next operators in the graph
44
+ :param function_proto: instances of ``onnx.FunctionProto``,
45
+ its domain must be ``onnx_plug``
46
+ :param n_inputs: number of inputs of the function, if not given,
47
+ the class will infer it from eager_fn signature,
48
+ only tensors must be counted
49
+ :param n_outputs: same for the number of outputs,
50
+ only tensors must be counted
51
+ :param name: the name of the custom op, the function name if not specified
52
+ :param kwargs: constants parameters with their default values
53
+ :param verbose: verbose level
54
+
55
+ Here is an example:
56
+
57
+ .. runpython::
58
+ :showcode:
59
+
60
+ import onnx.helper as oh
61
+ import torch
62
+ from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
63
+ from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
64
+ from onnx_diagnostic.export.api import to_onnx
65
+ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
66
+
67
+
68
+ def demo_customsub(x, y):
69
+ return x - y
70
+
71
+
72
+ def demo_customsub_shape(x, y):
73
+ return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)
74
+
75
+
76
+ def make_function_proto():
77
+ return oh.make_function(
78
+ "onnx_plug",
79
+ "demo_customsub",
80
+ ["x", "y"],
81
+ ["z"],
82
+ [oh.make_node("Sub", ["x", "y"], ["z"])],
83
+ opset_imports=[oh.make_opsetid("", 22)],
84
+ )
85
+
86
+
87
+ class Model(torch.nn.Module):
88
+ def forward(self, x):
89
+ y = x.sum(axis=1, keepdim=True)
90
+ d = torch.ops.onnx_plug.demo_customsub(x, y)
91
+ return torch.abs(d)
92
+
93
+
94
+ replacements = [
95
+ EagerDirectReplacementWithOnnx(
96
+ demo_customsub, demo_customsub_shape, make_function_proto(), 2, 1
97
+ )
98
+ ]
99
+
100
+ x = torch.randn((3, 4), dtype=torch.float32)
101
+ model = Model()
102
+ ds = ({0: "d1", 1: "d2"},)
103
+
104
+ # The exported program shows a custom op.
105
+ ep = torch.export.export(model, (x,), dynamic_shapes=use_dyn_not_str(ds))
106
+ print("ep")
107
+
108
+ # As the exporter knows how the replace this custom op.
109
+ # Let's export.
110
+
111
+ onx = to_onnx(
112
+ model,
113
+ (x,),
114
+ dynamic_shapes=ds,
115
+ exporter="custom",
116
+ onnx_plugs=replacements,
117
+ target_opset=22,
118
+ inline=False,
119
+ ).model_proto
120
+
121
+ print(pretty_onnx(onx))
122
+
123
+ # And with :func:`torch.onnx.export`:
124
+
125
+ onx = to_onnx(
126
+ model,
127
+ (x,),
128
+ dynamic_shapes=ds,
129
+ exporter="onnx-dynamo",
130
+ onnx_plugs=replacements,
131
+ target_opset=22,
132
+ inline=False,
133
+ ).model_proto
134
+
135
+ print(pretty_onnx(onx))
136
+ """
137
+
138
+ def __init__(
139
+ self,
140
+ eager_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
141
+ shape_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
142
+ function_proto: onnx.FunctionProto,
143
+ n_inputs: Optional[int] = None,
144
+ n_outputs: Optional[int] = None,
145
+ name: Optional[str] = None,
146
+ kwargs: Optional[Dict[str, Union[int, float]]] = None,
147
+ verbose: int = 0,
148
+ ):
149
+ assert isinstance(
150
+ function_proto, onnx.FunctionProto
151
+ ), f"Unexpected type {type(function_proto)} for function_proto"
152
+ assert isinstance(n_inputs, int), f"not implemented yet when n_inputs={n_inputs}"
153
+ assert isinstance(n_outputs, int), f"not implemented yet when n_inputs={n_outputs}"
154
+ self.eager_fn = eager_fn
155
+ self.shape_fn = shape_fn
156
+ self.function_proto = function_proto
157
+ self.n_inputs = n_inputs
158
+ self.n_outputs = n_outputs
159
+ self.name = name or (
160
+ eager_fn.__name__
161
+ if "<" not in eager_fn.__name__
162
+ else eager_fn.__qualname__.replace("<locals>", "L")
163
+ .replace("<lambda>", "l")
164
+ .replace(".", "_")
165
+ )
166
+ self.kwargs = kwargs or {}
167
+ assert all(isinstance(v, (int, float)) for v in self.kwargs.values()), (
168
+ f"Only int or floats are allowed for kwargs={kwargs}, one of them "
169
+ f"does not respect that constraint."
170
+ )
171
+ sig = inspect.signature(self.eager_fn)
172
+ params = list(sig.parameters)
173
+ assert (
174
+ len(params) >= n_inputs
175
+ ), f"{self.eager_fn} accepts {params} as parameters < n_inputs={n_inputs}"
176
+ assert n_inputs == len(function_proto.input), (
177
+ f"Input mismatch n_inputs={n_inputs} but "
178
+ f"function_proto.input={function_proto.input}"
179
+ )
180
+ assert n_outputs == len(function_proto.output), (
181
+ f"Output mismatch n_outputs={n_outputs} but "
182
+ f"function_proto.output={function_proto.output}"
183
+ )
184
+ assert (
185
+ function_proto.domain == self.domain
186
+ ), f"Function domain must be {self.domain!r} but it is {function_proto.domain!r}"
187
+ self.args_name = [p for p in params if p not in self.kwargs]
188
+ self.kwargs_name = [p for p in params if p in self.kwargs]
189
+ self.verbose = verbose
190
+ self.custom_op = self._register()
191
+
192
+ @property
193
+ def domain(self) -> str:
194
+ "Returns the onnx domain."
195
+ return "onnx_plug"
196
+
197
+ @property
198
+ def target_name(self) -> str:
199
+ "Returns the target name (see in the exported program)."
200
+ return f"{self.domain}::{self.name}"
201
+
202
+ @property
203
+ def torch_op(self) -> Callable:
204
+ "Returns ``torch.ops.onny_plug.<name>``."
205
+ return getattr(getattr(torch.ops, self.domain), self.name).default
206
+
207
+ def __call__(self, *args, **kwargs):
208
+ """Calls eager_fn or shape_fn if the model is being exported."""
209
+ if is_exporting():
210
+ return self.torch_op(*args)
211
+ return self.eager_fn(*args, **kwargs)
212
+
213
+ def _register(self):
214
+ """Registers the custom op."""
215
+ input_args = [f"Tensor {p}" for p in self.args_name]
216
+ for p in self.kwargs_name:
217
+ val = self.kwargs[p]
218
+ if isinstance(val, int):
219
+ input_args.append(f"int {p}={val}")
220
+ elif isinstance(val, float):
221
+ input_args.append(f"float {p}={val}")
222
+ else:
223
+ raise NotImplementedError(
224
+ f"kwargs {p!r} has a default value of unsupported type {type(val)}"
225
+ )
226
+
227
+ inputs = ", ".join(input_args)
228
+ schema = f"({inputs}) -> Tensor"
229
+ if self.n_outputs > 1:
230
+ schema += "[]"
231
+ if self.verbose:
232
+ print(
233
+ f"[EagerDirectReplacementWithOnnx._register] "
234
+ f"'torch.ops.{self.domain}.{self.name}"
235
+ )
236
+ print(f"[EagerDirectReplacementWithOnnx._register] schema={schema}")
237
+ custom_def = torch.library.CustomOpDef(self.domain, self.name, schema, self.eager_fn)
238
+ custom_def.register_kernel(None)(self.eager_fn)
239
+ custom_def._abstract_fn = self.shape_fn
240
+
241
+ def verify(
242
+ self,
243
+ *args,
244
+ engine: Optional[Callable] = None,
245
+ dump_onnx_model: Optional[str] = None,
246
+ **kwargs,
247
+ ) -> VerifyResult:
248
+ """
249
+ Verifies that the eager mode is equivalent to the onnx function given
250
+ as a replacements. This function evaluates `eager_fn`, checks that the shapes
251
+ are equivalent to the ones given by `shape_fn`, and finally evaluates the
252
+ onnx translation if the previous did not fail.
253
+
254
+ :param args: function inputs
255
+ :param kwargs: arguments for eager_fn
256
+ :param engine: by default an instance of
257
+ :class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`.
258
+ :param dump_onnx_model: to dump the onnx model used to verify
259
+ eager and onnx produce the same results
260
+ :param kwargs: additional arguments to the function
261
+ :return: outputs of :func:`onnx_diagnostic.helpers.max_diff`
262
+ """
263
+ expected = self.eager_fn(*args, **kwargs)
264
+ shapes = self.shape_fn(*args, **kwargs)
265
+ if isinstance(expected, torch.Tensor):
266
+ expected = (expected,)
267
+ assert isinstance(shapes, torch.Tensor), (
268
+ f"eager_fn={self.eager_fn} returns a Tensor but shape_fn={self.shape_fn} "
269
+ f"returns a {type(shapes)}"
270
+ )
271
+ shapes = (shapes,)
272
+ assert isinstance(expected, tuple) and isinstance(shapes, tuple), (
273
+ f"eager_fn={self.eager_fn} returns a {type(expected)} "
274
+ f"and shape_fn={self.shape_fn} returns a {type(shapes)}"
275
+ )
276
+ assert len(expected) and len(shapes), (
277
+ f"eager_fn={self.eager_fn} and shape_fn={self.shape_fn} "
278
+ f"do not return the same number of tensors."
279
+ )
280
+ for i, (e, s) in enumerate(zip(expected, shapes)):
281
+ assert e.dtype == s.dtype, (
282
+ f"Type mismatch {e.dtype} != {s.dtype} for output {i}, "
283
+ f"eager_fn={self.eager_fn} and shape_fn={self.shape_fn}"
284
+ )
285
+ assert e.shape == s.shape, (
286
+ f"Type mismatch {e.shape} != {s.shape} for output {i}, "
287
+ f"eager_fn={self.eager_fn} and shape_fn={self.shape_fn}"
288
+ )
289
+
290
+ # Now the ONNX execution.
291
+ assert engine is None, f"Not implemented yet with engine={engine!r}"
292
+ ags, kws = self._make_args_kwargs(*args, **kwargs)
293
+ sess = OnnxruntimeEvaluator(
294
+ self.function_proto,
295
+ whole=True,
296
+ dump_onnx_model=dump_onnx_model,
297
+ function_kwargs=kws,
298
+ )
299
+ feeds = dict(zip(sess.input_names, ags))
300
+ got = sess.run(None, feeds)
301
+ diffs = tuple(max_diff(e, g, hist=[0.1, 0.01]) for e, g in zip(expected, got))
302
+ return VerifyResult(eager_outputs=expected, onnx_outputs=tuple(got), diffs=diffs) # type: ignore[arg-type]
303
+
304
+ def _make_args_kwargs(self, *args, **kwargs):
305
+ ags = args[: len(self.args_name)]
306
+ kws = dict(zip(self.kwargs_name, args[len(self.args_name) :]))
307
+ kws.update(kwargs)
308
+ return ags, kws
309
+
310
+ def custom_converter(
311
+ self,
312
+ ) -> Callable:
313
+ """
314
+ Returns a function which
315
+ converts a custom ops found in the fx graph into ONNX
316
+ following the API of the custom exporter.
317
+ The converter adds a custom op and registers the local function.
318
+ """
319
+
320
+ def converter(
321
+ g: Any, # GraphBuilder
322
+ sts: Optional[Dict[str, Any]],
323
+ outputs: List[str],
324
+ *args,
325
+ **kwargs,
326
+ ) -> Any:
327
+ if not g.has_local_function(
328
+ self.function_proto.name, domain=self.function_proto.domain
329
+ ):
330
+ g.add_function(self.function_proto)
331
+ ags, kws = self._make_args_kwargs(*args, **kwargs)
332
+ res = g.make_node(
333
+ self.function_proto.name,
334
+ ags,
335
+ outputs,
336
+ domain=self.function_proto.domain,
337
+ name=self.target_name,
338
+ **kws,
339
+ )
340
+ if not sts:
341
+ new_shapes = self.shape_fn(*args)
342
+ if not isinstance(new_shapes, tuple):
343
+ new_shapes = (new_shapes,)
344
+ for sh, o in zip(new_shapes, outputs):
345
+ g.set_type(o, torch_dtype_to_onnx_dtype(sh.dtype))
346
+ g.set_shape(o, sh.shape)
347
+ return res
348
+
349
+ return converter
350
+
351
+ def onnx_dynamo_converter(self) -> Callable:
352
+ """
353
+ Returns a function which
354
+ which converts a custom ops found in the fx graph into ONNX
355
+ following the API of :func:`torch.onnx.export`.
356
+ """
357
+ import onnxscript
358
+
359
+ onnx_plug_op = onnxscript.values.Opset(domain=self.function_proto.domain, version=1)
360
+ schema = onnx_plug_op[self.function_proto.name]
361
+ if schema is None:
362
+ all_types = [
363
+ "tensor(float)",
364
+ "tensor(float16)",
365
+ "tensor(bfloat16)",
366
+ "tensor(double)",
367
+ "tensor(int64)",
368
+ "tensor(int32)",
369
+ ]
370
+ type_constraints = []
371
+ for i in range(self.n_inputs):
372
+ type_constraints.append((f"T{i}", all_types, ""))
373
+ for i in range(self.n_outputs):
374
+ type_constraints.append((f"U{i}", all_types, ""))
375
+ schema = onnx.defs.OpSchema(
376
+ self.function_proto.name,
377
+ self.function_proto.domain,
378
+ 1,
379
+ inputs=[
380
+ onnx.defs.OpSchema.FormalParameter(f"arg_{i}", f"T{i}")
381
+ for i in range(self.n_inputs)
382
+ ],
383
+ outputs=[
384
+ onnx.defs.OpSchema.FormalParameter(f"res_{i}", f"U{i}")
385
+ for i in range(self.n_outputs)
386
+ ],
387
+ type_constraints=type_constraints,
388
+ )
389
+ onnx.defs.register_schema(schema)
390
+ op = onnxscript.values.Op(onnx_plug_op, self.function_proto.name, schema)
391
+
392
+ def converter(*cargs, **ckwargs):
393
+ ags, kws = self._make_args_kwargs(*cargs, **ckwargs)
394
+ return op(*ags, n_outputs=self.n_outputs, **kws)
395
+
396
+ return onnxscript.values.TracedOnnxFunction(onnx_plug_op, converter)
@@ -9,6 +9,7 @@ import itertools
9
9
  import logging
10
10
  import os
11
11
  import re
12
+ import shutil
12
13
  import sys
13
14
  import unittest
14
15
  import warnings
@@ -63,7 +64,7 @@ def skipif_ci_apple(msg) -> Callable:
63
64
  return lambda x: x
64
65
 
65
66
 
66
- def unit_test_going():
67
+ def unit_test_going() -> bool:
67
68
  """
68
69
  Enables a flag telling the script is running while testing it.
69
70
  Avois unit tests to be very long.
@@ -147,7 +148,7 @@ def hide_stdout(f: Optional[Callable] = None) -> Callable:
147
148
 
148
149
  def wrapper(fct):
149
150
  def call_f(self):
150
- if os.environ.get("UNHIDE", ""):
151
+ if os.environ.get("UNHIDE", "") in (1, "1", "True", "true"):
151
152
  fct(self)
152
153
  return
153
154
  st = StringIO()
@@ -742,8 +743,15 @@ class ExtTestCase(unittest.TestCase):
742
743
  _warns: List[Tuple[str, int, Warning]] = []
743
744
  _todos: List[Tuple[Callable, str]] = []
744
745
 
746
+ def unit_test_going(self) -> bool:
747
+ """
748
+ Enables a flag telling the script is running while testing it.
749
+ Avois unit tests to be very long.
750
+ """
751
+ return unit_test_going()
752
+
745
753
  @property
746
- def verbose(self):
754
+ def verbose(self) -> int:
747
755
  "Returns the the value of environment variable ``VERBOSE``."
748
756
  return int(os.environ.get("VERBOSE", "0"))
749
757
 
@@ -768,13 +776,13 @@ class ExtTestCase(unittest.TestCase):
768
776
  cls._todos.append((f, msg))
769
777
 
770
778
  @classmethod
771
- def ort(cls):
779
+ def ort(cls) -> unittest.__class__:
772
780
  import onnxruntime
773
781
 
774
782
  return onnxruntime
775
783
 
776
784
  @classmethod
777
- def to_onnx(self, *args, **kwargs):
785
+ def to_onnx(self, *args, **kwargs) -> "ModelProto": # noqa: F821
778
786
  from experimental_experiment.torch_interpreter import to_onnx
779
787
 
780
788
  return to_onnx(*args, **kwargs)
@@ -806,12 +814,16 @@ class ExtTestCase(unittest.TestCase):
806
814
  os.makedirs(folder)
807
815
  return folder
808
816
 
809
- def dump_onnx(
810
- self,
811
- name: str,
812
- proto: Any,
813
- folder: Optional[str] = None,
814
- ) -> str:
817
+ def clean_dump(self, folder: str = "dump_test"):
818
+ """Cleans this folder."""
819
+ for item in os.listdir(folder):
820
+ item_path = os.path.join(folder, item)
821
+ if os.path.isfile(item_path) or os.path.islink(item_path):
822
+ os.remove(item_path)
823
+ elif os.path.isdir(item_path):
824
+ shutil.rmtree(item_path)
825
+
826
+ def dump_onnx(self, name: str, proto: Any, folder: Optional[str] = None) -> str:
815
827
  """Dumps an onnx file."""
816
828
  fullname = self.get_dump_file(name, folder=folder)
817
829
  with open(fullname, "wb") as f:
@@ -1094,10 +1106,15 @@ class ExtTestCase(unittest.TestCase):
1094
1106
  value = numpy.array(value).astype(expected.dtype)
1095
1107
  self.assertEqualArray(expected, value, atol=atol, rtol=rtol)
1096
1108
 
1097
- def check_ort(self, onx: "onnx.ModelProto") -> bool: # noqa: F821
1109
+ def check_ort(
1110
+ self, onx: "onnx.ModelProto" # noqa: F821
1111
+ ) -> "onnxruntime.InferenceSession": # noqa: F821
1098
1112
  from onnxruntime import InferenceSession
1099
1113
 
1100
- return InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
1114
+ return InferenceSession(
1115
+ onx if isinstance(onx, str) else onx.SerializeToString(),
1116
+ providers=["CPUExecutionProvider"],
1117
+ )
1101
1118
 
1102
1119
  def assertRaise(self, fct: Callable, exc_type: type[Exception], msg: Optional[str] = None):
1103
1120
  """In the name"""
@@ -1137,7 +1154,7 @@ class ExtTestCase(unittest.TestCase):
1137
1154
  if not full.endswith(suffix):
1138
1155
  raise AssertionError(f"suffix={suffix!r} does not end string {full!r}.")
1139
1156
 
1140
- def capture(self, fct: Callable):
1157
+ def capture(self, fct: Callable) -> Tuple[Any, str, str]:
1141
1158
  """
1142
1159
  Runs a function and capture standard output and error.
1143
1160
 
@@ -1188,6 +1205,8 @@ class ExtTestCase(unittest.TestCase):
1188
1205
  copy_inputs: bool = True,
1189
1206
  expected: Optional[Any] = None,
1190
1207
  use_ort: bool = False,
1208
+ ort_optimized_graph: bool = False,
1209
+ ep: Optional[Union["torch.export.ExportedProgram", str]] = None, # noqa: F821
1191
1210
  **kwargs,
1192
1211
  ):
1193
1212
  """
@@ -1206,6 +1225,8 @@ class ExtTestCase(unittest.TestCase):
1206
1225
  :param expected: expected values
1207
1226
  :param copy_inputs: to copy the inputs
1208
1227
  :param use_ort: use :class:`onnxruntime.InferenceSession`
1228
+ :param ort_optimized_graph: dumps the optimized onnxruntime graph
1229
+ :param ep: exported program (or saved exported program)
1209
1230
  :param kwargs: arguments sent to
1210
1231
  :class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`
1211
1232
  """
@@ -1214,33 +1235,56 @@ class ExtTestCase(unittest.TestCase):
1214
1235
  from .helpers.ort_session import InferenceSessionForTorch
1215
1236
 
1216
1237
  kws = dict(with_shape=True, with_min_max=verbose > 1)
1217
- if verbose:
1218
- vname = test_name or "assert_onnx_disc"
1238
+ vname = test_name or "assert_onnx_disc"
1219
1239
  if test_name:
1240
+ import onnx
1241
+
1220
1242
  name = f"{test_name}.onnx"
1221
- print(f"[{vname}] save the onnx model into {name!r}")
1222
- name = self.dump_onnx(name, proto)
1223
- print(f"[{vname}] file size {os.stat(name).st_size // 2**10:1.3f} kb")
1243
+ if verbose:
1244
+ print(f"[{vname}] save the onnx model into {name!r}")
1245
+ if isinstance(proto, str):
1246
+ name = proto
1247
+ proto = onnx.load(name)
1248
+ elif not self.unit_test_going():
1249
+ assert isinstance(
1250
+ proto, onnx.ModelProto
1251
+ ), f"Unexpected type {type(proto)} for proto"
1252
+ name = self.dump_onnx(name, proto)
1253
+ if verbose and not self.unit_test_going():
1254
+ print(f"[{vname}] file size {os.stat(name).st_size // 2**10:1.3f} kb")
1224
1255
  if verbose:
1225
1256
  print(f"[{vname}] make feeds {string_type(inputs, **kws)}")
1257
+
1226
1258
  if use_ort:
1259
+ assert isinstance(
1260
+ proto, onnx.ModelProto
1261
+ ), f"Unexpected type {type(proto)} for proto"
1227
1262
  feeds = make_feeds(proto, inputs, use_numpy=True, copy=True)
1228
- if verbose:
1229
- print(f"[{vname}] feeds {string_type(feeds, **kws)}")
1230
1263
  import onnxruntime
1231
1264
 
1265
+ options = onnxruntime.SessionOptions()
1266
+ if ort_optimized_graph:
1267
+ options.optimized_model_filepath = f"{name}.optort.onnx"
1268
+ providers = kwargs.get("providers", ["CPUExecutionProvider"])
1269
+ if verbose:
1270
+ print(f"[{vname}] create onnxruntime.InferenceSession with {providers}")
1232
1271
  sess = onnxruntime.InferenceSession(
1233
- proto.SerializeToString(), providers=["CPUExecutionProvider"]
1272
+ proto.SerializeToString(), options, providers=providers
1234
1273
  )
1274
+ if verbose:
1275
+ print(f"[{vname}] run ort feeds {string_type(feeds, **kws)}")
1235
1276
  got = sess.run(None, feeds)
1236
1277
  else:
1237
1278
  feeds = make_feeds(proto, inputs, copy=True)
1238
1279
  if verbose:
1239
- print(f"[{vname}] feeds {string_type(feeds, **kws)}")
1280
+ print(f"[{vname}] create InferenceSessionForTorch")
1240
1281
  sess = InferenceSessionForTorch(proto, **kwargs)
1282
+ if verbose:
1283
+ print(f"[{vname}] run orttorch feeds {string_type(feeds, **kws)}")
1241
1284
  got = sess.run(None, feeds)
1242
1285
  if verbose:
1243
1286
  print(f"[{vname}] compute expected values")
1287
+
1244
1288
  if expected is None:
1245
1289
  if copy_inputs:
1246
1290
  expected = (
@@ -1250,10 +1294,46 @@ class ExtTestCase(unittest.TestCase):
1250
1294
  )
1251
1295
  else:
1252
1296
  expected = model(*inputs) if isinstance(inputs, tuple) else model(**inputs)
1297
+
1253
1298
  if verbose:
1254
1299
  print(f"[{vname}] expected {string_type(expected, **kws)}")
1255
1300
  print(f"[{vname}] obtained {string_type(got, **kws)}")
1256
- diff = max_diff(expected, got, flatten=True)
1301
+
1302
+ if ep:
1303
+ if isinstance(ep, str):
1304
+ if verbose:
1305
+ print(f"[{vname}] load exported program {ep!r}")
1306
+ import torch
1307
+
1308
+ ep = torch.export.load(ep)
1309
+ ep_inputs = copy.deepcopy(inputs) if copy_inputs else inputs
1310
+ ep_model = ep.module() # type: ignore[union-attr]
1311
+ ep_expected = (
1312
+ ep_model(*copy.deepcopy(ep_inputs))
1313
+ if isinstance(ep_inputs, tuple)
1314
+ else ep_model(**copy.deepcopy(ep_inputs))
1315
+ )
1316
+ if verbose:
1317
+ print(f"[{vname}] ep_expected {string_type(ep_expected, **kws)}")
1318
+ ep_diff = max_diff(expected, ep_expected, hist=[0.1, 0.01])
1319
+ if verbose:
1320
+ print(f"[{vname}] ep_diff {string_diff(ep_diff)}")
1321
+ assert (
1322
+ isinstance(ep_diff["abs"], float)
1323
+ and isinstance(ep_diff["rel"], float)
1324
+ and not numpy.isnan(ep_diff["abs"])
1325
+ and ep_diff["abs"] <= atol
1326
+ and not numpy.isnan(ep_diff["rel"])
1327
+ and ep_diff["rel"] <= rtol
1328
+ ), (
1329
+ f"discrepancies in {test_name!r} between the exported program "
1330
+ f"and the exported model diff={string_diff(ep_diff)}"
1331
+ )
1332
+ ep_nx_diff = max_diff(ep_expected, got, flatten=True, hist=[0.1, 0.01])
1333
+ if verbose:
1334
+ print(f"[{vname}] ep_nx_diff {string_diff(ep_nx_diff)}")
1335
+
1336
+ diff = max_diff(expected, got, flatten=True, hist=[0.1, 0.01])
1257
1337
  if verbose:
1258
1338
  print(f"[{vname}] diff {string_diff(diff)}")
1259
1339
  assert (
@@ -1263,7 +1343,10 @@ class ExtTestCase(unittest.TestCase):
1263
1343
  and diff["abs"] <= atol
1264
1344
  and not numpy.isnan(diff["rel"])
1265
1345
  and diff["rel"] <= rtol
1266
- ), f"discrepancies in {test_name!r}, diff={string_diff(diff)}"
1346
+ ), (
1347
+ f"discrepancies in {test_name!r} between the model and "
1348
+ f"the onnx model diff={string_diff(diff)}"
1349
+ )
1267
1350
 
1268
1351
  def _debug(self):
1269
1352
  "Tells if DEBUG=1 is set up."
@@ -1274,6 +1357,16 @@ class ExtTestCase(unittest.TestCase):
1274
1357
 
1275
1358
  return string_type(*args, **kwargs)
1276
1359
 
1360
+ def max_diff(self, *args, **kwargs):
1361
+ from .helpers import max_diff
1362
+
1363
+ return max_diff(*args, **kwargs)
1364
+
1365
+ def use_dyn_not_str(self, *args, **kwargs):
1366
+ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
1367
+
1368
+ return use_dyn_not_str(*args, *kwargs)
1369
+
1277
1370
  def subloop(self, *args, verbose: int = 0):
1278
1371
  "Loops over elements and calls :meth:`unittests.TestCase.subTest`."
1279
1372
  if len(args) == 1: