onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.8.1"
6
+ __version__ = "0.8.2"
7
7
  __author__ = "Xavier Dupré"
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List, Sequence, Optional, Tuple, Union
1
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
2
2
  import torch
3
3
 
4
4
 
@@ -14,6 +14,10 @@ def to_onnx(
14
14
  output_names: Optional[List[str]] = None,
15
15
  output_dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
16
16
  exporter: str = "onnx-dynamo",
17
+ exporter_kwargs: Optional[Dict[str, Any]] = None,
18
+ save_ep: Optional[str] = None,
19
+ optimize: bool = True,
20
+ use_control_flow_dispatcher: bool = False,
17
21
  ) -> Any:
18
22
  """
19
23
  Common API for exporters. By default, the models are optimized to use the
@@ -32,6 +36,11 @@ def to_onnx(
32
36
  :param output_names: to change the output of the onnx model
33
37
  :param output_dynamic_shapes: to overwrite the dynamic shapes names
34
38
  :param exporter: exporter to use (``onnx-dynamo``, ``modelbuilder``, ``custom``)
39
+ :param exporter_kwargs: additional parameters sent to the exporter
40
+ :param save_ep: saves the exported program
41
+ :param optimize: optimizes the model
42
+ :param use_control_flow_dispatcher: use the dispatcher created to supported
43
+ custom loops (see :func:`onnx_diagnostic.export.control_flow.loop_for`)
35
44
  :return: the output of the selected exporter, usually a structure including
36
45
  an onnx model
37
46
 
@@ -48,9 +57,23 @@ def to_onnx(
48
57
  )
49
58
  """
50
59
  if exporter == "custom":
51
- from experimental_experiment.torch_interpreter import to_onnx as _to_onnx
60
+ from experimental_experiment.torch_interpreter import (
61
+ to_onnx as _to_onnx,
62
+ ExportOptions,
63
+ )
52
64
  from experimental_experiment.xbuilder import OptimizationOptions
53
65
 
