onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +387 -12
  3. onnx_diagnostic/export/api.py +91 -8
  4. onnx_diagnostic/export/control_flow.py +48 -345
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +3 -3
  7. onnx_diagnostic/export/onnx_plug.py +396 -0
  8. onnx_diagnostic/ext_test_case.py +92 -23
  9. onnx_diagnostic/helpers/cache_helper.py +1 -1
  10. onnx_diagnostic/helpers/dot_helper.py +210 -0
  11. onnx_diagnostic/helpers/helper.py +90 -26
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  14. onnx_diagnostic/helpers/onnx_helper.py +103 -1
  15. onnx_diagnostic/helpers/ort_session.py +37 -11
  16. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  17. onnx_diagnostic/helpers/torch_helper.py +103 -6
  18. onnx_diagnostic/reference/ort_evaluator.py +233 -28
  19. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  20. onnx_diagnostic/tasks/summarization.py +72 -137
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
  22. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  23. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  24. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
  36. onnx_diagnostic/torch_models/validate.py +50 -1
  37. onnx_diagnostic/torch_onnx/sbs.py +963 -312
  38. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
  39. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
  40. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +43 -24
  41. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
  42. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  43. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,396 @@
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
+ import onnx
5
+ import torch
6
+ from ..helpers import max_diff
7
+ from ..helpers.torch_helper import torch_dtype_to_onnx_dtype
8
+ from ..reference import OnnxruntimeEvaluator
9
+
10
+ TUPLE_TENSORS = Tuple[torch.Tensor, ...]
11
+
12
+
13
+ def is_exporting() -> bool:
14
+ """
15
+ Returns :func:`torch.compiler.is_exporting` or
16
+ :func:`torch.compiler.is_compiling`.
17
+ Changes ``_TEST_EXPORT`` to make it trigger.
18
+ """
19
+ return torch.compiler.is_exporting() or torch.compiler.is_compiling()
20
+
21
+
22
+ @dataclass
23
+ class VerifyResult:
24
+ """
25
+ Outputs of method :meth:`verify
26
+ <onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx.verify>`.
27
+ """
28
+
29
+ eager_outputs: TUPLE_TENSORS
30
+ onnx_outputs: TUPLE_TENSORS
31
+ diffs: Tuple[Dict[str, float], ...]
32
+
33
+
34
+ class EagerDirectReplacementWithOnnx:
35
+ """
36
+ Replaces a piece of code by another one written in ONNX
37
+ at export time. The function inserts a custom operator
38
+ and links it to the eager_fn
39
+
40
+ :param eager_fn: the code it replaces, it must be given in order to be able
41
+ to execute the torch.fx.Graph the exporter produces
42
+ :param shape_fn: the function produces dummy outputs with the shapes
43
+ the exporter can use for the next operators in the graph
44
+ :param function_proto: instances of ``onnx.FunctionProto``,
45
+ its domain must be ``onnx_plug``
46
+ :param n_inputs: number of inputs of the function, if not given,
47
+ the class will infer it from eager_fn signature,
48
+ only tensors must be counted
49
+ :param n_outputs: same for the number of outputs,
50
+ only tensors must be counted
51
+ :param name: the name of the custom op, the function name if not specified
52
+ :param kwargs: constants parameters with their default values
53
+ :param verbose: verbose level
54
+
55
+ Here is an example:
56
+
57
+ .. runpython::
58
+ :showcode:
59
+
60
+ import onnx.helper as oh
61
+ import torch
62
+ from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
63
+ from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
64
+ from onnx_diagnostic.export.api import to_onnx
65
+ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
66
+
67
+
68
+ def demo_customsub(x, y):
69
+ return x - y
70
+
71
+
72
+ def demo_customsub_shape(x, y):
73
+ return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)
74
+
75
+
76
+ def make_function_proto():
77
+ return oh.make_function(
78
+ "onnx_plug",
79
+ "demo_customsub",
80
+ ["x", "y"],
81
+ ["z"],
82
+ [oh.make_node("Sub", ["x", "y"], ["z"])],
83
+ opset_imports=[oh.make_opsetid("", 22)],
84
+ )
85
+
86
+
87
+ class Model(torch.nn.Module):
88
+ def forward(self, x):
89
+ y = x.sum(axis=1, keepdim=True)
90
+ d = torch.ops.onnx_plug.demo_customsub(x, y)
91
+ return torch.abs(d)
92
+
93
+
94
+ replacements = [
95
+ EagerDirectReplacementWithOnnx(
96
+ demo_customsub, demo_customsub_shape, make_function_proto(), 2, 1
97
+ )
98
+ ]
99
+
100
+ x = torch.randn((3, 4), dtype=torch.float32)
101
+ model = Model()
102
+ ds = ({0: "d1", 1: "d2"},)
103
+
104
+ # The exported program shows a custom op.
105
+ ep = torch.export.export(model, (x,), dynamic_shapes=use_dyn_not_str(ds))
106
+ print("ep")
107
+
108
+ # As the exporter knows how the replace this custom op.
109
+ # Let's export.
110
+
111
+ onx = to_onnx(
112
+ model,
113
+ (x,),
114
+ dynamic_shapes=ds,
115
+ exporter="custom",
116
+ onnx_plugs=replacements,
117
+ target_opset=22,
118
+ inline=False,
119
+ ).model_proto
120
+
121
+ print(pretty_onnx(onx))
122
+
123
+ # And with :func:`torch.onnx.export`:
124
+
125
+ onx = to_onnx(
126
+ model,
127
+ (x,),
128
+ dynamic_shapes=ds,
129
+ exporter="onnx-dynamo",
130
+ onnx_plugs=replacements,
131
+ target_opset=22,
132
+ inline=False,
133
+ ).model_proto
134
+
135
+ print(pretty_onnx(onx))
136
+ """
137
+
138
+ def __init__(
139
+ self,
140
+ eager_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
141
+ shape_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
142
+ function_proto: onnx.FunctionProto,
143
+ n_inputs: Optional[int] = None,
144
+ n_outputs: Optional[int] = None,
145
+ name: Optional[str] = None,
146
+ kwargs: Optional[Dict[str, Union[int, float]]] = None,
147
+ verbose: int = 0,
148
+ ):
149
+ assert isinstance(
150
+ function_proto, onnx.FunctionProto
151
+ ), f"Unexpected type {type(function_proto)} for function_proto"
152
+ assert isinstance(n_inputs, int), f"not implemented yet when n_inputs={n_inputs}"
153
+ assert isinstance(n_outputs, int), f"not implemented yet when n_inputs={n_outputs}"
154
+ self.eager_fn = eager_fn
155
+ self.shape_fn = shape_fn
156
+ self.function_proto = function_proto
157
+ self.n_inputs = n_inputs
158
+ self.n_outputs = n_outputs
159
+ self.name = name or (
160
+ eager_fn.__name__
161
+ if "<" not in eager_fn.__name__
162
+ else eager_fn.__qualname__.replace("<locals>", "L")
163
+ .replace("<lambda>", "l")
164
+ .replace(".", "_")
165
+ )
166
+ self.kwargs = kwargs or {}
167
+ assert all(isinstance(v, (int, float)) for v in self.kwargs.values()), (
168
+ f"Only int or floats are allowed for kwargs={kwargs}, one of them "
169
+ f"does not respect that constraint."
170
+ )
171
+ sig = inspect.signature(self.eager_fn)
172
+ params = list(sig.parameters)
173
+ assert (
174
+ len(params) >= n_inputs
175
+ ), f"{self.eager_fn} accepts {params} as parameters < n_inputs={n_inputs}"
176
+ assert n_inputs == len(function_proto.input), (
177
+ f"Input mismatch n_inputs={n_inputs} but "
178
+ f"function_proto.input={function_proto.input}"
179
+ )
180
+ assert n_outputs == len(function_proto.output), (
181
+ f"Output mismatch n_outputs={n_outputs} but "
182
+ f"function_proto.output={function_proto.output}"
183
+ )
184
+ assert (
185
+ function_proto.domain == self.domain
186
+ ), f"Function domain must be {self.domain!r} but it is {function_proto.domain!r}"
187
+ self.args_name = [p for p in params if p not in self.kwargs]
188
+ self.kwargs_name = [p for p in params if p in self.kwargs]
189
+ self.verbose = verbose
190
+ self.custom_op = self._register()
191
+
192
+ @property
193
+ def domain(self) -> str:
194
+ "Returns the onnx domain."
195
+ return "onnx_plug"
196
+
197
+ @property
198
+ def target_name(self) -> str:
199
+ "Returns the target name (see in the exported program)."
200
+ return f"{self.domain}::{self.name}"
201
+
202
+ @property
203
+ def torch_op(self) -> Callable:
204
+ "Returns ``torch.ops.onny_plug.<name>``."
205
+ return getattr(getattr(torch.ops, self.domain), self.name).default
206
+
207
+ def __call__(self, *args, **kwargs):
208
+ """Calls eager_fn or shape_fn if the model is being exported."""
209
+ if is_exporting():
210
+ return self.torch_op(*args)
211
+ return self.eager_fn(*args, **kwargs)
212
+
213
+ def _register(self):
214
+ """Registers the custom op."""
215
+ input_args = [f"Tensor {p}" for p in self.args_name]
216
+ for p in self.kwargs_name:
217
+ val = self.kwargs[p]
218
+ if isinstance(val, int):
219
+ input_args.append(f"int {p}={val}")
220
+ elif isinstance(val, float):
221
+ input_args.append(f"float {p}={val}")
222
+ else:
223
+ raise NotImplementedError(
224
+ f"kwargs {p!r} has a default value of unsupported type {type(val)}"
225
+ )
226
+
227
+ inputs = ", ".join(input_args)
228
+ schema = f"({inputs}) -> Tensor"
229
+ if self.n_outputs > 1:
230
+ schema += "[]"
231
+ if self.verbose:
232
+ print(
233
+ f"[EagerDirectReplacementWithOnnx._register] "
234
+ f"'torch.ops.{self.domain}.{self.name}"
235
+ )
236
+ print(f"[EagerDirectReplacementWithOnnx._register] schema={schema}")
237
+ custom_def = torch.library.CustomOpDef(self.domain, self.name, schema, self.eager_fn)
238
+ custom_def.register_kernel(None)(self.eager_fn)
239
+ custom_def._abstract_fn = self.shape_fn
240
+
241
+ def verify(
242
+ self,
243
+ *args,
244
+ engine: Optional[Callable] = None,
245
+ dump_onnx_model: Optional[str] = None,
246
+ **kwargs,
247
+ ) -> VerifyResult:
248
+ """
249
+ Verifies that the eager mode is equivalent to the onnx function given
250
+ as a replacements. This function evaluates `eager_fn`, checks that the shapes
251
+ are equivalent to the ones given by `shape_fn`, and finally evaluates the
252
+ onnx translation if the previous did not fail.
253
+
254
+ :param args: function inputs
255
+ :param kwargs: arguments for eager_fn
256
+ :param engine: by default an instance of
257
+ :class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`.
258
+ :param dump_onnx_model: to dump the onnx model used to verify
259
+ eager and onnx produce the same results
260
+ :param kwargs: additional arguments to the function
261
+ :return: outputs of :func:`onnx_diagnostic.helpers.max_diff`
262
+ """
263
+ expected = self.eager_fn(*args, **kwargs)
264
+ shapes = self.shape_fn(*args, **kwargs)
265
+ if isinstance(expected, torch.Tensor):
266
+ expected = (expected,)
267
+ assert isinstance(shapes, torch.Tensor), (
268
+ f"eager_fn={self.eager_fn} returns a Tensor but shape_fn={self.shape_fn} "
269
+ f"returns a {type(shapes)}"
270
+ )
271
+ shapes = (shapes,)
272
+ assert isinstance(expected, tuple) and isinstance(shapes, tuple), (
273
+ f"eager_fn={self.eager_fn} returns a {type(expected)} "
274
+ f"and shape_fn={self.shape_fn} returns a {type(shapes)}"
275
+ )
276
+ assert len(expected) and len(shapes), (
277
+ f"eager_fn={self.eager_fn} and shape_fn={self.shape_fn} "
278
+ f"do not return the same number of tensors."
279
+ )
280
+ for i, (e, s) in enumerate(zip(expected, shapes)):
281
+ assert e.dtype == s.dtype, (
282
+ f"Type mismatch {e.dtype} != {s.dtype} for output {i}, "
283
+ f"eager_fn={self.eager_fn} and shape_fn={self.shape_fn}"
284
+ )
285
+ assert e.shape == s.shape, (
286
+ f"Type mismatch {e.shape} != {s.shape} for output {i}, "
287
+ f"eager_fn={self.eager_fn} and shape_fn={self.shape_fn}"
288
+ )
289
+
290
+ # Now the ONNX execution.
291
+ assert engine is None, f"Not implemented yet with engine={engine!r}"
292
+ ags, kws = self._make_args_kwargs(*args, **kwargs)
293
+ sess = OnnxruntimeEvaluator(
294
+ self.function_proto,
295
+ whole=True,
296
+ dump_onnx_model=dump_onnx_model,
297
+ function_kwargs=kws,
298
+ )
299
+ feeds = dict(zip(sess.input_names, ags))
300
+ got = sess.run(None, feeds)
301
+ diffs = tuple(max_diff(e, g, hist=[0.1, 0.01]) for e, g in zip(expected, got))
302
+ return VerifyResult(eager_outputs=expected, onnx_outputs=tuple(got), diffs=diffs) # type: ignore[arg-type]
303
+
304
+ def _make_args_kwargs(self, *args, **kwargs):
305
+ ags = args[: len(self.args_name)]
306
+ kws = dict(zip(self.kwargs_name, args[len(self.args_name) :]))
307
+ kws.update(kwargs)
308
+ return ags, kws
309
+
310
+ def custom_converter(
311
+ self,
312
+ ) -> Callable:
313
+ """
314
+ Returns a function which
315
+ converts a custom ops found in the fx graph into ONNX
316
+ following the API of the custom exporter.
317
+ The converter adds a custom op and registers the local function.
318
+ """
319
+
320
+ def converter(
321
+ g: Any, # GraphBuilder
322
+ sts: Optional[Dict[str, Any]],
323
+ outputs: List[str],
324
+ *args,
325
+ **kwargs,
326
+ ) -> Any:
327
+ if not g.has_local_function(
328
+ self.function_proto.name, domain=self.function_proto.domain
329
+ ):
330
+ g.add_function(self.function_proto)
331
+ ags, kws = self._make_args_kwargs(*args, **kwargs)
332
+ res = g.make_node(
333
+ self.function_proto.name,
334
+ ags,
335
+ outputs,
336
+ domain=self.function_proto.domain,
337
+ name=self.target_name,
338
+ **kws,
339
+ )
340
+ if not sts:
341
+ new_shapes = self.shape_fn(*args)
342
+ if not isinstance(new_shapes, tuple):
343
+ new_shapes = (new_shapes,)
344
+ for sh, o in zip(new_shapes, outputs):
345
+ g.set_type(o, torch_dtype_to_onnx_dtype(sh.dtype))
346
+ g.set_shape(o, sh.shape)
347
+ return res
348
+
349
+ return converter
350
+
351
+ def onnx_dynamo_converter(self) -> Callable:
352
+ """
353
+ Returns a function which
354
+ which converts a custom ops found in the fx graph into ONNX
355
+ following the API of :func:`torch.onnx.export`.
356
+ """
357
+ import onnxscript
358
+
359
+ onnx_plug_op = onnxscript.values.Opset(domain=self.function_proto.domain, version=1)
360
+ schema = onnx_plug_op[self.function_proto.name]
361
+ if schema is None:
362
+ all_types = [
363
+ "tensor(float)",
364
+ "tensor(float16)",
365
+ "tensor(bfloat16)",
366
+ "tensor(double)",
367
+ "tensor(int64)",
368
+ "tensor(int32)",
369
+ ]
370
+ type_constraints = []
371
+ for i in range(self.n_inputs):
372
+ type_constraints.append((f"T{i}", all_types, ""))
373
+ for i in range(self.n_outputs):
374
+ type_constraints.append((f"U{i}", all_types, ""))
375
+ schema = onnx.defs.OpSchema(
376
+ self.function_proto.name,
377
+ self.function_proto.domain,
378
+ 1,
379
+ inputs=[
380
+ onnx.defs.OpSchema.FormalParameter(f"arg_{i}", f"T{i}")
381
+ for i in range(self.n_inputs)
382
+ ],
383
+ outputs=[
384
+ onnx.defs.OpSchema.FormalParameter(f"res_{i}", f"U{i}")
385
+ for i in range(self.n_outputs)
386
+ ],
387
+ type_constraints=type_constraints,
388
+ )
389
+ onnx.defs.register_schema(schema)
390
+ op = onnxscript.values.Op(onnx_plug_op, self.function_proto.name, schema)
391
+
392
+ def converter(*cargs, **ckwargs):
393
+ ags, kws = self._make_args_kwargs(*cargs, **ckwargs)
394
+ return op(*ags, n_outputs=self.n_outputs, **kws)
395
+
396
+ return onnxscript.values.TracedOnnxFunction(onnx_plug_op, converter)
@@ -9,6 +9,7 @@ import itertools
9
9
  import logging
