onnx-diagnostic 0.8.8__py3-none-any.whl → 0.8.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,18 @@
1
1
  import inspect
2
2
  import os
3
3
  import textwrap
4
+ import time
5
+ from collections.abc import Mapping, Iterable
4
6
  from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
5
7
  import torch
6
8
  from .dynamic_shapes import ModelInputs
7
9
  from .onnx_plug import EagerDirectReplacementWithOnnx
8
- from ..helpers import string_type
10
+ from ..helpers import flatten_object, max_diff, string_diff, string_type
11
+ from ..helpers.cache_helper import CacheKeyValue
12
+ from ..helpers.torch_helper import torch_deepcopy
13
+ from ..helpers.rt_helper import make_feeds
14
+ from ..helpers.onnx_helper import pretty_onnx
15
+ from ..reference import OnnxruntimeEvaluator
9
16
 
10
17
 
11
18
  def get_main_dispatcher(
@@ -314,10 +321,11 @@ def to_onnx(
314
321
  raise ValueError(f"Unknown exporter={exporter!r}")
315
322
 
316
323
 
317
- class _WrapperToExportMethodToOnnx(torch.nn.Module):
324
+ class WrapperToExportMethodToOnnx(torch.nn.Module):
318
325
  """
319
326
  Wraps an existing models in order to spy on inputs.
320
- This is used by :func:`onnx_diagnostic.export.api.method_to_onnx`.
327
+ This is used by :func:`onnx_diagnostic.export.api.method_to_onnx`
328
+ or :ref:`l-plot-tiny-llm-export-method-generate` for an example.
321
329
  """
322
330
 
323
331
  def __init__(
@@ -342,6 +350,8 @@ class _WrapperToExportMethodToOnnx(torch.nn.Module):
342
350
  patch_kwargs: Optional[Dict[str, Any]] = None,
343
351
  skip_kwargs_names: Optional[Set[str]] = None,
344
352
  dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
353
+ dynamic_batch_for: Optional[Sequence[Union[int, str]]] = None,
354
+ expand_batch_for: Optional[Sequence[Union[int, str]]] = None,
345
355
  ):
346
356
  super().__init__()
347
357
  self._model_to_call = mod
@@ -351,13 +361,17 @@ class _WrapperToExportMethodToOnnx(torch.nn.Module):
351
361
  if method_name == "forward"
352
362
  else getattr(mod, method_name)
353
363
  )
364
+ self._signature = inspect.signature(self._method_call)
354
365
  self._inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = []
366
+ self._outputs: List[Any] = []
355
367
  self._convert_after_n_calls = convert_after_n_calls
356
368
  self._patch_kwargs = patch_kwargs
357
369
  self._method_src = None
358
370
  self.verbose = verbose
359
371
  self.skip_kwargs_names = skip_kwargs_names
360
372
  self.dynamic_shapes = dynamic_shapes
373
+ self.expand_batch_for = expand_batch_for
374
+ self.dynamic_batch_for = dynamic_batch_for
361
375
  self._to_onnx_kwargs = dict(
362
376
  input_names=input_names,
363
377
  target_opset=target_opset,
@@ -375,42 +389,117 @@ class _WrapperToExportMethodToOnnx(torch.nn.Module):
375
389
  inline=inline,
376
390
  )
377
391
  self._export_done = False
392
+ self._serialization_classes: Set[type] = set()
378
393
 
379
394
  def __str__(self) -> str:
395
+ "usual"
380
396
  return self.__repr__()
381
397
 
382
398
  def __repr__(self) -> str:
399
+ "usual"
383
400
  return (
384
401
  f"{self.__class__.__name__}({self._model_to_call.__class__.__name__}."
385
402
  f"{self._method_name})"
386
403
  )
387
404
 
405
+ def _collect_classes(self, obj):
406
+ if obj is None or isinstance(obj, torch.Tensor):
407
+ return
408
+ cls = type(obj)
409
+ if cls.__module__ not in ("builtins",):
410
+ self._serialization_classes.add(cls)
411
+ if hasattr(obj, "__dict__"):
412
+ for v in vars(obj).values():
413
+ self._collect_classes(v)
414
+ return
415
+ if isinstance(obj, Mapping):
416
+ for v in obj.values():
417
+ self._collect_classes(v)
418
+ return
419
+ if isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)):
420
+ for v in obj:
421
+ self._collect_classes(v)
422
+ return
423
+
424
+ def _reorder_kwargs(self, kwargs):
425
+ new_kwargs = {k: kwargs[k] for k in self._signature.parameters if k in kwargs}
426
+ for k, v in kwargs.items():
427
+ if k not in new_kwargs:
428
+ new_kwargs[k] = v
429
+ return new_kwargs
430
+
388
431
  def forward(self, *args, **kwargs):