66
+ if use_control_flow_dispatcher:
67
+ from .control_flow import create_global_dispatcher
68
+
69
+ dispatcher = create_global_dispatcher()
70
+
71
+ options = None
72
+ if exporter_kwargs is not None:
73
+ options = exporter_kwargs.pop("options", None)
74
+ if options is None:
75
+ options = OptimizationOptions(patterns="default+onnxruntime")
76
+
54
77
  return _to_onnx(
55
78
  mod,
56
79
  args=args,
@@ -63,7 +86,10 @@ def to_onnx(
63
86
  dynamic_shapes=dynamic_shapes,
64
87
  large_model=True,
65
88
  output_dynamic_shapes=output_dynamic_shapes,
66
- options=OptimizationOptions(patterns="default+onnxruntime"),
89
+ export_options=ExportOptions(save_ep=save_ep),
90
+ options=options,
91
+ **(exporter_kwargs or {}),
92
+ dispatcher=dispatcher if use_control_flow_dispatcher else None,
67
93
  )
68
94
  if exporter in ("dynamo", "onnx-dynamo"):
69
95
  import onnxscript.rewriter.ort_fusions as ort_fusions
@@ -80,9 +106,12 @@ def to_onnx(
80
106
  opset_version=target_opset,
81
107
  dynamic_shapes=dynamic_shapes,
82
108
  dynamo=True,
109
+ **(exporter_kwargs or {}),
83
110
  )
84
- ort_fusions.optimize_for_ort(epo.model)
85
- epo.save(filename)
111
+ if optimize:
112
+ ort_fusions.optimize_for_ort(epo.model)
113
+ if filename:
114
+ epo.save(filename, external_data=True)
86
115
  return epo
87
116
 
88
117
  if exporter == "modelbuilder":
@@ -117,6 +146,7 @@ def to_onnx(
117
146
  precision=str(first_float[0].dtype).split(".")[-1],
118
147
  execution_provider="cuda" if first.is_cuda else "cpu",
119
148
  cache_dir=os.path.dirname(filename),
149
+ **(exporter_kwargs or {}),
120
150
  )
121
151
  save_model_builder(onx, os.path.dirname(filename))
122
152
  return onx
@@ -0,0 +1,511 @@
1
+ import contextlib
2
+ import inspect
3
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
4
+ import onnx
5
+ import onnx.helper as oh
6
+ import torch
7
+ from torch._higher_order_ops.utils import materialize_as_graph
8
+ from torch._higher_order_ops.utils import check_input_alias_and_mutation_return_outputs
9
+ from .api import to_onnx
10
+
11
+ _TEST_EXPORT = False
12
+ _REGISTERED_SCHEMA = {} # type: ignore[var-annotated]
13
+ _DISPATCHER = None
14
+
15
+
16
+ def create_global_dispatcher():
17
+ global _DISPATCHER
18
+
19
+ if not _DISPATCHER:
20
+ from experimental_experiment.torch_interpreter import Dispatcher
21
+
22
+ class ControlFlowDispatcher(Dispatcher):
23
+ def __init__(self):
24
+ super().__init__({})
25
+
26
+ def register(self, aten_name: str, converter: Callable):
27
+ assert aten_name not in self.registered_functions, (
28
+ f"Name {aten_name!r} is already registered in "
29
+ f"{sorted(self.registered_functions)}"
30
+ )
31
+ self.registered_functions[aten_name] = converter
32
+
33
+ _DISPATCHER = ControlFlowDispatcher()
34
+ return _DISPATCHER
35
+
36
+
37
+ @contextlib.contextmanager
38
+ def enable_code_export_control_flow():
39
+ """Enables the code means to be exported."""
40
+ global _TEST_EXPORT
41
+ old = _TEST_EXPORT
42
+ _TEST_EXPORT = True
43
+ try:
44
+ yield
45
+ finally:
46
+ _TEST_EXPORT = old
47
+
48
+
49
+ def is_exporting() -> bool:
50
+ """
51
+ Returns :func:`torch.compiler.is_exporting` or
52
+ :func:`torch.compiler.is_compiling`.
53
+ Changes ``_TEST_EXPORT`` to make it trigger.
54
+ """
55
+ return _TEST_EXPORT or torch.compiler.is_exporting() or torch.compiler.is_compiling()
56
+
57
+
58
+ def _loop_for_fn(n_iter, body_fn, reduction_dim, args):
59
+ """
60
+ Python implementation of the loop.
61
+
62
+ :param n_iter: number of iteration
63
+ :param body_fn: function implementing the body
64
+ :param reduction_dim: dimension used to reduce the list produced by the loop
65
+ :param args: arguments to the loop body
66
+ :return: results
67
+ """
68
+ res = []
69
+ for i in torch.arange(n_iter, dtype=n_iter.dtype):
70
+ r = body_fn(i, *args)
71
+ if isinstance(r, tuple):
72
+ assert not res or len(r) == len(res[-1]), (
73
+ f"Unexpected number of results {len(r)} for function {body_fn}, "
74
+ f"expected {len(res[-1])}"
75
+ )
76
+ res.append(r)
77
+ else:
78
+ assert isinstance(r, torch.Tensor), (
79
+ f"Unexpected type {r} for function {body_fn}, "
80
+ f"it must be a tuple or a Tensor."
81
+ )
82
+ assert not res or len(res[-1]) == 1, (
83
+ f"Unexpected number of results {len(r)} for function {body_fn}, "
84
+ f"expected {len(res[-1])}"
85
+ )
86
+ res.append((r,))
87
+
88
+ if not res:
89
+ return torch.empty(tuple(), dtype=torch.float32, device=args[0].device)
90
+ if len(res) == 1:
91
+ final = res[0]
92
+ else:
93
+ n_res = len(res[0])
94
+ final = [
95
+ torch.cat(
96
+ [r[i] for r in res],
97
+ dim=(
98
+ 0 if reduction_dim is None or i >= len(reduction_dim) else reduction_dim[i]
99
+ ),
100
+ )
101
+ for i in range(n_res)
102
+ ]
103
+ return tuple(final) if len(final) > 1 else final[0]
104
+
105
+
106
+ def make_custom_loop_for(
107
+ n_iter: torch.Tensor,
108
+ body_fn: Callable,
109
+ reduction_dim: Optional[Sequence[int]],
110
+ args: Sequence[torch.Tensor],
111
+ body_gm: Optional[torch.fx.GraphModule] = None,
112
+ body_mutated_inputs: Optional[List[Any]] = None,
113
+ body_outputs: Optional[List[Any]] = None,
114
+ ) -> Tuple[str, torch.library.CustomOpDef]:
115
+ """
116
+ Defines a custom operator for a loop in order to avoid
117
+ :func:`torch.export.export` digging into it.
118
+ It registers the custom op and a custom conversion
119
+ to ONNX.
120
+
121
+ :param n_iter: number of iterations defined by a tensor of no dimension
122
+ :param body_fn: the loop body defined as a function
123
+ :param reduction_dim: dimension used to concatenated the results
124
+ :param args: list of tensors, input to the body
125
+ :param body_gm: torch.fx.GraphModule equivalent to *body_gm*
126
+ :param body_mutated_inputs: inputs to *body_gm*
127
+ :param body_outputs: outputs to *body_gm*
128
+ :return: a name and the custom op definition, the name
129
+ is used to cache the custom op
130
+ """
131
+ global _DISPATCHER
132
+ assert body_gm is not None, "body_gm cannot be None"
133
+ assert body_mutated_inputs is not None, "body_mutated_inputs cannot be None"
134
+ assert body_outputs is not None, "body_outputs cannot be None"
135
+ srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs)
136
+ sred = "x".join(map(str, reduction_dim)) if reduction_dim else ""
137
+ name = f"loop_for_{body_fn.__name__}_{id(body_fn)}_{srank}_{sred}"
138
+ if name in _REGISTERED_SCHEMA:
139
+ return name, _REGISTERED_SCHEMA[name][0]
140
+ sig = inspect.signature(body_fn)
141
+ inputs = ", ".join([f"Tensor {p}" for p in sig.parameters])
142
+ schema = f"({inputs}) -> Tensor"
143
+ if len(body_outputs) > 1:
144
+ schema += "[]"
145
+ custom_def = torch.library.CustomOpDef("onnx_higher_ops", name, schema, body_fn)
146
+ custom_def.register_kernel("cpu")(body_fn)
147
+
148
+ custom_def._abstract_fn = lambda *_args, _o=body_outputs: (
149
+ tuple([torch.empty_like(s) for s in _o]) if len(_o) > 1 else torch.empty_like(_o[0])
150
+ )
151
+
152
+ def _make_onx(
153
+ body_gm=body_gm, args=args, target_opset=None, verbose=0, exporter_kwargs=None
154
+ ):
155
+ return convert_into_onnx(
156
+ body_gm,
157
+ args,
158
+ exporter_kwargs=exporter_kwargs,
159
+ target_opset=target_opset,
160
+ verbose=verbose,
161
+ )
162
+
163
+ to_register = (
164
+ custom_def,
165
+ _make_onx,
166
+ (
167
+ lambda g, sts, outputs, *args, bc=_make_onx, rd=reduction_dim, name=name: (
168
+ convert_custom_loop_into_onnx(
169
+ g,
170
+ sts,
171
+ outputs,
172
+ *args,
173
+ body_callable=bc,
174
+ reduction_dim=rd,
175
+ name=name,
176
+ )
177
+ )
178
+ ),
179
+ )
180
+ if _DISPATCHER is None:
181
+ create_global_dispatcher()
182
+ assert _DISPATCHER
183
+ _DISPATCHER.register(f"onnx_higher_ops::{name}", to_register[-1])
184
+ _REGISTERED_SCHEMA[name] = to_register
185
+ return name, custom_def
186
+
187
+
188
+ def convert_custom_loop_into_onnx(
189
+ g: Any, # "GreaphBuilder"
190
+ sts: Dict[str, Any],
191
+ outputs: List[str],
192
+ *args: str,
193
+ body_callable: Callable[..., onnx.ModelProto],
194
+ reduction_dim: Optional[Sequence[int]] = None,
195
+ name: str = "loop_for",
196
+ ) -> Union[str, List[str]]:
197
+ """
198
+ Converts a custom op ``higher_ops::loop_for...`` into e sequence of node.
199
+
200
+ :param g: GreaphBuilder
201
+ :param sts: if not defined, torch does not know the output shapes
202
+ :param outputs: output names
203
+ :param args: input argument known at export time
204
+ :param body: GraphProto, the loop body
205
+ :param reduction_dim: the dimension to follow when aggregating the
206
+ list of tensors after the loop ran
207
+ :param name: to give the onnx nodes a name
208
+ :return: output names
209
+ """
210
+ assert body_callable is not None, "body_callable cannot be None"
211
+ # This should be part of a public API.
212
+ body = body_callable(
213
+ target_opset=g.main_opset,
214
+ verbose=g.verbose,
215
+ exporter_kwargs={"options": g.optimization_options},
216
+ )
217
+
218
+ graph = body.graph if isinstance(body, onnx.ModelProto) else body
219
+ assert isinstance(
220
+ graph, onnx.GraphProto
221
+ ), f"Unexpected type {type(body)} for body{g.get_debug_msg()}"
222
+ assert len(outputs) == 1, f"Only one outputs is expected but outputs={outputs!r}"
223
+ if len(graph.output) != 1:
224
+ outputs = [f"{outputs[0]}#{i}" for i in range(len(graph.output))]
225
+ input_names = [i.name for i in graph.input]
226
+ inputs = [
227
+ *graph.input[:1],
228
+ oh.make_tensor_value_info("cond_unused", onnx.TensorProto.BOOL, []),
229
+ *[
230
+ oh.make_tensor_sequence_value_info(
231
+ f"loop_in{i}", graph.output[i].type.tensor_type.elem_type, None
232
+ )
233
+ for i in range(len(graph.output))
234
+ ],
235
+ # hidden inputs are not added
236
+ ]
237
+ nodes = [
238
+ oh.make_node("Identity", ["cond_unused"], ["cond_out"]),
239
+ *[oh.make_node("Identity", [a], [r]) for a, r in zip(args[1:], input_names[1:])],
240
+ *graph.node,
241
+ *[
242
+ oh.make_node(
243
+ "SequenceInsert",
244
+ [f"loop_in{i}", graph.output[i].name],
245
+ [f"loop_out{i}"],
246
+ )
247
+ for i in range(len(graph.output))
248
+ ],
249
+ ]
250
+ graph_outputs = [
251
+ oh.make_tensor_value_info("cond_out", onnx.TensorProto.BOOL, []),
252
+ *[
253
+ oh.make_tensor_sequence_value_info(
254
+ f"loop_out{i}", graph.output[i].type.tensor_type.elem_type, None
255
+ )
256
+ for i in range(len(graph.output))
257
+ ],
258
+ ]
259
+ graph = oh.make_graph(
260
+ nodes, graph.name, inputs, graph_outputs, graph.initializer, graph.sparse_initializer
261
+ )
262
+
263
+ sequences = [g.op.SequenceEmpty() for _ in outputs]
264
+
265
+ outloop = [g.unique_name(f"loop_for{i}") for i in range(len(sequences))]
266
+
267
+ for i, s in enumerate(sequences):
268
+ g.set_sequence(s, graph.output[i].type.tensor_type.elem_type)
269
+ g.make_node("Loop", [args[0], "", *sequences], outloop, name=name, body=graph)
270
+ for i, o in enumerate(outloop):
271
+ g.set_sequence(o, graph.output[i].type.tensor_type.elem_type)
272
+ _res = [
273
+ g.op.ConcatFromSequence(
274
+ out,
275
+ outputs=[o],
276
+ name=name,
277
+ axis=0 if not reduction_dim or i >= len(reduction_dim) else reduction_dim[i],
278
+ )
279
+ for i, (out, o) in enumerate(zip(outloop, outputs))
280
+ ]
281
+ if not sts:
282
+ for i, o in enumerate(outputs):
283
+ g.set_type(o, graph.output[i].type.tensor_type.elem_type)
284
+ g.set_rank(o, len(graph.output[i].type.tensor_type.shape.dims))
285
+ return outputs if len(outputs) > 1 else outputs[0]
286
+
287
+
288
+ def convert_into_onnx(
289
+ body_gm: torch.fx.GraphModule,
290
+ args: Sequence[torch.Tensor],
291
+ target_opset: Optional[int] = None,
292
+ verbose: int = 0,
293
+ exporter_kwargs: Optional[Dict[str, Any]] = None,
294
+ ) -> onnx.ModelProto:
295
+ """
296
+ Converts a torch.fx.GraphModule into ONNX.
297
+ It returns a ModelProto.
298
+
299
+ :param body_gm: a torch.fx.GraphModule
300
+ :param args: arguments known at export time
301
+ :param target_opset: targeted opset
302
+ :param verbose: verbosity level
303
+ :param exporter_kwargs: additional exporter arguments
304
+ :return: a ModelProto
305
+ """
306
+ # This does not work with onnx-dynamo.
307
+ # opset still needs to be defined
308
+ container = to_onnx(
309
+ body_gm,
310
+ args,
311
+ exporter="custom",
312
+ exporter_kwargs=exporter_kwargs,
313
+ target_opset=target_opset,
314
+ verbose=verbose,
315
+ )
316
+ return container.model_proto
317
+
318
+
319
+ def loop_for(
320
+ n_iter: Union[torch.SymInt, torch.Tensor],
321
+ body_fn: Callable[..., Tuple[torch.Tensor]],
322
+ args: Sequence[torch.Tensor],
323
+ reduction_dim: Optional[Sequence[int]] = None,
324
+ ) -> Tuple[torch.Tensor, ...]:
325
+ """
326
+ High operators used to easily export a loop in ONNX.
327
+ Does not fully work with :func:`torch.export.export`,
328
+ it does replaces a custom op with a loop operator afterwards.
329
+ Every iteration produces tensors, all of them are gathered
330
+ into lists, all these lists are concatenated into tensors.
331
+
332
+ :param n_iter: number of iterations, it can be fixed on
333
+ variable, in that case it should a tensor with no dimension
334
+ :param body_fn: function body, takes only tensors and returns
335
+ only tensors, the first argument is the iteration number
336
+ in a tensor with no dimension, all the others
337
+ are not changed during the loop
338
+ :param args: the available tensors at every loop
339
+ :param reduction_dim: the loop aggregated the results into list,
340
+ one of each output, each of them is concatenated into one
341
+ tensor along one dimension, by default, it is the first
342
+ dimension, but it can be defined otherwise
343
+
344
+ .. runpython::
345
+ :showcode:
346
+
347
+ import torch
348
+ import onnxruntime
349
+ from onnx_diagnostic.export.api import to_onnx
350
+ from onnx_diagnostic.export.control_flow import loop_for
351
+
352
+
353
+ class Model(torch.nn.Module):
354
+ def forward(self, n_iter, x):
355
+ def body(i, x):
356
+ return x[: i.item() + 1].unsqueeze(1)
357
+
358
+ return loop_for(n_iter, body, (x,))
359
+
360
+
361
+ model = Model()
362
+ n_iter = torch.tensor(4, dtype=torch.int64)
363
+ x = torch.arange(10, dtype=torch.float32)
364
+ expected = model(n_iter, x)
365
+ print("expected:", expected)
366
+
367
+ onx = to_onnx(
368
+ model,
369
+ (n_iter, x),
370
+ dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
371
+ exporter="custom",
372
+ use_control_flow_dispatcher=True,
373
+ ).model_proto
374
+
375
+ sess = onnxruntime.InferenceSession(
376
+ onx.SerializeToString(), providers=["CPUExecutionProvider"]
377
+ )
378
+ got = sess.run(None, dict(zip(["n_iter", "x"], [n_iter.numpy(), x.numpy()])))
379
+ print("got:", got)
380
+
381
+
382
+ # The loop is exported as a custom ops.
383
+ ep = torch.export.export(
384
+ model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
385
+ )
386
+ print(ep)
387
+
388
+ Another example with two outputs:
389
+
390
+ .. runpython::
391
+ :showcode:
392
+
393
+ import torch
394
+ import onnxruntime
395
+ from onnx_diagnostic.export.api import to_onnx
396
+ from onnx_diagnostic.export.control_flow import loop_for
397
+
398
+
399
+ class Model(torch.nn.Module):
400
+ def forward(self, n_iter, x):
401
+ def body(i, x):
402
+ return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(1) + 1
403
+
404
+ two = loop_for(n_iter, body, (x,))
405
+ return two[0] + two[1]
406
+
407
+
408
+ model = Model()
409
+ n_iter = torch.tensor(4, dtype=torch.int64)
410
+ x = torch.arange(10, dtype=torch.float32)
411
+ expected = model(n_iter, x)
412
+ print("expected:", expected)
413
+
414
+ onx = to_onnx(
415
+ model,
416
+ (n_iter, x),
417
+ dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
418
+ exporter="custom",
419
+ use_control_flow_dispatcher=True,
420
+ ).model_proto
421
+
422
+ sess = onnxruntime.InferenceSession(
423
+ onx.SerializeToString(), providers=["CPUExecutionProvider"]
424
+ )
425
+ got = sess.run(None, dict(zip(["n_iter", "x"], [n_iter.numpy(), x.numpy()])))
426
+ print("got:", got)
427
+
428
+
429
+ # The loop is exported as a custom ops.
430
+ ep = torch.export.export(
431
+ model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
432
+ )
433
+ print(ep)
434
+
435
+ A last example with ``reduction_dim``:
436
+
437
+ .. runpython::
438
+ :showcode:
439
+
440
+ import torch
441
+ import onnxruntime
442
+ from onnx_diagnostic.export.api import to_onnx
443
+ from onnx_diagnostic.export.control_flow import loop_for
444
+
445
+
446
+ class Model(torch.nn.Module):
447
+ def forward(self, n_iter, x):
448
+ def body(i, x):
449
+ return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1
450
+
451
+ two = loop_for(n_iter, body, (x,), reduction_dim=[0, 1])
452
+ return two[0] + two[1].T
453
+
454
+
455
+ model = Model()
456
+ n_iter = torch.tensor(4, dtype=torch.int64)
457
+ x = torch.arange(10, dtype=torch.float32)
458
+ expected = model(n_iter, x)
459
+ print("expected:", expected)
460
+
461
+ onx = to_onnx(
462
+ model,
463
+ (n_iter, x),
464
+ dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
465
+ exporter="custom",
466
+ use_control_flow_dispatcher=True,
467
+ ).model_proto
468
+
469
+ sess = onnxruntime.InferenceSession(
470
+ onx.SerializeToString(), providers=["CPUExecutionProvider"]
471
+ )
472
+ got = sess.run(None, dict(zip(["n_iter", "x"], [n_iter.numpy(), x.numpy()])))
473
+ print("got:", got)
474
+
475
+
476
+ # The loop is exported as a custom ops.
477
+ ep = torch.export.export(
478
+ model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
479
+ )
480
+ print(ep)
481
+ """
482
+ assert args, "The function should have at least one arg."
483
+ assert (
484
+ isinstance(n_iter, torch.Tensor)
485
+ and n_iter.dtype == torch.int64
486
+ and len(n_iter.shape) == 0
487
+ ), f"Only a tensor for one int64 is allowed for n_iter but it equal to {n_iter}."
488
+ if is_exporting():
489
+ body_gm: torch.fx.GraphModule = materialize_as_graph(
490
+ body_fn, (torch.tensor(0, dtype=torch.int64), *args)
491
+ )
492
+ (
493
+ _1,
494
+ _2,
495
+ _3,
496
+ body_mutated_inputs,
497
+ body_outputs,
498
+ ) = check_input_alias_and_mutation_return_outputs(body_gm)
499
+ name, _custom_ops = make_custom_loop_for(
500
+ n_iter,
501
+ body_fn,
502
+ reduction_dim,
503
+ args,
504
+ body_gm=body_gm,
505
+ body_mutated_inputs=body_mutated_inputs,
506
+ body_outputs=body_outputs,
507
+ )
508
+ fct = getattr(torch.ops.onnx_higher_ops, name)
509
+ return fct(n_iter, *args)
510
+
511
+ return _loop_for_fn(n_iter, body_fn, reduction_dim, args)