onnx-diagnostic 0.8.8__py3-none-any.whl → 0.8.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/doc.py +258 -8
- onnx_diagnostic/export/api.py +478 -17
- onnx_diagnostic/export/dynamic_shapes.py +21 -6
- onnx_diagnostic/export/shape_helper.py +0 -8
- onnx_diagnostic/helpers/cache_helper.py +98 -13
- onnx_diagnostic/helpers/helper.py +6 -5
- onnx_diagnostic/helpers/onnx_helper.py +7 -0
- onnx_diagnostic/helpers/rt_helper.py +14 -1
- onnx_diagnostic/helpers/torch_helper.py +22 -9
- onnx_diagnostic/tasks/image_text_to_text.py +4 -1
- onnx_diagnostic/tasks/text_generation.py +17 -17
- onnx_diagnostic/torch_export_patches/eval/__init__.py +1 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +62 -38
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -9
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +42 -30
- {onnx_diagnostic-0.8.8.dist-info → onnx_diagnostic-0.8.9.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.8.8.dist-info → onnx_diagnostic-0.8.9.dist-info}/RECORD +21 -21
- {onnx_diagnostic-0.8.8.dist-info → onnx_diagnostic-0.8.9.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.8.dist-info → onnx_diagnostic-0.8.9.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.8.dist-info → onnx_diagnostic-0.8.9.dist-info}/top_level.txt +0 -0
onnx_diagnostic/export/api.py
CHANGED
|
@@ -1,11 +1,18 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
import os
|
|
3
3
|
import textwrap
|
|
4
|
+
import time
|
|
5
|
+
from collections.abc import Mapping, Iterable
|
|
4
6
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
|
|
5
7
|
import torch
|
|
6
8
|
from .dynamic_shapes import ModelInputs
|
|
7
9
|
from .onnx_plug import EagerDirectReplacementWithOnnx
|
|
8
|
-
from ..helpers import string_type
|
|
10
|
+
from ..helpers import flatten_object, max_diff, string_diff, string_type
|
|
11
|
+
from ..helpers.cache_helper import CacheKeyValue
|
|
12
|
+
from ..helpers.torch_helper import torch_deepcopy
|
|
13
|
+
from ..helpers.rt_helper import make_feeds
|
|
14
|
+
from ..helpers.onnx_helper import pretty_onnx
|
|
15
|
+
from ..reference import OnnxruntimeEvaluator
|
|
9
16
|
|
|
10
17
|
|
|
11
18
|
def get_main_dispatcher(
|
|
@@ -314,10 +321,11 @@ def to_onnx(
|
|
|
314
321
|
raise ValueError(f"Unknown exporter={exporter!r}")
|
|
315
322
|
|
|
316
323
|
|
|
317
|
-
class
|
|
324
|
+
class WrapperToExportMethodToOnnx(torch.nn.Module):
|
|
318
325
|
"""
|
|
319
326
|
Wraps an existing models in order to spy on inputs.
|
|
320
|
-
This is used by :func:`onnx_diagnostic.export.api.method_to_onnx
|
|
327
|
+
This is used by :func:`onnx_diagnostic.export.api.method_to_onnx`
|
|
328
|
+
or :ref:`l-plot-tiny-llm-export-method-generate` for an example.
|
|
321
329
|
"""
|
|
322
330
|
|
|
323
331
|
def __init__(
|
|
@@ -342,6 +350,8 @@ class _WrapperToExportMethodToOnnx(torch.nn.Module):
|
|
|
342
350
|
patch_kwargs: Optional[Dict[str, Any]] = None,
|
|
343
351
|
skip_kwargs_names: Optional[Set[str]] = None,
|
|
344
352
|
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
353
|
+
dynamic_batch_for: Optional[Sequence[Union[int, str]]] = None,
|
|
354
|
+
expand_batch_for: Optional[Sequence[Union[int, str]]] = None,
|
|
345
355
|
):
|
|
346
356
|
super().__init__()
|
|
347
357
|
self._model_to_call = mod
|
|
@@ -351,13 +361,17 @@ class _WrapperToExportMethodToOnnx(torch.nn.Module):
|
|
|
351
361
|
if method_name == "forward"
|
|
352
362
|
else getattr(mod, method_name)
|
|
353
363
|
)
|
|
364
|
+
self._signature = inspect.signature(self._method_call)
|
|
354
365
|
self._inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = []
|
|
366
|
+
self._outputs: List[Any] = []
|
|
355
367
|
self._convert_after_n_calls = convert_after_n_calls
|
|
356
368
|
self._patch_kwargs = patch_kwargs
|
|
357
369
|
self._method_src = None
|
|
358
370
|
self.verbose = verbose
|
|
359
371
|
self.skip_kwargs_names = skip_kwargs_names
|
|
360
372
|
self.dynamic_shapes = dynamic_shapes
|
|
373
|
+
self.expand_batch_for = expand_batch_for
|
|
374
|
+
self.dynamic_batch_for = dynamic_batch_for
|
|
361
375
|
self._to_onnx_kwargs = dict(
|
|
362
376
|
input_names=input_names,
|
|
363
377
|
target_opset=target_opset,
|
|
@@ -375,42 +389,121 @@ class _WrapperToExportMethodToOnnx(torch.nn.Module):
|
|
|
375
389
|
inline=inline,
|
|
376
390
|
)
|
|
377
391
|
self._export_done = False
|
|
392
|
+
self._serialization_classes: Set[type] = set()
|
|
378
393
|
|
|
379
394
|
def __str__(self) -> str:
|
|
395
|
+
"usual"
|
|
380
396
|
return self.__repr__()
|
|
381
397
|
|
|
382
398
|
def __repr__(self) -> str:
|
|
399
|
+
"usual"
|
|
383
400
|
return (
|
|
384
401
|
f"{self.__class__.__name__}({self._model_to_call.__class__.__name__}."
|
|
385
402
|
f"{self._method_name})"
|
|
386
403
|
)
|
|
387
404
|
|
|
405
|
+
def _collect_classes(self, obj):
|
|
406
|
+
if obj is None or isinstance(obj, torch.Tensor):
|
|
407
|
+
return
|
|
408
|
+
cls = type(obj)
|
|
409
|
+
if cls.__module__ not in ("builtins",):
|
|
410
|
+
self._serialization_classes.add(cls)
|
|
411
|
+
if hasattr(obj, "__dict__"):
|
|
412
|
+
for v in vars(obj).values():
|
|
413
|
+
self._collect_classes(v)
|
|
414
|
+
return
|
|
415
|
+
if isinstance(obj, Mapping):
|
|
416
|
+
for v in obj.values():
|
|
417
|
+
self._collect_classes(v)
|
|
418
|
+
return
|
|
419
|
+
if isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)):
|
|
420
|
+
for v in obj:
|
|
421
|
+
self._collect_classes(v)
|
|
422
|
+
return
|
|
423
|
+
|
|
424
|
+
def _reorder_kwargs(self, kwargs):
|
|
425
|
+
new_kwargs = {k: kwargs[k] for k in self._signature.parameters if k in kwargs}
|
|
426
|
+
for k, v in kwargs.items():
|
|
427
|
+
if k not in new_kwargs:
|
|
428
|
+
new_kwargs[k] = v
|
|
429
|
+
return new_kwargs
|
|
430
|
+
|
|
388
431
|
def forward(self, *args, **kwargs):
|
|
389
432
|
if not self._export_done:
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
)
|
|
400
|
-
|
|
433
|
+
inp_args = args
|
|
434
|
+
# filters out the inputs not desired, int, float, bool, None
|
|
435
|
+
# are considered as constant for the exporter, they are removed
|
|
436
|
+
# from the named arguments.
|
|
437
|
+
inp_kwargs = (
|
|
438
|
+
kwargs
|
|
439
|
+
if not kwargs
|
|
440
|
+
else {
|
|
441
|
+
k: v
|
|
442
|
+
for k, v in kwargs.items()
|
|
443
|
+
if v is not None
|
|
444
|
+
and (not self.skip_kwargs_names or k not in self.skip_kwargs_names)
|
|
445
|
+
and not isinstance(v, (bool, int, float))
|
|
446
|
+
}
|
|
401
447
|
)
|
|
448
|
+
if self.expand_batch_for:
|
|
449
|
+
# extends the inputs to artificially create a batch dimension != 1.
|
|
450
|
+
inp_args = self._expand_batch_dimension(inp_args, self.expand_batch_for)
|
|
451
|
+
inp_kwargs = self._expand_batch_dimension(inp_kwargs, self.expand_batch_for)
|
|
452
|
+
inp_args, inp_kwargs = torch_deepcopy((inp_args, inp_kwargs))
|
|
453
|
+
# reorders the parameter following the method signature.
|
|
454
|
+
inp_kwargs = self._reorder_kwargs(inp_kwargs)
|
|
455
|
+
# stores the inputs
|
|
456
|
+
self._inputs.append((inp_args, inp_kwargs))
|
|
457
|
+
|
|
402
458
|
if self.verbose:
|
|
403
459
|
print(
|
|
404
460
|
f"[method_to_onnx] input[{len(self._inputs)-1}]: "
|
|
405
461
|
f"{string_type(self._inputs[-1], with_shape=True)}"
|
|
406
462
|
)
|
|
463
|
+
|
|
407
464
|
if len(self._inputs) >= self._convert_after_n_calls:
|
|
465
|
+
# conversion starts after _convert_after_n_calls calls to the forward method
|
|
466
|
+
name = os.path.splitext(self._to_onnx_kwargs["filename"])[0]
|
|
467
|
+
input_file = f"{name}.inputs.pt"
|
|
468
|
+
self._input_file = input_file
|
|
469
|
+
if self.verbose:
|
|
470
|
+
print(
|
|
471
|
+
f"[method_to_onnx] save {len(self._inputs)} inputs in {input_file!r}"
|
|
472
|
+
)
|
|
473
|
+
torch.save(self._inputs, input_file)
|
|
408
474
|
self._convert_method_to_onnx()
|
|
409
|
-
del self._inputs[:]
|
|
410
475
|
self._export_done = True
|
|
411
|
-
|
|
476
|
+
|
|
477
|
+
# calls the inner method (no change here)
|
|
478
|
+
begin = time.perf_counter()
|
|
479
|
+
res = self._method_call(*args, **kwargs)
|
|
480
|
+
duration = time.perf_counter() - begin
|
|
481
|
+
self._collect_classes([args, kwargs, res])
|
|
482
|
+
if self._inputs:
|
|
483
|
+
# stores the outputs if discrepancies need to be checked
|
|
484
|
+
self._outputs.append((torch_deepcopy(res), duration))
|
|
485
|
+
assert len(self._inputs) == len(self._outputs), (
|
|
486
|
+
f"Number of inputs {len(self._inputs)} and "
|
|
487
|
+
f"outputs {len(self._outputs)} are different."
|
|
488
|
+
)
|
|
489
|
+
if self._export_done:
|
|
490
|
+
name = os.path.splitext(self._to_onnx_kwargs["filename"])[0]
|
|
491
|
+
output_file = f"{name}.outputs.pt"
|
|
492
|
+
if self.verbose:
|
|
493
|
+
print(
|
|
494
|
+
f"[method_to_onnx] save {len(self._outputs)} "
|
|
495
|
+
f"outputs in {output_file!r}"
|
|
496
|
+
)
|
|
497
|
+
torch.save(self._outputs, output_file)
|
|
498
|
+
self._output_file = output_file
|
|
499
|
+
del self._inputs[:]
|
|
500
|
+
del self._outputs[:]
|
|
501
|
+
return res
|
|
412
502
|
|
|
413
503
|
def _convert_method_to_onnx(self):
|
|
504
|
+
for args, kwargs in self._inputs:
|
|
505
|
+
self._serialization_classes |= {type(a) for a in args}
|
|
506
|
+
self._serialization_classes |= {type(a) for a in kwargs.values()}
|
|
414
507
|
|
|
415
508
|
def make_method(self):
|
|
416
509
|
inner_sig = inspect.signature(self._method_call)
|
|
@@ -450,6 +543,17 @@ class _WrapperToExportMethodToOnnx(torch.nn.Module):
|
|
|
450
543
|
if self.verbose:
|
|
451
544
|
print(f"[method_to_onnx] guess_dynamic_shapes={string_type(ds)}")
|
|
452
545
|
a, kw, nds = mi.move_to_kwargs(*self._inputs[-1], ds)
|
|
546
|
+
if self.dynamic_batch_for:
|
|
547
|
+
nds = (
|
|
548
|
+
self._dynamic_batch_dimension(nds[0], self.dynamic_batch_for),
|
|
549
|
+
self.rename_dynamic_shapes(
|
|
550
|
+
self._dynamic_batch_dimension(nds[1], self.dynamic_batch_for),
|
|
551
|
+
verbose=self.verbose,
|
|
552
|
+
),
|
|
553
|
+
)
|
|
554
|
+
if self.verbose:
|
|
555
|
+
print(f"[method_to_onnx] dynamic_batch_for={self.dynamic_batch_for}")
|
|
556
|
+
print(f"[method_to_onnx] dynamic_shapes with batch={nds}")
|
|
453
557
|
else:
|
|
454
558
|
a, kw = self._inputs[-1]
|
|
455
559
|
nds = [self.dynamic_shapes]
|
|
@@ -477,6 +581,351 @@ class _WrapperToExportMethodToOnnx(torch.nn.Module):
|
|
|
477
581
|
**self._to_onnx_kwargs,
|
|
478
582
|
)
|
|
479
583
|
|
|
584
|
+
@classmethod
|
|
585
|
+
def make_empty_cache_from_others(cls, examples: List[Any]) -> Any:
|
|
586
|
+
"""Builds an empty cache based on existing one."""
|
|
587
|
+
unique_types = {type(t) for t in examples}
|
|
588
|
+
assert (
|
|
589
|
+
len(unique_types) == 1
|
|
590
|
+
), f"Unable to guess an empty cache from {string_type(examples, with_shape=True)}"
|
|
591
|
+
unique_type = unique_types.pop()
|
|
592
|
+
if unique_type == torch.Tensor:
|
|
593
|
+
shapes = [t.shape for t in examples]
|
|
594
|
+
assert len(set(shapes)) > 1, f"Unable to guess an empty shape from shapes {shapes}"
|
|
595
|
+
ranks = {len(s) for s in shapes}
|
|
596
|
+
assert len(ranks) == 1, f"Ranks are different in {shapes}"
|
|
597
|
+
rank = ranks.pop()
|
|
598
|
+
new_shape = []
|
|
599
|
+
for i in range(rank):
|
|
600
|
+
dims = [t.shape[i] for t in examples]
|
|
601
|
+
if len(set(dims)) == 1:
|
|
602
|
+
new_shape.append(dims[0])
|
|
603
|
+
else:
|
|
604
|
+
# The empty shape
|
|
605
|
+
new_shape.append(0)
|
|
606
|
+
example = examples[0]
|
|
607
|
+
return torch.empty(tuple(new_shape), dtype=example.dtype, device=example.device)
|
|
608
|
+
assert (
|
|
609
|
+
unique_type.__name__ == "DynamicCache"
|
|
610
|
+
), f"This is not implemented for class {unique_type}"
|
|
611
|
+
caches = [CacheKeyValue(dc) for dc in examples]
|
|
612
|
+
caches_list = [dc.aslist() for dc in caches]
|
|
613
|
+
empty = [
|
|
614
|
+
cls.make_empty_cache_from_others([caches_list[i][k] for i in range(len(examples))])
|
|
615
|
+
for k in range(len(caches_list[0]))
|
|
616
|
+
]
|
|
617
|
+
empty_cache = CacheKeyValue(
|
|
618
|
+
empty, cls_layers=caches[0].cls_layers
|
|
619
|
+
).make_dynamic_cache()
|
|
620
|
+
return empty_cache
|
|
621
|
+
|
|
622
|
+
@classmethod
|
|
623
|
+
def add_empty_cache_if_needed(cls, inputs: List[Any]) -> List[Any]:
|
|
624
|
+
"""
|
|
625
|
+
Adds empty cache if needed as onnxruntime needs an empty cache,
|
|
626
|
+
not a missing cache. It only works if inputs are defined as a dictionary.
|
|
627
|
+
"""
|
|
628
|
+
if all(isinstance(t, tuple) for t in inputs) and all(
|
|
629
|
+
len(t) == 2 and isinstance(t[0], tuple) and isinstance(t[1], dict) and not t[0]
|
|
630
|
+
for t in inputs
|
|
631
|
+
):
|
|
632
|
+
dict_part = [t[1] for t in inputs]
|
|
633
|
+
res = cls.add_empty_cache_if_needed(dict_part)
|
|
634
|
+
return [(tuple(), d) for d in res]
|
|
635
|
+
if any(not isinstance(t, dict) for t in inputs):
|
|
636
|
+
return inputs
|
|
637
|
+
all_keys = set()
|
|
638
|
+
for input_set in inputs:
|
|
639
|
+
all_keys |= set(input_set)
|
|
640
|
+
# even though the inputs are defined as a dictionary, it is better
|
|
641
|
+
# to keep the same order
|
|
642
|
+
ordered = None
|
|
643
|
+
for input_set in inputs:
|
|
644
|
+
if set(input_set) == all_keys:
|
|
645
|
+
ordered = list(input_set)
|
|
646
|
+
break
|
|
647
|
+
new_inputs = []
|
|
648
|
+
for input_set in inputs:
|
|
649
|
+
if set(input_set) == all_keys:
|
|
650
|
+
new_inputs.append(input_set)
|
|
651
|
+
continue
|
|
652
|
+
missing = {k for k in all_keys if k not in input_set}
|
|
653
|
+
input_set_copy = input_set.copy()
|
|
654
|
+
for miss in missing:
|
|
655
|
+
input_set_copy[miss] = cls.make_empty_cache_from_others(
|
|
656
|
+
[sub[miss] for sub in inputs if miss in sub]
|
|
657
|
+
)
|
|
658
|
+
new_inputs.append({k: input_set_copy[k] for k in ordered}) # type: ignore[union-attr]
|
|
659
|
+
return new_inputs
|
|
660
|
+
|
|
661
|
+
@classmethod
|
|
662
|
+
def _expand_batch_dimension(cls, obj: Any, expand_for: Sequence[Union[int, str]]) -> Any:
|
|
663
|
+
expand_for_args = {i for i in expand_for if isinstance(i, int)}
|
|
664
|
+
expand_for_kwargs = {i for i in expand_for if isinstance(i, str)}
|
|
665
|
+
if isinstance(obj, tuple):
|
|
666
|
+
return tuple(
|
|
667
|
+
o if i not in expand_for_args else cls._expand_batch_dimension_input(o, i)
|
|
668
|
+
for i, o in enumerate(obj)
|
|
669
|
+
)
|
|
670
|
+
assert isinstance(obj, dict), f"Unexpected type {type(obj)}"
|
|
671
|
+
return {
|
|
672
|
+
k: v if k not in expand_for_kwargs else cls._expand_batch_dimension_input(v, k)
|
|
673
|
+
for k, v in obj.items()
|
|
674
|
+
}
|
|
675
|
+
|
|
676
|
+
@classmethod
|
|
677
|
+
def _expand_batch_dimension_input(cls, obj: Any, msg: Union[str, int]) -> Any:
|
|
678
|
+
if isinstance(obj, torch.Tensor):
|
|
679
|
+
assert obj.shape[0] == 1, (
|
|
680
|
+
f"Are you sure to expoand input {msg!r}, "
|
|
681
|
+
f"batch size is not 1 and shape={obj.shape}"
|
|
682
|
+
)
|
|
683
|
+
sizes = [2, *obj.shape[1:]]
|
|
684
|
+
return obj.expand(*sizes)
|
|
685
|
+
if isinstance(obj, list):
|
|
686
|
+
return [
|
|
687
|
+
cls._expand_batch_dimension_input(o, f"{msg}[{i}]") for i, o in enumerate(obj)
|
|
688
|
+
]
|
|
689
|
+
if obj.__class__.__name__ == "DynamicCache":
|
|
690
|
+
dc = CacheKeyValue(obj)
|
|
691
|
+
flat = dc.aslist()
|
|
692
|
+
flat = cls._expand_batch_dimension_input(flat, msg)
|
|
693
|
+
return CacheKeyValue(flat, cls_layers=dc.cls_layers).make_dynamic_cache()
|
|
694
|
+
# This might end up in an infinite loop if no registration is done.
|
|
695
|
+
flat, _spec = torch.utils._pytree.tree_flatten(obj)
|
|
696
|
+
assert (
|
|
697
|
+
not isinstance(flat, list) or len(flat) != 1 or type(flat[0]) is not type(obj)
|
|
698
|
+
), f"class {type(obj)} was is not registered for serialization."
|
|
699
|
+
flat = cls._expand_batch_dimension_input(flat, msg)
|
|
700
|
+
return torch.utils._pytree.tree_unflatten(flat, _spec)
|
|
701
|
+
|
|
702
|
+
@classmethod
|
|
703
|
+
def _dynamic_batch_dimension(
|
|
704
|
+
cls, ds: Union[Tuple[Any, ...], Dict[str, Any]], dynamic_for: Sequence[Union[int, str]]
|
|
705
|
+
) -> Union[Tuple[Any, ...], Dict[str, Any]]:
|
|
706
|
+
if isinstance(ds, tuple):
|
|
707
|
+
return tuple(
|
|
708
|
+
(v if i not in dynamic_for else cls._dynamic_batch_dimension_input(v, i))
|
|
709
|
+
for i, v in enumerate(ds)
|
|
710
|
+
)
|
|
711
|
+
return {
|
|
712
|
+
k: (v if k not in dynamic_for else cls._dynamic_batch_dimension_input(v, k))
|
|
713
|
+
for k, v in ds.items()
|
|
714
|
+
}
|
|
715
|
+
|
|
716
|
+
@classmethod
|
|
717
|
+
def _dynamic_batch_dimension_input(cls, ds: Any, msg: Union[str, int]) -> Any:
|
|
718
|
+
if isinstance(ds, dict) and all(isinstance(k, int) for k in ds):
|
|
719
|
+
ds[0] = "batch"
|
|
720
|
+
return {k: v for k, v in sorted(ds.items())} # noqa: C416
|
|
721
|
+
if isinstance(ds, list):
|
|
722
|
+
return [
|
|
723
|
+
cls._dynamic_batch_dimension_input(o, f"{msg}[{i}]") for i, o in enumerate(ds)
|
|
724
|
+
]
|
|
725
|
+
raise NotImplementedError(f"cannot make first dimension dynamic for batch for {ds}")
|
|
726
|
+
|
|
727
|
+
def check_discrepancies(
|
|
728
|
+
self, atol: float = 1e-4, rtol: float = 0.1, hist=(0.1, 0.01), verbose: int = 0
|
|
729
|
+
) -> List[Dict[str, Union[str, int, float]]]:
|
|
730
|
+
"""
|
|
731
|
+
Computes the discrepancies between the saved inputs and outputs
|
|
732
|
+
with the saved onnx model.
|
|
733
|
+
|
|
734
|
+
:param atol: absolute tolerance, recommended values, 1e-4 for float, 1e-2 flot float16
|
|
735
|
+
:param rtol: relative tolerance
|
|
736
|
+
:param hist: thresholds, the function determines the number of discrepancies
|
|
737
|
+
above that threshold.
|
|
738
|
+
:param verbose: verbosity
|
|
739
|
+
:return: results, a list of dictionaries, ready to be consumed by a dataframe
|
|
740
|
+
"""
|
|
741
|
+
assert self._export_done, "The onnx export was not done."
|
|
742
|
+
assert os.path.exists(self._input_file), f"input file {self._input_file!r} not found"
|
|
743
|
+
assert os.path.exists(
|
|
744
|
+
self._output_file
|
|
745
|
+
), f"output file {self._output_file!r} not found"
|
|
746
|
+
filename = self._to_onnx_kwargs["filename"]
|
|
747
|
+
assert isinstance(filename, str) and os.path.exists(
|
|
748
|
+
filename
|
|
749
|
+
), f"onnx file {filename!r} not found"
|
|
750
|
+
classes = [
|
|
751
|
+
cls
|
|
752
|
+
for cls in self._serialization_classes
|
|
753
|
+
if cls not in {int, float, bool, str, torch.Tensor, list, set, dict, torch.device}
|
|
754
|
+
]
|
|
755
|
+
if verbose:
|
|
756
|
+
print(f"[method_to_onnx.check_discrepancies] register classes {classes}")
|
|
757
|
+
print(f"[method_to_onnx.check_discrepancies] load {self._input_file!r}")
|
|
758
|
+
with torch.serialization.safe_globals(classes):
|
|
759
|
+
inputs = torch.load(self._input_file)
|
|
760
|
+
if verbose:
|
|
761
|
+
print(f"[method_to_onnx.check_discrepancies] load {self._output_file!r}")
|
|
762
|
+
with torch.serialization.safe_globals(classes):
|
|
763
|
+
outputs = torch.load(self._output_file)
|
|
764
|
+
assert len(inputs) == len(outputs), (
|
|
765
|
+
f"Unexpected number of inputs {len(inputs)} and outputs {len(outputs)}, "
|
|
766
|
+
f"inputs={string_type(inputs, with_shape=True)}, "
|
|
767
|
+
f"outputs={string_type(outputs, with_shape=True)}"
|
|
768
|
+
)
|
|
769
|
+
if verbose:
|
|
770
|
+
print(f"[method_to_onnx.check_discrepancies] create onnx session {filename!r}")
|
|
771
|
+
sess = OnnxruntimeEvaluator(filename, whole=True)
|
|
772
|
+
input_names = sess.input_names
|
|
773
|
+
if verbose:
|
|
774
|
+
print(f"[method_to_onnx.check_discrepancies] input_names={input_names}")
|
|
775
|
+
print(
|
|
776
|
+
f"[method_to_onnx.check_discrepancies] onnx_shapes="
|
|
777
|
+
f"{', '.join(pretty_onnx(i) for i in sess.input_types)}"
|
|
778
|
+
)
|
|
779
|
+
data = []
|
|
780
|
+
for i, (input, (output, latency)) in enumerate(
|
|
781
|
+
zip(self.add_empty_cache_if_needed(inputs), outputs)
|
|
782
|
+
):
|
|
783
|
+
if verbose:
|
|
784
|
+
if verbose > 1:
|
|
785
|
+
print(
|
|
786
|
+
f"[method_to_onnx.check_discrepancies] process input {i}: "
|
|
787
|
+
f"{string_type(input, with_shape=True)}"
|
|
788
|
+
)
|
|
789
|
+
print(
|
|
790
|
+
f"[method_to_onnx.check_discrepancies] expects: "
|
|
791
|
+
f"{string_type(output, with_shape=True)}"
|
|
792
|
+
)
|
|
793
|
+
else:
|
|
794
|
+
print(
|
|
795
|
+
f"[method_to_onnx.check_discrepancies] process input {i} "
|
|
796
|
+
f"#args={len(input[0])} #kwargs={len(input[1])}"
|
|
797
|
+
)
|
|
798
|
+
|
|
799
|
+
flat_inputs = flatten_object(input, drop_keys=True)
|
|
800
|
+
if verbose > 1:
|
|
801
|
+
print(
|
|
802
|
+
f"[method_to_onnx.check_discrepancies] "
|
|
803
|
+
f"input={string_type(input, with_shape=True)}"
|
|
804
|
+
)
|
|
805
|
+
print(
|
|
806
|
+
f"[method_to_onnx.check_discrepancies] "
|
|
807
|
+
f"flat_inputs={string_type(flat_inputs, with_shape=True)}"
|
|
808
|
+
)
|
|
809
|
+
if len(flat_inputs) < len(input_names):
|
|
810
|
+
# not implemented yet, it is caused by a missing cache,
|
|
811
|
+
# which requires an empty cache instead
|
|
812
|
+
data.append(dict(index=i, duration_torch=latency, n_inputs=len(flat_inputs)))
|
|
813
|
+
continue
|
|
814
|
+
assert len(flat_inputs) == len(input_names), (
|
|
815
|
+
f"Length mismatch, expecting {len(input_names)} onnx inputs and got "
|
|
816
|
+
f"{len(flat_inputs)} flat torch inputs"
|
|
817
|
+
)
|
|
818
|
+
feeds = make_feeds(input_names, flat_inputs)
|
|
819
|
+
if verbose > 1:
|
|
820
|
+
print(
|
|
821
|
+
f"[method_to_onnx.check_discrepancies] "
|
|
822
|
+
f"feeds={string_type(feeds, with_shape=True)}"
|
|
823
|
+
)
|
|
824
|
+
begin = time.perf_counter()
|
|
825
|
+
ort_outputs = sess.run(None, feeds)
|
|
826
|
+
duration = time.perf_counter() - begin
|
|
827
|
+
diff = max_diff(output, ort_outputs, hist=hist)
|
|
828
|
+
if "rep" in diff and isinstance(diff["rep"], dict):
|
|
829
|
+
diff.update(diff["rep"])
|
|
830
|
+
del diff["rep"]
|
|
831
|
+
diff["SUCCESS"] = (
|
|
832
|
+
isinstance(diff["abs"], float)
|
|
833
|
+
and isinstance(diff["rel"], float)
|
|
834
|
+
and diff["abs"] < atol
|
|
835
|
+
and diff["rel"] < rtol
|
|
836
|
+
)
|
|
837
|
+
diff.update(
|
|
838
|
+
dict(
|
|
839
|
+
index=i,
|
|
840
|
+
duration_torch=latency,
|
|
841
|
+
ort_duration=duration,
|
|
842
|
+
n_inputs=len(flat_inputs),
|
|
843
|
+
)
|
|
844
|
+
)
|
|
845
|
+
if verbose > 1:
|
|
846
|
+
print(
|
|
847
|
+
f"[method_to_onnx.check_discrepancies] ort output "
|
|
848
|
+
f"{string_type(ort_outputs, with_shape=True)}"
|
|
849
|
+
)
|
|
850
|
+
print(f"[method_to_onnx.check_discrepancies] diff {string_diff(diff)}")
|
|
851
|
+
data.append(diff)
|
|
852
|
+
if verbose:
|
|
853
|
+
print("[method_to_onnx.check_discrepancies] done")
|
|
854
|
+
return data
|
|
855
|
+
|
|
856
|
+
@classmethod
|
|
857
|
+
def _apply_known_shape_pattern(
|
|
858
|
+
cls, shape: Dict[int, Any], pattern: Dict[int, str]
|
|
859
|
+
) -> Dict[int, Any]:
|
|
860
|
+
return {k: pattern.get(k, v) for k, v in shape.items()}
|
|
861
|
+
|
|
862
|
+
@classmethod
|
|
863
|
+
def get_dynamic_shape_patterns(cls) -> Dict[str, Any]:
|
|
864
|
+
"""
|
|
865
|
+
Returns the known patterns for the dynamic shapes.
|
|
866
|
+
|
|
867
|
+
.. runpython::
|
|
868
|
+
:showcode:
|
|
869
|
+
|
|
870
|
+
import pprint
|
|
871
|
+
from onnx_diagnostic.export.api import WrapperToExportMethodToOnnx
|
|
872
|
+
pprint.pprint(WrapperToExportMethodToOnnx.get_dynamic_shape_patterns())
|
|
873
|
+
"""
|
|
874
|
+
return {
|
|
875
|
+
"LLM.text": {
|
|
876
|
+
"cache_position": {0: "seqlength"},
|
|
877
|
+
"past_key_values": {0: "batch", 2: "pastlength"},
|
|
878
|
+
"input_ids": {0: "batch", 1: "seqlength"},
|
|
879
|
+
"attention_mask": {0: "batch", 1: "totallength"}, # pastlength+seqlength
|
|
880
|
+
}
|
|
881
|
+
}
|
|
882
|
+
|
|
883
|
+
@classmethod
|
|
884
|
+
def rename_dynamic_shapes(cls, ds: Dict[str, Any], verbose: int = 0) -> Dict[str, Any]:
|
|
885
|
+
"""
|
|
886
|
+
Renames the dynamic shapes with names.
|
|
887
|
+
Tries to rename any dynamic dimnesion dimension
|
|
888
|
+
before export. It is not very clever, it just tries
|
|
889
|
+
to recognize a known configuration based on input names.
|
|
890
|
+
Dimension names in dynamic shapes are renamed if *ds* has
|
|
891
|
+
the same number of named arguments as the one of the patterns
|
|
892
|
+
returned by function :meth:`get_dynamic_shape_patterns
|
|
893
|
+
<onnx_diagnostic.export.api.WrapperToExportMethodToOnnx.get_dynamic_shape_patterns>`.
|
|
894
|
+
"""
|
|
895
|
+
is_shape = lambda s: isinstance(s, dict) and all( # noqa: E731
|
|
896
|
+
isinstance(_, int) for _ in s
|
|
897
|
+
)
|
|
898
|
+
llm_patterns = cls.get_dynamic_shape_patterns()
|
|
899
|
+
for pattern_name, pattern_shape in llm_patterns.items():
|
|
900
|
+
if len(set(ds) & set(pattern_shape)) == len(pattern_shape):
|
|
901
|
+
if verbose:
|
|
902
|
+
print(
|
|
903
|
+
f"[method_to_onnx.rename_dynamic_shapes] "
|
|
904
|
+
f"apply pattern shapes {pattern_name!r}"
|
|
905
|
+
)
|
|
906
|
+
new_ds = {}
|
|
907
|
+
for k, v in ds.items():
|
|
908
|
+
if k not in pattern_shape:
|
|
909
|
+
new_ds[k] = v
|
|
910
|
+
continue
|
|
911
|
+
if is_shape(v):
|
|
912
|
+
# A shape
|
|
913
|
+
new_ds[k] = cls._apply_known_shape_pattern(v, pattern_shape[k])
|
|
914
|
+
elif isinstance(v, list):
|
|
915
|
+
# A cache
|
|
916
|
+
new_ds[k] = [
|
|
917
|
+
(
|
|
918
|
+
cls._apply_known_shape_pattern(s, pattern_shape[k])
|
|
919
|
+
if is_shape(s)
|
|
920
|
+
else s
|
|
921
|
+
)
|
|
922
|
+
for s in v
|
|
923
|
+
]
|
|
924
|
+
return new_ds
|
|
925
|
+
|
|
926
|
+
# unchanged
|
|
927
|
+
return ds
|
|
928
|
+
|
|
480
929
|
|
|
481
930
|
def method_to_onnx(
|
|
482
931
|
mod: "torch.nn.Module",
|
|
@@ -499,6 +948,8 @@ def method_to_onnx(
|
|
|
499
948
|
patch_kwargs: Optional[Dict[str, Any]] = None,
|
|
500
949
|
skip_kwargs_names: Optional[Set[str]] = None,
|
|
501
950
|
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
951
|
+
dynamic_batch_for: Optional[Sequence[Union[int, str]]] = None,
|
|
952
|
+
expand_batch_for: Optional[Sequence[Union[int, str]]] = None,
|
|
502
953
|
) -> Callable:
|
|
503
954
|
"""
|
|
504
955
|
Exports one method into ONNX for a module into ONNX.
|
|
@@ -528,12 +979,20 @@ def method_to_onnx(
|
|
|
528
979
|
:param skip_kwargs_names: use default values for these parameters part of
|
|
529
980
|
the signature of the method to export
|
|
530
981
|
:param dynamic_shapes: dynamic shapes to use if the guessed ones are not right
|
|
982
|
+
:param dynamic_batch_for: LLM are usually called with a batch size equal to 1,
|
|
983
|
+
but the export may benefit from having a dynamic batch size,
|
|
984
|
+
this parameter forces the input specified in this set to have the first dimension
|
|
985
|
+
be dynamic
|
|
986
|
+
:param expand_batch_for: LLM are usually called with a batch size equal to 1,
|
|
987
|
+
but the export may benefit from having another value for the batch size,
|
|
988
|
+
this parameter forces the input specified in this set to be expanded
|
|
989
|
+
to 2 if the batch size is one
|
|
531
990
|
:return: the output of the selected exporter, usually a structure including
|
|
532
991
|
an onnx model
|
|
533
992
|
|
|
534
993
|
See :ref:`l-plot-tiny-llm-export-method-generate` for an example.
|
|
535
994
|
"""
|
|
536
|
-
wrapped_model =
|
|
995
|
+
wrapped_model = WrapperToExportMethodToOnnx(
|
|
537
996
|
mod=mod,
|
|
538
997
|
method_name=method_name,
|
|
539
998
|
input_names=input_names,
|
|
@@ -554,5 +1013,7 @@ def method_to_onnx(
|
|
|
554
1013
|
patch_kwargs=patch_kwargs,
|
|
555
1014
|
skip_kwargs_names=skip_kwargs_names,
|
|
556
1015
|
dynamic_shapes=dynamic_shapes,
|
|
1016
|
+
dynamic_batch_for=dynamic_batch_for,
|
|
1017
|
+
expand_batch_for=expand_batch_for,
|
|
557
1018
|
)
|
|
558
1019
|
return wrapped_model
|
|
@@ -329,7 +329,7 @@ class CoupleInputsDynamicShapes:
|
|
|
329
329
|
if type(inputs) in (tuple, list, dict):
|
|
330
330
|
# Type must be strict, some custom classes can inherit from those.
|
|
331
331
|
assert type(inputs) is type(ds), (
|
|
332
|
-
f"Input type and dynamic
|
|
332
|
+
f"Input type and dynamic shapes type mush match but "
|
|
333
333
|
f"type(inputs)={type(inputs)}, type(ds)={type(ds)}, "
|
|
334
334
|
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
|
|
335
335
|
)
|
|
@@ -967,6 +967,8 @@ class ModelInputs:
|
|
|
967
967
|
"""
|
|
968
968
|
Guesses the dynamic shapes for that module from two execution.
|
|
969
969
|
If there is only one execution, then that would be static dimensions.
|
|
970
|
+
If the model signature is available, the kwargs are reordered following
|
|
971
|
+
the signature order, otherwise it follows the order given in the inputs.
|
|
970
972
|
|
|
971
973
|
:param auto: if auto is True, use ``torch.export.Dim.AUTO`` for any
|
|
972
974
|
dimension if the number of inputs is one,
|
|
@@ -1026,11 +1028,24 @@ class ModelInputs:
|
|
|
1026
1028
|
msg=lambda name=name: f" failing input {name!r}",
|
|
1027
1029
|
)
|
|
1028
1030
|
# reordering
|
|
1029
|
-
if kwargs
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1031
|
+
if kwargs:
|
|
1032
|
+
if self.forward_ordered_parameter_names:
|
|
1033
|
+
kwargs1 = {
|
|
1034
|
+
p: kwargs[p] for p in self.forward_ordered_parameter_names if p in kwargs
|
|
1035
|
+
}
|
|
1036
|
+
kwargs = {**kwargs1, **{k: v for k, v in kwargs.items() if k not in kwargs1}}
|
|
1037
|
+
else:
|
|
1038
|
+
# We reorder the same the way the input were given.
|
|
1039
|
+
use = None
|
|
1040
|
+
params = set(kwargs)
|
|
1041
|
+
for _args, kws in self.inputs:
|
|
1042
|
+
if set(kws) == params:
|
|
1043
|
+
use = kws
|
|
1044
|
+
break
|
|
1045
|
+
if use:
|
|
1046
|
+
ordered = list(use)
|
|
1047
|
+
kwargs = {k: kwargs[k] for k in ordered}
|
|
1048
|
+
|
|
1034
1049
|
return tuple(args), kwargs
|
|
1035
1050
|
|
|
1036
1051
|
def move_to_kwargs(
|
|
@@ -47,7 +47,6 @@ def all_dynamic_shapes_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
|
|
|
47
47
|
make_dynamic_cache,
|
|
48
48
|
make_encoder_decoder_cache,
|
|
49
49
|
make_mamba_cache,
|
|
50
|
-
make_sliding_window_cache,
|
|
51
50
|
make_static_cache,
|
|
52
51
|
)
|
|
53
52
|
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
|
|
@@ -77,13 +76,6 @@ def all_dynamic_shapes_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
|
|
|
77
76
|
]
|
|
78
77
|
),
|
|
79
78
|
),
|
|
80
|
-
make_sliding_window_cache(
|
|
81
|
-
[
|
|
82
|
-
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
|
83
|
-
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
|
84
|
-
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
|
85
|
-
]
|
|
86
|
-
),
|
|
87
79
|
make_static_cache(
|
|
88
80
|
[
|
|
89
81
|
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|