onnx-diagnostic 0.8.4__py3-none-any.whl → 0.8.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,352 @@
1
+ import contextlib
2
+ from typing import Callable, List, Optional, Sequence, Tuple, Union
3
+ import torch
4
+ from torch._C import DispatchKey
5
+ from torch._ops import HigherOrderOperator
6
+ from torch._subclasses.fake_tensor import FakeTensorMode
7
+ import torch.utils._pytree as pytree
8
+ from torch._higher_order_ops.utils import (
9
+ check_input_alias_and_mutation_return_outputs,
10
+ reenter_make_fx,
11
+ unique_graph_id,
12
+ validate_subgraph_args_types,
13
+ )
14
+ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
15
+ from torch.utils._python_dispatch import _get_current_dispatch_mode
16
+
17
+
18
+ class SimpleLoopForOp(HigherOrderOperator):
19
+ """Higher order op for :func:`simple_loop_for`."""
20
+
21
+ def __init__(self):
22
+ super().__init__("simple_loop_for")
23
+
24
+ def __call__(self, n_iter, body_fn, operands, concatenation_dims=None):
25
+ validate_subgraph_args_types(operands)
26
+ return super().__call__(n_iter, body_fn, operands, concatenation_dims)
27
+
28
+ def gen_schema(self, n_iter, body_fn, operands, concatenation_dims):
29
+ from torch._higher_order_ops.schema import HopSchemaGenerator
30
+ from torch._higher_order_ops.utils import materialize_as_graph
31
+
32
+ body_gm: torch.fx.GraphModule = materialize_as_graph( # type: ignore[annotation-unchecked]
33
+ body_fn, (torch.tensor(0, dtype=torch.int64), *operands)
34
+ )
35
+ (
36
+ _,
37
+ _,
38
+ _,
39
+ body_mutated_inputs,
40
+ body_outputs,
41
+ ) = check_input_alias_and_mutation_return_outputs(body_gm)
42
+ mutated_inputs = body_mutated_inputs
43
+
44
+ schema_gen = HopSchemaGenerator(self)
45
+ schema_gen.add_arg("n_iter", n_iter)
46
+ schema_gen.add_arg("body_fn", body_gm)
47
+ for idx, arg in enumerate(operands):
48
+ schema_gen.add_arg(f"operand{idx}", arg, is_mutated=idx in mutated_inputs)
49
+
50
+ for out in body_outputs:
51
+ schema_gen.add_output(out)
52
+ assert concatenation_dims is None or len(concatenation_dims) == len(body_outputs), (
53
+ f"concatenation_dims={concatenation_dims} but its length should be equal to "
54
+ f"the number of outputs ({len(body_outputs)})"
55
+ )
56
+ schema_gen.add_schema_tree_spec(n_iter, body_fn, operands, concatenation_dims)
57
+ return schema_gen.gen_schema()
58
+
59
+
60
+ simple_loop_for_op = SimpleLoopForOp()
61
+
62
+
63
+ def _simple_loop_for_fn(
64
+ n_iter: torch.Tensor,
65
+ body_fn: Callable,
66
+ operands: Tuple[torch.Tensor, ...] = (),
67
+ concatenation_dims: Optional[Sequence[int]] = None,
68
+ ) -> Tuple[torch.Tensor, ...]:
69
+ """
70
+ Python implementation of the loop.
71
+
72
+ :param n_iter: number of iteration
73
+ :param body_fn: function implementing the body
74
+ :param concatenation_dims: dimension used to reduce the list produced by the loop
75
+ :param operands: arguments to the loop body
76
+ :return: results
77
+ """
78
+ torch._check(
79
+ isinstance(n_iter, (int, torch.Tensor)),
80
+ lambda: f"Unexpected type {type(n_iter)} for n_iter",
81
+ )
82
+ torch._check(callable(body_fn), lambda: f"Unexpected type {type(body_fn)} for body_fn")
83
+ torch._check(
84
+ concatenation_dims is None or isinstance(concatenation_dims, (list, tuple)),
85
+ lambda: f"Unexpected type {type(concatenation_dims)} for concatenation_dims",
86
+ )
87
+ torch._check(
88
+ isinstance(operands, tuple), lambda: f"Unexpected type {type(operands)} for operands"
89
+ )
90
+ res: List[Union[torch.Tensor, Tuple[torch.Tensor, ...]]] = []
91
+ for i in torch.arange(
92
+ n_iter, dtype=torch.int64 if isinstance(n_iter, int) else n_iter.dtype
93
+ ):
94
+ r = body_fn(i, *operands)
95
+ if isinstance(r, tuple):
96
+ assert not res or len(r) == len(res[-1]), (
97
+ f"Unexpected number of results {len(r)} for function {body_fn}, "
98
+ f"expected {len(res[-1])}"
99
+ )
100
+ res.append(r)
101
+ else:
102
+ assert isinstance(r, torch.Tensor), (
103
+ f"Unexpected type {r} for function {body_fn}, "
104
+ f"it must be a tuple or a Tensor."
105
+ )
106
+ assert not res or len(res[-1]) == 1, (
107
+ f"Unexpected number of results {len(r)} for function {body_fn}, "
108
+ f"expected {len(res[-1])}"
109
+ )
110
+ res.append((r,))
111
+
112
+ if not res:
113
+ return torch.empty(tuple(), dtype=torch.float32, device=operands[0].device)
114
+
115
+ n_res = len(res[0])
116
+ return tuple(
117
+ torch.cat(
118
+ [r[i] for r in res],
119
+ dim=(
120
+ 0
121
+ if concatenation_dims is None or i >= len(concatenation_dims)
122
+ else concatenation_dims[i]
123
+ ),
124
+ )
125
+ for i in range(n_res)
126
+ )
127
+
128
+
129
+ # from torch._functorch.utils import exposed_in
130
+ # @exposed_in("torch")
131
+ def _simple_loop_for(
132
+ n_iter: Union[int, torch.Tensor],
133
+ body_fn: Callable,
134
+ operands: Tuple[torch.Tensor, ...] = (),
135
+ concatenation_dims: Optional[Sequence[int]] = None,
136
+ ) -> Tuple[torch.Tensor, ...]:
137
+ def _validate_input(n_iter, body_fn, operands, concatenation_dims):
138
+ assert isinstance(
139
+ n_iter, (int, torch.Tensor, torch.SymInt)
140
+ ), f"Expected pred to be bool or tensor, but got {n_iter}."
141
+ assert (
142
+ not isinstance(n_iter, torch.Tensor) or n_iter.numel() == 1
143
+ ), f"Expected pred to be bool or single-element tensor, but got {n_iter}."
144
+ assert callable(body_fn), "Expect both branches to be callable."
145
+ assert isinstance(operands, (tuple, list)) and pytree.tree_all(
146
+ lambda t: isinstance(t, torch.Tensor), operands
147
+ ), (
148
+ "Expect operands to be a tuple of possibly nested dict/list/tuple that only "
149
+ f"consists of tensor leaves, but got {operands}."
150
+ )
151
+ assert concatenation_dims is None or (
152
+ isinstance(concatenation_dims, (list, tuple))
153
+ and all(isinstance(i, int) for i in concatenation_dims)
154
+ ), (
155
+ f"concatenation_dims should be None or a list of integers but it is "
156
+ f"{concatenation_dims}. Its length should be equal to the number of outputs."
157
+ )
158
+ assert torch._dynamo.is_dynamo_supported(), "simple_loop_for requires dynamo support."
159
+
160
+ if torch.compiler.is_dynamo_compiling():
161
+ return simple_loop_for_op(
162
+ n_iter, body_fn, (n_iter, *operands), concatenation_dims=concatenation_dims
163
+ )
164
+
165
+ if isinstance(n_iter, (bool, int, float)):
166
+ torch._check(
167
+ isinstance(n_iter, int),
168
+ lambda: f"n_iter must be an integer or a tensor not {type(n_iter)}",
169
+ )
170
+ return _simple_loop_for_fn(
171
+ n_iter, body_fn, operands, concatenation_dims=concatenation_dims
172
+ )
173
+
174
+ def _loop_for_op_wrapper(n_iter, body_fn, operands, concatenation_dims):
175
+ return simple_loop_for_op(n_iter, body_fn, operands, concatenation_dims)
176
+
177
+ _validate_input(n_iter, body_fn, operands, concatenation_dims)
178
+
179
+ # This requires torch>=2.10.
180
+ from torch._higher_order_ops.utils import setup_compilation_env
181
+
182
+ with setup_compilation_env() as _backend:
183
+ return _loop_for_op_wrapper(n_iter, body_fn, operands, concatenation_dims)
184
+ # return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)(
185
+ # n_iter, body_fn, operands, concatenation_dims)
186
+
187
+
188
+ def trace_simple_loop_for(
189
+ proxy_mode, func_overload, n_iter, body_fn, operands, concatenation_dims
190
+ ):
191
+ """See function ``simple_loop_for``."""
192
+ assert isinstance(operands, (list, tuple)) and (
193
+ concatenation_dims is None
194
+ or (
195
+ isinstance(concatenation_dims, (list, tuple))
196
+ and all(isinstance(i, int) for i in concatenation_dims)
197
+ )
198
+ ), (
199
+ f"simple_loop_for operands must be a list or tuple of tensors and SymInts and "
200
+ f"concatenation_dims must be None or a list of integer, "
201
+ f"operands={[type(o) for o in operands]}, "
202
+ f"concatenation_dims={concatenation_dims}"
203
+ )
204
+
205
+ body_graph = reenter_make_fx(body_fn)(n_iter, *operands)
206
+
207
+ body_outs = []
208
+ for node in body_graph.graph.nodes:
209
+ if node.op == "output":
210
+ body_outs.extend(node.args)
211
+
212
+ # flat_body_outs = pytree.arg_tree_leaves(*body_outs)
213
+ _i, body_name = unique_graph_id(proxy_mode, prefix="body_graph")
214
+ proxy_mode.tracer.root.register_module(body_name, body_graph)
215
+ args = (n_iter, body_graph, operands, concatenation_dims)
216
+ proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
217
+ out_proxy = proxy_mode.tracer.create_proxy("call_function", func_overload, proxy_args, {})
218
+ out = func_overload(n_iter, body_graph, operands, concatenation_dims)
219
+ return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
220
+
221
+
222
+ @simple_loop_for_op.py_impl(DispatchKey.CompositeExplicitAutograd)
223
+ def loop_for_op_dense(n_iter, body_fn, operands, concatenation_dims=None):
224
+ """Registered eager mode implementation."""
225
+ assert all(isinstance(o, torch.Tensor) for o in operands) and (
226
+ concatenation_dims is None
227
+ or (
228
+ isinstance(concatenation_dims, (list, tuple))
229
+ and all(isinstance(i, int) for i in concatenation_dims)
230
+ )
231
+ ), (
232
+ f"simple_loop_for operands must be a list or tuple of tensors and SymInts and "
233
+ f"concatenation_dims must be None or a list of integer, "
234
+ f"operands={[type(o) for o in operands]}, "
235
+ f"concatenation_dims={concatenation_dims}"
236
+ )
237
+ mode = _get_current_dispatch_mode()
238
+ assert mode is None, "Mode should never be enabled for CPU/CUDA key"
239
+ return _simple_loop_for_fn(
240
+ n_iter, body_fn, operands, concatenation_dims=concatenation_dims
241
+ )
242
+
243
+
244
+ @simple_loop_for_op.py_impl(ProxyTorchDispatchMode)
245
+ def inner(mode, n_iter, body_fn, operands, concatenation_dims=None):
246
+ """Registered tracing implementation."""
247
+ return trace_simple_loop_for(
248
+ mode, simple_loop_for_op, n_iter, body_fn, operands, concatenation_dims
249
+ )
250
+
251
+
252
+ @simple_loop_for_op.py_impl(FakeTensorMode)
253
+ def simple_loop_for_fake_tensor_mode(mode, n_iter, body_fn, operands, concatenation_dims=None):
254
+ """Registered FakeMode implementation."""
255
+ ignore_fresh_unbacked = contextlib.nullcontext()
256
+ if mode.shape_env:
257
+ ignore_fresh_unbacked = mode.shape_env.ignore_fresh_unbacked_symbols()
258
+
259
+ with mode, ignore_fresh_unbacked:
260
+ flat_body_outs, true_body_spec = pytree.tree_flatten(body_fn(n_iter, *operands))
261
+
262
+ return pytree.tree_unflatten(flat_body_outs, true_body_spec)
263
+
264
+
265
+ # Registration for autograd.
266
+ simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCPU)
267
+ simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCUDA)
268
+
269
+
270
+ def simple_loop_for(
271
+ n_iter: Union[int, torch.Tensor],
272
+ body_fn: Callable,
273
+ operands: Tuple[torch.Tensor, ...] = (),
274
+ concatenation_dims: Optional[Union[int, Sequence[int]]] = None,
275
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
276
+ """
277
+ Implements a simple loop for, the body is defined by a function which takes the
278
+ iteration number stored in a tensor, and other tensors.
279
+ It results one or several tensors in a tuple. All of them
280
+ are finally concatenated along the first dimension.
281
+
282
+ :param n_iter: iteration number
283
+ :param body: function
284
+ :param operands: bidy arguments
285
+ :param concatenation_dims: dimension or dimensions used to concatenate the output sequences
286
+ :return: contenated outputs, the output is a Tensor
287
+
288
+ An example with one output:
289
+
290
+ .. runpython::
291
+ :showcode:
292
+
293
+ import torch
294
+ from onnx_diagnostic.export.cf_simple_loop_for import simple_loop_for
295
+
296
+
297
+ class Model(torch.nn.Module):
298
+ def forward(self, n_iter, x):
299
+ def body(i, x):
300
+ return (x[: i.item() + 1].unsqueeze(1),)
301
+
302
+ return simple_loop_for(n_iter, body, (x,))
303
+
304
+
305
+ model = Model()
306
+ n_iter = torch.tensor(4, dtype=torch.int64)
307
+ x = torch.arange(10, dtype=torch.float32)
308
+ ep = torch.export.export(
309
+ model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
310
+ )
311
+ print(ep)
312
+
313
+ Another example with two outputs and a final concatenation on different axes.
314
+
315
+ .. runpython::
316
+ :showcode:
317
+
318
+ import torch
319
+ from onnx_diagnostic.export.cf_simple_loop_for import simple_loop_for
320
+
321
+
322
+ class Model(torch.nn.Module):
323
+ def forward(self, n_iter, x):
324
+ def body(i, x):
325
+ return (x[: i.item() + 1].unsqueeze(1), x[i.item() + 1 :].unsqueeze(0))
326
+
327
+ return simple_loop_for(n_iter, body, (x,), (0, 1))
328
+
329
+
330
+ model = Model()
331
+ n_iter = torch.tensor(4, dtype=torch.int64)
332
+ x = torch.arange(10, dtype=torch.float32)
333
+ ep = torch.export.export(
334
+ model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
335
+ )
336
+ print(ep)
337
+ """
338
+ res = _simple_loop_for(
339
+ n_iter,
340
+ body_fn,
341
+ operands,
342
+ concatenation_dims=(
343
+ (concatenation_dims,)
344
+ if isinstance(concatenation_dims, int)
345
+ else concatenation_dims
346
+ ),
347
+ )
348
+ torch._check(
349
+ isinstance(res, tuple),
350
+ lambda: f"Output of the loop should be a tuple not {type(res)}.",
351
+ )
352
+ return res[0] if len(res) == 1 else res
@@ -55,13 +55,13 @@ def is_exporting() -> bool:
55
55
  return _TEST_EXPORT or torch.compiler.is_exporting() or torch.compiler.is_compiling()
56
56
 
57
57
 
58
- def _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args):
58
+ def _loop_for_onnx_fn(n_iter, body_fn, concatenation_dims, args):
59
59
  """
60
60
  Python implementation of the loop.
61
61
 
62
62
  :param n_iter: number of iteration
63
63
  :param body_fn: function implementing the body
64
- :param reduction_dim: dimension used to reduce the list produced by the loop
64
+ :param concatenation_dims: dimension used to reduce the list produced by the loop
65
65
  :param args: arguments to the loop body
66
66
  :return: results
67
67
  """
