onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +412 -12
  3. onnx_diagnostic/export/api.py +111 -8
  4. onnx_diagnostic/export/control_flow.py +48 -345
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +12 -7
  7. onnx_diagnostic/export/onnx_plug.py +531 -0
  8. onnx_diagnostic/ext_test_case.py +163 -48
  9. onnx_diagnostic/helpers/cache_helper.py +1 -1
  10. onnx_diagnostic/helpers/dot_helper.py +222 -0
  11. onnx_diagnostic/helpers/helper.py +108 -37
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  14. onnx_diagnostic/helpers/onnx_helper.py +531 -6
  15. onnx_diagnostic/helpers/ort_session.py +45 -19
  16. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  17. onnx_diagnostic/helpers/torch_helper.py +131 -8
  18. onnx_diagnostic/reference/ort_evaluator.py +228 -46
  19. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  20. onnx_diagnostic/tasks/summarization.py +72 -137
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
  22. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  23. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  24. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
  36. onnx_diagnostic/torch_models/code_sample.py +2 -1
  37. onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
  38. onnx_diagnostic/torch_models/validate.py +64 -2
  39. onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
  40. onnx_diagnostic/torch_onnx/sbs.py +969 -312
  41. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
  42. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
  43. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
  44. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
  45. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  46. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,531 @@
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
+ import onnx
5
+ import torch
6
+ from ..helpers import max_diff, string_type
7
+ from ..helpers.torch_helper import (
8
+ torch_dtype_to_onnx_dtype,
9
+ onnx_dtype_to_torch_dtype,
10
+ int_device_to_torch_device,
11
+ )
12
+ from ..reference import OnnxruntimeEvaluator
13
+
14
+ TUPLE_TENSORS = Tuple[torch.Tensor, ...]
15
+
16
+
17
+ def is_exporting() -> bool:
18
+ """
19
+ Returns :func:`torch.compiler.is_exporting` or
20
+ :func:`torch.compiler.is_compiling`.
21
+ Changes ``_TEST_EXPORT`` to make it trigger.
22
+ """
23
+ return torch.compiler.is_exporting() or torch.compiler.is_compiling()
24
+
25
+
26
+ @dataclass
27
+ class VerifyResult:
28
+ """
29
+ Outputs of method :meth:`verify
30
+ <onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx.verify>`.
31
+ """
32
+
33
+ eager_outputs: TUPLE_TENSORS
34
+ onnx_outputs: TUPLE_TENSORS
35
+ diffs: Tuple[Dict[str, float], ...]
36
+
37
+
38
+ class EagerDirectReplacementWithOnnx:
39
+ """
40
+ Replaces a piece of code by another one written in ONNX
41
+ at export time. The function inserts a custom operator
42
+ and links it to the eager_fn
43
+
44
+ :param eager_fn: the code it replaces, it must be given in order to be able
45
+ to execute the torch.fx.Graph the exporter produces
46
+ :param shape_fn: the function produces dummy outputs with the shapes
47
+ the exporter can use for the next operators in the graph
48
+ :param function_proto: instances of ``onnx.FunctionProto``,
49
+ its domain must be ``onnx_plug``
50
+ :param n_inputs: number of inputs of the function, if not given,
51
+ the class will infer it from eager_fn signature,
52
+ only tensors must be counted
53
+ :param n_outputs: same for the number of outputs,
54
+ only tensors must be counted
55
+ :param name: the name of the custom op, the function name if not specified
56
+ :param kwargs: constants parameters with their default values
57
+ :param version_selector: selects the version based on the arguments,
58
+ see below for an example, this allows the user to define different
59
+ onnx version depending on the inputs
60
+ :param default_opset: opset to use by default
61
+ :param verbose: verbose level
62
+
63
+ Here is an example:
64
+
65
+ .. runpython::
66
+ :showcode:
67
+
68
+ import onnx.helper as oh
69
+ import torch
70
+ from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
71
+ from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
72
+ from onnx_diagnostic.export.api import to_onnx
73
+ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
74
+
75
+
76
+ def demo_customsub(x, y):
77
+ return x - y
78
+
79
+
80
+ def demo_customsub_shape(x, y):
81
+ return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)
82
+
83
+
84
+ def make_function_proto():
85
+ return oh.make_function(
86
+ "onnx_plug",
87
+ "demo_customsub",
88
+ ["x", "y"],
89
+ ["z"],
90
+ [oh.make_node("Sub", ["x", "y"], ["z"])],
91
+ opset_imports=[oh.make_opsetid("", 22)],
92
+ )
93
+
94
+
95
+ class Model(torch.nn.Module):
96
+ def forward(self, x):
97
+ y = x.sum(axis=1, keepdim=True)
98
+ d = torch.ops.onnx_plug.demo_customsub(x, y)
99
+ return torch.abs(d)
100
+
101
+
102
+ replacements = [
103
+ EagerDirectReplacementWithOnnx(
104
+ demo_customsub, demo_customsub_shape, make_function_proto(), 2, 1
105
+ )
106
+ ]
107
+
108
+ x = torch.randn((3, 4), dtype=torch.float32)
109
+ model = Model()
110
+ ds = ({0: "d1", 1: "d2"},)
111
+
112
+ # The exported program shows a custom op.
113
+ ep = torch.export.export(model, (x,), dynamic_shapes=use_dyn_not_str(ds))
114
+ print("ep")
115
+
116
+ # As the exporter knows how the replace this custom op.
117
+ # Let's export.
118
+
119
+ onx = to_onnx(
120
+ model,
121
+ (x,),
122
+ dynamic_shapes=ds,
123
+ exporter="custom",
124
+ onnx_plugs=replacements,
125
+ target_opset=22,
126
+ inline=False,
127
+ ).model_proto
128
+
129
+ print(pretty_onnx(onx))
130
+
131
+ # And with :func:`torch.onnx.export`:
132
+
133
+ onx = to_onnx(
134
+ model,
135
+ (x,),
136
+ dynamic_shapes=ds,
137
+ exporter="onnx-dynamo",
138
+ onnx_plugs=replacements,
139
+ target_opset=22,
140
+ inline=False,
141
+ ).model_proto
142
+
143
+ print(pretty_onnx(onx))
144
+
145
+ This shows how to define multiple versions depending on the device,
146
+ the type or the targeted onnx opset.
147
+
148
+ .. code-block:: python
149
+
150
+ def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
151
+ first_tensor = next(a for a in args if a is not None)
152
+ dtype = first_tensor.dtype
153
+ itype = torch_dtype_to_onnx_dtype(dtype)
154
+ if dtype == torch.float32:
155
+ if opset >= 24:
156
+ return "LOOPA24", itype
157
+ return "LOOPMHA", itype
158
+ if dtype == torch.float16:
159
+ if first_tensor.is_cuda:
160
+ return "PACKED", itype
161
+ return "LOOPMHA", itype
162
+ raise AssertionError(
163
+ f"Unable to handle type {torch.dtype} (itype={itype}) "
164
+ f"on device {torch.device} with opset={opset}"
165
+ )
166
+
167
+ qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
168
+ qwen_sdpa_attention,
169
+ lambda qs, *args, **kwargs: torch.empty(
170
+ (qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
171
+ dtype=qs.dtype,
172
+ device=qs.device,
173
+ ),
174
+ {
175
+ ("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset(
176
+ PackedAttention.to_function_proto()
177
+ ),
178
+ ("LOOPA24", onnx.TensorProto.FLOAT): LoopAttention24.to_function_proto(),
179
+ ("LOOPA24", onnx.TensorProto.FLOAT16): _update_sequence_type(
180
+ onnx.TensorProto.FLOAT16, LoopAttention24.to_function_proto()
181
+ ),
182
+ ("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset(
183
+ LoopMHAAttention.to_function_proto()
184
+ ),
185
+ ("LOOPMHA", onnx.TensorProto.FLOAT16): _update_sequence_type(
186
+ onnx.TensorProto.FLOAT16,
187
+ _add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
188
+ ),
189
+ },
190
+ n_inputs=4,
191
+ n_outputs=1,
192
+ kwargs=dict(scaling=0.11180339887498948, num_heads=16),
193
+ name="qwen_sdpa_attention_versatile",
194
+ version_selector=qwen_version_selector,
195
+ )
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ eager_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
201
+ shape_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
202
+ function_proto: Union[onnx.FunctionProto, Dict[Any, onnx.FunctionProto]],
203
+ n_inputs: Optional[int] = None,
204
+ n_outputs: Optional[int] = None,
205
+ name: Optional[str] = None,
206
+ kwargs: Optional[Dict[str, Union[int, float]]] = None,
207
+ verbose: int = 0,
208
+ version_selector: Optional[Callable[..., Tuple[Any, ...]]] = None,
209
+ default_opset: int = 22,
210
+ ):
211
+ assert isinstance(function_proto, onnx.FunctionProto) or (
212
+ isinstance(function_proto, dict)
213
+ or all(isinstance(v, onnx.FunctionProto) for v in function_proto.values())
214
+ ), f"Unexpected type {type(function_proto)} for function_proto"
215
+ assert isinstance(n_inputs, int), f"not implemented yet when n_inputs={n_inputs}"
216
+ assert isinstance(n_outputs, int), f"not implemented yet when n_outputs={n_outputs}"
217
+ self.eager_fn = eager_fn
218
+ self.shape_fn = shape_fn
219
+ self._function_proto = (
220
+ function_proto if isinstance(function_proto, onnx.FunctionProto) else None
221
+ )
222
+ self._function_proto_versioned = (
223
+ function_proto if isinstance(function_proto, dict) else {}
224
+ )
225
+ self.n_inputs = n_inputs
226
+ self.n_outputs = n_outputs
227
+ self.name = name or (
228
+ eager_fn.__name__
229
+ if "<" not in eager_fn.__name__
230
+ else eager_fn.__qualname__.replace("<locals>", "L")
231
+ .replace("<lambda>", "l")
232
+ .replace(".", "_")
233
+ )
234
+ self.kwargs = kwargs or {}
235
+ assert all(isinstance(v, (int, float)) for v in self.kwargs.values()), (
236
+ f"Only int or floats are allowed for kwargs={kwargs}, one of them "
237
+ f"does not respect that constraint."
238
+ )
239
+ sig = inspect.signature(self.eager_fn)
240
+ params = list(sig.parameters)
241
+ self.args_name = [p for p in params if p not in self.kwargs]
242
+ self.kwargs_name = [p for p in params if p in self.kwargs]
243
+ self.verbose = verbose
244
+ self.custom_op = self._register()
245
+ self.version_selector = version_selector
246
+ self.default_opset = default_opset
247
+ self._check_protos(params)
248
+
249
+ def _check_protos(self, params):
250
+ assert (
251
+ len(params) >= self.n_inputs
252
+ ), f"{self.eager_fn} accepts {params} as parameters < n_inputs={self.n_inputs}"
253
+
254
+ # one proto
255
+ assert self._function_proto is None or self.n_inputs == len(
256
+ self._function_proto.input
257
+ ), (
258
+ f"Input mismatch n_inputs={self.n_inputs} but "
259
+ f"function_proto.input={self._function_proto.input}"
260
+ )
261
+ assert self._function_proto is None or self.n_outputs == len(
262
+ self._function_proto.output
263
+ ), (
264
+ f"Output mismatch n_outputs={self.n_outputs} but "
265
+ f"function_proto.output={self._function_proto.output}"
266
+ )
267
+ assert self._function_proto is None or (
268
+ self._function_proto.domain == self.domain
269
+ ), f"Function domain must be {self.domain!r} but it is {self._function_proto.domain!r}"
270
+
271
+ # multiple protos
272
+ assert all(
273
+ self.n_inputs == len(v.input) for v in self._function_proto_versioned.values()
274
+ ), f"Output mismatch n_inputs={self.n_inputs} but one version is wrong"
275
+ assert all(
276
+ self.n_outputs == len(v.output) for v in self._function_proto_versioned.values()
277
+ ), f"Output mismatch n_outputs={self.n_outputs} but one version is wrong"
278
+ assert all(
279
+ v.domain == self.domain for v in self._function_proto_versioned.values()
280
+ ), f"Function domain must be {self.domain!r} but it is different in one version"
281
+ assert (
282
+ not self._function_proto_versioned or self.version_selector
283
+ ), "version_selector is needed when multiple protos are given."
284
+
285
+ def get_function_proto(self, opset: int, *args) -> onnx.FunctionProto:
286
+ """Returns the correct version based on the inputs."""
287
+ if self._function_proto:
288
+ return self._function_proto
289
+ assert isinstance(
290
+ opset, int
291
+ ), f"The first argument must be an integer for the onnx opset but it is {type(opset)}"
292
+ assert any(
293
+ a is not None for a in args
294
+ ), f"Unexpected args={string_type(args, with_shape=True)}"
295
+ try:
296
+ key = self.version_selector(opset, *args) # type: ignore[misc]
297
+ except (ValueError, AttributeError) as e:
298
+ raise AssertionError(
299
+ f"Unable to select a version, fails to get a key, available="
300
+ f"{set(self._function_proto_versioned)}, "
301
+ f"args={string_type(args,with_shape=True)}"
302
+ ) from e
303
+ assert key in self._function_proto_versioned, (
304
+ f"Unable to select a version, key={key}, available="
305
+ f"{set(self._function_proto_versioned)}, args={string_type(args,with_shape=True)}"
306
+ )
307
+ return self._function_proto_versioned[key]
308
+
309
+ @property
310
+ def domain(self) -> str:
311
+ "Returns the onnx domain."
312
+ return "onnx_plug"
313
+
314
+ @property
315
+ def target_name(self) -> str:
316
+ "Returns the target name (see in the exported program)."
317
+ return f"{self.domain}::{self.name}"
318
+
319
+ @property
320
+ def torch_op(self) -> Callable:
321
+ "Returns ``torch.ops.onny_plug.<name>``."
322
+ return getattr(getattr(torch.ops, self.domain), self.name).default
323
+
324
+ def __call__(self, *args, **kwargs):
325
+ """Calls eager_fn or shape_fn if the model is being exported."""
326
+ if is_exporting():
327
+ return self.torch_op(*args)
328
+ return self.eager_fn(*args, **kwargs)
329
+
330
+ def _register(self):
331
+ """Registers the custom op."""
332
+ input_args = [f"Tensor {p}" for p in self.args_name]
333
+ for p in self.kwargs_name:
334
+ val = self.kwargs[p]
335
+ if isinstance(val, int):
336
+ input_args.append(f"int {p}={val}")
337
+ elif isinstance(val, float):
338
+ input_args.append(f"float {p}={val}")
339
+ elif isinstance(val, str):
340
+ input_args.append(f"str {p}={val}")
341
+ else:
342
+ raise NotImplementedError(
343
+ f"kwargs {p!r} has a default value of unsupported type {type(val)}"
344
+ )
345
+
346
+ inputs = ", ".join(input_args)
347
+ schema = f"({inputs}) -> Tensor"
348
+ if self.n_outputs > 1:
349
+ schema += "[]"
350
+ if self.verbose:
351
+ print(
352
+ f"[EagerDirectReplacementWithOnnx._register] "
353
+ f"'torch.ops.{self.domain}.{self.name}"
354
+ )
355
+ print(f"[EagerDirectReplacementWithOnnx._register] schema={schema}")
356
+ custom_def = torch.library.CustomOpDef(self.domain, self.name, schema, self.eager_fn)
357
+ custom_def.register_kernel(None)(self.eager_fn)
358
+ custom_def._abstract_fn = self.shape_fn
359
+
360
+ def verify(
361
+ self,
362
+ *args,
363
+ engine: Optional[Callable] = None,
364
+ dump_onnx_model: Optional[str] = None,
365
+ opset: int = 22,
366
+ **kwargs,
367
+ ) -> VerifyResult:
368
+ """
369
+ Verifies that the eager mode is equivalent to the onnx function given
370
+ as a replacements. This function evaluates `eager_fn`, checks that the shapes
371
+ are equivalent to the ones given by `shape_fn`, and finally evaluates the
372
+ onnx translation if the previous did not fail.
373
+
374
+ :param args: function inputs
375
+ :param kwargs: arguments for eager_fn
376
+ :param engine: by default an instance of
377
+ :class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`.
378
+ :param dump_onnx_model: to dump the onnx model used to verify
379
+ eager and onnx produce the same results
380
+ :param opset: onnx opset to use
381
+ :param kwargs: additional arguments to the function
382
+ :return: outputs of :func:`onnx_diagnostic.helpers.max_diff`
383
+ """
384
+ expected = self.eager_fn(*args, **kwargs)
385
+ shapes = self.shape_fn(*args, **kwargs)
386
+ if isinstance(expected, torch.Tensor):
387
+ expected = (expected,)
388
+ assert isinstance(shapes, torch.Tensor), (
389
+ f"eager_fn={self.eager_fn} returns a Tensor but shape_fn={self.shape_fn} "
390
+ f"returns a {type(shapes)}"
391
+ )
392
+ shapes = (shapes,)
393
+ assert isinstance(expected, tuple) and isinstance(shapes, tuple), (
394
+ f"eager_fn={self.eager_fn} returns a {type(expected)} "
395
+ f"and shape_fn={self.shape_fn} returns a {type(shapes)}"
396
+ )
397
+ assert len(expected) and len(shapes), (
398
+ f"eager_fn={self.eager_fn} and shape_fn={self.shape_fn} "
399
+ f"do not return the same number of tensors."
400
+ )
401
+ for i, (e, s) in enumerate(zip(expected, shapes)):
402
+ assert e.dtype == s.dtype, (
403
+ f"Type mismatch {e.dtype} != {s.dtype} for output {i}, "
404
+ f"eager_fn={self.eager_fn} and shape_fn={self.shape_fn}"
405
+ )
406
+ assert e.shape == s.shape, (
407
+ f"Type mismatch {e.shape} != {s.shape} for output {i}, "
408
+ f"eager_fn={self.eager_fn} and shape_fn={self.shape_fn}"
409
+ )
410
+
411
+ # Now the ONNX execution.
412
+ assert engine is None, f"Not implemented yet with engine={engine!r}"
413
+ ags, kws = self._make_args_kwargs(*args, **kwargs)
414
+ sess = OnnxruntimeEvaluator(
415
+ self.get_function_proto(opset, *args),
416
+ whole=True,
417
+ dump_onnx_model=dump_onnx_model,
418
+ function_kwargs=kws,
419
+ )
420
+ feeds = dict(zip(sess.input_names, ags))
421
+ got = sess.run(None, feeds)
422
+ diffs = tuple(max_diff(e, g, hist=[0.1, 0.01]) for e, g in zip(expected, got))
423
+ return VerifyResult(eager_outputs=expected, onnx_outputs=tuple(got), diffs=diffs) # type: ignore[arg-type]
424
+
425
+ def _make_args_kwargs(self, *args, **kwargs):
426
+ ags = args[: len(self.args_name)]
427
+ kws = dict(zip(self.kwargs_name, args[len(self.args_name) :]))
428
+ kws.update(kwargs)
429
+ return ags, kws
430
+
431
+ def custom_converter(
432
+ self,
433
+ ) -> Callable:
434
+ """
435
+ Returns a function which
436
+ converts a custom ops found in the fx graph into ONNX
437
+ following the API of the custom exporter.
438
+ The converter adds a custom op and registers the local function.
439
+ """
440
+
441
+ def converter(
442
+ g: Any, # GraphBuilder
443
+ sts: Optional[Dict[str, Any]],
444
+ outputs: List[str],
445
+ *args,
446
+ **kwargs,
447
+ ) -> Any:
448
+ has_devices = [a for a in args if isinstance(a, str) and g.has_device(a)]
449
+ assert (
450
+ has_devices
451
+ ), f"Missing device for any of the inputs {args}{g.get_debug_msg()}"
452
+ arg_device = has_devices[0]
453
+ fake_tensor = torch.empty(
454
+ tuple([(_ if isinstance(_, int) else 2) for _ in g.get_shape(args[0])]),
455
+ dtype=onnx_dtype_to_torch_dtype(g.get_type(args[0])),
456
+ device=int_device_to_torch_device(g.get_device(arg_device)),
457
+ )
458
+ function_proto = self.get_function_proto(g.main_opset, fake_tensor)
459
+ if not g.has_local_function(function_proto.name, domain=function_proto.domain):
460
+ g.add_function(function_proto)
461
+ ags, kws = self._make_args_kwargs(*args, **kwargs)
462
+ res = g.make_node(
463
+ function_proto.name,
464
+ ags,
465
+ outputs,
466
+ domain=function_proto.domain,
467
+ name=self.target_name,
468
+ **kws,
469
+ )
470
+ if not sts:
471
+ new_shapes = self.shape_fn(*args)
472
+ if not isinstance(new_shapes, tuple):
473
+ new_shapes = (new_shapes,)
474
+ for sh, o in zip(new_shapes, outputs):
475
+ g.set_type(o, torch_dtype_to_onnx_dtype(sh.dtype))
476
+ g.set_shape(o, sh.shape)
477
+ return res
478
+
479
+ return converter
480
+
481
+ def onnx_dynamo_converter(self) -> Callable:
482
+ """
483
+ Returns a function which
484
+ which converts a custom ops found in the fx graph into ONNX
485
+ following the API of :func:`torch.onnx.export`.
486
+ """
487
+ import onnxscript
488
+
489
+ onnx_plug_op = onnxscript.values.Opset(domain=self.domain, version=1)
490
+
491
+ def get_proto(*args):
492
+ function_proto = self.get_function_proto(self.default_opset, *args)
493
+ schema = onnx_plug_op[function_proto.name]
494
+ if schema is None:
495
+ all_types = [
496
+ "tensor(float)",
497
+ "tensor(float16)",
498
+ "tensor(bfloat16)",
499
+ "tensor(double)",
500
+ "tensor(int64)",
501
+ "tensor(int32)",
502
+ ]
503
+ type_constraints = []
504
+ for i in range(self.n_inputs):
505
+ type_constraints.append((f"T{i}", all_types, ""))
506
+ for i in range(self.n_outputs):
507
+ type_constraints.append((f"U{i}", all_types, ""))
508
+ schema = onnx.defs.OpSchema(
509
+ function_proto.name,
510
+ function_proto.domain,
511
+ 1,
512
+ inputs=[
513
+ onnx.defs.OpSchema.FormalParameter(f"arg_{i}", f"T{i}")
514
+ for i in range(self.n_inputs)
515
+ ],
516
+ outputs=[
517
+ onnx.defs.OpSchema.FormalParameter(f"res_{i}", f"U{i}")
518
+ for i in range(self.n_outputs)
519
+ ],
520
+ type_constraints=type_constraints,
521
+ )
522
+ onnx.defs.register_schema(schema)
523
+ op = onnxscript.values.Op(onnx_plug_op, function_proto.name, schema)
524
+ return op
525
+
526
+ def converter(*cargs, **ckwargs):
527
+ ags, kws = self._make_args_kwargs(*cargs, **ckwargs)
528
+ op = get_proto(*cargs)
529
+ return op(*ags, n_outputs=self.n_outputs, **kws)
530
+
531
+ return onnxscript.values.TracedOnnxFunction(onnx_plug_op, converter)