onnx-diagnostic 0.8.6__py3-none-any.whl → 0.8.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +108 -3
- onnx_diagnostic/ci_models/ci_helpers.py +12 -7
- onnx_diagnostic/ci_models/export_phi4_mm.py +1062 -0
- onnx_diagnostic/ci_models/export_qwen25_vl.py +12 -4
- onnx_diagnostic/export/api.py +295 -5
- onnx_diagnostic/export/cf_simple_loop_for.py +195 -10
- onnx_diagnostic/export/dynamic_shapes.py +45 -3
- onnx_diagnostic/export/shape_helper.py +1 -0
- onnx_diagnostic/ext_test_case.py +9 -2
- onnx_diagnostic/helpers/bench_run.py +1 -1
- onnx_diagnostic/helpers/cache_helper.py +0 -8
- onnx_diagnostic/helpers/fake_tensor_helper.py +26 -5
- onnx_diagnostic/helpers/helper.py +30 -1
- onnx_diagnostic/helpers/log_helper.py +1 -3
- onnx_diagnostic/helpers/optim_helper.py +116 -0
- onnx_diagnostic/helpers/ort_session.py +5 -0
- onnx_diagnostic/tasks/image_text_to_text.py +19 -9
- onnx_diagnostic/tasks/text2text_generation.py +84 -48
- onnx_diagnostic/tasks/text_generation.py +3 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +28 -2
- onnx_diagnostic/torch_export_patches/patch_details.py +3 -3
- onnx_diagnostic/torch_export_patches/patch_expressions.py +4 -1
- onnx_diagnostic/torch_export_patches/patch_module.py +31 -23
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +14 -5
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py +80 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +12 -1
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +15 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +22 -24
- onnx_diagnostic/torch_models/hghub/hub_api.py +11 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +9 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -19
- onnx_diagnostic/torch_models/validate.py +48 -0
- {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/METADATA +3 -1
- {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/RECORD +39 -36
- {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/top_level.txt +0 -0
|
@@ -8,10 +8,10 @@ Requirements
|
|
|
8
8
|
::
|
|
9
9
|
|
|
10
10
|
git+https://github.com/sdpython/experimental-experiment.git # optional
|
|
11
|
-
huggingface_hub
|
|
11
|
+
huggingface_hub
|
|
12
12
|
onnx-diagnostic>=0.8.6
|
|
13
13
|
onnxruntime>=1.23
|
|
14
|
-
torch>=2.
|
|
14
|
+
torch>=2.10 # weekly is better
|
|
15
15
|
tqdm
|
|
16
16
|
transformers>=4.57
|
|
17
17
|
|
|
@@ -59,6 +59,7 @@ It is possible to overwrite this by by setting environment variable
|
|
|
59
59
|
import os
|
|
60
60
|
import sys
|
|
61
61
|
import time
|
|
62
|
+
import warnings
|
|
62
63
|
from typing import Any, Dict, List, Tuple
|
|
63
64
|
from .ci_helpers import (
|
|
64
65
|
check_for_discrepancies_and_log_everything_into_a_json_file,
|
|
@@ -97,7 +98,6 @@ def get_untrained_model(model_id: str, second_input: bool, verbose: int) -> Dict
|
|
|
97
98
|
},
|
|
98
99
|
# "_attn_implementation": "flash_attention_2",
|
|
99
100
|
"_attn_implementation": "sdpa",
|
|
100
|
-
"dtype": "float16",
|
|
101
101
|
}
|
|
102
102
|
|
|
103
103
|
config_reduction = _config_reduction
|
|
@@ -281,6 +281,10 @@ def main(
|
|
|
281
281
|
).eval()
|
|
282
282
|
data = dict(model=model)
|
|
283
283
|
config = model.config
|
|
284
|
+
if not hasattr(config, "bos_token_id") or not config.bos_token_id:
|
|
285
|
+
config.bos_token_id = 151643
|
|
286
|
+
if not hasattr(config, "eos_token_id") or not config.eos_token_id:
|
|
287
|
+
config.eos_token_id = 151645
|
|
284
288
|
else:
|
|
285
289
|
print("-- random model")
|
|
286
290
|
data = get_untrained_model(model_id, second_input=second_input, verbose=1)
|
|
@@ -298,7 +302,11 @@ def main(
|
|
|
298
302
|
print(f"-- config._attn_implementation={model.config._attn_implementation}")
|
|
299
303
|
print(f"-- model.dtype={model.dtype}")
|
|
300
304
|
print(f"-- model.device={model.device}")
|
|
301
|
-
|
|
305
|
+
try:
|
|
306
|
+
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
|
|
307
|
+
except OSError as e:
|
|
308
|
+
warnings.warn(f"Unable to access internet due to {e!r}", ResourceWarning, stacklevel=0)
|
|
309
|
+
return
|
|
302
310
|
print(f"-- processor={type(processor)}")
|
|
303
311
|
|
|
304
312
|
export_inputs, other_inputs = None, None
|
onnx_diagnostic/export/api.py
CHANGED
|
@@ -1,6 +1,11 @@
|
|
|
1
|
-
|
|
1
|
+
import inspect
|
|
2
|
+
import os
|
|
3
|
+
import textwrap
|
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
|
|
2
5
|
import torch
|
|
6
|
+
from .dynamic_shapes import ModelInputs
|
|
3
7
|
from .onnx_plug import EagerDirectReplacementWithOnnx
|
|
8
|
+
from ..helpers import string_type
|
|
4
9
|
|
|
5
10
|
|
|
6
11
|
def get_main_dispatcher(
|
|
@@ -70,6 +75,7 @@ def to_onnx(
|
|
|
70
75
|
inline: bool = True,
|
|
71
76
|
) -> Any:
|
|
72
77
|
"""
|
|
78
|
+
Exports one model into ONNX.
|
|
73
79
|
Common API for exporters. By default, the models are optimized to use the
|
|
74
80
|
most efficient kernels implemented in :epkg:`onnxruntime`.
|
|
75
81
|
|
|
@@ -126,8 +132,12 @@ def to_onnx(
|
|
|
126
132
|
from experimental_experiment.xbuilder import OptimizationOptions
|
|
127
133
|
|
|
128
134
|
options = None
|
|
135
|
+
export_options = None
|
|
129
136
|
if exporter_kwargs is not None:
|
|
130
137
|
options = exporter_kwargs.pop("options", None)
|
|
138
|
+
export_options = exporter_kwargs.pop("export_options", None)
|
|
139
|
+
if export_options is None:
|
|
140
|
+
export_options = ExportOptions(save_ep=save_ep)
|
|
131
141
|
if options is None and optimize:
|
|
132
142
|
options = OptimizationOptions(
|
|
133
143
|
patterns="default+onnxruntime" if optimizer_for_ort else "default"
|
|
@@ -138,7 +148,7 @@ def to_onnx(
|
|
|
138
148
|
else None
|
|
139
149
|
)
|
|
140
150
|
|
|
141
|
-
|
|
151
|
+
proto, opt_stats = _to_onnx(
|
|
142
152
|
mod,
|
|
143
153
|
args=args,
|
|
144
154
|
kwargs=kwargs,
|
|
@@ -150,15 +160,52 @@ def to_onnx(
|
|
|
150
160
|
dynamic_shapes=dynamic_shapes,
|
|
151
161
|
large_model=True,
|
|
152
162
|
output_dynamic_shapes=output_dynamic_shapes,
|
|
153
|
-
export_options=
|
|
163
|
+
export_options=export_options,
|
|
154
164
|
options=options,
|
|
155
165
|
inline=inline,
|
|
156
166
|
dispatcher=main_dispatcher,
|
|
167
|
+
optimize=optimize,
|
|
168
|
+
return_optimize_report=True,
|
|
157
169
|
**(exporter_kwargs or {}),
|
|
158
170
|
)
|
|
171
|
+
if opt_stats and filename and os.path.exists(filename):
|
|
172
|
+
import pandas
|
|
173
|
+
|
|
174
|
+
stat_filename = f"{os.path.splitext(filename)[0]}.opt.xlsx"
|
|
175
|
+
pattern_stats = []
|
|
176
|
+
for k, v in opt_stats.items():
|
|
177
|
+
if "time" in k:
|
|
178
|
+
pattern_stats.append(dict(level="main", pattern=k, time_in=v))
|
|
179
|
+
pattern_stats.extend(
|
|
180
|
+
[{**obs, "level": "detailed"} for obs in opt_stats["optimization"]]
|
|
181
|
+
)
|
|
182
|
+
df = pandas.DataFrame(pattern_stats)
|
|
183
|
+
df.to_excel(stat_filename, index=False)
|
|
184
|
+
cols = [
|
|
185
|
+
c
|
|
186
|
+
for c in [
|
|
187
|
+
"level",
|
|
188
|
+
"pattern",
|
|
189
|
+
"time_in",
|
|
190
|
+
"iteration",
|
|
191
|
+
"inlined",
|
|
192
|
+
"removed",
|
|
193
|
+
"added",
|
|
194
|
+
"instances",
|
|
195
|
+
"changed",
|
|
196
|
+
"scale",
|
|
197
|
+
]
|
|
198
|
+
if c in df.columns
|
|
199
|
+
]
|
|
200
|
+
agg = {k: "sum" for k in cols if k not in ("level", "pattern")}
|
|
201
|
+
agg.update(dict(iteration="max", instances="mean"))
|
|
202
|
+
agg = {k: v for k, v in agg.items() if k in df.columns}
|
|
203
|
+
stat_filename = f"{os.path.splitext(filename)[0]}.opt.agg.xlsx"
|
|
204
|
+
df[cols].groupby(["level", "pattern"]).agg(agg).to_excel(stat_filename)
|
|
205
|
+
|
|
206
|
+
return proto
|
|
159
207
|
|
|
160
208
|
if exporter in ("dynamo", "onnx-dynamo"):
|
|
161
|
-
import os
|
|
162
209
|
from ..helpers import flatten_object
|
|
163
210
|
import onnxscript.rewriter.ort_fusions as ort_fusions
|
|
164
211
|
|
|
@@ -225,7 +272,6 @@ def to_onnx(
|
|
|
225
272
|
return epo
|
|
226
273
|
|
|
227
274
|
if exporter == "modelbuilder":
|
|
228
|
-
import os
|
|
229
275
|
from ..helpers import flatten_object, string_type
|
|
230
276
|
from ..helpers.model_builder_helper import create_model_builder, save_model_builder
|
|
231
277
|
|
|
@@ -266,3 +312,247 @@ def to_onnx(
|
|
|
266
312
|
return onx
|
|
267
313
|
|
|
268
314
|
raise ValueError(f"Unknown exporter={exporter!r}")
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
class _WrapperToExportMethodToOnnx(torch.nn.Module):
|
|
318
|
+
"""
|
|
319
|
+
Wraps an existing models in order to spy on inputs.
|
|
320
|
+
This is used by :func:`onnx_diagnostic.export.api.method_to_onnx`.
|
|
321
|
+
"""
|
|
322
|
+
|
|
323
|
+
def __init__(
|
|
324
|
+
self,
|
|
325
|
+
mod: "torch.nn.Module",
|
|
326
|
+
method_name: str = "forward",
|
|
327
|
+
input_names: Optional[Sequence[str]] = None,
|
|
328
|
+
target_opset: Optional[Union[int, Dict[str, int]]] = None,
|
|
329
|
+
verbose: int = 0,
|
|
330
|
+
filename: Optional[str] = None,
|
|
331
|
+
output_names: Optional[List[str]] = None,
|
|
332
|
+
output_dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
333
|
+
exporter: str = "onnx-dynamo",
|
|
334
|
+
exporter_kwargs: Optional[Dict[str, Any]] = None,
|
|
335
|
+
save_ep: Optional[str] = None,
|
|
336
|
+
optimize: bool = True,
|
|
337
|
+
optimizer_for_ort: bool = True,
|
|
338
|
+
use_control_flow_dispatcher: bool = False,
|
|
339
|
+
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
|
|
340
|
+
inline: bool = True,
|
|
341
|
+
convert_after_n_calls: int = 2,
|
|
342
|
+
patch_kwargs: Optional[Dict[str, Any]] = None,
|
|
343
|
+
skip_kwargs_names: Optional[Set[str]] = None,
|
|
344
|
+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
345
|
+
):
|
|
346
|
+
super().__init__()
|
|
347
|
+
self._model_to_call = mod
|
|
348
|
+
self._method_name = method_name
|
|
349
|
+
self._method_call = (
|
|
350
|
+
self._model_to_call.forward
|
|
351
|
+
if method_name == "forward"
|
|
352
|
+
else getattr(mod, method_name)
|
|
353
|
+
)
|
|
354
|
+
self._inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = []
|
|
355
|
+
self._convert_after_n_calls = convert_after_n_calls
|
|
356
|
+
self._patch_kwargs = patch_kwargs
|
|
357
|
+
self._method_src = None
|
|
358
|
+
self.verbose = verbose
|
|
359
|
+
self.skip_kwargs_names = skip_kwargs_names
|
|
360
|
+
self.dynamic_shapes = dynamic_shapes
|
|
361
|
+
self._to_onnx_kwargs = dict(
|
|
362
|
+
input_names=input_names,
|
|
363
|
+
target_opset=target_opset,
|
|
364
|
+
verbose=verbose,
|
|
365
|
+
filename=filename,
|
|
366
|
+
output_names=output_names,
|
|
367
|
+
output_dynamic_shapes=output_dynamic_shapes,
|
|
368
|
+
exporter=exporter,
|
|
369
|
+
exporter_kwargs=exporter_kwargs,
|
|
370
|
+
save_ep=save_ep,
|
|
371
|
+
optimize=optimize,
|
|
372
|
+
optimizer_for_ort=optimizer_for_ort,
|
|
373
|
+
use_control_flow_dispatcher=use_control_flow_dispatcher,
|
|
374
|
+
onnx_plugs=onnx_plugs,
|
|
375
|
+
inline=inline,
|
|
376
|
+
)
|
|
377
|
+
self._export_done = False
|
|
378
|
+
|
|
379
|
+
def __str__(self) -> str:
|
|
380
|
+
return self.__repr__()
|
|
381
|
+
|
|
382
|
+
def __repr__(self) -> str:
|
|
383
|
+
return (
|
|
384
|
+
f"{self.__class__.__name__}({self._model_to_call.__class__.__name__}."
|
|
385
|
+
f"{self._method_name})"
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
def forward(self, *args, **kwargs):
|
|
389
|
+
if not self._export_done:
|
|
390
|
+
self._inputs.append(
|
|
391
|
+
(
|
|
392
|
+
args,
|
|
393
|
+
(
|
|
394
|
+
kwargs
|
|
395
|
+
if not kwargs or not self.skip_kwargs_names
|
|
396
|
+
else {
|
|
397
|
+
k: v for k, v in kwargs.items() if k not in self.skip_kwargs_names
|
|
398
|
+
}
|
|
399
|
+
),
|
|
400
|
+
)
|
|
401
|
+
)
|
|
402
|
+
if self.verbose:
|
|
403
|
+
print(
|
|
404
|
+
f"[method_to_onnx] input[{len(self._inputs)-1}]: "
|
|
405
|
+
f"{string_type(self._inputs[-1], with_shape=True)}"
|
|
406
|
+
)
|
|
407
|
+
if len(self._inputs) >= self._convert_after_n_calls:
|
|
408
|
+
self._convert_method_to_onnx()
|
|
409
|
+
del self._inputs[:]
|
|
410
|
+
self._export_done = True
|
|
411
|
+
return self._method_call(*args, **kwargs)
|
|
412
|
+
|
|
413
|
+
def _convert_method_to_onnx(self):
|
|
414
|
+
|
|
415
|
+
def make_method(self):
|
|
416
|
+
inner_sig = inspect.signature(self._method_call)
|
|
417
|
+
params = [
|
|
418
|
+
p.replace(annotation=inspect._empty) for p in inner_sig.parameters.values()
|
|
419
|
+
]
|
|
420
|
+
simple_sig = inspect.Signature(params, return_annotation=inspect._empty)
|
|
421
|
+
args = str(simple_sig)[1:-1]
|
|
422
|
+
calls_args = ", ".join(f"{p}={p}" for p in simple_sig.parameters)
|
|
423
|
+
src = textwrap.dedent(
|
|
424
|
+
f"""
|
|
425
|
+
def f(self, {args}):
|
|
426
|
+
return self._method_call({calls_args})
|
|
427
|
+
"""
|
|
428
|
+
)
|
|
429
|
+
self._method_src = src
|
|
430
|
+
ns = {}
|
|
431
|
+
try:
|
|
432
|
+
exec(src, ns)
|
|
433
|
+
except NameError as e:
|
|
434
|
+
raise NameError(f"Unable to compile due to {e}\n{src}") from e
|
|
435
|
+
return ns["f"]
|
|
436
|
+
|
|
437
|
+
class WrapWithExactSignature(torch.nn.Module):
|
|
438
|
+
def __init__(self, parent):
|
|
439
|
+
super().__init__()
|
|
440
|
+
self._model_to_call = parent._model_to_call
|
|
441
|
+
self._method_call = parent._method_call
|
|
442
|
+
|
|
443
|
+
forward = make_method(self)
|
|
444
|
+
|
|
445
|
+
compiled_model = WrapWithExactSignature(self)
|
|
446
|
+
|
|
447
|
+
if self.dynamic_shapes is None:
|
|
448
|
+
mi = ModelInputs(compiled_model, self._inputs)
|
|
449
|
+
ds = mi.guess_dynamic_shapes()
|
|
450
|
+
if self.verbose:
|
|
451
|
+
print(f"[method_to_onnx] guess_dynamic_shapes={string_type(ds)}")
|
|
452
|
+
a, kw, nds = mi.move_to_kwargs(*self._inputs[-1], ds)
|
|
453
|
+
else:
|
|
454
|
+
a, kw = self._inputs[-1]
|
|
455
|
+
nds = [self.dynamic_shapes]
|
|
456
|
+
if self.verbose:
|
|
457
|
+
print(f"[method_to_onnx] export args={string_type(a, with_shape=True)}")
|
|
458
|
+
print(f"[method_to_onnx] export kwargs={string_type(kw, with_shape=True)}")
|
|
459
|
+
print(f"[method_to_onnx] dynamic_shapes={string_type(nds)}")
|
|
460
|
+
if self._patch_kwargs is None:
|
|
461
|
+
to_onnx(
|
|
462
|
+
compiled_model,
|
|
463
|
+
args=a,
|
|
464
|
+
kwargs=kw,
|
|
465
|
+
dynamic_shapes=nds[-1],
|
|
466
|
+
**self._to_onnx_kwargs,
|
|
467
|
+
)
|
|
468
|
+
return
|
|
469
|
+
from ..torch_export_patches import torch_export_patches
|
|
470
|
+
|
|
471
|
+
with torch_export_patches(**self._patch_kwargs):
|
|
472
|
+
to_onnx(
|
|
473
|
+
compiled_model,
|
|
474
|
+
args=a,
|
|
475
|
+
kwargs=kw,
|
|
476
|
+
dynamic_shapes=nds[-1],
|
|
477
|
+
**self._to_onnx_kwargs,
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def method_to_onnx(
|
|
482
|
+
mod: "torch.nn.Module",
|
|
483
|
+
method_name: str = "forward",
|
|
484
|
+
input_names: Optional[Sequence[str]] = None,
|
|
485
|
+
target_opset: Optional[Union[int, Dict[str, int]]] = None,
|
|
486
|
+
verbose: int = 0,
|
|
487
|
+
filename: Optional[str] = None,
|
|
488
|
+
output_names: Optional[List[str]] = None,
|
|
489
|
+
output_dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
490
|
+
exporter: str = "onnx-dynamo",
|
|
491
|
+
exporter_kwargs: Optional[Dict[str, Any]] = None,
|
|
492
|
+
save_ep: Optional[str] = None,
|
|
493
|
+
optimize: bool = True,
|
|
494
|
+
optimizer_for_ort: bool = True,
|
|
495
|
+
use_control_flow_dispatcher: bool = False,
|
|
496
|
+
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
|
|
497
|
+
inline: bool = True,
|
|
498
|
+
convert_after_n_calls: int = 2,
|
|
499
|
+
patch_kwargs: Optional[Dict[str, Any]] = None,
|
|
500
|
+
skip_kwargs_names: Optional[Set[str]] = None,
|
|
501
|
+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
502
|
+
) -> Callable:
|
|
503
|
+
"""
|
|
504
|
+
Exports one method into ONNX for a module into ONNX.
|
|
505
|
+
It returns a new method which must be called by the user
|
|
506
|
+
at least twice with different values for the dynamic dimension
|
|
507
|
+
between triggering the conversion into ONNX.
|
|
508
|
+
|
|
509
|
+
:param mod_meth: function to export into ONNX
|
|
510
|
+
:param input_names: input names for the onnx model (optional)
|
|
511
|
+
:param target_opset: opset to target, if not specified, each converter
|
|
512
|
+
keeps its default value
|
|
513
|
+
:param verbose: verbosity level
|
|
514
|
+
:param filename: output filename, mandatory, the onnx model is saved on disk
|
|
515
|
+
:param output_names: to change the output of the onnx model
|
|
516
|
+
:param output_dynamic_shapes: to overwrite the dynamic shapes names
|
|
517
|
+
:param exporter: exporter to use (``onnx-dynamo``, ``modelbuilder``, ``custom``)
|
|
518
|
+
:param exporter_kwargs: additional parameters sent to the exporter
|
|
519
|
+
:param save_ep: saves the exported program
|
|
520
|
+
:param optimize: optimizes the model
|
|
521
|
+
:param optimizer_for_ort: optimizes the model for onnxruntime
|
|
522
|
+
:param use_control_flow_dispatcher: use the dispatcher created to supported
|
|
523
|
+
custom loops (see :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx`)
|
|
524
|
+
:param onnx_plugs: the code was modified to replace some parts with onnx translation
|
|
525
|
+
:param inline: inline local functions
|
|
526
|
+
:param convert_after_n_calls: converts the model after this number of calls.
|
|
527
|
+
:param patch_kwargs: patch arguments
|
|
528
|
+
:param skip_kwargs_names: use default values for these parameters part of
|
|
529
|
+
the signature of the method to export
|
|
530
|
+
:param dynamic_shapes: dynamic shapes to use if the guessed ones are not right
|
|
531
|
+
:return: the output of the selected exporter, usually a structure including
|
|
532
|
+
an onnx model
|
|
533
|
+
|
|
534
|
+
See :ref:`l-plot-tiny-llm-export-method-generate` for an example.
|
|
535
|
+
"""
|
|
536
|
+
wrapped_model = _WrapperToExportMethodToOnnx(
|
|
537
|
+
mod=mod,
|
|
538
|
+
method_name=method_name,
|
|
539
|
+
input_names=input_names,
|
|
540
|
+
target_opset=target_opset,
|
|
541
|
+
verbose=verbose,
|
|
542
|
+
filename=filename,
|
|
543
|
+
output_names=output_names,
|
|
544
|
+
output_dynamic_shapes=output_dynamic_shapes,
|
|
545
|
+
exporter=exporter,
|
|
546
|
+
exporter_kwargs=exporter_kwargs,
|
|
547
|
+
save_ep=save_ep,
|
|
548
|
+
optimize=optimize,
|
|
549
|
+
optimizer_for_ort=optimizer_for_ort,
|
|
550
|
+
use_control_flow_dispatcher=use_control_flow_dispatcher,
|
|
551
|
+
onnx_plugs=onnx_plugs,
|
|
552
|
+
inline=inline,
|
|
553
|
+
convert_after_n_calls=convert_after_n_calls,
|
|
554
|
+
patch_kwargs=patch_kwargs,
|
|
555
|
+
skip_kwargs_names=skip_kwargs_names,
|
|
556
|
+
dynamic_shapes=dynamic_shapes,
|
|
557
|
+
)
|
|
558
|
+
return wrapped_model
|
|
@@ -11,6 +11,7 @@ from torch._higher_order_ops.utils import (
|
|
|
11
11
|
unique_graph_id,
|
|
12
12
|
validate_subgraph_args_types,
|
|
13
13
|
)
|
|
14
|
+
import torch._dynamo.variables.higher_order_ops as hop
|
|
14
15
|
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
|
15
16
|
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
|
16
17
|
|
|
@@ -97,14 +98,18 @@ def _simple_loop_for_fn(
|
|
|
97
98
|
f"Unexpected number of results {len(r)} for function {body_fn}, "
|
|
98
99
|
f"expected {len(res[-1])}"
|
|
99
100
|
)
|
|
101
|
+
assert all(isinstance(t, torch.Tensor) for t in r), (
|
|
102
|
+
f"Unexpected type {[type(_) for _ in r]} for returned by function {body_fn}, "
|
|
103
|
+
f"it must be a tuple of Tensor or a Tensor."
|
|
104
|
+
)
|
|
100
105
|
res.append(r)
|
|
101
106
|
else:
|
|
102
107
|
assert isinstance(r, torch.Tensor), (
|
|
103
|
-
f"Unexpected type {r}
|
|
104
|
-
f"it must be a tuple or a Tensor."
|
|
108
|
+
f"Unexpected type {type(r)} coming from function {body_fn}, "
|
|
109
|
+
f"it must be a tuple of Tensor or a Tensor."
|
|
105
110
|
)
|
|
106
111
|
assert not res or len(res[-1]) == 1, (
|
|
107
|
-
f"Unexpected number of results {len(r)}
|
|
112
|
+
f"Unexpected number of results {len(r)} coming from function {body_fn}, "
|
|
108
113
|
f"expected {len(res[-1])}"
|
|
109
114
|
)
|
|
110
115
|
res.append((r,))
|
|
@@ -126,8 +131,6 @@ def _simple_loop_for_fn(
|
|
|
126
131
|
)
|
|
127
132
|
|
|
128
133
|
|
|
129
|
-
# from torch._functorch.utils import exposed_in
|
|
130
|
-
# @exposed_in("torch")
|
|
131
134
|
def _simple_loop_for(
|
|
132
135
|
n_iter: Union[int, torch.Tensor],
|
|
133
136
|
body_fn: Callable,
|
|
@@ -159,7 +162,7 @@ def _simple_loop_for(
|
|
|
159
162
|
|
|
160
163
|
if torch.compiler.is_dynamo_compiling():
|
|
161
164
|
return simple_loop_for_op(
|
|
162
|
-
n_iter, body_fn,
|
|
165
|
+
n_iter, body_fn, operands, concatenation_dims=concatenation_dims
|
|
163
166
|
)
|
|
164
167
|
|
|
165
168
|
if isinstance(n_iter, (bool, int, float)):
|
|
@@ -181,8 +184,10 @@ def _simple_loop_for(
|
|
|
181
184
|
|
|
182
185
|
with setup_compilation_env() as _backend:
|
|
183
186
|
return _loop_for_op_wrapper(n_iter, body_fn, operands, concatenation_dims)
|
|
184
|
-
#
|
|
185
|
-
#
|
|
187
|
+
# This is needed to support function body using module weights or function body
|
|
188
|
+
# defined as a class method. This is yet to be implemented.
|
|
189
|
+
# cpl = torch.compile(_loop_for_op_wrapper, backend=_backend, fullgraph=True)
|
|
190
|
+
# return cpl(n_iter, body_fn, operands, concatenation_dims)
|
|
186
191
|
|
|
187
192
|
|
|
188
193
|
def trace_simple_loop_for(
|
|
@@ -236,9 +241,15 @@ def loop_for_op_dense(n_iter, body_fn, operands, concatenation_dims=None):
|
|
|
236
241
|
)
|
|
237
242
|
mode = _get_current_dispatch_mode()
|
|
238
243
|
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
|
|
239
|
-
|
|
240
|
-
|
|
244
|
+
is_fake = isinstance(n_iter, torch._subclasses.fake_tensor.FakeTensor)
|
|
245
|
+
res = _simple_loop_for_fn(n_iter, body_fn, operands, concatenation_dims=concatenation_dims)
|
|
246
|
+
assert is_fake or not any(
|
|
247
|
+
isinstance(r, torch._subclasses.fake_tensor.FakeTensor) for r in res
|
|
248
|
+
), (
|
|
249
|
+
f"One result is a fake tensor but the inputs were not, type(n_iter)={type(n_iter)}, "
|
|
250
|
+
f"operands: {[type(_) for _ in operands]}, res: {[type(_) for _ in res]}"
|
|
241
251
|
)
|
|
252
|
+
return res
|
|
242
253
|
|
|
243
254
|
|
|
244
255
|
@simple_loop_for_op.py_impl(ProxyTorchDispatchMode)
|
|
@@ -267,6 +278,180 @@ simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCPU)
|
|
|
267
278
|
simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCUDA)
|
|
268
279
|
|
|
269
280
|
|
|
281
|
+
class SimpleLoopForHigherOrderVariable(hop.TorchHigherOrderOperatorVariable):
|
|
282
|
+
"""
|
|
283
|
+
Replicates the same pattern found for other higher order operators.
|
|
284
|
+
This enables recursive compilation and the use of modules inside a function.
|
|
285
|
+
"""
|
|
286
|
+
|
|
287
|
+
_HOP_NAME = "simple_loop_for"
|
|
288
|
+
_ALLOW_FALLBACK_TO_EAGER = False
|
|
289
|
+
supports_input_mutation = False
|
|
290
|
+
supports_aliasing = False
|
|
291
|
+
|
|
292
|
+
def _call_function(
|
|
293
|
+
self,
|
|
294
|
+
tx: torch._dynamo.symbolic_convert.InstructionTranslator,
|
|
295
|
+
args: list[hop.VariableTracker],
|
|
296
|
+
kwargs: dict[str, hop.VariableTracker],
|
|
297
|
+
) -> hop.VariableTracker:
|
|
298
|
+
"""Main function."""
|
|
299
|
+
args, kwargs = hop.LazyVariableTracker.realize_all((args, kwargs))
|
|
300
|
+
|
|
301
|
+
for i, k in enumerate(["n_iter", "body_fn", "operands", "concatenated_dims"]):
|
|
302
|
+
if v := kwargs.pop(k, None):
|
|
303
|
+
assert i == len(args), "did not provide the right number of non-keyword args"
|
|
304
|
+
args.append(v)
|
|
305
|
+
|
|
306
|
+
if len(args) != 4 or kwargs:
|
|
307
|
+
hop.unimplemented(
|
|
308
|
+
gb_type="simple_loop_for: improper args/kwargs",
|
|
309
|
+
context=f"args: {args}, kwargs: {kwargs}",
|
|
310
|
+
explanation=f"torch.cond expects 4 positional arguments (got {len(args)}) "
|
|
311
|
+
f"and no keyword arguments (got {len(kwargs)})",
|
|
312
|
+
hints=[*hop.graph_break_hints.USER_ERROR],
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
# Specialize into one of the branches since pred is constant
|
|
316
|
+
n_iter, body_fn, operands, _concatenated_dims = args
|
|
317
|
+
assert type(n_iter) is not hop.ConstantVariable, (
|
|
318
|
+
f"n_iter is a {type(n_iter)}. When used simple_loop_for, "
|
|
319
|
+
f"it unrolls the loop. A SymInt should be used."
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# predicate
|
|
323
|
+
if type(n_iter.realize()) not in (
|
|
324
|
+
hop.ConstantVariable,
|
|
325
|
+
hop.TensorVariable,
|
|
326
|
+
hop.SymNodeVariable,
|
|
327
|
+
):
|
|
328
|
+
hop.unimplemented(
|
|
329
|
+
gb_type="simple_loop_for: improper predicate",
|
|
330
|
+
context=str(n_iter),
|
|
331
|
+
explanation=(
|
|
332
|
+
f"Expected `n_iter` to be an int or a integer "
|
|
333
|
+
f"tensor with a single item "
|
|
334
|
+
f"but got {str(type(n_iter))} with original python type "
|
|
335
|
+
f"{str(n_iter.python_type())}."
|
|
336
|
+
),
|
|
337
|
+
hints=[*hop.graph_break_hints.USER_ERROR],
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# operands
|
|
341
|
+
if not isinstance(operands, (hop.ListVariable, hop.TupleVariable)):
|
|
342
|
+
hop.unimplemented(
|
|
343
|
+
gb_type="simple_loop_for: improper operands",
|
|
344
|
+
context=str(operands),
|
|
345
|
+
explanation="Expected `operands` to be a list/tuple "
|
|
346
|
+
f"but got {operands.python_type()}.",
|
|
347
|
+
hints=[*hop.graph_break_hints.USER_ERROR],
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
operands_seq = operands.unpack_var_sequence(tx)
|
|
351
|
+
if not hop.only_consist_of(
|
|
352
|
+
operands, (hop.TensorVariable, hop.ConstantVariable, hop.SymNodeVariable)
|
|
353
|
+
):
|
|
354
|
+
hop.unimplemented(
|
|
355
|
+
gb_type="simple_loop_for: improper operands contents",
|
|
356
|
+
context=str(operands),
|
|
357
|
+
explanation=(
|
|
358
|
+
"Expected `operands` to be a list/tuple of pytrees "
|
|
359
|
+
"that only consists of tensor leaves."
|
|
360
|
+
),
|
|
361
|
+
hints=[*hop.graph_break_hints.USER_ERROR],
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
# branches
|
|
365
|
+
hop._check_supported_callable_arg(tx, body_fn, "body_fn")
|
|
366
|
+
|
|
367
|
+
def speculate_body():
|
|
368
|
+
(
|
|
369
|
+
(ret_val, ret_spec),
|
|
370
|
+
ret_graph,
|
|
371
|
+
ret_lifted_freevars,
|
|
372
|
+
) = hop.speculate_subgraph(
|
|
373
|
+
tx,
|
|
374
|
+
args[1],
|
|
375
|
+
(args[0], *operands_seq),
|
|
376
|
+
{},
|
|
377
|
+
self._HOP_NAME,
|
|
378
|
+
source_target=self.value,
|
|
379
|
+
should_flatten_outputs=True,
|
|
380
|
+
# TODO - removing consts from control flow ops need more work
|
|
381
|
+
remove_consts_from_outputs=False,
|
|
382
|
+
supports_input_mutation=self.supports_input_mutation,
|
|
383
|
+
supports_aliasing=self.supports_aliasing,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# need to ensure we increase epoch so we don't memoize unbacked bindings
|
|
387
|
+
# across different subgraphs which can interfere with runtime assertion
|
|
388
|
+
# generation.
|
|
389
|
+
tx.fake_mode.epoch += 1
|
|
390
|
+
|
|
391
|
+
if not hop.only_consist_of(ret_val, (hop.TensorVariable, hop.ConstantVariable)):
|
|
392
|
+
hop.unimplemented(
|
|
393
|
+
gb_type="simple_loop_for: unsupported branch return type",
|
|
394
|
+
context=str(ret_val),
|
|
395
|
+
explanation=(
|
|
396
|
+
"Expected branches to return a possibly nested "
|
|
397
|
+
"pytree of tensors or constant ints."
|
|
398
|
+
),
|
|
399
|
+
hints=[*hop.graph_break_hints.USER_ERROR],
|
|
400
|
+
)
|
|
401
|
+
for ret in ret_val.unpack_var_sequence(tx):
|
|
402
|
+
if ret.is_python_constant() and not isinstance(ret.as_python_constant(), int):
|
|
403
|
+
hop.unimplemented(
|
|
404
|
+
gb_type=(
|
|
405
|
+
"simple_loop_for: unsupported branch return type "
|
|
406
|
+
"(constant non-int)"
|
|
407
|
+
),
|
|
408
|
+
context=str(ret_val),
|
|
409
|
+
explanation="Constants returned from branches must be ints.",
|
|
410
|
+
hints=[*hop.graph_break_hints.USER_ERROR],
|
|
411
|
+
)
|
|
412
|
+
return ret_val, ret_spec, ret_graph, ret_lifted_freevars
|
|
413
|
+
|
|
414
|
+
body_r, body_spec, body_graph, body_lifted_freevars = speculate_body()
|
|
415
|
+
body_nn_modules = dict(tx.output.nn_modules)
|
|
416
|
+
|
|
417
|
+
same_spec = body_spec.treespec.as_python_constant()
|
|
418
|
+
if same_spec is not NotImplemented and not same_spec:
|
|
419
|
+
hop.unimplemented(
|
|
420
|
+
gb_type="simple_loop_for: differing branch outputs",
|
|
421
|
+
context=(
|
|
422
|
+
f"body_spec: {body_spec.treespec}, false_spec: "
|
|
423
|
+
f"{body_spec.treespec}, same_spec: {same_spec}"
|
|
424
|
+
),
|
|
425
|
+
explanation="Expected branches to return the same pytree structure.",
|
|
426
|
+
hints=[*hop.graph_break_hints.USER_ERROR],
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
body_name = tx.output.install_subgraph(
|
|
430
|
+
"loop_body", torch.fx.GraphModule(body_nn_modules, body_graph)
|
|
431
|
+
)
|
|
432
|
+
body_node = hop.make_attr(tx, body_name)
|
|
433
|
+
p_args = (
|
|
434
|
+
n_iter.as_proxy(),
|
|
435
|
+
body_node,
|
|
436
|
+
# We pick true_shared but it shouldn't matter
|
|
437
|
+
operands.as_proxy() + tuple(body_lifted_freevars.keys()),
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
return hop._call_function_and_unflatten_output(
|
|
441
|
+
tx,
|
|
442
|
+
simple_loop_for,
|
|
443
|
+
p_args,
|
|
444
|
+
{},
|
|
445
|
+
None,
|
|
446
|
+
body_spec,
|
|
447
|
+
body_r,
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
hop._hop_name_to_variable_class["simple_loop_for"] = SimpleLoopForHigherOrderVariable
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
# @torch._functorch.utils.exposed_in("torch")
|
|
270
455
|
def simple_loop_for(
|
|
271
456
|
n_iter: Union[int, torch.Tensor],
|
|
272
457
|
body_fn: Callable,
|