onnx-diagnostic 0.8.7__py3-none-any.whl → 0.8.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/ci_models/export_phi4_mm.py +1 -1
- onnx_diagnostic/doc.py +258 -8
- onnx_diagnostic/export/api.py +755 -5
- onnx_diagnostic/export/dynamic_shapes.py +61 -4
- onnx_diagnostic/export/shape_helper.py +1 -8
- onnx_diagnostic/helpers/cache_helper.py +98 -21
- onnx_diagnostic/helpers/fake_tensor_helper.py +26 -5
- onnx_diagnostic/helpers/helper.py +36 -6
- onnx_diagnostic/helpers/onnx_helper.py +7 -0
- onnx_diagnostic/helpers/ort_session.py +5 -0
- onnx_diagnostic/helpers/rt_helper.py +14 -1
- onnx_diagnostic/helpers/torch_helper.py +22 -9
- onnx_diagnostic/tasks/image_text_to_text.py +8 -5
- onnx_diagnostic/tasks/text_generation.py +17 -17
- onnx_diagnostic/torch_export_patches/eval/__init__.py +1 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +62 -38
- onnx_diagnostic/torch_export_patches/patch_details.py +3 -3
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +14 -5
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -9
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +42 -30
- onnx_diagnostic/torch_models/validate.py +48 -0
- {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/METADATA +3 -1
- {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/RECORD +28 -28
- {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/top_level.txt +0 -0
onnx_diagnostic/export/api.py
CHANGED
|
@@ -1,6 +1,18 @@
|
|
|
1
|
-
|
|
1
|
+
import inspect
|
|
2
|
+
import os
|
|
3
|
+
import textwrap
|
|
4
|
+
import time
|
|
5
|
+
from collections.abc import Mapping, Iterable
|
|
6
|
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
|
|
2
7
|
import torch
|
|
8
|
+
from .dynamic_shapes import ModelInputs
|
|
3
9
|
from .onnx_plug import EagerDirectReplacementWithOnnx
|
|
10
|
+
from ..helpers import flatten_object, max_diff, string_diff, string_type
|
|
11
|
+
from ..helpers.cache_helper import CacheKeyValue
|
|
12
|
+
from ..helpers.torch_helper import torch_deepcopy
|
|
13
|
+
from ..helpers.rt_helper import make_feeds
|
|
14
|
+
from ..helpers.onnx_helper import pretty_onnx
|
|
15
|
+
from ..reference import OnnxruntimeEvaluator
|
|
4
16
|
|
|
5
17
|
|
|
6
18
|
def get_main_dispatcher(
|
|
@@ -70,6 +82,7 @@ def to_onnx(
|
|
|
70
82
|
inline: bool = True,
|
|
71
83
|
) -> Any:
|
|
72
84
|
"""
|
|
85
|
+
Exports one model into ONNX.
|
|
73
86
|
Common API for exporters. By default, the models are optimized to use the
|
|
74
87
|
most efficient kernels implemented in :epkg:`onnxruntime`.
|
|
75
88
|
|
|
@@ -126,8 +139,12 @@ def to_onnx(
|
|
|
126
139
|
from experimental_experiment.xbuilder import OptimizationOptions
|
|
127
140
|
|
|
128
141
|
options = None
|
|
142
|
+
export_options = None
|
|
129
143
|
if exporter_kwargs is not None:
|
|
130
144
|
options = exporter_kwargs.pop("options", None)
|
|
145
|
+
export_options = exporter_kwargs.pop("export_options", None)
|
|
146
|
+
if export_options is None:
|
|
147
|
+
export_options = ExportOptions(save_ep=save_ep)
|
|
131
148
|
if options is None and optimize:
|
|
132
149
|
options = OptimizationOptions(
|
|
133
150
|
patterns="default+onnxruntime" if optimizer_for_ort else "default"
|
|
@@ -138,7 +155,7 @@ def to_onnx(
|
|
|
138
155
|
else None
|
|
139
156
|
)
|
|
140
157
|
|
|
141
|
-
|
|
158
|
+
proto, opt_stats = _to_onnx(
|
|
142
159
|
mod,
|
|
143
160
|
args=args,
|
|
144
161
|
kwargs=kwargs,
|
|
@@ -150,16 +167,52 @@ def to_onnx(
|
|
|
150
167
|
dynamic_shapes=dynamic_shapes,
|
|
151
168
|
large_model=True,
|
|
152
169
|
output_dynamic_shapes=output_dynamic_shapes,
|
|
153
|
-
export_options=
|
|
170
|
+
export_options=export_options,
|
|
154
171
|
options=options,
|
|
155
172
|
inline=inline,
|
|
156
173
|
dispatcher=main_dispatcher,
|
|
157
174
|
optimize=optimize,
|
|
175
|
+
return_optimize_report=True,
|
|
158
176
|
**(exporter_kwargs or {}),
|
|
159
177
|
)
|
|
178
|
+
if opt_stats and filename and os.path.exists(filename):
|
|
179
|
+
import pandas
|
|
180
|
+
|
|
181
|
+
stat_filename = f"{os.path.splitext(filename)[0]}.opt.xlsx"
|
|
182
|
+
pattern_stats = []
|
|
183
|
+
for k, v in opt_stats.items():
|
|
184
|
+
if "time" in k:
|
|
185
|
+
pattern_stats.append(dict(level="main", pattern=k, time_in=v))
|
|
186
|
+
pattern_stats.extend(
|
|
187
|
+
[{**obs, "level": "detailed"} for obs in opt_stats["optimization"]]
|
|
188
|
+
)
|
|
189
|
+
df = pandas.DataFrame(pattern_stats)
|
|
190
|
+
df.to_excel(stat_filename, index=False)
|
|
191
|
+
cols = [
|
|
192
|
+
c
|
|
193
|
+
for c in [
|
|
194
|
+
"level",
|
|
195
|
+
"pattern",
|
|
196
|
+
"time_in",
|
|
197
|
+
"iteration",
|
|
198
|
+
"inlined",
|
|
199
|
+
"removed",
|
|
200
|
+
"added",
|
|
201
|
+
"instances",
|
|
202
|
+
"changed",
|
|
203
|
+
"scale",
|
|
204
|
+
]
|
|
205
|
+
if c in df.columns
|
|
206
|
+
]
|
|
207
|
+
agg = {k: "sum" for k in cols if k not in ("level", "pattern")}
|
|
208
|
+
agg.update(dict(iteration="max", instances="mean"))
|
|
209
|
+
agg = {k: v for k, v in agg.items() if k in df.columns}
|
|
210
|
+
stat_filename = f"{os.path.splitext(filename)[0]}.opt.agg.xlsx"
|
|
211
|
+
df[cols].groupby(["level", "pattern"]).agg(agg).to_excel(stat_filename)
|
|
212
|
+
|
|
213
|
+
return proto
|
|
160
214
|
|
|
161
215
|
if exporter in ("dynamo", "onnx-dynamo"):
|
|
162
|
-
import os
|
|
163
216
|
from ..helpers import flatten_object
|
|
164
217
|
import onnxscript.rewriter.ort_fusions as ort_fusions
|
|
165
218
|
|
|
@@ -226,7 +279,6 @@ def to_onnx(
|
|
|
226
279
|
return epo
|
|
227
280
|
|
|
228
281
|
if exporter == "modelbuilder":
|
|
229
|
-
import os
|
|
230
282
|
from ..helpers import flatten_object, string_type
|
|
231
283
|
from ..helpers.model_builder_helper import create_model_builder, save_model_builder
|
|
232
284
|
|
|
@@ -267,3 +319,701 @@ def to_onnx(
|
|
|
267
319
|
return onx
|
|
268
320
|
|
|
269
321
|
raise ValueError(f"Unknown exporter={exporter!r}")
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
class WrapperToExportMethodToOnnx(torch.nn.Module):
|
|
325
|
+
"""
|
|
326
|
+
Wraps an existing models in order to spy on inputs.
|
|
327
|
+
This is used by :func:`onnx_diagnostic.export.api.method_to_onnx`
|
|
328
|
+
or :ref:`l-plot-tiny-llm-export-method-generate` for an example.
|
|
329
|
+
"""
|
|
330
|
+
|
|
331
|
+
def __init__(
|
|
332
|
+
self,
|
|
333
|
+
mod: "torch.nn.Module",
|
|
334
|
+
method_name: str = "forward",
|
|
335
|
+
input_names: Optional[Sequence[str]] = None,
|
|
336
|
+
target_opset: Optional[Union[int, Dict[str, int]]] = None,
|
|
337
|
+
verbose: int = 0,
|
|
338
|
+
filename: Optional[str] = None,
|
|
339
|
+
output_names: Optional[List[str]] = None,
|
|
340
|
+
output_dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
341
|
+
exporter: str = "onnx-dynamo",
|
|
342
|
+
exporter_kwargs: Optional[Dict[str, Any]] = None,
|
|
343
|
+
save_ep: Optional[str] = None,
|
|
344
|
+
optimize: bool = True,
|
|
345
|
+
optimizer_for_ort: bool = True,
|
|
346
|
+
use_control_flow_dispatcher: bool = False,
|
|
347
|
+
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
|
|
348
|
+
inline: bool = True,
|
|
349
|
+
convert_after_n_calls: int = 2,
|
|
350
|
+
patch_kwargs: Optional[Dict[str, Any]] = None,
|
|
351
|
+
skip_kwargs_names: Optional[Set[str]] = None,
|
|
352
|
+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
353
|
+
dynamic_batch_for: Optional[Sequence[Union[int, str]]] = None,
|
|
354
|
+
expand_batch_for: Optional[Sequence[Union[int, str]]] = None,
|
|
355
|
+
):
|
|
356
|
+
super().__init__()
|
|
357
|
+
self._model_to_call = mod
|
|
358
|
+
self._method_name = method_name
|
|
359
|
+
self._method_call = (
|
|
360
|
+
self._model_to_call.forward
|
|
361
|
+
if method_name == "forward"
|
|
362
|
+
else getattr(mod, method_name)
|
|
363
|
+
)
|
|
364
|
+
self._signature = inspect.signature(self._method_call)
|
|
365
|
+
self._inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = []
|
|
366
|
+
self._outputs: List[Any] = []
|
|
367
|
+
self._convert_after_n_calls = convert_after_n_calls
|
|
368
|
+
self._patch_kwargs = patch_kwargs
|
|
369
|
+
self._method_src = None
|
|
370
|
+
self.verbose = verbose
|
|
371
|
+
self.skip_kwargs_names = skip_kwargs_names
|
|
372
|
+
self.dynamic_shapes = dynamic_shapes
|
|
373
|
+
self.expand_batch_for = expand_batch_for
|
|
374
|
+
self.dynamic_batch_for = dynamic_batch_for
|
|
375
|
+
self._to_onnx_kwargs = dict(
|
|
376
|
+
input_names=input_names,
|
|
377
|
+
target_opset=target_opset,
|
|
378
|
+
verbose=verbose,
|
|
379
|
+
filename=filename,
|
|
380
|
+
output_names=output_names,
|
|
381
|
+
output_dynamic_shapes=output_dynamic_shapes,
|
|
382
|
+
exporter=exporter,
|
|
383
|
+
exporter_kwargs=exporter_kwargs,
|
|
384
|
+
save_ep=save_ep,
|
|
385
|
+
optimize=optimize,
|
|
386
|
+
optimizer_for_ort=optimizer_for_ort,
|
|
387
|
+
use_control_flow_dispatcher=use_control_flow_dispatcher,
|
|
388
|
+
onnx_plugs=onnx_plugs,
|
|
389
|
+
inline=inline,
|
|
390
|
+
)
|
|
391
|
+
self._export_done = False
|
|
392
|
+
self._serialization_classes: Set[type] = set()
|
|
393
|
+
|
|
394
|
+
def __str__(self) -> str:
|
|
395
|
+
"usual"
|
|
396
|
+
return self.__repr__()
|
|
397
|
+
|
|
398
|
+
def __repr__(self) -> str:
|
|
399
|
+
"usual"
|
|
400
|
+
return (
|
|
401
|
+
f"{self.__class__.__name__}({self._model_to_call.__class__.__name__}."
|
|
402
|
+
f"{self._method_name})"
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
def _collect_classes(self, obj):
|
|
406
|
+
if obj is None or isinstance(obj, torch.Tensor):
|
|
407
|
+
return
|
|
408
|
+
cls = type(obj)
|
|
409
|
+
if cls.__module__ not in ("builtins",):
|
|
410
|
+
self._serialization_classes.add(cls)
|
|
411
|
+
if hasattr(obj, "__dict__"):
|
|
412
|
+
for v in vars(obj).values():
|
|
413
|
+
self._collect_classes(v)
|
|
414
|
+
return
|
|
415
|
+
if isinstance(obj, Mapping):
|
|
416
|
+
for v in obj.values():
|
|
417
|
+
self._collect_classes(v)
|
|
418
|
+
return
|
|
419
|
+
if isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)):
|
|
420
|
+
for v in obj:
|
|
421
|
+
self._collect_classes(v)
|
|
422
|
+
return
|
|
423
|
+
|
|
424
|
+
def _reorder_kwargs(self, kwargs):
|
|
425
|
+
new_kwargs = {k: kwargs[k] for k in self._signature.parameters if k in kwargs}
|
|
426
|
+
for k, v in kwargs.items():
|
|
427
|
+
if k not in new_kwargs:
|
|
428
|
+
new_kwargs[k] = v
|
|
429
|
+
return new_kwargs
|
|
430
|
+
|
|
431
|
+
def forward(self, *args, **kwargs):
|
|
432
|
+
if not self._export_done:
|
|
433
|
+
inp_args = args
|
|
434
|
+
# filters out the inputs not desired, int, float, bool, None
|
|
435
|
+
# are considered as constant for the exporter, they are removed
|
|
436
|
+
# from the named arguments.
|
|
437
|
+
inp_kwargs = (
|
|
438
|
+
kwargs
|
|
439
|
+
if not kwargs
|
|
440
|
+
else {
|
|
441
|
+
k: v
|
|
442
|
+
for k, v in kwargs.items()
|
|
443
|
+
if v is not None
|
|
444
|
+
and (not self.skip_kwargs_names or k not in self.skip_kwargs_names)
|
|
445
|
+
and not isinstance(v, (bool, int, float))
|
|
446
|
+
}
|
|
447
|
+
)
|
|
448
|
+
if self.expand_batch_for:
|
|
449
|
+
# extends the inputs to artificially create a batch dimension != 1.
|
|
450
|
+
inp_args = self._expand_batch_dimension(inp_args, self.expand_batch_for)
|
|
451
|
+
inp_kwargs = self._expand_batch_dimension(inp_kwargs, self.expand_batch_for)
|
|
452
|
+
inp_args, inp_kwargs = torch_deepcopy((inp_args, inp_kwargs))
|
|
453
|
+
# reorders the parameter following the method signature.
|
|
454
|
+
inp_kwargs = self._reorder_kwargs(inp_kwargs)
|
|
455
|
+
# stores the inputs
|
|
456
|
+
self._inputs.append((inp_args, inp_kwargs))
|
|
457
|
+
|
|
458
|
+
if self.verbose:
|
|
459
|
+
print(
|
|
460
|
+
f"[method_to_onnx] input[{len(self._inputs)-1}]: "
|
|
461
|
+
f"{string_type(self._inputs[-1], with_shape=True)}"
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
if len(self._inputs) >= self._convert_after_n_calls:
|
|
465
|
+
# conversion starts after _convert_after_n_calls calls to the forward method
|
|
466
|
+
name = os.path.splitext(self._to_onnx_kwargs["filename"])[0]
|
|
467
|
+
input_file = f"{name}.inputs.pt"
|
|
468
|
+
self._input_file = input_file
|
|
469
|
+
if self.verbose:
|
|
470
|
+
print(
|
|
471
|
+
f"[method_to_onnx] save {len(self._inputs)} inputs in {input_file!r}"
|
|
472
|
+
)
|
|
473
|
+
torch.save(self._inputs, input_file)
|
|
474
|
+
self._convert_method_to_onnx()
|
|
475
|
+
self._export_done = True
|
|
476
|
+
|
|
477
|
+
# calls the inner method (no change here)
|
|
478
|
+
begin = time.perf_counter()
|
|
479
|
+
res = self._method_call(*args, **kwargs)
|
|
480
|
+
duration = time.perf_counter() - begin
|
|
481
|
+
self._collect_classes([args, kwargs, res])
|
|
482
|
+
if self._inputs:
|
|
483
|
+
# stores the outputs if discrepancies need to be checked
|
|
484
|
+
self._outputs.append((torch_deepcopy(res), duration))
|
|
485
|
+
assert len(self._inputs) == len(self._outputs), (
|
|
486
|
+
f"Number of inputs {len(self._inputs)} and "
|
|
487
|
+
f"outputs {len(self._outputs)} are different."
|
|
488
|
+
)
|
|
489
|
+
if self._export_done:
|
|
490
|
+
name = os.path.splitext(self._to_onnx_kwargs["filename"])[0]
|
|
491
|
+
output_file = f"{name}.outputs.pt"
|
|
492
|
+
if self.verbose:
|
|
493
|
+
print(
|
|
494
|
+
f"[method_to_onnx] save {len(self._outputs)} "
|
|
495
|
+
f"outputs in {output_file!r}"
|
|
496
|
+
)
|
|
497
|
+
torch.save(self._outputs, output_file)
|
|
498
|
+
self._output_file = output_file
|
|
499
|
+
del self._inputs[:]
|
|
500
|
+
del self._outputs[:]
|
|
501
|
+
return res
|
|
502
|
+
|
|
503
|
+
def _convert_method_to_onnx(self):
|
|
504
|
+
for args, kwargs in self._inputs:
|
|
505
|
+
self._serialization_classes |= {type(a) for a in args}
|
|
506
|
+
self._serialization_classes |= {type(a) for a in kwargs.values()}
|
|
507
|
+
|
|
508
|
+
def make_method(self):
|
|
509
|
+
inner_sig = inspect.signature(self._method_call)
|
|
510
|
+
params = [
|
|
511
|
+
p.replace(annotation=inspect._empty) for p in inner_sig.parameters.values()
|
|
512
|
+
]
|
|
513
|
+
simple_sig = inspect.Signature(params, return_annotation=inspect._empty)
|
|
514
|
+
args = str(simple_sig)[1:-1]
|
|
515
|
+
calls_args = ", ".join(f"{p}={p}" for p in simple_sig.parameters)
|
|
516
|
+
src = textwrap.dedent(
|
|
517
|
+
f"""
|
|
518
|
+
def f(self, {args}):
|
|
519
|
+
return self._method_call({calls_args})
|
|
520
|
+
"""
|
|
521
|
+
)
|
|
522
|
+
self._method_src = src
|
|
523
|
+
ns = {}
|
|
524
|
+
try:
|
|
525
|
+
exec(src, ns)
|
|
526
|
+
except NameError as e:
|
|
527
|
+
raise NameError(f"Unable to compile due to {e}\n{src}") from e
|
|
528
|
+
return ns["f"]
|
|
529
|
+
|
|
530
|
+
class WrapWithExactSignature(torch.nn.Module):
|
|
531
|
+
def __init__(self, parent):
|
|
532
|
+
super().__init__()
|
|
533
|
+
self._model_to_call = parent._model_to_call
|
|
534
|
+
self._method_call = parent._method_call
|
|
535
|
+
|
|
536
|
+
forward = make_method(self)
|
|
537
|
+
|
|
538
|
+
compiled_model = WrapWithExactSignature(self)
|
|
539
|
+
|
|
540
|
+
if self.dynamic_shapes is None:
|
|
541
|
+
mi = ModelInputs(compiled_model, self._inputs)
|
|
542
|
+
ds = mi.guess_dynamic_shapes()
|
|
543
|
+
if self.verbose:
|
|
544
|
+
print(f"[method_to_onnx] guess_dynamic_shapes={string_type(ds)}")
|
|
545
|
+
a, kw, nds = mi.move_to_kwargs(*self._inputs[-1], ds)
|
|
546
|
+
if self.dynamic_batch_for:
|
|
547
|
+
nds = (
|
|
548
|
+
self._dynamic_batch_dimension(nds[0], self.dynamic_batch_for),
|
|
549
|
+
self.rename_dynamic_shapes(
|
|
550
|
+
self._dynamic_batch_dimension(nds[1], self.dynamic_batch_for),
|
|
551
|
+
verbose=self.verbose,
|
|
552
|
+
),
|
|
553
|
+
)
|
|
554
|
+
if self.verbose:
|
|
555
|
+
print(f"[method_to_onnx] dynamic_batch_for={self.dynamic_batch_for}")
|
|
556
|
+
print(f"[method_to_onnx] dynamic_shapes with batch={nds}")
|
|
557
|
+
else:
|
|
558
|
+
a, kw = self._inputs[-1]
|
|
559
|
+
nds = [self.dynamic_shapes]
|
|
560
|
+
if self.verbose:
|
|
561
|
+
print(f"[method_to_onnx] export args={string_type(a, with_shape=True)}")
|
|
562
|
+
print(f"[method_to_onnx] export kwargs={string_type(kw, with_shape=True)}")
|
|
563
|
+
print(f"[method_to_onnx] dynamic_shapes={string_type(nds)}")
|
|
564
|
+
if self._patch_kwargs is None:
|
|
565
|
+
to_onnx(
|
|
566
|
+
compiled_model,
|
|
567
|
+
args=a,
|
|
568
|
+
kwargs=kw,
|
|
569
|
+
dynamic_shapes=nds[-1],
|
|
570
|
+
**self._to_onnx_kwargs,
|
|
571
|
+
)
|
|
572
|
+
return
|
|
573
|
+
from ..torch_export_patches import torch_export_patches
|
|
574
|
+
|
|
575
|
+
with torch_export_patches(**self._patch_kwargs):
|
|
576
|
+
to_onnx(
|
|
577
|
+
compiled_model,
|
|
578
|
+
args=a,
|
|
579
|
+
kwargs=kw,
|
|
580
|
+
dynamic_shapes=nds[-1],
|
|
581
|
+
**self._to_onnx_kwargs,
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
@classmethod
|
|
585
|
+
def make_empty_cache_from_others(cls, examples: List[Any]) -> Any:
|
|
586
|
+
"""Builds an empty cache based on existing one."""
|
|
587
|
+
unique_types = {type(t) for t in examples}
|
|
588
|
+
assert (
|
|
589
|
+
len(unique_types) == 1
|
|
590
|
+
), f"Unable to guess an empty cache from {string_type(examples, with_shape=True)}"
|
|
591
|
+
unique_type = unique_types.pop()
|
|
592
|
+
if unique_type == torch.Tensor:
|
|
593
|
+
shapes = [t.shape for t in examples]
|
|
594
|
+
assert len(set(shapes)) > 1, f"Unable to guess an empty shape from shapes {shapes}"
|
|
595
|
+
ranks = {len(s) for s in shapes}
|
|
596
|
+
assert len(ranks) == 1, f"Ranks are different in {shapes}"
|
|
597
|
+
rank = ranks.pop()
|
|
598
|
+
new_shape = []
|
|
599
|
+
for i in range(rank):
|
|
600
|
+
dims = [t.shape[i] for t in examples]
|
|
601
|
+
if len(set(dims)) == 1:
|
|
602
|
+
new_shape.append(dims[0])
|
|
603
|
+
else:
|
|
604
|
+
# The empty shape
|
|
605
|
+
new_shape.append(0)
|
|
606
|
+
example = examples[0]
|
|
607
|
+
return torch.empty(tuple(new_shape), dtype=example.dtype, device=example.device)
|
|
608
|
+
assert (
|
|
609
|
+
unique_type.__name__ == "DynamicCache"
|
|
610
|
+
), f"This is not implemented for class {unique_type}"
|
|
611
|
+
caches = [CacheKeyValue(dc) for dc in examples]
|
|
612
|
+
caches_list = [dc.aslist() for dc in caches]
|
|
613
|
+
empty = [
|
|
614
|
+
cls.make_empty_cache_from_others([caches_list[i][k] for i in range(len(examples))])
|
|
615
|
+
for k in range(len(caches_list[0]))
|
|
616
|
+
]
|
|
617
|
+
empty_cache = CacheKeyValue(
|
|
618
|
+
empty, cls_layers=caches[0].cls_layers
|
|
619
|
+
).make_dynamic_cache()
|
|
620
|
+
return empty_cache
|
|
621
|
+
|
|
622
|
+
@classmethod
|
|
623
|
+
def add_empty_cache_if_needed(cls, inputs: List[Any]) -> List[Any]:
|
|
624
|
+
"""
|
|
625
|
+
Adds empty cache if needed as onnxruntime needs an empty cache,
|
|
626
|
+
not a missing cache. It only works if inputs are defined as a dictionary.
|
|
627
|
+
"""
|
|
628
|
+
if all(isinstance(t, tuple) for t in inputs) and all(
|
|
629
|
+
len(t) == 2 and isinstance(t[0], tuple) and isinstance(t[1], dict) and not t[0]
|
|
630
|
+
for t in inputs
|
|
631
|
+
):
|
|
632
|
+
dict_part = [t[1] for t in inputs]
|
|
633
|
+
res = cls.add_empty_cache_if_needed(dict_part)
|
|
634
|
+
return [(tuple(), d) for d in res]
|
|
635
|
+
if any(not isinstance(t, dict) for t in inputs):
|
|
636
|
+
return inputs
|
|
637
|
+
all_keys = set()
|
|
638
|
+
for input_set in inputs:
|
|
639
|
+
all_keys |= set(input_set)
|
|
640
|
+
# even though the inputs are defined as a dictionary, it is better
|
|
641
|
+
# to keep the same order
|
|
642
|
+
ordered = None
|
|
643
|
+
for input_set in inputs:
|
|
644
|
+
if set(input_set) == all_keys:
|
|
645
|
+
ordered = list(input_set)
|
|
646
|
+
break
|
|
647
|
+
new_inputs = []
|
|
648
|
+
for input_set in inputs:
|
|
649
|
+
if set(input_set) == all_keys:
|
|
650
|
+
new_inputs.append(input_set)
|
|
651
|
+
continue
|
|
652
|
+
missing = {k for k in all_keys if k not in input_set}
|
|
653
|
+
input_set_copy = input_set.copy()
|
|
654
|
+
for miss in missing:
|
|
655
|
+
input_set_copy[miss] = cls.make_empty_cache_from_others(
|
|
656
|
+
[sub[miss] for sub in inputs if miss in sub]
|
|
657
|
+
)
|
|
658
|
+
new_inputs.append({k: input_set_copy[k] for k in ordered}) # type: ignore[union-attr]
|
|
659
|
+
return new_inputs
|
|
660
|
+
|
|
661
|
+
@classmethod
|
|
662
|
+
def _expand_batch_dimension(cls, obj: Any, expand_for: Sequence[Union[int, str]]) -> Any:
|
|
663
|
+
expand_for_args = {i for i in expand_for if isinstance(i, int)}
|
|
664
|
+
expand_for_kwargs = {i for i in expand_for if isinstance(i, str)}
|
|
665
|
+
if isinstance(obj, tuple):
|
|
666
|
+
return tuple(
|
|
667
|
+
o if i not in expand_for_args else cls._expand_batch_dimension_input(o, i)
|
|
668
|
+
for i, o in enumerate(obj)
|
|
669
|
+
)
|
|
670
|
+
assert isinstance(obj, dict), f"Unexpected type {type(obj)}"
|
|
671
|
+
return {
|
|
672
|
+
k: v if k not in expand_for_kwargs else cls._expand_batch_dimension_input(v, k)
|
|
673
|
+
for k, v in obj.items()
|
|
674
|
+
}
|
|
675
|
+
|
|
676
|
+
@classmethod
|
|
677
|
+
def _expand_batch_dimension_input(cls, obj: Any, msg: Union[str, int]) -> Any:
|
|
678
|
+
if isinstance(obj, torch.Tensor):
|
|
679
|
+
assert obj.shape[0] == 1, (
|
|
680
|
+
f"Are you sure to expoand input {msg!r}, "
|
|
681
|
+
f"batch size is not 1 and shape={obj.shape}"
|
|
682
|
+
)
|
|
683
|
+
sizes = [2, *obj.shape[1:]]
|
|
684
|
+
return obj.expand(*sizes)
|
|
685
|
+
if isinstance(obj, list):
|
|
686
|
+
return [
|
|
687
|
+
cls._expand_batch_dimension_input(o, f"{msg}[{i}]") for i, o in enumerate(obj)
|
|
688
|
+
]
|
|
689
|
+
if obj.__class__.__name__ == "DynamicCache":
|
|
690
|
+
dc = CacheKeyValue(obj)
|
|
691
|
+
flat = dc.aslist()
|
|
692
|
+
flat = cls._expand_batch_dimension_input(flat, msg)
|
|
693
|
+
return CacheKeyValue(flat, cls_layers=dc.cls_layers).make_dynamic_cache()
|
|
694
|
+
# This might end up in an infinite loop if no registration is done.
|
|
695
|
+
flat, _spec = torch.utils._pytree.tree_flatten(obj)
|
|
696
|
+
assert (
|
|
697
|
+
not isinstance(flat, list) or len(flat) != 1 or type(flat[0]) is not type(obj)
|
|
698
|
+
), f"class {type(obj)} was is not registered for serialization."
|
|
699
|
+
flat = cls._expand_batch_dimension_input(flat, msg)
|
|
700
|
+
return torch.utils._pytree.tree_unflatten(flat, _spec)
|
|
701
|
+
|
|
702
|
+
@classmethod
|
|
703
|
+
def _dynamic_batch_dimension(
|
|
704
|
+
cls, ds: Union[Tuple[Any, ...], Dict[str, Any]], dynamic_for: Sequence[Union[int, str]]
|
|
705
|
+
) -> Union[Tuple[Any, ...], Dict[str, Any]]:
|
|
706
|
+
if isinstance(ds, tuple):
|
|
707
|
+
return tuple(
|
|
708
|
+
(v if i not in dynamic_for else cls._dynamic_batch_dimension_input(v, i))
|
|
709
|
+
for i, v in enumerate(ds)
|
|
710
|
+
)
|
|
711
|
+
return {
|
|
712
|
+
k: (v if k not in dynamic_for else cls._dynamic_batch_dimension_input(v, k))
|
|
713
|
+
for k, v in ds.items()
|
|
714
|
+
}
|
|
715
|
+
|
|
716
|
+
@classmethod
|
|
717
|
+
def _dynamic_batch_dimension_input(cls, ds: Any, msg: Union[str, int]) -> Any:
|
|
718
|
+
if isinstance(ds, dict) and all(isinstance(k, int) for k in ds):
|
|
719
|
+
ds[0] = "batch"
|
|
720
|
+
return {k: v for k, v in sorted(ds.items())} # noqa: C416
|
|
721
|
+
if isinstance(ds, list):
|
|
722
|
+
return [
|
|
723
|
+
cls._dynamic_batch_dimension_input(o, f"{msg}[{i}]") for i, o in enumerate(ds)
|
|
724
|
+
]
|
|
725
|
+
raise NotImplementedError(f"cannot make first dimension dynamic for batch for {ds}")
|
|
726
|
+
|
|
727
|
+
def check_discrepancies(
|
|
728
|
+
self, atol: float = 1e-4, rtol: float = 0.1, hist=(0.1, 0.01), verbose: int = 0
|
|
729
|
+
) -> List[Dict[str, Union[str, int, float]]]:
|
|
730
|
+
"""
|
|
731
|
+
Computes the discrepancies between the saved inputs and outputs
|
|
732
|
+
with the saved onnx model.
|
|
733
|
+
|
|
734
|
+
:param atol: absolute tolerance, recommended values, 1e-4 for float, 1e-2 flot float16
|
|
735
|
+
:param rtol: relative tolerance
|
|
736
|
+
:param hist: thresholds, the function determines the number of discrepancies
|
|
737
|
+
above that threshold.
|
|
738
|
+
:param verbose: verbosity
|
|
739
|
+
:return: results, a list of dictionaries, ready to be consumed by a dataframe
|
|
740
|
+
"""
|
|
741
|
+
assert self._export_done, "The onnx export was not done."
|
|
742
|
+
assert os.path.exists(self._input_file), f"input file {self._input_file!r} not found"
|
|
743
|
+
assert os.path.exists(
|
|
744
|
+
self._output_file
|
|
745
|
+
), f"output file {self._output_file!r} not found"
|
|
746
|
+
filename = self._to_onnx_kwargs["filename"]
|
|
747
|
+
assert isinstance(filename, str) and os.path.exists(
|
|
748
|
+
filename
|
|
749
|
+
), f"onnx file {filename!r} not found"
|
|
750
|
+
classes = [
|
|
751
|
+
cls
|
|
752
|
+
for cls in self._serialization_classes
|
|
753
|
+
if cls not in {int, float, bool, str, torch.Tensor, list, set, dict, torch.device}
|
|
754
|
+
]
|
|
755
|
+
if verbose:
|
|
756
|
+
print(f"[method_to_onnx.check_discrepancies] register classes {classes}")
|
|
757
|
+
print(f"[method_to_onnx.check_discrepancies] load {self._input_file!r}")
|
|
758
|
+
with torch.serialization.safe_globals(classes):
|
|
759
|
+
inputs = torch.load(self._input_file)
|
|
760
|
+
if verbose:
|
|
761
|
+
print(f"[method_to_onnx.check_discrepancies] load {self._output_file!r}")
|
|
762
|
+
with torch.serialization.safe_globals(classes):
|
|
763
|
+
outputs = torch.load(self._output_file)
|
|
764
|
+
assert len(inputs) == len(outputs), (
|
|
765
|
+
f"Unexpected number of inputs {len(inputs)} and outputs {len(outputs)}, "
|
|
766
|
+
f"inputs={string_type(inputs, with_shape=True)}, "
|
|
767
|
+
f"outputs={string_type(outputs, with_shape=True)}"
|
|
768
|
+
)
|
|
769
|
+
if verbose:
|
|
770
|
+
print(f"[method_to_onnx.check_discrepancies] create onnx session {filename!r}")
|
|
771
|
+
sess = OnnxruntimeEvaluator(filename, whole=True)
|
|
772
|
+
input_names = sess.input_names
|
|
773
|
+
if verbose:
|
|
774
|
+
print(f"[method_to_onnx.check_discrepancies] input_names={input_names}")
|
|
775
|
+
print(
|
|
776
|
+
f"[method_to_onnx.check_discrepancies] onnx_shapes="
|
|
777
|
+
f"{', '.join(pretty_onnx(i) for i in sess.input_types)}"
|
|
778
|
+
)
|
|
779
|
+
data = []
|
|
780
|
+
for i, (input, (output, latency)) in enumerate(
|
|
781
|
+
zip(self.add_empty_cache_if_needed(inputs), outputs)
|
|
782
|
+
):
|
|
783
|
+
if verbose:
|
|
784
|
+
if verbose > 1:
|
|
785
|
+
print(
|
|
786
|
+
f"[method_to_onnx.check_discrepancies] process input {i}: "
|
|
787
|
+
f"{string_type(input, with_shape=True)}"
|
|
788
|
+
)
|
|
789
|
+
print(
|
|
790
|
+
f"[method_to_onnx.check_discrepancies] expects: "
|
|
791
|
+
f"{string_type(output, with_shape=True)}"
|
|
792
|
+
)
|
|
793
|
+
else:
|
|
794
|
+
print(
|
|
795
|
+
f"[method_to_onnx.check_discrepancies] process input {i} "
|
|
796
|
+
f"#args={len(input[0])} #kwargs={len(input[1])}"
|
|
797
|
+
)
|
|
798
|
+
|
|
799
|
+
flat_inputs = flatten_object(input, drop_keys=True)
|
|
800
|
+
if verbose > 1:
|
|
801
|
+
print(
|
|
802
|
+
f"[method_to_onnx.check_discrepancies] "
|
|
803
|
+
f"input={string_type(input, with_shape=True)}"
|
|
804
|
+
)
|
|
805
|
+
print(
|
|
806
|
+
f"[method_to_onnx.check_discrepancies] "
|
|
807
|
+
f"flat_inputs={string_type(flat_inputs, with_shape=True)}"
|
|
808
|
+
)
|
|
809
|
+
if len(flat_inputs) < len(input_names):
|
|
810
|
+
# not implemented yet, it is caused by a missing cache,
|
|
811
|
+
# which requires an empty cache instead
|
|
812
|
+
data.append(dict(index=i, duration_torch=latency, n_inputs=len(flat_inputs)))
|
|
813
|
+
continue
|
|
814
|
+
assert len(flat_inputs) == len(input_names), (
|
|
815
|
+
f"Length mismatch, expecting {len(input_names)} onnx inputs and got "
|
|
816
|
+
f"{len(flat_inputs)} flat torch inputs"
|
|
817
|
+
)
|
|
818
|
+
feeds = make_feeds(input_names, flat_inputs)
|
|
819
|
+
if verbose > 1:
|
|
820
|
+
print(
|
|
821
|
+
f"[method_to_onnx.check_discrepancies] "
|
|
822
|
+
f"feeds={string_type(feeds, with_shape=True)}"
|
|
823
|
+
)
|
|
824
|
+
begin = time.perf_counter()
|
|
825
|
+
ort_outputs = sess.run(None, feeds)
|
|
826
|
+
duration = time.perf_counter() - begin
|
|
827
|
+
diff = max_diff(output, ort_outputs, hist=hist)
|
|
828
|
+
if "rep" in diff and isinstance(diff["rep"], dict):
|
|
829
|
+
diff.update(diff["rep"])
|
|
830
|
+
del diff["rep"]
|
|
831
|
+
diff["SUCCESS"] = (
|
|
832
|
+
isinstance(diff["abs"], float)
|
|
833
|
+
and isinstance(diff["rel"], float)
|
|
834
|
+
and diff["abs"] < atol
|
|
835
|
+
and diff["rel"] < rtol
|
|
836
|
+
)
|
|
837
|
+
diff.update(
|
|
838
|
+
dict(
|
|
839
|
+
index=i,
|
|
840
|
+
duration_torch=latency,
|
|
841
|
+
ort_duration=duration,
|
|
842
|
+
n_inputs=len(flat_inputs),
|
|
843
|
+
)
|
|
844
|
+
)
|
|
845
|
+
if verbose > 1:
|
|
846
|
+
print(
|
|
847
|
+
f"[method_to_onnx.check_discrepancies] ort output "
|
|
848
|
+
f"{string_type(ort_outputs, with_shape=True)}"
|
|
849
|
+
)
|
|
850
|
+
print(f"[method_to_onnx.check_discrepancies] diff {string_diff(diff)}")
|
|
851
|
+
data.append(diff)
|
|
852
|
+
if verbose:
|
|
853
|
+
print("[method_to_onnx.check_discrepancies] done")
|
|
854
|
+
return data
|
|
855
|
+
|
|
856
|
+
@classmethod
|
|
857
|
+
def _apply_known_shape_pattern(
|
|
858
|
+
cls, shape: Dict[int, Any], pattern: Dict[int, str]
|
|
859
|
+
) -> Dict[int, Any]:
|
|
860
|
+
return {k: pattern.get(k, v) for k, v in shape.items()}
|
|
861
|
+
|
|
862
|
+
@classmethod
|
|
863
|
+
def get_dynamic_shape_patterns(cls) -> Dict[str, Any]:
|
|
864
|
+
"""
|
|
865
|
+
Returns the known patterns for the dynamic shapes.
|
|
866
|
+
|
|
867
|
+
.. runpython::
|
|
868
|
+
:showcode:
|
|
869
|
+
|
|
870
|
+
import pprint
|
|
871
|
+
from onnx_diagnostic.export.api import WrapperToExportMethodToOnnx
|
|
872
|
+
pprint.pprint(WrapperToExportMethodToOnnx.get_dynamic_shape_patterns())
|
|
873
|
+
"""
|
|
874
|
+
return {
|
|
875
|
+
"LLM.text": {
|
|
876
|
+
"cache_position": {0: "seqlength"},
|
|
877
|
+
"past_key_values": {0: "batch", 2: "pastlength"},
|
|
878
|
+
"input_ids": {0: "batch", 1: "seqlength"},
|
|
879
|
+
"attention_mask": {0: "batch", 1: "totallength"}, # pastlength+seqlength
|
|
880
|
+
}
|
|
881
|
+
}
|
|
882
|
+
|
|
883
|
+
@classmethod
|
|
884
|
+
def rename_dynamic_shapes(cls, ds: Dict[str, Any], verbose: int = 0) -> Dict[str, Any]:
|
|
885
|
+
"""
|
|
886
|
+
Renames the dynamic shapes with names.
|
|
887
|
+
Tries to rename any dynamic dimnesion dimension
|
|
888
|
+
before export. It is not very clever, it just tries
|
|
889
|
+
to recognize a known configuration based on input names.
|
|
890
|
+
Dimension names in dynamic shapes are renamed if *ds* has
|
|
891
|
+
the same number of named arguments as the one of the patterns
|
|
892
|
+
returned by function :meth:`get_dynamic_shape_patterns
|
|
893
|
+
<onnx_diagnostic.export.api.WrapperToExportMethodToOnnx.get_dynamic_shape_patterns>`.
|
|
894
|
+
"""
|
|
895
|
+
is_shape = lambda s: isinstance(s, dict) and all( # noqa: E731
|
|
896
|
+
isinstance(_, int) for _ in s
|
|
897
|
+
)
|
|
898
|
+
llm_patterns = cls.get_dynamic_shape_patterns()
|
|
899
|
+
for pattern_name, pattern_shape in llm_patterns.items():
|
|
900
|
+
if len(set(ds) & set(pattern_shape)) == len(pattern_shape):
|
|
901
|
+
if verbose:
|
|
902
|
+
print(
|
|
903
|
+
f"[method_to_onnx.rename_dynamic_shapes] "
|
|
904
|
+
f"apply pattern shapes {pattern_name!r}"
|
|
905
|
+
)
|
|
906
|
+
new_ds = {}
|
|
907
|
+
for k, v in ds.items():
|
|
908
|
+
if k not in pattern_shape:
|
|
909
|
+
new_ds[k] = v
|
|
910
|
+
continue
|
|
911
|
+
if is_shape(v):
|
|
912
|
+
# A shape
|
|
913
|
+
new_ds[k] = cls._apply_known_shape_pattern(v, pattern_shape[k])
|
|
914
|
+
elif isinstance(v, list):
|
|
915
|
+
# A cache
|
|
916
|
+
new_ds[k] = [
|
|
917
|
+
(
|
|
918
|
+
cls._apply_known_shape_pattern(s, pattern_shape[k])
|
|
919
|
+
if is_shape(s)
|
|
920
|
+
else s
|
|
921
|
+
)
|
|
922
|
+
for s in v
|
|
923
|
+
]
|
|
924
|
+
return new_ds
|
|
925
|
+
|
|
926
|
+
# unchanged
|
|
927
|
+
return ds
|
|
928
|
+
|
|
929
|
+
|
|
930
|
+
def method_to_onnx(
|
|
931
|
+
mod: "torch.nn.Module",
|
|
932
|
+
method_name: str = "forward",
|
|
933
|
+
input_names: Optional[Sequence[str]] = None,
|
|
934
|
+
target_opset: Optional[Union[int, Dict[str, int]]] = None,
|
|
935
|
+
verbose: int = 0,
|
|
936
|
+
filename: Optional[str] = None,
|
|
937
|
+
output_names: Optional[List[str]] = None,
|
|
938
|
+
output_dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
939
|
+
exporter: str = "onnx-dynamo",
|
|
940
|
+
exporter_kwargs: Optional[Dict[str, Any]] = None,
|
|
941
|
+
save_ep: Optional[str] = None,
|
|
942
|
+
optimize: bool = True,
|
|
943
|
+
optimizer_for_ort: bool = True,
|
|
944
|
+
use_control_flow_dispatcher: bool = False,
|
|
945
|
+
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
|
|
946
|
+
inline: bool = True,
|
|
947
|
+
convert_after_n_calls: int = 2,
|
|
948
|
+
patch_kwargs: Optional[Dict[str, Any]] = None,
|
|
949
|
+
skip_kwargs_names: Optional[Set[str]] = None,
|
|
950
|
+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
951
|
+
dynamic_batch_for: Optional[Sequence[Union[int, str]]] = None,
|
|
952
|
+
expand_batch_for: Optional[Sequence[Union[int, str]]] = None,
|
|
953
|
+
) -> Callable:
|
|
954
|
+
"""
|
|
955
|
+
Exports one method into ONNX for a module into ONNX.
|
|
956
|
+
It returns a new method which must be called by the user
|
|
957
|
+
at least twice with different values for the dynamic dimension
|
|
958
|
+
between triggering the conversion into ONNX.
|
|
959
|
+
|
|
960
|
+
:param mod_meth: function to export into ONNX
|
|
961
|
+
:param input_names: input names for the onnx model (optional)
|
|
962
|
+
:param target_opset: opset to target, if not specified, each converter
|
|
963
|
+
keeps its default value
|
|
964
|
+
:param verbose: verbosity level
|
|
965
|
+
:param filename: output filename, mandatory, the onnx model is saved on disk
|
|
966
|
+
:param output_names: to change the output of the onnx model
|
|
967
|
+
:param output_dynamic_shapes: to overwrite the dynamic shapes names
|
|
968
|
+
:param exporter: exporter to use (``onnx-dynamo``, ``modelbuilder``, ``custom``)
|
|
969
|
+
:param exporter_kwargs: additional parameters sent to the exporter
|
|
970
|
+
:param save_ep: saves the exported program
|
|
971
|
+
:param optimize: optimizes the model
|
|
972
|
+
:param optimizer_for_ort: optimizes the model for onnxruntime
|
|
973
|
+
:param use_control_flow_dispatcher: use the dispatcher created to supported
|
|
974
|
+
custom loops (see :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx`)
|
|
975
|
+
:param onnx_plugs: the code was modified to replace some parts with onnx translation
|
|
976
|
+
:param inline: inline local functions
|
|
977
|
+
:param convert_after_n_calls: converts the model after this number of calls.
|
|
978
|
+
:param patch_kwargs: patch arguments
|
|
979
|
+
:param skip_kwargs_names: use default values for these parameters part of
|
|
980
|
+
the signature of the method to export
|
|
981
|
+
:param dynamic_shapes: dynamic shapes to use if the guessed ones are not right
|
|
982
|
+
:param dynamic_batch_for: LLM are usually called with a batch size equal to 1,
|
|
983
|
+
but the export may benefit from having a dynamic batch size,
|
|
984
|
+
this parameter forces the input specified in this set to have the first dimension
|
|
985
|
+
be dynamic
|
|
986
|
+
:param expand_batch_for: LLM are usually called with a batch size equal to 1,
|
|
987
|
+
but the export may benefit from having another value for the batch size,
|
|
988
|
+
this parameter forces the input specified in this set to be expanded
|
|
989
|
+
to 2 if the batch size is one
|
|
990
|
+
:return: the output of the selected exporter, usually a structure including
|
|
991
|
+
an onnx model
|
|
992
|
+
|
|
993
|
+
See :ref:`l-plot-tiny-llm-export-method-generate` for an example.
|
|
994
|
+
"""
|
|
995
|
+
wrapped_model = WrapperToExportMethodToOnnx(
|
|
996
|
+
mod=mod,
|
|
997
|
+
method_name=method_name,
|
|
998
|
+
input_names=input_names,
|
|
999
|
+
target_opset=target_opset,
|
|
1000
|
+
verbose=verbose,
|
|
1001
|
+
filename=filename,
|
|
1002
|
+
output_names=output_names,
|
|
1003
|
+
output_dynamic_shapes=output_dynamic_shapes,
|
|
1004
|
+
exporter=exporter,
|
|
1005
|
+
exporter_kwargs=exporter_kwargs,
|
|
1006
|
+
save_ep=save_ep,
|
|
1007
|
+
optimize=optimize,
|
|
1008
|
+
optimizer_for_ort=optimizer_for_ort,
|
|
1009
|
+
use_control_flow_dispatcher=use_control_flow_dispatcher,
|
|
1010
|
+
onnx_plugs=onnx_plugs,
|
|
1011
|
+
inline=inline,
|
|
1012
|
+
convert_after_n_calls=convert_after_n_calls,
|
|
1013
|
+
patch_kwargs=patch_kwargs,
|
|
1014
|
+
skip_kwargs_names=skip_kwargs_names,
|
|
1015
|
+
dynamic_shapes=dynamic_shapes,
|
|
1016
|
+
dynamic_batch_for=dynamic_batch_for,
|
|
1017
|
+
expand_batch_for=expand_batch_for,
|
|
1018
|
+
)
|
|
1019
|
+
return wrapped_model
|