389
432
  if not self._export_done:
390
- self._inputs.append(
391
- (
392
- args,
393
- (
394
- kwargs
395
- if not kwargs or not self.skip_kwargs_names
396
- else {
397
- k: v for k, v in kwargs.items() if k not in self.skip_kwargs_names
398
- }
399
- ),
400
- )
433
+ inp_args = args
434
+ # filters out the inputs not desired, int, float, bool, None
435
+ # are considered as constant for the exporter, they are removed
436
+ # from the named arguments.
437
+ inp_kwargs = (
438
+ kwargs
439
+ if not kwargs
440
+ else {
441
+ k: v
442
+ for k, v in kwargs.items()
443
+ if v is not None
444
+ and (not self.skip_kwargs_names or k not in self.skip_kwargs_names)
445
+ and not isinstance(v, (bool, int, float))
446
+ }
401
447
  )
448
+ inp_args, inp_kwargs = torch_deepcopy((inp_args, inp_kwargs))
449
+ # reorders the parameter following the method signature.
450
+ inp_kwargs = self._reorder_kwargs(inp_kwargs)
451
+ # stores the inputs
452
+ self._inputs.append((inp_args, inp_kwargs))
453
+
402
454
  if self.verbose:
403
455
  print(
404
456
  f"[method_to_onnx] input[{len(self._inputs)-1}]: "
405
457
  f"{string_type(self._inputs[-1], with_shape=True)}"
406
458
  )
459
+
407
460
  if len(self._inputs) >= self._convert_after_n_calls:
461
+ # conversion starts after _convert_after_n_calls calls to the forward method
462
+ name = os.path.splitext(self._to_onnx_kwargs["filename"])[0]
463
+ input_file = f"{name}.inputs.pt"
464
+ self._input_file = input_file
465
+ if self.verbose:
466
+ print(
467
+ f"[method_to_onnx] save {len(self._inputs)} inputs in {input_file!r}"
468
+ )
469
+ torch.save(self._inputs, input_file)
408
470
  self._convert_method_to_onnx()
409
- del self._inputs[:]
410
471
  self._export_done = True
411
- return self._method_call(*args, **kwargs)
472
+
473
+ # calls the inner method (no change here)
474
+ begin = time.perf_counter()
475
+ res = self._method_call(*args, **kwargs)
476
+ duration = time.perf_counter() - begin
477
+ self._collect_classes([args, kwargs, res])
478
+ if self._inputs:
479
+ # stores the outputs if discrepancies need to be checked
480
+ self._outputs.append((torch_deepcopy(res), duration))
481
+ assert len(self._inputs) == len(self._outputs), (
482
+ f"Number of inputs {len(self._inputs)} and "
483
+ f"outputs {len(self._outputs)} are different."
484
+ )
485
+ if self._export_done:
486
+ name = os.path.splitext(self._to_onnx_kwargs["filename"])[0]
487
+ output_file = f"{name}.outputs.pt"
488
+ if self.verbose:
489
+ print(
490
+ f"[method_to_onnx] save {len(self._outputs)} "
491
+ f"outputs in {output_file!r}"
492
+ )
493
+ torch.save(self._outputs, output_file)
494
+ self._output_file = output_file
495
+ del self._inputs[:]
496
+ del self._outputs[:]
497
+ return res
412
498
 
413
499
  def _convert_method_to_onnx(self):
500
+ for args, kwargs in self._inputs:
501
+ self._serialization_classes |= {type(a) for a in args}
502
+ self._serialization_classes |= {type(a) for a in kwargs.values()}
414
503
 
415
504
  def make_method(self):
416
505
  inner_sig = inspect.signature(self._method_call)
@@ -450,9 +539,24 @@ class _WrapperToExportMethodToOnnx(torch.nn.Module):
450
539
  if self.verbose:
451
540
  print(f"[method_to_onnx] guess_dynamic_shapes={string_type(ds)}")
452
541
  a, kw, nds = mi.move_to_kwargs(*self._inputs[-1], ds)
542
+ if self.dynamic_batch_for:
543
+ nds = (
544
+ self._dynamic_batch_dimension(nds[0], self.dynamic_batch_for),
545
+ self.rename_dynamic_shapes(
546
+ self._dynamic_batch_dimension(nds[1], self.dynamic_batch_for),
547
+ verbose=self.verbose,
548
+ ),
549
+ )
550
+ if self.verbose:
551
+ print(f"[method_to_onnx] dynamic_batch_for={self.dynamic_batch_for}")
552
+ print(f"[method_to_onnx] dynamic_shapes with batch={nds}")
453
553
  else:
454
554
  a, kw = self._inputs[-1]
455
555
  nds = [self.dynamic_shapes]
556
+ if self.expand_batch_for:
557
+ # extends the inputs to artificially create a batch dimension != 1.
558
+ a = self._expand_batch_dimension(a, self.expand_batch_for)
559
+ kw = self._expand_batch_dimension(kw, self.expand_batch_for)
456
560
  if self.verbose:
457
561
  print(f"[method_to_onnx] export args={string_type(a, with_shape=True)}")
458
562
  print(f"[method_to_onnx] export kwargs={string_type(kw, with_shape=True)}")
@@ -477,6 +581,365 @@ class _WrapperToExportMethodToOnnx(torch.nn.Module):
477
581
  **self._to_onnx_kwargs,
478
582
  )
