onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,184 @@
1
+ import ast
2
+ import functools
3
+ from typing import Any, Dict, List, Optional
4
+
5
+
6
+ class OrToBitOrTransformer(ast.NodeTransformer):
7
+ def visit_BoolOp(self, node):
8
+ self.generic_visit(node)
9
+ if isinstance(node.op, ast.Or):
10
+ new_node = node.values[0]
11
+ for value in node.values[1:]:
12
+ new_node = ast.BinOp(left=new_node, op=ast.BitOr(), right=value)
13
+ return ast.copy_location(new_node, node)
14
+ return node
15
+
16
+
17
+ def ast_or_into_bitor(node: "ast.Node") -> "ast.Node":
18
+ """Replaces every operator ``or`` into ``|``."""
19
+ new_node = OrToBitOrTransformer().visit(node)
20
+ return new_node
21
+
22
+
23
+ @functools.lru_cache
24
+ def _rewrite_forward_clamp_float16() -> Dict[str, List[type]]:
25
+
26
+ import transformers
27
+
28
+ _known = {
29
+ "AutoformerEncoderLayer": [
30
+ transformers.models.autoformer.modeling_autoformer.AutoformerEncoderLayer
31
+ ],
32
+ "BartEncoderLayer": [
33
+ transformers.models.bart.modeling_bart.BartEncoderLayer,
34
+ transformers.models.plbart.modeling_plbart.PLBartEncoderLayer,
35
+ ],
36
+ "BigBirdPegasusEncoderLayer": [
37
+ transformers.models.bigbird_pegasus.modeling_bigbird_pegasus.BigBirdPegasusEncoderLayer
38
+ ],
39
+ "BlenderbotSmallEncoderLayer": [
40
+ transformers.models.blenderbot_small.modeling_blenderbot_small.BlenderbotSmallEncoderLayer
41
+ ],
42
+ "InformerEncoderLayer": [
43
+ transformers.models.informer.modeling_informer.InformerEncoderLayer
44
+ ],
45
+ "LEDEncoderLayer": [transformers.models.led.modeling_led.LEDEncoderLayer],
46
+ "MarianEncoderLayer": [transformers.models.marian.modeling_marian.MarianEncoderLayer],
47
+ "MvpEncoderLayer": [transformers.models.mvp.modeling_mvp.MvpEncoderLayer],
48
+ "NllbMoeEncoderLayer": [
49
+ transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeEncoderLayer
50
+ ],
51
+ "TimeSeriesTransformerEncoderLayer": [
52
+ transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerEncoderLayer
53
+ ],
54
+ }
55
+ return _known
56
+
57
+
58
+ @functools.lru_cache
59
+ def known_transformers_rewritings_clamp_float16() -> Dict[str, str]:
60
+ """
61
+ This functions returns the list of known classes to be rewritten.
62
+ in :epkg:`transformers`. Each class is mapped to an alias,
63
+ this alias is then given to :func:`rewritings_transformers_clamp_float16`
64
+ to rewrite the encoder layers because of a specific control flow.
65
+
66
+ .. runpython::
67
+ :showcode:
68
+
69
+ import pprint
70
+ from onnx_diagnostic.torch_export_patches.patch_module_helper import (
71
+ known_transformers_rewritings_clamp_float16,
72
+ )
73
+
74
+ pprint.pprint(known_transformers_rewritings_clamp_float16())
75
+ """
76
+ _alias = {
77
+ "AutoformerEncoder": "AutoformerEncoderLayer",
78
+ "AutoformerEncoderLayer": "AutoformerEncoderLayer",
79
+ "AutoformerForPrediction": "AutoformerEncoderLayer",
80
+ "AutoformerModel": "AutoformerEncoderLayer",
81
+ "BartEncoderLayer": "BartEncoderLayer",
82
+ "BartForConditionalGeneration": "BartEncoderLayer",
83
+ "BartModel": "BartEncoderLayer",
84
+ "BigBirdPegasusForConditionalGeneration": "BigBirdPegasusEncoderLayer",
85
+ "BigBirdPegasusForQuestionAnswering": "BigBirdPegasusEncoderLayer",
86
+ "BigBirdPegasusForCausalLM": "BigBirdPegasusEncoderLayer",
87
+ "BlenderbotSmallEncoderLayer": "BlenderbotSmallEncoderLayer",
88
+ "BlenderbotSmallForConditionalGeneration": "BlenderbotSmallEncoderLayer",
89
+ "BlenderbotSmallForCausalLM": "BlenderbotSmallEncoderLayer",
90
+ "InformerEncoderLayer": "InformerEncoderLayer",
91
+ "InformerForPrediction": "InformerEncoderLayer",
92
+ "LEDEncoderLayer": "LEDEncoderLayer",
93
+ "LEDClassificationHead": "LEDEncoderLayer",
94
+ "LEDForConditionalGeneration": "LEDEncoderLayer",
95
+ "MarianEncoderLayer": "MarianEncoderLayer",
96
+ "MarianEncoder": "MarianEncoderLayer",
97
+ "MarianModel": "MarianEncoderLayer",
98
+ "MarianMTModel": "MarianEncoderLayer",
99
+ "MvpEncoderLayer": "MvpEncoderLayer",
100
+ "MvpPrompt": "MvpEncoderLayer",
101
+ "MvpForConditionalGeneration": "MvpEncoderLayer",
102
+ "MvpForSequenceClassification": "MvpEncoderLayer",
103
+ "MvpForQuestionAnswering": "MvpEncoderLayer",
104
+ "MvpForCausalLM": "MvpEncoderLayer",
105
+ "NllbMoeEncoderLayer": "NllbMoeEncoderLayer",
106
+ "NllbMoeForConditionalGeneration": "NllbMoeEncoderLayer",
107
+ "PLBartEncoderLayer": "BartEncoderLayer",
108
+ "PLBartForConditionalGeneration": "BartEncoderLayer",
109
+ "TimeSeriesTransformerEncoderLayer": "TimeSeriesTransformerEncoderLayer",
110
+ "TimeSeriesTransformerForPrediction": "TimeSeriesTransformerEncoderLayer",
111
+ }
112
+ return _alias
113
+
114
+
115
+ def rewritings_transformers_clamp_float16(cls_name) -> List[type]:
116
+ """
117
+ Rewrites known control flows equal to this:
118
+
119
+ .. code-block:: python
120
+
121
+ if hidden_states.dtype == torch.float16 and (
122
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
123
+ ):
124
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
125
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
126
+
127
+ *cls_name* is the class name. It is mapped with a list of other class names
128
+ to rename. Here is the known list:
129
+
130
+ .. runpython::
131
+ :showcode:
132
+
133
+ import pprint
134
+ from onnx_diagnostic.torch_export_patches.patch_module_helper import (
135
+ _rewrite_forward_clamp_float16,
136
+ )
137
+
138
+ pprint.pprint(_rewrite_forward_clamp_float16())
139
+
140
+ Function `_rewrite_forward_clamp_float16` collects
141
+ all model classes using those layers.
142
+ """
143
+ _known = _rewrite_forward_clamp_float16()
144
+
145
+ assert cls_name in _known, f"cls_name={cls_name!r} unknown in {sorted(_known)}."
146
+
147
+ bd = dict(
148
+ filter_node=(
149
+ lambda node: isinstance(node, ast.If) and not isinstance(node.test, ast.Name)
150
+ ),
151
+ pre_rewriter=ast_or_into_bitor,
152
+ )
153
+
154
+ def _add(f):
155
+ g = bd.copy()
156
+ g["function"] = f
157
+ return g
158
+
159
+ return [_add(cls.forward) for cls in _known[cls_name]]
160
+
161
+
162
+ def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]:
163
+ """
164
+ Returns a known list of classes mapped to a known rewritings
165
+ because of control flow. See :func:`known_transformers_rewritings_clamp_float16`.
166
+
167
+ :param cls_name: name of the class
168
+ :return: a list of rewriting
169
+
170
+ .. runpython::
171
+ :showcode:
172
+
173
+ import pprint
174
+ from onnx_diagnostic.torch_export_patches.patch_module_helper import (
175
+ code_needing_rewriting,
176
+ )
177
+
178
+ pprint.pprint(code_needing_rewriting("BartForConditionalGeneration"))
179
+ """
180
+ aliases = known_transformers_rewritings_clamp_float16()
181
+ if cls_name in aliases:
182
+ alias = aliases[cls_name]
183
+ return rewritings_transformers_clamp_float16(alias)
184
+ return None