onnx-diagnostic 0.7.16__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +78 -22
  3. onnx_diagnostic/export/api.py +124 -0
  4. onnx_diagnostic/export/dynamic_shapes.py +2 -1
  5. onnx_diagnostic/export/shape_helper.py +47 -70
  6. onnx_diagnostic/ext_test_case.py +11 -0
  7. onnx_diagnostic/helpers/cache_helper.py +38 -7
  8. onnx_diagnostic/helpers/fake_tensor_helper.py +224 -104
  9. onnx_diagnostic/helpers/helper.py +27 -33
  10. onnx_diagnostic/helpers/log_helper.py +109 -5
  11. onnx_diagnostic/helpers/memory_peak.py +2 -0
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +1 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +132 -2
  14. onnx_diagnostic/helpers/onnx_helper.py +1 -1
  15. onnx_diagnostic/helpers/ort_session.py +4 -0
  16. onnx_diagnostic/helpers/rt_helper.py +393 -43
  17. onnx_diagnostic/helpers/torch_helper.py +20 -1
  18. onnx_diagnostic/tasks/__init__.py +7 -0
  19. onnx_diagnostic/tasks/automatic_speech_recognition.py +2 -8
  20. onnx_diagnostic/tasks/feature_extraction.py +2 -8
  21. onnx_diagnostic/tasks/image_text_to_text.py +10 -8
  22. onnx_diagnostic/tasks/summarization.py +2 -8
  23. onnx_diagnostic/tasks/text2text_generation.py +3 -8
  24. onnx_diagnostic/tasks/text_generation.py +86 -65
  25. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +718 -438
  26. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  27. onnx_diagnostic/torch_export_patches/patch_inputs.py +1 -1
  28. onnx_diagnostic/torch_export_patches/patch_module.py +9 -36
  29. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -6
  30. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +162 -24
  31. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +140 -104
  32. onnx_diagnostic/torch_models/untrained/llm_phi2.py +1 -4
  33. onnx_diagnostic/torch_models/validate.py +626 -228
  34. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/METADATA +1 -1
  35. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/RECORD +38 -36
  36. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/WHEEL +0 -0
  37. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/licenses/LICENSE.txt +0 -0
  38. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/top_level.txt +0 -0
@@ -1,22 +1,25 @@
1
- import gc
2
1
  import datetime
2
+ import gc
3
3
  import inspect
4
4
  import os
5
5
  import pprint
6
6
  import sys
7
- from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
8
7
  import time
8
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
9
9
  import numpy as np
10
10
  import onnx
11
11
  import torch
12
12
  from ..export import CoupleInputsDynamicShapes
13
13
  from ..helpers import max_diff, string_type, string_diff
14
14
  from ..helpers.helper import flatten_object
15
- from ..helpers.rt_helper import make_feeds, reorder_modelbuilder_cache_to_torch
15
+ from ..helpers.rt_helper import make_feeds
16
16
  from ..helpers.torch_helper import to_any, torch_deepcopy
17
17
  from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
18
18
  from ..tasks import random_input_kwargs
19
- from ..torch_export_patches import torch_export_patches
19
+ from ..torch_export_patches import (
20
+ torch_export_patches,
21
+ register_additional_serialization_functions,
22
+ )
20
23
  from ..torch_export_patches.patch_inputs import use_dyn_not_str
21
24
  from .hghub import get_untrained_model_with_inputs
22
25
  from .hghub.model_inputs import _preprocess_model_id