@@ -95,7 +95,9 @@ def _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args):
95
95
  torch.cat(
96
96
  [r[i] for r in res],
97
97
  dim=(
98
- 0 if reduction_dim is None or i >= len(reduction_dim) else reduction_dim[i]
98
+ 0
99
+ if concatenation_dims is None or i >= len(concatenation_dims)
100
+ else concatenation_dims[i]
99
101
  ),
100
102
  )
101
103
  for i in range(n_res)
@@ -106,7 +108,7 @@ def _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args):
106
108
  def make_custom_loop_for_onnx(
107
109
  n_iter: torch.Tensor,
108
110
  body_fn: Callable,
109
- reduction_dim: Optional[Sequence[int]],
111
+ concatenation_dims: Optional[Sequence[int]],
110
112
  args: Sequence[torch.Tensor],
111
113
  body_gm: Optional[torch.fx.GraphModule] = None,
112
114
  body_mutated_inputs: Optional[List[Any]] = None,
@@ -120,7 +122,7 @@ def make_custom_loop_for_onnx(
120
122
 
121
123
  :param n_iter: number of iterations defined by a tensor of no dimension
122
124
  :param body_fn: the loop body defined as a function
123
- :param reduction_dim: dimension used to concatenated the results
125
+ :param concatenation_dims: dimension used to concatenated the results
124
126
  :param args: list of tensors, input to the body
125
127
  :param body_gm: torch.fx.GraphModule equivalent to *body_gm*
126
128
  :param body_mutated_inputs: inputs to *body_gm*
@@ -133,7 +135,7 @@ def make_custom_loop_for_onnx(
133
135
  assert body_mutated_inputs is not None, "body_mutated_inputs cannot be None"
134
136
  assert body_outputs is not None, "body_outputs cannot be None"
135
137
  srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs)
136
- sred = "x".join(map(str, reduction_dim)) if reduction_dim else ""
138
+ sred = "x".join(map(str, concatenation_dims)) if concatenation_dims else ""
137
139
  full_name = (
138
140
  body_fn.__qualname__.replace("<locals>", "L")
139
141
  .replace("<lambda>", "l")
@@ -169,14 +171,14 @@ def make_custom_loop_for_onnx(
169
171
  custom_def,
170
172
  _make_onx,
171
173
  (
172
- lambda g, sts, outputs, *args, bc=_make_onx, rd=reduction_dim, name=name: (
174
+ lambda g, sts, outputs, *args, bc=_make_onx, rd=concatenation_dims, name=name: (
173
175
  convert_custom_loop_into_onnx(
174
176
  g,
175
177
  sts,
176
178
  outputs,
177
179
  *args,
178
180
  body_callable=bc,
179
- reduction_dim=rd,
181
+ concatenation_dims=rd,
180
182
  name=name,
181
183
  )
182
184
  )
@@ -196,7 +198,7 @@ def convert_custom_loop_into_onnx(
196
198
  outputs: List[str],
197
199
  *args: str,
198
200
  body_callable: Callable[..., onnx.ModelProto],
199
- reduction_dim: Optional[Sequence[int]] = None,
201
+ concatenation_dims: Optional[Sequence[int]] = None,
200
202
  name: str = "loop_for_onnx",
201
203
  ) -> Union[str, List[str]]:
202
204
  """
@@ -207,7 +209,7 @@ def convert_custom_loop_into_onnx(
207
209
  :param outputs: output names
208
210
  :param args: input argument known at export time
209
211
  :param body: GraphProto, the loop body
210
- :param reduction_dim: the dimension to follow when aggregating the
212
+ :param concatenation_dims: the dimension to follow when aggregating the
211
213
  list of tensors after the loop ran
212
214
  :param name: to give the onnx nodes a name
213
215
  :return: output names
@@ -289,7 +291,11 @@ def convert_custom_loop_into_onnx(
289
291
  out,
290
292
  outputs=[o],
291
293
  name=name,
292
- axis=0 if not reduction_dim or i >= len(reduction_dim) else reduction_dim[i],
294
+ axis=(
295
+ 0
296
+ if not concatenation_dims or i >= len(concatenation_dims)
297
+ else concatenation_dims[i]
298
+ ),
293
299
  )
294
300
  for i, (out, o) in enumerate(zip(outloop, outputs))
295
301
  ]
@@ -337,7 +343,7 @@ def loop_for_onnx(
337
343
  n_iter: Union[torch.SymInt, torch.Tensor],
338
344
  body_fn: Callable[..., Tuple[torch.Tensor]],
339
345
  args: Sequence[torch.Tensor],
340
- reduction_dim: Optional[Sequence[int]] = None,
346
+ concatenation_dims: Optional[Sequence[int]] = None,
341
347
  ) -> Tuple[torch.Tensor, ...]:
342
348
  """
343
349
  High operators used to easily export a loop in ONNX.
@@ -353,7 +359,7 @@ def loop_for_onnx(
353
359
  in a tensor with no dimension, all the others
354
360
  are not changed during the loop
355
361
  :param args: the available tensors at every loop
356
- :param reduction_dim: the loop aggregated the results into list,
362
+ :param concatenation_dims: the loop aggregated the results into list,
357
363
  one of each output, each of them is concatenated into one
358
364
  tensor along one dimension, by default, it is the first
359
365
  dimension, but it can be defined otherwise
@@ -449,7 +455,7 @@ def loop_for_onnx(
449
455
  )
450
456
  print(ep)
451
457
 
452
- A last example with ``reduction_dim``:
458
+ A last example with ``concatenation_dims``:
453
459
 
454
460
  .. runpython::
455
461
  :showcode:
@@ -465,7 +471,7 @@ def loop_for_onnx(
465
471
  def body(i, x):
466
472
  return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1
467
473
 
468
- two = loop_for_onnx(n_iter, body, (x,), reduction_dim=[0, 1])
474
+ two = loop_for_onnx(n_iter, body, (x,), concatenation_dims=[0, 1])
469
475
  return two[0] + two[1].T
470
476
 
471
477
 
@@ -516,7 +522,7 @@ def loop_for_onnx(
516
522
  name, _custom_ops = make_custom_loop_for_onnx(
517
523
  n_iter,
518
524
  body_fn,
519
- reduction_dim,
525
+ concatenation_dims,
520
526
  args,
521
527
  body_gm=body_gm,
522
528
  body_mutated_inputs=body_mutated_inputs,
@@ -525,4 +531,4 @@ def loop_for_onnx(
525
531
  fct = getattr(torch.ops.onnx_higher_ops, name)
526
532
  return fct(n_iter, *args)
527
533
 
528
- return _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args)
534
+ return _loop_for_onnx_fn(n_iter, body_fn, concatenation_dims, args)
@@ -128,7 +128,61 @@ class EagerDirectReplacementWithOnnx:
128
128
 
129
129
  print(pretty_onnx(onx))
130
130
 
131
- # And with :func:`torch.onnx.export`:
131
+ We do the same with :func:`torch.onnx.export`:
132
+
133
+ .. runpython::
134
+ :showcode:
135
+
136
+ import onnx.helper as oh
137
+ import torch
138
+ from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
139
+ from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
140
+ from onnx_diagnostic.export.api import to_onnx
141
+ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
142
+
143
+
144
+ def demo_customsub(x, y):
145
+ return x - y
146
+
147
+
148
+ def demo_customsub_shape(x, y):
149
+ return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)
150
+
151
+
152
+ def make_function_proto():
153
+ return oh.make_function(
154
+ "onnx_plug",
155
+ "demo_customsub",
156
+ ["x", "y"],
157
+ ["z"],
158
+ [oh.make_node("Sub", ["x", "y"], ["z"])],
159
+ opset_imports=[oh.make_opsetid("", 22)],
160
+ )
161
+
162
+
163
+ class Model(torch.nn.Module):
164
+ def forward(self, x):
165
+ y = x.sum(axis=1, keepdim=True)
166
+ d = torch.ops.onnx_plug.demo_customsub(x, y)
167
+ return torch.abs(d)
168
+
169
+
170
+ replacements = [
171
+ EagerDirectReplacementWithOnnx(
172
+ demo_customsub, demo_customsub_shape, make_function_proto(), 2, 1
173
+ )
174
+ ]
175
+
176
+ x = torch.randn((3, 4), dtype=torch.float32)
177
+ model = Model()
178
+ ds = ({0: "d1", 1: "d2"},)
179
+
180
+ # The exported program shows a custom op.
181
+ ep = torch.export.export(model, (x,), dynamic_shapes=use_dyn_not_str(ds))
182
+ print("ep")
183
+
184
+ # As the exporter knows how the replace this custom op.
185
+ # Let's export.
132
186
 
133
187
  onx = to_onnx(
134
188
  model,
@@ -152,8 +206,8 @@ class EagerDirectReplacementWithOnnx:
152
206
  dtype = first_tensor.dtype
153
207
  itype = torch_dtype_to_onnx_dtype(dtype)
154
208
  if dtype == torch.float32:
155
- if opset >= 24:
156
- return "LOOPA24", itype
209
+ if opset >= 23:
210
+ return "LOOPA23", itype
157
211
  return "LOOPMHA", itype
158
212
  if dtype == torch.float16:
159
213
  if first_tensor.is_cuda:
@@ -175,9 +229,9 @@ class EagerDirectReplacementWithOnnx:
175
229
  ("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset(
176
230
  PackedAttention.to_function_proto()
177
231
  ),
178
- ("LOOPA24", onnx.TensorProto.FLOAT): LoopAttention24.to_function_proto(),
179
- ("LOOPA24", onnx.TensorProto.FLOAT16): _update_sequence_type(
180
- onnx.TensorProto.FLOAT16, LoopAttention24.to_function_proto()
232
+ ("LOOPA23", onnx.TensorProto.FLOAT): LoopAttention23.to_function_proto(),
233
+ ("LOOPA23", onnx.TensorProto.FLOAT16): _update_sequence_type(
234
+ onnx.TensorProto.FLOAT16, LoopAttention23.to_function_proto()
181
235
  ),
182
236
  ("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset(
183
237
  LoopMHAAttention.to_function_proto()
@@ -700,6 +700,19 @@ def requires_onnx(version: str, msg: str = "") -> Callable:
700
700
  return lambda x: x
701
701
 
702
702
 
703
+ def requires_experimental_experiment(version: str, msg: str = "") -> Callable:
704
+ """Skips a unit test if :epkg:`onnx-array-api` is not recent enough."""
705
+ import packaging.version as pv
706
+ import experimental_experiment
707
+
708
+ if pv.Version(experimental_experiment.__version__) < pv.Version(version):
709
+ msg = (
710
+ f"onnx-array-api version {experimental_experiment.__version__} < {version}: {msg}"
711
+ )
712
+ return unittest.skip(msg)
713
+ return lambda x: x
714
+
715
+
703
716
  def requires_onnx_array_api(version: str, msg: str = "") -> Callable:
704
717
  """Skips a unit test if :epkg:`onnx-array-api` is not recent enough."""
705
718
  import packaging.version as pv
@@ -774,6 +787,7 @@ class ExtTestCase(unittest.TestCase):
774
787
  def setUpClass(cls):
775
788
  logger = logging.getLogger("onnxscript.optimizer.constant_folding")
776
789
  logger.setLevel(logging.ERROR)
790
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
777
791
  unittest.TestCase.setUpClass()
778
792
 
779
793
  @classmethod