onnx-diagnostic 0.8.3__py3-none-any.whl → 0.8.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +47 -10
  3. onnx_diagnostic/export/api.py +81 -50
  4. onnx_diagnostic/export/control_flow_research.py +10 -5
  5. onnx_diagnostic/export/onnx_plug.py +250 -61
  6. onnx_diagnostic/ext_test_case.py +99 -53
  7. onnx_diagnostic/helpers/dot_helper.py +37 -25
  8. onnx_diagnostic/helpers/helper.py +44 -38
  9. onnx_diagnostic/helpers/onnx_helper.py +441 -18
  10. onnx_diagnostic/helpers/ort_session.py +8 -8
  11. onnx_diagnostic/helpers/torch_helper.py +28 -2
  12. onnx_diagnostic/reference/ort_evaluator.py +6 -29
  13. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +1 -0
  14. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +10 -1
  15. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +168 -113
  16. onnx_diagnostic/torch_models/code_sample.py +2 -1
  17. onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
  18. onnx_diagnostic/torch_models/validate.py +14 -1
  19. onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
  20. onnx_diagnostic/torch_onnx/sbs.py +11 -5
  21. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +48 -4
  22. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/METADATA +1 -1
  23. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/RECORD +26 -26
  24. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/WHEEL +0 -0
  25. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/licenses/LICENSE.txt +0 -0
  26. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/top_level.txt +0 -0
@@ -3,8 +3,12 @@ from dataclasses import dataclass
3
3
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
4
  import onnx
5
5
  import torch
6
- from ..helpers import max_diff
7
- from ..helpers.torch_helper import torch_dtype_to_onnx_dtype
6
+ from ..helpers import max_diff, string_type
7
+ from ..helpers.torch_helper import (
8
+ torch_dtype_to_onnx_dtype,
9
+ onnx_dtype_to_torch_dtype,
10
+ int_device_to_torch_device,
11
+ )
8
12
  from ..reference import OnnxruntimeEvaluator
9
13
 
10
14
  TUPLE_TENSORS = Tuple[torch.Tensor, ...]
@@ -50,6 +54,10 @@ class EagerDirectReplacementWithOnnx:
50
54
  only tensors must be counted
51
55
  :param name: the name of the custom op, the function name if not specified
52
56
  :param kwargs: constants parameters with their default values
57
+ :param version_selector: selects the version based on the arguments,
58
+ see below for an example, this allows the user to define different
59
+ onnx version depending on the inputs
60
+ :param default_opset: opset to use by default
53
61
  :param verbose: verbose level
54
62
 
55
63
  Here is an example:
@@ -120,7 +128,61 @@ class EagerDirectReplacementWithOnnx:
120
128
 
121
129
  print(pretty_onnx(onx))
122
130
 
123
- # And with :func:`torch.onnx.export`:
131
+ We do the same with :func:`torch.onnx.export`:
132
+
133
+ .. runpython::
134
+ :showcode:
135
+
136
+ import onnx.helper as oh
137
+ import torch
138
+ from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
139
+ from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
140
+ from onnx_diagnostic.export.api import to_onnx
141
+ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
142
+
143
+
144
+ def demo_customsub(x, y):
145
+ return x - y
146
+
147
+
148
+ def demo_customsub_shape(x, y):
149
+ return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)
150
+
151
+
152
+ def make_function_proto():
153
+ return oh.make_function(
154
+ "onnx_plug",
155
+ "demo_customsub",
156
+ ["x", "y"],
157
+ ["z"],
158
+ [oh.make_node("Sub", ["x", "y"], ["z"])],
159
+ opset_imports=[oh.make_opsetid("", 22)],
160
+ )
161
+
162
+
163
+ class Model(torch.nn.Module):
164
+ def forward(self, x):
165
+ y = x.sum(axis=1, keepdim=True)
166
+ d = torch.ops.onnx_plug.demo_customsub(x, y)
167
+ return torch.abs(d)
168
+
169
+
170
+ replacements = [
171
+ EagerDirectReplacementWithOnnx(
172
+ demo_customsub, demo_customsub_shape, make_function_proto(), 2, 1
173
+ )
174
+ ]
175
+
176
+ x = torch.randn((3, 4), dtype=torch.float32)
177
+ model = Model()
178
+ ds = ({0: "d1", 1: "d2"},)
179
+
180
+ # The exported program shows a custom op.
181
+ ep = torch.export.export(model, (x,), dynamic_shapes=use_dyn_not_str(ds))
182
+ print("ep")
183
+
184
+ # As the exporter knows how the replace this custom op.
185
+ # Let's export.
124
186
 
