onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +412 -12
  3. onnx_diagnostic/export/api.py +111 -8
  4. onnx_diagnostic/export/control_flow.py +48 -345
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +12 -7
  7. onnx_diagnostic/export/onnx_plug.py +531 -0
  8. onnx_diagnostic/ext_test_case.py +163 -48
  9. onnx_diagnostic/helpers/cache_helper.py +1 -1
  10. onnx_diagnostic/helpers/dot_helper.py +222 -0
  11. onnx_diagnostic/helpers/helper.py +108 -37
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  14. onnx_diagnostic/helpers/onnx_helper.py +531 -6
  15. onnx_diagnostic/helpers/ort_session.py +45 -19
  16. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  17. onnx_diagnostic/helpers/torch_helper.py +131 -8
  18. onnx_diagnostic/reference/ort_evaluator.py +228 -46
  19. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  20. onnx_diagnostic/tasks/summarization.py +72 -137
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
  22. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  23. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  24. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
  36. onnx_diagnostic/torch_models/code_sample.py +2 -1
  37. onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
  38. onnx_diagnostic/torch_models/validate.py +64 -2
  39. onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
  40. onnx_diagnostic/torch_onnx/sbs.py +969 -312
  41. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
  42. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
  43. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
  44. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
  45. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  46. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,528 @@
