onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
from logging import getLogger
|
|
2
|
+
from typing import Any, Dict, List, Optional, Union
|
|
3
|
+
from onnx import FunctionProto, ModelProto, NodeProto, TypeProto
|
|
4
|
+
from onnx.defs import get_schema
|
|
5
|
+
from onnx.reference import ReferenceEvaluator
|
|
6
|
+
from onnx.reference.op_run import OpRun
|
|
7
|
+
from .ops.op_add_add_mul_mul import (
|
|
8
|
+
AddAdd,
|
|
9
|
+
AddMul,
|
|
10
|
+
AddSharedInput,
|
|
11
|
+
MulAdd,
|
|
12
|
+
MulMul,
|
|
13
|
+
MulSharedInput,
|
|
14
|
+
MulSub,
|
|
15
|
+
SubMul,
|
|
16
|
+
)
|
|
17
|
+
from .ops.op_attention import Attention
|
|
18
|
+
from .ops.op_average_pool_grad import AveragePoolGrad
|
|
19
|
+
from .ops.op_bias_softmax import BiasSoftmax
|
|
20
|
+
from .ops.op_cast_like import CastLike_15, CastLike_19
|
|
21
|
+
from .ops.op_complex import ComplexModule, ToComplex
|
|
22
|
+
from .ops.op_concat import Concat
|
|
23
|
+
from .ops.op_constant_of_shape import ConstantOfShape
|
|
24
|
+
from .ops.op_fused_matmul import FusedMatMul
|
|
25
|
+
from .ops.op_gather import Gather
|
|
26
|
+
from .ops.op_gather_elements import GatherElements
|
|
27
|
+
from .ops.op_gather_grad import GatherGrad
|
|
28
|
+
from .ops.op_memcpy_host import MemcpyFromHost, MemcpyToHost
|
|
29
|
+
from .ops.op_mul_sigmoid import MulSigmoid
|
|
30
|
+
from .ops.op_negxplus1 import NegXplus1
|
|
31
|
+
from .ops.op_qlinear_average_pool import QLinearAveragePool
|
|
32
|
+
from .ops.op_qlinear_conv import QLinearConv
|
|
33
|
+
from .ops.op_quick_gelu import QuickGelu
|
|
34
|
+
from .ops.op_replace_zero import ReplaceZero
|
|
35
|
+
from .ops.op_rotary import Rotary
|
|
36
|
+
from .ops.op_scan import Scan
|
|
37
|
+
from .ops.op_scatter_elements import ScatterElements
|
|
38
|
+
from .ops.op_scatternd_of_shape import MaskedScatterNDOfShape, ScatterNDOfShape
|
|
39
|
+
from .ops.op_simplified_layer_normalization import SimplifiedLayerNormalization
|
|
40
|
+
from .ops.op_skip_layer_normalization import SkipLayerNormalization
|
|
41
|
+
from .ops.op_slice import Slice_1, Slice_10
|
|
42
|
+
from .ops.op_transpose_cast import Transpose2DCastFP16, Transpose2DCastFP32
|
|
43
|
+
from .ops.op_tri_matrix import TriMatrix
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
logger = getLogger("onnx-diagnostic-eval")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class ExtendedReferenceEvaluator(ReferenceEvaluator):
|
|
50
|
+
"""
|
|
51
|
+
This class replaces the python implementation by custom implementation.
|
|
52
|
+
The evaluator allows to test
|
|
53
|
+
scenarios outside what an onnx backend bound to the official onnx
|
|
54
|
+
operators definition could do such as optimization patterns
|
|
55
|
+
involving onnxruntime contrib operators.
|
|
56
|
+
|
|
57
|
+
::
|
|
58
|
+
|
|
59
|
+
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
|
|
60
|
+
ref = ExtendedReferenceEvaluator(...)
|
|
61
|
+
|
|
62
|
+
The class overloads or adds the following operators by default:
|
|
63
|
+
|
|
64
|
+
.. runpython::
|
|
65
|
+
:showcode:
|
|
66
|
+
|
|
67
|
+
import pprint
|
|
68
|
+
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
|
|
69
|
+
|
|
70
|
+
pprint.pprint(ExtendedReferenceEvaluator.default_ops)
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
default_ops: List[type[OpRun]] = [
|
|
74
|
+
AddAdd,
|
|
75
|
+
AddMul,
|
|
76
|
+
AddSharedInput,
|
|
77
|
+
Attention,
|
|
78
|
+
AveragePoolGrad,
|
|
79
|
+
BiasSoftmax,
|
|
80
|
+
Concat,
|
|
81
|
+
CastLike_15,
|
|
82
|
+
CastLike_19,
|
|
83
|
+
ComplexModule,
|
|
84
|
+
ConstantOfShape,
|
|
85
|
+
FusedMatMul,
|
|
86
|
+
Gather,
|
|
87
|
+
GatherElements,
|
|
88
|
+
GatherGrad,
|
|
89
|
+
MaskedScatterNDOfShape,
|
|
90
|
+
MemcpyFromHost,
|
|
91
|
+
MemcpyToHost,
|
|
92
|
+
MulAdd,
|
|
93
|
+
MulMul,
|
|
94
|
+
MulSharedInput,
|
|
95
|
+
MulSigmoid,
|
|
96
|
+
MulSub,
|
|
97
|
+
NegXplus1,
|
|
98
|
+
QLinearConv,
|
|
99
|
+
QLinearAveragePool,
|
|
100
|
+
QuickGelu,
|
|
101
|
+
ReplaceZero,
|
|
102
|
+
Rotary,
|
|
103
|
+
Scan,
|
|
104
|
+
ScatterElements,
|
|
105
|
+
ScatterNDOfShape,
|
|
106
|
+
SimplifiedLayerNormalization,
|
|
107
|
+
SkipLayerNormalization,
|
|
108
|
+
Slice_1,
|
|
109
|
+
Slice_10,
|
|
110
|
+
SubMul,
|
|
111
|
+
ToComplex,
|
|
112
|
+
Transpose2DCastFP16,
|
|
113
|
+
Transpose2DCastFP32,
|
|
114
|
+
TriMatrix,
|
|
115
|
+
]
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
118
|
+
def filter_ops(proto, new_ops, opsets):
|
|
119
|
+
if opsets is None and isinstance(proto, (ModelProto, FunctionProto)):
|
|
120
|
+
opsets = {d.domain: d.version for d in proto.opset_import}
|
|
121
|
+
best = {}
|
|
122
|
+
renamed = {}
|
|
123
|
+
for cl in new_ops:
|
|
124
|
+
if "_" not in cl.__name__:
|
|
125
|
+
continue
|
|
126
|
+
vers = cl.__name__.split("_")
|
|
127
|
+
try:
|
|
128
|
+
v = int(vers[-1])
|
|
129
|
+
except ValueError:
|
|
130
|
+
# not a version
|
|
131
|
+
continue
|
|
132
|
+
if opsets is not None and v > opsets.get(cl.op_domain, 1):
|
|
133
|
+
continue
|
|
134
|
+
renamed[cl.__name__] = cl
|
|
135
|
+
key = cl.op_domain, "_".join(vers[:-1])
|
|
136
|
+
if key not in best or best[key][0] < v:
|
|
137
|
+
best[key] = (v, cl)
|
|
138
|
+
|
|
139
|
+
modified = []
|
|
140
|
+
for cl in new_ops:
|
|
141
|
+
if cl.__name__ not in renamed:
|
|
142
|
+
modified.append(cl)
|
|
143
|
+
for k, v in best.items():
|
|
144
|
+
atts = {"domain": k[0]}
|
|
145
|
+
bases = (v[1],)
|
|
146
|
+
if not hasattr(v[1], "op_schema"):
|
|
147
|
+
atts["op_schema"] = get_schema(k[1], v[0], domain=v[1].op_domain)
|
|
148
|
+
new_cl = type(k[1], bases, atts)
|
|
149
|
+
modified.append(new_cl)
|
|
150
|
+
|
|
151
|
+
new_ops = modified
|
|
152
|
+
return new_ops
|
|
153
|
+
|
|
154
|
+
def __init__(
|
|
155
|
+
self,
|
|
156
|
+
proto: Any,
|
|
157
|
+
opsets: Optional[Dict[str, int]] = None,
|
|
158
|
+
functions: Optional[List[Union[ReferenceEvaluator, FunctionProto]]] = None,
|
|
159
|
+
verbose: int = 0,
|
|
160
|
+
new_ops: Optional[List[type[OpRun]]] = None,
|
|
161
|
+
**kwargs,
|
|
162
|
+
):
|
|
163
|
+
if new_ops is None:
|
|
164
|
+
new_ops = ExtendedReferenceEvaluator.default_ops
|
|
165
|
+
else:
|
|
166
|
+
new_ops = new_ops.copy()
|
|
167
|
+
new_ops.extend(ExtendedReferenceEvaluator.default_ops)
|
|
168
|
+
new_ops = ExtendedReferenceEvaluator.filter_ops(proto, new_ops, opsets)
|
|
169
|
+
|
|
170
|
+
ReferenceEvaluator.__init__(
|
|
171
|
+
self,
|
|
172
|
+
proto,
|
|
173
|
+
opsets=opsets,
|
|
174
|
+
functions=functions,
|
|
175
|
+
verbose=verbose,
|
|
176
|
+
new_ops=new_ops,
|
|
177
|
+
**kwargs,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def _log(self, level: int, pattern: str, *args: List[Any]) -> None:
|
|
181
|
+
if level < self.verbose:
|
|
182
|
+
new_args = [self._log_arg(a) for a in args]
|
|
183
|
+
print(pattern % tuple(new_args))
|
|
184
|
+
else:
|
|
185
|
+
logger.debug(pattern, *args)
|
|
186
|
+
|
|
187
|
+
def run(self, *args, **kwargs):
|
|
188
|
+
"""See :meth:`onnx.reference.ReferenceEvaluator.run`."""
|
|
189
|
+
if len(args) == 1 and isinstance(args[0], list):
|
|
190
|
+
feeds = dict(zip(self.input_names, args[0]))
|
|
191
|
+
return self.run(None, feeds, **kwargs)
|
|
192
|
+
if isinstance(self.proto_, FunctionProto):
|
|
193
|
+
return self._run_function(*args, **kwargs)
|
|
194
|
+
return ReferenceEvaluator.run(self, *args, **kwargs)
|
|
195
|
+
|
|
196
|
+
def _load_impl(self, node: NodeProto, input_types: TypeProto | None = None) -> Any:
|
|
197
|
+
res = super()._load_impl(node, input_types)
|
|
198
|
+
assert (
|
|
199
|
+
not hasattr(res, "op_domain") or res.op_domain == node.domain
|
|
200
|
+
), f"Domain mismatch {res.op_domain!r} != {node.domain} for node={node}"
|
|
201
|
+
return res
|
|
202
|
+
|
|
203
|
+
def _run_function(
|
|
204
|
+
self,
|
|
205
|
+
output_names,
|
|
206
|
+
feed_inputs: Dict[str, Any],
|
|
207
|
+
attributes: Optional[Dict[str, Any]] = None,
|
|
208
|
+
intermediate: bool = False,
|
|
209
|
+
) -> Union[Dict[str, Any], List[Any]]: # type: ignore
|
|
210
|
+
if output_names is None:
|
|
211
|
+
output_names = self.output_names
|
|
212
|
+
|
|
213
|
+
# step 1: inputs and initializers
|
|
214
|
+
results = {"": None} # optional input
|
|
215
|
+
results.update(self.rt_inits_) # type: ignore[arg-type]
|
|
216
|
+
results.update(feed_inputs)
|
|
217
|
+
for k, v in self.rt_inits_.items():
|
|
218
|
+
self._log(2, " +C %s: %s", k, v) # type: ignore[arg-type]
|
|
219
|
+
for k, v in feed_inputs.items():
|
|
220
|
+
self._log(2, " +I %s: %s", k, v) # type: ignore[arg-type]
|
|
221
|
+
|
|
222
|
+
# step 2: execute nodes
|
|
223
|
+
for node in self.rt_nodes_:
|
|
224
|
+
self._log(1, "%s(%s) -> %s", node.op_type, node.input, node.output)
|
|
225
|
+
for i in node.input:
|
|
226
|
+
if i not in results:
|
|
227
|
+
raise RuntimeError(
|
|
228
|
+
f"Unable to find input {i!r} in known results {sorted(results)}, "
|
|
229
|
+
f"self.rt_inits_ has {sorted(self.rt_inits_)}, "
|
|
230
|
+
f"feed_inputs has {sorted(feed_inputs)}."
|
|
231
|
+
)
|
|
232
|
+
inputs = [results[i] for i in node.input]
|
|
233
|
+
linked_attributes = {}
|
|
234
|
+
if node.has_linked_attribute and attributes:
|
|
235
|
+
linked_attributes["linked_attributes"] = attributes
|
|
236
|
+
if node.need_context():
|
|
237
|
+
outputs = node.run(*inputs, context=results, **linked_attributes)
|
|
238
|
+
else:
|
|
239
|
+
outputs = node.run(*inputs, **linked_attributes)
|
|
240
|
+
for name, value in zip(node.output, outputs):
|
|
241
|
+
self._log(2, " + %s: %s", name, value) # type: ignore[arg-type]
|
|
242
|
+
results[name] = value
|
|
243
|
+
|
|
244
|
+
# return the results
|
|
245
|
+
if intermediate:
|
|
246
|
+
return results
|
|
247
|
+
|
|
248
|
+
for name in output_names:
|
|
249
|
+
if name not in results:
|
|
250
|
+
raise RuntimeError(
|
|
251
|
+
f"Unable to find output name {name!r} "
|
|
252
|
+
f"in {sorted(results)}, proto is\n{self.proto_}"
|
|
253
|
+
)
|
|
254
|
+
return [results[name] for name in output_names]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from onnx.reference.op_run import OpRun
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class AddAdd(OpRun):
|
|
6
|
+
op_domain = "onnx_extended.ortops.optim.cuda"
|
|
7
|
+
|
|
8
|
+
def _run(self, x, y, z):
|
|
9
|
+
return (x + y + z,)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MulMul(OpRun):
|
|
13
|
+
op_domain = "onnx_extended.ortops.optim.cuda"
|
|
14
|
+
|
|
15
|
+
def _run(self, x, y, z):
|
|
16
|
+
return (x * y * z,)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AddMul(OpRun):
|
|
20
|
+
op_domain = "onnx_extended.ortops.optim.cuda"
|
|
21
|
+
|
|
22
|
+
def _run(self, x, y, z, transposeMiddle=None):
|
|
23
|
+
res = (x + y) * z
|
|
24
|
+
if transposeMiddle:
|
|
25
|
+
res = np.transpose(res, axes=[0, 2, 1, 3])
|
|
26
|
+
return (res,)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MulAdd(OpRun):
|
|
30
|
+
op_domain = "onnx_extended.ortops.optim.cuda"
|
|
31
|
+
|
|
32
|
+
def _run(self, x, y, z, transposeMiddle=None):
|
|
33
|
+
res = (x * y) + z
|
|
34
|
+
if transposeMiddle:
|
|
35
|
+
res = np.transpose(res, axes=[0, 2, 1, 3])
|
|
36
|
+
return (res,)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class SubMul(OpRun):
|
|
40
|
+
op_domain = "onnx_extended.ortops.optim.cuda"
|
|
41
|
+
|
|
42
|
+
def _run(self, x, y, z, negative=None):
|
|
43
|
+
if negative:
|
|
44
|
+
return ((y - x) * z,)
|
|
45
|
+
return ((x - y) * z,)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class MulSub(OpRun):
|
|
49
|
+
op_domain = "onnx_extended.ortops.optim.cuda"
|
|
50
|
+
|
|
51
|
+
def _run(self, x, y, z, negative=None):
|
|
52
|
+
if negative:
|
|
53
|
+
return (z - (x * y),)
|
|
54
|
+
return ((x * y) - z,)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class AddSharedInput(OpRun):
|
|
58
|
+
op_domain = "onnx_extended.ortops.optim.cuda"
|
|
59
|
+
|
|
60
|
+
def _run(self, x, y, z):
|
|
61
|
+
return (x + y, x + z)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class MulSharedInput(OpRun):
|
|
65
|
+
op_domain = "onnx_extended.ortops.optim.cuda"
|
|
66
|
+
|
|
67
|
+
def _run(self, x, y, z):
|
|
68
|
+
return (x * y, x * z)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import scipy.special as scipy_special
|
|
3
|
+
from onnx.reference.op_run import OpRun
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Attention(OpRun):
|
|
7
|
+
op_domain = "com.microsoft"
|
|
8
|
+
|
|
9
|
+
def __init__(self, *args, **kwargs):
|
|
10
|
+
super().__init__(*args, **kwargs)
|
|
11
|
+
# Fix a bug onnx.reference.ReferenceEvaluator
|
|
12
|
+
self._schema = None
|
|
13
|
+
self.attributes_names_ = ["num_heads"]
|
|
14
|
+
|
|
15
|
+
def _run(
|
|
16
|
+
self,
|
|
17
|
+
x,
|
|
18
|
+
weights,
|
|
19
|
+
bias,
|
|
20
|
+
mask_index,
|
|
21
|
+
past,
|
|
22
|
+
attention_bias,
|
|
23
|
+
num_heads=None,
|
|
24
|
+
):
|
|
25
|
+
assert past is None, f"Attention not implemented if past == {past!r}"
|
|
26
|
+
assert (
|
|
27
|
+
num_heads == attention_bias.shape[1]
|
|
28
|
+
), f"num_heads={num_heads} not in attention_bias.shape={attention_bias.shape}"
|
|
29
|
+
d = weights.shape[1] // 3
|
|
30
|
+
q_weights = weights[:, :d]
|
|
31
|
+
k_weights = weights[:, d : d * 2]
|
|
32
|
+
v_weights = weights[:, d * 2 :]
|
|
33
|
+
|
|
34
|
+
d = bias.shape[0] // 3
|
|
35
|
+
q_bias = bias[:d]
|
|
36
|
+
k_bias = bias[d : d * 2]
|
|
37
|
+
v_bias = bias[d * 2 :]
|
|
38
|
+
|
|
39
|
+
shape_4d = (*x.shape[:2], num_heads, -1)
|
|
40
|
+
|
|
41
|
+
# nodes
|
|
42
|
+
mask_applied = mask_index == 0
|
|
43
|
+
xqb = x @ q_weights + q_bias
|
|
44
|
+
xqb_4d = xqb.reshape(shape_4d)
|
|
45
|
+
xkb = x @ k_weights + k_bias
|
|
46
|
+
xkb_4d = xkb.reshape(shape_4d)
|
|
47
|
+
xvb = x @ v_weights + v_bias
|
|
48
|
+
xvb_4d = xvb.reshape(shape_4d)
|
|
49
|
+
rot_xqb = np.transpose(xqb_4d, axes=(0, 2, 1, 3))
|
|
50
|
+
rot_xkb = np.transpose(xkb_4d, axes=(0, 2, 1, 3))
|
|
51
|
+
matmul = 0.125 * rot_xqb @ np.transpose(rot_xkb, [0, 1, 3, 2])
|
|
52
|
+
transpose_3 = np.transpose(xvb_4d, axes=(0, 2, 1, 3))
|
|
53
|
+
add_322 = matmul + attention_bias
|
|
54
|
+
masked_fill_2 = np.where(mask_applied, -np.inf, add_322)
|
|
55
|
+
softmax = scipy_special.softmax(masked_fill_2, axis=-1)
|
|
56
|
+
masked_fill_3 = np.where(mask_applied, 0, softmax)
|
|
57
|
+
matmul_1 = masked_fill_3 @ transpose_3
|
|
58
|
+
transpose_5 = np.transpose(matmul_1, axes=(0, 2, 1, 3))
|
|
59
|
+
view_3 = transpose_5.reshape(x.shape)
|
|
60
|
+
return (view_3,)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from onnx.reference.op_run import OpRun
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class AveragePoolGrad(OpRun):
|
|
6
|
+
def _run(
|
|
7
|
+
self,
|
|
8
|
+
out,
|
|
9
|
+
auto_pad=None,
|
|
10
|
+
ceil_mode=None,
|
|
11
|
+
count_include_pad=None,
|
|
12
|
+
kernel_shape=None,
|
|
13
|
+
pads=None,
|
|
14
|
+
strides=None,
|
|
15
|
+
):
|
|
16
|
+
assert auto_pad is not None, "auto_pad is None"
|
|
17
|
+
assert ceil_mode is not None, "ceil_mode is None"
|
|
18
|
+
assert count_include_pad is not None, "count_include_pad is None"
|
|
19
|
+
assert kernel_shape is not None, "kernel_shape is None"
|
|
20
|
+
assert pads is not None, "pads is None"
|
|
21
|
+
assert strides is not None, "strides is None"
|
|
22
|
+
|
|
23
|
+
assert auto_pad == "NOTSET", f"Not implemented for autopad={auto_pad!r}"
|
|
24
|
+
assert ceil_mode == 0, f"Not implemented for ceil_mode={ceil_mode!r}"
|
|
25
|
+
assert (
|
|
26
|
+
count_include_pad == 1
|
|
27
|
+
), f"Not implemented for count_include_pad={count_include_pad!r}"
|
|
28
|
+
|
|
29
|
+
grad_shape = list(out.shape[:2])
|
|
30
|
+
for i in range(len(kernel_shape)):
|
|
31
|
+
d = (
|
|
32
|
+
out.shape[i + 2] * strides[i]
|
|
33
|
+
+ kernel_shape[i]
|
|
34
|
+
- 1
|
|
35
|
+
+ sum(pads[i * 2 : i * 2 + 2])
|
|
36
|
+
)
|
|
37
|
+
grad_shape.append(d)
|
|
38
|
+
|
|
39
|
+
grad = np.zeros(tuple(grad_shape), dtype=out.dtype)
|
|
40
|
+
scale = (1.0 / np.prod(kernel_shape)).astype(out.dtype)
|
|
41
|
+
if len(grad_shape) == 4:
|
|
42
|
+
# 2D
|
|
43
|
+
for batch in range(grad.shape[0]):
|
|
44
|
+
for channel in range(grad.shape[1]):
|
|
45
|
+
for i in range(out.shape[2]):
|
|
46
|
+
t = max(i * strides[0] - pads[0], 0)
|
|
47
|
+
b = min(i * strides[0] - pads[0] + kernel_shape[0], grad.shape[2])
|
|
48
|
+
for j in range(out.shape[3]):
|
|
49
|
+
le = max(j * strides[1] - pads[2], 0)
|
|
50
|
+
ri = min(
|
|
51
|
+
j * strides[1] - pads[2] + kernel_shape[1],
|
|
52
|
+
grad.shape[3],
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
grad[batch, channel, t:b, le:ri] += (
|
|
56
|
+
out[batch, channel, i, j] * scale
|
|
57
|
+
)
|
|
58
|
+
else:
|
|
59
|
+
raise NotImplementedError(
|
|
60
|
+
f"AveragePoolGrad is not implemented for shape={out.shape}."
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
return (grad.astype(out.dtype),)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from onnx.reference.op_run import OpRun
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class BiasSoftmax(OpRun):
|
|
6
|
+
op_domain = "com.microsoft"
|
|
7
|
+
|
|
8
|
+
def _run(self, x, y, axis=None, is_inner_broadcast=None): # type: ignore
|
|
9
|
+
assert (
|
|
10
|
+
is_inner_broadcast == 0
|
|
11
|
+
), f"Not implemented for is_inner_broadcast={is_inner_broadcast}"
|
|
12
|
+
z = x + y
|
|
13
|
+
tmp = z - z.max(axis=axis, keepdims=1) # type: ignore
|
|
14
|
+
w = np.exp(tmp)
|
|
15
|
+
w /= w.sum(axis=axis, keepdims=1) # type: ignore
|
|
16
|
+
return (w.astype(x.dtype),)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from onnx.onnx_pb import TensorProto
|
|
2
|
+
from onnx.reference.op_run import OpRun
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
from onnx.reference.ops.op_cast import (
|
|
6
|
+
bfloat16,
|
|
7
|
+
cast_to,
|
|
8
|
+
float8e4m3fn,
|
|
9
|
+
float8e4m3fnuz,
|
|
10
|
+
float8e5m2,
|
|
11
|
+
float8e5m2fnuz,
|
|
12
|
+
)
|
|
13
|
+
except ImportError:
|
|
14
|
+
bfloat16 = None
|
|
15
|
+
from onnx.reference.ops.op_cast import cast_to
|
|
16
|
+
from ...helpers.onnx_helper import np_dtype_to_tensor_dtype
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _cast_like(x, y, saturate):
|
|
20
|
+
if bfloat16 is not None:
|
|
21
|
+
if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16":
|
|
22
|
+
# np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16
|
|
23
|
+
to = TensorProto.BFLOAT16
|
|
24
|
+
elif y.dtype == float8e4m3fn and y.dtype.descr[0][0] == "e4m3fn":
|
|
25
|
+
to = TensorProto.FLOAT8E4M3FN
|
|
26
|
+
elif y.dtype == float8e4m3fnuz and y.dtype.descr[0][0] == "e4m3fnuz":
|
|
27
|
+
to = TensorProto.FLOAT8E4M3FNUZ
|
|
28
|
+
elif y.dtype == float8e5m2 and y.dtype.descr[0][0] == "e5m2":
|
|
29
|
+
to = TensorProto.FLOAT8E5M2
|
|
30
|
+
elif y.dtype == float8e5m2fnuz and y.dtype.descr[0][0] == "e5m2fnuz":
|
|
31
|
+
to = TensorProto.FLOAT8E5M2FNUZ
|
|
32
|
+
else:
|
|
33
|
+
to = np_dtype_to_tensor_dtype(y.dtype) # type: ignore
|
|
34
|
+
else:
|
|
35
|
+
to = np_dtype_to_tensor_dtype(y.dtype) # type: ignore
|
|
36
|
+
return (cast_to(x, to, saturate),)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class CastLike_15(OpRun):
|
|
40
|
+
def _run(self, x, y): # type: ignore
|
|
41
|
+
return _cast_like(x, y, True)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class CastLike_19(OpRun):
|
|
45
|
+
def _run(self, x, y, saturate=None): # type: ignore
|
|
46
|
+
return _cast_like(x, y, saturate)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from onnx.reference.op_run import OpRun
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ToComplex(OpRun):
|
|
6
|
+
op_domain = "ai.onnx.complex"
|
|
7
|
+
|
|
8
|
+
def _run(self, x):
|
|
9
|
+
assert x.shape[-1] in (
|
|
10
|
+
1,
|
|
11
|
+
2,
|
|
12
|
+
), f"Unexpected shape {x.shape}, it should a tensor (..., 2)"
|
|
13
|
+
if x.shape[-1] == 1:
|
|
14
|
+
return (x[..., 0] + 0j,)
|
|
15
|
+
return (x[..., 0] + 1j * x[..., 1],)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ComplexModule(OpRun):
|
|
19
|
+
op_domain = "ai.onnx.complex"
|
|
20
|
+
|
|
21
|
+
def _run(self, x):
|
|
22
|
+
assert x.dtype in (
|
|
23
|
+
np.complex64,
|
|
24
|
+
np.complex128,
|
|
25
|
+
), f"Unexpected type {x.dtype}, it should a complex tensor"
|
|
26
|
+
return (np.abs(x),)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from onnx.reference.op_run import OpRun
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Concat(OpRun):
|
|
7
|
+
def _preprocess(self, a: np.ndarray, axis: int) -> np.ndarray:
|
|
8
|
+
if axis >= len(a.shape): # type: ignore
|
|
9
|
+
new_shape = a.shape + (1,) * (axis + 1 - len(a.shape)) # type: ignore
|
|
10
|
+
return a.reshape(new_shape)
|
|
11
|
+
return a
|
|
12
|
+
|
|
13
|
+
def _run(self, *args, axis=None): # type: ignore
|
|
14
|
+
targs = tuple(self._preprocess(a, axis) for a in args)
|
|
15
|
+
return (np.concatenate(targs, axis),) # type: ignore
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from onnx.reference.op_run import OpRun
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
import ml_dtypes
|
|
6
|
+
except ImportError:
|
|
7
|
+
ml_dtypes = None # type: ignore
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ConstantOfShape(OpRun):
|
|
11
|
+
@staticmethod
|
|
12
|
+
def _process(value):
|
|
13
|
+
if (
|
|
14
|
+
value is not None
|
|
15
|
+
and ml_dtypes is not None
|
|
16
|
+
and value.dtype == (np.uint16, [("bfloat16", "<u2")])
|
|
17
|
+
):
|
|
18
|
+
value = value.view(ml_dtypes.bfloat16)
|
|
19
|
+
cst = value[0] if isinstance(value, np.ndarray) and value.size > 0 else value
|
|
20
|
+
if isinstance(value, np.ndarray):
|
|
21
|
+
if not value.shape:
|
|
22
|
+
cst = value
|
|
23
|
+
elif value.size > 0:
|
|
24
|
+
cst = value.ravel()[0]
|
|
25
|
+
else:
|
|
26
|
+
raise ValueError(f"Unexpected fill_value={value!r}")
|
|
27
|
+
if isinstance(cst, bool):
|
|
28
|
+
cst = np.bool_(cst)
|
|
29
|
+
elif isinstance(cst, int):
|
|
30
|
+
cst = np.int64(cst)
|
|
31
|
+
elif isinstance(cst, float):
|
|
32
|
+
cst = np.float64(cst)
|
|
33
|
+
elif cst is None:
|
|
34
|
+
cst = np.float32(0)
|
|
35
|
+
if ml_dtypes is not None and isinstance(cst, ml_dtypes.bfloat16):
|
|
36
|
+
return cst
|
|
37
|
+
if not isinstance(
|
|
38
|
+
cst,
|
|
39
|
+
(
|
|
40
|
+
np.float16,
|
|
41
|
+
np.float32,
|
|
42
|
+
np.float64,
|
|
43
|
+
np.int64,
|
|
44
|
+
np.int32,
|
|
45
|
+
np.int16,
|
|
46
|
+
np.int8,
|
|
47
|
+
np.uint64,
|
|
48
|
+
np.uint32,
|
|
49
|
+
np.uint16,
|
|
50
|
+
np.uint8,
|
|
51
|
+
np.bool_,
|
|
52
|
+
),
|
|
53
|
+
):
|
|
54
|
+
raise TypeError(f"value must be a real not {type(cst)}")
|
|
55
|
+
return cst
|
|
56
|
+
|
|
57
|
+
def _run(self, data, value=None):
|
|
58
|
+
cst = self._process(value)
|
|
59
|
+
try:
|
|
60
|
+
res = np.full(tuple(data), cst)
|
|
61
|
+
except TypeError as e:
|
|
62
|
+
raise RuntimeError(
|
|
63
|
+
f"Unable to create a constant of shape "
|
|
64
|
+
f"{data!r} with value {cst!r} "
|
|
65
|
+
f"(raw value={value!r})."
|
|
66
|
+
) from e
|
|
67
|
+
return (res,)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from onnx.reference.op_run import OpRun
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class FusedMatMul(OpRun):
|
|
6
|
+
op_domain = "com.microsoft"
|
|
7
|
+
|
|
8
|
+
def _run(
|
|
9
|
+
self,
|
|
10
|
+
A,
|
|
11
|
+
B,
|
|
12
|
+
alpha: float = 1,
|
|
13
|
+
transA: int = 0,
|
|
14
|
+
transB: int = 0,
|
|
15
|
+
transBatchA: int = 0,
|
|
16
|
+
transBatchB: int = 0,
|
|
17
|
+
):
|
|
18
|
+
assert transBatchA == 0, f"Not implemented for transBatchA==1 and {A.shape}x{B.shape}"
|
|
19
|
+
assert transBatchB == 0, f"Not implemented for transBatchB==1 and {A.shape}x{B.shape}"
|
|
20
|
+
if transA:
|
|
21
|
+
perm = list(range(len(A.shape)))
|
|
22
|
+
dim = len(perm)
|
|
23
|
+
perm[dim - 2], perm[dim - 1] = perm[dim - 1], perm[dim - 2]
|
|
24
|
+
A = np.transpose(A, perm)
|
|
25
|
+
if transB:
|
|
26
|
+
perm = list(range(len(B.shape)))
|
|
27
|
+
dim = len(perm)
|
|
28
|
+
perm[dim - 2], perm[dim - 1] = perm[dim - 1], perm[dim - 2]
|
|
29
|
+
B = np.transpose(B, perm)
|
|
30
|
+
a = np.array(alpha, dtype=A.dtype)
|
|
31
|
+
return (np.matmul(A, B) * a,)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from onnx.reference.op_run import OpRun
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Gather(OpRun):
|
|
12
|
+
def _run(self, x, indices, axis=None):
|
|
13
|
+
if x.size == 0 or indices.size == 0:
|
|
14
|
+
if axis is None:
|
|
15
|
+
new_shape = indices.shape
|
|
16
|
+
else:
|
|
17
|
+
new_shape = (*x.shape[:axis], *indices.shape, *x.shape[axis + 1 :])
|
|
18
|
+
if 0 not in new_shape:
|
|
19
|
+
new_shape = (0, *new_shape[1:])
|
|
20
|
+
return (np.empty(new_shape, dtype=x.dtype),)
|
|
21
|
+
if not x.flags["C_CONTIGUOUS"]:
|
|
22
|
+
x = np.ascontiguousarray(x)
|
|
23
|
+
if not indices.flags["C_CONTIGUOUS"]:
|
|
24
|
+
indices = indices.ascontiguousarray()
|
|
25
|
+
try:
|
|
26
|
+
return (np.take(x, indices, axis=axis),)
|
|
27
|
+
except TypeError:
|
|
28
|
+
# distribution x86 requires int32.
|
|
29
|
+
return (np.take(x, indices.astype(int), axis=axis),)
|