125
187
  onx = to_onnx(
126
188
  model,
@@ -133,27 +195,87 @@ class EagerDirectReplacementWithOnnx:
133
195
  ).model_proto
134
196
 
135
197
  print(pretty_onnx(onx))
198
+
199
+ This shows how to define multiple versions depending on the device,
200
+ the type or the targeted onnx opset.
201
+
202
+ .. code-block:: python
203
+
204
+ def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
205
+ first_tensor = next(a for a in args if a is not None)
206
+ dtype = first_tensor.dtype
207
+ itype = torch_dtype_to_onnx_dtype(dtype)
208
+ if dtype == torch.float32:
209
+ if opset >= 23:
210
+ return "LOOPA23", itype
211
+ return "LOOPMHA", itype
212
+ if dtype == torch.float16:
213
+ if first_tensor.is_cuda:
214
+ return "PACKED", itype
215
+ return "LOOPMHA", itype
216
+ raise AssertionError(
217
+ f"Unable to handle type {torch.dtype} (itype={itype}) "
218
+ f"on device {torch.device} with opset={opset}"
219
+ )
220
+
221
+ qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
222
+ qwen_sdpa_attention,
223
+ lambda qs, *args, **kwargs: torch.empty(
224
+ (qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
225
+ dtype=qs.dtype,
226
+ device=qs.device,
227
+ ),
228
+ {
229
+ ("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset(
230
+ PackedAttention.to_function_proto()
231
+ ),
232
+ ("LOOPA23", onnx.TensorProto.FLOAT): LoopAttention23.to_function_proto(),
233
+ ("LOOPA23", onnx.TensorProto.FLOAT16): _update_sequence_type(
234
+ onnx.TensorProto.FLOAT16, LoopAttention23.to_function_proto()
235
+ ),
236
+ ("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset(
237
+ LoopMHAAttention.to_function_proto()
238
+ ),
239
+ ("LOOPMHA", onnx.TensorProto.FLOAT16): _update_sequence_type(
240
+ onnx.TensorProto.FLOAT16,
241
+ _add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
242
+ ),
243
+ },
244
+ n_inputs=4,
245
+ n_outputs=1,
246
+ kwargs=dict(scaling=0.11180339887498948, num_heads=16),
247
+ name="qwen_sdpa_attention_versatile",
248
+ version_selector=qwen_version_selector,
249
+ )
136
250
  """
137
251
 
138
252
  def __init__(
139
253
  self,
140
254
  eager_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
141
255
  shape_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
142
- function_proto: onnx.FunctionProto,
256
+ function_proto: Union[onnx.FunctionProto, Dict[Any, onnx.FunctionProto]],
143
257
  n_inputs: Optional[int] = None,
144
258
  n_outputs: Optional[int] = None,
145
259
  name: Optional[str] = None,
146
260
  kwargs: Optional[Dict[str, Union[int, float]]] = None,
147
261
  verbose: int = 0,
262
+ version_selector: Optional[Callable[..., Tuple[Any, ...]]] = None,
263
+ default_opset: int = 22,
148
264
  ):
149
- assert isinstance(
150
- function_proto, onnx.FunctionProto
265
+ assert isinstance(function_proto, onnx.FunctionProto) or (
266
+ isinstance(function_proto, dict)
267
+ or all(isinstance(v, onnx.FunctionProto) for v in function_proto.values())
151
268
  ), f"Unexpected type {type(function_proto)} for function_proto"
152
269
  assert isinstance(n_inputs, int), f"not implemented yet when n_inputs={n_inputs}"
153
- assert isinstance(n_outputs, int), f"not implemented yet when n_inputs={n_outputs}"
270
+ assert isinstance(n_outputs, int), f"not implemented yet when n_outputs={n_outputs}"
154
271
  self.eager_fn = eager_fn
155
272
  self.shape_fn = shape_fn
156
- self.function_proto = function_proto
273
+ self._function_proto = (
274
+ function_proto if isinstance(function_proto, onnx.FunctionProto) else None
275
+ )
276
+ self._function_proto_versioned = (
277
+ function_proto if isinstance(function_proto, dict) else {}
278
+ )
157
279
  self.n_inputs = n_inputs
158
280
  self.n_outputs = n_outputs
159
281
  self.name = name or (
@@ -170,24 +292,73 @@ class EagerDirectReplacementWithOnnx:
170
292
  )
171
293
  sig = inspect.signature(self.eager_fn)
172
294
  params = list(sig.parameters)
173
- assert (
174
- len(params) >= n_inputs
175
- ), f"{self.eager_fn} accepts {params} as parameters < n_inputs={n_inputs}"
176
- assert n_inputs == len(function_proto.input), (
177
- f"Input mismatch n_inputs={n_inputs} but "
178
- f"function_proto.input={function_proto.input}"
179
- )
180
- assert n_outputs == len(function_proto.output), (
181
- f"Output mismatch n_outputs={n_outputs} but "
182
- f"function_proto.output={function_proto.output}"
183
- )
184
- assert (
185
- function_proto.domain == self.domain
186
- ), f"Function domain must be {self.domain!r} but it is {function_proto.domain!r}"
187
295
  self.args_name = [p for p in params if p not in self.kwargs]
188
296
  self.kwargs_name = [p for p in params if p in self.kwargs]
189
297
  self.verbose = verbose
190
298
  self.custom_op = self._register()
299
+ self.version_selector = version_selector
300
+ self.default_opset = default_opset
301
+ self._check_protos(params)
302
+
303
+ def _check_protos(self, params):
304
+ assert (
305
+ len(params) >= self.n_inputs
306
+ ), f"{self.eager_fn} accepts {params} as parameters < n_inputs={self.n_inputs}"
307
+
308
+ # one proto
309
+ assert self._function_proto is None or self.n_inputs == len(
310
+ self._function_proto.input
311
+ ), (
312
+ f"Input mismatch n_inputs={self.n_inputs} but "
313
+ f"function_proto.input={self._function_proto.input}"
314
+ )
315
+ assert self._function_proto is None or self.n_outputs == len(
316
+ self._function_proto.output
317
+ ), (
318
+ f"Output mismatch n_outputs={self.n_outputs} but "
319
+ f"function_proto.output={self._function_proto.output}"
320
+ )
321
+ assert self._function_proto is None or (
322
+ self._function_proto.domain == self.domain
323
+ ), f"Function domain must be {self.domain!r} but it is {self._function_proto.domain!r}"
324
+
325
+ # multiple protos
326
+ assert all(
327
+ self.n_inputs == len(v.input) for v in self._function_proto_versioned.values()
328
+ ), f"Output mismatch n_inputs={self.n_inputs} but one version is wrong"
329
+ assert all(
330
+ self.n_outputs == len(v.output) for v in self._function_proto_versioned.values()
331
+ ), f"Output mismatch n_outputs={self.n_outputs} but one version is wrong"
332
+ assert all(
333
+ v.domain == self.domain for v in self._function_proto_versioned.values()
334
+ ), f"Function domain must be {self.domain!r} but it is different in one version"
335
+ assert (
336
+ not self._function_proto_versioned or self.version_selector
337
+ ), "version_selector is needed when multiple protos are given."
338
+
339
+ def get_function_proto(self, opset: int, *args) -> onnx.FunctionProto:
340
+ """Returns the correct version based on the inputs."""
341
+ if self._function_proto:
342
+ return self._function_proto
343
+ assert isinstance(
344
+ opset, int
345
+ ), f"The first argument must be an integer for the onnx opset but it is {type(opset)}"
346
+ assert any(
347
+ a is not None for a in args
348
+ ), f"Unexpected args={string_type(args, with_shape=True)}"
349
+ try:
350
+ key = self.version_selector(opset, *args) # type: ignore[misc]
351
+ except (ValueError, AttributeError) as e:
352
+ raise AssertionError(
353
+ f"Unable to select a version, fails to get a key, available="
354
+ f"{set(self._function_proto_versioned)}, "
355
+ f"args={string_type(args,with_shape=True)}"
356
+ ) from e
357
+ assert key in self._function_proto_versioned, (
358
+ f"Unable to select a version, key={key}, available="
359
+ f"{set(self._function_proto_versioned)}, args={string_type(args,with_shape=True)}"
360
+ )
361
+ return self._function_proto_versioned[key]
191
362
 
192
363
  @property
193
364
  def domain(self) -> str:
@@ -219,6 +390,8 @@ class EagerDirectReplacementWithOnnx:
219
390
  input_args.append(f"int {p}={val}")
220
391
  elif isinstance(val, float):
221
392
  input_args.append(f"float {p}={val}")
393
+ elif isinstance(val, str):
394
+ input_args.append(f"str {p}={val}")
222
395
  else:
223
396
  raise NotImplementedError(
224
397
  f"kwargs {p!r} has a default value of unsupported type {type(val)}"
@@ -243,6 +416,7 @@ class EagerDirectReplacementWithOnnx:
243
416
  *args,
244
417
  engine: Optional[Callable] = None,
245
418
  dump_onnx_model: Optional[str] = None,
419
+ opset: int = 22,
246
420
  **kwargs,
247
421
  ) -> VerifyResult:
248
422
  """