479
583
 
584
+ @classmethod
585
+ def make_empty_cache_from_others(cls, examples: List[Any]) -> Any:
586
+ """Builds an empty cache based on existing one."""
587
+ unique_types = {type(t) for t in examples}
588
+ assert (
589
+ len(unique_types) == 1
590
+ ), f"Unable to guess an empty cache from {string_type(examples, with_shape=True)}"
591
+ unique_type = unique_types.pop()
592
+ if unique_type == torch.Tensor:
593
+ shapes = [t.shape for t in examples]
594
+ assert len(set(shapes)) > 1, f"Unable to guess an empty shape from shapes {shapes}"
595
+ ranks = {len(s) for s in shapes}
596
+ assert len(ranks) == 1, f"Ranks are different in {shapes}"
597
+ rank = ranks.pop()
598
+ new_shape = []
599
+ for i in range(rank):
600
+ dims = [t.shape[i] for t in examples]
601
+ if len(set(dims)) == 1:
602
+ new_shape.append(dims[0])
603
+ else:
604
+ # The empty shape
605
+ new_shape.append(0)
606
+ example = examples[0]
607
+ return torch.empty(tuple(new_shape), dtype=example.dtype, device=example.device)
608
+ assert (
609
+ unique_type.__name__ == "DynamicCache"
610
+ ), f"This is not implemented for class {unique_type}"
611
+ caches = [CacheKeyValue(dc) for dc in examples]
612
+ caches_list = [dc.aslist() for dc in caches]
613
+ empty = [
614
+ cls.make_empty_cache_from_others([caches_list[i][k] for i in range(len(examples))])
615
+ for k in range(len(caches_list[0]))
616
+ ]
617
+ empty_cache = CacheKeyValue(
618
+ empty, cls_layers=caches[0].cls_layers
619
+ ).make_dynamic_cache()
620
+ return empty_cache
621
+
622
+ @classmethod
623
+ def add_empty_cache_if_needed(cls, inputs: List[Any]) -> List[Any]:
624
+ """
625
+ Adds empty cache if needed as onnxruntime needs an empty cache,
626
+ not a missing cache. It only works if inputs are defined as a dictionary.
627
+ """
628
+ if all(isinstance(t, tuple) for t in inputs) and all(
629
+ len(t) == 2 and isinstance(t[0], tuple) and isinstance(t[1], dict) and not t[0]
630
+ for t in inputs
631
+ ):
632
+ dict_part = [t[1] for t in inputs]
633
+ res = cls.add_empty_cache_if_needed(dict_part)
634
+ return [(tuple(), d) for d in res]
635
+ if any(not isinstance(t, dict) for t in inputs):
636
+ return inputs
637
+ all_keys = set()
638
+ for input_set in inputs:
639
+ all_keys |= set(input_set)
640
+ # even though the inputs are defined as a dictionary, it is better
641
+ # to keep the same order
642
+ ordered = None
643
+ for input_set in inputs:
644
+ if set(input_set) == all_keys:
645
+ ordered = list(input_set)
646
+ break
647
+ new_inputs = []
648
+ for input_set in inputs:
649
+ if set(input_set) == all_keys:
650
+ new_inputs.append(input_set)
651
+ continue
652
+ missing = {k for k in all_keys if k not in input_set}
653
+ input_set_copy = input_set.copy()
654
+ for miss in missing:
655
+ input_set_copy[miss] = cls.make_empty_cache_from_others(
656
+ [sub[miss] for sub in inputs if miss in sub]
657
+ )
658
+ new_inputs.append({k: input_set_copy[k] for k in ordered}) # type: ignore[union-attr]
659
+ return new_inputs
660
+
661
+ @classmethod
662
+ def _expand_batch_dimension(cls, obj: Any, expand_for: Sequence[Union[int, str]]) -> Any:
663
+ expand_for_args = {i for i in expand_for if isinstance(i, int)}
664
+ expand_for_kwargs = {i for i in expand_for if isinstance(i, str)}
665
+ if isinstance(obj, tuple):
666
+ return tuple(
667
+ o if i not in expand_for_args else cls._expand_batch_dimension_input(o, i)
668
+ for i, o in enumerate(obj)
669
+ )
670
+ assert isinstance(obj, dict), f"Unexpected type {type(obj)}"
671
+ return {
672
+ k: v if k not in expand_for_kwargs else cls._expand_batch_dimension_input(v, k)
673
+ for k, v in obj.items()
674
+ }
675
+
676
+ @classmethod
677
+ def _expand_batch_dimension_input(cls, obj: Any, msg: Union[str, int]) -> Any:
678
+ if isinstance(obj, torch.Tensor):
679
+ assert obj.shape[0] == 1, (
680
+ f"Are you sure to expoand input {msg!r}, "
681
+ f"batch size is not 1 and shape={obj.shape}"
682
+ )
683
+ sizes = [2, *obj.shape[1:]]
684
+ return obj.expand(*sizes)
685
+ if isinstance(obj, list):
686
+ return [
687
+ cls._expand_batch_dimension_input(o, f"{msg}[{i}]") for i, o in enumerate(obj)
688
+ ]
689
+ if obj.__class__.__name__ == "DynamicCache":
690
+ dc = CacheKeyValue(obj)
691
+ flat = dc.aslist()
692
+ flat = cls._expand_batch_dimension_input(flat, msg)
693
+ return CacheKeyValue(flat, cls_layers=dc.cls_layers).make_dynamic_cache()
694
+ # This might end up in an infinite loop if no registration is done.
695
+ flat, _spec = torch.utils._pytree.tree_flatten(obj)
696
+ assert (
697
+ not isinstance(flat, list) or len(flat) != 1 or type(flat[0]) is not type(obj)
698
+ ), f"class {type(obj)} was is not registered for serialization."
699
+ flat = cls._expand_batch_dimension_input(flat, msg)
700
+ return torch.utils._pytree.tree_unflatten(flat, _spec)
701
+
702
+ @classmethod
703
+ def _dynamic_batch_dimension(
704
+ cls, ds: Union[Tuple[Any, ...], Dict[str, Any]], dynamic_for: Sequence[Union[int, str]]
705
+ ) -> Union[Tuple[Any, ...], Dict[str, Any]]:
706
+ if isinstance(ds, tuple):
707
+ return tuple(
708
+ (v if i not in dynamic_for else cls._dynamic_batch_dimension_input(v, i))
709
+ for i, v in enumerate(ds)
710
+ )
711
+ return {
712
+ k: (v if k not in dynamic_for else cls._dynamic_batch_dimension_input(v, k))
713
+ for k, v in ds.items()
714
+ }
715
+
716
+ @classmethod
717
+ def _dynamic_batch_dimension_input(cls, ds: Any, msg: Union[str, int]) -> Any:
718
+ if isinstance(ds, dict) and all(isinstance(k, int) for k in ds):
719
+ ds[0] = "batch"
720
+ return {k: v for k, v in sorted(ds.items())} # noqa: C416
721
+ if isinstance(ds, list):
722
+ return [
723
+ cls._dynamic_batch_dimension_input(o, f"{msg}[{i}]") for i, o in enumerate(ds)
724
+ ]
725
+ raise NotImplementedError(f"cannot make first dimension dynamic for batch for {ds}")
726
+
727
+ def check_discrepancies(
728
+ self, atol: float = 1e-4, rtol: float = 0.1, hist=(0.1, 0.01), verbose: int = 0
729
+ ) -> List[Dict[str, Union[str, int, float]]]:
730
+ """
731
+ Computes the discrepancies between the saved inputs and outputs
732
+ with the saved onnx model.
733
+
734
+ :param atol: absolute tolerance, recommended values, 1e-4 for float, 1e-2 flot float16
735
+ :param rtol: relative tolerance
736
+ :param hist: thresholds, the function determines the number of discrepancies
737
+ above that threshold.
738
+ :param verbose: verbosity
739
+ :return: results, a list of dictionaries, ready to be consumed by a dataframe
740
+ """
741
+ assert (
742
+ self._export_done
743
+ ), f"The onnx export was not done, only {len(self._inputs)} were stored."
744
+ assert os.path.exists(self._input_file), f"input file {self._input_file!r} not found"
745
+ assert os.path.exists(
746
+ self._output_file
747
+ ), f"output file {self._output_file!r} not found"
748
+ filename = self._to_onnx_kwargs["filename"]
749
+ assert isinstance(filename, str) and os.path.exists(
750
+ filename
751
+ ), f"onnx file {filename!r} not found"
752
+ classes = [
753
+ cls
754
+ for cls in self._serialization_classes
755
+ if cls
756
+ not in {
757
+ int,
758
+ float,
759
+ bool,
760
+ str,
761
+ torch.Tensor,
762
+ list,
763
+ set,
764
+ dict,
765
+ torch.device,
766
+ torch.dtype,
767
+ }
768
+ ]
769
+ if verbose:
770
+ print(f"[method_to_onnx.check_discrepancies] register classes {classes}")
771
+ print(f"[method_to_onnx.check_discrepancies] load {self._input_file!r}")
772
+ with torch.serialization.safe_globals(classes):
773
+ inputs = torch.load(self._input_file, weights_only=False)
774
+ if verbose:
775
+ print(f"[method_to_onnx.check_discrepancies] load {self._output_file!r}")
776
+ with torch.serialization.safe_globals(classes):
777
+ outputs = torch.load(self._output_file, weights_only=False)
778
+ assert len(inputs) == len(outputs), (
779
+ f"Unexpected number of inputs {len(inputs)} and outputs {len(outputs)}, "
780
+ f"inputs={string_type(inputs, with_shape=True)}, "
781
+ f"outputs={string_type(outputs, with_shape=True)}"
782
+ )
783
+ if verbose:
784
+ print(f"[method_to_onnx.check_discrepancies] create onnx session {filename!r}")
785
+ sess = OnnxruntimeEvaluator(filename, whole=True)
786
+ input_names = sess.input_names
787
+ if verbose:
788
+ print(f"[method_to_onnx.check_discrepancies] input_names={input_names}")
789
+ print(
790
+ f"[method_to_onnx.check_discrepancies] onnx_shapes="
791
+ f"{', '.join(pretty_onnx(i) for i in sess.input_types)}"
792
+ )
793
+ data = []
794
+ for i, (input, (output, latency)) in enumerate(
795
+ zip(self.add_empty_cache_if_needed(inputs), outputs)
796
+ ):
797
+ if verbose:
798
+ if verbose > 1:
799
+ print(
800
+ f"[method_to_onnx.check_discrepancies] process input {i}: "
801
+ f"{string_type(input, with_shape=True)}"
802
+ )
803
+ print(
804
+ f"[method_to_onnx.check_discrepancies] expects: "
805
+ f"{string_type(output, with_shape=True)}"
806
+ )
807
+ else:
808
+ print(
809
+ f"[method_to_onnx.check_discrepancies] process input {i} "
810
+ f"#args={len(input[0])} #kwargs={len(input[1])}"
811
+ )
812
+
813
+ flat_inputs = flatten_object(input, drop_keys=True)
814
+ if verbose > 1:
815
+ print(
816
+ f"[method_to_onnx.check_discrepancies] "
817
+ f"input={string_type(input, with_shape=True)}"
818
+ )
819
+ print(
820
+ f"[method_to_onnx.check_discrepancies] "
821
+ f"flat_inputs={string_type(flat_inputs, with_shape=True)}"
822
+ )
823
+ if len(flat_inputs) < len(input_names):
824
+ # not implemented yet, it is caused by a missing cache,
825
+ # which requires an empty cache instead
826
+ data.append(dict(index=i, duration_torch=latency, n_inputs=len(flat_inputs)))
827
+ continue
828
+ assert len(flat_inputs) == len(input_names), (
829
+ f"Length mismatch, expecting {len(input_names)} onnx inputs and got "
830
+ f"{len(flat_inputs)} flat torch inputs"
831
+ )
832
+ feeds = make_feeds(input_names, flat_inputs)
833
+ if verbose > 1:
834
+ print(
835
+ f"[method_to_onnx.check_discrepancies] "
836
+ f"feeds={string_type(feeds, with_shape=True)}"
837
+ )
838
+ begin = time.perf_counter()
839
+ ort_outputs = sess.run(None, feeds)
840
+ duration = time.perf_counter() - begin
841
+ diff = max_diff(output, ort_outputs, hist=hist)
842
+ if "rep" in diff and isinstance(diff["rep"], dict):
843
+ diff.update(diff["rep"])
844
+ del diff["rep"]
845
+ diff["SUCCESS"] = (
846
+ isinstance(diff["abs"], float)
847
+ and isinstance(diff["rel"], float)
848
+ and diff["abs"] < atol
849
+ and diff["rel"] < rtol
850
+ )
851
+ diff.update(
852
+ dict(
853
+ index=i,
854
+ duration_torch=latency,
855
+ ort_duration=duration,
856
+ n_inputs=len(flat_inputs),
857
+ )
858
+ )
859
+ if verbose > 1:
860
+ print(
861
+ f"[method_to_onnx.check_discrepancies] ort output "
862
+ f"{string_type(ort_outputs, with_shape=True)}"
863
+ )
864
+ print(f"[method_to_onnx.check_discrepancies] diff {string_diff(diff)}")
865
+ data.append(diff)
866
+ if verbose:
867
+ print("[method_to_onnx.check_discrepancies] done")
868
+ return data
869
+
870
+ @classmethod
871
+ def _apply_known_shape_pattern(
872
+ cls, shape: Dict[int, Any], pattern: Dict[int, str]
873
+ ) -> Dict[int, Any]:
874
+ return {k: pattern.get(k, v) for k, v in shape.items()}
875
+
876
+ @classmethod
877
+ def get_dynamic_shape_patterns(cls) -> Dict[str, Any]:
878
+ """
879
+ Returns the known patterns for the dynamic shapes.
880
+
881
+ .. runpython::
882
+ :showcode:
883
+
884
+ import pprint
885
+ from onnx_diagnostic.export.api import WrapperToExportMethodToOnnx
886
+ pprint.pprint(WrapperToExportMethodToOnnx.get_dynamic_shape_patterns())
887
+ """
888
+ return {
889
+ "LLM.text": {
890
+ "cache_position": {0: "seqlength"},
891
+ "past_key_values": {0: "batch", 2: "pastlength"},
892
+ "input_ids": {0: "batch", 1: "seqlength"},
893
+ "attention_mask": {0: "batch", 1: "totallength"}, # pastlength+seqlength
894
+ }
895
+ }
896
+
897
+ @classmethod
898
+ def rename_dynamic_shapes(cls, ds: Dict[str, Any], verbose: int = 0) -> Dict[str, Any]:
899
+ """
900
+ Renames the dynamic shapes with names.
901
+ Tries to rename any dynamic dimnesion dimension
902
+ before export. It is not very clever, it just tries
903
+ to recognize a known configuration based on input names.
904
+ Dimension names in dynamic shapes are renamed if *ds* has
905
+ the same number of named arguments as the one of the patterns
906
+ returned by function :meth:`get_dynamic_shape_patterns
907
+ <onnx_diagnostic.export.api.WrapperToExportMethodToOnnx.get_dynamic_shape_patterns>`.
908
+ """
909
+ is_shape = lambda s: isinstance(s, dict) and all( # noqa: E731
910
+ isinstance(_, int) for _ in s
911
+ )
912
+ llm_patterns = cls.get_dynamic_shape_patterns()
913
+ for pattern_name, pattern_shape in llm_patterns.items():
914
+ if len(set(ds) & set(pattern_shape)) == len(pattern_shape):
915
+ if verbose:
916
+ print(
917
+ f"[method_to_onnx.rename_dynamic_shapes] "
918
+ f"apply pattern shapes {pattern_name!r}"
919
+ )
920
+ new_ds = {}
921
+ for k, v in ds.items():
922
+ if k not in pattern_shape:
923
+ new_ds[k] = v
924
+ continue
925
+ if is_shape(v):
926
+ # A shape
927
+ new_ds[k] = cls._apply_known_shape_pattern(v, pattern_shape[k])
928
+ elif isinstance(v, list):
929
+ # A cache
930
+ new_ds[k] = [
931
+ (
932
+ cls._apply_known_shape_pattern(s, pattern_shape[k])
933
+ if is_shape(s)
934
+ else s
935
+ )
936
+ for s in v
937
+ ]
938
+ return new_ds
939
+
940
+ # unchanged
941
+ return ds
942
+
480
943
 
