onnx-diagnostic 0.8.3__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.8.3"
6
+ __version__ = "0.8.4"
7
7
  __author__ = "Xavier Dupré"
@@ -1286,7 +1286,13 @@ def get_parser_sbs() -> ArgumentParser:
1286
1286
  "--first",
1287
1287
  action=BooleanOptionalAction,
1288
1288
  default=False,
1289
- help="First runs the whole model.",
1289
+ help="First runs the whole model (default is False).",
1290
+ )
1291
+ parser.add_argument(
1292
+ "--sbs",
1293
+ action=BooleanOptionalAction,
1294
+ default=True,
1295
+ help="Runs the side-by-side (default is True).",
1290
1296
  )
1291
1297
  parser.add_argument(
1292
1298
  "-2",
@@ -1342,6 +1348,20 @@ def get_parser_sbs() -> ArgumentParser:
1342
1348
  default="replay",
1343
1349
  help="If the replay is triggered, this defines the folder where everything is dumped.",
1344
1350
  )
1351
+ parser.add_argument(
1352
+ "-p",
1353
+ "--replay-prefix-model",
1354
+ action=BooleanOptionalAction,
1355
+ default=False,
1356
+ help=textwrap.dedent(
1357
+ """
1358
+ There are two ways to recompute an intermediate output, the first one is to "
1359
+ produce the minimal model between torch and onnx.
1360
+ The second one is to dump onnx models from the inputs
1361
+ to the considered intermediate results. This enables the second one.
1362
+ """
1363
+ ),
1364
+ )
1345
1365
 
1346
1366
  return parser
1347
1367
 
@@ -1417,6 +1437,10 @@ def _cmd_sbs(argv: List[Any]):
1417
1437
  print("-- done")
1418
1438
  del sess
1419
1439
 
1440
+ if not args.sbs:
1441
+ print("-- done")
1442
+ return
1443
+
1420
1444
  print(f"-- load onnx {args.onnx!r}")
1421
1445
  begin = time.perf_counter()
1422
1446
  onx = onnx.load(args.onnx)
@@ -1431,6 +1455,7 @@ def _cmd_sbs(argv: List[Any]):
1431
1455
  set(args.replay_op_types.split(",")) if args.replay_op_types else None
1432
1456
  ),
1433
1457
  dump_folder=args.replay_folder,
1458
+ dump_prefix_model=args.replay_prefix_model,
1434
1459
  )
1435
1460
 
1436
1461
  print("-- starts side-by-side")
@@ -3,6 +3,52 @@ import torch
3
3
  from .onnx_plug import EagerDirectReplacementWithOnnx
4
4
 
5
5
 
