onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +387 -12
- onnx_diagnostic/export/api.py +91 -8
- onnx_diagnostic/export/control_flow.py +48 -345
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +3 -3
- onnx_diagnostic/export/onnx_plug.py +396 -0
- onnx_diagnostic/ext_test_case.py +92 -23
- onnx_diagnostic/helpers/cache_helper.py +1 -1
- onnx_diagnostic/helpers/dot_helper.py +210 -0
- onnx_diagnostic/helpers/helper.py +90 -26
- onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +103 -1
- onnx_diagnostic/helpers/ort_session.py +37 -11
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +103 -6
- onnx_diagnostic/reference/ort_evaluator.py +233 -28
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
- onnx_diagnostic/torch_models/validate.py +50 -1
- onnx_diagnostic/torch_onnx/sbs.py +963 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +43 -24
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
|
@@ -1,42 +1,18 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
-
import
|
|
3
|
-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
|
4
|
-
import onnx
|
|
5
|
-
import onnx.helper as oh
|
|
2
|
+
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
|
|
6
3
|
import torch
|
|
7
|
-
from torch._higher_order_ops.utils import
|
|
8
|
-
|
|
9
|
-
|
|
4
|
+
from torch._higher_order_ops.utils import (
|
|
5
|
+
materialize_as_graph,
|
|
6
|
+
check_input_alias_and_mutation_return_outputs,
|
|
7
|
+
# _maybe_reenter_make_fx,
|
|
8
|
+
)
|
|
10
9
|
|
|
11
10
|
_TEST_EXPORT = False
|
|
12
|
-
_REGISTERED_SCHEMA = {} # type: ignore[var-annotated]
|
|
13
|
-
_DISPATCHER = None
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def create_global_dispatcher():
|
|
17
|
-
global _DISPATCHER
|
|
18
|
-
|
|
19
|
-
if not _DISPATCHER:
|
|
20
|
-
from experimental_experiment.torch_interpreter import Dispatcher
|
|
21
|
-
|
|
22
|
-
class ControlFlowDispatcher(Dispatcher):
|
|
23
|
-
def __init__(self):
|
|
24
|
-
super().__init__({})
|
|
25
|
-
|
|
26
|
-
def register(self, aten_name: str, converter: Callable):
|
|
27
|
-
assert aten_name not in self.registered_functions, (
|
|
28
|
-
f"Name {aten_name!r} is already registered in "
|
|
29
|
-
f"{sorted(self.registered_functions)}"
|
|
30
|
-
)
|
|
31
|
-
self.registered_functions[aten_name] = converter
|
|
32
|
-
|
|
33
|
-
_DISPATCHER = ControlFlowDispatcher()
|
|
34
|
-
return _DISPATCHER
|
|
35
11
|
|
|
36
12
|
|
|
37
13
|
@contextlib.contextmanager
|
|
38
14
|
def enable_code_export_control_flow():
|
|
39
|
-
"""Enables the code
|
|
15
|
+
"""Enables the code meant to be exported."""
|
|
40
16
|
global _TEST_EXPORT
|
|
41
17
|
old = _TEST_EXPORT
|
|
42
18
|
_TEST_EXPORT = True
|
|
@@ -128,194 +104,31 @@ def make_custom_loop_for(
|
|
|
128
104
|
:return: a name and the custom op definition, the name
|
|
129
105
|
is used to cache the custom op
|
|
130
106
|
"""
|
|
131
|
-
global _DISPATCHER
|
|
132
107
|
assert body_gm is not None, "body_gm cannot be None"
|
|
133
108
|
assert body_mutated_inputs is not None, "body_mutated_inputs cannot be None"
|
|
134
109
|
assert body_outputs is not None, "body_outputs cannot be None"
|
|
110
|
+
|
|
135
111
|
srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs)
|
|
136
112
|
sred = "x".join(map(str, reduction_dim)) if reduction_dim else ""
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
113
|
+
full_name = (
|
|
114
|
+
body_fn.__qualname__.replace("<locals>", "L")
|
|
115
|
+
.replace("<lambda>", "l")
|
|
116
|
+
.replace(".", "_")
|
|
117
|
+
)
|
|
118
|
+
name = f"loop_for_onnx_{full_name}_{srank}_{sred}"
|
|
119
|
+
|
|
120
|
+
schema = "(str body_fn, Tensor n_iter, Tensor[] body_inputs) -> Tensor"
|
|
143
121
|
if len(body_outputs) > 1:
|
|
144
122
|
schema += "[]"
|
|
145
|
-
custom_def = torch.library.CustomOpDef("onnx_higher_ops",
|
|
123
|
+
custom_def = torch.library.CustomOpDef("onnx_higher_ops", "loop_for", schema, body_fn)
|
|
146
124
|
custom_def.register_kernel("cpu")(body_fn)
|
|
147
125
|
|
|
148
|
-
custom_def._abstract_fn = lambda *_args, _o=body_outputs: (
|
|
126
|
+
custom_def._abstract_fn = lambda _fn_id, *_args, _o=body_outputs: (
|
|
149
127
|
tuple([torch.empty_like(s) for s in _o]) if len(_o) > 1 else torch.empty_like(_o[0])
|
|
150
128
|
)
|
|
151
|
-
|
|
152
|
-
def _make_onx(
|
|
153
|
-
body_gm=body_gm, args=args, target_opset=None, verbose=0, exporter_kwargs=None
|
|
154
|
-
):
|
|
155
|
-
return convert_into_onnx(
|
|
156
|
-
body_gm,
|
|
157
|
-
args,
|
|
158
|
-
exporter_kwargs=exporter_kwargs,
|
|
159
|
-
target_opset=target_opset,
|
|
160
|
-
verbose=verbose,
|
|
161
|
-
)
|
|
162
|
-
|
|
163
|
-
to_register = (
|
|
164
|
-
custom_def,
|
|
165
|
-
_make_onx,
|
|
166
|
-
(
|
|
167
|
-
lambda g, sts, outputs, *args, bc=_make_onx, rd=reduction_dim, name=name: (
|
|
168
|
-
convert_custom_loop_into_onnx(
|
|
169
|
-
g,
|
|
170
|
-
sts,
|
|
171
|
-
outputs,
|
|
172
|
-
*args,
|
|
173
|
-
body_callable=bc,
|
|
174
|
-
reduction_dim=rd,
|
|
175
|
-
name=name,
|
|
176
|
-
)
|
|
177
|
-
)
|
|
178
|
-
),
|
|
179
|
-
)
|
|
180
|
-
if _DISPATCHER is None:
|
|
181
|
-
create_global_dispatcher()
|
|
182
|
-
assert _DISPATCHER
|
|
183
|
-
_DISPATCHER.register(f"onnx_higher_ops::{name}", to_register[-1])
|
|
184
|
-
_REGISTERED_SCHEMA[name] = to_register
|
|
185
129
|
return name, custom_def
|
|
186
130
|
|
|
187
131
|
|
|
188
|
-
def convert_custom_loop_into_onnx(
|
|
189
|
-
g: Any, # "GreaphBuilder"
|
|
190
|
-
sts: Dict[str, Any],
|
|
191
|
-
outputs: List[str],
|
|
192
|
-
*args: str,
|
|
193
|
-
body_callable: Callable[..., onnx.ModelProto],
|
|
194
|
-
reduction_dim: Optional[Sequence[int]] = None,
|
|
195
|
-
name: str = "loop_for",
|
|
196
|
-
) -> Union[str, List[str]]:
|
|
197
|
-
"""
|
|
198
|
-
Converts a custom op ``higher_ops::loop_for...`` into e sequence of node.
|
|
199
|
-
|
|
200
|
-
:param g: GreaphBuilder
|
|
201
|
-
:param sts: if not defined, torch does not know the output shapes
|
|
202
|
-
:param outputs: output names
|
|
203
|
-
:param args: input argument known at export time
|
|
204
|
-
:param body: GraphProto, the loop body
|
|
205
|
-
:param reduction_dim: the dimension to follow when aggregating the
|
|
206
|
-
list of tensors after the loop ran
|
|
207
|
-
:param name: to give the onnx nodes a name
|
|
208
|
-
:return: output names
|
|
209
|
-
"""
|
|
210
|
-
assert body_callable is not None, "body_callable cannot be None"
|
|
211
|
-
# This should be part of a public API.
|
|
212
|
-
body = body_callable(
|
|
213
|
-
target_opset=g.main_opset,
|
|
214
|
-
verbose=g.verbose,
|
|
215
|
-
exporter_kwargs={"options": g.optimization_options},
|
|
216
|
-
)
|
|
217
|
-
|
|
218
|
-
graph = body.graph if isinstance(body, onnx.ModelProto) else body
|
|
219
|
-
assert isinstance(
|
|
220
|
-
graph, onnx.GraphProto
|
|
221
|
-
), f"Unexpected type {type(body)} for body{g.get_debug_msg()}"
|
|
222
|
-
assert len(outputs) == 1, f"Only one outputs is expected but outputs={outputs!r}"
|
|
223
|
-
if len(graph.output) != 1:
|
|
224
|
-
outputs = [f"{outputs[0]}#{i}" for i in range(len(graph.output))]
|
|
225
|
-
input_names = [i.name for i in graph.input]
|
|
226
|
-
inputs = [
|
|
227
|
-
*graph.input[:1],
|
|
228
|
-
oh.make_tensor_value_info("cond_unused", onnx.TensorProto.BOOL, []),
|
|
229
|
-
*[
|
|
230
|
-
oh.make_tensor_sequence_value_info(
|
|
231
|
-
f"loop_in{i}", graph.output[i].type.tensor_type.elem_type, None
|
|
232
|
-
)
|
|
233
|
-
for i in range(len(graph.output))
|
|
234
|
-
],
|
|
235
|
-
# hidden inputs are not added
|
|
236
|
-
]
|
|
237
|
-
nodes = [
|
|
238
|
-
oh.make_node("Identity", ["cond_unused"], ["cond_out"]),
|
|
239
|
-
*[oh.make_node("Identity", [a], [r]) for a, r in zip(args[1:], input_names[1:])],
|
|
240
|
-
*graph.node,
|
|
241
|
-
*[
|
|
242
|
-
oh.make_node(
|
|
243
|
-
"SequenceInsert",
|
|
244
|
-
[f"loop_in{i}", graph.output[i].name],
|
|
245
|
-
[f"loop_out{i}"],
|
|
246
|
-
)
|
|
247
|
-
for i in range(len(graph.output))
|
|
248
|
-
],
|
|
249
|
-
]
|
|
250
|
-
graph_outputs = [
|
|
251
|
-
oh.make_tensor_value_info("cond_out", onnx.TensorProto.BOOL, []),
|
|
252
|
-
*[
|
|
253
|
-
oh.make_tensor_sequence_value_info(
|
|
254
|
-
f"loop_out{i}", graph.output[i].type.tensor_type.elem_type, None
|
|
255
|
-
)
|
|
256
|
-
for i in range(len(graph.output))
|
|
257
|
-
],
|
|
258
|
-
]
|
|
259
|
-
graph = oh.make_graph(
|
|
260
|
-
nodes, graph.name, inputs, graph_outputs, graph.initializer, graph.sparse_initializer
|
|
261
|
-
)
|
|
262
|
-
|
|
263
|
-
sequences = [g.op.SequenceEmpty() for _ in outputs]
|
|
264
|
-
|
|
265
|
-
outloop = [g.unique_name(f"loop_for{i}") for i in range(len(sequences))]
|
|
266
|
-
|
|
267
|
-
for i, s in enumerate(sequences):
|
|
268
|
-
g.set_sequence(s, graph.output[i].type.tensor_type.elem_type)
|
|
269
|
-
g.make_node("Loop", [args[0], "", *sequences], outloop, name=name, body=graph)
|
|
270
|
-
for i, o in enumerate(outloop):
|
|
271
|
-
g.set_sequence(o, graph.output[i].type.tensor_type.elem_type)
|
|
272
|
-
_res = [
|
|
273
|
-
g.op.ConcatFromSequence(
|
|
274
|
-
out,
|
|
275
|
-
outputs=[o],
|
|
276
|
-
name=name,
|
|
277
|
-
axis=0 if not reduction_dim or i >= len(reduction_dim) else reduction_dim[i],
|
|
278
|
-
)
|
|
279
|
-
for i, (out, o) in enumerate(zip(outloop, outputs))
|
|
280
|
-
]
|
|
281
|
-
if not sts:
|
|
282
|
-
for i, o in enumerate(outputs):
|
|
283
|
-
g.set_type(o, graph.output[i].type.tensor_type.elem_type)
|
|
284
|
-
g.set_rank(o, len(graph.output[i].type.tensor_type.shape.dims))
|
|
285
|
-
return outputs if len(outputs) > 1 else outputs[0]
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
def convert_into_onnx(
|
|
289
|
-
body_gm: torch.fx.GraphModule,
|
|
290
|
-
args: Sequence[torch.Tensor],
|
|
291
|
-
target_opset: Optional[int] = None,
|
|
292
|
-
verbose: int = 0,
|
|
293
|
-
exporter_kwargs: Optional[Dict[str, Any]] = None,
|
|
294
|
-
) -> onnx.ModelProto:
|
|
295
|
-
"""
|
|
296
|
-
Converts a torch.fx.GraphModule into ONNX.
|
|
297
|
-
It returns a ModelProto.
|
|
298
|
-
|
|
299
|
-
:param body_gm: a torch.fx.GraphModule
|
|
300
|
-
:param args: arguments known at export time
|
|
301
|
-
:param target_opset: targeted opset
|
|
302
|
-
:param verbose: verbosity level
|
|
303
|
-
:param exporter_kwargs: additional exporter arguments
|
|
304
|
-
:return: a ModelProto
|
|
305
|
-
"""
|
|
306
|
-
# This does not work with onnx-dynamo.
|
|
307
|
-
# opset still needs to be defined
|
|
308
|
-
container = to_onnx(
|
|
309
|
-
body_gm,
|
|
310
|
-
args,
|
|
311
|
-
exporter="custom",
|
|
312
|
-
exporter_kwargs=exporter_kwargs,
|
|
313
|
-
target_opset=target_opset,
|
|
314
|
-
verbose=verbose,
|
|
315
|
-
)
|
|
316
|
-
return container.model_proto
|
|
317
|
-
|
|
318
|
-
|
|
319
132
|
def loop_for(
|
|
320
133
|
n_iter: Union[torch.SymInt, torch.Tensor],
|
|
321
134
|
body_fn: Callable[..., Tuple[torch.Tensor]],
|
|
@@ -340,144 +153,6 @@ def loop_for(
|
|
|
340
153
|
one of each output, each of them is concatenated into one
|
|
341
154
|
tensor along one dimension, by default, it is the first
|
|
342
155
|
dimension, but it can be defined otherwise
|
|
343
|
-
|
|
344
|
-
.. runpython::
|
|
345
|
-
:showcode:
|
|
346
|
-
|
|
347
|
-
import torch
|
|
348
|
-
import onnxruntime
|
|
349
|
-
from onnx_diagnostic.export.api import to_onnx
|
|
350
|
-
from onnx_diagnostic.export.control_flow import loop_for
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
class Model(torch.nn.Module):
|
|
354
|
-
def forward(self, n_iter, x):
|
|
355
|
-
def body(i, x):
|
|
356
|
-
return x[: i.item() + 1].unsqueeze(1)
|
|
357
|
-
|
|
358
|
-
return loop_for(n_iter, body, (x,))
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
model = Model()
|
|
362
|
-
n_iter = torch.tensor(4, dtype=torch.int64)
|
|
363
|
-
x = torch.arange(10, dtype=torch.float32)
|
|
364
|
-
expected = model(n_iter, x)
|
|
365
|
-
print("expected:", expected)
|
|
366
|
-
|
|
367
|
-
onx = to_onnx(
|
|
368
|
-
model,
|
|
369
|
-
(n_iter, x),
|
|
370
|
-
dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
|
|
371
|
-
exporter="custom",
|
|
372
|
-
use_control_flow_dispatcher=True,
|
|
373
|
-
).model_proto
|
|
374
|
-
|
|
375
|
-
sess = onnxruntime.InferenceSession(
|
|
376
|
-
onx.SerializeToString(), providers=["CPUExecutionProvider"]
|
|
377
|
-
)
|
|
378
|
-
got = sess.run(None, dict(zip(["n_iter", "x"], [n_iter.numpy(), x.numpy()])))
|
|
379
|
-
print("got:", got)
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
# The loop is exported as a custom ops.
|
|
383
|
-
ep = torch.export.export(
|
|
384
|
-
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
|
|
385
|
-
)
|
|
386
|
-
print(ep)
|
|
387
|
-
|
|
388
|
-
Another example with two outputs:
|
|
389
|
-
|
|
390
|
-
.. runpython::
|
|
391
|
-
:showcode:
|
|
392
|
-
|
|
393
|
-
import torch
|
|
394
|
-
import onnxruntime
|
|
395
|
-
from onnx_diagnostic.export.api import to_onnx
|
|
396
|
-
from onnx_diagnostic.export.control_flow import loop_for
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
class Model(torch.nn.Module):
|
|
400
|
-
def forward(self, n_iter, x):
|
|
401
|
-
def body(i, x):
|
|
402
|
-
return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(1) + 1
|
|
403
|
-
|
|
404
|
-
two = loop_for(n_iter, body, (x,))
|
|
405
|
-
return two[0] + two[1]
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
model = Model()
|
|
409
|
-
n_iter = torch.tensor(4, dtype=torch.int64)
|
|
410
|
-
x = torch.arange(10, dtype=torch.float32)
|
|
411
|
-
expected = model(n_iter, x)
|
|
412
|
-
print("expected:", expected)
|
|
413
|
-
|
|
414
|
-
onx = to_onnx(
|
|
415
|
-
model,
|
|
416
|
-
(n_iter, x),
|
|
417
|
-
dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
|
|
418
|
-
exporter="custom",
|
|
419
|
-
use_control_flow_dispatcher=True,
|
|
420
|
-
).model_proto
|
|
421
|
-
|
|
422
|
-
sess = onnxruntime.InferenceSession(
|
|
423
|
-
onx.SerializeToString(), providers=["CPUExecutionProvider"]
|
|
424
|
-
)
|
|
425
|
-
got = sess.run(None, dict(zip(["n_iter", "x"], [n_iter.numpy(), x.numpy()])))
|
|
426
|
-
print("got:", got)
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
# The loop is exported as a custom ops.
|
|
430
|
-
ep = torch.export.export(
|
|
431
|
-
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
|
|
432
|
-
)
|
|
433
|
-
print(ep)
|
|
434
|
-
|
|
435
|
-
A last example with ``reduction_dim``:
|
|
436
|
-
|
|
437
|
-
.. runpython::
|
|
438
|
-
:showcode:
|
|
439
|
-
|
|
440
|
-
import torch
|
|
441
|
-
import onnxruntime
|
|
442
|
-
from onnx_diagnostic.export.api import to_onnx
|
|
443
|
-
from onnx_diagnostic.export.control_flow import loop_for
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
class Model(torch.nn.Module):
|
|
447
|
-
def forward(self, n_iter, x):
|
|
448
|
-
def body(i, x):
|
|
449
|
-
return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1
|
|
450
|
-
|
|
451
|
-
two = loop_for(n_iter, body, (x,), reduction_dim=[0, 1])
|
|
452
|
-
return two[0] + two[1].T
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
model = Model()
|
|
456
|
-
n_iter = torch.tensor(4, dtype=torch.int64)
|
|
457
|
-
x = torch.arange(10, dtype=torch.float32)
|
|
458
|
-
expected = model(n_iter, x)
|
|
459
|
-
print("expected:", expected)
|
|
460
|
-
|
|
461
|
-
onx = to_onnx(
|
|
462
|
-
model,
|
|
463
|
-
(n_iter, x),
|
|
464
|
-
dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
|
|
465
|
-
exporter="custom",
|
|
466
|
-
use_control_flow_dispatcher=True,
|
|
467
|
-
).model_proto
|
|
468
|
-
|
|
469
|
-
sess = onnxruntime.InferenceSession(
|
|
470
|
-
onx.SerializeToString(), providers=["CPUExecutionProvider"]
|
|
471
|
-
)
|
|
472
|
-
got = sess.run(None, dict(zip(["n_iter", "x"], [n_iter.numpy(), x.numpy()])))
|
|
473
|
-
print("got:", got)
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
# The loop is exported as a custom ops.
|
|
477
|
-
ep = torch.export.export(
|
|
478
|
-
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
|
|
479
|
-
)
|
|
480
|
-
print(ep)
|
|
481
156
|
"""
|
|
482
157
|
assert args, "The function should have at least one arg."
|
|
483
158
|
assert (
|
|
@@ -486,6 +161,12 @@ def loop_for(
|
|
|
486
161
|
and len(n_iter.shape) == 0
|
|
487
162
|
), f"Only a tensor for one int64 is allowed for n_iter but it equal to {n_iter}."
|
|
488
163
|
if is_exporting():
|
|
164
|
+
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
|
|
165
|
+
|
|
166
|
+
# tracer = _CURRENT_MAKE_FX_TRACER.fx_tracer
|
|
167
|
+
root = _CURRENT_MAKE_FX_TRACER.fx_tracer.root
|
|
168
|
+
# graph = _CURRENT_MAKE_FX_TRACER.fx_tracer.graph
|
|
169
|
+
|
|
489
170
|
body_gm: torch.fx.GraphModule = materialize_as_graph(
|
|
490
171
|
body_fn, (torch.tensor(0, dtype=torch.int64), *args)
|
|
491
172
|
)
|
|
@@ -505,7 +186,29 @@ def loop_for(
|
|
|
505
186
|
body_mutated_inputs=body_mutated_inputs,
|
|
506
187
|
body_outputs=body_outputs,
|
|
507
188
|
)
|
|
508
|
-
|
|
509
|
-
|
|
189
|
+
root.register_module(name, body_gm)
|
|
190
|
+
# body_graph = _maybe_reenter_make_fx(body_fn)(n_iter, *args)
|
|
191
|
+
return torch.ops.onnx_higher_ops.loop_for(name, n_iter, args)
|
|
510
192
|
|
|
511
193
|
return _loop_for_fn(n_iter, body_fn, reduction_dim, args)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
"""
|
|
197
|
+
proxy_mode.tracer.root.register_module(cond_graph_name, cond_graph)
|
|
198
|
+
proxy_mode.tracer.root.register_module(body_graph_name, body_graph)
|
|
199
|
+
|
|
200
|
+
args = (cond_graph, body_graph, carried_inputs, additional_inputs)
|
|
201
|
+
|
|
202
|
+
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
|
|
203
|
+
|
|
204
|
+
out_proxy = proxy_mode.tracer.create_proxy(
|
|
205
|
+
"call_function", op, proxy_args, {}, name=op._name
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
out = op(
|
|
209
|
+
cond_graph, body_graph, unspecialized_carried_inputs, additional_inputs
|
|
210
|
+
)
|
|
211
|
+
return track_tensor_tree(
|
|
212
|
+
out, out_proxy, constant=None, tracer=proxy_mode.tracer
|
|
213
|
+
)
|
|
214
|
+
"""
|