onnx-diagnostic 0.8.7__py3-none-any.whl → 0.8.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/ci_models/export_phi4_mm.py +1 -1
  3. onnx_diagnostic/doc.py +258 -8
  4. onnx_diagnostic/export/api.py +755 -5
  5. onnx_diagnostic/export/dynamic_shapes.py +61 -4
  6. onnx_diagnostic/export/shape_helper.py +1 -8
  7. onnx_diagnostic/helpers/cache_helper.py +98 -21
  8. onnx_diagnostic/helpers/fake_tensor_helper.py +26 -5
  9. onnx_diagnostic/helpers/helper.py +36 -6
  10. onnx_diagnostic/helpers/onnx_helper.py +7 -0
  11. onnx_diagnostic/helpers/ort_session.py +5 -0
  12. onnx_diagnostic/helpers/rt_helper.py +14 -1
  13. onnx_diagnostic/helpers/torch_helper.py +22 -9
  14. onnx_diagnostic/tasks/image_text_to_text.py +8 -5
  15. onnx_diagnostic/tasks/text_generation.py +17 -17
  16. onnx_diagnostic/torch_export_patches/eval/__init__.py +1 -1
  17. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +62 -38
  18. onnx_diagnostic/torch_export_patches/patch_details.py +3 -3
  19. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +14 -5
  20. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
  21. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -9
  22. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +42 -30
  23. onnx_diagnostic/torch_models/validate.py +48 -0
  24. {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/METADATA +3 -1
  25. {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/RECORD +28 -28
  26. {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/WHEEL +0 -0
  27. {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/licenses/LICENSE.txt +0 -0
  28. {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,18 @@
1
- from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
1
+ import inspect
2
+ import os
3
+ import textwrap
4
+ import time
5
+ from collections.abc import Mapping, Iterable
6
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
2
7
  import torch
8
+ from .dynamic_shapes import ModelInputs
3
9
  from .onnx_plug import EagerDirectReplacementWithOnnx
10
+ from ..helpers import flatten_object, max_diff, string_diff, string_type
11
+ from ..helpers.cache_helper import CacheKeyValue
12
+ from ..helpers.torch_helper import torch_deepcopy
13
+ from ..helpers.rt_helper import make_feeds
14
+ from ..helpers.onnx_helper import pretty_onnx
15
+ from ..reference import OnnxruntimeEvaluator
4
16
 
5
17
 
6
18
  def get_main_dispatcher(
@@ -70,6 +82,7 @@ def to_onnx(
70
82
  inline: bool = True,
71
83
  ) -> Any:
72
84
  """
85
+ Exports one model into ONNX.
73
86
  Common API for exporters. By default, the models are optimized to use the
74
87
  most efficient kernels implemented in :epkg:`onnxruntime`.
75
88
 
@@ -126,8 +139,12 @@ def to_onnx(
126
139
  from experimental_experiment.xbuilder import OptimizationOptions
127
140
 
128
141
  options = None
142
+ export_options = None
129
143
  if exporter_kwargs is not None:
130
144
  options = exporter_kwargs.pop("options", None)
145
+ export_options = exporter_kwargs.pop("export_options", None)
146
+ if export_options is None:
147
+ export_options = ExportOptions(save_ep=save_ep)
131
148
  if options is None and optimize:
132
149
  options = OptimizationOptions(
133
150
  patterns="default+onnxruntime" if optimizer_for_ort else "default"
@@ -138,7 +155,7 @@ def to_onnx(
138
155
  else None
139
156
  )
140
157
 
141
- return _to_onnx(
158
+ proto, opt_stats = _to_onnx(
142
159
  mod,
143
160
  args=args,
144
161
  kwargs=kwargs,
@@ -150,16 +167,52 @@ def to_onnx(
150
167
  dynamic_shapes=dynamic_shapes,
151
168
  large_model=True,
152
169
  output_dynamic_shapes=output_dynamic_shapes,
153
- export_options=ExportOptions(save_ep=save_ep),
170
+ export_options=export_options,
154
171
  options=options,
155
172
  inline=inline,
156
173
  dispatcher=main_dispatcher,
157
174
  optimize=optimize,
175
+ return_optimize_report=True,
158
176
  **(exporter_kwargs or {}),
159
177
  )
178
+ if opt_stats and filename and os.path.exists(filename):
179
+ import pandas
180
+
181
+ stat_filename = f"{os.path.splitext(filename)[0]}.opt.xlsx"
182
+ pattern_stats = []
183
+ for k, v in opt_stats.items():
184
+ if "time" in k:
185
+ pattern_stats.append(dict(level="main", pattern=k, time_in=v))
186
+ pattern_stats.extend(
187
+ [{**obs, "level": "detailed"} for obs in opt_stats["optimization"]]
188
+ )
189
+ df = pandas.DataFrame(pattern_stats)
190
+ df.to_excel(stat_filename, index=False)
191
+ cols = [
192
+ c
193
+ for c in [
194
+ "level",
195
+ "pattern",
196
+ "time_in",
197
+ "iteration",
198
+ "inlined",
199
+ "removed",
200
+ "added",
201
+ "instances",
202
+ "changed",
203
+ "scale",
204
+ ]
205
+ if c in df.columns
206
+ ]
207
+ agg = {k: "sum" for k in cols if k not in ("level", "pattern")}
208
+ agg.update(dict(iteration="max", instances="mean"))
209
+ agg = {k: v for k, v in agg.items() if k in df.columns}
210
+ stat_filename = f"{os.path.splitext(filename)[0]}.opt.agg.xlsx"
211
+ df[cols].groupby(["level", "pattern"]).agg(agg).to_excel(stat_filename)
212
+
213
+ return proto
160
214
 
161
215
  if exporter in ("dynamo", "onnx-dynamo"):
162
- import os
163
216
  from ..helpers import flatten_object
164
217
  import onnxscript.rewriter.ort_fusions as ort_fusions
165
218
 
@@ -226,7 +279,6 @@ def to_onnx(
226
279
  return epo
227
280
 
228
281
  if exporter == "modelbuilder":
229
- import os
230
282
  from ..helpers import flatten_object, string_type
231
283
  from ..helpers.model_builder_helper import create_model_builder, save_model_builder
232
284
 
@@ -267,3 +319,701 @@ def to_onnx(
267
319
  return onx
268
320
 
269
321
  raise ValueError(f"Unknown exporter={exporter!r}")
322
+
323
+
324
+ class WrapperToExportMethodToOnnx(torch.nn.Module):
325
+ """
326
+ Wraps an existing models in order to spy on inputs.
327
+ This is used by :func:`onnx_diagnostic.export.api.method_to_onnx`
328
+ or :ref:`l-plot-tiny-llm-export-method-generate` for an example.
329
+ """
330
+
331
+ def __init__(
332
+ self,
333
+ mod: "torch.nn.Module",
334
+ method_name: str = "forward",
335
+ input_names: Optional[Sequence[str]] = None,
336
+ target_opset: Optional[Union[int, Dict[str, int]]] = None,
337
+ verbose: int = 0,
338
+ filename: Optional[str] = None,
339
+ output_names: Optional[List[str]] = None,
340
+ output_dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
341
+ exporter: str = "onnx-dynamo",
342
+ exporter_kwargs: Optional[Dict[str, Any]] = None,
343
+ save_ep: Optional[str] = None,
344
+ optimize: bool = True,
345
+ optimizer_for_ort: bool = True,
346
+ use_control_flow_dispatcher: bool = False,
347
+ onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
348
+ inline: bool = True,
349
+ convert_after_n_calls: int = 2,
350
+ patch_kwargs: Optional[Dict[str, Any]] = None,
351
+ skip_kwargs_names: Optional[Set[str]] = None,
352
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
353
+ dynamic_batch_for: Optional[Sequence[Union[int, str]]] = None,
354
+ expand_batch_for: Optional[Sequence[Union[int, str]]] = None,
355
+ ):
356
+ super().__init__()
357
+ self._model_to_call = mod
358
+ self._method_name = method_name
359
+ self._method_call = (
360
+ self._model_to_call.forward
361
+ if method_name == "forward"
362
+ else getattr(mod, method_name)
363
+ )
364
+ self._signature = inspect.signature(self._method_call)
365
+ self._inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = []
366
+ self._outputs: List[Any] = []
367
+ self._convert_after_n_calls = convert_after_n_calls
368
+ self._patch_kwargs = patch_kwargs
369
+ self._method_src = None
370
+ self.verbose = verbose
371
+ self.skip_kwargs_names = skip_kwargs_names
372
+ self.dynamic_shapes = dynamic_shapes
373
+ self.expand_batch_for = expand_batch_for
374
+ self.dynamic_batch_for = dynamic_batch_for
375
+ self._to_onnx_kwargs = dict(
376
+ input_names=input_names,
377
+ target_opset=target_opset,
378
+ verbose=verbose,
379
+ filename=filename,
380
+ output_names=output_names,
381
+ output_dynamic_shapes=output_dynamic_shapes,
382
+ exporter=exporter,
383
+ exporter_kwargs=exporter_kwargs,
384
+ save_ep=save_ep,
385
+ optimize=optimize,
386
+ optimizer_for_ort=optimizer_for_ort,
387
+ use_control_flow_dispatcher=use_control_flow_dispatcher,
388
+ onnx_plugs=onnx_plugs,
389
+ inline=inline,
390
+ )
391
+ self._export_done = False
392
+ self._serialization_classes: Set[type] = set()
393
+
394
+ def __str__(self) -> str:
395
+ "usual"
396
+ return self.__repr__()
397
+
398
+ def __repr__(self) -> str:
399
+ "usual"
400
+ return (
401
+ f"{self.__class__.__name__}({self._model_to_call.__class__.__name__}."
402
+ f"{self._method_name})"
403
+ )
404
+
405
+ def _collect_classes(self, obj):
406
+ if obj is None or isinstance(obj, torch.Tensor):
407
+ return
408
+ cls = type(obj)
409
+ if cls.__module__ not in ("builtins",):
410
+ self._serialization_classes.add(cls)
411
+ if hasattr(obj, "__dict__"):
412
+ for v in vars(obj).values():
413
+ self._collect_classes(v)
414
+ return
415
+ if isinstance(obj, Mapping):
416
+ for v in obj.values():
417
+ self._collect_classes(v)
418
+ return
419
+ if isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)):
420
+ for v in obj:
421
+ self._collect_classes(v)
422
+ return
423
+
424
+ def _reorder_kwargs(self, kwargs):
425
+ new_kwargs = {k: kwargs[k] for k in self._signature.parameters if k in kwargs}
426
+ for k, v in kwargs.items():
427
+ if k not in new_kwargs:
428
+ new_kwargs[k] = v
429
+ return new_kwargs
430
+
431
+ def forward(self, *args, **kwargs):
432
+ if not self._export_done:
433
+ inp_args = args
434
+ # filters out the inputs not desired, int, float, bool, None
435
+ # are considered as constant for the exporter, they are removed
436
+ # from the named arguments.
437
+ inp_kwargs = (
438
+ kwargs
439
+ if not kwargs
440
+ else {
441
+ k: v
442
+ for k, v in kwargs.items()
443
+ if v is not None
444
+ and (not self.skip_kwargs_names or k not in self.skip_kwargs_names)
445
+ and not isinstance(v, (bool, int, float))
446
+ }
447
+ )
448
+ if self.expand_batch_for:
449
+ # extends the inputs to artificially create a batch dimension != 1.
450
+ inp_args = self._expand_batch_dimension(inp_args, self.expand_batch_for)
451
+ inp_kwargs = self._expand_batch_dimension(inp_kwargs, self.expand_batch_for)
452
+ inp_args, inp_kwargs = torch_deepcopy((inp_args, inp_kwargs))
453
+ # reorders the parameter following the method signature.
454
+ inp_kwargs = self._reorder_kwargs(inp_kwargs)
455
+ # stores the inputs
456
+ self._inputs.append((inp_args, inp_kwargs))
457
+
458
+ if self.verbose:
459
+ print(
460
+ f"[method_to_onnx] input[{len(self._inputs)-1}]: "
461
+ f"{string_type(self._inputs[-1], with_shape=True)}"
462
+ )
463
+
464
+ if len(self._inputs) >= self._convert_after_n_calls:
465
+ # conversion starts after _convert_after_n_calls calls to the forward method
466
+ name = os.path.splitext(self._to_onnx_kwargs["filename"])[0]
467
+ input_file = f"{name}.inputs.pt"
468
+ self._input_file = input_file
469
+ if self.verbose:
470
+ print(
471
+ f"[method_to_onnx] save {len(self._inputs)} inputs in {input_file!r}"
472
+ )
473
+ torch.save(self._inputs, input_file)
474
+ self._convert_method_to_onnx()
475
+ self._export_done = True
476
+
477
+ # calls the inner method (no change here)
478
+ begin = time.perf_counter()
479
+ res = self._method_call(*args, **kwargs)
480
+ duration = time.perf_counter() - begin
481
+ self._collect_classes([args, kwargs, res])
482
+ if self._inputs:
483
+ # stores the outputs if discrepancies need to be checked
484
+ self._outputs.append((torch_deepcopy(res), duration))
485
+ assert len(self._inputs) == len(self._outputs), (
486
+ f"Number of inputs {len(self._inputs)} and "
487
+ f"outputs {len(self._outputs)} are different."
488
+ )
489
+ if self._export_done:
490
+ name = os.path.splitext(self._to_onnx_kwargs["filename"])[0]
491
+ output_file = f"{name}.outputs.pt"
492
+ if self.verbose:
493
+ print(
494
+ f"[method_to_onnx] save {len(self._outputs)} "
495
+ f"outputs in {output_file!r}"
496
+ )
497
+ torch.save(self._outputs, output_file)
498
+ self._output_file = output_file
499
+ del self._inputs[:]
500
+ del self._outputs[:]
501
+ return res
502
+
503
+ def _convert_method_to_onnx(self):
504
+ for args, kwargs in self._inputs:
505
+ self._serialization_classes |= {type(a) for a in args}
506
+ self._serialization_classes |= {type(a) for a in kwargs.values()}
507
+
508
+ def make_method(self):
509
+ inner_sig = inspect.signature(self._method_call)
510
+ params = [
511
+ p.replace(annotation=inspect._empty) for p in inner_sig.parameters.values()
512
+ ]
513
+ simple_sig = inspect.Signature(params, return_annotation=inspect._empty)
514
+ args = str(simple_sig)[1:-1]
515
+ calls_args = ", ".join(f"{p}={p}" for p in simple_sig.parameters)
516
+ src = textwrap.dedent(
517
+ f"""
518
+ def f(self, {args}):
519
+ return self._method_call({calls_args})
520
+ """
521
+ )
522
+ self._method_src = src
523
+ ns = {}
524
+ try:
525
+ exec(src, ns)
526
+ except NameError as e:
527
+ raise NameError(f"Unable to compile due to {e}\n{src}") from e
528
+ return ns["f"]
529
+
530
+ class WrapWithExactSignature(torch.nn.Module):
531
+ def __init__(self, parent):
532
+ super().__init__()
533
+ self._model_to_call = parent._model_to_call
534
+ self._method_call = parent._method_call
535
+
536
+ forward = make_method(self)
537
+
538
+ compiled_model = WrapWithExactSignature(self)
539
+
540
+ if self.dynamic_shapes is None:
541
+ mi = ModelInputs(compiled_model, self._inputs)
542
+ ds = mi.guess_dynamic_shapes()
543
+ if self.verbose:
544
+ print(f"[method_to_onnx] guess_dynamic_shapes={string_type(ds)}")
545
+ a, kw, nds = mi.move_to_kwargs(*self._inputs[-1], ds)
546
+ if self.dynamic_batch_for:
547
+ nds = (
548
+ self._dynamic_batch_dimension(nds[0], self.dynamic_batch_for),
549
+ self.rename_dynamic_shapes(
550
+ self._dynamic_batch_dimension(nds[1], self.dynamic_batch_for),
551
+ verbose=self.verbose,
552
+ ),
553
+ )
554
+ if self.verbose:
555
+ print(f"[method_to_onnx] dynamic_batch_for={self.dynamic_batch_for}")
556
+ print(f"[method_to_onnx] dynamic_shapes with batch={nds}")
557
+ else:
558
+ a, kw = self._inputs[-1]
559
+ nds = [self.dynamic_shapes]
560
+ if self.verbose:
561
+ print(f"[method_to_onnx] export args={string_type(a, with_shape=True)}")
562
+ print(f"[method_to_onnx] export kwargs={string_type(kw, with_shape=True)}")
563
+ print(f"[method_to_onnx] dynamic_shapes={string_type(nds)}")
564
+ if self._patch_kwargs is None:
565
+ to_onnx(
566
+ compiled_model,
567
+ args=a,
568
+ kwargs=kw,
569
+ dynamic_shapes=nds[-1],
570
+ **self._to_onnx_kwargs,
571
+ )
572
+ return
573
+ from ..torch_export_patches import torch_export_patches
574
+
575
+ with torch_export_patches(**self._patch_kwargs):
576
+ to_onnx(
577
+ compiled_model,
578
+ args=a,
579
+ kwargs=kw,
580
+ dynamic_shapes=nds[-1],
581
+ **self._to_onnx_kwargs,
582
+ )
583
+
584
+ @classmethod
585
+ def make_empty_cache_from_others(cls, examples: List[Any]) -> Any:
586
+ """Builds an empty cache based on existing one."""
587
+ unique_types = {type(t) for t in examples}
588
+ assert (
589
+ len(unique_types) == 1
590
+ ), f"Unable to guess an empty cache from {string_type(examples, with_shape=True)}"
591
+ unique_type = unique_types.pop()
592
+ if unique_type == torch.Tensor:
593
+ shapes = [t.shape for t in examples]
594
+ assert len(set(shapes)) > 1, f"Unable to guess an empty shape from shapes {shapes}"
595
+ ranks = {len(s) for s in shapes}
596
+ assert len(ranks) == 1, f"Ranks are different in {shapes}"
597
+ rank = ranks.pop()
598
+ new_shape = []
599
+ for i in range(rank):
600
+ dims = [t.shape[i] for t in examples]
601
+ if len(set(dims)) == 1:
602
+ new_shape.append(dims[0])
603
+ else:
604
+ # The empty shape
605
+ new_shape.append(0)
606
+ example = examples[0]
607
+ return torch.empty(tuple(new_shape), dtype=example.dtype, device=example.device)
608
+ assert (
609
+ unique_type.__name__ == "DynamicCache"
610
+ ), f"This is not implemented for class {unique_type}"
611
+ caches = [CacheKeyValue(dc) for dc in examples]
612
+ caches_list = [dc.aslist() for dc in caches]
613
+ empty = [
614
+ cls.make_empty_cache_from_others([caches_list[i][k] for i in range(len(examples))])
615
+ for k in range(len(caches_list[0]))
616
+ ]
617
+ empty_cache = CacheKeyValue(
618
+ empty, cls_layers=caches[0].cls_layers
619
+ ).make_dynamic_cache()
620
+ return empty_cache
621
+
622
+ @classmethod
623
+ def add_empty_cache_if_needed(cls, inputs: List[Any]) -> List[Any]:
624
+ """
625
+ Adds empty cache if needed as onnxruntime needs an empty cache,
626
+ not a missing cache. It only works if inputs are defined as a dictionary.
627
+ """
628
+ if all(isinstance(t, tuple) for t in inputs) and all(
629
+ len(t) == 2 and isinstance(t[0], tuple) and isinstance(t[1], dict) and not t[0]
630
+ for t in inputs
631
+ ):
632
+ dict_part = [t[1] for t in inputs]
633
+ res = cls.add_empty_cache_if_needed(dict_part)
634
+ return [(tuple(), d) for d in res]
635
+ if any(not isinstance(t, dict) for t in inputs):
636
+ return inputs
637
+ all_keys = set()
638
+ for input_set in inputs:
639
+ all_keys |= set(input_set)
640
+ # even though the inputs are defined as a dictionary, it is better
641
+ # to keep the same order
642
+ ordered = None
643
+ for input_set in inputs:
644
+ if set(input_set) == all_keys:
645
+ ordered = list(input_set)
646
+ break
647
+ new_inputs = []
648
+ for input_set in inputs:
649
+ if set(input_set) == all_keys:
650
+ new_inputs.append(input_set)
651
+ continue
652
+ missing = {k for k in all_keys if k not in input_set}
653
+ input_set_copy = input_set.copy()
654
+ for miss in missing:
655
+ input_set_copy[miss] = cls.make_empty_cache_from_others(
656
+ [sub[miss] for sub in inputs if miss in sub]
657
+ )
658
+ new_inputs.append({k: input_set_copy[k] for k in ordered}) # type: ignore[union-attr]
659
+ return new_inputs
660
+
661
+ @classmethod
662
+ def _expand_batch_dimension(cls, obj: Any, expand_for: Sequence[Union[int, str]]) -> Any:
663
+ expand_for_args = {i for i in expand_for if isinstance(i, int)}
664
+ expand_for_kwargs = {i for i in expand_for if isinstance(i, str)}
665
+ if isinstance(obj, tuple):
666
+ return tuple(
667
+ o if i not in expand_for_args else cls._expand_batch_dimension_input(o, i)
668
+ for i, o in enumerate(obj)
669
+ )
670
+ assert isinstance(obj, dict), f"Unexpected type {type(obj)}"
671
+ return {
672
+ k: v if k not in expand_for_kwargs else cls._expand_batch_dimension_input(v, k)
673
+ for k, v in obj.items()
674
+ }
675
+
676
+ @classmethod
677
+ def _expand_batch_dimension_input(cls, obj: Any, msg: Union[str, int]) -> Any:
678
+ if isinstance(obj, torch.Tensor):
679
+ assert obj.shape[0] == 1, (
680
+ f"Are you sure to expoand input {msg!r}, "
681
+ f"batch size is not 1 and shape={obj.shape}"
682
+ )
683
+ sizes = [2, *obj.shape[1:]]
684
+ return obj.expand(*sizes)
685
+ if isinstance(obj, list):
686
+ return [
687
+ cls._expand_batch_dimension_input(o, f"{msg}[{i}]") for i, o in enumerate(obj)
688
+ ]
689
+ if obj.__class__.__name__ == "DynamicCache":
690
+ dc = CacheKeyValue(obj)
691
+ flat = dc.aslist()
692
+ flat = cls._expand_batch_dimension_input(flat, msg)
693
+ return CacheKeyValue(flat, cls_layers=dc.cls_layers).make_dynamic_cache()
694
+ # This might end up in an infinite loop if no registration is done.
695
+ flat, _spec = torch.utils._pytree.tree_flatten(obj)
696
+ assert (
697
+ not isinstance(flat, list) or len(flat) != 1 or type(flat[0]) is not type(obj)
698
+ ), f"class {type(obj)} was is not registered for serialization."
699
+ flat = cls._expand_batch_dimension_input(flat, msg)
700
+ return torch.utils._pytree.tree_unflatten(flat, _spec)
701
+
702
+ @classmethod
703
+ def _dynamic_batch_dimension(
704
+ cls, ds: Union[Tuple[Any, ...], Dict[str, Any]], dynamic_for: Sequence[Union[int, str]]
705
+ ) -> Union[Tuple[Any, ...], Dict[str, Any]]:
706
+ if isinstance(ds, tuple):
707
+ return tuple(
708
+ (v if i not in dynamic_for else cls._dynamic_batch_dimension_input(v, i))
709
+ for i, v in enumerate(ds)
710
+ )
711
+ return {
712
+ k: (v if k not in dynamic_for else cls._dynamic_batch_dimension_input(v, k))
713
+ for k, v in ds.items()
714
+ }
715
+
716
+ @classmethod
717
+ def _dynamic_batch_dimension_input(cls, ds: Any, msg: Union[str, int]) -> Any:
718
+ if isinstance(ds, dict) and all(isinstance(k, int) for k in ds):
719
+ ds[0] = "batch"
720
+ return {k: v for k, v in sorted(ds.items())} # noqa: C416
721
+ if isinstance(ds, list):
722
+ return [
723
+ cls._dynamic_batch_dimension_input(o, f"{msg}[{i}]") for i, o in enumerate(ds)
724
+ ]
725
+ raise NotImplementedError(f"cannot make first dimension dynamic for batch for {ds}")
726
+
727
+ def check_discrepancies(
728
+ self, atol: float = 1e-4, rtol: float = 0.1, hist=(0.1, 0.01), verbose: int = 0
729
+ ) -> List[Dict[str, Union[str, int, float]]]:
730
+ """
731
+ Computes the discrepancies between the saved inputs and outputs
732
+ with the saved onnx model.
733
+
734
+ :param atol: absolute tolerance, recommended values, 1e-4 for float, 1e-2 flot float16
735
+ :param rtol: relative tolerance
736
+ :param hist: thresholds, the function determines the number of discrepancies
737
+ above that threshold.
738
+ :param verbose: verbosity
739
+ :return: results, a list of dictionaries, ready to be consumed by a dataframe
740
+ """
741
+ assert self._export_done, "The onnx export was not done."
742
+ assert os.path.exists(self._input_file), f"input file {self._input_file!r} not found"
743
+ assert os.path.exists(
744
+ self._output_file
745
+ ), f"output file {self._output_file!r} not found"
746
+ filename = self._to_onnx_kwargs["filename"]
747
+ assert isinstance(filename, str) and os.path.exists(
748
+ filename
749
+ ), f"onnx file {filename!r} not found"
750
+ classes = [
751
+ cls
752
+ for cls in self._serialization_classes
753
+ if cls not in {int, float, bool, str, torch.Tensor, list, set, dict, torch.device}
754
+ ]
755
+ if verbose:
756
+ print(f"[method_to_onnx.check_discrepancies] register classes {classes}")
757
+ print(f"[method_to_onnx.check_discrepancies] load {self._input_file!r}")
758
+ with torch.serialization.safe_globals(classes):
759
+ inputs = torch.load(self._input_file)
760
+ if verbose:
761
+ print(f"[method_to_onnx.check_discrepancies] load {self._output_file!r}")
762
+ with torch.serialization.safe_globals(classes):
763
+ outputs = torch.load(self._output_file)
764
+ assert len(inputs) == len(outputs), (
765
+ f"Unexpected number of inputs {len(inputs)} and outputs {len(outputs)}, "
766
+ f"inputs={string_type(inputs, with_shape=True)}, "
767
+ f"outputs={string_type(outputs, with_shape=True)}"
768
+ )
769
+ if verbose:
770
+ print(f"[method_to_onnx.check_discrepancies] create onnx session {filename!r}")
771
+ sess = OnnxruntimeEvaluator(filename, whole=True)
772
+ input_names = sess.input_names
773
+ if verbose:
774
+ print(f"[method_to_onnx.check_discrepancies] input_names={input_names}")
775
+ print(
776
+ f"[method_to_onnx.check_discrepancies] onnx_shapes="
777
+ f"{', '.join(pretty_onnx(i) for i in sess.input_types)}"
778
+ )
779
+ data = []
780
+ for i, (input, (output, latency)) in enumerate(
781
+ zip(self.add_empty_cache_if_needed(inputs), outputs)
782
+ ):
783
+ if verbose:
784
+ if verbose > 1:
785
+ print(
786
+ f"[method_to_onnx.check_discrepancies] process input {i}: "
787
+ f"{string_type(input, with_shape=True)}"
788
+ )
789
+ print(
790
+ f"[method_to_onnx.check_discrepancies] expects: "
791
+ f"{string_type(output, with_shape=True)}"
792
+ )
793
+ else:
794
+ print(
795
+ f"[method_to_onnx.check_discrepancies] process input {i} "
796
+ f"#args={len(input[0])} #kwargs={len(input[1])}"
797
+ )
798
+
799
+ flat_inputs = flatten_object(input, drop_keys=True)
800
+ if verbose > 1:
801
+ print(
802
+ f"[method_to_onnx.check_discrepancies] "
803
+ f"input={string_type(input, with_shape=True)}"
804
+ )
805
+ print(
806
+ f"[method_to_onnx.check_discrepancies] "
807
+ f"flat_inputs={string_type(flat_inputs, with_shape=True)}"
808
+ )
809
+ if len(flat_inputs) < len(input_names):
810
+ # not implemented yet, it is caused by a missing cache,
811
+ # which requires an empty cache instead
812
+ data.append(dict(index=i, duration_torch=latency, n_inputs=len(flat_inputs)))
813
+ continue
814
+ assert len(flat_inputs) == len(input_names), (
815
+ f"Length mismatch, expecting {len(input_names)} onnx inputs and got "
816
+ f"{len(flat_inputs)} flat torch inputs"
817
+ )
818
+ feeds = make_feeds(input_names, flat_inputs)
819
+ if verbose > 1:
820
+ print(
821
+ f"[method_to_onnx.check_discrepancies] "
822
+ f"feeds={string_type(feeds, with_shape=True)}"
823
+ )
824
+ begin = time.perf_counter()
825
+ ort_outputs = sess.run(None, feeds)
826
+ duration = time.perf_counter() - begin
827
+ diff = max_diff(output, ort_outputs, hist=hist)
828
+ if "rep" in diff and isinstance(diff["rep"], dict):
829
+ diff.update(diff["rep"])
830
+ del diff["rep"]
831
+ diff["SUCCESS"] = (
832
+ isinstance(diff["abs"], float)
833
+ and isinstance(diff["rel"], float)
834
+ and diff["abs"] < atol
835
+ and diff["rel"] < rtol
836
+ )
837
+ diff.update(
838
+ dict(
839
+ index=i,
840
+ duration_torch=latency,
841
+ ort_duration=duration,
842
+ n_inputs=len(flat_inputs),
843
+ )
844
+ )
845
+ if verbose > 1:
846
+ print(
847
+ f"[method_to_onnx.check_discrepancies] ort output "
848
+ f"{string_type(ort_outputs, with_shape=True)}"
849
+ )
850
+ print(f"[method_to_onnx.check_discrepancies] diff {string_diff(diff)}")
851
+ data.append(diff)
852
+ if verbose:
853
+ print("[method_to_onnx.check_discrepancies] done")
854
+ return data
855
+
856
+ @classmethod
857
+ def _apply_known_shape_pattern(
858
+ cls, shape: Dict[int, Any], pattern: Dict[int, str]
859
+ ) -> Dict[int, Any]:
860
+ return {k: pattern.get(k, v) for k, v in shape.items()}
861
+
862
+ @classmethod
863
+ def get_dynamic_shape_patterns(cls) -> Dict[str, Any]:
864
+ """
865
+ Returns the known patterns for the dynamic shapes.
866
+
867
+ .. runpython::
868
+ :showcode:
869
+
870
+ import pprint
871
+ from onnx_diagnostic.export.api import WrapperToExportMethodToOnnx
872
+ pprint.pprint(WrapperToExportMethodToOnnx.get_dynamic_shape_patterns())
873
+ """
874
+ return {
875
+ "LLM.text": {
876
+ "cache_position": {0: "seqlength"},
877
+ "past_key_values": {0: "batch", 2: "pastlength"},
878
+ "input_ids": {0: "batch", 1: "seqlength"},
879
+ "attention_mask": {0: "batch", 1: "totallength"}, # pastlength+seqlength
880
+ }
881
+ }
882
+
883
+ @classmethod
884
+ def rename_dynamic_shapes(cls, ds: Dict[str, Any], verbose: int = 0) -> Dict[str, Any]:
885
+ """
886
+ Renames the dynamic shapes with names.
887
+ Tries to rename any dynamic dimnesion dimension
888
+ before export. It is not very clever, it just tries
889
+ to recognize a known configuration based on input names.
890
+ Dimension names in dynamic shapes are renamed if *ds* has
891
+ the same number of named arguments as the one of the patterns
892
+ returned by function :meth:`get_dynamic_shape_patterns
893
+ <onnx_diagnostic.export.api.WrapperToExportMethodToOnnx.get_dynamic_shape_patterns>`.
894
+ """
895
+ is_shape = lambda s: isinstance(s, dict) and all( # noqa: E731
896
+ isinstance(_, int) for _ in s
897
+ )
898
+ llm_patterns = cls.get_dynamic_shape_patterns()
899
+ for pattern_name, pattern_shape in llm_patterns.items():
900
+ if len(set(ds) & set(pattern_shape)) == len(pattern_shape):
901
+ if verbose:
902
+ print(
903
+ f"[method_to_onnx.rename_dynamic_shapes] "
904
+ f"apply pattern shapes {pattern_name!r}"
905
+ )
906
+ new_ds = {}
907
+ for k, v in ds.items():
908
+ if k not in pattern_shape:
909
+ new_ds[k] = v
910
+ continue
911
+ if is_shape(v):
912
+ # A shape
913
+ new_ds[k] = cls._apply_known_shape_pattern(v, pattern_shape[k])
914
+ elif isinstance(v, list):
915
+ # A cache
916
+ new_ds[k] = [
917
+ (
918
+ cls._apply_known_shape_pattern(s, pattern_shape[k])
919
+ if is_shape(s)
920
+ else s
921
+ )
922
+ for s in v
923
+ ]
924
+ return new_ds
925
+
926
+ # unchanged
927
+ return ds
928
+
929
+
930
+ def method_to_onnx(
931
+ mod: "torch.nn.Module",
932
+ method_name: str = "forward",
933
+ input_names: Optional[Sequence[str]] = None,
934
+ target_opset: Optional[Union[int, Dict[str, int]]] = None,
935
+ verbose: int = 0,
936
+ filename: Optional[str] = None,
937
+ output_names: Optional[List[str]] = None,
938
+ output_dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
939
+ exporter: str = "onnx-dynamo",
940
+ exporter_kwargs: Optional[Dict[str, Any]] = None,
941
+ save_ep: Optional[str] = None,
942
+ optimize: bool = True,
943
+ optimizer_for_ort: bool = True,
944
+ use_control_flow_dispatcher: bool = False,
945
+ onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
946
+ inline: bool = True,
947
+ convert_after_n_calls: int = 2,
948
+ patch_kwargs: Optional[Dict[str, Any]] = None,
949
+ skip_kwargs_names: Optional[Set[str]] = None,
950
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
951
+ dynamic_batch_for: Optional[Sequence[Union[int, str]]] = None,
952
+ expand_batch_for: Optional[Sequence[Union[int, str]]] = None,
953
+ ) -> Callable:
954
+ """
955
+ Exports one method into ONNX for a module into ONNX.
956
+ It returns a new method which must be called by the user
957
+ at least twice with different values for the dynamic dimension
958
+ between triggering the conversion into ONNX.
959
+
960
+ :param mod_meth: function to export into ONNX
961
+ :param input_names: input names for the onnx model (optional)
962
+ :param target_opset: opset to target, if not specified, each converter
963
+ keeps its default value
964
+ :param verbose: verbosity level
965
+ :param filename: output filename, mandatory, the onnx model is saved on disk
966
+ :param output_names: to change the output of the onnx model
967
+ :param output_dynamic_shapes: to overwrite the dynamic shapes names
968
+ :param exporter: exporter to use (``onnx-dynamo``, ``modelbuilder``, ``custom``)
969
+ :param exporter_kwargs: additional parameters sent to the exporter
970
+ :param save_ep: saves the exported program
971
+ :param optimize: optimizes the model
972
+ :param optimizer_for_ort: optimizes the model for onnxruntime
973
+ :param use_control_flow_dispatcher: use the dispatcher created to supported
974
+ custom loops (see :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx`)
975
+ :param onnx_plugs: the code was modified to replace some parts with onnx translation
976
+ :param inline: inline local functions
977
+ :param convert_after_n_calls: converts the model after this number of calls.
978
+ :param patch_kwargs: patch arguments
979
+ :param skip_kwargs_names: use default values for these parameters part of
980
+ the signature of the method to export
981
+ :param dynamic_shapes: dynamic shapes to use if the guessed ones are not right
982
+ :param dynamic_batch_for: LLM are usually called with a batch size equal to 1,
983
+ but the export may benefit from having a dynamic batch size,
984
+ this parameter forces the input specified in this set to have the first dimension
985
+ be dynamic
986
+ :param expand_batch_for: LLM are usually called with a batch size equal to 1,
987
+ but the export may benefit from having another value for the batch size,
988
+ this parameter forces the input specified in this set to be expanded
989
+ to 2 if the batch size is one
990
+ :return: the output of the selected exporter, usually a structure including
991
+ an onnx model
992
+
993
+ See :ref:`l-plot-tiny-llm-export-method-generate` for an example.
994
+ """
995
+ wrapped_model = WrapperToExportMethodToOnnx(
996
+ mod=mod,
997
+ method_name=method_name,
998
+ input_names=input_names,
999
+ target_opset=target_opset,
1000
+ verbose=verbose,
1001
+ filename=filename,
1002
+ output_names=output_names,
1003
+ output_dynamic_shapes=output_dynamic_shapes,
1004
+ exporter=exporter,
1005
+ exporter_kwargs=exporter_kwargs,
1006
+ save_ep=save_ep,
1007
+ optimize=optimize,
1008
+ optimizer_for_ort=optimizer_for_ort,
1009
+ use_control_flow_dispatcher=use_control_flow_dispatcher,
1010
+ onnx_plugs=onnx_plugs,
1011
+ inline=inline,
1012
+ convert_after_n_calls=convert_after_n_calls,
1013
+ patch_kwargs=patch_kwargs,
1014
+ skip_kwargs_names=skip_kwargs_names,
1015
+ dynamic_shapes=dynamic_shapes,
1016
+ dynamic_batch_for=dynamic_batch_for,
1017
+ expand_batch_for=expand_batch_for,
1018
+ )
1019
+ return wrapped_model