onnx-diagnostic 0.8.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,11 @@
1
- import gc
2
1
  import datetime
2
+ import gc
3
3
  import inspect
4
4
  import os
5
5
  import pprint
6
6
  import sys
7
- from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
8
7
  import time
8
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
9
9
  import numpy as np
10
10
  import onnx
11
11
  import torch
@@ -273,8 +273,8 @@ def _quiet_or_not_quiet(
273
273
  summary[f"time_{suffix}_latency_std"] = a.std()
274
274
  summary[f"time_{suffix}_latency_min"] = a.min()
275
275
  summary[f"time_{suffix}_latency_max"] = a.max()
276
- summary[f"time_{suffix}_latency_098"] = a[-i2]
277
- summary[f"time_{suffix}_latency_095"] = a[-i5]
276
+ summary[f"time_{suffix}_latency_098"] = a[-(max(i2, 1))]
277
+ summary[f"time_{suffix}_latency_095"] = a[-max(i5, 1)]
278
278
  summary[f"time_{suffix}_latency_005"] = a[i5]
279
279
  summary[f"time_{suffix}_latency_002"] = a[i2]
280
280
  summary[f"time_{suffix}_n"] = len(a)
@@ -323,125 +323,33 @@ def make_patch_kwargs(
323
323
  return patch_kwargs
324
324
 
325
325
 
326
- def validate_model(
327
- model_id: str,
328
- task: Optional[str] = None,
329
- do_run: bool = False,
330
- exporter: Optional[str] = None,
331
- do_same: bool = False,
332
- verbose: int = 0,
333
- dtype: Optional[Union[str, torch.dtype]] = None,
334
- device: Optional[Union[str, torch.device]] = None,
335
- same_as_pretrained: bool = False,
336
- use_pretrained: bool = False,
337
- optimization: Optional[str] = None,
338
- quiet: bool = False,
339
- patch: Union[bool, str, Dict[str, bool]] = False,
340
- rewrite: bool = False,
341
- stop_if_static: int = 1,
342
- dump_folder: Optional[str] = None,
343
- drop_inputs: Optional[List[str]] = None,
344
- ortfusiontype: Optional[str] = None,
345
- input_options: Optional[Dict[str, Any]] = None,
346
- model_options: Optional[Dict[str, Any]] = None,
347
- subfolder: Optional[str] = None,
348
- opset: Optional[int] = None,
349
- runtime: str = "onnxruntime",
350
- repeat: int = 1,
351
- warmup: int = 0,
352
- inputs2: int = 1,
353
- output_names: Optional[List[str]] = None,
354
- ort_logs: bool = False,
355
- quiet_input_sets: Optional[Set[str]] = None,
356
- ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
357
- """
358
- Validates a model.
359
- The function can also be called through the command line
360
- :ref:`l-cmd-validate`.
361
-
362
- :param model_id: model id to validate
363
- :param task: task used to generate the necessary inputs,
364
- can be left empty to use the default task for this model
365
- if it can be determined
366
- :param do_run: checks the model works with the defined inputs
367
- :param exporter: exporter the model using this exporter,
368
- available list: ``export-strict``, ``export-nostrict``, ...
369
- see below
370
- :param do_same: checks the discrepancies of the exported model
371
- :param verbose: verbosity level
372
- :param dtype: uses this dtype to check the model
373
- :param device: do the verification on this device
374
- :param same_as_pretrained: use a model equivalent to the trained,
375
- this is not always possible
376
- :param use_pretrained: use the trained model, not the untrained one
377
- :param optimization: optimization to apply to the exported model,
378
- depend on the the exporter
379
- :param quiet: if quiet, catches exception if any issue
380
- :param patch: applies patches (``patch_transformers=True, path_diffusers=True``)
381
- if True before exporting
382
- see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`,
383
- a string can be used to specify only one of them
384
- :param rewrite: applies known rewriting (``patch_transformers=True``) before exporting,
385
- see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
386
- :param stop_if_static: stops if a dynamic dimension becomes static,
387
- see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
388
- :param dump_folder: dumps everything in a subfolder of this one
389
- :param drop_inputs: drops this list of inputs (given their names)
390
- :param ortfusiontype: runs ort fusion, the parameters defines the fusion type,
391
- it accepts multiple values separated by ``|``,
392
- see :func:`onnx_diagnostic.torch_models.validate.run_ort_fusion`
393
- :param input_options: additional options to define the dummy inputs
394
- used to export
395
- :param model_options: additional options when creating the model such as
396
- ``num_hidden_layers`` or ``attn_implementation``
397
- :param subfolder: version or subfolders to uses when retrieving a model id
398
- :param opset: onnx opset to use for the conversion
399
- :param runtime: onnx runtime to use to check about discrepancies,
400
- possible values ``onnxruntime``, ``torch``, ``orteval``,
401
- ``orteval10``, ``ref`` only if `do_run` is true
402
- :param repeat: number of time to measure the model
403
- :param warmup: warmup the model first
404
- :param inputs2: checks that other sets of inputs are running as well,
405
- this ensures that the model does support dynamism, the value is used
406
- as an increment to the first set of values (added to dimensions),
407
- or an empty cache for example
408
- :param output_names: output names the onnx exporter should use
409
- :param ort_logs: increases onnxruntime verbosity when creating the session
410
- :param quiet_input_sets: avoid raising an exception if the inputs belongs to that set
411
- even if quiet is False
412
- :return: two dictionaries, one with some metrics,
413
- another one with whatever the function produces
414
-
415
- The following environment variables can be used to print out some
416
- information:
417
-
418
- * ``PRINT_CONFIG``: prints the model configuration
419
-
420
- The following exporters are available:
421
-
422
- * ``export-nostrict``: run :func:`torch.export.export` (..., strict=False)
423
- * ``onnx-dynamo``: run :func:`torch.onnx.export` (...),
424
- models can be optimized with ``optimization`` in ``("ir", "os_ort")``
425
- * ``modelbuilder``: use :epkg:`ModelBuilder` to builds the onnx model
426
- * ``custom``: custom exporter (see :epkg:`experimental-experiment`),
427
- models can be optimized with ``optimization`` in
428
- ``("default", "default+onnxruntime", "default+os_ort", "default+onnxruntime+os_ort")``
429
-
430
- The default runtime, :epkg:`onnxruntime` is used to validate a model and check the
431
- exported model returns the same outputs as the original one, otherwise,
432
- :class:`onnx_diagnostic.reference.TorchOnnxEvaluator`
433
- if ``runtime == 'torch'`` or
434
- :class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`
435
- if ``runtime == 'orteval'`` or
436
- :class:`onnx_diagnostic.reference.ExtendedReferenceEvaluator`
437
- if ``runtime == 'ref'``,
438
- ``orteval10`` increases the verbosity.
439
-
440
- .. versionchanged:: 0.7.13
441
- *inputs2* not only means a second set of inputs but many
442
- such as ``input_empty_cache``
443
- which refers to a set of inputs using an empty cache.
444
- """
326
+ def _prepare_validation(
327
+ model_id,
328
+ subfolder,
329
+ same_as_pretrained,
330
+ use_pretrained,
331
+ patch,
332
+ rewrite,
333
+ do_run,
334
+ dtype,
335
+ device,
336
+ optimization,
337
+ quiet,
338
+ drop_inputs,
339
+ ortfusiontype,
340
+ stop_if_static,
341
+ exporter,
342
+ runtime,
343
+ inputs2,
344
+ input_options,
345
+ model_options,
346
+ exporter_options,
347
+ opset,
348
+ task,
349
+ verbose,
350
+ output_names,
351
+ dump_folder,
352
+ ):
445
353
  main_validation_begin = time.perf_counter()
446
354
  model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
447
355
  model_id,
@@ -473,6 +381,10 @@ def validate_model(
473
381
  version_exporter=exporter or "",
474
382
  version_runtime=runtime,
475
383
  version_inputs2=inputs2,
384
+ version_input_options=str(input_options),
385
+ version_drop_input=str(drop_inputs),
386
+ version_model_options=str(model_options),
387
+ version_exporter_options=str(exporter_options),
476
388
  time_preprocess_model_id=time_preprocess_model_id,
477
389
  )
478
390
  )
@@ -523,6 +435,32 @@ def validate_model(
523
435
  summary["model_id"] = model_id
524
436
  summary["model_subfolder"] = subfolder or ""
525
437
 
438
+ return (
439
+ summary,
440
+ model_id,
441
+ subfolder,
442
+ same_as_pretrained,
443
+ use_pretrained,
444
+ dump_folder,
445
+ folder_name,
446
+ patch_kwargs,
447
+ )
448
+
449
+
450
+ def _get_untrained_model_with_inputs(
451
+ summary,
452
+ model_id,
453
+ verbose,
454
+ task,
455
+ use_pretrained,
456
+ same_as_pretrained,
457
+ input_options,
458
+ model_options,
459
+ subfolder,
460
+ inputs2,
461
+ quiet,
462
+ dump_folder,
463
+ ):
526
464
  iop = input_options or {}
527
465
  mop = model_options or {}
528
466
  data = _quiet_or_not_quiet(
@@ -547,8 +485,6 @@ def validate_model(
547
485
  ),
548
486
  )
549
487
 
550
- second_input_keys = [k for k in data if k.startswith("inputs") and k != "inputs"]
551
-
552
488
  if dump_folder:
553
489
  with open(os.path.join(dump_folder, "model_config.txt"), "w") as f:
554
490
  f.write(f"model_id: {model_id}\n------\n")
@@ -565,25 +501,45 @@ def validate_model(
565
501
  f.write(f"model_id: {model_id}\n------\n")
566
502
  f.write(pprint.pformat(dump_info))
567
503
 
568
- if exporter == "modelbuilder":
569
- # Models used with ModelBuilder do not like batch size > 1.
570
- # Let's change that.
571
- for k in ["inputs", "inputs2"]:
572
- if k not in data:
573
- continue
574
- if verbose:
575
- print(f"[validate_model] set batch=1 for data[{k!r}]")
576
- print(f"[validate_model] batch=1 === {string_type(data[k], with_shape=True)}")
577
- cpl = CoupleInputsDynamicShapes(
578
- tuple(), data[k], dynamic_shapes=data["dynamic_shapes"]
579
- )
580
- with register_additional_serialization_functions(patch_transformers=True): # type: ignore[arg-type]
581
- data[k] = cpl.change_dynamic_dimensions(
582
- desired_values=dict(batch=1), only_desired=True
583
- )
584
- if verbose:
585
- print(f"[validate_model] batch=1 --> {string_type(data[k], with_shape=True)}")
504
+ return data, iop, mop
505
+
586
506
 
507
+ def _update_data_for_modelbuilder(data, verbose):
508
+ # Models used with ModelBuilder do not like batch size > 1.
509
+ # Let's change that.
510
+ for k in ["inputs", "inputs2"]:
511
+ if k not in data:
512
+ continue
513
+ if verbose:
514
+ print(f"[validate_model] set batch=1 for data[{k!r}]")
515
+ print(f"[validate_model] batch=1 === {string_type(data[k], with_shape=True)}")
516
+ cpl = CoupleInputsDynamicShapes(
517
+ tuple(), data[k], dynamic_shapes=data["dynamic_shapes"]
518
+ )
519
+ with register_additional_serialization_functions(patch_transformers=True): # type: ignore[arg-type]
520
+ data[k] = cpl.change_dynamic_dimensions(
521
+ desired_values=dict(batch=1), only_desired=True
522
+ )
523
+ if verbose:
524
+ print(f"[validate_model] batch=1 --> {string_type(data[k], with_shape=True)}")
525
+
526
+
527
+ def _update_inputs_outputs(
528
+ data,
529
+ summary,
530
+ exporter,
531
+ iop,
532
+ mop,
533
+ dump_folder,
534
+ opset,
535
+ device,
536
+ dtype,
537
+ rewrite,
538
+ drop_inputs,
539
+ verbose,
540
+ second_input_keys,
541
+ model_id,
542
+ ):
587
543
  # modelbuilder needs different treatments sometimes, so
588
544
  # we mark it for later usage.
589
545
  # for example, it has different past_kv ordering than
@@ -670,7 +626,7 @@ def validate_model(
670
626
  for k in ["task", "size", "n_weights"]:
671
627
  summary[f"model_{k.replace('_','')}"] = data[k]
672
628
  summary["second_input_keys"] = ",".join(second_input_keys)
673
- summary["model_inputs_options"] = str(input_options or "")
629
+ summary["model_inputs_options"] = str(iop or "")
674
630
  summary["model_inputs"] = string_type(data["inputs"], with_shape=True)
675
631
  summary["model_shapes"] = string_type(data["dynamic_shapes"])
676
632
  summary["model_class"] = data["model"].__class__.__name__
@@ -687,6 +643,8 @@ def validate_model(
687
643
  ).replace(" ", "")
688
644
  summary["model_id"] = model_id
689
645
 
646
+
647
+ def _verbose_validate(data, second_input_keys, verbose):
690
648
  if verbose:
691
649
  print("[validate_model] --")
692
650
  print(f"[validate_model] task={data['task']}")
@@ -699,33 +657,30 @@ def validate_model(
699
657
  print(f"[validate_model] second_input_keys={second_input_keys}")
700
658
  print("[validate_model] --")
701
659
 
702
- if do_run:
703
- validation_begin = time.perf_counter()
704
-
705
- _validate_do_run_model(
706
- data, summary, "inputs", "run", "run_expected", verbose, repeat, warmup, quiet
707
- )
708
- if second_input_keys:
709
- for k in second_input_keys:
710
- _validate_do_run_model(
711
- data,
712
- summary,
713
- k,
714
- f"run2{k[6:]}",
715
- f"run_expected2{k[6:]}",
716
- verbose,
717
- 1,
718
- 0,
719
- quiet,
720
- )
721
-
722
- summary["time_total_validation_torch"] = time.perf_counter() - validation_begin
723
660
 
661
+ def _call_exporter(
662
+ data,
663
+ summary,
664
+ exporter,
665
+ patch_kwargs,
666
+ stop_if_static,
667
+ verbose,
668
+ dump_folder,
669
+ quiet,
670
+ optimization,
671
+ do_run,
672
+ output_names,
673
+ exporter_options,
674
+ ):
724
675
  if exporter:
725
- print(
726
- f"[validate_model] -- export the model with {exporter!r}, "
727
- f"optimization={optimization!r}"
728
- )
676
+ expop = exporter_options or {}
677
+ if verbose:
678
+ print(
679
+ f"[validate_model] -- export the model with {exporter!r}, "
680
+ f"optimization={optimization!r}"
681
+ )
682
+ if expop:
683
+ print(f"[validate_model] -- exporter options {expop}")
729
684
  exporter_begin = time.perf_counter()
730
685
  if patch_kwargs:
731
686
  if verbose:
@@ -755,6 +710,7 @@ def validate_model(
755
710
  do_run=do_run,
756
711
  dump_folder=dump_folder,
757
712
  output_names=output_names,
713
+ exporter_options=expop,
758
714
  )
759
715
  else:
760
716
  data["inputs_export"] = data["inputs"]
@@ -768,11 +724,14 @@ def validate_model(
768
724
  do_run=do_run,
769
725
  dump_folder=dump_folder,
770
726
  output_names=output_names,
727
+ exporter_options=expop,
771
728
  )
772
729
 
773
730
  summary.update(summary_export)
774
731
  summary["time_total_exporter"] = time.perf_counter() - exporter_begin
775
732
 
733
+
734
+ def _dump_onnx_model(data, summary, dump_folder, verbose, exporter, folder_name):
776
735
  dump_stats = None
777
736
  if dump_folder:
778
737
  if "exported_program" in data:
@@ -837,26 +796,392 @@ def validate_model(
837
796
  ):
838
797
  if verbose:
839
798
  print("[validate_model] -- done (final)")
840
- if dump_stats:
841
- with open(dump_stats, "w") as f:
842
- for k, v in sorted(summary.items()):
843
- f.write(f":{k}:{v};\n")
799
+ return False, dump_stats
800
+ return True, dump_stats
801
+
802
+
803
+ def validate_model(
804
+ model_id: str,
805
+ task: Optional[str] = None,
806
+ do_run: bool = False,
807
+ exporter: Optional[str] = None,
808
+ do_same: bool = False,
809
+ verbose: int = 0,
810
+ dtype: Optional[Union[str, torch.dtype]] = None,
811
+ device: Optional[Union[str, torch.device]] = None,
812
+ same_as_pretrained: bool = False,
813
+ use_pretrained: bool = False,
814
+ optimization: Optional[str] = None,
815
+ quiet: bool = False,
816
+ patch: Union[bool, str, Dict[str, bool]] = False,
817
+ rewrite: bool = False,
818
+ stop_if_static: int = 1,
819
+ dump_folder: Optional[str] = None,
820
+ drop_inputs: Optional[List[str]] = None,
821
+ ortfusiontype: Optional[str] = None,
822
+ input_options: Optional[Dict[str, Any]] = None,
823
+ model_options: Optional[Dict[str, Any]] = None,
824
+ exporter_options: Optional[Dict[str, Any]] = None,
825
+ subfolder: Optional[str] = None,
826
+ opset: Optional[int] = None,
827
+ runtime: str = "onnxruntime",
828
+ repeat: int = 1,
829
+ warmup: int = 0,
830
+ inputs2: int = 1,
831
+ output_names: Optional[List[str]] = None,
832
+ ort_logs: bool = False,
833
+ quiet_input_sets: Optional[Set[str]] = None,
834
+ ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
835
+ """
836
+ Validates a model.
837
+ The function can also be called through the command line
838
+ :ref:`l-cmd-validate`.
839
+
840
+ :param model_id: model id to validate
841
+ :param task: task used to generate the necessary inputs,
842
+ can be left empty to use the default task for this model
843
+ if it can be determined
844
+ :param do_run: checks the model works with the defined inputs
845
+ :param exporter: exporter the model using this exporter,
846
+ available list: ``export-strict``, ``export-nostrict``, ...
847
+ see below
848
+ :param do_same: checks the discrepancies of the exported model
849
+ :param verbose: verbosity level
850
+ :param dtype: uses this dtype to check the model
851
+ :param device: do the verification on this device
852
+ :param same_as_pretrained: use a model equivalent to the trained,
853
+ this is not always possible
854
+ :param use_pretrained: use the trained model, not the untrained one
855
+ :param optimization: optimization to apply to the exported model,
856
+ depend on the the exporter
857
+ :param quiet: if quiet, catches exception if any issue
858
+ :param patch: applies patches (``patch_transformers=True, path_diffusers=True``)
859
+ if True before exporting
860
+ see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`,
861
+ a string can be used to specify only one of them
862
+ :param rewrite: applies known rewriting (``patch_transformers=True``) before exporting,
863
+ see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
864
+ :param stop_if_static: stops if a dynamic dimension becomes static,
865
+ see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
866
+ :param dump_folder: dumps everything in a subfolder of this one
867
+ :param drop_inputs: drops this list of inputs (given their names)
868
+ :param ortfusiontype: runs ort fusion, the parameters defines the fusion type,
869
+ it accepts multiple values separated by ``|``,
870
+ see :func:`onnx_diagnostic.torch_models.validate.run_ort_fusion`
871
+ :param input_options: additional options to define the dummy inputs
872
+ used to export
873
+ :param model_options: additional options when creating the model such as
874
+ ``num_hidden_layers`` or ``attn_implementation``
875
+ :param exporter_options: additional options when exporting the model such as
876
+ ``report=True`` or ``verify=True``
877
+ :param subfolder: version or subfolders to uses when retrieving a model id
878
+ :param opset: onnx opset to use for the conversion
879
+ :param runtime: onnx runtime to use to check about discrepancies,
880
+ possible values ``onnxruntime``, ``torch``, ``orteval``,
881
+ ``orteval10``, ``ref`` only if `do_run` is true
882
+ :param repeat: number of time to measure the model
883
+ :param warmup: warmup the model first
884
+ :param inputs2: checks that other sets of inputs are running as well,
885
+ this ensures that the model does support dynamism, the value is used
886
+ as an increment to the first set of values (added to dimensions),
887
+ or an empty cache for example
888
+ :param output_names: output names the onnx exporter should use
889
+ :param ort_logs: increases onnxruntime verbosity when creating the session
890
+ :param quiet_input_sets: avoid raising an exception if the inputs belongs to that set
891
+ even if quiet is False
892
+ :return: two dictionaries, one with some metrics,
893
+ another one with whatever the function produces
894
+
895
+ The following environment variables can be used to print out some
896
+ information:
897
+
898
+ * ``PRINT_CONFIG``: prints the model configuration
899
+
900
+ The following exporters are available:
901
+
902
+ * ``export-nostrict``: run :func:`torch.export.export` (..., strict=False)
903
+ * ``onnx-dynamo``: run :func:`torch.onnx.export` (...),
904
+ models can be optimized with ``optimization`` in ``("ir", "os_ort")``
905
+ * ``modelbuilder``: use :epkg:`ModelBuilder` to builds the onnx model
906
+ * ``custom``: custom exporter (see :epkg:`experimental-experiment`),
907
+ models can be optimized with ``optimization`` in
908
+ ``("default", "default+onnxruntime", "default+os_ort", "default+onnxruntime+os_ort")``
909
+
910
+ The default runtime, :epkg:`onnxruntime` is used to validate a model and check the
911
+ exported model returns the same outputs as the original one, otherwise,
912
+ :class:`onnx_diagnostic.reference.TorchOnnxEvaluator`
913
+ if ``runtime == 'torch'`` or
914
+ :class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`
915
+ if ``runtime == 'orteval'`` or
916
+ :class:`onnx_diagnostic.reference.ExtendedReferenceEvaluator`
917
+ if ``runtime == 'ref'``,
918
+ ``orteval10`` increases the verbosity.
919
+
920
+ .. versionchanged:: 0.7.13
921
+ *inputs2* not only means a second set of inputs but many
922
+ such as ``input_empty_cache``
923
+ which refers to a set of inputs using an empty cache.
924
+ """
925
+ main_validation_begin = time.perf_counter()
926
+ cont, summary, data, dump_stats, second_input_keys = _validate_model_step1(
927
+ model_id=model_id,
928
+ do_same=do_same,
929
+ do_run=do_run,
930
+ patch=patch,
931
+ rewrite=rewrite,
932
+ dtype=dtype,
933
+ device=device,
934
+ optimization=optimization,
935
+ quiet=quiet,
936
+ drop_inputs=drop_inputs,
937
+ ortfusiontype=ortfusiontype,
938
+ stop_if_static=stop_if_static,
939
+ exporter=exporter,
940
+ verbose=verbose,
941
+ task=task,
942
+ runtime=runtime,
943
+ inputs2=inputs2,
944
+ input_options=input_options,
945
+ model_options=model_options,
946
+ exporter_options=exporter_options,
947
+ opset=opset,
948
+ output_names=output_names,
949
+ repeat=repeat,
950
+ warmup=warmup,
951
+ dump_folder=dump_folder,
952
+ subfolder=subfolder,
953
+ use_pretrained=use_pretrained,
954
+ same_as_pretrained=same_as_pretrained,
955
+ )
956
+ if dump_folder:
957
+ with open(dump_stats, "w") as f:
958
+ for k, v in sorted(summary.items()):
959
+ f.write(f":{k}:{v};\n")
960
+ if not cont:
844
961
  return summary, data
962
+ data, summary = _clean_data_remove_model_and_proto(data, summary)
963
+ _validate_model_step2(
964
+ summary=summary,
965
+ data=data,
966
+ do_run=do_run,
967
+ quiet=quiet,
968
+ verbose=verbose,
969
+ runtime=runtime,
970
+ repeat=repeat,
971
+ warmup=warmup,
972
+ second_input_keys=second_input_keys,
973
+ ort_logs=ort_logs,
974
+ quiet_input_sets=quiet_input_sets,
975
+ ortfusiontype=ortfusiontype,
976
+ model_id=model_id,
977
+ )
978
+
979
+ summary["time_total"] = time.perf_counter() - main_validation_begin
980
+
981
+ if verbose:
982
+ print("[validate_model] -- done (final)")
983
+ with open(dump_stats, "w") as f:
984
+ for k, v in sorted(summary.items()):
985
+ f.write(f":{k}:{v};\n")
986
+ return summary, data
987
+
988
+
989
+ def _clean_data_remove_model_and_proto(data, summary):
990
+ assert isinstance(data, dict) and isinstance(data, dict)
991
+ data = _clean_data_remove_model_and_proto_(data)
992
+ summary = _clean_data_remove_model_and_proto_(summary)
993
+ gc.collect()
994
+ return data, summary
995
+
996
+
997
+ def _clean_data_remove_model_and_proto_(obj):
998
+ if type(obj) is dict:
999
+ # do not use isinstance otherwise CausalLMOutputWithPast becomes a dictionary
1000
+ return {k: _clean_data_remove_model_and_proto_(v) for k, v in obj.items()}
1001
+ if isinstance(obj, list):
1002
+ return [_clean_data_remove_model_and_proto_(v) for v in obj]
1003
+ if isinstance(obj, tuple):
1004
+ return tuple(_clean_data_remove_model_and_proto_(v) for v in obj)
1005
+ if isinstance(obj, set):
1006
+ return {_clean_data_remove_model_and_proto_(v) for v in obj}
1007
+ if isinstance(obj, (torch.nn.Module, onnx.ModelProto)):
1008
+ return None
1009
+ return obj
1010
+
1011
+
1012
+ def _validate_model_step1(
1013
+ model_id,
1014
+ do_same,
1015
+ do_run,
1016
+ patch,
1017
+ rewrite,
1018
+ dtype,
1019
+ device,
1020
+ optimization,
1021
+ quiet,
1022
+ drop_inputs,
1023
+ ortfusiontype,
1024
+ stop_if_static,
1025
+ exporter,
1026
+ verbose,
1027
+ task,
1028
+ runtime,
1029
+ inputs2,
1030
+ input_options,
1031
+ model_options,
1032
+ exporter_options,
1033
+ opset,
1034
+ output_names,
1035
+ repeat,
1036
+ warmup,
1037
+ dump_folder,
1038
+ subfolder,
1039
+ use_pretrained,
1040
+ same_as_pretrained,
1041
+ ):
1042
+ assert not do_same or do_run, (
1043
+ f"Discrepancies cannot be measured if the model is not run, "
1044
+ f"do_run={do_run}, do_same={do_same}"
1045
+ )
1046
+ (
1047
+ summary,
1048
+ model_id,
1049
+ subfolder,
1050
+ same_as_pretrained,
1051
+ use_pretrained,
1052
+ dump_folder,
1053
+ folder_name,
1054
+ patch_kwargs,
1055
+ ) = _prepare_validation(
1056
+ model_id=model_id,
1057
+ subfolder=subfolder,
1058
+ same_as_pretrained=same_as_pretrained,
1059
+ use_pretrained=use_pretrained,
1060
+ patch=patch,
1061
+ rewrite=rewrite,
1062
+ do_run=do_run,
1063
+ dtype=dtype,
1064
+ device=device,
1065
+ optimization=optimization,
1066
+ quiet=quiet,
1067
+ drop_inputs=drop_inputs,
1068
+ ortfusiontype=ortfusiontype,
1069
+ stop_if_static=stop_if_static,
1070
+ exporter=exporter,
1071
+ runtime=runtime,
1072
+ inputs2=inputs2,
1073
+ input_options=input_options,
1074
+ model_options=model_options,
1075
+ exporter_options=exporter_options,
1076
+ opset=opset,
1077
+ task=task,
1078
+ verbose=verbose,
1079
+ output_names=output_names,
1080
+ dump_folder=dump_folder,
1081
+ )
1082
+
1083
+ data, iop, mop = _get_untrained_model_with_inputs(
1084
+ summary=summary,
1085
+ model_id=model_id,
1086
+ verbose=verbose,
1087
+ task=task,
1088
+ use_pretrained=use_pretrained,
1089
+ same_as_pretrained=same_as_pretrained,
1090
+ input_options=input_options,
1091
+ model_options=model_options,
1092
+ subfolder=subfolder,
1093
+ inputs2=inputs2,
1094
+ quiet=quiet,
1095
+ dump_folder=dump_folder,
1096
+ )
1097
+
1098
+ second_input_keys = [k for k in data if k.startswith("inputs") and k != "inputs"]
1099
+ if exporter == "modelbuilder":
1100
+ _update_data_for_modelbuilder(data, verbose)
1101
+
1102
+ _update_inputs_outputs(
1103
+ data=data,
1104
+ summary=summary,
1105
+ exporter=exporter,
1106
+ iop=iop,
1107
+ mop=mop,
1108
+ dump_folder=dump_folder,
1109
+ opset=opset,
1110
+ device=device,
1111
+ dtype=dtype,
1112
+ rewrite=rewrite,
1113
+ drop_inputs=drop_inputs,
1114
+ verbose=verbose,
1115
+ second_input_keys=second_input_keys,
1116
+ model_id=model_id,
1117
+ )
1118
+
1119
+ _verbose_validate(data, second_input_keys, verbose)
845
1120
 
846
1121
  if do_run:
847
- # Let's move the model to CPU to make sure it frees GPU memory.
848
- if verbose:
849
- # It does not really work for the time being and the model
850
- # gets loaded twice, one by torch, one by onnxruntime
851
- print("[validation_model] -- delete the model")
852
- for key in ["model", "onnx_program", "config"]:
853
- if key in data:
854
- del data[key]
855
- if device is not None and "cuda" in str(device).lower():
856
- torch.cuda.empty_cache()
857
- gc.collect()
858
- print("[validation_model] -- done")
1122
+ validation_begin = time.perf_counter()
1123
+
1124
+ _validate_do_run_model(
1125
+ data, summary, "inputs", "run", "run_expected", verbose, repeat, warmup, quiet
1126
+ )
1127
+ if second_input_keys:
1128
+ for k in second_input_keys:
1129
+ _validate_do_run_model(
1130
+ data,
1131
+ summary,
1132
+ k,
1133
+ f"run2{k[6:]}",
1134
+ f"run_expected2{k[6:]}",
1135
+ verbose,
1136
+ 1,
1137
+ 0,
1138
+ quiet,
1139
+ )
1140
+
1141
+ summary["time_total_validation_torch"] = time.perf_counter() - validation_begin
1142
+
1143
+ _call_exporter(
1144
+ data=data,
1145
+ summary=summary,
1146
+ exporter=exporter,
1147
+ patch_kwargs=patch_kwargs,
1148
+ stop_if_static=stop_if_static,
1149
+ verbose=verbose,
1150
+ dump_folder=dump_folder,
1151
+ quiet=quiet,
1152
+ optimization=optimization,
1153
+ do_run=do_run,
1154
+ output_names=output_names,
1155
+ exporter_options=exporter_options,
1156
+ )
859
1157
 
1158
+ cont, dump_stats = _dump_onnx_model(
1159
+ data=data,
1160
+ summary=summary,
1161
+ dump_folder=dump_folder,
1162
+ verbose=verbose,
1163
+ exporter=exporter,
1164
+ folder_name=folder_name,
1165
+ )
1166
+ return cont, summary, data, dump_stats, second_input_keys
1167
+
1168
+
1169
+ def _validate_model_step2(
1170
+ summary,
1171
+ data,
1172
+ do_run,
1173
+ quiet,
1174
+ verbose,
1175
+ runtime,
1176
+ repeat,
1177
+ warmup,
1178
+ second_input_keys,
1179
+ ort_logs,
1180
+ quiet_input_sets,
1181
+ ortfusiontype,
1182
+ model_id,
1183
+ ):
1184
+ if do_run:
860
1185
  validation_begin = time.perf_counter()
861
1186
  summary_valid, data = validate_onnx_model(
862
1187
  data=data,
@@ -935,16 +1260,6 @@ def validate_model(
935
1260
  summary.update(summary_valid)
936
1261
 
937
1262
  _compute_final_statistics(summary)
938
- summary["time_total"] = time.perf_counter() - main_validation_begin
939
-
940
- if verbose:
941
- print("[validate_model] -- done (final)")
942
- if dump_stats:
943
- # Dumps again the statistics.
944
- with open(dump_stats, "w") as f:
945
- for k, v in sorted(summary.items()):
946
- f.write(f":{k}:{v};\n")
947
- return summary, data
948
1263
 
949
1264
 
950
1265
  def compute_statistics(onnx_filename: str) -> Dict[str, Union[float, int]]:
@@ -1028,7 +1343,7 @@ def _validate_do_run_model(
1028
1343
 
1029
1344
  summary[expected_tag] = string_type(expected, with_shape=True)
1030
1345
  if verbose:
1031
- print(f"[validate_model] done ([{tag}])")
1346
+ print(f"[validate_model] done ([{tag}]) - {string_type(expected, with_shape=True)}")
1032
1347
  data[expected_tag] = expected
1033
1348
  assert hash_inputs == string_type(data[key], with_shape=True), (
1034
1349
  f"The model execution did modified the inputs:\n"
@@ -1038,7 +1353,6 @@ def _validate_do_run_model(
1038
1353
 
1039
1354
 
1040
1355
  def _validate_do_run_exported_program(data, summary, verbose, quiet):
1041
-
1042
1356
  # We run a second time the model to check the patch did not
1043
1357
  # introduce any discrepancies
1044
1358
  if verbose:
@@ -1063,7 +1377,13 @@ def _validate_do_run_exported_program(data, summary, verbose, quiet):
1063
1377
  if "ERR_run_patched" in summary:
1064
1378
  return summary, data
1065
1379
 
1066
- disc = max_diff(data["run_expected"], expected)
1380
+ verbose_diff = int(os.environ.get("MAXDIFF", "0"))
1381
+ if verbose_diff >= 10:
1382
+ print("[_validate_do_run_exported_program] with inputs_export")
1383
+ disc = max_diff(data["run_expected"], expected, verbose=verbose_diff)
1384
+ assert not verbose_diff or (
1385
+ not np.isnan(disc["abs"]) and not np.isinf(disc["abs"])
1386
+ ), f"something went wrong disc={disc}"
1067
1387
  for k, v in disc.items():
1068
1388
  summary[f"disc_patched_{k}"] = str(v)
1069
1389
  if verbose:
@@ -1105,6 +1425,7 @@ def call_exporter(
1105
1425
  do_run: bool = False,
1106
1426
  dump_folder: Optional[str] = None,
1107
1427
  output_names: Optional[List[str]] = None,
1428
+ exporter_options: Optional[Dict[str, Any]] = None,
1108
1429
  ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
1109
1430
  """
1110
1431
  Calls an exporter on a model;
@@ -1118,6 +1439,7 @@ def call_exporter(
1118
1439
  :param do_run: runs and compute discrepancies
1119
1440
  :param dump_folder: to dump additional information
1120
1441
  :param output_names: list of output names to use with the onnx exporter
1442
+ :param exporter_options: exporter options
1121
1443
  :return: two dictionaries, one with some metrics,
1122
1444
  another one with whatever the function produces
1123
1445
  """
@@ -1133,6 +1455,7 @@ def call_exporter(
1133
1455
  verbose=verbose,
1134
1456
  optimization=optimization,
1135
1457
  do_run=do_run,
1458
+ exporter_options=exporter_options,
1136
1459
  )
1137
1460
  _restore_torch_export_export(summary)
1138
1461
  return summary, data
@@ -1145,6 +1468,7 @@ def call_exporter(
1145
1468
  verbose=verbose,
1146
1469
  optimization=optimization,
1147
1470
  output_names=output_names,
1471
+ exporter_options=exporter_options,
1148
1472
  )
1149
1473
  _restore_torch_export_export(summary)
1150
1474
  return summary, data
@@ -1158,6 +1482,7 @@ def call_exporter(
1158
1482
  optimization=optimization,
1159
1483
  dump_folder=dump_folder,
1160
1484
  output_names=output_names,
1485
+ exporter_options=exporter_options,
1161
1486
  )
1162
1487
  _restore_torch_export_export(summary)
1163
1488
  return summary, data
@@ -1170,6 +1495,7 @@ def call_exporter(
1170
1495
  verbose=verbose,
1171
1496
  optimization=optimization,
1172
1497
  output_names=output_names,
1498
+ exporter_options=exporter_options,
1173
1499
  )
1174
1500
  _restore_torch_export_export(summary)
1175
1501
  return summary, data
@@ -1189,6 +1515,7 @@ def call_torch_export_export(
1189
1515
  verbose: int = 0,
1190
1516
  optimization: Optional[str] = None,
1191
1517
  do_run: bool = False,
1518
+ exporter_options: Optional[Dict[str, Any]] = None,
1192
1519
  ):
1193
1520
  """
1194
1521
  Exports a model with :func:`torch.export.export`.
@@ -1201,9 +1528,11 @@ def call_torch_export_export(
1201
1528
  :param verbose: verbosity
1202
1529
  :param optimization: optimization to do
1203
1530
  :param do_run: runs and compute discrepancies
1531
+ :param exporter_options: additional options given to the exporter
1204
1532
  :return: two dictionaries, one with some metrics,
1205
1533
  another one with whatever the function produces
1206
1534
  """
1535
+ exporter_options = exporter_options or {}
1207
1536
  assert exporter in {
1208
1537
  "export",
1209
1538
  "export-strict",
@@ -1212,8 +1541,12 @@ def call_torch_export_export(
1212
1541
  assert not optimization, f"No optimization is implemented for exporter={exporter!r}"
1213
1542
  assert "model" in data, f"model is missing from data: {sorted(data)}"
1214
1543
  assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
1544
+ assert ("-strict" not in exporter) or ("strict" not in exporter_options), (
1545
+ f"Options strict cannot be specified in the exporter name {exporter!r} "
1546
+ f"and in the options {exporter_options}"
1547
+ )
1215
1548
  summary: Dict[str, Union[str, int, float]] = {}
1216
- strict = "-strict" in exporter
1549
+ strict = "-strict" in exporter or exporter_options.pop("strict", False)
1217
1550
  args, kwargs = split_args_kwargs(data["inputs_export"])
1218
1551
  ds = data.get("dynamic_shapes", None)
1219
1552
 
@@ -1223,6 +1556,7 @@ def call_torch_export_export(
1223
1556
  summary["export_args"] = string_type(args, with_shape=True)
1224
1557
  summary["export_kwargs"] = string_type(kwargs, with_shape=True)
1225
1558
  summary["export_dynamic_shapes"] = string_type(ds)
1559
+ summary["export_options"] = str(exporter_options)
1226
1560
 
1227
1561
  # There is an issue with DynamicShape [[],[]] becomes []
1228
1562
  dse = use_dyn_not_str(ds)
@@ -1249,7 +1583,9 @@ def call_torch_export_export(
1249
1583
  data,
1250
1584
  (
1251
1585
  lambda m=model, args=args, kws=kwargs, dse=dse, s=strict: (
1252
- torch.export.export(m, args, kwargs=kws, dynamic_shapes=dse, strict=s)
1586
+ torch.export.export(
1587
+ m, args, kwargs=kws, dynamic_shapes=dse, strict=s, **exporter_options
1588
+ )
1253
1589
  )
1254
1590
  ),
1255
1591
  )
@@ -1292,7 +1628,14 @@ def call_torch_export_export(
1292
1628
  if "ERR_export_export" in summary:
1293
1629
  return summary, data
1294
1630
 
1295
- disc = max_diff(data["run_expected"], expected)
1631
+ verbose_diff = int(os.environ.get("MAXDIFF", "0"))
1632
+ if verbose_diff >= 10:
1633
+ print("[call_torch_export_export] with inputs_export")
1634
+ disc = max_diff(data["run_expected"], expected, verbose=verbose_diff)
1635
+ assert not verbose_diff or (
1636
+ not np.isnan(disc["abs"]) and not np.isinf(disc["abs"])
1637
+ ), f"something went wrong disc={disc}"
1638
+
1296
1639
  for k, v in disc.items():
1297
1640
  summary[f"disc_exported_{k}"] = str(v)
1298
1641
  if verbose:
@@ -1463,6 +1806,9 @@ def validate_onnx_model(
1463
1806
  if verbose:
1464
1807
  print(f"[validate_onnx_model] -- keys={keys}")
1465
1808
  for k_input, k_expected, suffix in keys:
1809
+ if k_input == "inputs_prompt":
1810
+ # this must used onnx_generate
1811
+ continue
1466
1812
  # make_feeds
1467
1813
  assert k_input in data, f"Unable to find {k_input!r} in {sorted(data)}"
1468
1814
  assert k_expected in data, f"Unable to find {k_expected!r} in {sorted(data)}"
@@ -1509,7 +1855,16 @@ def validate_onnx_model(
1509
1855
  print(f"[validate_onnx_model] got={string_type(got, with_shape=True)}")
1510
1856
 
1511
1857
  # compute discrepancies
1512
- disc = max_diff(data[k_expected], got, flatten=True)
1858
+ verbose_diff = int(os.environ.get("MAXDIFF", "0"))
1859
+ if verbose_diff >= 10:
1860
+ print(
1861
+ f"[validate_onnx_model] k_input={k_input!r}, "
1862
+ f"k_expected={k_expected!r}, suffix={suffix!r}"
1863
+ )
1864
+ disc = max_diff(data[k_expected], got, flatten=True, verbose=verbose_diff)
1865
+ assert not verbose_diff or (
1866
+ not np.isnan(disc["abs"]) and not np.isinf(disc["abs"])
1867
+ ), f"something went wrong disc={disc}"
1513
1868
  if verbose:
1514
1869
  print(f"[validate_onnx_model] discrepancies={string_diff(disc)}")
1515
1870
  for k, v in disc.items():
@@ -1524,6 +1879,7 @@ def call_torch_export_onnx(
1524
1879
  verbose: int = 0,
1525
1880
  optimization: Optional[str] = None,
1526
1881
  output_names: Optional[List[str]] = None,
1882
+ exporter_options: Optional[Dict[str, Any]] = None,
1527
1883
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1528
1884
  """
1529
1885
  Exports a model into onnx.
@@ -1536,10 +1892,12 @@ def call_torch_export_onnx(
1536
1892
  :param verbose: verbosity
1537
1893
  :param optimization: optimization to do
1538
1894
  :param output_names: output names to use
1895
+ :param exporter_options: additional options to give the exporter
1539
1896
  :return: two dictionaries, one with some metrics,
1540
1897
  another one with whatever the function produces
1541
1898
  """
1542
1899
  available = {None, "", "ir", "os_ort", "ir+default"}
1900
+ exporter_options = exporter_options or {}
1543
1901
  assert (
1544
1902
  optimization in available
1545
1903
  ), f"unexpected value for optimization={optimization}, available={available}"
@@ -1567,6 +1925,7 @@ def call_torch_export_onnx(
1567
1925
  summary["export_dynamo"] = dynamo
1568
1926
  summary["export_args"] = string_type(args, with_shape=True)
1569
1927
  summary["export_kwargs"] = string_type(kwargs, with_shape=True)
1928
+ summary["export_exporter"] = str(exporter_options)
1570
1929
  opset = data.get("model_opset", None)
1571
1930
  if opset:
1572
1931
  summary["export_opset"] = opset
@@ -1594,6 +1953,11 @@ def call_torch_export_onnx(
1594
1953
  export_export_kwargs["output_names"] = output_names
1595
1954
  if opset:
1596
1955
  export_export_kwargs["opset_version"] = opset
1956
+ assert not (set(export_export_kwargs) & set(exporter_options)), (
1957
+ f"Some options were defined twice, "
1958
+ f"{set(export_export_kwargs) & set(exporter_options)}, "
1959
+ f"you should remove them from exporter_options={exporter_options}"
1960
+ )
1597
1961
  if verbose:
1598
1962
  print(
1599
1963
  f"[call_torch_export_onnx] export_export_kwargs="
@@ -1613,6 +1977,7 @@ def call_torch_export_onnx(
1613
1977
  args,
1614
1978
  kwargs=kws,
1615
1979
  **ekws,
1980
+ **exporter_options,
1616
1981
  )
1617
1982
  )
1618
1983
  ),
@@ -1685,6 +2050,7 @@ def call_torch_export_model_builder(
1685
2050
  verbose: int = 0,
1686
2051
  optimization: Optional[str] = None,
1687
2052
  output_names: Optional[List[str]] = None,
2053
+ exporter_options: Optional[Dict[str, Any]] = None,
1688
2054
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1689
2055
  """
1690
2056
  Exports a model into onnx with :epkg:`ModelBuilder`.
@@ -1696,11 +2062,13 @@ def call_torch_export_model_builder(
1696
2062
  :param verbose: verbosity
1697
2063
  :param optimization: optimization to do
1698
2064
  :param output_names: list of output names to use
2065
+ :param exporter_options: additional options to give the exporter
1699
2066
  :return: two dictionaries, one with some metrics,
1700
2067
  another one with whatever the function produces
1701
2068
  """
1702
2069
  from ..helpers.model_builder_helper import create_model_builder, save_model_builder
1703
2070
 
2071
+ exporter_options = exporter_options or {}
1704
2072
  assert optimization in (
1705
2073
  None,
1706
2074
  "",
@@ -1728,7 +2096,12 @@ def call_torch_export_model_builder(
1728
2096
  ], p=precision, pr=provider, cd=cache_dir: (
1729
2097
  save_model_builder(
1730
2098
  create_model_builder(
1731
- c, m, precision=p, execution_provider=pr, cache_dir=cd
2099
+ c,
2100
+ m,
2101
+ precision=p,
2102
+ execution_provider=pr,
2103
+ cache_dir=cd,
2104
+ **exporter_options,
1732
2105
  )
1733
2106
  )
1734
2107
  )
@@ -1845,6 +2218,7 @@ def call_torch_export_custom(
1845
2218
  optimization: Optional[str] = None,
1846
2219
  dump_folder: Optional[str] = None,
1847
2220
  output_names: Optional[List[str]] = None,
2221
+ exporter_options: Optional[Dict[str, Any]] = None,
1848
2222
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1849
2223
  """
1850
2224
  Exports a model into onnx.
@@ -1858,9 +2232,11 @@ def call_torch_export_custom(
1858
2232
  :param optimization: optimization to do
1859
2233
  :param dump_folder: to store additional information
1860
2234
  :param output_names: list of output names to use
2235
+ :param exporter_options: additional exporter options
1861
2236
  :return: two dictionaries, one with some metrics,
1862
2237
  another one with whatever the function produces
1863
2238
  """
2239
+ exporter_options = exporter_options or {}
1864
2240
  available = {
1865
2241
  "",
1866
2242
  "default",
@@ -1896,11 +2272,20 @@ def call_torch_export_custom(
1896
2272
  assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
1897
2273
  assert "model" in data, f"model is missing from data: {sorted(data)}"
1898
2274
  assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
2275
+ assert ("-strict" not in exporter) or ("strict" not in exporter_options), (
2276
+ f"Options strict cannot be specified in the exporter name {exporter!r} "
2277
+ f"and in the options {exporter_options}"
2278
+ )
2279
+ assert ("-fake" not in exporter) or ("fake" not in exporter_options), (
2280
+ f"Options strict cannot be specified in the exporter name {exporter!r} "
2281
+ f"and in the options {exporter_options}"
2282
+ )
1899
2283
  summary: Dict[str, Union[str, int, float]] = {}
1900
- strict = "-strict" in exporter
2284
+ strict = "-strict" in exporter or exporter_options.pop("strict", False)
1901
2285
  args, kwargs = split_args_kwargs(data["inputs_export"])
1902
2286
  ds = data.get("dynamic_shapes", None)
1903
- if "-fake" in exporter:
2287
+ fake = "-fake" in exporter or exporter_options.pop("fake", False)
2288
+ if fake:
1904
2289
  from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
1905
2290
 
1906
2291
  if verbose:
@@ -1923,8 +2308,10 @@ def call_torch_export_custom(
1923
2308
  summary["export_exporter"] = exporter
1924
2309
  summary["export_optimization"] = optimization or ""
1925
2310
  summary["export_strict"] = strict
2311
+ summary["export_fake"] = fake
1926
2312
  summary["export_args"] = string_type(args, with_shape=True)
1927
2313
  summary["export_kwargs"] = string_type(kwargs, with_shape=True)
2314
+ summary["export_options"] = str(exporter_options)
1928
2315
 
1929
2316
  from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
1930
2317
  from experimental_experiment.xbuilder import OptimizationOptions
@@ -1932,17 +2319,35 @@ def call_torch_export_custom(
1932
2319
  spl = optimization.split("+") if optimization else []
1933
2320
  os_ort = "os_ort" in spl
1934
2321
  optimization = "+".join(_ for _ in spl if _ != "os_ort")
1935
-
1936
- export_options = ExportOptions(
1937
- strict=strict,
1938
- decomposition_table=(
2322
+ inline = "-noinline" not in exporter or exporter_options.pop("inline", True)
2323
+ decomposition_table = (
2324
+ exporter_options.pop("decomposition_table")
2325
+ if "decomposition_table" in exporter_options
2326
+ else (
1939
2327
  "default"
1940
2328
  if ("-default" in exporter or "-dec" in exporter)
1941
2329
  else ("all" if ("-all" in exporter or "-decall" in exporter) else None)
1942
- ),
2330
+ )
2331
+ )
2332
+ large_model = bool(exporter_options.pop("large_model", True))
2333
+ return_optimize_report = bool(exporter_options.pop("return_optimize_report", True))
2334
+ export_modules_as_functions = bool(
2335
+ exporter_options.pop("export_modules_as_functions", False)
2336
+ )
2337
+ external_threshold = int(exporter_options.pop("external_threshold", 1024))
2338
+ summary["export_decomposition_table"] = str(decomposition_table)
2339
+ summary["export_inline"] = str(inline)
2340
+ summary["export_large_model"] = str(large_model)
2341
+ summary["export_return_optimize_report"] = str(return_optimize_report)
2342
+ summary["export_export_modules_as_functions"] = str(export_modules_as_functions)
2343
+ summary["export_external_threshold"] = str(external_threshold)
2344
+
2345
+ export_options = ExportOptions(
2346
+ strict=strict,
2347
+ decomposition_table=decomposition_table,
1943
2348
  save_ep=(os.path.join(dump_folder, f"{exporter}.ep") if dump_folder else None),
2349
+ **exporter_options,
1944
2350
  )
1945
- inline = "-noinline" not in exporter
1946
2351
  options = OptimizationOptions(patterns=optimization) if optimization else None
1947
2352
  model = data["model"]
1948
2353
  kws = dict(
@@ -1950,10 +2355,12 @@ def call_torch_export_custom(
1950
2355
  export_options=export_options,
1951
2356
  options=options,
1952
2357
  optimize=bool(optimization),
1953
- large_model=True,
1954
- return_optimize_report=True,
1955
2358
  verbose=max(verbose - 2, 0),
1956
2359
  inline=inline,
2360
+ large_model=large_model,
2361
+ return_optimize_report=return_optimize_report,
2362
+ export_modules_as_functions=export_modules_as_functions,
2363
+ external_threshold=external_threshold,
1957
2364
  )
1958
2365
  if opset:
1959
2366
  kws["target_opset"] = opset