10
10
  import os
11
11
  import re
12
+ import shutil
12
13
  import sys
13
14
  import unittest
14
15
  import warnings
@@ -63,7 +64,7 @@ def skipif_ci_apple(msg) -> Callable:
63
64
  return lambda x: x
64
65
 
65
66
 
66
- def unit_test_going():
67
+ def unit_test_going() -> bool:
67
68
  """
68
69
  Enables a flag telling the script is running while testing it.
69
70
  Avois unit tests to be very long.
@@ -147,7 +148,7 @@ def hide_stdout(f: Optional[Callable] = None) -> Callable:
147
148
 
148
149
  def wrapper(fct):
149
150
  def call_f(self):
150
- if os.environ.get("UNHIDE", ""):
151
+ if os.environ.get("UNHIDE", "") in (1, "1", "True", "true"):
151
152
  fct(self)
152
153
  return
153
154
  st = StringIO()
@@ -742,8 +743,15 @@ class ExtTestCase(unittest.TestCase):
742
743
  _warns: List[Tuple[str, int, Warning]] = []
743
744
  _todos: List[Tuple[Callable, str]] = []
744
745
 
746
+ def unit_test_going(self) -> bool:
747
+ """
748
+ Enables a flag telling the script is running while testing it.
749
+ Avois unit tests to be very long.
750
+ """
751
+ return unit_test_going()
752
+
745
753
  @property
746
- def verbose(self):
754
+ def verbose(self) -> int:
747
755
  "Returns the the value of environment variable ``VERBOSE``."
748
756
  return int(os.environ.get("VERBOSE", "0"))
749
757
 
@@ -768,13 +776,13 @@ class ExtTestCase(unittest.TestCase):
768
776
  cls._todos.append((f, msg))
769
777
 
770
778
  @classmethod
771
- def ort(cls):
779
+ def ort(cls) -> unittest.__class__:
772
780
  import onnxruntime
773
781
 
774
782
  return onnxruntime
775
783
 
776
784
  @classmethod
777
- def to_onnx(self, *args, **kwargs):
785
+ def to_onnx(self, *args, **kwargs) -> "ModelProto": # noqa: F821
778
786
  from experimental_experiment.torch_interpreter import to_onnx
779
787
 
780
788
  return to_onnx(*args, **kwargs)
@@ -806,12 +814,16 @@ class ExtTestCase(unittest.TestCase):
806
814
  os.makedirs(folder)
807
815
  return folder
808
816
 
809
- def dump_onnx(
810
- self,
811
- name: str,
812
- proto: Any,
813
- folder: Optional[str] = None,
814
- ) -> str:
817
+ def clean_dump(self, folder: str = "dump_test"):
818
+ """Cleans this folder."""
819
+ for item in os.listdir(folder):
820
+ item_path = os.path.join(folder, item)
821
+ if os.path.isfile(item_path) or os.path.islink(item_path):
822
+ os.remove(item_path)
823
+ elif os.path.isdir(item_path):
824
+ shutil.rmtree(item_path)
825
+
826
+ def dump_onnx(self, name: str, proto: Any, folder: Optional[str] = None) -> str:
815
827
  """Dumps an onnx file."""
816
828
  fullname = self.get_dump_file(name, folder=folder)
817
829
  with open(fullname, "wb") as f:
@@ -1094,10 +1106,15 @@ class ExtTestCase(unittest.TestCase):
1094
1106
  value = numpy.array(value).astype(expected.dtype)
1095
1107
  self.assertEqualArray(expected, value, atol=atol, rtol=rtol)
1096
1108
 
1097
- def check_ort(self, onx: "onnx.ModelProto") -> bool: # noqa: F821
1109
+ def check_ort(
1110
+ self, onx: "onnx.ModelProto" # noqa: F821
1111
+ ) -> "onnxruntime.InferenceSession": # noqa: F821
1098
1112
  from onnxruntime import InferenceSession
1099
1113
 
1100
- return InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
1114
+ return InferenceSession(
1115
+ onx if isinstance(onx, str) else onx.SerializeToString(),
1116
+ providers=["CPUExecutionProvider"],
1117
+ )
1101
1118
 
1102
1119
  def assertRaise(self, fct: Callable, exc_type: type[Exception], msg: Optional[str] = None):
1103
1120
  """In the name"""
@@ -1137,7 +1154,7 @@ class ExtTestCase(unittest.TestCase):
1137
1154
  if not full.endswith(suffix):
1138
1155
  raise AssertionError(f"suffix={suffix!r} does not end string {full!r}.")
1139
1156
 
1140
- def capture(self, fct: Callable):
1157
+ def capture(self, fct: Callable) -> Tuple[Any, str, str]:
1141
1158
  """