481
944
  def method_to_onnx(
482
945
  mod: "torch.nn.Module",
@@ -499,6 +962,8 @@ def method_to_onnx(
499
962
  patch_kwargs: Optional[Dict[str, Any]] = None,
500
963
  skip_kwargs_names: Optional[Set[str]] = None,
501
964
  dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
965
+ dynamic_batch_for: Optional[Sequence[Union[int, str]]] = None,
966
+ expand_batch_for: Optional[Sequence[Union[int, str]]] = None,
502
967
  ) -> Callable:
503
968
  """
504
969
  Exports one method into ONNX for a module into ONNX.
@@ -528,12 +993,20 @@ def method_to_onnx(
528
993
  :param skip_kwargs_names: use default values for these parameters part of
529
994
  the signature of the method to export
530
995
  :param dynamic_shapes: dynamic shapes to use if the guessed ones are not right
996
+ :param dynamic_batch_for: LLM are usually called with a batch size equal to 1,
997
+ but the export may benefit from having a dynamic batch size,
998
+ this parameter forces the input specified in this set to have the first dimension
999
+ be dynamic
1000
+ :param expand_batch_for: LLM are usually called with a batch size equal to 1,
1001
+ but the export may benefit from having another value for the batch size,
1002
+ this parameter forces the input specified in this set to be expanded
1003
+ to 2 if the batch size is one
531
1004
  :return: the output of the selected exporter, usually a structure including
532
1005
  an onnx model
533
1006
 
534
1007
  See :ref:`l-plot-tiny-llm-export-method-generate` for an example.
535
1008
  """
536
- wrapped_model = _WrapperToExportMethodToOnnx(
1009
+ wrapped_model = WrapperToExportMethodToOnnx(
537
1010
  mod=mod,
538
1011
  method_name=method_name,
539
1012
  input_names=input_names,
@@ -554,5 +1027,7 @@ def method_to_onnx(
554
1027
  patch_kwargs=patch_kwargs,
555
1028
  skip_kwargs_names=skip_kwargs_names,
556
1029
  dynamic_shapes=dynamic_shapes,
1030
+ dynamic_batch_for=dynamic_batch_for,
1031
+ expand_batch_for=expand_batch_for,
557
1032
  )
558
1033
  return wrapped_model
@@ -329,7 +329,7 @@ class CoupleInputsDynamicShapes:
329
329
  if type(inputs) in (tuple, list, dict):
330
330
  # Type must be strict, some custom classes can inherit from those.
331
331
  assert type(inputs) is type(ds), (
332
- f"Input type and dynamic shape type mush match but "
332
+ f"Input type and dynamic shapes type mush match but "
333
333
  f"type(inputs)={type(inputs)}, type(ds)={type(ds)}, "
334
334
  f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
335
335
  )
@@ -967,6 +967,8 @@ class ModelInputs:
967
967
  """
968
968
  Guesses the dynamic shapes for that module from two execution.
969
969
  If there is only one execution, then that would be static dimensions.
970
+ If the model signature is available, the kwargs are reordered following
971
+ the signature order, otherwise it follows the order given in the inputs.
970
972
 
971
973
  :param auto: if auto is True, use ``torch.export.Dim.AUTO`` for any
972
974
  dimension if the number of inputs is one,
@@ -1026,11 +1028,24 @@ class ModelInputs:
1026
1028
  msg=lambda name=name: f" failing input {name!r}",
1027
1029
  )
1028
1030
  # reordering
1029
- if kwargs is not None and self.forward_ordered_parameter_names:
1030
- kwargs1 = {
1031
- p: kwargs[p] for p in self.forward_ordered_parameter_names if p in kwargs
1032
- }
1033
- kwargs = {**kwargs1, **{k: v for k, v in kwargs.items() if k not in kwargs1}}
1031
+ if kwargs:
1032
+ if self.forward_ordered_parameter_names:
1033
+ kwargs1 = {
1034
+ p: kwargs[p] for p in self.forward_ordered_parameter_names if p in kwargs
1035
+ }
1036
+ kwargs = {**kwargs1, **{k: v for k, v in kwargs.items() if k not in kwargs1}}
1037
+ else:
1038
+ # We reorder the same the way the input were given.
1039
+ use = None
1040
+ params = set(kwargs)
1041
+ for _args, kws in self.inputs:
1042
+ if set(kws) == params:
1043
+ use = kws
1044
+ break
1045
+ if use:
1046
+ ordered = list(use)
1047
+ kwargs = {k: kwargs[k] for k in ordered}
1048
+
1034
1049
  return tuple(args), kwargs
1035
1050
 
1036
1051
  def move_to_kwargs(
@@ -47,7 +47,6 @@ def all_dynamic_shapes_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
47
47
  make_dynamic_cache,
48
48
  make_encoder_decoder_cache,
49
49
  make_mamba_cache,
50
- make_sliding_window_cache,
51
50
  make_static_cache,
52
51
  )
53
52
  from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
@@ -77,13 +76,6 @@ def all_dynamic_shapes_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
77
76
  ]
78
77
  ),
79
78
  ),
80
- make_sliding_window_cache(
81
- [
82
- (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
83
- (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
84
- (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
85
- ]
86
- ),
87
79
  make_static_cache(
88
80
  [
89
81
  (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),