onnx-diagnostic 0.8.0__py3-none-any.whl → 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +78 -22
  3. onnx_diagnostic/export/api.py +35 -5
  4. onnx_diagnostic/export/control_flow.py +511 -0
  5. onnx_diagnostic/export/control_flow_research.py +135 -0
  6. onnx_diagnostic/ext_test_case.py +33 -9
  7. onnx_diagnostic/helpers/cache_helper.py +217 -203
  8. onnx_diagnostic/helpers/helper.py +6 -2
  9. onnx_diagnostic/helpers/log_helper.py +39 -5
  10. onnx_diagnostic/helpers/memory_peak.py +2 -0
  11. onnx_diagnostic/helpers/mini_onnx_builder.py +55 -3
  12. onnx_diagnostic/helpers/onnx_helper.py +13 -16
  13. onnx_diagnostic/helpers/rt_helper.py +579 -15
  14. onnx_diagnostic/helpers/torch_helper.py +5 -0
  15. onnx_diagnostic/tasks/image_text_to_text.py +5 -1
  16. onnx_diagnostic/tasks/text2text_generation.py +1 -0
  17. onnx_diagnostic/tasks/text_generation.py +84 -54
  18. onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
  19. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
  20. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
  21. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -1
  22. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +563 -61
  23. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
  24. onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
  25. onnx_diagnostic/torch_models/validate.py +620 -213
  26. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/METADATA +1 -1
  27. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/RECORD +30 -28
  28. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/WHEEL +0 -0
  29. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/licenses/LICENSE.txt +0 -0
  30. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.8.0"
6
+ __version__ = "0.8.2"
7
7
  __author__ = "Xavier Dupré"
@@ -265,7 +265,7 @@ def get_parser_config() -> ArgumentParser:
265
265
  "--mop",
266
266
  metavar="KEY=VALUE",
267
267
  nargs="*",
268
- help="Additional model options, use to change some parameters of the model, "
268
+ help="Additional model options, used to change some parameters of the model, "
269
269
  "example:\n --mop attn_implementation=sdpa or --mop attn_implementation=eager",
270
270
  action=_ParseDict,
271
271
  )
@@ -442,11 +442,17 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
442
442
  default=True,
443
443
  action=_BoolOrParseDictPatch,
444
444
  nargs="*",
445
- help="Applies patches before exporting, it can be a boolean "
446
- "to enable to disable the patches or be more finetuned. It is possible to "
447
- "disable patch for torch by adding "
448
- '--patch "patch_sympy=False" --patch "patch_torch=False", '
449
- "default is True.",
445
+ help=textwrap.dedent(
446
+ """
447
+ Applies patches before exporting, it can be a boolean
448
+ to enable to disable the patches or be more finetuned
449
+ (default is True). It is possible to disable patch for torch
450
+ by adding:
451
+ --patch "patch_sympy=False" --patch "patch_torch=False"
452
+ """.strip(
453
+ "\n"
454
+ )
455
+ ),
450
456
  )
451
457
  parser.add_argument(
452
458
  "--rewrite",
@@ -476,10 +482,16 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
476
482
  "--inputs2",
477
483
  default=1,
478
484
  type=int,
479
- help="Validates or exports the model on a second set of inputs\n"
480
- "to check the exported model supports dynamism. The values is used "
481
- "as an increment to the first set of inputs. A high value may trick "
482
- "a different behavior in the model and missed by the exporter.",
485
+ help=textwrap.dedent(
486
+ """
487
+ Validates or exports the model on a second set of inputs
488
+ to check the exported model supports dynamism. The values is used
489
+ as an increment to the first set of inputs. A high value may trick
490
+ a different behavior in the model and missed by the exporter.
491
+ """.strip(
492
+ "\n"
493
+ )
494
+ ),
483
495
  )