1142
1159
  Runs a function and capture standard output and error.
1143
1160
 
@@ -1189,6 +1206,7 @@ class ExtTestCase(unittest.TestCase):
1189
1206
  expected: Optional[Any] = None,
1190
1207
  use_ort: bool = False,
1191
1208
  ort_optimized_graph: bool = False,
1209
+ ep: Optional[Union["torch.export.ExportedProgram", str]] = None, # noqa: F821
1192
1210
  **kwargs,
1193
1211
  ):
1194
1212
  """
@@ -1208,6 +1226,7 @@ class ExtTestCase(unittest.TestCase):
1208
1226
  :param copy_inputs: to copy the inputs
1209
1227
  :param use_ort: use :class:`onnxruntime.InferenceSession`
1210
1228
  :param ort_optimized_graph: dumps the optimized onnxruntime graph
1229
+ :param ep: exported program (or saved exported program)
1211
1230
  :param kwargs: arguments sent to
1212
1231
  :class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`
1213
1232
  """
@@ -1226,15 +1245,16 @@ class ExtTestCase(unittest.TestCase):
1226
1245
  if isinstance(proto, str):
1227
1246
  name = proto
1228
1247
  proto = onnx.load(name)
1229
- else:
1248
+ elif not self.unit_test_going():
1230
1249
  assert isinstance(
1231
1250
  proto, onnx.ModelProto
1232
1251
  ), f"Unexpected type {type(proto)} for proto"
