onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,4 @@
1
+ from .evaluator import ExtendedReferenceEvaluator
2
+ from .ort_evaluator import OnnxruntimeEvaluator
3
+ from .torch_evaluator import TorchOnnxEvaluator
4
+ from .report_results_comparison import ReportResultComparison
@@ -0,0 +1,254 @@
1
+ from logging import getLogger
2
+ from typing import Any, Dict, List, Optional, Union
3
+ from onnx import FunctionProto, ModelProto, NodeProto, TypeProto
4
+ from onnx.defs import get_schema
5
+ from onnx.reference import ReferenceEvaluator
6
+ from onnx.reference.op_run import OpRun
7
+ from .ops.op_add_add_mul_mul import (
8
+ AddAdd,
9
+ AddMul,
10
+ AddSharedInput,
11
+ MulAdd,
12
+ MulMul,
13
+ MulSharedInput,
14
+ MulSub,
15
+ SubMul,
16
+ )
17
+ from .ops.op_attention import Attention
18
+ from .ops.op_average_pool_grad import AveragePoolGrad
19
+ from .ops.op_bias_softmax import BiasSoftmax
20
+ from .ops.op_cast_like import CastLike_15, CastLike_19
21
+ from .ops.op_complex import ComplexModule, ToComplex
22
+ from .ops.op_concat import Concat
23
+ from .ops.op_constant_of_shape import ConstantOfShape
24
+ from .ops.op_fused_matmul import FusedMatMul
25
+ from .ops.op_gather import Gather
26
+ from .ops.op_gather_elements import GatherElements
27
+ from .ops.op_gather_grad import GatherGrad
28
+ from .ops.op_memcpy_host import MemcpyFromHost, MemcpyToHost
29
+ from .ops.op_mul_sigmoid import MulSigmoid
30
+ from .ops.op_negxplus1 import NegXplus1
31
+ from .ops.op_qlinear_average_pool import QLinearAveragePool
32
+ from .ops.op_qlinear_conv import QLinearConv
33
+ from .ops.op_quick_gelu import QuickGelu
34
+ from .ops.op_replace_zero import ReplaceZero
35
+ from .ops.op_rotary import Rotary
36
+ from .ops.op_scan import Scan
37
+ from .ops.op_scatter_elements import ScatterElements
38
+ from .ops.op_scatternd_of_shape import MaskedScatterNDOfShape, ScatterNDOfShape
39
+ from .ops.op_simplified_layer_normalization import SimplifiedLayerNormalization
40
+ from .ops.op_skip_layer_normalization import SkipLayerNormalization
41
+ from .ops.op_slice import Slice_1, Slice_10
42
+ from .ops.op_transpose_cast import Transpose2DCastFP16, Transpose2DCastFP32
43
+ from .ops.op_tri_matrix import TriMatrix
44
+
45
+
46
+ logger = getLogger("onnx-diagnostic-eval")
47
+
48
+
49
+ class ExtendedReferenceEvaluator(ReferenceEvaluator):
50
+ """
51
+ This class replaces the python implementation by custom implementation.
52
+ The evaluator allows to test
53
+ scenarios outside what an onnx backend bound to the official onnx
54
+ operators definition could do such as optimization patterns
55
+ involving onnxruntime contrib operators.
56
+
57
+ ::
58
+
59
+ from onnx_diagnostic.reference import ExtendedReferenceEvaluator
60
+ ref = ExtendedReferenceEvaluator(...)
61
+
62
+ The class overloads or adds the following operators by default:
63
+
64
+ .. runpython::
65
+ :showcode:
66
+
67
+ import pprint
68
+ from onnx_diagnostic.reference import ExtendedReferenceEvaluator
69
+
70
+ pprint.pprint(ExtendedReferenceEvaluator.default_ops)
71
+ """
72
+
73
+ default_ops: List[type[OpRun]] = [
74
+ AddAdd,
75
+ AddMul,
76
+ AddSharedInput,
77
+ Attention,
78
+ AveragePoolGrad,
79
+ BiasSoftmax,
80
+ Concat,
81
+ CastLike_15,
82
+ CastLike_19,
83
+ ComplexModule,
84
+ ConstantOfShape,
85
+ FusedMatMul,
86
+ Gather,
87
+ GatherElements,
88
+ GatherGrad,
89
+ MaskedScatterNDOfShape,
90
+ MemcpyFromHost,
91
+ MemcpyToHost,
92
+ MulAdd,
93
+ MulMul,
94
+ MulSharedInput,
95
+ MulSigmoid,
96
+ MulSub,
97
+ NegXplus1,
98
+ QLinearConv,
99
+ QLinearAveragePool,
100
+ QuickGelu,
101
+ ReplaceZero,
102
+ Rotary,
103
+ Scan,
104
+ ScatterElements,
105
+ ScatterNDOfShape,
106
+ SimplifiedLayerNormalization,
107
+ SkipLayerNormalization,
108
+ Slice_1,
109
+ Slice_10,
110
+ SubMul,
111
+ ToComplex,
112
+ Transpose2DCastFP16,
113
+ Transpose2DCastFP32,
114
+ TriMatrix,
115
+ ]
116
+
117
+ @staticmethod
118
+ def filter_ops(proto, new_ops, opsets):
119
+ if opsets is None and isinstance(proto, (ModelProto, FunctionProto)):
120
+ opsets = {d.domain: d.version for d in proto.opset_import}
121
+ best = {}
122
+ renamed = {}
123
+ for cl in new_ops:
124
+ if "_" not in cl.__name__:
125
+ continue
126
+ vers = cl.__name__.split("_")
127
+ try:
128
+ v = int(vers[-1])
129
+ except ValueError:
130
+ # not a version
131
+ continue
132
+ if opsets is not None and v > opsets.get(cl.op_domain, 1):
133
+ continue
134
+ renamed[cl.__name__] = cl
135
+ key = cl.op_domain, "_".join(vers[:-1])
136
+ if key not in best or best[key][0] < v:
137
+ best[key] = (v, cl)
138
+
139
+ modified = []
140
+ for cl in new_ops:
141
+ if cl.__name__ not in renamed:
142
+ modified.append(cl)
143
+ for k, v in best.items():
144
+ atts = {"domain": k[0]}
145
+ bases = (v[1],)
146
+ if not hasattr(v[1], "op_schema"):
147
+ atts["op_schema"] = get_schema(k[1], v[0], domain=v[1].op_domain)
148
+ new_cl = type(k[1], bases, atts)
149
+ modified.append(new_cl)
150
+
151
+ new_ops = modified
152
+ return new_ops
153
+
154
+ def __init__(
155
+ self,
156
+ proto: Any,
157
+ opsets: Optional[Dict[str, int]] = None,
158
+ functions: Optional[List[Union[ReferenceEvaluator, FunctionProto]]] = None,
159
+ verbose: int = 0,
160
+ new_ops: Optional[List[type[OpRun]]] = None,
161
+ **kwargs,
162
+ ):
163
+ if new_ops is None:
164
+ new_ops = ExtendedReferenceEvaluator.default_ops
165
+ else:
166
+ new_ops = new_ops.copy()
167
+ new_ops.extend(ExtendedReferenceEvaluator.default_ops)
168
+ new_ops = ExtendedReferenceEvaluator.filter_ops(proto, new_ops, opsets)
169
+
170
+ ReferenceEvaluator.__init__(
171
+ self,
172
+ proto,
173
+ opsets=opsets,
174
+ functions=functions,
175
+ verbose=verbose,
176
+ new_ops=new_ops,
177
+ **kwargs,
178
+ )
179
+
180
+ def _log(self, level: int, pattern: str, *args: List[Any]) -> None:
181
+ if level < self.verbose:
182
+ new_args = [self._log_arg(a) for a in args]
183
+ print(pattern % tuple(new_args))
184
+ else:
185
+ logger.debug(pattern, *args)
186
+
187
+ def run(self, *args, **kwargs):
188
+ """See :meth:`onnx.reference.ReferenceEvaluator.run`."""
189
+ if len(args) == 1 and isinstance(args[0], list):
190
+ feeds = dict(zip(self.input_names, args[0]))
191
+ return self.run(None, feeds, **kwargs)
192
+ if isinstance(self.proto_, FunctionProto):
193
+ return self._run_function(*args, **kwargs)
194
+ return ReferenceEvaluator.run(self, *args, **kwargs)
195
+
196
+ def _load_impl(self, node: NodeProto, input_types: TypeProto | None = None) -> Any:
197
+ res = super()._load_impl(node, input_types)
198
+ assert (
199
+ not hasattr(res, "op_domain") or res.op_domain == node.domain
200
+ ), f"Domain mismatch {res.op_domain!r} != {node.domain} for node={node}"
201
+ return res
202
+
203
+ def _run_function(
204
+ self,
205
+ output_names,
206
+ feed_inputs: Dict[str, Any],
207
+ attributes: Optional[Dict[str, Any]] = None,
208
+ intermediate: bool = False,
209
+ ) -> Union[Dict[str, Any], List[Any]]: # type: ignore
210
+ if output_names is None:
211
+ output_names = self.output_names
212
+
213
+ # step 1: inputs and initializers
214
+ results = {"": None} # optional input
215
+ results.update(self.rt_inits_) # type: ignore[arg-type]
216
+ results.update(feed_inputs)
217
+ for k, v in self.rt_inits_.items():
218
+ self._log(2, " +C %s: %s", k, v) # type: ignore[arg-type]
219
+ for k, v in feed_inputs.items():
220
+ self._log(2, " +I %s: %s", k, v) # type: ignore[arg-type]
221
+
222
+ # step 2: execute nodes
223
+ for node in self.rt_nodes_:
224
+ self._log(1, "%s(%s) -> %s", node.op_type, node.input, node.output)
225
+ for i in node.input:
226
+ if i not in results:
227
+ raise RuntimeError(
228
+ f"Unable to find input {i!r} in known results {sorted(results)}, "
229
+ f"self.rt_inits_ has {sorted(self.rt_inits_)}, "
230
+ f"feed_inputs has {sorted(feed_inputs)}."
231
+ )
232
+ inputs = [results[i] for i in node.input]
233
+ linked_attributes = {}
234
+ if node.has_linked_attribute and attributes:
235
+ linked_attributes["linked_attributes"] = attributes
236
+ if node.need_context():
237
+ outputs = node.run(*inputs, context=results, **linked_attributes)
238
+ else:
239
+ outputs = node.run(*inputs, **linked_attributes)
240
+ for name, value in zip(node.output, outputs):
241
+ self._log(2, " + %s: %s", name, value) # type: ignore[arg-type]
242
+ results[name] = value
243
+
244
+ # return the results
245
+ if intermediate:
246
+ return results
247
+
248
+ for name in output_names:
249
+ if name not in results:
250
+ raise RuntimeError(
251
+ f"Unable to find output name {name!r} "
252
+ f"in {sorted(results)}, proto is\n{self.proto_}"
253
+ )
254
+ return [results[name] for name in output_names]
@@ -0,0 +1 @@
1
+
@@ -0,0 +1,68 @@
1
+ import numpy as np
2
+ from onnx.reference.op_run import OpRun
3
+
4
+
5
+ class AddAdd(OpRun):
6
+ op_domain = "onnx_extended.ortops.optim.cuda"
7
+
8
+ def _run(self, x, y, z):
9
+ return (x + y + z,)
10
+
11
+
12
+ class MulMul(OpRun):
13
+ op_domain = "onnx_extended.ortops.optim.cuda"
14
+
15
+ def _run(self, x, y, z):
16
+ return (x * y * z,)
17
+
18
+
19
+ class AddMul(OpRun):
20
+ op_domain = "onnx_extended.ortops.optim.cuda"
21
+
22
+ def _run(self, x, y, z, transposeMiddle=None):
23
+ res = (x + y) * z
24
+ if transposeMiddle:
25
+ res = np.transpose(res, axes=[0, 2, 1, 3])
26
+ return (res,)
27
+
28
+
29
+ class MulAdd(OpRun):
30
+ op_domain = "onnx_extended.ortops.optim.cuda"
31
+
32
+ def _run(self, x, y, z, transposeMiddle=None):
33
+ res = (x * y) + z
34
+ if transposeMiddle:
35
+ res = np.transpose(res, axes=[0, 2, 1, 3])
36
+ return (res,)
37
+
38
+
39
+ class SubMul(OpRun):
40
+ op_domain = "onnx_extended.ortops.optim.cuda"
41
+
42
+ def _run(self, x, y, z, negative=None):
43
+ if negative:
44
+ return ((y - x) * z,)
45
+ return ((x - y) * z,)
46
+
47
+
48
+ class MulSub(OpRun):
49
+ op_domain = "onnx_extended.ortops.optim.cuda"
50
+
51
+ def _run(self, x, y, z, negative=None):
52
+ if negative:
53
+ return (z - (x * y),)
54
+ return ((x * y) - z,)
55
+
56
+
57
+ class AddSharedInput(OpRun):
58
+ op_domain = "onnx_extended.ortops.optim.cuda"
59
+
60
+ def _run(self, x, y, z):
61
+ return (x + y, x + z)
62
+
63
+
64
+ class MulSharedInput(OpRun):
65
+ op_domain = "onnx_extended.ortops.optim.cuda"
66
+
67
+ def _run(self, x, y, z):
68
+ return (x * y, x * z)
@@ -0,0 +1,60 @@
1
+ import numpy as np
2
+ import scipy.special as scipy_special
3
+ from onnx.reference.op_run import OpRun
4
+
5
+
6
+ class Attention(OpRun):
7
+ op_domain = "com.microsoft"
8
+
9
+ def __init__(self, *args, **kwargs):
10
+ super().__init__(*args, **kwargs)
11
+ # Fix a bug onnx.reference.ReferenceEvaluator
12
+ self._schema = None
13
+ self.attributes_names_ = ["num_heads"]
14
+
15
+ def _run(
16
+ self,
17
+ x,
18
+ weights,
19
+ bias,
20
+ mask_index,
21
+ past,
22
+ attention_bias,
23
+ num_heads=None,
24
+ ):
25
+ assert past is None, f"Attention not implemented if past == {past!r}"
26
+ assert (
27
+ num_heads == attention_bias.shape[1]
28
+ ), f"num_heads={num_heads} not in attention_bias.shape={attention_bias.shape}"
29
+ d = weights.shape[1] // 3
30
+ q_weights = weights[:, :d]
31
+ k_weights = weights[:, d : d * 2]
32
+ v_weights = weights[:, d * 2 :]
33
+
34
+ d = bias.shape[0] // 3
35
+ q_bias = bias[:d]
36
+ k_bias = bias[d : d * 2]
37
+ v_bias = bias[d * 2 :]
38
+
39
+ shape_4d = (*x.shape[:2], num_heads, -1)
40
+
41
+ # nodes
42
+ mask_applied = mask_index == 0
43
+ xqb = x @ q_weights + q_bias
44
+ xqb_4d = xqb.reshape(shape_4d)
45
+ xkb = x @ k_weights + k_bias
46
+ xkb_4d = xkb.reshape(shape_4d)
47
+ xvb = x @ v_weights + v_bias
48
+ xvb_4d = xvb.reshape(shape_4d)
49
+ rot_xqb = np.transpose(xqb_4d, axes=(0, 2, 1, 3))
50
+ rot_xkb = np.transpose(xkb_4d, axes=(0, 2, 1, 3))
51
+ matmul = 0.125 * rot_xqb @ np.transpose(rot_xkb, [0, 1, 3, 2])
52
+ transpose_3 = np.transpose(xvb_4d, axes=(0, 2, 1, 3))
53
+ add_322 = matmul + attention_bias
54
+ masked_fill_2 = np.where(mask_applied, -np.inf, add_322)
55
+ softmax = scipy_special.softmax(masked_fill_2, axis=-1)
56
+ masked_fill_3 = np.where(mask_applied, 0, softmax)
57
+ matmul_1 = masked_fill_3 @ transpose_3
58
+ transpose_5 = np.transpose(matmul_1, axes=(0, 2, 1, 3))
59
+ view_3 = transpose_5.reshape(x.shape)
60
+ return (view_3,)
@@ -0,0 +1,63 @@
1
+ import numpy as np
2
+ from onnx.reference.op_run import OpRun
3
+
4
+
5
+ class AveragePoolGrad(OpRun):
6
+ def _run(
7
+ self,
8
+ out,
9
+ auto_pad=None,
10
+ ceil_mode=None,
11
+ count_include_pad=None,
12
+ kernel_shape=None,
13
+ pads=None,
14
+ strides=None,
15
+ ):
16
+ assert auto_pad is not None, "auto_pad is None"
17
+ assert ceil_mode is not None, "ceil_mode is None"
18
+ assert count_include_pad is not None, "count_include_pad is None"
19
+ assert kernel_shape is not None, "kernel_shape is None"
20
+ assert pads is not None, "pads is None"
21
+ assert strides is not None, "strides is None"
22
+
23
+ assert auto_pad == "NOTSET", f"Not implemented for autopad={auto_pad!r}"
24
+ assert ceil_mode == 0, f"Not implemented for ceil_mode={ceil_mode!r}"
25
+ assert (
26
+ count_include_pad == 1
27
+ ), f"Not implemented for count_include_pad={count_include_pad!r}"
28
+
29
+ grad_shape = list(out.shape[:2])
30
+ for i in range(len(kernel_shape)):
31
+ d = (
32
+ out.shape[i + 2] * strides[i]
33
+ + kernel_shape[i]
34
+ - 1
35
+ + sum(pads[i * 2 : i * 2 + 2])
36
+ )
37
+ grad_shape.append(d)
38
+
39
+ grad = np.zeros(tuple(grad_shape), dtype=out.dtype)
40
+ scale = (1.0 / np.prod(kernel_shape)).astype(out.dtype)
41
+ if len(grad_shape) == 4:
42
+ # 2D
43
+ for batch in range(grad.shape[0]):
44
+ for channel in range(grad.shape[1]):
45
+ for i in range(out.shape[2]):
46
+ t = max(i * strides[0] - pads[0], 0)
47
+ b = min(i * strides[0] - pads[0] + kernel_shape[0], grad.shape[2])
48
+ for j in range(out.shape[3]):
49
+ le = max(j * strides[1] - pads[2], 0)
50
+ ri = min(
51
+ j * strides[1] - pads[2] + kernel_shape[1],
52
+ grad.shape[3],
53
+ )
54
+
55
+ grad[batch, channel, t:b, le:ri] += (
56
+ out[batch, channel, i, j] * scale
57
+ )
58
+ else:
59
+ raise NotImplementedError(
60
+ f"AveragePoolGrad is not implemented for shape={out.shape}."
61
+ )
62
+
63
+ return (grad.astype(out.dtype),)
@@ -0,0 +1,16 @@
1
+ import numpy as np
2
+ from onnx.reference.op_run import OpRun
3
+
4
+
5
+ class BiasSoftmax(OpRun):
6
+ op_domain = "com.microsoft"
7
+
8
+ def _run(self, x, y, axis=None, is_inner_broadcast=None): # type: ignore
9
+ assert (
10
+ is_inner_broadcast == 0
11
+ ), f"Not implemented for is_inner_broadcast={is_inner_broadcast}"
12
+ z = x + y
13
+ tmp = z - z.max(axis=axis, keepdims=1) # type: ignore
14
+ w = np.exp(tmp)
15
+ w /= w.sum(axis=axis, keepdims=1) # type: ignore
16
+ return (w.astype(x.dtype),)
@@ -0,0 +1,46 @@
1
+ from onnx.onnx_pb import TensorProto
2
+ from onnx.reference.op_run import OpRun
3
+
4
+ try:
5
+ from onnx.reference.ops.op_cast import (
6
+ bfloat16,
7
+ cast_to,
8
+ float8e4m3fn,
9
+ float8e4m3fnuz,
10
+ float8e5m2,
11
+ float8e5m2fnuz,
12
+ )
13
+ except ImportError:
14
+ bfloat16 = None
15
+ from onnx.reference.ops.op_cast import cast_to
16
+ from ...helpers.onnx_helper import np_dtype_to_tensor_dtype
17
+
18
+
19
+ def _cast_like(x, y, saturate):
20
+ if bfloat16 is not None:
21
+ if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16":
22
+ # np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16
23
+ to = TensorProto.BFLOAT16
24
+ elif y.dtype == float8e4m3fn and y.dtype.descr[0][0] == "e4m3fn":
25
+ to = TensorProto.FLOAT8E4M3FN
26
+ elif y.dtype == float8e4m3fnuz and y.dtype.descr[0][0] == "e4m3fnuz":
27
+ to = TensorProto.FLOAT8E4M3FNUZ
28
+ elif y.dtype == float8e5m2 and y.dtype.descr[0][0] == "e5m2":
29
+ to = TensorProto.FLOAT8E5M2
30
+ elif y.dtype == float8e5m2fnuz and y.dtype.descr[0][0] == "e5m2fnuz":
31
+ to = TensorProto.FLOAT8E5M2FNUZ
32
+ else:
33
+ to = np_dtype_to_tensor_dtype(y.dtype) # type: ignore
34
+ else:
35
+ to = np_dtype_to_tensor_dtype(y.dtype) # type: ignore
36
+ return (cast_to(x, to, saturate),)
37
+
38
+
39
+ class CastLike_15(OpRun):
40
+ def _run(self, x, y): # type: ignore
41
+ return _cast_like(x, y, True)
42
+
43
+
44
+ class CastLike_19(OpRun):
45
+ def _run(self, x, y, saturate=None): # type: ignore
46
+ return _cast_like(x, y, saturate)
@@ -0,0 +1,26 @@
1
+ import numpy as np
2
+ from onnx.reference.op_run import OpRun
3
+
4
+
5
+ class ToComplex(OpRun):
6
+ op_domain = "ai.onnx.complex"
7
+
8
+ def _run(self, x):
9
+ assert x.shape[-1] in (
10
+ 1,
11
+ 2,
12
+ ), f"Unexpected shape {x.shape}, it should a tensor (..., 2)"
13
+ if x.shape[-1] == 1:
14
+ return (x[..., 0] + 0j,)
15
+ return (x[..., 0] + 1j * x[..., 1],)
16
+
17
+
18
+ class ComplexModule(OpRun):
19
+ op_domain = "ai.onnx.complex"
20
+
21
+ def _run(self, x):
22
+ assert x.dtype in (
23
+ np.complex64,
24
+ np.complex128,
25
+ ), f"Unexpected type {x.dtype}, it should a complex tensor"
26
+ return (np.abs(x),)
@@ -0,0 +1,15 @@
1
+ import numpy as np
2
+
3
+ from onnx.reference.op_run import OpRun
4
+
5
+
6
+ class Concat(OpRun):
7
+ def _preprocess(self, a: np.ndarray, axis: int) -> np.ndarray:
8
+ if axis >= len(a.shape): # type: ignore
9
+ new_shape = a.shape + (1,) * (axis + 1 - len(a.shape)) # type: ignore
10
+ return a.reshape(new_shape)
11
+ return a
12
+
13
+ def _run(self, *args, axis=None): # type: ignore
14
+ targs = tuple(self._preprocess(a, axis) for a in args)
15
+ return (np.concatenate(targs, axis),) # type: ignore
@@ -0,0 +1,67 @@
1
+ import numpy as np
2
+ from onnx.reference.op_run import OpRun
3
+
4
+ try:
5
+ import ml_dtypes
6
+ except ImportError:
7
+ ml_dtypes = None # type: ignore
8
+
9
+
10
+ class ConstantOfShape(OpRun):
11
+ @staticmethod
12
+ def _process(value):
13
+ if (
14
+ value is not None
15
+ and ml_dtypes is not None
16
+ and value.dtype == (np.uint16, [("bfloat16", "<u2")])
17
+ ):
18
+ value = value.view(ml_dtypes.bfloat16)
19
+ cst = value[0] if isinstance(value, np.ndarray) and value.size > 0 else value
20
+ if isinstance(value, np.ndarray):
21
+ if not value.shape:
22
+ cst = value
23
+ elif value.size > 0:
24
+ cst = value.ravel()[0]
25
+ else:
26
+ raise ValueError(f"Unexpected fill_value={value!r}")
27
+ if isinstance(cst, bool):
28
+ cst = np.bool_(cst)
29
+ elif isinstance(cst, int):
30
+ cst = np.int64(cst)
31
+ elif isinstance(cst, float):
32
+ cst = np.float64(cst)
33
+ elif cst is None:
34
+ cst = np.float32(0)
35
+ if ml_dtypes is not None and isinstance(cst, ml_dtypes.bfloat16):
36
+ return cst
37
+ if not isinstance(
38
+ cst,
39
+ (
40
+ np.float16,
41
+ np.float32,
42
+ np.float64,
43
+ np.int64,
44
+ np.int32,
45
+ np.int16,
46
+ np.int8,
47
+ np.uint64,
48
+ np.uint32,
49
+ np.uint16,
50
+ np.uint8,
51
+ np.bool_,
52
+ ),
53
+ ):
54
+ raise TypeError(f"value must be a real not {type(cst)}")
55
+ return cst
56
+
57
+ def _run(self, data, value=None):
58
+ cst = self._process(value)
59
+ try:
60
+ res = np.full(tuple(data), cst)
61
+ except TypeError as e:
62
+ raise RuntimeError(
63
+ f"Unable to create a constant of shape "
64
+ f"{data!r} with value {cst!r} "
65
+ f"(raw value={value!r})."
66
+ ) from e
67
+ return (res,)
@@ -0,0 +1,31 @@
1
+ import numpy as np
2
+ from onnx.reference.op_run import OpRun
3
+
4
+
5
+ class FusedMatMul(OpRun):
6
+ op_domain = "com.microsoft"
7
+
8
+ def _run(
9
+ self,
10
+ A,
11
+ B,
12
+ alpha: float = 1,
13
+ transA: int = 0,
14
+ transB: int = 0,
15
+ transBatchA: int = 0,
16
+ transBatchB: int = 0,
17
+ ):
18
+ assert transBatchA == 0, f"Not implemented for transBatchA==1 and {A.shape}x{B.shape}"
19
+ assert transBatchB == 0, f"Not implemented for transBatchB==1 and {A.shape}x{B.shape}"
20
+ if transA:
21
+ perm = list(range(len(A.shape)))
22
+ dim = len(perm)
23
+ perm[dim - 2], perm[dim - 1] = perm[dim - 1], perm[dim - 2]
24
+ A = np.transpose(A, perm)
25
+ if transB:
26
+ perm = list(range(len(B.shape)))
27
+ dim = len(perm)
28
+ perm[dim - 2], perm[dim - 1] = perm[dim - 1], perm[dim - 2]
29
+ B = np.transpose(B, perm)
30
+ a = np.array(alpha, dtype=A.dtype)
31
+ return (np.matmul(A, B) * a,)
@@ -0,0 +1,29 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ from __future__ import annotations
5
+
6
+ import numpy as np
7
+
8
+ from onnx.reference.op_run import OpRun
9
+
10
+
11
+ class Gather(OpRun):
12
+ def _run(self, x, indices, axis=None):
13
+ if x.size == 0 or indices.size == 0:
14
+ if axis is None:
15
+ new_shape = indices.shape
16
+ else:
17
+ new_shape = (*x.shape[:axis], *indices.shape, *x.shape[axis + 1 :])
18
+ if 0 not in new_shape:
19
+ new_shape = (0, *new_shape[1:])
20
+ return (np.empty(new_shape, dtype=x.dtype),)
21
+ if not x.flags["C_CONTIGUOUS"]:
22
+ x = np.ascontiguousarray(x)
23
+ if not indices.flags["C_CONTIGUOUS"]:
24
+ indices = indices.ascontiguousarray()
25
+ try:
26
+ return (np.take(x, indices, axis=axis),)
27
+ except TypeError:
28
+ # distribution x86 requires int32.
29
+ return (np.take(x, indices.astype(int), axis=axis),)