484
496
  parser.add_argument(
485
497
  "--runtime",
@@ -512,9 +524,15 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
512
524
  parser.add_argument(
513
525
  "--ortfusiontype",
514
526
  required=False,
515
- help="Applies onnxruntime fusion, this parameter should contain the\n"
516
- "model type or multiple values separated by `|`. `ALL` can be used\n"
517
- "to run them all.",
527
+ help=textwrap.dedent(
528
+ """
529
+ Applies onnxruntime fusion, this parameter should contain the
530
+ model type or multiple values separated by `|`. `ALL` can be used
531
+ to run them all.
532
+ """.strip(
533
+ "\n"
534
+ )
535
+ ),
518
536
  )
519
537
  parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
520
538
  parser.add_argument("--dtype", help="Changes dtype if necessary.")
@@ -523,18 +541,32 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
523
541
  "--iop",
524
542
  metavar="KEY=VALUE",
525
543
  nargs="*",
526
- help="Additional input options, use to change the default"
527
- "inputs use to export, example:\n --iop cls_cache=SlidingWindowCache"
528
- "\n --iop cls_cache=StaticCache",
544
+ help=textwrap.dedent(
545
+ """
546
+ Additional input options, used to change the default
547
+ inputs use to export. Examples:
548
+ --iop cls_cache=SlidingWindowCache
549
+ --iop cls_cache=StaticCache
550
+ """.strip(
551
+ "\n"
552
+ )
553
+ ),
529
554
  action=_ParseDict,
530
555
  )
531
556
  parser.add_argument(
532
557
  "--mop",
533
558
  metavar="KEY=VALUE",
534
559
  nargs="*",
535
- help="Additional model options, use to change some parameters of the model, "
536
- "example:\n --mop attn_implementation=sdpa --mop attn_implementation=eager\n "
537
- "--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\"",
560
+ help=textwrap.dedent(
561
+ """
562
+ Additional model options, used to change some parameters
563
+ of the model. Example:
564
+ --mop attn_implementation=sdpa --mop attn_implementation=eager"
565
+ --mop "rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}"
566
+ """.strip(
567
+ "\n"
568
+ )
569
+ ),
538
570
  action=_ParseDict,
539
571
  )
540
572
  if name == "validate":
@@ -566,9 +598,32 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
566
598
  parser.add_argument(
567
599
  "--quiet-input-sets",
568
600
  default="",
569
- help="Avoids raising an exception when an input sets does not work with "
570
- "the exported model.\nExample: --quiet-input-sets=inputs,inputs22",
601
+ help=textwrap.dedent(
602
+ """
603
+ Avoids raising an exception when an input sets does not work with
604
+ the exported model. Example:
605
+ --quiet-input-sets=inputs,inputs22
606
+ """.strip(
607
+ "\n"
608
+ )
609
+ ),
571
610
  )
611
+ parser.add_argument(
612
+ "--expop",
613
+ metavar="KEY=VALUE",
614
+ nargs="*",
615
+ help=textwrap.dedent(
616
+ """
617
+ Additional exporter options, use to change some parameters
618
+ of the model. Examples:
619
+ --expop report=True
620
+ --expop report=True --expop verify=True
621
+ """.strip(
622
+ "\n"
623
+ )
624
+ ),
625
+ action=_ParseDict,
626
+ )
572
627
  return parser
573
628
 
574
629
 
@@ -634,6 +689,7 @@ def _cmd_validate(argv: List[Any]):
634
689
  output_names=(
635
690
  None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
636
691
  ),
692
+ exporter_options=args.expop,
637
693
  )
638
694
  print("")
639
695
  print("-- summary --")
@@ -940,7 +996,7 @@ def get_parser_agg() -> ArgumentParser:
940
996
  "n_model_faster2x,n_model_faster3x,n_model_faster4x,n_node_attention,"
941
997
  "n_node_attention23,n_node_rotary_embedding,n_node_rotary_embedding23,"
942
998
  "n_node_gqa,n_node_layer_normalization,n_node_layer_normalization23,"
943
- "peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,"
999
+ "peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,n_node_random,"
944
1000
  "n_node_constant,n_node_shape,n_node_expand,"
945
1001
  "n_node_function,n_node_initializer,n_node_scatter,"