1233
1252
  name = self.dump_onnx(name, proto)
1234
- if verbose:
1253
+ if verbose and not self.unit_test_going():
1235
1254
  print(f"[{vname}] file size {os.stat(name).st_size // 2**10:1.3f} kb")
1236
1255
  if verbose:
1237
1256
  print(f"[{vname}] make feeds {string_type(inputs, **kws)}")
1257
+
1238
1258
  if use_ort:
1239
1259
  assert isinstance(
1240
1260
  proto, onnx.ModelProto
@@ -1242,15 +1262,14 @@ class ExtTestCase(unittest.TestCase):
1242
1262
  feeds = make_feeds(proto, inputs, use_numpy=True, copy=True)
1243
1263
  import onnxruntime
1244
1264
 
1245
- if verbose:
1246
- print(f"[{vname}] create onnxruntime.InferenceSession")
1247
1265
  options = onnxruntime.SessionOptions()
1248
1266
  if ort_optimized_graph:
1249
1267
  options.optimized_model_filepath = f"{name}.optort.onnx"
1268
+ providers = kwargs.get("providers", ["CPUExecutionProvider"])
1269
+ if verbose:
1270
+ print(f"[{vname}] create onnxruntime.InferenceSession with {providers}")
1250
1271
  sess = onnxruntime.InferenceSession(
1251
- proto.SerializeToString(),
1252
- options,
1253
- providers=kwargs.get("providers", ["CPUExecutionProvider"]),
1272
+ proto.SerializeToString(), options, providers=providers
1254
1273
  )
1255
1274
  if verbose:
1256
1275
  print(f"[{vname}] run ort feeds {string_type(feeds, **kws)}")
@@ -1265,6 +1284,7 @@ class ExtTestCase(unittest.TestCase):
1265
1284
  got = sess.run(None, feeds)
1266
1285
  if verbose:
1267
1286
  print(f"[{vname}] compute expected values")
1287
+
1268
1288
  if expected is None:
1269
1289
  if copy_inputs:
1270
1290
  expected = (
@@ -1274,10 +1294,46 @@ class ExtTestCase(unittest.TestCase):
1274
1294
  )
1275
1295
  else:
1276
1296
  expected = model(*inputs) if isinstance(inputs, tuple) else model(**inputs)
1297
+
1277
1298
  if verbose:
1278
1299
  print(f"[{vname}] expected {string_type(expected, **kws)}")
1279
1300
  print(f"[{vname}] obtained {string_type(got, **kws)}")
1280
- diff = max_diff(expected, got, flatten=True)
1301
+
1302
+ if ep:
1303
+ if isinstance(ep, str):
1304
+ if verbose:
1305
+ print(f"[{vname}] load exported program {ep!r}")
1306
+ import torch
1307
+
1308
+ ep = torch.export.load(ep)
1309
+ ep_inputs = copy.deepcopy(inputs) if copy_inputs else inputs
1310
+ ep_model = ep.module() # type: ignore[union-attr]
1311
+ ep_expected = (
1312
+ ep_model(*copy.deepcopy(ep_inputs))
1313
+ if isinstance(ep_inputs, tuple)
1314
+ else ep_model(**copy.deepcopy(ep_inputs))
1315
+ )
1316
+ if verbose:
1317
+ print(f"[{vname}] ep_expected {string_type(ep_expected, **kws)}")
1318
+ ep_diff = max_diff(expected, ep_expected, hist=[0.1, 0.01])
1319
+ if verbose:
1320
+ print(f"[{vname}] ep_diff {string_diff(ep_diff)}")
1321
+ assert (
1322
+ isinstance(ep_diff["abs"], float)
1323
+ and isinstance(ep_diff["rel"], float)
1324
+ and not numpy.isnan(ep_diff["abs"])
1325
+ and ep_diff["abs"] <= atol
1326
+ and not numpy.isnan(ep_diff["rel"])
1327
+ and ep_diff["rel"] <= rtol
1328
+ ), (
1329
+ f"discrepancies in {test_name!r} between the exported program "
1330
+ f"and the exported model diff={string_diff(ep_diff)}"
1331
+ )
1332
+ ep_nx_diff = max_diff(ep_expected, got, flatten=True, hist=[0.1, 0.01])
1333
+ if verbose:
1334
+ print(f"[{vname}] ep_nx_diff {string_diff(ep_nx_diff)}")
1335
+
1336
+ diff = max_diff(expected, got, flatten=True, hist=[0.1, 0.01])
1281
1337
  if verbose:
1282
1338
  print(f"[{vname}] diff {string_diff(diff)}")
1283
1339
  assert (
@@ -1287,7 +1343,10 @@ class ExtTestCase(unittest.TestCase):
1287
1343
  and diff["abs"] <= atol
1288
1344
  and not numpy.isnan(diff["rel"])
1289
1345
  and diff["rel"] <= rtol
1290
- ), f"discrepancies in {test_name!r}, diff={string_diff(diff)}"
1346
+ ), (
1347
+ f"discrepancies in {test_name!r} between the model and "
1348
+ f"the onnx model diff={string_diff(diff)}"
1349
+ )
1291
1350
 
1292
1351
  def _debug(self):
1293
1352
  "Tells if DEBUG=1 is set up."
@@ -1298,6 +1357,16 @@ class ExtTestCase(unittest.TestCase):
1298
1357
 
1299
1358
  return string_type(*args, **kwargs)
1300
1359
 
1360
+ def max_diff(self, *args, **kwargs):
1361
+ from .helpers import max_diff
1362
+
1363
+ return max_diff(*args, **kwargs)
1364
+
1365
+ def use_dyn_not_str(self, *args, **kwargs):
1366
+ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
1367
+
1368
+ return use_dyn_not_str(*args, *kwargs)
1369
+
1301
1370
  def subloop(self, *args, verbose: int = 0):
1302
1371
  "Loops over elements and calls :meth:`unittests.TestCase.subTest`."
1303
1372
  if len(args) == 1:
@@ -80,7 +80,7 @@ def flatten_unflatten_for_dynamic_shapes(
80
80
  start = 0
81
81
  end = 0
82
82
  subtrees = []
83
- for subspec in spec.children_specs:
83
+ for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs):
84
84
  end += subspec.num_leaves
85
85
  value = subspec.unflatten(flat[start:end])
86
86
  value = flatten_unflatten_for_dynamic_shapes(