onnx-diagnostic 0.8.5__py3-none-any.whl → 0.8.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +154 -3
- onnx_diagnostic/ci_models/__init__.py +0 -0
- onnx_diagnostic/ci_models/ci_helpers.py +435 -0
- onnx_diagnostic/ci_models/export_phi4_mm.py +1062 -0
- onnx_diagnostic/ci_models/export_qwen25_vl.py +568 -0
- onnx_diagnostic/export/api.py +1 -0
- onnx_diagnostic/export/cf_simple_loop_for.py +537 -0
- onnx_diagnostic/export/control_flow_onnx.py +23 -17
- onnx_diagnostic/ext_test_case.py +23 -2
- onnx_diagnostic/helpers/bench_run.py +1 -1
- onnx_diagnostic/helpers/log_helper.py +1 -3
- onnx_diagnostic/helpers/optim_helper.py +116 -0
- onnx_diagnostic/tasks/image_text_to_text.py +15 -5
- onnx_diagnostic/tasks/text2text_generation.py +84 -48
- onnx_diagnostic/tasks/text_generation.py +3 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +44 -2
- onnx_diagnostic/torch_export_patches/patch_expressions.py +4 -1
- onnx_diagnostic/torch_export_patches/patch_module.py +31 -23
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py +80 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +86 -3
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +15 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +23 -24
- onnx_diagnostic/torch_models/hghub/hub_api.py +11 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +9 -1
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +29 -8
- onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -19
- onnx_diagnostic/torch_onnx/compare.py +357 -0
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/RECORD +33 -27
- onnx_diagnostic/export/control_flow.py +0 -214
- onnx_diagnostic/export/control_flow_research.py +0 -140
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,537 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
from typing import Callable, List, Optional, Sequence, Tuple, Union
|
|
3
|
+
import torch
|
|
4
|
+
from torch._C import DispatchKey
|
|
5
|
+
from torch._ops import HigherOrderOperator
|
|
6
|
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
7
|
+
import torch.utils._pytree as pytree
|
|
8
|
+
from torch._higher_order_ops.utils import (
|
|
9
|
+
check_input_alias_and_mutation_return_outputs,
|
|
10
|
+
reenter_make_fx,
|
|
11
|
+
unique_graph_id,
|
|
12
|
+
validate_subgraph_args_types,
|
|
13
|
+
)
|
|
14
|
+
import torch._dynamo.variables.higher_order_ops as hop
|
|
15
|
+
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
|
16
|
+
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SimpleLoopForOp(HigherOrderOperator):
|
|
20
|
+
"""Higher order op for :func:`simple_loop_for`."""
|
|
21
|
+
|
|
22
|
+
def __init__(self):
|
|
23
|
+
super().__init__("simple_loop_for")
|
|
24
|
+
|
|
25
|
+
def __call__(self, n_iter, body_fn, operands, concatenation_dims=None):
|
|
26
|
+
validate_subgraph_args_types(operands)
|
|
27
|
+
return super().__call__(n_iter, body_fn, operands, concatenation_dims)
|
|
28
|
+
|
|
29
|
+
def gen_schema(self, n_iter, body_fn, operands, concatenation_dims):
|
|
30
|
+
from torch._higher_order_ops.schema import HopSchemaGenerator
|
|
31
|
+
from torch._higher_order_ops.utils import materialize_as_graph
|
|
32
|
+
|
|
33
|
+
body_gm: torch.fx.GraphModule = materialize_as_graph( # type: ignore[annotation-unchecked]
|
|
34
|
+
body_fn, (torch.tensor(0, dtype=torch.int64), *operands)
|
|
35
|
+
)
|
|
36
|
+
(
|
|
37
|
+
_,
|
|
38
|
+
_,
|
|
39
|
+
_,
|
|
40
|
+
body_mutated_inputs,
|
|
41
|
+
body_outputs,
|
|
42
|
+
) = check_input_alias_and_mutation_return_outputs(body_gm)
|
|
43
|
+
mutated_inputs = body_mutated_inputs
|
|
44
|
+
|
|
45
|
+
schema_gen = HopSchemaGenerator(self)
|
|
46
|
+
schema_gen.add_arg("n_iter", n_iter)
|
|
47
|
+
schema_gen.add_arg("body_fn", body_gm)
|
|
48
|
+
for idx, arg in enumerate(operands):
|
|
49
|
+
schema_gen.add_arg(f"operand{idx}", arg, is_mutated=idx in mutated_inputs)
|
|
50
|
+
|
|
51
|
+
for out in body_outputs:
|
|
52
|
+
schema_gen.add_output(out)
|
|
53
|
+
assert concatenation_dims is None or len(concatenation_dims) == len(body_outputs), (
|
|
54
|
+
f"concatenation_dims={concatenation_dims} but its length should be equal to "
|
|
55
|
+
f"the number of outputs ({len(body_outputs)})"
|
|
56
|
+
)
|
|
57
|
+
schema_gen.add_schema_tree_spec(n_iter, body_fn, operands, concatenation_dims)
|
|
58
|
+
return schema_gen.gen_schema()
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
simple_loop_for_op = SimpleLoopForOp()
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _simple_loop_for_fn(
|
|
65
|
+
n_iter: torch.Tensor,
|
|
66
|
+
body_fn: Callable,
|
|
67
|
+
operands: Tuple[torch.Tensor, ...] = (),
|
|
68
|
+
concatenation_dims: Optional[Sequence[int]] = None,
|
|
69
|
+
) -> Tuple[torch.Tensor, ...]:
|
|
70
|
+
"""
|
|
71
|
+
Python implementation of the loop.
|
|
72
|
+
|
|
73
|
+
:param n_iter: number of iteration
|
|
74
|
+
:param body_fn: function implementing the body
|
|
75
|
+
:param concatenation_dims: dimension used to reduce the list produced by the loop
|
|
76
|
+
:param operands: arguments to the loop body
|
|
77
|
+
:return: results
|
|
78
|
+
"""
|
|
79
|
+
torch._check(
|
|
80
|
+
isinstance(n_iter, (int, torch.Tensor)),
|
|
81
|
+
lambda: f"Unexpected type {type(n_iter)} for n_iter",
|
|
82
|
+
)
|
|
83
|
+
torch._check(callable(body_fn), lambda: f"Unexpected type {type(body_fn)} for body_fn")
|
|
84
|
+
torch._check(
|
|
85
|
+
concatenation_dims is None or isinstance(concatenation_dims, (list, tuple)),
|
|
86
|
+
lambda: f"Unexpected type {type(concatenation_dims)} for concatenation_dims",
|
|
87
|
+
)
|
|
88
|
+
torch._check(
|
|
89
|
+
isinstance(operands, tuple), lambda: f"Unexpected type {type(operands)} for operands"
|
|
90
|
+
)
|
|
91
|
+
res: List[Union[torch.Tensor, Tuple[torch.Tensor, ...]]] = []
|
|
92
|
+
for i in torch.arange(
|
|
93
|
+
n_iter, dtype=torch.int64 if isinstance(n_iter, int) else n_iter.dtype
|
|
94
|
+
):
|
|
95
|
+
r = body_fn(i, *operands)
|
|
96
|
+
if isinstance(r, tuple):
|
|
97
|
+
assert not res or len(r) == len(res[-1]), (
|
|
98
|
+
f"Unexpected number of results {len(r)} for function {body_fn}, "
|
|
99
|
+
f"expected {len(res[-1])}"
|
|
100
|
+
)
|
|
101
|
+
assert all(isinstance(t, torch.Tensor) for t in r), (
|
|
102
|
+
f"Unexpected type {[type(_) for _ in r]} for returned by function {body_fn}, "
|
|
103
|
+
f"it must be a tuple of Tensor or a Tensor."
|
|
104
|
+
)
|
|
105
|
+
res.append(r)
|
|
106
|
+
else:
|
|
107
|
+
assert isinstance(r, torch.Tensor), (
|
|
108
|
+
f"Unexpected type {type(r)} coming from function {body_fn}, "
|
|
109
|
+
f"it must be a tuple of Tensor or a Tensor."
|
|
110
|
+
)
|
|
111
|
+
assert not res or len(res[-1]) == 1, (
|
|
112
|
+
f"Unexpected number of results {len(r)} coming from function {body_fn}, "
|
|
113
|
+
f"expected {len(res[-1])}"
|
|
114
|
+
)
|
|
115
|
+
res.append((r,))
|
|
116
|
+
|
|
117
|
+
if not res:
|
|
118
|
+
return torch.empty(tuple(), dtype=torch.float32, device=operands[0].device)
|
|
119
|
+
|
|
120
|
+
n_res = len(res[0])
|
|
121
|
+
return tuple(
|
|
122
|
+
torch.cat(
|
|
123
|
+
[r[i] for r in res],
|
|
124
|
+
dim=(
|
|
125
|
+
0
|
|
126
|
+
if concatenation_dims is None or i >= len(concatenation_dims)
|
|
127
|
+
else concatenation_dims[i]
|
|
128
|
+
),
|
|
129
|
+
)
|
|
130
|
+
for i in range(n_res)
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _simple_loop_for(
|
|
135
|
+
n_iter: Union[int, torch.Tensor],
|
|
136
|
+
body_fn: Callable,
|
|
137
|
+
operands: Tuple[torch.Tensor, ...] = (),
|
|
138
|
+
concatenation_dims: Optional[Sequence[int]] = None,
|
|
139
|
+
) -> Tuple[torch.Tensor, ...]:
|
|
140
|
+
def _validate_input(n_iter, body_fn, operands, concatenation_dims):
|
|
141
|
+
assert isinstance(
|
|
142
|
+
n_iter, (int, torch.Tensor, torch.SymInt)
|
|
143
|
+
), f"Expected pred to be bool or tensor, but got {n_iter}."
|
|
144
|
+
assert (
|
|
145
|
+
not isinstance(n_iter, torch.Tensor) or n_iter.numel() == 1
|
|
146
|
+
), f"Expected pred to be bool or single-element tensor, but got {n_iter}."
|
|
147
|
+
assert callable(body_fn), "Expect both branches to be callable."
|
|
148
|
+
assert isinstance(operands, (tuple, list)) and pytree.tree_all(
|
|
149
|
+
lambda t: isinstance(t, torch.Tensor), operands
|
|
150
|
+
), (
|
|
151
|
+
"Expect operands to be a tuple of possibly nested dict/list/tuple that only "
|
|
152
|
+
f"consists of tensor leaves, but got {operands}."
|
|
153
|
+
)
|
|
154
|
+
assert concatenation_dims is None or (
|
|
155
|
+
isinstance(concatenation_dims, (list, tuple))
|
|
156
|
+
and all(isinstance(i, int) for i in concatenation_dims)
|
|
157
|
+
), (
|
|
158
|
+
f"concatenation_dims should be None or a list of integers but it is "
|
|
159
|
+
f"{concatenation_dims}. Its length should be equal to the number of outputs."
|
|
160
|
+
)
|
|
161
|
+
assert torch._dynamo.is_dynamo_supported(), "simple_loop_for requires dynamo support."
|
|
162
|
+
|
|
163
|
+
if torch.compiler.is_dynamo_compiling():
|
|
164
|
+
return simple_loop_for_op(
|
|
165
|
+
n_iter, body_fn, operands, concatenation_dims=concatenation_dims
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
if isinstance(n_iter, (bool, int, float)):
|
|
169
|
+
torch._check(
|
|
170
|
+
isinstance(n_iter, int),
|
|
171
|
+
lambda: f"n_iter must be an integer or a tensor not {type(n_iter)}",
|
|
172
|
+
)
|
|
173
|
+
return _simple_loop_for_fn(
|
|
174
|
+
n_iter, body_fn, operands, concatenation_dims=concatenation_dims
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
def _loop_for_op_wrapper(n_iter, body_fn, operands, concatenation_dims):
|
|
178
|
+
return simple_loop_for_op(n_iter, body_fn, operands, concatenation_dims)
|
|
179
|
+
|
|
180
|
+
_validate_input(n_iter, body_fn, operands, concatenation_dims)
|
|
181
|
+
|
|
182
|
+
# This requires torch>=2.10.
|
|
183
|
+
from torch._higher_order_ops.utils import setup_compilation_env
|
|
184
|
+
|
|
185
|
+
with setup_compilation_env() as _backend:
|
|
186
|
+
return _loop_for_op_wrapper(n_iter, body_fn, operands, concatenation_dims)
|
|
187
|
+
# This is needed to support function body using module weights or function body
|
|
188
|
+
# defined as a class method. This is yet to be implemented.
|
|
189
|
+
# cpl = torch.compile(_loop_for_op_wrapper, backend=_backend, fullgraph=True)
|
|
190
|
+
# return cpl(n_iter, body_fn, operands, concatenation_dims)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def trace_simple_loop_for(
|
|
194
|
+
proxy_mode, func_overload, n_iter, body_fn, operands, concatenation_dims
|
|
195
|
+
):
|
|
196
|
+
"""See function ``simple_loop_for``."""
|
|
197
|
+
assert isinstance(operands, (list, tuple)) and (
|
|
198
|
+
concatenation_dims is None
|
|
199
|
+
or (
|
|
200
|
+
isinstance(concatenation_dims, (list, tuple))
|
|
201
|
+
and all(isinstance(i, int) for i in concatenation_dims)
|
|
202
|
+
)
|
|
203
|
+
), (
|
|
204
|
+
f"simple_loop_for operands must be a list or tuple of tensors and SymInts and "
|
|
205
|
+
f"concatenation_dims must be None or a list of integer, "
|
|
206
|
+
f"operands={[type(o) for o in operands]}, "
|
|
207
|
+
f"concatenation_dims={concatenation_dims}"
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
body_graph = reenter_make_fx(body_fn)(n_iter, *operands)
|
|
211
|
+
|
|
212
|
+
body_outs = []
|
|
213
|
+
for node in body_graph.graph.nodes:
|
|
214
|
+
if node.op == "output":
|
|
215
|
+
body_outs.extend(node.args)
|
|
216
|
+
|
|
217
|
+
# flat_body_outs = pytree.arg_tree_leaves(*body_outs)
|
|
218
|
+
_i, body_name = unique_graph_id(proxy_mode, prefix="body_graph")
|
|
219
|
+
proxy_mode.tracer.root.register_module(body_name, body_graph)
|
|
220
|
+
args = (n_iter, body_graph, operands, concatenation_dims)
|
|
221
|
+
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
|
|
222
|
+
out_proxy = proxy_mode.tracer.create_proxy("call_function", func_overload, proxy_args, {})
|
|
223
|
+
out = func_overload(n_iter, body_graph, operands, concatenation_dims)
|
|
224
|
+
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
@simple_loop_for_op.py_impl(DispatchKey.CompositeExplicitAutograd)
|
|
228
|
+
def loop_for_op_dense(n_iter, body_fn, operands, concatenation_dims=None):
|
|
229
|
+
"""Registered eager mode implementation."""
|
|
230
|
+
assert all(isinstance(o, torch.Tensor) for o in operands) and (
|
|
231
|
+
concatenation_dims is None
|
|
232
|
+
or (
|
|
233
|
+
isinstance(concatenation_dims, (list, tuple))
|
|
234
|
+
and all(isinstance(i, int) for i in concatenation_dims)
|
|
235
|
+
)
|
|
236
|
+
), (
|
|
237
|
+
f"simple_loop_for operands must be a list or tuple of tensors and SymInts and "
|
|
238
|
+
f"concatenation_dims must be None or a list of integer, "
|
|
239
|
+
f"operands={[type(o) for o in operands]}, "
|
|
240
|
+
f"concatenation_dims={concatenation_dims}"
|
|
241
|
+
)
|
|
242
|
+
mode = _get_current_dispatch_mode()
|
|
243
|
+
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
|
|
244
|
+
is_fake = isinstance(n_iter, torch._subclasses.fake_tensor.FakeTensor)
|
|
245
|
+
res = _simple_loop_for_fn(n_iter, body_fn, operands, concatenation_dims=concatenation_dims)
|
|
246
|
+
assert is_fake or not any(
|
|
247
|
+
isinstance(r, torch._subclasses.fake_tensor.FakeTensor) for r in res
|
|
248
|
+
), (
|
|
249
|
+
f"One result is a fake tensor but the inputs were not, type(n_iter)={type(n_iter)}, "
|
|
250
|
+
f"operands: {[type(_) for _ in operands]}, res: {[type(_) for _ in res]}"
|
|
251
|
+
)
|
|
252
|
+
return res
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
@simple_loop_for_op.py_impl(ProxyTorchDispatchMode)
|
|
256
|
+
def inner(mode, n_iter, body_fn, operands, concatenation_dims=None):
|
|
257
|
+
"""Registered tracing implementation."""
|
|
258
|
+
return trace_simple_loop_for(
|
|
259
|
+
mode, simple_loop_for_op, n_iter, body_fn, operands, concatenation_dims
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
@simple_loop_for_op.py_impl(FakeTensorMode)
|
|
264
|
+
def simple_loop_for_fake_tensor_mode(mode, n_iter, body_fn, operands, concatenation_dims=None):
|
|
265
|
+
"""Registered FakeMode implementation."""
|
|
266
|
+
ignore_fresh_unbacked = contextlib.nullcontext()
|
|
267
|
+
if mode.shape_env:
|
|
268
|
+
ignore_fresh_unbacked = mode.shape_env.ignore_fresh_unbacked_symbols()
|
|
269
|
+
|
|
270
|
+
with mode, ignore_fresh_unbacked:
|
|
271
|
+
flat_body_outs, true_body_spec = pytree.tree_flatten(body_fn(n_iter, *operands))
|
|
272
|
+
|
|
273
|
+
return pytree.tree_unflatten(flat_body_outs, true_body_spec)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
# Registration for autograd.
|
|
277
|
+
simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCPU)
|
|
278
|
+
simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCUDA)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
class SimpleLoopForHigherOrderVariable(hop.TorchHigherOrderOperatorVariable):
|
|
282
|
+
"""
|
|
283
|
+
Replicates the same pattern found for other higher order operators.
|
|
284
|
+
This enables recursive compilation and the use of modules inside a function.
|
|
285
|
+
"""
|
|
286
|
+
|
|
287
|
+
_HOP_NAME = "simple_loop_for"
|
|
288
|
+
_ALLOW_FALLBACK_TO_EAGER = False
|
|
289
|
+
supports_input_mutation = False
|
|
290
|
+
supports_aliasing = False
|
|
291
|
+
|
|
292
|
+
def _call_function(
|
|
293
|
+
self,
|
|
294
|
+
tx: torch._dynamo.symbolic_convert.InstructionTranslator,
|
|
295
|
+
args: list[hop.VariableTracker],
|
|
296
|
+
kwargs: dict[str, hop.VariableTracker],
|
|
297
|
+
) -> hop.VariableTracker:
|
|
298
|
+
"""Main function."""
|
|
299
|
+
args, kwargs = hop.LazyVariableTracker.realize_all((args, kwargs))
|
|
300
|
+
|
|
301
|
+
for i, k in enumerate(["n_iter", "body_fn", "operands", "concatenated_dims"]):
|
|
302
|
+
if v := kwargs.pop(k, None):
|
|
303
|
+
assert i == len(args), "did not provide the right number of non-keyword args"
|
|
304
|
+
args.append(v)
|
|
305
|
+
|
|
306
|
+
if len(args) != 4 or kwargs:
|
|
307
|
+
hop.unimplemented(
|
|
308
|
+
gb_type="simple_loop_for: improper args/kwargs",
|
|
309
|
+
context=f"args: {args}, kwargs: {kwargs}",
|
|
310
|
+
explanation=f"torch.cond expects 4 positional arguments (got {len(args)}) "
|
|
311
|
+
f"and no keyword arguments (got {len(kwargs)})",
|
|
312
|
+
hints=[*hop.graph_break_hints.USER_ERROR],
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
# Specialize into one of the branches since pred is constant
|
|
316
|
+
n_iter, body_fn, operands, _concatenated_dims = args
|
|
317
|
+
assert type(n_iter) is not hop.ConstantVariable, (
|
|
318
|
+
f"n_iter is a {type(n_iter)}. When used simple_loop_for, "
|
|
319
|
+
f"it unrolls the loop. A SymInt should be used."
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# predicate
|
|
323
|
+
if type(n_iter.realize()) not in (
|
|
324
|
+
hop.ConstantVariable,
|
|
325
|
+
hop.TensorVariable,
|
|
326
|
+
hop.SymNodeVariable,
|
|
327
|
+
):
|
|
328
|
+
hop.unimplemented(
|
|
329
|
+
gb_type="simple_loop_for: improper predicate",
|
|
330
|
+
context=str(n_iter),
|
|
331
|
+
explanation=(
|
|
332
|
+
f"Expected `n_iter` to be an int or a integer "
|
|
333
|
+
f"tensor with a single item "
|
|
334
|
+
f"but got {str(type(n_iter))} with original python type "
|
|
335
|
+
f"{str(n_iter.python_type())}."
|
|
336
|
+
),
|
|
337
|
+
hints=[*hop.graph_break_hints.USER_ERROR],
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# operands
|
|
341
|
+
if not isinstance(operands, (hop.ListVariable, hop.TupleVariable)):
|
|
342
|
+
hop.unimplemented(
|
|
343
|
+
gb_type="simple_loop_for: improper operands",
|
|
344
|
+
context=str(operands),
|
|
345
|
+
explanation="Expected `operands` to be a list/tuple "
|
|
346
|
+
f"but got {operands.python_type()}.",
|
|
347
|
+
hints=[*hop.graph_break_hints.USER_ERROR],
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
operands_seq = operands.unpack_var_sequence(tx)
|
|
351
|
+
if not hop.only_consist_of(
|
|
352
|
+
operands, (hop.TensorVariable, hop.ConstantVariable, hop.SymNodeVariable)
|
|
353
|
+
):
|
|
354
|
+
hop.unimplemented(
|
|
355
|
+
gb_type="simple_loop_for: improper operands contents",
|
|
356
|
+
context=str(operands),
|
|
357
|
+
explanation=(
|
|
358
|
+
"Expected `operands` to be a list/tuple of pytrees "
|
|
359
|
+
"that only consists of tensor leaves."
|
|
360
|
+
),
|
|
361
|
+
hints=[*hop.graph_break_hints.USER_ERROR],
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
# branches
|
|
365
|
+
hop._check_supported_callable_arg(tx, body_fn, "body_fn")
|
|
366
|
+
|
|
367
|
+
def speculate_body():
|
|
368
|
+
(
|
|
369
|
+
(ret_val, ret_spec),
|
|
370
|
+
ret_graph,
|
|
371
|
+
ret_lifted_freevars,
|
|
372
|
+
) = hop.speculate_subgraph(
|
|
373
|
+
tx,
|
|
374
|
+
args[1],
|
|
375
|
+
(args[0], *operands_seq),
|
|
376
|
+
{},
|
|
377
|
+
self._HOP_NAME,
|
|
378
|
+
source_target=self.value,
|
|
379
|
+
should_flatten_outputs=True,
|
|
380
|
+
# TODO - removing consts from control flow ops need more work
|
|
381
|
+
remove_consts_from_outputs=False,
|
|
382
|
+
supports_input_mutation=self.supports_input_mutation,
|
|
383
|
+
supports_aliasing=self.supports_aliasing,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# need to ensure we increase epoch so we don't memoize unbacked bindings
|
|
387
|
+
# across different subgraphs which can interfere with runtime assertion
|
|
388
|
+
# generation.
|
|
389
|
+
tx.fake_mode.epoch += 1
|
|
390
|
+
|
|
391
|
+
if not hop.only_consist_of(ret_val, (hop.TensorVariable, hop.ConstantVariable)):
|
|
392
|
+
hop.unimplemented(
|
|
393
|
+
gb_type="simple_loop_for: unsupported branch return type",
|
|
394
|
+
context=str(ret_val),
|
|
395
|
+
explanation=(
|
|
396
|
+
"Expected branches to return a possibly nested "
|
|
397
|
+
"pytree of tensors or constant ints."
|
|
398
|
+
),
|
|
399
|
+
hints=[*hop.graph_break_hints.USER_ERROR],
|
|
400
|
+
)
|
|
401
|
+
for ret in ret_val.unpack_var_sequence(tx):
|
|
402
|
+
if ret.is_python_constant() and not isinstance(ret.as_python_constant(), int):
|
|
403
|
+
hop.unimplemented(
|
|
404
|
+
gb_type=(
|
|
405
|
+
"simple_loop_for: unsupported branch return type "
|
|
406
|
+
"(constant non-int)"
|
|
407
|
+
),
|
|
408
|
+
context=str(ret_val),
|
|
409
|
+
explanation="Constants returned from branches must be ints.",
|
|
410
|
+
hints=[*hop.graph_break_hints.USER_ERROR],
|
|
411
|
+
)
|
|
412
|
+
return ret_val, ret_spec, ret_graph, ret_lifted_freevars
|
|
413
|
+
|
|
414
|
+
body_r, body_spec, body_graph, body_lifted_freevars = speculate_body()
|
|
415
|
+
body_nn_modules = dict(tx.output.nn_modules)
|
|
416
|
+
|
|
417
|
+
same_spec = body_spec.treespec.as_python_constant()
|
|
418
|
+
if same_spec is not NotImplemented and not same_spec:
|
|
419
|
+
hop.unimplemented(
|
|
420
|
+
gb_type="simple_loop_for: differing branch outputs",
|
|
421
|
+
context=(
|
|
422
|
+
f"body_spec: {body_spec.treespec}, false_spec: "
|
|
423
|
+
f"{body_spec.treespec}, same_spec: {same_spec}"
|
|
424
|
+
),
|
|
425
|
+
explanation="Expected branches to return the same pytree structure.",
|
|
426
|
+
hints=[*hop.graph_break_hints.USER_ERROR],
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
body_name = tx.output.install_subgraph(
|
|
430
|
+
"loop_body", torch.fx.GraphModule(body_nn_modules, body_graph)
|
|
431
|
+
)
|
|
432
|
+
body_node = hop.make_attr(tx, body_name)
|
|
433
|
+
p_args = (
|
|
434
|
+
n_iter.as_proxy(),
|
|
435
|
+
body_node,
|
|
436
|
+
# We pick true_shared but it shouldn't matter
|
|
437
|
+
operands.as_proxy() + tuple(body_lifted_freevars.keys()),
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
return hop._call_function_and_unflatten_output(
|
|
441
|
+
tx,
|
|
442
|
+
simple_loop_for,
|
|
443
|
+
p_args,
|
|
444
|
+
{},
|
|
445
|
+
None,
|
|
446
|
+
body_spec,
|
|
447
|
+
body_r,
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
hop._hop_name_to_variable_class["simple_loop_for"] = SimpleLoopForHigherOrderVariable
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
# @torch._functorch.utils.exposed_in("torch")
|
|
455
|
+
def simple_loop_for(
|
|
456
|
+
n_iter: Union[int, torch.Tensor],
|
|
457
|
+
body_fn: Callable,
|
|
458
|
+
operands: Tuple[torch.Tensor, ...] = (),
|
|
459
|
+
concatenation_dims: Optional[Union[int, Sequence[int]]] = None,
|
|
460
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
|
461
|
+
"""
|
|
462
|
+
Implements a simple loop for, the body is defined by a function which takes the
|
|
463
|
+
iteration number stored in a tensor, and other tensors.
|
|
464
|
+
It results one or several tensors in a tuple. All of them
|
|
465
|
+
are finally concatenated along the first dimension.
|
|
466
|
+
|
|
467
|
+
:param n_iter: iteration number
|
|
468
|
+
:param body: function
|
|
469
|
+
:param operands: bidy arguments
|
|
470
|
+
:param concatenation_dims: dimension or dimensions used to concatenate the output sequences
|
|
471
|
+
:return: contenated outputs, the output is a Tensor
|
|
472
|
+
|
|
473
|
+
An example with one output:
|
|
474
|
+
|
|
475
|
+
.. runpython::
|
|
476
|
+
:showcode:
|
|
477
|
+
|
|
478
|
+
import torch
|
|
479
|
+
from onnx_diagnostic.export.cf_simple_loop_for import simple_loop_for
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
class Model(torch.nn.Module):
|
|
483
|
+
def forward(self, n_iter, x):
|
|
484
|
+
def body(i, x):
|
|
485
|
+
return (x[: i.item() + 1].unsqueeze(1),)
|
|
486
|
+
|
|
487
|
+
return simple_loop_for(n_iter, body, (x,))
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
model = Model()
|
|
491
|
+
n_iter = torch.tensor(4, dtype=torch.int64)
|
|
492
|
+
x = torch.arange(10, dtype=torch.float32)
|
|
493
|
+
ep = torch.export.export(
|
|
494
|
+
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
|
|
495
|
+
)
|
|
496
|
+
print(ep)
|
|
497
|
+
|
|
498
|
+
Another example with two outputs and a final concatenation on different axes.
|
|
499
|
+
|
|
500
|
+
.. runpython::
|
|
501
|
+
:showcode:
|
|
502
|
+
|
|
503
|
+
import torch
|
|
504
|
+
from onnx_diagnostic.export.cf_simple_loop_for import simple_loop_for
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
class Model(torch.nn.Module):
|
|
508
|
+
def forward(self, n_iter, x):
|
|
509
|
+
def body(i, x):
|
|
510
|
+
return (x[: i.item() + 1].unsqueeze(1), x[i.item() + 1 :].unsqueeze(0))
|
|
511
|
+
|
|
512
|
+
return simple_loop_for(n_iter, body, (x,), (0, 1))
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
model = Model()
|
|
516
|
+
n_iter = torch.tensor(4, dtype=torch.int64)
|
|
517
|
+
x = torch.arange(10, dtype=torch.float32)
|
|
518
|
+
ep = torch.export.export(
|
|
519
|
+
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
|
|
520
|
+
)
|
|
521
|
+
print(ep)
|
|
522
|
+
"""
|
|
523
|
+
res = _simple_loop_for(
|
|
524
|
+
n_iter,
|
|
525
|
+
body_fn,
|
|
526
|
+
operands,
|
|
527
|
+
concatenation_dims=(
|
|
528
|
+
(concatenation_dims,)
|
|
529
|
+
if isinstance(concatenation_dims, int)
|
|
530
|
+
else concatenation_dims
|
|
531
|
+
),
|
|
532
|
+
)
|
|
533
|
+
torch._check(
|
|
534
|
+
isinstance(res, tuple),
|
|
535
|
+
lambda: f"Output of the loop should be a tuple not {type(res)}.",
|
|
536
|
+
)
|
|
537
|
+
return res[0] if len(res) == 1 else res
|
|
@@ -55,13 +55,13 @@ def is_exporting() -> bool:
|
|
|
55
55
|
return _TEST_EXPORT or torch.compiler.is_exporting() or torch.compiler.is_compiling()
|
|
56
56
|
|
|
57
57
|
|
|
58
|
-
def _loop_for_onnx_fn(n_iter, body_fn,
|
|
58
|
+
def _loop_for_onnx_fn(n_iter, body_fn, concatenation_dims, args):
|
|
59
59
|
"""
|
|
60
60
|
Python implementation of the loop.
|
|
61
61
|
|
|
62
62
|
:param n_iter: number of iteration
|
|
63
63
|
:param body_fn: function implementing the body
|
|
64
|
-
:param
|
|
64
|
+
:param concatenation_dims: dimension used to reduce the list produced by the loop
|
|
65
65
|
:param args: arguments to the loop body
|
|
66
66
|
:return: results
|
|
67
67
|
"""
|
|
@@ -95,7 +95,9 @@ def _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args):
|
|
|
95
95
|
torch.cat(
|
|
96
96
|
[r[i] for r in res],
|
|
97
97
|
dim=(
|
|
98
|
-
0
|
|
98
|
+
0
|
|
99
|
+
if concatenation_dims is None or i >= len(concatenation_dims)
|
|
100
|
+
else concatenation_dims[i]
|
|
99
101
|
),
|
|
100
102
|
)
|
|
101
103
|
for i in range(n_res)
|
|
@@ -106,7 +108,7 @@ def _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args):
|
|
|
106
108
|
def make_custom_loop_for_onnx(
|
|
107
109
|
n_iter: torch.Tensor,
|
|
108
110
|
body_fn: Callable,
|
|
109
|
-
|
|
111
|
+
concatenation_dims: Optional[Sequence[int]],
|
|
110
112
|
args: Sequence[torch.Tensor],
|
|
111
113
|
body_gm: Optional[torch.fx.GraphModule] = None,
|
|
112
114
|
body_mutated_inputs: Optional[List[Any]] = None,
|
|
@@ -120,7 +122,7 @@ def make_custom_loop_for_onnx(
|
|
|
120
122
|
|
|
121
123
|
:param n_iter: number of iterations defined by a tensor of no dimension
|
|
122
124
|
:param body_fn: the loop body defined as a function
|
|
123
|
-
:param
|
|
125
|
+
:param concatenation_dims: dimension used to concatenated the results
|
|
124
126
|
:param args: list of tensors, input to the body
|
|
125
127
|
:param body_gm: torch.fx.GraphModule equivalent to *body_gm*
|
|
126
128
|
:param body_mutated_inputs: inputs to *body_gm*
|
|
@@ -133,7 +135,7 @@ def make_custom_loop_for_onnx(
|
|
|
133
135
|
assert body_mutated_inputs is not None, "body_mutated_inputs cannot be None"
|
|
134
136
|
assert body_outputs is not None, "body_outputs cannot be None"
|
|
135
137
|
srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs)
|
|
136
|
-
sred = "x".join(map(str,
|
|
138
|
+
sred = "x".join(map(str, concatenation_dims)) if concatenation_dims else ""
|
|
137
139
|
full_name = (
|
|
138
140
|
body_fn.__qualname__.replace("<locals>", "L")
|
|
139
141
|
.replace("<lambda>", "l")
|
|
@@ -169,14 +171,14 @@ def make_custom_loop_for_onnx(
|
|
|
169
171
|
custom_def,
|
|
170
172
|
_make_onx,
|
|
171
173
|
(
|
|
172
|
-
lambda g, sts, outputs, *args, bc=_make_onx, rd=
|
|
174
|
+
lambda g, sts, outputs, *args, bc=_make_onx, rd=concatenation_dims, name=name: (
|
|
173
175
|
convert_custom_loop_into_onnx(
|
|
174
176
|
g,
|
|
175
177
|
sts,
|
|
176
178
|
outputs,
|
|
177
179
|
*args,
|
|
178
180
|
body_callable=bc,
|
|
179
|
-
|
|
181
|
+
concatenation_dims=rd,
|
|
180
182
|
name=name,
|
|
181
183
|
)
|
|
182
184
|
)
|
|
@@ -196,7 +198,7 @@ def convert_custom_loop_into_onnx(
|
|
|
196
198
|
outputs: List[str],
|
|
197
199
|
*args: str,
|
|
198
200
|
body_callable: Callable[..., onnx.ModelProto],
|
|
199
|
-
|
|
201
|
+
concatenation_dims: Optional[Sequence[int]] = None,
|
|
200
202
|
name: str = "loop_for_onnx",
|
|
201
203
|
) -> Union[str, List[str]]:
|
|
202
204
|
"""
|
|
@@ -207,7 +209,7 @@ def convert_custom_loop_into_onnx(
|
|
|
207
209
|
:param outputs: output names
|
|
208
210
|
:param args: input argument known at export time
|
|
209
211
|
:param body: GraphProto, the loop body
|
|
210
|
-
:param
|
|
212
|
+
:param concatenation_dims: the dimension to follow when aggregating the
|
|
211
213
|
list of tensors after the loop ran
|
|
212
214
|
:param name: to give the onnx nodes a name
|
|
213
215
|
:return: output names
|
|
@@ -289,7 +291,11 @@ def convert_custom_loop_into_onnx(
|
|
|
289
291
|
out,
|
|
290
292
|
outputs=[o],
|
|
291
293
|
name=name,
|
|
292
|
-
axis=
|
|
294
|
+
axis=(
|
|
295
|
+
0
|
|
296
|
+
if not concatenation_dims or i >= len(concatenation_dims)
|
|
297
|
+
else concatenation_dims[i]
|
|
298
|
+
),
|
|
293
299
|
)
|
|
294
300
|
for i, (out, o) in enumerate(zip(outloop, outputs))
|
|
295
301
|
]
|
|
@@ -337,7 +343,7 @@ def loop_for_onnx(
|
|
|
337
343
|
n_iter: Union[torch.SymInt, torch.Tensor],
|
|
338
344
|
body_fn: Callable[..., Tuple[torch.Tensor]],
|
|
339
345
|
args: Sequence[torch.Tensor],
|
|
340
|
-
|
|
346
|
+
concatenation_dims: Optional[Sequence[int]] = None,
|
|
341
347
|
) -> Tuple[torch.Tensor, ...]:
|
|
342
348
|
"""
|
|
343
349
|
High operators used to easily export a loop in ONNX.
|
|
@@ -353,7 +359,7 @@ def loop_for_onnx(
|
|
|
353
359
|
in a tensor with no dimension, all the others
|
|
354
360
|
are not changed during the loop
|
|
355
361
|
:param args: the available tensors at every loop
|
|
356
|
-
:param
|
|
362
|
+
:param concatenation_dims: the loop aggregated the results into list,
|
|
357
363
|
one of each output, each of them is concatenated into one
|
|
358
364
|
tensor along one dimension, by default, it is the first
|
|
359
365
|
dimension, but it can be defined otherwise
|
|
@@ -449,7 +455,7 @@ def loop_for_onnx(
|
|
|
449
455
|
)
|
|
450
456
|
print(ep)
|
|
451
457
|
|
|
452
|
-
A last example with ``
|
|
458
|
+
A last example with ``concatenation_dims``:
|
|
453
459
|
|
|
454
460
|
.. runpython::
|
|
455
461
|
:showcode:
|
|
@@ -465,7 +471,7 @@ def loop_for_onnx(
|
|
|
465
471
|
def body(i, x):
|
|
466
472
|
return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1
|
|
467
473
|
|
|
468
|
-
two = loop_for_onnx(n_iter, body, (x,),
|
|
474
|
+
two = loop_for_onnx(n_iter, body, (x,), concatenation_dims=[0, 1])
|
|
469
475
|
return two[0] + two[1].T
|
|
470
476
|
|
|
471
477
|
|
|
@@ -516,7 +522,7 @@ def loop_for_onnx(
|
|
|
516
522
|
name, _custom_ops = make_custom_loop_for_onnx(
|
|
517
523
|
n_iter,
|
|
518
524
|
body_fn,
|
|
519
|
-
|
|
525
|
+
concatenation_dims,
|
|
520
526
|
args,
|
|
521
527
|
body_gm=body_gm,
|
|
522
528
|
body_mutated_inputs=body_mutated_inputs,
|
|
@@ -525,4 +531,4 @@ def loop_for_onnx(
|
|
|
525
531
|
fct = getattr(torch.ops.onnx_higher_ops, name)
|
|
526
532
|
return fct(n_iter, *args)
|
|
527
533
|
|
|
528
|
-
return _loop_for_onnx_fn(n_iter, body_fn,
|
|
534
|
+
return _loop_for_onnx_fn(n_iter, body_fn, concatenation_dims, args)
|