946
1002
  "time_export_unbiased,onnx_n_nodes_no_cst,n_node_initializer_small",
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List, Sequence, Optional, Tuple, Union
1
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
2
2
  import torch
3
3
 
4
4
 
@@ -14,6 +14,10 @@ def to_onnx(
14
14
  output_names: Optional[List[str]] = None,
15
15
  output_dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
16
16
  exporter: str = "onnx-dynamo",
17
+ exporter_kwargs: Optional[Dict[str, Any]] = None,
18
+ save_ep: Optional[str] = None,
19
+ optimize: bool = True,
20
+ use_control_flow_dispatcher: bool = False,
17
21
  ) -> Any:
18
22
  """
19
23
  Common API for exporters. By default, the models are optimized to use the
@@ -32,6 +36,11 @@ def to_onnx(
32
36
  :param output_names: to change the output of the onnx model
33
37
  :param output_dynamic_shapes: to overwrite the dynamic shapes names
34
38
  :param exporter: exporter to use (``onnx-dynamo``, ``modelbuilder``, ``custom``)
39
+ :param exporter_kwargs: additional parameters sent to the exporter
40
+ :param save_ep: saves the exported program
41
+ :param optimize: optimizes the model
42
+ :param use_control_flow_dispatcher: use the dispatcher created to supported
43
+ custom loops (see :func:`onnx_diagnostic.export.control_flow.loop_for`)
35
44
  :return: the output of the selected exporter, usually a structure including
36
45
  an onnx model
37
46
 
@@ -48,9 +57,23 @@ def to_onnx(
48
57
  )
49
58
  """
50
59
  if exporter == "custom":
51
- from experimental_experiment.torch_interpreter import to_onnx as _to_onnx
60
+ from experimental_experiment.torch_interpreter import (
61
+ to_onnx as _to_onnx,
62
+ ExportOptions,
63
+ )
52
64
  from experimental_experiment.xbuilder import OptimizationOptions
53
65
 
66
+ if use_control_flow_dispatcher:
67
+ from .control_flow import create_global_dispatcher
68
+
69
+ dispatcher = create_global_dispatcher()
70
+
71
+ options = None
72
+ if exporter_kwargs is not None:
73
+ options = exporter_kwargs.pop("options", None)
74
+ if options is None:
75
+ options = OptimizationOptions(patterns="default+onnxruntime")
76
+
54
77
  return _to_onnx(
55
78
  mod,
56
79
  args=args,
@@ -63,7 +86,10 @@ def to_onnx(
63
86
  dynamic_shapes=dynamic_shapes,
64
87
  large_model=True,
65
88
  output_dynamic_shapes=output_dynamic_shapes,
66
- options=OptimizationOptions(patterns="default+onnxruntime"),
89
+ export_options=ExportOptions(save_ep=save_ep),
90
+ options=options,
91
+ **(exporter_kwargs or {}),
92
+ dispatcher=dispatcher if use_control_flow_dispatcher else None,
67
93
  )
68
94
  if exporter in ("dynamo", "onnx-dynamo"):
69
95
  import onnxscript.rewriter.ort_fusions as ort_fusions
@@ -80,9 +106,12 @@ def to_onnx(
80
106
  opset_version=target_opset,
81
107
  dynamic_shapes=dynamic_shapes,
82
108
  dynamo=True,
109
+ **(exporter_kwargs or {}),
83
110
  )
84
- ort_fusions.optimize_for_ort(epo.model)
85
- epo.save(filename)
111
+ if optimize:
112
+ ort_fusions.optimize_for_ort(epo.model)
113
+ if filename:
114
+ epo.save(filename, external_data=True)
86
115
  return epo
87
116
 
88
117
  if exporter == "modelbuilder":
@@ -117,6 +146,7 @@ def to_onnx(
117
146
  precision=str(first_float[0].dtype).split(".")[-1],
118
147
  execution_provider="cuda" if first.is_cuda else "cpu",
119
148
  cache_dir=os.path.dirname(filename),
149
+ **(exporter_kwargs or {}),
120
150
  )
121
151
  save_model_builder(onx, os.path.dirname(filename))
122
152
  return onx