1
+ import contextlib
2
+ import inspect
3
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
4
+ import onnx
5
+ import onnx.helper as oh
6
+ import torch
7
+ from torch._higher_order_ops.utils import materialize_as_graph
8
+ from torch._higher_order_ops.utils import check_input_alias_and_mutation_return_outputs
9
+ from .api import to_onnx
10
+
11
+ _TEST_EXPORT = False
12
+ _REGISTERED_SCHEMA = {} # type: ignore[var-annotated]
13
+ _DISPATCHER = None
14
+
15
+
16
+ def create_global_dispatcher():
17
+ global _DISPATCHER
18
+
19
+ if not _DISPATCHER:
20
+ from experimental_experiment.torch_interpreter import Dispatcher
21
+
22
+ class ControlFlowDispatcher(Dispatcher):
23
+ def __init__(self):
24
+ super().__init__({})
25
+
26
+ def register(self, aten_name: str, converter: Callable):
27
+ assert aten_name not in self.registered_functions, (
28
+ f"Name {aten_name!r} is already registered in "
29
+ f"{sorted(self.registered_functions)}"
30
+ )
31
+ self.registered_functions[aten_name] = converter
32
+
33
+ _DISPATCHER = ControlFlowDispatcher()
34
+ return _DISPATCHER
35
+
36
+
37
+ @contextlib.contextmanager
38
+ def enable_code_export_control_flow():
39
+ """Enables the code meant to be exported."""
40
+ global _TEST_EXPORT
41
+ old = _TEST_EXPORT
42
+ _TEST_EXPORT = True
43
+ try:
44
+ yield
45
+ finally:
46
+ _TEST_EXPORT = old
47
+
48
+
49
+ def is_exporting() -> bool:
50
+ """
51
+ Returns :func:`torch.compiler.is_exporting` or
52
+ :func:`torch.compiler.is_compiling`.
53
+ Changes ``_TEST_EXPORT`` to make it trigger.
54
+ """
55
+ return _TEST_EXPORT or torch.compiler.is_exporting() or torch.compiler.is_compiling()
56
+
57
+
58
+ def _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args):
59
+ """
60
+ Python implementation of the loop.
61
+
62
+ :param n_iter: number of iteration
63
+ :param body_fn: function implementing the body
64
+ :param reduction_dim: dimension used to reduce the list produced by the loop
65
+ :param args: arguments to the loop body
66
+ :return: results
67
+ """
68
+ res = []
69
+ for i in torch.arange(n_iter, dtype=n_iter.dtype):
70
+ r = body_fn(i, *args)
71
+ if isinstance(r, tuple):
72
+ assert not res or len(r) == len(res[-1]), (
73
+ f"Unexpected number of results {len(r)} for function {body_fn}, "
74
+ f"expected {len(res[-1])}"
75
+ )
76
+ res.append(r)
77
+ else:
78
+ assert isinstance(r, torch.Tensor), (
79
+ f"Unexpected type {r} for function {body_fn}, "
80
+ f"it must be a tuple or a Tensor."
81
+ )
82
+ assert not res or len(res[-1]) == 1, (
83
+ f"Unexpected number of results {len(r)} for function {body_fn}, "
84
+ f"expected {len(res[-1])}"
85
+ )
86
+ res.append((r,))
87
+
88
+ if not res:
89
+ return torch.empty(tuple(), dtype=torch.float32, device=args[0].device)
90
+ if len(res) == 1:
91
+ final = res[0]
92
+ else:
93
+ n_res = len(res[0])
94
+ final = [
95
+ torch.cat(
96
+ [r[i] for r in res],
97
+ dim=(
98
+ 0 if reduction_dim is None or i >= len(reduction_dim) else reduction_dim[i]
99
+ ),
100
+ )
101
+ for i in range(n_res)
102
+ ]
103
+ return tuple(final) if len(final) > 1 else final[0]
104
+
105
+
106
+ def make_custom_loop_for_onnx(
107
+ n_iter: torch.Tensor,
108
+ body_fn: Callable,
109
+ reduction_dim: Optional[Sequence[int]],
110
+ args: Sequence[torch.Tensor],
111
+ body_gm: Optional[torch.fx.GraphModule] = None,
112
+ body_mutated_inputs: Optional[List[Any]] = None,
113
+ body_outputs: Optional[List[Any]] = None,
114
+ ) -> Tuple[str, torch.library.CustomOpDef]:
115
+ """
116
+ Defines a custom operator for a loop in order to avoid
117
+ :func:`torch.export.export` digging into it.
118
+ It registers the custom op and a custom conversion
119
+ to ONNX.
120
+
121
+ :param n_iter: number of iterations defined by a tensor of no dimension
122
+ :param body_fn: the loop body defined as a function
123
+ :param reduction_dim: dimension used to concatenated the results
124
+ :param args: list of tensors, input to the body
125
+ :param body_gm: torch.fx.GraphModule equivalent to *body_gm*
126
+ :param body_mutated_inputs: inputs to *body_gm*
127
+ :param body_outputs: outputs to *body_gm*
128
+ :return: a name and the custom op definition, the name
129
+ is used to cache the custom op
130
+ """
131
+ global _DISPATCHER
132
+ assert body_gm is not None, "body_gm cannot be None"
133
+ assert body_mutated_inputs is not None, "body_mutated_inputs cannot be None"
134
+ assert body_outputs is not None, "body_outputs cannot be None"
135
+ srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs)
136
+ sred = "x".join(map(str, reduction_dim)) if reduction_dim else ""
137
+ full_name = (
138
+ body_fn.__qualname__.replace("<locals>", "L")
139
+ .replace("<lambda>", "l")
140
+ .replace(".", "_")
141
+ )
142
+ name = f"loop_for_onnx_{full_name}_{srank}_{sred}"
143
+ if name in _REGISTERED_SCHEMA:
144
+ return name, _REGISTERED_SCHEMA[name][0]
145
+ sig = inspect.signature(body_fn)
146
+ inputs = ", ".join([f"Tensor {p}" for p in sig.parameters])
147
+ schema = f"({inputs}) -> Tensor"
148
+ if len(body_outputs) > 1:
149
+ schema += "[]"
150
+ custom_def = torch.library.CustomOpDef("onnx_higher_ops", name, schema, body_fn)
151
+ custom_def.register_kernel("cpu")(body_fn)
152
+
153
+ custom_def._abstract_fn = lambda *_args, _o=body_outputs: (
154
+ tuple([torch.empty_like(s) for s in _o]) if len(_o) > 1 else torch.empty_like(_o[0])
155
+ )
156
+
157
+ def _make_onx(
158
+ body_gm=body_gm, args=args, target_opset=None, verbose=0, exporter_kwargs=None
159
+ ):
160
+ return convert_into_onnx(
161
+ body_gm,
162
+ args,
163
+ exporter_kwargs=exporter_kwargs,
164
+ target_opset=target_opset,
165
+ verbose=verbose,
166
+ )
167
+
168
+ to_register = (
169
+ custom_def,
170
+ _make_onx,
171
+ (
172
+ lambda g, sts, outputs, *args, bc=_make_onx, rd=reduction_dim, name=name: (
173
+ convert_custom_loop_into_onnx(
174
+ g,
175
+ sts,
176
+ outputs,
177
+ *args,
178
+ body_callable=bc,
179
+ reduction_dim=rd,
180
+ name=name,
181
+ )
182
+ )
183
+ ),
184
+ )
185
+ if _DISPATCHER is None:
186
+ create_global_dispatcher()
187
+ assert _DISPATCHER
188
+ _DISPATCHER.register(f"onnx_higher_ops::{name}", to_register[-1])
189
+ _REGISTERED_SCHEMA[name] = to_register
190
+ return name, custom_def
191
+
192
+
193
+ def convert_custom_loop_into_onnx(
194
+ g: Any, # "GreaphBuilder"
195
+ sts: Dict[str, Any],
196
+ outputs: List[str],
197
+ *args: str,
198
+ body_callable: Callable[..., onnx.ModelProto],
199
+ reduction_dim: Optional[Sequence[int]] = None,
200
+ name: str = "loop_for_onnx",
201
+ ) -> Union[str, List[str]]:
202
+ """
203
+ Converts a custom op ``higher_ops::loop_for_onnx...`` into e sequence of node.
204
+
205
+ :param g: GreaphBuilder
206
+ :param sts: if not defined, torch does not know the output shapes
207
+ :param outputs: output names
208
+ :param args: input argument known at export time
209
+ :param body: GraphProto, the loop body
210
+ :param reduction_dim: the dimension to follow when aggregating the
211
+ list of tensors after the loop ran
212
+ :param name: to give the onnx nodes a name
213
+ :return: output names
214
+ """
215
+ assert body_callable is not None, "body_callable cannot be None"
216
+ # This should be part of a public API.
217
+ body = body_callable(
218
+ target_opset=g.main_opset,
219
+ verbose=g.verbose,
220
+ exporter_kwargs={"options": g.optimization_options},
221
+ )
222
+
223
+ graph = body.graph if isinstance(body, onnx.ModelProto) else body
224
+ assert isinstance(
225
+ graph, onnx.GraphProto
226
+ ), f"Unexpected type {type(body)} for body{g.get_debug_msg()}"
227
+ assert len(outputs) == 1, f"Only one outputs is expected but outputs={outputs!r}"
228
+ if len(graph.output) != 1:
229
+ outputs = [f"{outputs[0]}#{i}" for i in range(len(graph.output))]
230
+ input_names = [i.name for i in graph.input]
231
+ inputs = [
232
+ *graph.input[:1],
233
+ oh.make_tensor_value_info("cond_unused", onnx.TensorProto.BOOL, []),
234
+ *[
235
+ oh.make_tensor_sequence_value_info(
236
+ f"loop_in{i}", graph.output[i].type.tensor_type.elem_type, None
237
+ )
238
+ for i in range(len(graph.output))
239
+ ],
240
+ # hidden inputs are not added
241
+ ]
242
+ nodes = [
243
+ oh.make_node("Identity", ["cond_unused"], ["cond_out"]),
244
+ *[oh.make_node("Identity", [a], [r]) for a, r in zip(args[1:], input_names[1:])],
245
+ *graph.node,
246
+ *[
247
+ oh.make_node(
248
+ "SequenceInsert",
249
+ [f"loop_in{i}", graph.output[i].name],
250
+ [f"loop_out{i}"],
251
+ )
252
+ for i in range(len(graph.output))
253
+ ],
254
+ ]
255
+ graph_outputs = [
256
+ oh.make_tensor_value_info("cond_out", onnx.TensorProto.BOOL, []),
257
+ *[
258
+ oh.make_tensor_sequence_value_info(
259
+ f"loop_out{i}", graph.output[i].type.tensor_type.elem_type, None
260
+ )
261
+ for i in range(len(graph.output))
262
+ ],
263
+ ]
264
+ graph = oh.make_graph(
265
+ nodes, graph.name, inputs, graph_outputs, graph.initializer, graph.sparse_initializer
266
+ )
267
+
268
+ itypes = [
269
+ graph.output[i].type.sequence_type.elem_type.tensor_type.elem_type
270
+ for i in range(1, len(graph.output))
271
+ ]
272
+ assert len(outputs) == len(
273
+ itypes
274
+ ), f"Length mismatch between outputs={outputs} and graph.output={graph.output}"
275
+ assert (
276
+ 0 not in itypes
277
+ ), f"Undefined types are not allowed in itype={itypes}, graph.output={graph.output}"
278
+ sequences = [g.op.SequenceEmpty(dtype=itype) for itype in itypes]
279
+
280
+ outloop = [g.unique_name(f"loop_for_onnx{i}") for i in range(len(sequences))]
281
+
282
+ for i, s in enumerate(sequences):
283
+ g.set_sequence(s, graph.output[i].type.tensor_type.elem_type)
284
+ g.make_node("Loop", [args[0], "", *sequences], outloop, name=name, body=graph)
285
+ for i, o in enumerate(outloop):
286
+ g.set_sequence(o, graph.output[i].type.tensor_type.elem_type)
287
+ _res = [
288
+ g.op.ConcatFromSequence(
289
+ out,
290
+ outputs=[o],
291
+ name=name,
292
+ axis=0 if not reduction_dim or i >= len(reduction_dim) else reduction_dim[i],
293
+ )
294
+ for i, (out, o) in enumerate(zip(outloop, outputs))
295
+ ]
296
+ if not sts:
297
+ for i, o in enumerate(outputs):
298
+ g.set_type(o, graph.output[i].type.sequence_type.elem_type.tensor_type.elem_type)
299
+ g.set_rank(
300
+ o, len(graph.output[i].type.sequence_type.elem_type.tensor_type.shape.dims)
301
+ )
302
+ return outputs if len(outputs) > 1 else outputs[0]
303
+
304
+
305
+ def convert_into_onnx(
306
+ body_gm: torch.fx.GraphModule,
307
+ args: Sequence[torch.Tensor],
308
+ target_opset: Optional[int] = None,
309
+ verbose: int = 0,
310
+ exporter_kwargs: Optional[Dict[str, Any]] = None,
311
+ ) -> onnx.ModelProto:
312
+ """
313
+ Converts a torch.fx.GraphModule into ONNX.
314
+ It returns a ModelProto.
315
+
316
+ :param body_gm: a torch.fx.GraphModule
317
+ :param args: arguments known at export time
318
+ :param target_opset: targeted opset
319
+ :param verbose: verbosity level
320
+ :param exporter_kwargs: additional exporter arguments
321
+ :return: a ModelProto
322
+ """
323
+ # This does not work with onnx-dynamo.
324
+ # opset still needs to be defined
325
+ container = to_onnx(
326
+ body_gm,
327
+ args,
328
+ exporter="custom",
329
+ exporter_kwargs=exporter_kwargs,
330
+ target_opset=target_opset,
331
+ verbose=verbose,
332
+ )
333
+ return container.model_proto
334
+
335
+
336
+ def loop_for_onnx(
337
+ n_iter: Union[torch.SymInt, torch.Tensor],
338
+ body_fn: Callable[..., Tuple[torch.Tensor]],
339
+ args: Sequence[torch.Tensor],
340
+ reduction_dim: Optional[Sequence[int]] = None,
341
+ ) -> Tuple[torch.Tensor, ...]:
342
+ """
343
+ High operators used to easily export a loop in ONNX.
344
+ Does not fully work with :func:`torch.export.export`,
345
+ it does replaces a custom op with a loop operator afterwards.
346
+ Every iteration produces tensors, all of them are gathered
347
+ into lists, all these lists are concatenated into tensors.
348
+
349
+ :param n_iter: number of iterations, it can be fixed on
350
+ variable, in that case it should a tensor with no dimension
351
+ :param body_fn: function body, takes only tensors and returns
352
+ only tensors, the first argument is the iteration number
353
+ in a tensor with no dimension, all the others
354
+ are not changed during the loop
355
+ :param args: the available tensors at every loop
356
+ :param reduction_dim: the loop aggregated the results into list,
357
+ one of each output, each of them is concatenated into one
358
+ tensor along one dimension, by default, it is the first
359
+ dimension, but it can be defined otherwise
360
+
361
+ .. runpython::
362
+ :showcode:
363
+
364
+ import torch
365
+ import onnxruntime
366
+ from onnx_diagnostic.export.api import to_onnx
367
+ from onnx_diagnostic.export.control_flow_onnx import loop_for_onnx
368
+
369
+
370
+ class Model(torch.nn.Module):
371
+ def forward(self, n_iter, x):
372
+ def body(i, x):
373
+ return x[: i.item() + 1].unsqueeze(1)
374
+
375
+ return loop_for_onnx(n_iter, body, (x,))
376
+
377
+
378
+ model = Model()
379
+ n_iter = torch.tensor(4, dtype=torch.int64)
380
+ x = torch.arange(10, dtype=torch.float32)
381
+ expected = model(n_iter, x)
382
+ print("expected:", expected)
383
+
384
+ onx = to_onnx(
385
+ model,
386
+ (n_iter, x),
387
+ dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
388
+ exporter="custom",
389
+ use_control_flow_dispatcher=True,
390
+ ).model_proto
391
+
392
+ sess = onnxruntime.InferenceSession(
393
+ onx.SerializeToString(), providers=["CPUExecutionProvider"]
394
+ )
395
+ got = sess.run(None, dict(zip(["n_iter", "x"], [n_iter.numpy(), x.numpy()])))
396
+ print("got:", got)
397
+
398
+
399
+ # The loop is exported as a custom ops.
400
+ ep = torch.export.export(
401
+ model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
402
+ )
403
+ print(ep)
404
+
405
+ Another example with two outputs:
406
+
407
+ .. runpython::
408
+ :showcode:
409
+
410
+ import torch
411
+ import onnxruntime
412
+ from onnx_diagnostic.export.api import to_onnx
413
+ from onnx_diagnostic.export.control_flow_onnx import loop_for_onnx
414
+
415
+
416
+ class Model(torch.nn.Module):
417
+ def forward(self, n_iter, x):
418
+ def body(i, x):
419
+ return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(1) + 1
420
+
421
+ two = loop_for_onnx(n_iter, body, (x,))
422
+ return two[0] + two[1]
423
+
424
+
425
+ model = Model()
426
+ n_iter = torch.tensor(4, dtype=torch.int64)
427
+ x = torch.arange(10, dtype=torch.float32)
428
+ expected = model(n_iter, x)
429
+ print("expected:", expected)
430
+
431
+ onx = to_onnx(
432
+ model,
433
+ (n_iter, x),
434
+ dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
435
+ exporter="custom",
436
+ use_control_flow_dispatcher=True,
437
+ ).model_proto
438
+
439
+ sess = onnxruntime.InferenceSession(
440
+ onx.SerializeToString(), providers=["CPUExecutionProvider"]
441
+ )
442
+ got = sess.run(None, dict(zip(["n_iter", "x"], [n_iter.numpy(), x.numpy()])))
443
+ print("got:", got)
444
+
445
+
446
+ # The loop is exported as a custom ops.
447
+ ep = torch.export.export(
448
+ model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
449
+ )
450
+ print(ep)
451
+
452
+ A last example with ``reduction_dim``:
453
+
454
+ .. runpython::
455
+ :showcode:
456
+
457
+ import torch
458
+ import onnxruntime
459
+ from onnx_diagnostic.export.api import to_onnx
460
+ from onnx_diagnostic.export.control_flow_onnx import loop_for_onnx
461
+
462
+
463
+ class Model(torch.nn.Module):
464
+ def forward(self, n_iter, x):
465
+ def body(i, x):
466
+ return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1
467
+
468
+ two = loop_for_onnx(n_iter, body, (x,), reduction_dim=[0, 1])
469
+ return two[0] + two[1].T
470
+
471
+
472
+ model = Model()
473
+ n_iter = torch.tensor(4, dtype=torch.int64)
474
+ x = torch.arange(10, dtype=torch.float32)
475
+ expected = model(n_iter, x)
476
+ print("expected:", expected)
477
+
478
+ onx = to_onnx(
479
+ model,
480
+ (n_iter, x),
481
+ dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
482
+ exporter="custom",
483
+ use_control_flow_dispatcher=True,
484
+ ).model_proto
485
+
486
+ sess = onnxruntime.InferenceSession(
487
+ onx.SerializeToString(), providers=["CPUExecutionProvider"]
488
+ )
489
+ got = sess.run(None, dict(zip(["n_iter", "x"], [n_iter.numpy(), x.numpy()])))
490
+ print("got:", got)
491
+
492
+
493
+ # The loop is exported as a custom ops.
494
+ ep = torch.export.export(
495
+ model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
496
+ )
497
+ print(ep)
498
+ """
499
+ assert args, "The function should have at least one arg."
500
+ assert (
501
+ isinstance(n_iter, torch.Tensor)
502
+ and n_iter.dtype == torch.int64
503
+ and len(n_iter.shape) == 0
504
+ ), f"Only a tensor for one int64 is allowed for n_iter but it equal to {n_iter}."
505
+ if is_exporting():
506
+ body_gm: torch.fx.GraphModule = materialize_as_graph(
507
+ body_fn, (torch.tensor(0, dtype=torch.int64), *args)
508
+ )
509
+ (
510
+ _1,
511
+ _2,
512
+ _3,
513
+ body_mutated_inputs,
514
+ body_outputs,
515
+ ) = check_input_alias_and_mutation_return_outputs(body_gm)
516
+ name, _custom_ops = make_custom_loop_for_onnx(
517
+ n_iter,
518
+ body_fn,
519
+ reduction_dim,
520
+ args,
521
+ body_gm=body_gm,
522
+ body_mutated_inputs=body_mutated_inputs,
523
+ body_outputs=body_outputs,
524
+ )
525
+ fct = getattr(torch.ops.onnx_higher_ops, name)
526
+ return fct(n_iter, *args)
527
+
528
+ return _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args)
@@ -14,7 +14,7 @@ from torch._higher_order_ops.utils import (
14
14
  )