@@ -257,6 +431,7 @@ class EagerDirectReplacementWithOnnx:
257
431
  :class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`.
258
432
  :param dump_onnx_model: to dump the onnx model used to verify
259
433
  eager and onnx produce the same results
434
+ :param opset: onnx opset to use
260
435
  :param kwargs: additional arguments to the function
261
436
  :return: outputs of :func:`onnx_diagnostic.helpers.max_diff`
262
437
  """
@@ -291,7 +466,7 @@ class EagerDirectReplacementWithOnnx:
291
466
  assert engine is None, f"Not implemented yet with engine={engine!r}"
292
467
  ags, kws = self._make_args_kwargs(*args, **kwargs)
293
468
  sess = OnnxruntimeEvaluator(
294
- self.function_proto,
469
+ self.get_function_proto(opset, *args),
295
470
  whole=True,
296
471
  dump_onnx_model=dump_onnx_model,
297
472
  function_kwargs=kws,
@@ -324,16 +499,25 @@ class EagerDirectReplacementWithOnnx:
324
499
  *args,
325
500
  **kwargs,
326
501
  ) -> Any:
327
- if not g.has_local_function(
328
- self.function_proto.name, domain=self.function_proto.domain
329
- ):
330
- g.add_function(self.function_proto)
502
+ has_devices = [a for a in args if isinstance(a, str) and g.has_device(a)]
503
+ assert (
504
+ has_devices
505
+ ), f"Missing device for any of the inputs {args}{g.get_debug_msg()}"
506
+ arg_device = has_devices[0]
507
+ fake_tensor = torch.empty(
508
+ tuple([(_ if isinstance(_, int) else 2) for _ in g.get_shape(args[0])]),
509
+ dtype=onnx_dtype_to_torch_dtype(g.get_type(args[0])),
510
+ device=int_device_to_torch_device(g.get_device(arg_device)),
511
+ )
512
+ function_proto = self.get_function_proto(g.main_opset, fake_tensor)
513
+ if not g.has_local_function(function_proto.name, domain=function_proto.domain):
514
+ g.add_function(function_proto)
331
515
  ags, kws = self._make_args_kwargs(*args, **kwargs)
332
516
  res = g.make_node(
333
- self.function_proto.name,
517
+ function_proto.name,
334
518
  ags,
335
519
  outputs,
336
- domain=self.function_proto.domain,
520
+ domain=function_proto.domain,
337
521
  name=self.target_name,
338
522
  **kws,
339
523
  )
@@ -356,41 +540,46 @@ class EagerDirectReplacementWithOnnx:
356
540
  """
357
541
  import onnxscript
358
542
 
359
- onnx_plug_op = onnxscript.values.Opset(domain=self.function_proto.domain, version=1)
360
- schema = onnx_plug_op[self.function_proto.name]
361
- if schema is None:
362
- all_types = [
363
- "tensor(float)",
364
- "tensor(float16)",
365
- "tensor(bfloat16)",
366
- "tensor(double)",
367
- "tensor(int64)",
368
- "tensor(int32)",
369
- ]
370
- type_constraints = []
371
- for i in range(self.n_inputs):
372
- type_constraints.append((f"T{i}", all_types, ""))
373
- for i in range(self.n_outputs):
374
- type_constraints.append((f"U{i}", all_types, ""))
375
- schema = onnx.defs.OpSchema(
376
- self.function_proto.name,
377
- self.function_proto.domain,
378
- 1,
379
- inputs=[
380
- onnx.defs.OpSchema.FormalParameter(f"arg_{i}", f"T{i}")
381
- for i in range(self.n_inputs)
382
- ],
383
- outputs=[
384
- onnx.defs.OpSchema.FormalParameter(f"res_{i}", f"U{i}")
385
- for i in range(self.n_outputs)
386
- ],
387
- type_constraints=type_constraints,
388
- )
389
- onnx.defs.register_schema(schema)
390
- op = onnxscript.values.Op(onnx_plug_op, self.function_proto.name, schema)
543
+ onnx_plug_op = onnxscript.values.Opset(domain=self.domain, version=1)
544
+
545
+ def get_proto(*args):
546
+ function_proto = self.get_function_proto(self.default_opset, *args)
547
+ schema = onnx_plug_op[function_proto.name]
548
+ if schema is None:
549
+ all_types = [
550
+ "tensor(float)",
551
+ "tensor(float16)",
552
+ "tensor(bfloat16)",
553
+ "tensor(double)",
554
+ "tensor(int64)",
555
+ "tensor(int32)",
556
+ ]
557
+ type_constraints = []
558
+ for i in range(self.n_inputs):
559
+ type_constraints.append((f"T{i}", all_types, ""))
560
+ for i in range(self.n_outputs):
561
+ type_constraints.append((f"U{i}", all_types, ""))
562
+ schema = onnx.defs.OpSchema(
563
+ function_proto.name,
564
+ function_proto.domain,
565
+ 1,
566
+ inputs=[
567
+ onnx.defs.OpSchema.FormalParameter(f"arg_{i}", f"T{i}")
568
+ for i in range(self.n_inputs)
569
+ ],
570
+ outputs=[
571
+ onnx.defs.OpSchema.FormalParameter(f"res_{i}", f"U{i}")
572
+ for i in range(self.n_outputs)
573
+ ],
574
+ type_constraints=type_constraints,
575
+ )
576
+ onnx.defs.register_schema(schema)
577
+ op = onnxscript.values.Op(onnx_plug_op, function_proto.name, schema)
578
+ return op
391
579
 
392
580
  def converter(*cargs, **ckwargs):
393
581
  ags, kws = self._make_args_kwargs(*cargs, **ckwargs)
582
+ op = get_proto(*cargs)
394
583
  return op(*ags, n_outputs=self.n_outputs, **kws)
395
584
 
396
585
  return onnxscript.values.TracedOnnxFunction(onnx_plug_op, converter)
@@ -610,6 +610,21 @@ def requires_onnxruntime(version: str, msg: str = "") -> Callable:
610
610
  return lambda x: x
611
611
 
612
612
 
613
+ def has_onnxruntime(version: str, msg: str = "") -> Callable:
614
+ """Skips a unit test if :epkg:`onnxruntime` is not recent enough."""
615
+ import packaging.version as pv
616
+ import onnxruntime
617
+
618
+ if not hasattr(onnxruntime, "__version__"):
619
+ # development version
620
+ return True
621
+
622
+ if pv.Version(onnxruntime.__version__) < pv.Version(version):
623
+ msg = f"onnxruntime version {onnxruntime.__version__} < {version}: {msg}"
624
+ return False
625
+ return True
626
+
627
+
613
628
  def has_onnxruntime_training(push_back_batch: bool = False):
614
629
  """Tells if onnxruntime_training is installed."""
615
630
  try:
@@ -830,6 +845,13 @@ class ExtTestCase(unittest.TestCase):
830
845
  f.write(proto.SerializeToString())
831
846
  return fullname
832
847
 
848
+ def dump_text(self, name: str, text: str, folder: Optional[str] = None) -> str:
849
+ """Dumps text in a file."""
850
+ fullname = self.get_dump_file(name, folder=folder)
851
+ with open(fullname, "w") as f:
852
+ f.write(text)
853
+ return fullname
854
+
833
855
  def assertExists(self, name):
834
856
  """Checks the existing of a file."""
835
857
  if not os.path.exists(name):
@@ -1196,9 +1218,9 @@ class ExtTestCase(unittest.TestCase):
1196
1218
  def assert_onnx_disc(
1197
1219
  self,
1198
1220
  test_name: str,
1199
- proto: "onnx.ModelProto", # noqa: F821
1221
+ proto: Union[str, "onnx.ModelProto"], # noqa: F821
1200
1222
  model: "torch.nn.Module", # noqa: F821
1201
- inputs: Union[Tuple[Any], Dict[str, Any]],
1223
+ inputs: Union[Tuple[Any], Dict[str, Any], List[Any]],
1202
1224
  verbose: int = 0,
1203
1225
  atol: float = 1e-5,
1204
1226
  rtol: float = 1e-3,
@@ -1242,7 +1264,9 @@ class ExtTestCase(unittest.TestCase):
1242
1264
  name = f"{test_name}.onnx"
1243
1265
  if verbose:
1244
1266
  print(f"[{vname}] save the onnx model into {name!r}")
1267
+ model_file = None
1245
1268
  if isinstance(proto, str):
1269
+ model_file = proto
1246
1270
  name = proto
1247
1271
  proto = onnx.load(name)
1248
1272
  elif not self.unit_test_going():
@@ -1255,45 +1279,64 @@ class ExtTestCase(unittest.TestCase):
1255
1279
  if verbose:
1256
1280
  print(f"[{vname}] make feeds {string_type(inputs, **kws)}")
1257
1281
 
1282
+ if not isinstance(inputs, list):
1283
+ inputs = [inputs]
1284
+ if expected is not None:
1285
+ expected = [expected]
1286
+
1287
+ gots = []
1258
1288
  if use_ort:
1259
1289
  assert isinstance(
1260
1290
  proto, onnx.ModelProto
1261
1291
  ), f"Unexpected type {type(proto)} for proto"
1262
- feeds = make_feeds(proto, inputs, use_numpy=True, copy=True)
1263
1292
  import onnxruntime
1264
1293
 
1265
1294
  options = onnxruntime.SessionOptions()
1266
1295
  if ort_optimized_graph:
1267
1296
  options.optimized_model_filepath = f"{name}.optort.onnx"
1297
+ if "log_severity_level" in kwargs:
1298
+ options.log_severity_level = kwargs["log_severity_level"]
1299
+ if "log_verbosity_level" in kwargs:
1300
+ options.log_verbosity_level = kwargs["log_verbosity_level"]
1268
1301
  providers = kwargs.get("providers", ["CPUExecutionProvider"])
1269
1302
  if verbose:
1270
1303
  print(f"[{vname}] create onnxruntime.InferenceSession with {providers}")
1271
1304
  sess = onnxruntime.InferenceSession(
1272
- proto.SerializeToString(), options, providers=providers
1305
+ model_file or proto.SerializeToString(), options, providers=providers
1273
1306
  )
1274
- if verbose:
1275
- print(f"[{vname}] run ort feeds {string_type(feeds, **kws)}")
1276
- got = sess.run(None, feeds)
1307
+ for inp in inputs:
1308
+ feeds = make_feeds(proto, inp, use_numpy=True, copy=True)
1309
+ if verbose:
1310
+ print(f"[{vname}] run ort feeds {string_type(feeds, **kws)}")
1311
+ got = sess.run(None, feeds)
1312
+ gots.append(got)
1277
1313
  else:
1278
- feeds = make_feeds(proto, inputs, copy=True)
1279
1314
  if verbose:
1280
1315
  print(f"[{vname}] create InferenceSessionForTorch")
1281
1316
  sess = InferenceSessionForTorch(proto, **kwargs)
1282
- if verbose:
1283
- print(f"[{vname}] run orttorch feeds {string_type(feeds, **kws)}")
1284
- got = sess.run(None, feeds)
1317
+ for inp in inputs:
1318
+ feeds = make_feeds(proto, inp, copy=True)
1319
+ if verbose:
1320
+ print(f"[{vname}] run orttorch feeds {string_type(feeds, **kws)}")
1321
+ got = sess.run(None, feeds)
1322
+ gots.append(got)
1285
1323
  if verbose:
1286
1324
  print(f"[{vname}] compute expected values")
1287
1325
 
1288
1326
  if expected is None:
1289
1327
  if copy_inputs:
1290
- expected = (
1291
- model(*copy.deepcopy(inputs))
1292
- if isinstance(inputs, tuple)
1293
- else model(**copy.deepcopy(inputs))
1294
- )
1328
+ expected = [
1329
+ (
1330
+ model(*copy.deepcopy(inp))
1331
+ if isinstance(inp, tuple)
1332
+ else model(**copy.deepcopy(inp))
1333
+ )
1334
+ for inp in inputs
1335
+ ]
1295
1336
  else:
1296
- expected = model(*inputs) if isinstance(inputs, tuple) else model(**inputs)
1337
+ expected = [
1338
+ model(*inp) if isinstance(inp, tuple) else model(**inp) for inp in inputs
1339
+ ]
1297
1340
 
1298
1341
  if verbose:
1299
1342
  print(f"[{vname}] expected {string_type(expected, **kws)}")
@@ -1306,47 +1349,50 @@ class ExtTestCase(unittest.TestCase):
1306
1349
  import torch
1307
1350
 
1308
1351
  ep = torch.export.load(ep)
1309
- ep_inputs = copy.deepcopy(inputs) if copy_inputs else inputs
1352
+
1310
1353
  ep_model = ep.module() # type: ignore[union-attr]
1311
- ep_expected = (
1312
- ep_model(*copy.deepcopy(ep_inputs))
1313
- if isinstance(ep_inputs, tuple)
1314
- else ep_model(**copy.deepcopy(ep_inputs))
1315
- )
1316
- if verbose:
1317
- print(f"[{vname}] ep_expected {string_type(ep_expected, **kws)}")
1318
- ep_diff = max_diff(expected, ep_expected, hist=[0.1, 0.01])
1354
+ for expe, inp, got in zip(expected, inputs, gots):
1355
+ ep_inputs = copy.deepcopy(inp) if copy_inputs else inp
1356
+ ep_expected = (
1357
+ ep_model(*copy.deepcopy(ep_inputs))
1358
+ if isinstance(ep_inputs, tuple)
1359
+ else ep_model(**copy.deepcopy(ep_inputs))
1360
+ )
1361
+ if verbose:
1362
+ print(f"[{vname}] ep_expected {string_type(ep_expected, **kws)}")
1363
+ ep_diff = max_diff(expe, ep_expected, hist=[0.1, 0.01])
1364
+ if verbose:
1365
+ print(f"[{vname}] ep_diff {string_diff(ep_diff)}")
1366
+ assert (
1367
+ isinstance(ep_diff["abs"], float)
1368
+ and isinstance(ep_diff["rel"], float)
1369
+ and not numpy.isnan(ep_diff["abs"])
1370
+ and ep_diff["abs"] <= atol
1371
+ and not numpy.isnan(ep_diff["rel"])
1372
+ and ep_diff["rel"] <= rtol
1373
+ ), (
1374
+ f"discrepancies in {test_name!r} between the exported program "
1375
+ f"and the exported model diff={string_diff(ep_diff)}"
1376
+ )
1377
+ ep_nx_diff = max_diff(ep_expected, got, flatten=True, hist=[0.1, 0.01])
1378
+ if verbose:
1379
+ print(f"[{vname}] ep_nx_diff {string_diff(ep_nx_diff)}")
1380
+
1381
+ for expe, got in zip(expected, gots):
1382
+ diff = max_diff(expe, got, flatten=True, hist=[0.1, 0.01])
1319
1383
  if verbose:
1320
- print(f"[{vname}] ep_diff {string_diff(ep_diff)}")
1384
+ print(f"[{vname}] diff {string_diff(diff)}")
1321
1385
  assert (
1322
- isinstance(ep_diff["abs"], float)
1323
- and isinstance(ep_diff["rel"], float)
1324
- and not numpy.isnan(ep_diff["abs"])
1325
- and ep_diff["abs"] <= atol
1326
- and not numpy.isnan(ep_diff["rel"])
1327
- and ep_diff["rel"] <= rtol
1386
+ isinstance(diff["abs"], float)
1387
+ and isinstance(diff["rel"], float)
1388
+ and not numpy.isnan(diff["abs"])
1389
+ and diff["abs"] <= atol
1390
+ and not numpy.isnan(diff["rel"])
1391
+ and diff["rel"] <= rtol
1328
1392
  ), (
1329
- f"discrepancies in {test_name!r} between the exported program "
1330
- f"and the exported model diff={string_diff(ep_diff)}"
1393
+ f"discrepancies in {test_name!r} between the model and "
1394
+ f"the onnx model diff={string_diff(diff)}"
1331
1395
  )
1332
- ep_nx_diff = max_diff(ep_expected, got, flatten=True, hist=[0.1, 0.01])
1333
- if verbose:
1334
- print(f"[{vname}] ep_nx_diff {string_diff(ep_nx_diff)}")
1335
-
1336
- diff = max_diff(expected, got, flatten=True, hist=[0.1, 0.01])
1337
- if verbose:
1338
- print(f"[{vname}] diff {string_diff(diff)}")
1339
- assert (
1340
- isinstance(diff["abs"], float)
1341
- and isinstance(diff["rel"], float)
1342
- and not numpy.isnan(diff["abs"])
1343
- and diff["abs"] <= atol
1344
- and not numpy.isnan(diff["rel"])
1345
- and diff["rel"] <= rtol
1346
- ), (
1347
- f"discrepancies in {test_name!r} between the model and "
1348
- f"the onnx model diff={string_diff(diff)}"
1349
- )
1350
1396
 
1351
1397
  def _debug(self):
1352
1398
  "Tells if DEBUG=1 is set up."