6
+ def get_main_dispatcher(
7
+ use_control_flow_dispatcher: bool = False,
8
+ onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
9
+ ) -> Any: # Dispatcher
10
+ """Creates a custom dispatcher for the custom exporter."""
11
+ from experimental_experiment.torch_interpreter import Dispatcher
12
+
13
+ if use_control_flow_dispatcher:
14
+ from .control_flow_onnx import create_global_dispatcher
15
+
16
+ control_flow_dispatcher = create_global_dispatcher()
17
+ else:
18
+ control_flow_dispatcher = None
19
+
20
+ class MainDispatcher(Dispatcher):
21
+ def __init__(self, previous_dispatcher=None):
22
+ super().__init__({})
23
+ self.previous_dispatcher = previous_dispatcher
24
+
25
+ @property
26
+ def supported(self):
27
+ if self.previous_dispatcher:
28
+ return set(self.registered_functions) | self.previous_dispatcher.supported
29
+ return set(self.registered_functions)
30
+
31
+ def find_function(self, name: Any):
32
+ if self.previous_dispatcher:
33
+ find = self.previous_dispatcher.find_function(name)
34
+ if find:
35
+ return find
36
+ return Dispatcher.find_function(self, name)
37
+
38
+ def find_method(self, name: Any):
39
+ if self.previous_dispatcher:
40
+ find = self.previous_dispatcher.find_method(name)
41
+ if find:
42
+ return find
43
+ return Dispatcher.find_method(self, name)
44
+
45
+ main_dispatcher = MainDispatcher(control_flow_dispatcher)
46
+ if onnx_plugs:
47
+ for plug in onnx_plugs:
48
+ main_dispatcher.registered_functions[plug.target_name] = plug.custom_converter()
49
+ return main_dispatcher
50
+
51
+
6
52
  def to_onnx(
7
53
  mod: Union["torch.nn.Module", "torch.fx.GraphModule"], # noqa: F821
8
54
  args: Optional[Sequence["torch.Tensor"]] = None, # noqa: F821
@@ -82,51 +128,11 @@ def to_onnx(
82
128
  options = exporter_kwargs.pop("options", None)
83
129
  if options is None:
84
130
  options = OptimizationOptions(patterns="default+onnxruntime")
85
- if onnx_plugs or use_control_flow_dispatcher:
86
- from experimental_experiment.torch_interpreter import Dispatcher
87
-
88
- if use_control_flow_dispatcher:
89
- from .control_flow_onnx import create_global_dispatcher
90
-
91
- control_flow_dispatcher = create_global_dispatcher()
92
- else:
93
- control_flow_dispatcher = None
94
-
95
- class MainDispatcher(Dispatcher):
96
- def __init__(self, previous_dispatcher=None):
97
- super().__init__({})
98
- self.previous_dispatcher = previous_dispatcher
99
-
100
- @property
101
- def supported(self):
102
- if self.previous_dispatcher:
103
- return (
104
- set(self.registered_functions) | self.previous_dispatcher.supported
105
- )
106
- return set(self.registered_functions)
107
-
108
- def find_function(self, name: Any):
109
- if self.previous_dispatcher:
110
- find = self.previous_dispatcher.find_function(name)
111
- if find:
112
- return find
113
- return Dispatcher.find_function(self, name)
114
-
115
- def find_method(self, name: Any):
116
- if self.previous_dispatcher:
117
- find = self.previous_dispatcher.find_method(name)
118
- if find:
119
- return find
120
- return Dispatcher.find_method(self, name)
121
-
122
- main_dispatcher = MainDispatcher(control_flow_dispatcher)
123
- if onnx_plugs:
124
- for plug in onnx_plugs:
125
- main_dispatcher.registered_functions[plug.target_name] = (
126
- plug.custom_converter()
127
- )
128
- else:
129
- main_dispatcher = None
131
+ main_dispatcher = (
132
+ get_main_dispatcher(use_control_flow_dispatcher, onnx_plugs)
133
+ if onnx_plugs or use_control_flow_dispatcher
134
+ else None
135
+ )
130
136
 
131
137
  return _to_onnx(
132
138
  mod,
@@ -149,6 +155,7 @@ def to_onnx(
149
155
 
150
156
  if exporter in ("dynamo", "onnx-dynamo"):
151
157
  import os
158
+ from ..helpers import flatten_object
152
159
  import onnxscript.rewriter.ort_fusions as ort_fusions
153
160
 
154
161
  assert (
@@ -180,7 +187,20 @@ def to_onnx(
180
187
  import onnx_ir as ir
181
188
  import onnx_ir.passes.common as common_passes
182
189
 
183
- irfunctions = [ir.from_proto(plug.function_proto) for plug in onnx_plugs]
190
+ opset = (
191
+ 18
192
+ if target_opset is None
193
+ else (target_opset if isinstance(target_opset, int) else target_opset[""])
194
+ )
195
+
196
+ irfunctions = [
197
+ ir.from_proto(
198
+ plug.get_function_proto(
199
+ opset, *flatten_object((args, kwargs), drop_keys=True)
200
+ )
201
+ )
202
+ for plug in onnx_plugs
203
+ ]
184
204
  for func in irfunctions:
185
205
  epo.model.functions[func.identifier()] = func
186
206
  if inline:
@@ -92,10 +92,11 @@ def simple_loop_for(
92
92
 
93
93
  from torch._higher_order_ops.utils import setup_compilation_env
94
94
 
95
- with setup_compilation_env() as backend:
96
- return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)(
97
- n_iter, body_fn, operands
98
- )
95
+ with setup_compilation_env() as _backend:
96
+ return _loop_for_op_wrapper(n_iter, body_fn, *operands)
97
+ # return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)(
98
+ # n_iter, body_fn, operands
99
+ # )
99
100
 
100
101
 
101
102
  def trace_loop_for(proxy_mode, func_overload, n_iter, body_fn, operands):
@@ -127,9 +128,13 @@ def loop_for_op_dense(n_iter, body_fn, operands):
127
128
  ), f"Dense implementation operands must be a list of tensors and ints {operands}"
128
129
  mode = _get_current_dispatch_mode()
129
130
  assert mode is None, "Mode should never be enabled for CPU/CUDA key"
130
- return _loop_for_onnx_fn(body_fn, n_iter, None, *operands)
131
+ return _loop_for_onnx_fn(body_fn, n_iter, None, operands)
131
132
 
132
133
 
133
134
  @simple_loop_for_op.py_impl(ProxyTorchDispatchMode)
134
135
  def inner(mode, n_iter, body_fn, operands):
135
136
  return trace_loop_for(mode, simple_loop_for_op, n_iter, body_fn, operands)
137
+
138
+
139
+ simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCPU)
140
+ simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCUDA)
@@ -3,8 +3,12 @@ from dataclasses import dataclass
3
3
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
4
  import onnx
5
5
  import torch
6
- from ..helpers import max_diff
7
- from ..helpers.torch_helper import torch_dtype_to_onnx_dtype
6
+ from ..helpers import max_diff, string_type
7
+ from ..helpers.torch_helper import (
8
+ torch_dtype_to_onnx_dtype,
9
+ onnx_dtype_to_torch_dtype,
10
+ int_device_to_torch_device,
11
+ )
8
12
  from ..reference import OnnxruntimeEvaluator
9
13
 
10
14
  TUPLE_TENSORS = Tuple[torch.Tensor, ...]
@@ -50,6 +54,10 @@ class EagerDirectReplacementWithOnnx:
50
54
  only tensors must be counted
51
55
  :param name: the name of the custom op, the function name if not specified
52
56
  :param kwargs: constants parameters with their default values
57
+ :param version_selector: selects the version based on the arguments,
58
+ see below for an example, this allows the user to define different
59
+ onnx version depending on the inputs
60
+ :param default_opset: opset to use by default
53
61
  :param verbose: verbose level
54
62
 
55
63
  Here is an example:
@@ -133,27 +141,87 @@ class EagerDirectReplacementWithOnnx:
133
141
  ).model_proto
134
142
 
135
143
  print(pretty_onnx(onx))
144
+
145
+ This shows how to define multiple versions depending on the device,
146
+ the type or the targeted onnx opset.
147
+
148
+ .. code-block:: python
149
+
150
+ def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
151
+ first_tensor = next(a for a in args if a is not None)
152
+ dtype = first_tensor.dtype
153
+ itype = torch_dtype_to_onnx_dtype(dtype)
154
+ if dtype == torch.float32:
155
+ if opset >= 24:
156
+ return "LOOPA24", itype
157
+ return "LOOPMHA", itype
158
+ if dtype == torch.float16:
159
+ if first_tensor.is_cuda:
160
+ return "PACKED", itype
161
+ return "LOOPMHA", itype
162
+ raise AssertionError(
163
+ f"Unable to handle type {torch.dtype} (itype={itype}) "
164
+ f"on device {torch.device} with opset={opset}"
165
+ )
166
+
167
+ qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
168
+ qwen_sdpa_attention,
169
+ lambda qs, *args, **kwargs: torch.empty(
170
+ (qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
171
+ dtype=qs.dtype,
172
+ device=qs.device,
173
+ ),
174
+ {
175
+ ("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset(
176
+ PackedAttention.to_function_proto()
177
+ ),
178
+ ("LOOPA24", onnx.TensorProto.FLOAT): LoopAttention24.to_function_proto(),
179
+ ("LOOPA24", onnx.TensorProto.FLOAT16): _update_sequence_type(
180
+ onnx.TensorProto.FLOAT16, LoopAttention24.to_function_proto()
181
+ ),
182
+ ("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset(
183
+ LoopMHAAttention.to_function_proto()
184
+ ),
185
+ ("LOOPMHA", onnx.TensorProto.FLOAT16): _update_sequence_type(
186
+ onnx.TensorProto.FLOAT16,
187
+ _add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
188
+ ),
189
+ },
190
+ n_inputs=4,
191
+ n_outputs=1,
192
+ kwargs=dict(scaling=0.11180339887498948, num_heads=16),
193
+ name="qwen_sdpa_attention_versatile",
194
+ version_selector=qwen_version_selector,
195
+ )
136
196
  """
137
197
 
138
198
  def __init__(
139
199
  self,
140
200
  eager_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
141
201
  shape_fn: Callable[[TUPLE_TENSORS], TUPLE_TENSORS],
142
- function_proto: onnx.FunctionProto,
202
+ function_proto: Union[onnx.FunctionProto, Dict[Any, onnx.FunctionProto]],
143
203
  n_inputs: Optional[int] = None,
144
204
  n_outputs: Optional[int] = None,
145
205
  name: Optional[str] = None,
146
206
  kwargs: Optional[Dict[str, Union[int, float]]] = None,
147
207
  verbose: int = 0,
208
+ version_selector: Optional[Callable[..., Tuple[Any, ...]]] = None,
209
+ default_opset: int = 22,
148
210
  ):
149
- assert isinstance(
150
- function_proto, onnx.FunctionProto
211
+ assert isinstance(function_proto, onnx.FunctionProto) or (
212
+ isinstance(function_proto, dict)
213
+ or all(isinstance(v, onnx.FunctionProto) for v in function_proto.values())
151
214
  ), f"Unexpected type {type(function_proto)} for function_proto"
152
215
  assert isinstance(n_inputs, int), f"not implemented yet when n_inputs={n_inputs}"
153
- assert isinstance(n_outputs, int), f"not implemented yet when n_inputs={n_outputs}"
216
+ assert isinstance(n_outputs, int), f"not implemented yet when n_outputs={n_outputs}"
154
217
  self.eager_fn = eager_fn
155
218
  self.shape_fn = shape_fn
156
- self.function_proto = function_proto
219
+ self._function_proto = (
220
+ function_proto if isinstance(function_proto, onnx.FunctionProto) else None
221
+ )
222
+ self._function_proto_versioned = (
223
+ function_proto if isinstance(function_proto, dict) else {}
224
+ )
157
225
  self.n_inputs = n_inputs
158
226
  self.n_outputs = n_outputs
159
227
  self.name = name or (
@@ -170,24 +238,73 @@ class EagerDirectReplacementWithOnnx:
170
238
  )
171
239
  sig = inspect.signature(self.eager_fn)
172
240
  params = list(sig.parameters)
173
- assert (
174
- len(params) >= n_inputs
175
- ), f"{self.eager_fn} accepts {params} as parameters < n_inputs={n_inputs}"
176
- assert n_inputs == len(function_proto.input), (
177
- f"Input mismatch n_inputs={n_inputs} but "
178
- f"function_proto.input={function_proto.input}"
179
- )
180
- assert n_outputs == len(function_proto.output), (
181
- f"Output mismatch n_outputs={n_outputs} but "
182
- f"function_proto.output={function_proto.output}"
183
- )
184
- assert (
185
- function_proto.domain == self.domain
186
- ), f"Function domain must be {self.domain!r} but it is {function_proto.domain!r}"
187
241
  self.args_name = [p for p in params if p not in self.kwargs]
188
242
  self.kwargs_name = [p for p in params if p in self.kwargs]
189
243
  self.verbose = verbose
190
244
  self.custom_op = self._register()
245
+ self.version_selector = version_selector
246
+ self.default_opset = default_opset
247
+ self._check_protos(params)
248
+
249
+ def _check_protos(self, params):
250
+ assert (
251
+ len(params) >= self.n_inputs
252
+ ), f"{self.eager_fn} accepts {params} as parameters < n_inputs={self.n_inputs}"
253
+
254
+ # one proto
255
+ assert self._function_proto is None or self.n_inputs == len(
256
+ self._function_proto.input
257
+ ), (
258
+ f"Input mismatch n_inputs={self.n_inputs} but "
259
+ f"function_proto.input={self._function_proto.input}"
260
+ )
261
+ assert self._function_proto is None or self.n_outputs == len(
262
+ self._function_proto.output
263
+ ), (
264
+ f"Output mismatch n_outputs={self.n_outputs} but "
265
+ f"function_proto.output={self._function_proto.output}"
266
+ )
267
+ assert self._function_proto is None or (
268
+ self._function_proto.domain == self.domain
269
+ ), f"Function domain must be {self.domain!r} but it is {self._function_proto.domain!r}"
270
+
271
+ # multiple protos
272
+ assert all(
273
+ self.n_inputs == len(v.input) for v in self._function_proto_versioned.values()
274
+ ), f"Output mismatch n_inputs={self.n_inputs} but one version is wrong"
275
+ assert all(
276
+ self.n_outputs == len(v.output) for v in self._function_proto_versioned.values()
277
+ ), f"Output mismatch n_outputs={self.n_outputs} but one version is wrong"
278
+ assert all(
279
+ v.domain == self.domain for v in self._function_proto_versioned.values()
280
+ ), f"Function domain must be {self.domain!r} but it is different in one version"
281
+ assert (
282
+ not self._function_proto_versioned or self.version_selector
283
+ ), "version_selector is needed when multiple protos are given."
284
+
285
+ def get_function_proto(self, opset: int, *args) -> onnx.FunctionProto:
286
+ """Returns the correct version based on the inputs."""
287
+ if self._function_proto:
288
+ return self._function_proto
289
+ assert isinstance(
290
+ opset, int
291
+ ), f"The first argument must be an integer for the onnx opset but it is {type(opset)}"
292
+ assert any(
293
+ a is not None for a in args
294
+ ), f"Unexpected args={string_type(args, with_shape=True)}"
295
+ try:
296
+ key = self.version_selector(opset, *args) # type: ignore[misc]
297
+ except (ValueError, AttributeError) as e:
298
+ raise AssertionError(
299
+ f"Unable to select a version, fails to get a key, available="
300
+ f"{set(self._function_proto_versioned)}, "
301
+ f"args={string_type(args,with_shape=True)}"
302
+ ) from e
303
+ assert key in self._function_proto_versioned, (
304
+ f"Unable to select a version, key={key}, available="
305
+ f"{set(self._function_proto_versioned)}, args={string_type(args,with_shape=True)}"
306
+ )
307
+ return self._function_proto_versioned[key]
191
308
 
192
309
  @property
193
310
  def domain(self) -> str:
@@ -219,6 +336,8 @@ class EagerDirectReplacementWithOnnx:
219
336
  input_args.append(f"int {p}={val}")
220
337
  elif isinstance(val, float):
221
338
  input_args.append(f"float {p}={val}")
339
+ elif isinstance(val, str):
340
+ input_args.append(f"str {p}={val}")
222
341
  else:
223
342
  raise NotImplementedError(
224
343
  f"kwargs {p!r} has a default value of unsupported type {type(val)}"
@@ -243,6 +362,7 @@ class EagerDirectReplacementWithOnnx:
243
362
  *args,
244
363
  engine: Optional[Callable] = None,
245
364
  dump_onnx_model: Optional[str] = None,
365
+ opset: int = 22,
246
366
  **kwargs,
247
367
  ) -> VerifyResult:
248
368
  """
@@ -257,6 +377,7 @@ class EagerDirectReplacementWithOnnx:
257
377
  :class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`.
258
378
  :param dump_onnx_model: to dump the onnx model used to verify
259
379
  eager and onnx produce the same results
380
+ :param opset: onnx opset to use
260
381
  :param kwargs: additional arguments to the function
261
382
  :return: outputs of :func:`onnx_diagnostic.helpers.max_diff`
262
383
  """
@@ -291,7 +412,7 @@ class EagerDirectReplacementWithOnnx:
291
412
  assert engine is None, f"Not implemented yet with engine={engine!r}"
292
413
  ags, kws = self._make_args_kwargs(*args, **kwargs)
293
414
  sess = OnnxruntimeEvaluator(
294
- self.function_proto,
415
+ self.get_function_proto(opset, *args),
295
416
  whole=True,
296
417
  dump_onnx_model=dump_onnx_model,
297
418
  function_kwargs=kws,
@@ -324,16 +445,25 @@ class EagerDirectReplacementWithOnnx:
324
445
  *args,
325
446
  **kwargs,
326
447
  ) -> Any:
327
- if not g.has_local_function(
328
- self.function_proto.name, domain=self.function_proto.domain
329
- ):
330
- g.add_function(self.function_proto)
448
+ has_devices = [a for a in args if isinstance(a, str) and g.has_device(a)]
449
+ assert (
450
+ has_devices
451
+ ), f"Missing device for any of the inputs {args}{g.get_debug_msg()}"
452
+ arg_device = has_devices[0]
453
+ fake_tensor = torch.empty(
454
+ tuple([(_ if isinstance(_, int) else 2) for _ in g.get_shape(args[0])]),
455
+ dtype=onnx_dtype_to_torch_dtype(g.get_type(args[0])),
456
+ device=int_device_to_torch_device(g.get_device(arg_device)),
457
+ )
458
+ function_proto = self.get_function_proto(g.main_opset, fake_tensor)
459
+ if not g.has_local_function(function_proto.name, domain=function_proto.domain):
460
+ g.add_function(function_proto)
331
461
  ags, kws = self._make_args_kwargs(*args, **kwargs)
332
462
  res = g.make_node(
333
- self.function_proto.name,
463
+ function_proto.name,
334
464
  ags,
335
465
  outputs,
336
- domain=self.function_proto.domain,
466
+ domain=function_proto.domain,
337
467
  name=self.target_name,
338
468
  **kws,
339
469
  )
@@ -356,41 +486,46 @@ class EagerDirectReplacementWithOnnx:
356
486
  """
357
487
  import onnxscript
358
488
 
359
- onnx_plug_op = onnxscript.values.Opset(domain=self.function_proto.domain, version=1)
360
- schema = onnx_plug_op[self.function_proto.name]
361
- if schema is None:
362
- all_types = [
363
- "tensor(float)",
364
- "tensor(float16)",
365
- "tensor(bfloat16)",
366
- "tensor(double)",
367
- "tensor(int64)",
368
- "tensor(int32)",
369
- ]
370
- type_constraints = []
371
- for i in range(self.n_inputs):
372
- type_constraints.append((f"T{i}", all_types, ""))
373
- for i in range(self.n_outputs):
374
- type_constraints.append((f"U{i}", all_types, ""))
375
- schema = onnx.defs.OpSchema(
376
- self.function_proto.name,
377
- self.function_proto.domain,
378
- 1,
379
- inputs=[
380
- onnx.defs.OpSchema.FormalParameter(f"arg_{i}", f"T{i}")
381
- for i in range(self.n_inputs)
382
- ],
383
- outputs=[
384
- onnx.defs.OpSchema.FormalParameter(f"res_{i}", f"U{i}")
385
- for i in range(self.n_outputs)
386
- ],
387
- type_constraints=type_constraints,
388
- )
389
- onnx.defs.register_schema(schema)
390
- op = onnxscript.values.Op(onnx_plug_op, self.function_proto.name, schema)
489
+ onnx_plug_op = onnxscript.values.Opset(domain=self.domain, version=1)
490
+
491
+ def get_proto(*args):
492
+ function_proto = self.get_function_proto(self.default_opset, *args)
493
+ schema = onnx_plug_op[function_proto.name]
494
+ if schema is None:
495
+ all_types = [
496
+ "tensor(float)",
497
+ "tensor(float16)",
498
+ "tensor(bfloat16)",
499
+ "tensor(double)",
500
+ "tensor(int64)",
501
+ "tensor(int32)",
502
+ ]
503
+ type_constraints = []
504
+ for i in range(self.n_inputs):
505
+ type_constraints.append((f"T{i}", all_types, ""))
506
+ for i in range(self.n_outputs):
507
+ type_constraints.append((f"U{i}", all_types, ""))
508
+ schema = onnx.defs.OpSchema(
509
+ function_proto.name,
510
+ function_proto.domain,
511
+ 1,
512
+ inputs=[
513
+ onnx.defs.OpSchema.FormalParameter(f"arg_{i}", f"T{i}")
514
+ for i in range(self.n_inputs)
515
+ ],
516
+ outputs=[
517
+ onnx.defs.OpSchema.FormalParameter(f"res_{i}", f"U{i}")
518
+ for i in range(self.n_outputs)
519
+ ],
520
+ type_constraints=type_constraints,
521
+ )
522
+ onnx.defs.register_schema(schema)
523
+ op = onnxscript.values.Op(onnx_plug_op, function_proto.name, schema)
524
+ return op
391
525
 
392
526
  def converter(*cargs, **ckwargs):
393
527
  ags, kws = self._make_args_kwargs(*cargs, **ckwargs)
528
+ op = get_proto(*cargs)
394
529
  return op(*ags, n_outputs=self.n_outputs, **kws)
395
530
 
396
531
  return onnxscript.values.TracedOnnxFunction(onnx_plug_op, converter)