15
15
  from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
16
16
  from torch.utils._python_dispatch import _get_current_dispatch_mode
17
- from .control_flow import _loop_for_fn
17
+ from .control_flow_onnx import _loop_for_onnx_fn
18
18
 
19
19
 
20
20
  class SimpleLoopForOp(HigherOrderOperator):
@@ -66,7 +66,7 @@ def simple_loop_for(
66
66
  return simple_loop_for_op(n_iter, body_fn, (n_iter, *operands))
67
67
 
68
68
  if isinstance(n_iter, (bool, int, float)):
69
- return _loop_for_fn(body_fn, n_iter, None, *operands)
69
+ return _loop_for_onnx_fn(body_fn, n_iter, None, *operands)
70
70
 
71
71
  def _validate_input(n_iter, body_fn, operands):
72
72
  assert isinstance(
@@ -92,10 +92,11 @@ def simple_loop_for(
92
92
 
93
93
  from torch._higher_order_ops.utils import setup_compilation_env
94
94
 
95
- with setup_compilation_env() as backend:
96
- return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)(
97
- n_iter, body_fn, operands
98
- )
95
+ with setup_compilation_env() as _backend:
96
+ return _loop_for_op_wrapper(n_iter, body_fn, *operands)
97
+ # return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)(
98
+ # n_iter, body_fn, operands
99
+ # )
99
100
 
100
101
 
101
102
  def trace_loop_for(proxy_mode, func_overload, n_iter, body_fn, operands):
@@ -127,9 +128,13 @@ def loop_for_op_dense(n_iter, body_fn, operands):
127
128
  ), f"Dense implementation operands must be a list of tensors and ints {operands}"
128
129
  mode = _get_current_dispatch_mode()
129
130
  assert mode is None, "Mode should never be enabled for CPU/CUDA key"
130
- return _loop_for_fn(body_fn, n_iter, None, *operands)
131
+ return _loop_for_onnx_fn(body_fn, n_iter, None, operands)
131
132
 
132
133
 
133
134
  @simple_loop_for_op.py_impl(ProxyTorchDispatchMode)
134
135
  def inner(mode, n_iter, body_fn, operands):
135
136
  return trace_loop_for(mode, simple_loop_for_op, n_iter, body_fn, operands)
137
+
138
+
139
+ simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCPU)
140
+ simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCUDA)