onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
onnx_diagnostic/api.py ADDED
@@ -0,0 +1,15 @@
1
+ from typing import Any
2
+
3
+
4
+ class TensorLike:
5
+ """Mocks a tensor."""
6
+
7
+ @property
8
+ def dtype(self) -> Any:
9
+ "Must be overwritten."
10
+ raise NotImplementedError("dtype must be overwritten.")
11
+
12
+ @property
13
+ def shape(self) -> Any:
14
+ "Must be overwritten."
15
+ raise NotImplementedError("shape must be overwritten.")
onnx_diagnostic/doc.py ADDED
@@ -0,0 +1,100 @@
1
+ from typing import Optional
2
+ import numpy as np
3
+
4
+
5
+ def get_latest_pypi_version(package_name="onnx-diagnostic") -> str:
6
+ """Returns the latest published version."""
7
+
8
+ import requests
9
+
10
+ url = f"https://pypi.org/pypi/{package_name}/json"
11
+ response = requests.get(url)
12
+
13
+ assert response.status_code == 200, f"Unable to retrieve the version response={response}"
14
+ data = response.json()
15
+ version = data["info"]["version"]
16
+ return version
17
+
18
+
19
+ def update_version_package(version: str, package_name="onnx-diagnostic") -> str:
20
+ "Adds dev if the major version is different from the latest published one."
21
+ released = get_latest_pypi_version(package_name)
22
+ shorten_r = ".".join(released.split(".")[:2])
23
+ shorten_v = ".".join(version.split(".")[:2])
24
+ return version if shorten_r == shorten_v else f"{shorten_v}.dev"
25
+
26
+
27
+ def reset_torch_transformers(gallery_conf, fname):
28
+ "Resets torch dynamo for :epkg:`sphinx-gallery`."
29
+ import matplotlib.pyplot as plt
30
+ import torch
31
+
32
+ plt.style.use("ggplot")
33
+ torch._dynamo.reset()
34
+
35
+
36
+ def plot_legend(
37
+ text: str, text_bottom: str = "", color: str = "green", fontsize: int = 15
38
+ ) -> "matplotlib.axes.Axes": # noqa: F821
39
+ import matplotlib.pyplot as plt
40
+
41
+ fig = plt.figure(figsize=(2, 2))
42
+ ax = fig.add_subplot()
43
+ ax.axis([0, 5, 0, 5])
44
+ ax.text(2.5, 4, "END", fontsize=10, horizontalalignment="center")
45
+ ax.text(
46
+ 2.5,
47
+ 2.5,
48
+ text,
49
+ fontsize=fontsize,
50
+ bbox={"facecolor": color, "alpha": 0.5, "pad": 10},
51
+ horizontalalignment="center",
52
+ verticalalignment="center",
53
+ )
54
+ if text_bottom:
55
+ ax.text(4.5, 0.5, text_bottom, fontsize=7, horizontalalignment="right")
56
+ ax.grid(False)
57
+ ax.set_axis_off()
58
+ return ax
59
+
60
+
61
+ def rotate_align(ax, angle=15, align="right"):
62
+ """Rotates x-label and align them to thr right. Returns ax."""
63
+ for label in ax.get_xticklabels():
64
+ label.set_rotation(angle)
65
+ label.set_horizontalalignment(align)
66
+ return ax
67
+
68
+
69
+ def save_fig(ax, name: str):
70
+ """Applies ``tight_layout`` and saves the figures. Returns ax."""
71
+ import matplotlib.pyplot as plt
72
+
73
+ plt.tight_layout()
74
+ fig = ax.get_figure()
75
+ fig.savefig(name)
76
+ return ax
77
+
78
+
79
+ def title(ax: "plt.axes", title: str) -> "plt.axes": # noqa: F821
80
+ "Adds a title to axes and returns them."
81
+ ax.set_title(title)
82
+ return ax
83
+
84
+
85
+ def plot_histogram(
86
+ tensor: np.ndarray,
87
+ ax: Optional["plt.axes"] = None, # noqa: F821
88
+ bins: int = 30,
89
+ color: str = "orange",
90
+ alpha: float = 0.7,
91
+ ) -> "plt.axes": # noqa: F821
92
+ "Computes the distribution for a tensor."
93
+ if ax is None:
94
+ import matplotlib.pyplot as plt
95
+
96
+ ax = plt.gca()
97
+ ax.cla()
98
+ ax.hist(tensor, bins=30, color="orange", alpha=0.7)
99
+ ax.set_yscale("log")
100
+ return ax
@@ -0,0 +1,2 @@
1
+ from .dynamic_shapes import CoupleInputsDynamicShapes, ModelInputs
2
+ from .validate import validate_ep
@@ -0,0 +1,124 @@
1
+ from typing import Any, Dict, List, Sequence, Optional, Tuple, Union
2
+ import torch
3
+
4
+
5
+ def to_onnx(
6
+ mod: Union["torch.nn.Module", "torch.fx.GraphModule"], # noqa: F821
7
+ args: Optional[Sequence["torch.Tensor"]] = None, # noqa: F821
8
+ kwargs: Optional[Dict[str, "torch.Tensor"]] = None, # noqa: F821
9
+ input_names: Optional[Sequence[str]] = None,
10
+ target_opset: Optional[Union[int, Dict[str, int]]] = None,
11
+ verbose: int = 0,
12
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
13
+ filename: Optional[str] = None,
14
+ output_names: Optional[List[str]] = None,
15
+ output_dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
16
+ exporter: str = "onnx-dynamo",
17
+ ) -> Any:
18
+ """
19
+ Common API for exporters. By default, the models are optimized to use the
20
+ most efficient kernels implemented in :epkg:`onnxruntime`.
21
+
22
+ :param mod: torch model
23
+ :param args: unnamed arguments
24
+ :param kwargs: named arguments
25
+ :param input_names: input names for the onnx model (optional)
26
+ :param target_opset: opset to target, if not specified, each converter
27
+ keeps its default value
28
+ :param verbose: verbosity level
29
+ :param dynamic_shapes: dynamic shapes, usually a nested structure
30
+ included a dictionary for each tensor
31
+ :param filename: output filename
32
+ :param output_names: to change the output of the onnx model
33
+ :param output_dynamic_shapes: to overwrite the dynamic shapes names
34
+ :param exporter: exporter to use (``onnx-dynamo``, ``modelbuilder``, ``custom``)
35
+ :return: the output of the selected exporter, usually a structure including
36
+ an onnx model
37
+
38
+ A simple example:
39
+
40
+ .. code-block:: python
41
+
42
+ to_onnx(
43
+ model,
44
+ kwargs=inputs,
45
+ dynamic_shapes=ds,
46
+ exporter=exporter,
47
+ filename=filename,
48
+ )
49
+ """
50
+ if exporter == "custom":
51
+ from experimental_experiment.torch_interpreter import to_onnx as _to_onnx
52
+ from experimental_experiment.xbuilder import OptimizationOptions
53
+
54
+ return _to_onnx(
55
+ mod,
56
+ args=args,
57
+ kwargs=kwargs,
58
+ input_names=input_names,
59
+ output_names=output_names,
60
+ target_opset=target_opset,
61
+ verbose=verbose,
62
+ filename=filename,
63
+ dynamic_shapes=dynamic_shapes,
64
+ large_model=True,
65
+ output_dynamic_shapes=output_dynamic_shapes,
66
+ options=OptimizationOptions(patterns="default+onnxruntime"),
67
+ )
68
+ if exporter in ("dynamo", "onnx-dynamo"):
69
+ import onnxscript.rewriter.ort_fusions as ort_fusions
70
+
71
+ assert (
72
+ not output_dynamic_shapes
73
+ ), f"output_dynamic_shapes not supported for exporter={exporter!r}"
74
+ epo = torch.onnx.export(
75
+ mod,
76
+ args=args or tuple(),
77
+ kwargs=kwargs,
78
+ input_names=input_names,
79
+ output_names=output_names,
80
+ opset_version=target_opset,
81
+ dynamic_shapes=dynamic_shapes,
82
+ dynamo=True,
83
+ )
84
+ ort_fusions.optimize_for_ort(epo.model)
85
+ epo.save(filename)
86
+ return epo
87
+
88
+ if exporter == "modelbuilder":
89
+ import os
90
+ from ..helpers import flatten_object, string_type
91
+ from ..helpers.model_builder_helper import create_model_builder, save_model_builder
92
+
93
+ assert filename, f"filename must be specified for exporter={exporter!r}"
94
+ assert (
95
+ not output_dynamic_shapes
96
+ ), f"output_dynamic_shapes not supported for exporter={exporter!r}"
97
+ assert hasattr(mod, "config"), f"configuration is missing in model class {type(mod)}"
98
+ assert not args, f"only kwargs can be defined with exporter={exporter!r}"
99
+ assert list(kwargs) == ["input_ids", "attention_mask", "past_key_values"], ( # type: ignore[arg-type]
100
+ f"Only a specified set of inputs is supported for exporter={exporter!r}, "
101
+ f"but it is {list(kwargs)}" # type: ignore[arg-type]
102
+ )
103
+ flat_inputs = flatten_object(kwargs, drop_keys=True)
104
+ first = flat_inputs[0]
105
+ first_float = [
106
+ t
107
+ for t in flat_inputs
108
+ if t.dtype in {torch.float32, torch.double, torch.float16, torch.bfloat16}
109
+ ]
110
+ assert first_float, (
111
+ f"Unable to find a float tensor in the inputs "
112
+ f"{string_type(kwargs, with_shape=True)}"
113
+ )
114
+ onx = create_model_builder(
115
+ mod.config,
116
+ mod,
117
+ precision=str(first_float[0].dtype).split(".")[-1],
118
+ execution_provider="cuda" if first.is_cuda else "cpu",
119
+ cache_dir=os.path.dirname(filename),
120
+ )
121
+ save_model_builder(onx, os.path.dirname(filename))
122
+ return onx
123
+
124
+ raise ValueError(f"Unknown exporter={exporter!r}")