@@ -270,8 +273,8 @@ def _quiet_or_not_quiet(
270
273
  summary[f"time_{suffix}_latency_std"] = a.std()
271
274
  summary[f"time_{suffix}_latency_min"] = a.min()
272
275
  summary[f"time_{suffix}_latency_max"] = a.max()
273
- summary[f"time_{suffix}_latency_098"] = a[-i2]
274
- summary[f"time_{suffix}_latency_095"] = a[-i5]
276
+ summary[f"time_{suffix}_latency_098"] = a[-(max(i2, 1))]
277
+ summary[f"time_{suffix}_latency_095"] = a[-max(i5, 1)]
275
278
  summary[f"time_{suffix}_latency_005"] = a[i5]
276
279
  summary[f"time_{suffix}_latency_002"] = a[i2]
277
280
  summary[f"time_{suffix}_n"] = len(a)
@@ -320,125 +323,33 @@ def make_patch_kwargs(
320
323
  return patch_kwargs
321
324
 
322
325
 
323
- def validate_model(
324
- model_id: str,
325
- task: Optional[str] = None,
326
- do_run: bool = False,
327
- exporter: Optional[str] = None,
328
- do_same: bool = False,
329
- verbose: int = 0,
330
- dtype: Optional[Union[str, torch.dtype]] = None,
331
- device: Optional[Union[str, torch.device]] = None,
332
- same_as_pretrained: bool = False,
333
- use_pretrained: bool = False,
334
- optimization: Optional[str] = None,
335
- quiet: bool = False,
336
- patch: Union[bool, str, Dict[str, bool]] = False,
337
- rewrite: bool = False,
338
- stop_if_static: int = 1,
339
- dump_folder: Optional[str] = None,
340
- drop_inputs: Optional[List[str]] = None,
341
- ortfusiontype: Optional[str] = None,
342
- input_options: Optional[Dict[str, Any]] = None,
343
- model_options: Optional[Dict[str, Any]] = None,
344
- subfolder: Optional[str] = None,
345
- opset: Optional[int] = None,
346
- runtime: str = "onnxruntime",
347
- repeat: int = 1,
348
- warmup: int = 0,
349
- inputs2: int = 1,
350
- output_names: Optional[List[str]] = None,
351
- ort_logs: bool = False,
352
- quiet_input_sets: Optional[Set[str]] = None,
353
- ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
354
- """
355
- Validates a model.
356
- The function can also be called through the command line
357
- :ref:`l-cmd-validate`.
358
-
359
- :param model_id: model id to validate
360
- :param task: task used to generate the necessary inputs,
361
- can be left empty to use the default task for this model
362
- if it can be determined
363
- :param do_run: checks the model works with the defined inputs
364
- :param exporter: exporter the model using this exporter,
365
- available list: ``export-strict``, ``export-nostrict``, ...
366
- see below
367
- :param do_same: checks the discrepancies of the exported model
368
- :param verbose: verbosity level
369
- :param dtype: uses this dtype to check the model
370
- :param device: do the verification on this device
371
- :param same_as_pretrained: use a model equivalent to the trained,
372
- this is not always possible
373
- :param use_pretrained: use the trained model, not the untrained one
374
- :param optimization: optimization to apply to the exported model,
375
- depend on the the exporter
376
- :param quiet: if quiet, catches exception if any issue
377
- :param patch: applies patches (``patch_transformers=True, path_diffusers=True``)
378
- if True before exporting
379
- see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`,
380
- a string can be used to specify only one of them
381
- :param rewrite: applies known rewriting (``patch_transformers=True``) before exporting,
382
- see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
383
- :param stop_if_static: stops if a dynamic dimension becomes static,
384
- see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
385
- :param dump_folder: dumps everything in a subfolder of this one
386
- :param drop_inputs: drops this list of inputs (given their names)
387
- :param ortfusiontype: runs ort fusion, the parameters defines the fusion type,
388
- it accepts multiple values separated by ``|``,
389
- see :func:`onnx_diagnostic.torch_models.validate.run_ort_fusion`
390
- :param input_options: additional options to define the dummy inputs
391
- used to export
392
- :param model_options: additional options when creating the model such as
393
- ``num_hidden_layers`` or ``attn_implementation``
394
- :param subfolder: version or subfolders to uses when retrieving a model id
395
- :param opset: onnx opset to use for the conversion
396
- :param runtime: onnx runtime to use to check about discrepancies,
397
- possible values ``onnxruntime``, ``torch``, ``orteval``,
398
- ``orteval10``, ``ref`` only if `do_run` is true
399
- :param repeat: number of time to measure the model
400
- :param warmup: warmup the model first
401
- :param inputs2: checks that other sets of inputs are running as well,
402
- this ensures that the model does support dynamism, the value is used
403
- as an increment to the first set of values (added to dimensions),
404
- or an empty cache for example
405
- :param output_names: output names the onnx exporter should use
406
- :param ort_logs: increases onnxruntime verbosity when creating the session
407
- :param quiet_input_sets: avoid raising an exception if the inputs belongs to that set
408
- even if quiet is False
409
- :return: two dictionaries, one with some metrics,
410
- another one with whatever the function produces
411
-
412
- The following environment variables can be used to print out some
413
- information:
414
-
415
- * ``PRINT_CONFIG``: prints the model configuration
416
-
417
- The following exporters are available:
418
-
419
- * ``export-nostrict``: run :func:`torch.export.export` (..., strict=False)
420
- * ``onnx-dynamo``: run :func:`torch.onnx.export` (...),
421
- models can be optimized with ``optimization`` in ``("ir", "os_ort")``
422
- * ``modelbuilder``: use :epkg:`ModelBuilder` to builds the onnx model
423
- * ``custom``: custom exporter (see :epkg:`experimental-experiment`),
424
- models can be optimized with ``optimization`` in
425
- ``("default", "default+onnxruntime", "default+os_ort", "default+onnxruntime+os_ort")``
426
-
427
- The default runtime, :epkg:`onnxruntime` is used to validate a model and check the
428
- exported model returns the same outputs as the original one, otherwise,
429
- :class:`onnx_diagnostic.reference.TorchOnnxEvaluator`
430
- if ``runtime == 'torch'`` or
431
- :class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`
432
- if ``runtime == 'orteval'`` or
433
- :class:`onnx_diagnostic.reference.ExtendedReferenceEvaluator`
434
- if ``runtime == 'ref'``,
435
- ``orteval10`` increases the verbosity.
436
-
437
- .. versionchanged:: 0.7.13
438
- *inputs2* not only means a second set of inputs but many
439
- such as ``input_empty_cache``
440
- which refers to a set of inputs using an empty cache.
441
- """
326
+ def _prepare_validation(
327
+ model_id,
328
+ subfolder,
329
+ same_as_pretrained,
330
+ use_pretrained,
331
+ patch,
332
+ rewrite,
333
+ do_run,
334
+ dtype,
335
+ device,
336
+ optimization,
337
+ quiet,
338
+ drop_inputs,
339
+ ortfusiontype,
340
+ stop_if_static,
341
+ exporter,
342
+ runtime,
343
+ inputs2,
344
+ input_options,
345
+ model_options,
346
+ exporter_options,
347
+ opset,
348
+ task,
349
+ verbose,
350
+ output_names,
351
+ dump_folder,
352
+ ):
442
353
  main_validation_begin = time.perf_counter()
443
354
  model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
444
355
  model_id,
@@ -470,6 +381,10 @@ def validate_model(
470
381
  version_exporter=exporter or "",
471
382
  version_runtime=runtime,
472
383
  version_inputs2=inputs2,
384
+ version_input_options=str(input_options),
385
+ version_drop_input=str(drop_inputs),
386
+ version_model_options=str(model_options),
387
+ version_exporter_options=str(exporter_options),
473
388
  time_preprocess_model_id=time_preprocess_model_id,
474
389
  )
475
390
  )
@@ -520,6 +435,32 @@ def validate_model(
520
435
  summary["model_id"] = model_id
521
436
  summary["model_subfolder"] = subfolder or ""
522
437
 
438
+ return (
439
+ summary,
440
+ model_id,
441
+ subfolder,
442
+ same_as_pretrained,
443
+ use_pretrained,
444
+ dump_folder,
445
+ folder_name,
446
+ patch_kwargs,
447
+ )
448
+
449
+
450
+ def _get_untrained_model_with_inputs(
451
+ summary,
452
+ model_id,
453
+ verbose,
454
+ task,
455
+ use_pretrained,
456
+ same_as_pretrained,
457
+ input_options,
458
+ model_options,
459
+ subfolder,
460
+ inputs2,
461
+ quiet,
462
+ dump_folder,
463
+ ):
523
464
  iop = input_options or {}
524
465
  mop = model_options or {}
525
466
  data = _quiet_or_not_quiet(
@@ -544,8 +485,6 @@ def validate_model(
544
485
  ),
545
486
  )
546
487
 
547
- second_input_keys = [k for k in data if k.startswith("inputs") and k != "inputs"]
548
-
549
488
  if dump_folder:
550
489
  with open(os.path.join(dump_folder, "model_config.txt"), "w") as f:
551
490
  f.write(f"model_id: {model_id}\n------\n")
@@ -562,30 +501,45 @@ def validate_model(
562
501
  f.write(f"model_id: {model_id}\n------\n")
563
502
  f.write(pprint.pformat(dump_info))
564
503
 
565
- if exporter == "modelbuilder":
566
- # Models used with ModelBuilder do not like batch size > 1.
567
- # Let's change that.
568
- for k in ["inputs", "inputs2"]:
569
- if k not in data:
570
- continue
571
- if verbose:
572
- print(f"[validate_model] set batch=1 for data[{k!r}]")
573
- print(f"[validate_model] batch=1 === {string_type(data[k], with_shape=True)}")
574
- cpl = CoupleInputsDynamicShapes(
575
- tuple(), data[k], dynamic_shapes=data["dynamic_shapes"]
576
- )
577
- if patch_kwargs.get("patch", False):
578
- with torch_export_patches(**patch_kwargs): # type: ignore[arg-type]
579
- data[k] = cpl.change_dynamic_dimensions(
580
- desired_values=dict(batch=1), only_desired=True
581
- )
582
- else:
583
- data[k] = cpl.change_dynamic_dimensions(
584
- desired_values=dict(batch=1), only_desired=True
585
- )
586
- if verbose:
587
- print(f"[validate_model] batch=1 --> {string_type(data[k], with_shape=True)}")
504
+ return data, iop, mop
505
+
588
506
 
507
+ def _update_data_for_modelbuilder(data, verbose):
508
+ # Models used with ModelBuilder do not like batch size > 1.
509
+ # Let's change that.
510
+ for k in ["inputs", "inputs2"]:
511
+ if k not in data:
512
+ continue
513
+ if verbose:
514
+ print(f"[validate_model] set batch=1 for data[{k!r}]")
515
+ print(f"[validate_model] batch=1 === {string_type(data[k], with_shape=True)}")
516
+ cpl = CoupleInputsDynamicShapes(
517
+ tuple(), data[k], dynamic_shapes=data["dynamic_shapes"]
518
+ )
519
+ with register_additional_serialization_functions(patch_transformers=True): # type: ignore[arg-type]
520
+ data[k] = cpl.change_dynamic_dimensions(
521
+ desired_values=dict(batch=1), only_desired=True
522
+ )
523
+ if verbose:
524
+ print(f"[validate_model] batch=1 --> {string_type(data[k], with_shape=True)}")
525
+
526
+
527
+ def _update_inputs_outputs(
528
+ data,
529
+ summary,
530
+ exporter,
531
+ iop,
532
+ mop,
533
+ dump_folder,
534
+ opset,
535
+ device,
536
+ dtype,
537
+ rewrite,
538
+ drop_inputs,
539
+ verbose,
540
+ second_input_keys,
541
+ model_id,
542
+ ):
589
543
  # modelbuilder needs different treatments sometimes, so
590
544
  # we mark it for later usage.
591
545
  # for example, it has different past_kv ordering than
@@ -672,7 +626,7 @@ def validate_model(
672
626
  for k in ["task", "size", "n_weights"]:
673
627
  summary[f"model_{k.replace('_','')}"] = data[k]
674
628
  summary["second_input_keys"] = ",".join(second_input_keys)
675
- summary["model_inputs_options"] = str(input_options or "")
629
+ summary["model_inputs_options"] = str(iop or "")
676
630
  summary["model_inputs"] = string_type(data["inputs"], with_shape=True)
677
631
  summary["model_shapes"] = string_type(data["dynamic_shapes"])
678
632
  summary["model_class"] = data["model"].__class__.__name__
@@ -689,6 +643,8 @@ def validate_model(
689
643
  ).replace(" ", "")
690
644
  summary["model_id"] = model_id
691
645
 
646
+
647
+ def _verbose_validate(data, second_input_keys, verbose):
692
648
  if verbose:
693
649
  print("[validate_model] --")
694
650
  print(f"[validate_model] task={data['task']}")
@@ -701,33 +657,30 @@ def validate_model(
701
657
  print(f"[validate_model] second_input_keys={second_input_keys}")
702
658
  print("[validate_model] --")
703
659
 
704
- if do_run:
705
- validation_begin = time.perf_counter()
706
-
707
- _validate_do_run_model(
708
- data, summary, "inputs", "run", "run_expected", verbose, repeat, warmup, quiet
709
- )
710
- if second_input_keys:
711
- for k in second_input_keys:
712
- _validate_do_run_model(
713
- data,
714
- summary,
715
- k,
716
- f"run2{k[6:]}",
717
- f"run_expected2{k[6:]}",
718
- verbose,
719
- 1,
720
- 0,
721
- quiet,
722
- )
723
-
724
- summary["time_total_validation_torch"] = time.perf_counter() - validation_begin
725
660
 
661
+ def _call_exporter(
662
+ data,
663
+ summary,
664
+ exporter,
665
+ patch_kwargs,
666
+ stop_if_static,
667
+ verbose,
668
+ dump_folder,
669
+ quiet,
670
+ optimization,
671
+ do_run,
672
+ output_names,
673
+ exporter_options,
674
+ ):
726
675
  if exporter:
727
- print(
728
- f"[validate_model] -- export the model with {exporter!r}, "
729
- f"optimization={optimization!r}"
730
- )
676
+ expop = exporter_options or {}
677
+ if verbose:
678
+ print(
679
+ f"[validate_model] -- export the model with {exporter!r}, "
680
+ f"optimization={optimization!r}"
681
+ )
682
+ if expop:
683
+ print(f"[validate_model] -- exporter options {expop}")
731
684
  exporter_begin = time.perf_counter()
732
685
  if patch_kwargs:
733
686
  if verbose:
@@ -757,6 +710,7 @@ def validate_model(
757
710
  do_run=do_run,
758
711
  dump_folder=dump_folder,
759
712
  output_names=output_names,
713
+ exporter_options=expop,
760
714
  )
761
715
  else:
762
716
  data["inputs_export"] = data["inputs"]
@@ -770,11 +724,14 @@ def validate_model(
770
724
  do_run=do_run,
771
725
  dump_folder=dump_folder,
772
726
  output_names=output_names,
727
+ exporter_options=expop,
773
728
  )
774
729
 
775
730
  summary.update(summary_export)
776
731
  summary["time_total_exporter"] = time.perf_counter() - exporter_begin
777
732
 
733
+
734
+ def _dump_onnx_model(data, summary, dump_folder, verbose, exporter, folder_name):
778
735
  dump_stats = None
779
736
  if dump_folder:
780
737
  if "exported_program" in data:
@@ -839,26 +796,392 @@ def validate_model(
839
796
  ):
840
797
  if verbose:
841
798
  print("[validate_model] -- done (final)")
842
- if dump_stats:
843
- with open(dump_stats, "w") as f:
844
- for k, v in sorted(summary.items()):
845
- f.write(f":{k}:{v};\n")
799
+ return False, dump_stats
800
+ return True, dump_stats
801
+
802
+
803
+ def validate_model(
804
+ model_id: str,
805
+ task: Optional[str] = None,
806
+ do_run: bool = False,
807
+ exporter: Optional[str] = None,
808
+ do_same: bool = False,
809
+ verbose: int = 0,
810
+ dtype: Optional[Union[str, torch.dtype]] = None,
811
+ device: Optional[Union[str, torch.device]] = None,
812
+ same_as_pretrained: bool = False,
813
+ use_pretrained: bool = False,
814
+ optimization: Optional[str] = None,
815
+ quiet: bool = False,
816
+ patch: Union[bool, str, Dict[str, bool]] = False,
817
+ rewrite: bool = False,
818
+ stop_if_static: int = 1,
819
+ dump_folder: Optional[str] = None,
820
+ drop_inputs: Optional[List[str]] = None,
821
+ ortfusiontype: Optional[str] = None,
822
+ input_options: Optional[Dict[str, Any]] = None,
823
+ model_options: Optional[Dict[str, Any]] = None,
824
+ exporter_options: Optional[Dict[str, Any]] = None,
825
+ subfolder: Optional[str] = None,
826
+ opset: Optional[int] = None,
827
+ runtime: str = "onnxruntime",
828
+ repeat: int = 1,
829
+ warmup: int = 0,
830
+ inputs2: int = 1,
831
+ output_names: Optional[List[str]] = None,
832
+ ort_logs: bool = False,
833
+ quiet_input_sets: Optional[Set[str]] = None,
834
+ ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
835
+ """
836
+ Validates a model.
837
+ The function can also be called through the command line
838
+ :ref:`l-cmd-validate`.
839
+
840
+ :param model_id: model id to validate
841
+ :param task: task used to generate the necessary inputs,
842
+ can be left empty to use the default task for this model
843
+ if it can be determined
844
+ :param do_run: checks the model works with the defined inputs
845
+ :param exporter: exporter the model using this exporter,
846
+ available list: ``export-strict``, ``export-nostrict``, ...
847
+ see below
848
+ :param do_same: checks the discrepancies of the exported model
849
+ :param verbose: verbosity level
850
+ :param dtype: uses this dtype to check the model
851
+ :param device: do the verification on this device
852
+ :param same_as_pretrained: use a model equivalent to the trained,
853
+ this is not always possible
854
+ :param use_pretrained: use the trained model, not the untrained one
855
+ :param optimization: optimization to apply to the exported model,
856
+ depend on the the exporter
857
+ :param quiet: if quiet, catches exception if any issue
858
+ :param patch: applies patches (``patch_transformers=True, path_diffusers=True``)
859
+ if True before exporting
860
+ see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`,
861
+ a string can be used to specify only one of them
862
+ :param rewrite: applies known rewriting (``patch_transformers=True``) before exporting,
863
+ see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
864
+ :param stop_if_static: stops if a dynamic dimension becomes static,
865
+ see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
866
+ :param dump_folder: dumps everything in a subfolder of this one
867
+ :param drop_inputs: drops this list of inputs (given their names)
868
+ :param ortfusiontype: runs ort fusion, the parameters defines the fusion type,
869
+ it accepts multiple values separated by ``|``,
870
+ see :func:`onnx_diagnostic.torch_models.validate.run_ort_fusion`
871
+ :param input_options: additional options to define the dummy inputs
872
+ used to export
873
+ :param model_options: additional options when creating the model such as
874
+ ``num_hidden_layers`` or ``attn_implementation``
875
+ :param exporter_options: additional options when exporting the model such as
876
+ ``report=True`` or ``verify=True``
877
+ :param subfolder: version or subfolders to uses when retrieving a model id
878
+ :param opset: onnx opset to use for the conversion
879
+ :param runtime: onnx runtime to use to check about discrepancies,
880
+ possible values ``onnxruntime``, ``torch``, ``orteval``,
881
+ ``orteval10``, ``ref`` only if `do_run` is true
882
+ :param repeat: number of time to measure the model
883
+ :param warmup: warmup the model first
884
+ :param inputs2: checks that other sets of inputs are running as well,
885
+ this ensures that the model does support dynamism, the value is used
886
+ as an increment to the first set of values (added to dimensions),
887
+ or an empty cache for example
888
+ :param output_names: output names the onnx exporter should use
889
+ :param ort_logs: increases onnxruntime verbosity when creating the session
890
+ :param quiet_input_sets: avoid raising an exception if the inputs belongs to that set
891
+ even if quiet is False
892
+ :return: two dictionaries, one with some metrics,
893
+ another one with whatever the function produces
894
+
895
+ The following environment variables can be used to print out some
896
+ information:
897
+
898
+ * ``PRINT_CONFIG``: prints the model configuration
899
+
900
+ The following exporters are available:
901
+
902
+ * ``export-nostrict``: run :func:`torch.export.export` (..., strict=False)
903
+ * ``onnx-dynamo``: run :func:`torch.onnx.export` (...),
904
+ models can be optimized with ``optimization`` in ``("ir", "os_ort")``
905
+ * ``modelbuilder``: use :epkg:`ModelBuilder` to builds the onnx model
906
+ * ``custom``: custom exporter (see :epkg:`experimental-experiment`),
907
+ models can be optimized with ``optimization`` in
908
+ ``("default", "default+onnxruntime", "default+os_ort", "default+onnxruntime+os_ort")``
909
+
910
+ The default runtime, :epkg:`onnxruntime` is used to validate a model and check the
911
+ exported model returns the same outputs as the original one, otherwise,
912
+ :class:`onnx_diagnostic.reference.TorchOnnxEvaluator`
913
+ if ``runtime == 'torch'`` or
914
+ :class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`
915
+ if ``runtime == 'orteval'`` or
916
+ :class:`onnx_diagnostic.reference.ExtendedReferenceEvaluator`
917
+ if ``runtime == 'ref'``,
918
+ ``orteval10`` increases the verbosity.
919
+
920
+ .. versionchanged:: 0.7.13
921
+ *inputs2* not only means a second set of inputs but many
922
+ such as ``input_empty_cache``
923
+ which refers to a set of inputs using an empty cache.
924
+ """
925
+ main_validation_begin = time.perf_counter()
926
+ cont, summary, data, dump_stats, second_input_keys = _validate_model_step1(
927
+ model_id=model_id,
928
+ do_same=do_same,
929
+ do_run=do_run,
930
+ patch=patch,
931
+ rewrite=rewrite,
932
+ dtype=dtype,
933
+ device=device,
934
+ optimization=optimization,
935
+ quiet=quiet,
936
+ drop_inputs=drop_inputs,
937
+ ortfusiontype=ortfusiontype,
938
+ stop_if_static=stop_if_static,
939
+ exporter=exporter,
940
+ verbose=verbose,
941
+ task=task,
942
+ runtime=runtime,
943
+ inputs2=inputs2,
944
+ input_options=input_options,
945
+ model_options=model_options,
946
+ exporter_options=exporter_options,
947
+ opset=opset,
948
+ output_names=output_names,
949
+ repeat=repeat,
950
+ warmup=warmup,
951
+ dump_folder=dump_folder,
952
+ subfolder=subfolder,
953
+ use_pretrained=use_pretrained,
954
+ same_as_pretrained=same_as_pretrained,
955
+ )
956
+ if dump_folder:
957
+ with open(dump_stats, "w") as f:
958
+ for k, v in sorted(summary.items()):
959
+ f.write(f":{k}:{v};\n")
960
+ if not cont:
846
961
  return summary, data
962
+ data, summary = _clean_data_remove_model_and_proto(data, summary)
963
+ _validate_model_step2(
964
+ summary=summary,
965
+ data=data,
966
+ do_run=do_run,
967
+ quiet=quiet,
968
+ verbose=verbose,
969
+ runtime=runtime,
970
+ repeat=repeat,
971
+ warmup=warmup,
972
+ second_input_keys=second_input_keys,
973
+ ort_logs=ort_logs,
974
+ quiet_input_sets=quiet_input_sets,
975
+ ortfusiontype=ortfusiontype,
976
+ model_id=model_id,
977
+ )
978
+
979
+ summary["time_total"] = time.perf_counter() - main_validation_begin
980
+
981
+ if verbose:
982
+ print("[validate_model] -- done (final)")
983
+ with open(dump_stats, "w") as f:
984
+ for k, v in sorted(summary.items()):
985
+ f.write(f":{k}:{v};\n")
986
+ return summary, data
987
+
988
+
989
+ def _clean_data_remove_model_and_proto(data, summary):
990
+ assert isinstance(data, dict) and isinstance(data, dict)
991
+ data = _clean_data_remove_model_and_proto_(data)
992
+ summary = _clean_data_remove_model_and_proto_(summary)
993
+ gc.collect()
994
+ return data, summary
995
+
996
+
997
+ def _clean_data_remove_model_and_proto_(obj):
998
+ if type(obj) is dict:
999
+ # do not use isinstance otherwise CausalLMOutputWithPast becomes a dictionary
1000
+ return {k: _clean_data_remove_model_and_proto_(v) for k, v in obj.items()}
1001
+ if isinstance(obj, list):
1002
+ return [_clean_data_remove_model_and_proto_(v) for v in obj]
1003
+ if isinstance(obj, tuple):
1004
+ return tuple(_clean_data_remove_model_and_proto_(v) for v in obj)
1005
+ if isinstance(obj, set):
1006
+ return {_clean_data_remove_model_and_proto_(v) for v in obj}
1007
+ if isinstance(obj, (torch.nn.Module, onnx.ModelProto)):
1008
+ return None
1009
+ return obj
1010
+
1011
+
1012
+ def _validate_model_step1(
1013
+ model_id,
1014
+ do_same,
1015
+ do_run,
1016
+ patch,
1017
+ rewrite,
1018
+ dtype,
1019
+ device,
1020
+ optimization,
1021
+ quiet,
1022
+ drop_inputs,
1023
+ ortfusiontype,
1024
+ stop_if_static,
1025
+ exporter,
1026
+ verbose,
1027
+ task,
1028
+ runtime,
1029
+ inputs2,
1030
+ input_options,
1031
+ model_options,
1032
+ exporter_options,
1033
+ opset,
1034
+ output_names,
1035
+ repeat,
1036
+ warmup,
1037
+ dump_folder,
1038
+ subfolder,
1039
+ use_pretrained,
1040
+ same_as_pretrained,
1041
+ ):
1042
+ assert not do_same or do_run, (
1043
+ f"Discrepancies cannot be measured if the model is not run, "
1044
+ f"do_run={do_run}, do_same={do_same}"
1045
+ )
1046
+ (
1047
+ summary,
1048
+ model_id,
1049
+ subfolder,
1050
+ same_as_pretrained,
1051
+ use_pretrained,
1052
+ dump_folder,
1053
+ folder_name,
1054
+ patch_kwargs,
1055
+ ) = _prepare_validation(
1056
+ model_id=model_id,
1057
+ subfolder=subfolder,
1058
+ same_as_pretrained=same_as_pretrained,
1059
+ use_pretrained=use_pretrained,
1060
+ patch=patch,
1061
+ rewrite=rewrite,
1062
+ do_run=do_run,
1063
+ dtype=dtype,
1064
+ device=device,
1065
+ optimization=optimization,
1066
+ quiet=quiet,
1067
+ drop_inputs=drop_inputs,
1068
+ ortfusiontype=ortfusiontype,
1069
+ stop_if_static=stop_if_static,
1070
+ exporter=exporter,
1071
+ runtime=runtime,
1072
+ inputs2=inputs2,
1073
+ input_options=input_options,
1074
+ model_options=model_options,
1075
+ exporter_options=exporter_options,
1076
+ opset=opset,
1077
+ task=task,
1078
+ verbose=verbose,
1079
+ output_names=output_names,
1080
+ dump_folder=dump_folder,
1081
+ )
1082
+
1083
+ data, iop, mop = _get_untrained_model_with_inputs(
1084
+ summary=summary,
1085
+ model_id=model_id,
1086
+ verbose=verbose,
1087
+ task=task,
1088
+ use_pretrained=use_pretrained,
1089
+ same_as_pretrained=same_as_pretrained,
1090
+ input_options=input_options,
1091
+ model_options=model_options,
1092
+ subfolder=subfolder,
1093
+ inputs2=inputs2,
1094
+ quiet=quiet,
1095
+ dump_folder=dump_folder,
1096
+ )
1097
+
1098
+ second_input_keys = [k for k in data if k.startswith("inputs") and k != "inputs"]
1099
+ if exporter == "modelbuilder":
1100
+ _update_data_for_modelbuilder(data, verbose)
1101
+
1102
+ _update_inputs_outputs(
1103
+ data=data,
1104
+ summary=summary,
1105
+ exporter=exporter,
1106
+ iop=iop,
1107
+ mop=mop,
1108
+ dump_folder=dump_folder,
1109
+ opset=opset,
1110
+ device=device,
1111
+ dtype=dtype,
1112
+ rewrite=rewrite,
1113
+ drop_inputs=drop_inputs,
1114
+ verbose=verbose,
1115
+ second_input_keys=second_input_keys,
1116
+ model_id=model_id,
1117
+ )
1118
+
1119
+ _verbose_validate(data, second_input_keys, verbose)
847
1120
 
848
1121
  if do_run:
849
- # Let's move the model to CPU to make sure it frees GPU memory.
850
- if verbose:
851
- # It does not really work for the time being and the model
852
- # gets loaded twice, one by torch, one by onnxruntime
853
- print("[validation_model] -- delete the model")
854
- for key in ["model", "onnx_program", "config"]:
855
- if key in data:
856
- del data[key]
857
- if device is not None and "cuda" in str(device).lower():
858
- torch.cuda.empty_cache()
859
- gc.collect()
860
- print("[validation_model] -- done")
1122
+ validation_begin = time.perf_counter()
861
1123
 
1124
+ _validate_do_run_model(
1125
+ data, summary, "inputs", "run", "run_expected", verbose, repeat, warmup, quiet
1126
+ )
1127
+ if second_input_keys:
1128
+ for k in second_input_keys:
1129
+ _validate_do_run_model(
1130
+ data,
1131
+ summary,
1132
+ k,
1133
+ f"run2{k[6:]}",
1134
+ f"run_expected2{k[6:]}",
1135
+ verbose,
1136
+ 1,
1137
+ 0,
1138
+ quiet,
1139
+ )
1140
+
1141
+ summary["time_total_validation_torch"] = time.perf_counter() - validation_begin
1142
+
1143
+ _call_exporter(
1144
+ data=data,
1145
+ summary=summary,
1146
+ exporter=exporter,
1147
+ patch_kwargs=patch_kwargs,
1148
+ stop_if_static=stop_if_static,
1149
+ verbose=verbose,
1150
+ dump_folder=dump_folder,
1151
+ quiet=quiet,
1152
+ optimization=optimization,
1153
+ do_run=do_run,
1154
+ output_names=output_names,
1155
+ exporter_options=exporter_options,
1156
+ )
1157
+
1158
+ cont, dump_stats = _dump_onnx_model(
1159
+ data=data,
1160
+ summary=summary,
1161
+ dump_folder=dump_folder,
1162
+ verbose=verbose,
1163
+ exporter=exporter,
1164
+ folder_name=folder_name,
1165
+ )
1166
+ return cont, summary, data, dump_stats, second_input_keys
1167
+
1168
+
1169
+ def _validate_model_step2(
1170
+ summary,
1171
+ data,
1172
+ do_run,
1173
+ quiet,
1174
+ verbose,
1175
+ runtime,
1176
+ repeat,
1177
+ warmup,
1178
+ second_input_keys,
1179
+ ort_logs,
1180
+ quiet_input_sets,
1181
+ ortfusiontype,
1182
+ model_id,
1183
+ ):
1184
+ if do_run:
862
1185
  validation_begin = time.perf_counter()
863
1186
  summary_valid, data = validate_onnx_model(
864
1187
  data=data,
@@ -937,16 +1260,6 @@ def validate_model(
937
1260
  summary.update(summary_valid)
938
1261
 
939
1262
  _compute_final_statistics(summary)
940
- summary["time_total"] = time.perf_counter() - main_validation_begin
941
-
942
- if verbose:
943
- print("[validate_model] -- done (final)")
944
- if dump_stats:
945
- # Dumps again the statistics.
946
- with open(dump_stats, "w") as f:
947
- for k, v in sorted(summary.items()):
948
- f.write(f":{k}:{v};\n")
949
- return summary, data
950
1263
 
951
1264
 
952
1265
  def compute_statistics(onnx_filename: str) -> Dict[str, Union[float, int]]:
@@ -1030,7 +1343,7 @@ def _validate_do_run_model(
1030
1343
 
1031
1344
  summary[expected_tag] = string_type(expected, with_shape=True)
1032
1345
  if verbose:
1033
- print(f"[validate_model] done ([{tag}])")
1346
+ print(f"[validate_model] done ([{tag}]) - {string_type(expected, with_shape=True)}")
1034
1347
  data[expected_tag] = expected
1035
1348
  assert hash_inputs == string_type(data[key], with_shape=True), (
1036
1349
  f"The model execution did modified the inputs:\n"
@@ -1040,7 +1353,6 @@ def _validate_do_run_model(
1040
1353
 
1041
1354
 
1042
1355
  def _validate_do_run_exported_program(data, summary, verbose, quiet):
1043
-
1044
1356
  # We run a second time the model to check the patch did not
1045
1357
  # introduce any discrepancies
1046
1358
  if verbose:
@@ -1065,7 +1377,13 @@ def _validate_do_run_exported_program(data, summary, verbose, quiet):
1065
1377
  if "ERR_run_patched" in summary:
1066
1378
  return summary, data
1067
1379
 
1068
- disc = max_diff(data["run_expected"], expected)
1380
+ verbose_diff = int(os.environ.get("MAXDIFF", "0"))
1381
+ if verbose_diff >= 10:
1382
+ print("[_validate_do_run_exported_program] with inputs_export")
1383
+ disc = max_diff(data["run_expected"], expected, verbose=verbose_diff)
1384
+ assert not verbose_diff or (
1385
+ not np.isnan(disc["abs"]) and not np.isinf(disc["abs"])
1386
+ ), f"something went wrong disc={disc}"
1069
1387
  for k, v in disc.items():
1070
1388
  summary[f"disc_patched_{k}"] = str(v)
1071
1389
  if verbose:
@@ -1107,6 +1425,7 @@ def call_exporter(
1107
1425
  do_run: bool = False,
1108
1426
  dump_folder: Optional[str] = None,
1109
1427
  output_names: Optional[List[str]] = None,
1428
+ exporter_options: Optional[Dict[str, Any]] = None,
1110
1429
  ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
1111
1430
  """
1112
1431
  Calls an exporter on a model;
@@ -1120,6 +1439,7 @@ def call_exporter(
1120
1439
  :param do_run: runs and compute discrepancies
1121
1440
  :param dump_folder: to dump additional information
1122
1441
  :param output_names: list of output names to use with the onnx exporter
1442
+ :param exporter_options: exporter options
1123
1443
  :return: two dictionaries, one with some metrics,
1124
1444
  another one with whatever the function produces
1125
1445
  """
@@ -1135,6 +1455,7 @@ def call_exporter(
1135
1455
  verbose=verbose,
1136
1456
  optimization=optimization,
1137
1457
  do_run=do_run,
1458
+ exporter_options=exporter_options,
1138
1459
  )
1139
1460
  _restore_torch_export_export(summary)
1140
1461
  return summary, data
@@ -1147,6 +1468,7 @@ def call_exporter(
1147
1468
  verbose=verbose,
1148
1469
  optimization=optimization,
1149
1470
  output_names=output_names,
1471
+ exporter_options=exporter_options,
1150
1472
  )
1151
1473
  _restore_torch_export_export(summary)
1152
1474
  return summary, data
@@ -1160,6 +1482,7 @@ def call_exporter(
1160
1482
  optimization=optimization,
1161
1483
  dump_folder=dump_folder,
1162
1484
  output_names=output_names,
1485
+ exporter_options=exporter_options,
1163
1486
  )
1164
1487
  _restore_torch_export_export(summary)
1165
1488
  return summary, data
@@ -1172,6 +1495,7 @@ def call_exporter(
1172
1495
  verbose=verbose,
1173
1496
  optimization=optimization,
1174
1497
  output_names=output_names,
1498
+ exporter_options=exporter_options,
1175
1499
  )
1176
1500
  _restore_torch_export_export(summary)
1177
1501
  return summary, data
@@ -1191,6 +1515,7 @@ def call_torch_export_export(
1191
1515
  verbose: int = 0,
1192
1516
  optimization: Optional[str] = None,
1193
1517
  do_run: bool = False,
1518
+ exporter_options: Optional[Dict[str, Any]] = None,
1194
1519
  ):
1195
1520
  """
1196
1521
  Exports a model with :func:`torch.export.export`.
@@ -1203,9 +1528,11 @@ def call_torch_export_export(
1203
1528
  :param verbose: verbosity
1204
1529
  :param optimization: optimization to do
1205
1530
  :param do_run: runs and compute discrepancies
1531
+ :param exporter_options: additional options given to the exporter
1206
1532
  :return: two dictionaries, one with some metrics,
1207
1533
  another one with whatever the function produces
1208
1534
  """
1535
+ exporter_options = exporter_options or {}
1209
1536
  assert exporter in {
1210
1537
  "export",
1211
1538
  "export-strict",
@@ -1214,8 +1541,12 @@ def call_torch_export_export(
1214
1541
  assert not optimization, f"No optimization is implemented for exporter={exporter!r}"
1215
1542
  assert "model" in data, f"model is missing from data: {sorted(data)}"
1216
1543
  assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
1544
+ assert ("-strict" not in exporter) or ("strict" not in exporter_options), (
1545
+ f"Options strict cannot be specified in the exporter name {exporter!r} "
1546
+ f"and in the options {exporter_options}"
1547
+ )
1217
1548
  summary: Dict[str, Union[str, int, float]] = {}
1218
- strict = "-strict" in exporter
1549
+ strict = "-strict" in exporter or exporter_options.pop("strict", False)
1219
1550
  args, kwargs = split_args_kwargs(data["inputs_export"])
1220
1551
  ds = data.get("dynamic_shapes", None)
1221
1552
 
@@ -1225,6 +1556,7 @@ def call_torch_export_export(
1225
1556
  summary["export_args"] = string_type(args, with_shape=True)
1226
1557
  summary["export_kwargs"] = string_type(kwargs, with_shape=True)
1227
1558
  summary["export_dynamic_shapes"] = string_type(ds)
1559
+ summary["export_options"] = str(exporter_options)
1228
1560
 
1229
1561
  # There is an issue with DynamicShape [[],[]] becomes []
1230
1562
  dse = use_dyn_not_str(ds)
@@ -1251,7 +1583,9 @@ def call_torch_export_export(
1251
1583
  data,
1252
1584
  (
1253
1585
  lambda m=model, args=args, kws=kwargs, dse=dse, s=strict: (
1254
- torch.export.export(m, args, kwargs=kws, dynamic_shapes=dse, strict=s)
1586
+ torch.export.export(
1587
+ m, args, kwargs=kws, dynamic_shapes=dse, strict=s, **exporter_options
1588
+ )
1255
1589
  )
1256
1590
  ),
1257
1591
  )
@@ -1294,7 +1628,14 @@ def call_torch_export_export(
1294
1628
  if "ERR_export_export" in summary:
1295
1629
  return summary, data
1296
1630
 
1297
- disc = max_diff(data["run_expected"], expected)
1631
+ verbose_diff = int(os.environ.get("MAXDIFF", "0"))
1632
+ if verbose_diff >= 10:
1633
+ print("[call_torch_export_export] with inputs_export")
1634
+ disc = max_diff(data["run_expected"], expected, verbose=verbose_diff)
1635
+ assert not verbose_diff or (
1636
+ not np.isnan(disc["abs"]) and not np.isinf(disc["abs"])
1637
+ ), f"something went wrong disc={disc}"
1638
+
1298
1639
  for k, v in disc.items():
1299
1640
  summary[f"disc_exported_{k}"] = str(v)
1300
1641
  if verbose:
@@ -1465,6 +1806,9 @@ def validate_onnx_model(
1465
1806
  if verbose:
1466
1807
  print(f"[validate_onnx_model] -- keys={keys}")
1467
1808
  for k_input, k_expected, suffix in keys:
1809
+ if k_input == "inputs_prompt":
1810
+ # this must used onnx_generate
1811
+ continue
1468
1812
  # make_feeds
1469
1813
  assert k_input in data, f"Unable to find {k_input!r} in {sorted(data)}"
1470
1814
  assert k_expected in data, f"Unable to find {k_expected!r} in {sorted(data)}"
@@ -1478,7 +1822,7 @@ def validate_onnx_model(
1478
1822
  data[k_input],
1479
1823
  use_numpy=True,
1480
1824
  check_flatten=False,
1481
- is_modelbuilder=data["exporter"] == "modelbuilder",
1825
+ is_modelbuilder=data["exporter"] == "modelbuilder", # to remove position_ids
1482
1826
  )
1483
1827
  if verbose:
1484
1828
  print(f"[validate_onnx_model] ort inputs={string_type(feeds, with_shape=True)}")
@@ -1501,13 +1845,6 @@ def validate_onnx_model(
1501
1845
  repeat=repeat,
1502
1846
  warmup=warmup,
1503
1847
  )
1504
- # NOTE: modelbuilder has different order on past_kv outputs
1505
- if data["exporter"] == "modelbuilder":
1506
- logits = got[:1]
1507
- past_key_values = got[1:]
1508
- reorder_past_key_values = reorder_modelbuilder_cache_to_torch(past_key_values)
1509
- got = logits + reorder_past_key_values
1510
-
1511
1848
  if f"ERR_{_mk(f'time_onnx_ort_run{suffix}')}" in summary:
1512
1849
  return summary, data
1513
1850
 
@@ -1518,7 +1855,16 @@ def validate_onnx_model(
1518
1855
  print(f"[validate_onnx_model] got={string_type(got, with_shape=True)}")
1519
1856
 
1520
1857
  # compute discrepancies
1521
- disc = max_diff(data[k_expected], got, flatten=True)
1858
+ verbose_diff = int(os.environ.get("MAXDIFF", "0"))
1859
+ if verbose_diff >= 10:
1860
+ print(
1861
+ f"[validate_onnx_model] k_input={k_input!r}, "
1862
+ f"k_expected={k_expected!r}, suffix={suffix!r}"
1863
+ )
1864
+ disc = max_diff(data[k_expected], got, flatten=True, verbose=verbose_diff)
1865
+ assert not verbose_diff or (
1866
+ not np.isnan(disc["abs"]) and not np.isinf(disc["abs"])
1867
+ ), f"something went wrong disc={disc}"
1522
1868
  if verbose:
1523
1869
  print(f"[validate_onnx_model] discrepancies={string_diff(disc)}")
1524
1870
  for k, v in disc.items():
@@ -1533,6 +1879,7 @@ def call_torch_export_onnx(
1533
1879
  verbose: int = 0,
1534
1880
  optimization: Optional[str] = None,
1535
1881
  output_names: Optional[List[str]] = None,
1882
+ exporter_options: Optional[Dict[str, Any]] = None,
1536
1883
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1537
1884
  """
1538
1885
  Exports a model into onnx.
@@ -1545,10 +1892,12 @@ def call_torch_export_onnx(
1545
1892
  :param verbose: verbosity
1546
1893
  :param optimization: optimization to do
1547
1894
  :param output_names: output names to use
1895
+ :param exporter_options: additional options to give the exporter
1548
1896
  :return: two dictionaries, one with some metrics,
1549
1897
  another one with whatever the function produces
1550
1898
  """
1551
1899
  available = {None, "", "ir", "os_ort", "ir+default"}
1900
+ exporter_options = exporter_options or {}
1552
1901
  assert (
1553
1902
  optimization in available
1554
1903
  ), f"unexpected value for optimization={optimization}, available={available}"
@@ -1576,6 +1925,7 @@ def call_torch_export_onnx(
1576
1925
  summary["export_dynamo"] = dynamo
1577
1926
  summary["export_args"] = string_type(args, with_shape=True)
1578
1927
  summary["export_kwargs"] = string_type(kwargs, with_shape=True)
1928
+ summary["export_exporter"] = str(exporter_options)
1579
1929
  opset = data.get("model_opset", None)
1580
1930
  if opset:
1581
1931
  summary["export_opset"] = opset
@@ -1603,6 +1953,11 @@ def call_torch_export_onnx(
1603
1953
  export_export_kwargs["output_names"] = output_names
1604
1954
  if opset:
1605
1955
  export_export_kwargs["opset_version"] = opset
1956
+ assert not (set(export_export_kwargs) & set(exporter_options)), (
1957
+ f"Some options were defined twice, "
1958
+ f"{set(export_export_kwargs) & set(exporter_options)}, "
1959
+ f"you should remove them from exporter_options={exporter_options}"
1960
+ )
1606
1961
  if verbose:
1607
1962
  print(
1608
1963
  f"[call_torch_export_onnx] export_export_kwargs="
@@ -1622,6 +1977,7 @@ def call_torch_export_onnx(
1622
1977
  args,
1623
1978
  kwargs=kws,
1624
1979
  **ekws,
1980
+ **exporter_options,
1625
1981
  )
1626
1982
  )
1627
1983
  ),
@@ -1694,6 +2050,7 @@ def call_torch_export_model_builder(
1694
2050
  verbose: int = 0,
1695
2051
  optimization: Optional[str] = None,
1696
2052
  output_names: Optional[List[str]] = None,
2053
+ exporter_options: Optional[Dict[str, Any]] = None,
1697
2054
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1698
2055
  """
1699
2056
  Exports a model into onnx with :epkg:`ModelBuilder`.
@@ -1705,11 +2062,13 @@ def call_torch_export_model_builder(
1705
2062
  :param verbose: verbosity
1706
2063
  :param optimization: optimization to do
1707
2064
  :param output_names: list of output names to use
2065
+ :param exporter_options: additional options to give the exporter
1708
2066
  :return: two dictionaries, one with some metrics,
1709
2067
  another one with whatever the function produces
1710
2068
  """
1711
2069
  from ..helpers.model_builder_helper import create_model_builder, save_model_builder
1712
2070
 
2071
+ exporter_options = exporter_options or {}
1713
2072
  assert optimization in (
1714
2073
  None,
1715
2074
  "",
@@ -1737,7 +2096,12 @@ def call_torch_export_model_builder(
1737
2096
  ], p=precision, pr=provider, cd=cache_dir: (
1738
2097
  save_model_builder(
1739
2098
  create_model_builder(
1740
- c, m, precision=p, execution_provider=pr, cache_dir=cd
2099
+ c,
2100
+ m,
2101
+ precision=p,
2102
+ execution_provider=pr,
2103
+ cache_dir=cd,
2104
+ **exporter_options,
1741
2105
  )
1742
2106
  )
1743
2107
  )
@@ -1854,6 +2218,7 @@ def call_torch_export_custom(
1854
2218
  optimization: Optional[str] = None,
1855
2219
  dump_folder: Optional[str] = None,
1856
2220
  output_names: Optional[List[str]] = None,
2221
+ exporter_options: Optional[Dict[str, Any]] = None,
1857
2222
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1858
2223
  """
1859
2224
  Exports a model into onnx.
@@ -1867,9 +2232,11 @@ def call_torch_export_custom(
1867
2232
  :param optimization: optimization to do
1868
2233
  :param dump_folder: to store additional information
1869
2234
  :param output_names: list of output names to use
2235
+ :param exporter_options: additional exporter options
1870
2236
  :return: two dictionaries, one with some metrics,
1871
2237
  another one with whatever the function produces
1872
2238
  """
2239
+ exporter_options = exporter_options or {}
1873
2240
  available = {
1874
2241
  "",
1875
2242
  "default",
@@ -1905,11 +2272,20 @@ def call_torch_export_custom(
1905
2272
  assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
1906
2273
  assert "model" in data, f"model is missing from data: {sorted(data)}"
1907
2274
  assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
2275
+ assert ("-strict" not in exporter) or ("strict" not in exporter_options), (
2276
+ f"Options strict cannot be specified in the exporter name {exporter!r} "
2277
+ f"and in the options {exporter_options}"
2278
+ )
2279
+ assert ("-fake" not in exporter) or ("fake" not in exporter_options), (
2280
+ f"Options strict cannot be specified in the exporter name {exporter!r} "
2281
+ f"and in the options {exporter_options}"
2282
+ )
1908
2283
  summary: Dict[str, Union[str, int, float]] = {}
1909
- strict = "-strict" in exporter
2284
+ strict = "-strict" in exporter or exporter_options.pop("strict", False)
1910
2285
  args, kwargs = split_args_kwargs(data["inputs_export"])
1911
2286
  ds = data.get("dynamic_shapes", None)
1912
- if "-fake" in exporter:
2287
+ fake = "-fake" in exporter or exporter_options.pop("fake", False)
2288
+ if fake:
1913
2289
  from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
1914
2290
 
1915
2291
  if verbose:
@@ -1932,8 +2308,10 @@ def call_torch_export_custom(
1932
2308
  summary["export_exporter"] = exporter
1933
2309
  summary["export_optimization"] = optimization or ""
1934
2310
  summary["export_strict"] = strict
2311
+ summary["export_fake"] = fake
1935
2312
  summary["export_args"] = string_type(args, with_shape=True)
1936
2313
  summary["export_kwargs"] = string_type(kwargs, with_shape=True)
2314
+ summary["export_options"] = str(exporter_options)
1937
2315
 
1938
2316
  from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
1939
2317
  from experimental_experiment.xbuilder import OptimizationOptions
@@ -1941,17 +2319,35 @@ def call_torch_export_custom(
1941
2319
  spl = optimization.split("+") if optimization else []
1942
2320
  os_ort = "os_ort" in spl
1943
2321
  optimization = "+".join(_ for _ in spl if _ != "os_ort")
1944
-
1945
- export_options = ExportOptions(
1946
- strict=strict,
1947
- decomposition_table=(
2322
+ inline = "-noinline" not in exporter or exporter_options.pop("inline", True)
2323
+ decomposition_table = (
2324
+ exporter_options.pop("decomposition_table")
2325
+ if "decomposition_table" in exporter_options
2326
+ else (
1948
2327
  "default"
1949
2328
  if ("-default" in exporter or "-dec" in exporter)
1950
2329
  else ("all" if ("-all" in exporter or "-decall" in exporter) else None)
1951
- ),
2330
+ )
2331
+ )
2332
+ large_model = bool(exporter_options.pop("large_model", True))
2333
+ return_optimize_report = bool(exporter_options.pop("return_optimize_report", True))
2334
+ export_modules_as_functions = bool(
2335
+ exporter_options.pop("export_modules_as_functions", False)
2336
+ )
2337
+ external_threshold = int(exporter_options.pop("external_threshold", 1024))
2338
+ summary["export_decomposition_table"] = str(decomposition_table)
2339
+ summary["export_inline"] = str(inline)
2340
+ summary["export_large_model"] = str(large_model)
2341
+ summary["export_return_optimize_report"] = str(return_optimize_report)
2342
+ summary["export_export_modules_as_functions"] = str(export_modules_as_functions)
2343
+ summary["export_external_threshold"] = str(external_threshold)
2344
+
2345
+ export_options = ExportOptions(
2346
+ strict=strict,
2347
+ decomposition_table=decomposition_table,
1952
2348
  save_ep=(os.path.join(dump_folder, f"{exporter}.ep") if dump_folder else None),
2349
+ **exporter_options,
1953
2350
  )
1954
- inline = "-noinline" not in exporter
1955
2351
  options = OptimizationOptions(patterns=optimization) if optimization else None
1956
2352
  model = data["model"]
1957
2353
  kws = dict(
@@ -1959,10 +2355,12 @@ def call_torch_export_custom(
1959
2355
  export_options=export_options,
1960
2356
  options=options,
1961
2357
  optimize=bool(optimization),
1962
- large_model=True,
1963
- return_optimize_report=True,
1964
2358
  verbose=max(verbose - 2, 0),
1965
2359
  inline=inline,
2360
+ large_model=large_model,
2361
+ return_optimize_report=return_optimize_report,
2362
+ export_modules_as_functions=export_modules_as_functions,
2363
+ external_threshold=external_threshold,
1966
2364
  )
1967
2365
  if opset:
1968
2366
  kws["target_opset"] = opset