onnx-diagnostic 0.8.5__py3-none-any.whl → 0.8.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +154 -3
  3. onnx_diagnostic/ci_models/__init__.py +0 -0
  4. onnx_diagnostic/ci_models/ci_helpers.py +435 -0
  5. onnx_diagnostic/ci_models/export_phi4_mm.py +1062 -0
  6. onnx_diagnostic/ci_models/export_qwen25_vl.py +568 -0
  7. onnx_diagnostic/export/api.py +1 -0
  8. onnx_diagnostic/export/cf_simple_loop_for.py +537 -0
  9. onnx_diagnostic/export/control_flow_onnx.py +23 -17
  10. onnx_diagnostic/ext_test_case.py +23 -2
  11. onnx_diagnostic/helpers/bench_run.py +1 -1
  12. onnx_diagnostic/helpers/log_helper.py +1 -3
  13. onnx_diagnostic/helpers/optim_helper.py +116 -0
  14. onnx_diagnostic/tasks/image_text_to_text.py +15 -5
  15. onnx_diagnostic/tasks/text2text_generation.py +84 -48
  16. onnx_diagnostic/tasks/text_generation.py +3 -0
  17. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +44 -2
  18. onnx_diagnostic/torch_export_patches/patch_expressions.py +4 -1
  19. onnx_diagnostic/torch_export_patches/patch_module.py +31 -23
  20. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py +80 -0
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +86 -3
  22. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +15 -0
  23. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +23 -24
  24. onnx_diagnostic/torch_models/hghub/hub_api.py +11 -0
  25. onnx_diagnostic/torch_models/hghub/hub_data.py +9 -1
  26. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +29 -8
  27. onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -19
  28. onnx_diagnostic/torch_onnx/compare.py +357 -0
  29. {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/METADATA +1 -1
  30. {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/RECORD +33 -27
  31. onnx_diagnostic/export/control_flow.py +0 -214
  32. onnx_diagnostic/export/control_flow_research.py +0 -140
  33. {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/WHEEL +0 -0
  34. {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/licenses/LICENSE.txt +0 -0
  35. {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,537 @@
1
+ import contextlib
2
+ from typing import Callable, List, Optional, Sequence, Tuple, Union
3
+ import torch
4
+ from torch._C import DispatchKey
5
+ from torch._ops import HigherOrderOperator
6
+ from torch._subclasses.fake_tensor import FakeTensorMode
7
+ import torch.utils._pytree as pytree
8
+ from torch._higher_order_ops.utils import (
9
+ check_input_alias_and_mutation_return_outputs,
10
+ reenter_make_fx,
11
+ unique_graph_id,
12
+ validate_subgraph_args_types,
13
+ )
14
+ import torch._dynamo.variables.higher_order_ops as hop
15
+ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
16
+ from torch.utils._python_dispatch import _get_current_dispatch_mode
17
+
18
+
19
+ class SimpleLoopForOp(HigherOrderOperator):
20
+ """Higher order op for :func:`simple_loop_for`."""
21
+
22
+ def __init__(self):
23
+ super().__init__("simple_loop_for")
24
+
25
+ def __call__(self, n_iter, body_fn, operands, concatenation_dims=None):
26
+ validate_subgraph_args_types(operands)
27
+ return super().__call__(n_iter, body_fn, operands, concatenation_dims)
28
+
29
+ def gen_schema(self, n_iter, body_fn, operands, concatenation_dims):
30
+ from torch._higher_order_ops.schema import HopSchemaGenerator
31
+ from torch._higher_order_ops.utils import materialize_as_graph
32
+
33
+ body_gm: torch.fx.GraphModule = materialize_as_graph( # type: ignore[annotation-unchecked]
34
+ body_fn, (torch.tensor(0, dtype=torch.int64), *operands)
35
+ )
36
+ (
37
+ _,
38
+ _,
39
+ _,
40
+ body_mutated_inputs,
41
+ body_outputs,
42
+ ) = check_input_alias_and_mutation_return_outputs(body_gm)
43
+ mutated_inputs = body_mutated_inputs
44
+
45
+ schema_gen = HopSchemaGenerator(self)
46
+ schema_gen.add_arg("n_iter", n_iter)
47
+ schema_gen.add_arg("body_fn", body_gm)
48
+ for idx, arg in enumerate(operands):
49
+ schema_gen.add_arg(f"operand{idx}", arg, is_mutated=idx in mutated_inputs)
50
+
51
+ for out in body_outputs:
52
+ schema_gen.add_output(out)
53
+ assert concatenation_dims is None or len(concatenation_dims) == len(body_outputs), (
54
+ f"concatenation_dims={concatenation_dims} but its length should be equal to "
55
+ f"the number of outputs ({len(body_outputs)})"
56
+ )
57
+ schema_gen.add_schema_tree_spec(n_iter, body_fn, operands, concatenation_dims)
58
+ return schema_gen.gen_schema()
59
+
60
+
61
+ simple_loop_for_op = SimpleLoopForOp()
62
+
63
+
64
+ def _simple_loop_for_fn(
65
+ n_iter: torch.Tensor,
66
+ body_fn: Callable,
67
+ operands: Tuple[torch.Tensor, ...] = (),
68
+ concatenation_dims: Optional[Sequence[int]] = None,
69
+ ) -> Tuple[torch.Tensor, ...]:
70
+ """
71
+ Python implementation of the loop.
72
+
73
+ :param n_iter: number of iteration
74
+ :param body_fn: function implementing the body
75
+ :param concatenation_dims: dimension used to reduce the list produced by the loop
76
+ :param operands: arguments to the loop body
77
+ :return: results
78
+ """
79
+ torch._check(
80
+ isinstance(n_iter, (int, torch.Tensor)),
81
+ lambda: f"Unexpected type {type(n_iter)} for n_iter",
82
+ )
83
+ torch._check(callable(body_fn), lambda: f"Unexpected type {type(body_fn)} for body_fn")
84
+ torch._check(
85
+ concatenation_dims is None or isinstance(concatenation_dims, (list, tuple)),
86
+ lambda: f"Unexpected type {type(concatenation_dims)} for concatenation_dims",
87
+ )
88
+ torch._check(
89
+ isinstance(operands, tuple), lambda: f"Unexpected type {type(operands)} for operands"
90
+ )
91
+ res: List[Union[torch.Tensor, Tuple[torch.Tensor, ...]]] = []
92
+ for i in torch.arange(
93
+ n_iter, dtype=torch.int64 if isinstance(n_iter, int) else n_iter.dtype
94
+ ):
95
+ r = body_fn(i, *operands)
96
+ if isinstance(r, tuple):
97
+ assert not res or len(r) == len(res[-1]), (
98
+ f"Unexpected number of results {len(r)} for function {body_fn}, "
99
+ f"expected {len(res[-1])}"
100
+ )
101
+ assert all(isinstance(t, torch.Tensor) for t in r), (
102
+ f"Unexpected type {[type(_) for _ in r]} for returned by function {body_fn}, "
103
+ f"it must be a tuple of Tensor or a Tensor."
104
+ )
105
+ res.append(r)
106
+ else:
107
+ assert isinstance(r, torch.Tensor), (
108
+ f"Unexpected type {type(r)} coming from function {body_fn}, "
109
+ f"it must be a tuple of Tensor or a Tensor."
110
+ )
111
+ assert not res or len(res[-1]) == 1, (
112
+ f"Unexpected number of results {len(r)} coming from function {body_fn}, "
113
+ f"expected {len(res[-1])}"
114
+ )
115
+ res.append((r,))
116
+
117
+ if not res:
118
+ return torch.empty(tuple(), dtype=torch.float32, device=operands[0].device)
119
+
120
+ n_res = len(res[0])
121
+ return tuple(
122
+ torch.cat(
123
+ [r[i] for r in res],
124
+ dim=(
125
+ 0
126
+ if concatenation_dims is None or i >= len(concatenation_dims)
127
+ else concatenation_dims[i]
128
+ ),
129
+ )
130
+ for i in range(n_res)
131
+ )
132
+
133
+
134
+ def _simple_loop_for(
135
+ n_iter: Union[int, torch.Tensor],
136
+ body_fn: Callable,
137
+ operands: Tuple[torch.Tensor, ...] = (),
138
+ concatenation_dims: Optional[Sequence[int]] = None,
139
+ ) -> Tuple[torch.Tensor, ...]:
140
+ def _validate_input(n_iter, body_fn, operands, concatenation_dims):
141
+ assert isinstance(
142
+ n_iter, (int, torch.Tensor, torch.SymInt)
143
+ ), f"Expected pred to be bool or tensor, but got {n_iter}."
144
+ assert (
145
+ not isinstance(n_iter, torch.Tensor) or n_iter.numel() == 1
146
+ ), f"Expected pred to be bool or single-element tensor, but got {n_iter}."
147
+ assert callable(body_fn), "Expect both branches to be callable."
148
+ assert isinstance(operands, (tuple, list)) and pytree.tree_all(
149
+ lambda t: isinstance(t, torch.Tensor), operands
150
+ ), (
151
+ "Expect operands to be a tuple of possibly nested dict/list/tuple that only "
152
+ f"consists of tensor leaves, but got {operands}."
153
+ )
154
+ assert concatenation_dims is None or (
155
+ isinstance(concatenation_dims, (list, tuple))
156
+ and all(isinstance(i, int) for i in concatenation_dims)
157
+ ), (
158
+ f"concatenation_dims should be None or a list of integers but it is "
159
+ f"{concatenation_dims}. Its length should be equal to the number of outputs."
160
+ )
161
+ assert torch._dynamo.is_dynamo_supported(), "simple_loop_for requires dynamo support."
162
+
163
+ if torch.compiler.is_dynamo_compiling():
164
+ return simple_loop_for_op(
165
+ n_iter, body_fn, operands, concatenation_dims=concatenation_dims
166
+ )
167
+
168
+ if isinstance(n_iter, (bool, int, float)):
169
+ torch._check(
170
+ isinstance(n_iter, int),
171
+ lambda: f"n_iter must be an integer or a tensor not {type(n_iter)}",
172
+ )
173
+ return _simple_loop_for_fn(
174
+ n_iter, body_fn, operands, concatenation_dims=concatenation_dims
175
+ )
176
+
177
+ def _loop_for_op_wrapper(n_iter, body_fn, operands, concatenation_dims):
178
+ return simple_loop_for_op(n_iter, body_fn, operands, concatenation_dims)
179
+
180
+ _validate_input(n_iter, body_fn, operands, concatenation_dims)
181
+
182
+ # This requires torch>=2.10.
183
+ from torch._higher_order_ops.utils import setup_compilation_env
184
+
185
+ with setup_compilation_env() as _backend:
186
+ return _loop_for_op_wrapper(n_iter, body_fn, operands, concatenation_dims)
187
+ # This is needed to support function body using module weights or function body
188
+ # defined as a class method. This is yet to be implemented.
189
+ # cpl = torch.compile(_loop_for_op_wrapper, backend=_backend, fullgraph=True)
190
+ # return cpl(n_iter, body_fn, operands, concatenation_dims)
191
+
192
+
193
+ def trace_simple_loop_for(
194
+ proxy_mode, func_overload, n_iter, body_fn, operands, concatenation_dims
195
+ ):
196
+ """See function ``simple_loop_for``."""
197
+ assert isinstance(operands, (list, tuple)) and (
198
+ concatenation_dims is None
199
+ or (
200
+ isinstance(concatenation_dims, (list, tuple))
201
+ and all(isinstance(i, int) for i in concatenation_dims)
202
+ )
203
+ ), (
204
+ f"simple_loop_for operands must be a list or tuple of tensors and SymInts and "
205
+ f"concatenation_dims must be None or a list of integer, "
206
+ f"operands={[type(o) for o in operands]}, "
207
+ f"concatenation_dims={concatenation_dims}"
208
+ )
209
+
210
+ body_graph = reenter_make_fx(body_fn)(n_iter, *operands)
211
+
212
+ body_outs = []
213
+ for node in body_graph.graph.nodes:
214
+ if node.op == "output":
215
+ body_outs.extend(node.args)
216
+
217
+ # flat_body_outs = pytree.arg_tree_leaves(*body_outs)
218
+ _i, body_name = unique_graph_id(proxy_mode, prefix="body_graph")
219
+ proxy_mode.tracer.root.register_module(body_name, body_graph)
220
+ args = (n_iter, body_graph, operands, concatenation_dims)
221
+ proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
222
+ out_proxy = proxy_mode.tracer.create_proxy("call_function", func_overload, proxy_args, {})
223
+ out = func_overload(n_iter, body_graph, operands, concatenation_dims)
224
+ return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
225
+
226
+
227
+ @simple_loop_for_op.py_impl(DispatchKey.CompositeExplicitAutograd)
228
+ def loop_for_op_dense(n_iter, body_fn, operands, concatenation_dims=None):
229
+ """Registered eager mode implementation."""
230
+ assert all(isinstance(o, torch.Tensor) for o in operands) and (
231
+ concatenation_dims is None
232
+ or (
233
+ isinstance(concatenation_dims, (list, tuple))
234
+ and all(isinstance(i, int) for i in concatenation_dims)
235
+ )
236
+ ), (
237
+ f"simple_loop_for operands must be a list or tuple of tensors and SymInts and "
238
+ f"concatenation_dims must be None or a list of integer, "
239
+ f"operands={[type(o) for o in operands]}, "
240
+ f"concatenation_dims={concatenation_dims}"
241
+ )
242
+ mode = _get_current_dispatch_mode()
243
+ assert mode is None, "Mode should never be enabled for CPU/CUDA key"
244
+ is_fake = isinstance(n_iter, torch._subclasses.fake_tensor.FakeTensor)
245
+ res = _simple_loop_for_fn(n_iter, body_fn, operands, concatenation_dims=concatenation_dims)
246
+ assert is_fake or not any(
247
+ isinstance(r, torch._subclasses.fake_tensor.FakeTensor) for r in res
248
+ ), (
249
+ f"One result is a fake tensor but the inputs were not, type(n_iter)={type(n_iter)}, "
250
+ f"operands: {[type(_) for _ in operands]}, res: {[type(_) for _ in res]}"
251
+ )
252
+ return res
253
+
254
+
255
+ @simple_loop_for_op.py_impl(ProxyTorchDispatchMode)
256
+ def inner(mode, n_iter, body_fn, operands, concatenation_dims=None):
257
+ """Registered tracing implementation."""
258
+ return trace_simple_loop_for(
259
+ mode, simple_loop_for_op, n_iter, body_fn, operands, concatenation_dims
260
+ )
261
+
262
+
263
+ @simple_loop_for_op.py_impl(FakeTensorMode)
264
+ def simple_loop_for_fake_tensor_mode(mode, n_iter, body_fn, operands, concatenation_dims=None):
265
+ """Registered FakeMode implementation."""
266
+ ignore_fresh_unbacked = contextlib.nullcontext()
267
+ if mode.shape_env:
268
+ ignore_fresh_unbacked = mode.shape_env.ignore_fresh_unbacked_symbols()
269
+
270
+ with mode, ignore_fresh_unbacked:
271
+ flat_body_outs, true_body_spec = pytree.tree_flatten(body_fn(n_iter, *operands))
272
+
273
+ return pytree.tree_unflatten(flat_body_outs, true_body_spec)
274
+
275
+
276
+ # Registration for autograd.
277
+ simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCPU)
278
+ simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCUDA)
279
+
280
+
281
+ class SimpleLoopForHigherOrderVariable(hop.TorchHigherOrderOperatorVariable):
282
+ """
283
+ Replicates the same pattern found for other higher order operators.
284
+ This enables recursive compilation and the use of modules inside a function.
285
+ """
286
+
287
+ _HOP_NAME = "simple_loop_for"
288
+ _ALLOW_FALLBACK_TO_EAGER = False
289
+ supports_input_mutation = False
290
+ supports_aliasing = False
291
+
292
+ def _call_function(
293
+ self,
294
+ tx: torch._dynamo.symbolic_convert.InstructionTranslator,
295
+ args: list[hop.VariableTracker],
296
+ kwargs: dict[str, hop.VariableTracker],
297
+ ) -> hop.VariableTracker:
298
+ """Main function."""
299
+ args, kwargs = hop.LazyVariableTracker.realize_all((args, kwargs))
300
+
301
+ for i, k in enumerate(["n_iter", "body_fn", "operands", "concatenated_dims"]):
302
+ if v := kwargs.pop(k, None):
303
+ assert i == len(args), "did not provide the right number of non-keyword args"
304
+ args.append(v)
305
+
306
+ if len(args) != 4 or kwargs:
307
+ hop.unimplemented(
308
+ gb_type="simple_loop_for: improper args/kwargs",
309
+ context=f"args: {args}, kwargs: {kwargs}",
310
+ explanation=f"torch.cond expects 4 positional arguments (got {len(args)}) "
311
+ f"and no keyword arguments (got {len(kwargs)})",
312
+ hints=[*hop.graph_break_hints.USER_ERROR],
313
+ )
314
+
315
+ # Specialize into one of the branches since pred is constant
316
+ n_iter, body_fn, operands, _concatenated_dims = args
317
+ assert type(n_iter) is not hop.ConstantVariable, (
318
+ f"n_iter is a {type(n_iter)}. When used simple_loop_for, "
319
+ f"it unrolls the loop. A SymInt should be used."
320
+ )
321
+
322
+ # predicate
323
+ if type(n_iter.realize()) not in (
324
+ hop.ConstantVariable,
325
+ hop.TensorVariable,
326
+ hop.SymNodeVariable,
327
+ ):
328
+ hop.unimplemented(
329
+ gb_type="simple_loop_for: improper predicate",
330
+ context=str(n_iter),
331
+ explanation=(
332
+ f"Expected `n_iter` to be an int or a integer "
333
+ f"tensor with a single item "
334
+ f"but got {str(type(n_iter))} with original python type "
335
+ f"{str(n_iter.python_type())}."
336
+ ),
337
+ hints=[*hop.graph_break_hints.USER_ERROR],
338
+ )
339
+
340
+ # operands
341
+ if not isinstance(operands, (hop.ListVariable, hop.TupleVariable)):
342
+ hop.unimplemented(
343
+ gb_type="simple_loop_for: improper operands",
344
+ context=str(operands),
345
+ explanation="Expected `operands` to be a list/tuple "
346
+ f"but got {operands.python_type()}.",
347
+ hints=[*hop.graph_break_hints.USER_ERROR],
348
+ )
349
+
350
+ operands_seq = operands.unpack_var_sequence(tx)
351
+ if not hop.only_consist_of(
352
+ operands, (hop.TensorVariable, hop.ConstantVariable, hop.SymNodeVariable)
353
+ ):
354
+ hop.unimplemented(
355
+ gb_type="simple_loop_for: improper operands contents",
356
+ context=str(operands),
357
+ explanation=(
358
+ "Expected `operands` to be a list/tuple of pytrees "
359
+ "that only consists of tensor leaves."
360
+ ),
361
+ hints=[*hop.graph_break_hints.USER_ERROR],
362
+ )
363
+
364
+ # branches
365
+ hop._check_supported_callable_arg(tx, body_fn, "body_fn")
366
+
367
+ def speculate_body():
368
+ (
369
+ (ret_val, ret_spec),
370
+ ret_graph,
371
+ ret_lifted_freevars,
372
+ ) = hop.speculate_subgraph(
373
+ tx,
374
+ args[1],
375
+ (args[0], *operands_seq),
376
+ {},
377
+ self._HOP_NAME,
378
+ source_target=self.value,
379
+ should_flatten_outputs=True,
380
+ # TODO - removing consts from control flow ops need more work
381
+ remove_consts_from_outputs=False,
382
+ supports_input_mutation=self.supports_input_mutation,
383
+ supports_aliasing=self.supports_aliasing,
384
+ )
385
+
386
+ # need to ensure we increase epoch so we don't memoize unbacked bindings
387
+ # across different subgraphs which can interfere with runtime assertion
388
+ # generation.
389
+ tx.fake_mode.epoch += 1
390
+
391
+ if not hop.only_consist_of(ret_val, (hop.TensorVariable, hop.ConstantVariable)):
392
+ hop.unimplemented(
393
+ gb_type="simple_loop_for: unsupported branch return type",
394
+ context=str(ret_val),
395
+ explanation=(
396
+ "Expected branches to return a possibly nested "
397
+ "pytree of tensors or constant ints."
398
+ ),
399
+ hints=[*hop.graph_break_hints.USER_ERROR],
400
+ )
401
+ for ret in ret_val.unpack_var_sequence(tx):
402
+ if ret.is_python_constant() and not isinstance(ret.as_python_constant(), int):
403
+ hop.unimplemented(
404
+ gb_type=(
405
+ "simple_loop_for: unsupported branch return type "
406
+ "(constant non-int)"
407
+ ),
408
+ context=str(ret_val),
409
+ explanation="Constants returned from branches must be ints.",
410
+ hints=[*hop.graph_break_hints.USER_ERROR],
411
+ )
412
+ return ret_val, ret_spec, ret_graph, ret_lifted_freevars
413
+
414
+ body_r, body_spec, body_graph, body_lifted_freevars = speculate_body()
415
+ body_nn_modules = dict(tx.output.nn_modules)
416
+
417
+ same_spec = body_spec.treespec.as_python_constant()
418
+ if same_spec is not NotImplemented and not same_spec:
419
+ hop.unimplemented(
420
+ gb_type="simple_loop_for: differing branch outputs",
421
+ context=(
422
+ f"body_spec: {body_spec.treespec}, false_spec: "
423
+ f"{body_spec.treespec}, same_spec: {same_spec}"
424
+ ),
425
+ explanation="Expected branches to return the same pytree structure.",
426
+ hints=[*hop.graph_break_hints.USER_ERROR],
427
+ )
428
+
429
+ body_name = tx.output.install_subgraph(
430
+ "loop_body", torch.fx.GraphModule(body_nn_modules, body_graph)
431
+ )
432
+ body_node = hop.make_attr(tx, body_name)
433
+ p_args = (
434
+ n_iter.as_proxy(),
435
+ body_node,
436
+ # We pick true_shared but it shouldn't matter
437
+ operands.as_proxy() + tuple(body_lifted_freevars.keys()),
438
+ )
439
+
440
+ return hop._call_function_and_unflatten_output(
441
+ tx,
442
+ simple_loop_for,
443
+ p_args,
444
+ {},
445
+ None,
446
+ body_spec,
447
+ body_r,
448
+ )
449
+
450
+
451
+ hop._hop_name_to_variable_class["simple_loop_for"] = SimpleLoopForHigherOrderVariable
452
+
453
+
454
+ # @torch._functorch.utils.exposed_in("torch")
455
+ def simple_loop_for(
456
+ n_iter: Union[int, torch.Tensor],
457
+ body_fn: Callable,
458
+ operands: Tuple[torch.Tensor, ...] = (),
459
+ concatenation_dims: Optional[Union[int, Sequence[int]]] = None,
460
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
461
+ """
462
+ Implements a simple loop for, the body is defined by a function which takes the
463
+ iteration number stored in a tensor, and other tensors.
464
+ It results one or several tensors in a tuple. All of them
465
+ are finally concatenated along the first dimension.
466
+
467
+ :param n_iter: iteration number
468
+ :param body: function
469
+ :param operands: bidy arguments
470
+ :param concatenation_dims: dimension or dimensions used to concatenate the output sequences
471
+ :return: contenated outputs, the output is a Tensor
472
+
473
+ An example with one output:
474
+
475
+ .. runpython::
476
+ :showcode:
477
+
478
+ import torch
479
+ from onnx_diagnostic.export.cf_simple_loop_for import simple_loop_for
480
+
481
+
482
+ class Model(torch.nn.Module):
483
+ def forward(self, n_iter, x):
484
+ def body(i, x):
485
+ return (x[: i.item() + 1].unsqueeze(1),)
486
+
487
+ return simple_loop_for(n_iter, body, (x,))
488
+
489
+
490
+ model = Model()
491
+ n_iter = torch.tensor(4, dtype=torch.int64)
492
+ x = torch.arange(10, dtype=torch.float32)
493
+ ep = torch.export.export(
494
+ model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
495
+ )
496
+ print(ep)
497
+
498
+ Another example with two outputs and a final concatenation on different axes.
499
+
500
+ .. runpython::
501
+ :showcode:
502
+
503
+ import torch
504
+ from onnx_diagnostic.export.cf_simple_loop_for import simple_loop_for
505
+
506
+
507
+ class Model(torch.nn.Module):
508
+ def forward(self, n_iter, x):
509
+ def body(i, x):
510
+ return (x[: i.item() + 1].unsqueeze(1), x[i.item() + 1 :].unsqueeze(0))
511
+
512
+ return simple_loop_for(n_iter, body, (x,), (0, 1))
513
+
514
+
515
+ model = Model()
516
+ n_iter = torch.tensor(4, dtype=torch.int64)
517
+ x = torch.arange(10, dtype=torch.float32)
518
+ ep = torch.export.export(
519
+ model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
520
+ )
521
+ print(ep)
522
+ """
523
+ res = _simple_loop_for(
524
+ n_iter,
525
+ body_fn,
526
+ operands,
527
+ concatenation_dims=(
528
+ (concatenation_dims,)
529
+ if isinstance(concatenation_dims, int)
530
+ else concatenation_dims
531
+ ),
532
+ )
533
+ torch._check(
534
+ isinstance(res, tuple),
535
+ lambda: f"Output of the loop should be a tuple not {type(res)}.",
536
+ )
537
+ return res[0] if len(res) == 1 else res
@@ -55,13 +55,13 @@ def is_exporting() -> bool:
55
55
  return _TEST_EXPORT or torch.compiler.is_exporting() or torch.compiler.is_compiling()
56
56
 
57
57
 
58
- def _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args):
58
+ def _loop_for_onnx_fn(n_iter, body_fn, concatenation_dims, args):
59
59
  """
60
60
  Python implementation of the loop.
61
61
 
62
62
  :param n_iter: number of iteration
63
63
  :param body_fn: function implementing the body
64
- :param reduction_dim: dimension used to reduce the list produced by the loop
64
+ :param concatenation_dims: dimension used to reduce the list produced by the loop
65
65
  :param args: arguments to the loop body
66
66
  :return: results
67
67
  """
@@ -95,7 +95,9 @@ def _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args):
95
95
  torch.cat(
96
96
  [r[i] for r in res],
97
97
  dim=(
98
- 0 if reduction_dim is None or i >= len(reduction_dim) else reduction_dim[i]
98
+ 0
99
+ if concatenation_dims is None or i >= len(concatenation_dims)
100
+ else concatenation_dims[i]
99
101
  ),
100
102
  )
101
103
  for i in range(n_res)
@@ -106,7 +108,7 @@ def _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args):
106
108
  def make_custom_loop_for_onnx(
107
109
  n_iter: torch.Tensor,
108
110
  body_fn: Callable,
109
- reduction_dim: Optional[Sequence[int]],
111
+ concatenation_dims: Optional[Sequence[int]],
110
112
  args: Sequence[torch.Tensor],
111
113
  body_gm: Optional[torch.fx.GraphModule] = None,
112
114
  body_mutated_inputs: Optional[List[Any]] = None,
@@ -120,7 +122,7 @@ def make_custom_loop_for_onnx(
120
122
 
121
123
  :param n_iter: number of iterations defined by a tensor of no dimension
122
124
  :param body_fn: the loop body defined as a function
123
- :param reduction_dim: dimension used to concatenated the results
125
+ :param concatenation_dims: dimension used to concatenated the results
124
126
  :param args: list of tensors, input to the body
125
127
  :param body_gm: torch.fx.GraphModule equivalent to *body_gm*
126
128
  :param body_mutated_inputs: inputs to *body_gm*
@@ -133,7 +135,7 @@ def make_custom_loop_for_onnx(
133
135
  assert body_mutated_inputs is not None, "body_mutated_inputs cannot be None"
134
136
  assert body_outputs is not None, "body_outputs cannot be None"
135
137
  srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs)
136
- sred = "x".join(map(str, reduction_dim)) if reduction_dim else ""
138
+ sred = "x".join(map(str, concatenation_dims)) if concatenation_dims else ""
137
139
  full_name = (
138
140
  body_fn.__qualname__.replace("<locals>", "L")
139
141
  .replace("<lambda>", "l")
@@ -169,14 +171,14 @@ def make_custom_loop_for_onnx(
169
171
  custom_def,
170
172
  _make_onx,
171
173
  (
172
- lambda g, sts, outputs, *args, bc=_make_onx, rd=reduction_dim, name=name: (
174
+ lambda g, sts, outputs, *args, bc=_make_onx, rd=concatenation_dims, name=name: (
173
175
  convert_custom_loop_into_onnx(
174
176
  g,
175
177
  sts,
176
178
  outputs,
177
179
  *args,
178
180
  body_callable=bc,
179
- reduction_dim=rd,
181
+ concatenation_dims=rd,
180
182
  name=name,
181
183
  )
182
184
  )
@@ -196,7 +198,7 @@ def convert_custom_loop_into_onnx(
196
198
  outputs: List[str],
197
199
  *args: str,
198
200
  body_callable: Callable[..., onnx.ModelProto],
199
- reduction_dim: Optional[Sequence[int]] = None,
201
+ concatenation_dims: Optional[Sequence[int]] = None,
200
202
  name: str = "loop_for_onnx",
201
203
  ) -> Union[str, List[str]]:
202
204
  """
@@ -207,7 +209,7 @@ def convert_custom_loop_into_onnx(
207
209
  :param outputs: output names
208
210
  :param args: input argument known at export time
209
211
  :param body: GraphProto, the loop body
210
- :param reduction_dim: the dimension to follow when aggregating the
212
+ :param concatenation_dims: the dimension to follow when aggregating the
211
213
  list of tensors after the loop ran
212
214
  :param name: to give the onnx nodes a name
213
215
  :return: output names
@@ -289,7 +291,11 @@ def convert_custom_loop_into_onnx(
289
291
  out,
290
292
  outputs=[o],
291
293
  name=name,
292
- axis=0 if not reduction_dim or i >= len(reduction_dim) else reduction_dim[i],
294
+ axis=(
295
+ 0
296
+ if not concatenation_dims or i >= len(concatenation_dims)
297
+ else concatenation_dims[i]
298
+ ),
293
299
  )
294
300
  for i, (out, o) in enumerate(zip(outloop, outputs))
295
301
  ]
@@ -337,7 +343,7 @@ def loop_for_onnx(
337
343
  n_iter: Union[torch.SymInt, torch.Tensor],
338
344
  body_fn: Callable[..., Tuple[torch.Tensor]],
339
345
  args: Sequence[torch.Tensor],
340
- reduction_dim: Optional[Sequence[int]] = None,
346
+ concatenation_dims: Optional[Sequence[int]] = None,
341
347
  ) -> Tuple[torch.Tensor, ...]:
342
348
  """
343
349
  High operators used to easily export a loop in ONNX.
@@ -353,7 +359,7 @@ def loop_for_onnx(
353
359
  in a tensor with no dimension, all the others
354
360
  are not changed during the loop
355
361
  :param args: the available tensors at every loop
356
- :param reduction_dim: the loop aggregated the results into list,
362
+ :param concatenation_dims: the loop aggregated the results into list,
357
363
  one of each output, each of them is concatenated into one
358
364
  tensor along one dimension, by default, it is the first
359
365
  dimension, but it can be defined otherwise
@@ -449,7 +455,7 @@ def loop_for_onnx(
449
455
  )
450
456
  print(ep)
451
457
 
452
- A last example with ``reduction_dim``:
458
+ A last example with ``concatenation_dims``:
453
459
 
454
460
  .. runpython::
455
461
  :showcode:
@@ -465,7 +471,7 @@ def loop_for_onnx(
465
471
  def body(i, x):
466
472
  return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1
467
473
 
468
- two = loop_for_onnx(n_iter, body, (x,), reduction_dim=[0, 1])
474
+ two = loop_for_onnx(n_iter, body, (x,), concatenation_dims=[0, 1])
469
475
  return two[0] + two[1].T
470
476
 
471
477
 
@@ -516,7 +522,7 @@ def loop_for_onnx(
516
522
  name, _custom_ops = make_custom_loop_for_onnx(
517
523
  n_iter,
518
524
  body_fn,
519
- reduction_dim,
525
+ concatenation_dims,
520
526
  args,
521
527
  body_gm=body_gm,
522
528
  body_mutated_inputs=body_mutated_inputs,
@@ -525,4 +531,4 @@ def loop_for_onnx(
525
531
  fct = getattr(torch.ops.onnx_higher_ops, name)
526
532
  return fct(n_iter, *args)
527
533
 
528
- return _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args)
534
+ return _loop_for_onnx_fn(n_iter, body_fn, concatenation_dims, args)