onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from onnx.reference.op_run import OpRun
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def gather_numpy_2(self: np.ndarray, index: np.ndarray) -> np.ndarray:
|
|
12
|
+
res = []
|
|
13
|
+
for a, b in zip(self, index):
|
|
14
|
+
res.append(a[b[0]])
|
|
15
|
+
return np.array(res, dtype=self.dtype).reshape(index.shape)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def gather_numpy(self: np.ndarray, dim: int, index: np.ndarray) -> np.ndarray:
|
|
19
|
+
idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1 :]
|
|
20
|
+
self_xsection_shape = self.shape[:dim] + self.shape[dim + 1 :]
|
|
21
|
+
if idx_xsection_shape != self_xsection_shape:
|
|
22
|
+
raise ValueError(
|
|
23
|
+
f"Except for dimension {dim!r}, all dimensions of "
|
|
24
|
+
f"index and self should be the same size."
|
|
25
|
+
)
|
|
26
|
+
data_swaped = np.swapaxes(self, 0, dim)
|
|
27
|
+
index_swaped = np.swapaxes(index, 0, dim)
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
gathered = np.choose(index_swaped, data_swaped, mode="wrap")
|
|
31
|
+
except ValueError:
|
|
32
|
+
if len(index_swaped.shape) == 2 and len(data_swaped.shape) == 2:
|
|
33
|
+
return gather_numpy_2(self, index)
|
|
34
|
+
raise # pragma: no cover
|
|
35
|
+
|
|
36
|
+
return np.swapaxes(gathered, 0, dim)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class GatherElements(OpRun):
|
|
40
|
+
def _run(self, data, indices, axis=None):
|
|
41
|
+
try:
|
|
42
|
+
return (gather_numpy(data, axis, indices),)
|
|
43
|
+
except TypeError:
|
|
44
|
+
# distribution x86 requires int32.
|
|
45
|
+
return (gather_numpy(data, axis, indices.astype(int)),)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from onnx.reference.op_run import OpRun
|
|
3
|
+
from onnx.reference.ops.op_scatternd import _scatter_nd_impl
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class GatherGrad(OpRun):
|
|
7
|
+
op_domain = "com.microsoft"
|
|
8
|
+
|
|
9
|
+
def _run(self, shape, indices, updates, reduction=None):
|
|
10
|
+
data = np.zeros(shape, dtype=updates.dtype)
|
|
11
|
+
y = _scatter_nd_impl(data, indices, updates, reduction=reduction)
|
|
12
|
+
return (y,)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from onnx.reference.op_run import OpRun
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def sigmoid(x): # type: ignore
|
|
6
|
+
if x > 0:
|
|
7
|
+
return 1 / (1 + np.exp(-x))
|
|
8
|
+
return np.exp(x) / (1 + np.exp(x))
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MulSigmoid(OpRun):
|
|
12
|
+
op_domain = "onnx_extended.ortops.optim.cuda"
|
|
13
|
+
|
|
14
|
+
def __init__(self, onnx_node, run_params): # type: ignore
|
|
15
|
+
OpRun.__init__(self, onnx_node, run_params)
|
|
16
|
+
self.vf = np.vectorize(sigmoid)
|
|
17
|
+
|
|
18
|
+
def _run(self, X):
|
|
19
|
+
if len(X.shape) == 0:
|
|
20
|
+
return ((X * sigmoid(X)).astype(X.dtype),)
|
|
21
|
+
if X.size == 0:
|
|
22
|
+
return (X,)
|
|
23
|
+
return ((X * self.vf(X)).astype(X.dtype),)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from onnx.reference.op_run import OpRun
|
|
2
|
+
from onnx.reference.ops.op_average_pool import AveragePool_19 as AveragePool
|
|
3
|
+
from onnx.reference.ops.op_dequantize_linear import DequantizeLinear_19 as DequantizeLinear
|
|
4
|
+
from onnx.reference.ops.op_quantize_linear import QuantizeLinear_19 as QuantizeLinear
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class QLinearAveragePool(OpRun):
|
|
8
|
+
op_domain = "com.microsoft"
|
|
9
|
+
|
|
10
|
+
def _run(
|
|
11
|
+
self,
|
|
12
|
+
x,
|
|
13
|
+
x_scale,
|
|
14
|
+
x_zero_point,
|
|
15
|
+
y_scale,
|
|
16
|
+
y_zero_point,
|
|
17
|
+
auto_pad=None,
|
|
18
|
+
ceil_mode=None,
|
|
19
|
+
channels_last=None,
|
|
20
|
+
count_include_pad=None,
|
|
21
|
+
kernel_shape=None,
|
|
22
|
+
pads=None,
|
|
23
|
+
strides=None,
|
|
24
|
+
):
|
|
25
|
+
assert channels_last in (
|
|
26
|
+
None,
|
|
27
|
+
0,
|
|
28
|
+
), f"QLinearAveragePool not implemented if channels_last={channels_last}"
|
|
29
|
+
dqx = DequantizeLinear.eval(x, x_scale, x_zero_point)
|
|
30
|
+
y = AveragePool.eval(
|
|
31
|
+
dqx,
|
|
32
|
+
auto_pad=auto_pad,
|
|
33
|
+
ceil_mode=ceil_mode,
|
|
34
|
+
count_include_pad=count_include_pad,
|
|
35
|
+
kernel_shape=kernel_shape,
|
|
36
|
+
pads=pads,
|
|
37
|
+
strides=strides,
|
|
38
|
+
)
|
|
39
|
+
qy = QuantizeLinear.eval(y, y_scale, y_zero_point)
|
|
40
|
+
return (qy,)
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
from onnx.defs import OpSchema
|
|
3
|
+
from onnx.helper import make_attribute
|
|
4
|
+
from onnx.reference.op_run import OpRun
|
|
5
|
+
from onnx.reference.ops.op_conv import Conv
|
|
6
|
+
from onnx.reference.ops.op_dequantize_linear import DequantizeLinear_19 as DequantizeLinear
|
|
7
|
+
from onnx.reference.ops.op_quantize_linear import QuantizeLinear_19 as QuantizeLinear
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _switch_dims_nchw_nhwc(dims: Tuple[int, ...], from_nchw_to_nhwc: bool):
|
|
11
|
+
if len(dims) == 4:
|
|
12
|
+
if from_nchw_to_nhwc:
|
|
13
|
+
return (dims[0], *dims[2:], dims[1])
|
|
14
|
+
return (dims[0], dims[-1], *dims[1:-1])
|
|
15
|
+
if len(dims) == 3:
|
|
16
|
+
if from_nchw_to_nhwc:
|
|
17
|
+
return (*dims[1:], dims[0])
|
|
18
|
+
return (dims[-1], *dims[:-1])
|
|
19
|
+
raise NotImplementedError(f"Unable to process shape={dims}")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class QLinearConv(OpRun):
|
|
23
|
+
op_domain = "com.microsoft"
|
|
24
|
+
|
|
25
|
+
op_schema = OpSchema(
|
|
26
|
+
"QLinearConv",
|
|
27
|
+
"com.microsoft",
|
|
28
|
+
1,
|
|
29
|
+
inputs=[
|
|
30
|
+
OpSchema.FormalParameter("x", "T"),
|
|
31
|
+
OpSchema.FormalParameter("x_scale", "T"),
|
|
32
|
+
OpSchema.FormalParameter("x_zero_point", "T1"),
|
|
33
|
+
OpSchema.FormalParameter("w", "T"),
|
|
34
|
+
OpSchema.FormalParameter("w_scale", "T"),
|
|
35
|
+
OpSchema.FormalParameter("w_zero_point", "T2"),
|
|
36
|
+
OpSchema.FormalParameter("y_scale", "T"),
|
|
37
|
+
OpSchema.FormalParameter("y_zero_point", "T3"),
|
|
38
|
+
OpSchema.FormalParameter(
|
|
39
|
+
"B", "T3", param_option=OpSchema.FormalParameterOption.Optional
|
|
40
|
+
),
|
|
41
|
+
],
|
|
42
|
+
outputs=[OpSchema.FormalParameter("y", "T3")],
|
|
43
|
+
type_constraints=[
|
|
44
|
+
("T", ["tensor(float)"], ""),
|
|
45
|
+
("T1", ["tensor(int8)", "tensor(uint8)"], ""),
|
|
46
|
+
("T2", ["tensor(int8)", "tensor(uint8)"], ""),
|
|
47
|
+
("T3", ["tensor(int8)", "tensor(uint8)"], ""),
|
|
48
|
+
],
|
|
49
|
+
attributes=[
|
|
50
|
+
OpSchema.Attribute("auto_pad", make_attribute("auto_pad", "NOTSET"), ""),
|
|
51
|
+
OpSchema.Attribute("kernel_shape", OpSchema.AttrType.INTS, "", required=False),
|
|
52
|
+
OpSchema.Attribute("dilations", OpSchema.AttrType.INTS, "", required=False),
|
|
53
|
+
OpSchema.Attribute("strides", OpSchema.AttrType.INTS, "", required=False),
|
|
54
|
+
OpSchema.Attribute("pads", OpSchema.AttrType.INTS, "", required=False),
|
|
55
|
+
OpSchema.Attribute("group", make_attribute("group", 1), ""),
|
|
56
|
+
OpSchema.Attribute("channels_last", make_attribute("channels_last", 0), ""),
|
|
57
|
+
],
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def _run(
|
|
61
|
+
self,
|
|
62
|
+
x,
|
|
63
|
+
x_scale,
|
|
64
|
+
x_zero_point,
|
|
65
|
+
w,
|
|
66
|
+
w_scale,
|
|
67
|
+
w_zero_point,
|
|
68
|
+
y_scale,
|
|
69
|
+
y_zero_point,
|
|
70
|
+
B=None,
|
|
71
|
+
auto_pad=None,
|
|
72
|
+
channels_last=None,
|
|
73
|
+
dilations=None,
|
|
74
|
+
group=None,
|
|
75
|
+
kernel_shape=None,
|
|
76
|
+
pads=None,
|
|
77
|
+
strides=None,
|
|
78
|
+
):
|
|
79
|
+
dqx = DequantizeLinear.eval(x, x_scale, x_zero_point)
|
|
80
|
+
dqw = DequantizeLinear.eval(w, w_scale, w_zero_point)
|
|
81
|
+
if channels_last:
|
|
82
|
+
dqx = dqx.reshape(_switch_dims_nchw_nhwc(x.shape, False))
|
|
83
|
+
dqb = (
|
|
84
|
+
DequantizeLinear.eval(B, x_scale * w_scale, 0).astype(dqx.dtype)
|
|
85
|
+
if B is not None
|
|
86
|
+
else None
|
|
87
|
+
)
|
|
88
|
+
y = Conv.eval(
|
|
89
|
+
dqx,
|
|
90
|
+
dqw,
|
|
91
|
+
dqb,
|
|
92
|
+
auto_pad=auto_pad,
|
|
93
|
+
dilations=dilations,
|
|
94
|
+
group=group,
|
|
95
|
+
kernel_shape=kernel_shape,
|
|
96
|
+
pads=pads,
|
|
97
|
+
strides=strides,
|
|
98
|
+
)
|
|
99
|
+
if channels_last:
|
|
100
|
+
y = y.reshape(_switch_dims_nchw_nhwc(y.shape, True))
|
|
101
|
+
qy = QuantizeLinear.eval(y, y_scale, y_zero_point)
|
|
102
|
+
return (qy,)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from onnx.reference.op_run import OpRun
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def sigmoid(x): # type: ignore
|
|
6
|
+
if x > 0:
|
|
7
|
+
return 1 / (1 + np.exp(-x))
|
|
8
|
+
return np.exp(x) / (1 + np.exp(x))
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class QuickGelu(OpRun):
|
|
12
|
+
op_domain = "com.microsoft"
|
|
13
|
+
|
|
14
|
+
def __init__(self, onnx_node, run_params): # type: ignore
|
|
15
|
+
OpRun.__init__(self, onnx_node, run_params)
|
|
16
|
+
self.vf = np.vectorize(sigmoid)
|
|
17
|
+
|
|
18
|
+
def _run(self, X, alpha=1.0):
|
|
19
|
+
if len(X.shape) == 0:
|
|
20
|
+
return ((X * sigmoid(X * alpha)).astype(X.dtype),)
|
|
21
|
+
if X.size == 0:
|
|
22
|
+
return (X,)
|
|
23
|
+
return ((X * self.vf(X * alpha)).astype(X.dtype),)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from onnx.reference.op_run import OpRun
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class ReplaceZero(OpRun):
|
|
5
|
+
op_domain = "onnx_extended.ortops.optim.cuda"
|
|
6
|
+
|
|
7
|
+
def _run(self, X, by=None, equal=None):
|
|
8
|
+
x2 = X.copy().flatten()
|
|
9
|
+
if equal:
|
|
10
|
+
x2[x2 == 0] = by
|
|
11
|
+
else:
|
|
12
|
+
x2[x2 != 0] = by
|
|
13
|
+
return (x2.reshape(X.shape),)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from onnx.reference.op_run import OpRun
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Rotary(OpRun):
|
|
5
|
+
op_domain = "onnx_extended.ortops.optim.cuda"
|
|
6
|
+
|
|
7
|
+
def _run(self, X, splits=None, side=None):
|
|
8
|
+
assert splits is None or (
|
|
9
|
+
splits.shape == (2,) and splits[0] == splits[1]
|
|
10
|
+
), f"Unexpected split value {splits}"
|
|
11
|
+
last_dim = X.shape[-1] // 2
|
|
12
|
+
cp = X.copy()
|
|
13
|
+
if side == "left":
|
|
14
|
+
cp[..., :last_dim] = X[..., last_dim:]
|
|
15
|
+
cp[..., last_dim:] = -X[..., :last_dim]
|
|
16
|
+
else:
|
|
17
|
+
cp[..., :last_dim] = -X[..., last_dim:]
|
|
18
|
+
cp[..., last_dim:] = X[..., :last_dim]
|
|
19
|
+
return (cp,)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from onnx.reference.ops.op_scan import Scan as _Scan
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Scan(_Scan):
|
|
6
|
+
|
|
7
|
+
def need_context(self) -> bool:
|
|
8
|
+
"""Tells the runtime if this node needs the context
|
|
9
|
+
(all the results produced so far) as it may silently access
|
|
10
|
+
one of them (operator Loop).
|
|
11
|
+
The default answer is `False`.
|
|
12
|
+
"""
|
|
13
|
+
return True
|
|
14
|
+
|
|
15
|
+
def _run(
|
|
16
|
+
self,
|
|
17
|
+
*args,
|
|
18
|
+
context=None,
|
|
19
|
+
body=None,
|
|
20
|
+
num_scan_inputs=None,
|
|
21
|
+
scan_input_axes=None,
|
|
22
|
+
scan_input_directions=None,
|
|
23
|
+
scan_output_axes=None,
|
|
24
|
+
scan_output_directions=None,
|
|
25
|
+
attributes=None,
|
|
26
|
+
):
|
|
27
|
+
(
|
|
28
|
+
num_loop_state_vars,
|
|
29
|
+
_num_scan_outputs,
|
|
30
|
+
_output_directions,
|
|
31
|
+
_max_dir_out,
|
|
32
|
+
_output_axes,
|
|
33
|
+
_max_axe_out,
|
|
34
|
+
state_names_in,
|
|
35
|
+
state_names_out,
|
|
36
|
+
scan_names_in,
|
|
37
|
+
scan_names_out,
|
|
38
|
+
scan_values,
|
|
39
|
+
states,
|
|
40
|
+
) = self._common_run_shape(*args)
|
|
41
|
+
|
|
42
|
+
max_iter = args[num_loop_state_vars].shape[self.input_axes_[0]]
|
|
43
|
+
results = [[] for _ in scan_names_out] # type: ignore
|
|
44
|
+
|
|
45
|
+
for it in range(max_iter):
|
|
46
|
+
inputs = context.copy()
|
|
47
|
+
inputs.update(dict(zip(state_names_in, states)))
|
|
48
|
+
inputs.update({name: value[it] for name, value in zip(scan_names_in, scan_values)})
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
outputs_list = self._run_body(inputs) # type: ignore
|
|
52
|
+
except TypeError as e:
|
|
53
|
+
raise TypeError(
|
|
54
|
+
f"Unable to call 'run' for type '{type(self.body)}'." # type: ignore
|
|
55
|
+
) from e
|
|
56
|
+
|
|
57
|
+
outputs = dict(zip(self.output_names, outputs_list))
|
|
58
|
+
states = [outputs[name] for name in state_names_out]
|
|
59
|
+
for i, name in enumerate(scan_names_out):
|
|
60
|
+
results[i].append(np.expand_dims(outputs[name], axis=0))
|
|
61
|
+
|
|
62
|
+
for res in results:
|
|
63
|
+
conc = np.vstack(res)
|
|
64
|
+
states.append(conc)
|
|
65
|
+
return self._check_and_fix_outputs(tuple(states))
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from onnx.reference.op_run import OpRun
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def scatter_elements(data, indices, updates, axis=0, reduction=None): # type: ignore
|
|
7
|
+
if reduction == "add":
|
|
8
|
+
|
|
9
|
+
def f(x, y):
|
|
10
|
+
return x + y
|
|
11
|
+
|
|
12
|
+
elif reduction == "min":
|
|
13
|
+
|
|
14
|
+
def f(x, y):
|
|
15
|
+
return np.minimum(x, y)
|
|
16
|
+
|
|
17
|
+
elif reduction == "max":
|
|
18
|
+
|
|
19
|
+
def f(x, y):
|
|
20
|
+
return np.maximum(x, y)
|
|
21
|
+
|
|
22
|
+
elif reduction == "mul":
|
|
23
|
+
|
|
24
|
+
def f(x, y):
|
|
25
|
+
return x * y
|
|
26
|
+
|
|
27
|
+
else:
|
|
28
|
+
|
|
29
|
+
def f(x, y):
|
|
30
|
+
return y
|
|
31
|
+
|
|
32
|
+
if axis < 0:
|
|
33
|
+
axis = data.ndim + axis
|
|
34
|
+
|
|
35
|
+
if len(data.shape) == 1 and axis == 0:
|
|
36
|
+
scattered = np.copy(data)
|
|
37
|
+
for pos, up in zip(indices, updates):
|
|
38
|
+
scattered[pos] = f(scattered[pos], up)
|
|
39
|
+
return scattered
|
|
40
|
+
|
|
41
|
+
if len(indices.shape) == 2:
|
|
42
|
+
scattered = np.copy(data)
|
|
43
|
+
if axis == 0:
|
|
44
|
+
for i in range(indices.shape[0]):
|
|
45
|
+
for j in range(indices.shape[1]):
|
|
46
|
+
scattered[indices[i, j], j] = f(scattered[indices[i, j], j], updates[i, j])
|
|
47
|
+
else:
|
|
48
|
+
for i in range(indices.shape[0]):
|
|
49
|
+
for j in range(indices.shape[1]):
|
|
50
|
+
scattered[i, indices[i, j]] = f(scattered[i, indices[i, j]], updates[i, j])
|
|
51
|
+
return scattered
|
|
52
|
+
|
|
53
|
+
if len(indices.shape) == 3:
|
|
54
|
+
scattered = np.copy(data)
|
|
55
|
+
if axis == 0:
|
|
56
|
+
for i in range(indices.shape[0]):
|
|
57
|
+
for j in range(indices.shape[1]):
|
|
58
|
+
for k in range(indices.shape[2]):
|
|
59
|
+
scattered[indices[i, j, k], j, k] = f(
|
|
60
|
+
scattered[indices[i, j, k], j, k], updates[i, j, k]
|
|
61
|
+
)
|
|
62
|
+
elif axis == 1:
|
|
63
|
+
for i in range(indices.shape[0]):
|
|
64
|
+
for j in range(indices.shape[1]):
|
|
65
|
+
for k in range(indices.shape[2]):
|
|
66
|
+
scattered[i, indices[i, j, k], k] = f(
|
|
67
|
+
scattered[i, indices[i, j, k], k], updates[i, j, k]
|
|
68
|
+
)
|
|
69
|
+
elif axis == 2:
|
|
70
|
+
for i in range(indices.shape[0]):
|
|
71
|
+
for j in range(indices.shape[1]):
|
|
72
|
+
for k in range(indices.shape[2]):
|
|
73
|
+
scattered[i, j, indices[i, j, k]] = f(
|
|
74
|
+
scattered[i, j, indices[i, j, k]], updates[i, j, k]
|
|
75
|
+
)
|
|
76
|
+
return scattered
|
|
77
|
+
|
|
78
|
+
if len(indices.shape) == 4:
|
|
79
|
+
scattered = np.copy(data)
|
|
80
|
+
if axis == 3:
|
|
81
|
+
for a in range(indices.shape[0]):
|
|
82
|
+
for i in range(indices.shape[1]):
|
|
83
|
+
for j in range(indices.shape[2]):
|
|
84
|
+
for k in range(indices.shape[3]):
|
|
85
|
+
scattered[a, i, j, indices[a, i, j, k]] = f(
|
|
86
|
+
scattered[a, i, j, indices[a, i, j, k]],
|
|
87
|
+
updates[a, i, j, k],
|
|
88
|
+
)
|
|
89
|
+
return scattered
|
|
90
|
+
if axis == 0:
|
|
91
|
+
for a in range(indices.shape[0]):
|
|
92
|
+
for i in range(indices.shape[1]):
|
|
93
|
+
for j in range(indices.shape[2]):
|
|
94
|
+
for k in range(indices.shape[3]):
|
|
95
|
+
scattered[indices[a, i, j, k], i, j, k] = f(
|
|
96
|
+
scattered[indices[a, i, j, k], i, j, k],
|
|
97
|
+
updates[a, i, j, k],
|
|
98
|
+
)
|
|
99
|
+
return scattered
|
|
100
|
+
|
|
101
|
+
raise RuntimeError(f"Not implemented for indices.shape={indices.shape} and axis={axis}")
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class ScatterElements(OpRun):
|
|
105
|
+
def _run(self, data, indices, updates, axis=None, reduction=None): # type: ignore
|
|
106
|
+
res = scatter_elements(data, indices, updates, axis=axis, reduction=reduction)
|
|
107
|
+
return (res,)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from onnx.reference.op_run import OpRun
|
|
3
|
+
from onnx.reference.ops.op_scatternd import _scatter_nd_impl
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ScatterNDOfShape(OpRun):
|
|
7
|
+
op_domain = "onnx_extended.ortops.optim.cuda"
|
|
8
|
+
|
|
9
|
+
def _run(self, shape, indices, updates, reduction=None, strategy=None):
|
|
10
|
+
data = np.zeros(shape, dtype=updates.dtype)
|
|
11
|
+
y = _scatter_nd_impl(data, indices, updates, reduction=reduction)
|
|
12
|
+
return (y,)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MaskedScatterNDOfShape(OpRun):
|
|
16
|
+
op_domain = "onnx_extended.ortops.optim.cuda"
|
|
17
|
+
|
|
18
|
+
def _run(self, shape, indices, updates, reduction=None, maskedValue=None):
|
|
19
|
+
data = np.zeros(shape, dtype=updates.dtype)
|
|
20
|
+
new_updates = np.where(indices == maskedValue, 0, updates)
|
|
21
|
+
y = _scatter_nd_impl(data, indices, new_updates, reduction=reduction)
|
|
22
|
+
return (y,)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from onnx.reference.op_run import OpRun
|
|
2
|
+
from onnx.reference.ops.op_layer_normalization import _layer_normalization
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class SkipLayerNormalization(OpRun):
|
|
6
|
+
op_domain = "com.microsoft"
|
|
7
|
+
|
|
8
|
+
def _run(self, x, skip, gamma=None, beta=None, bias=None, epsilon=None):
|
|
9
|
+
add = x + skip
|
|
10
|
+
if bias is not None:
|
|
11
|
+
add = add + bias
|
|
12
|
+
res = _layer_normalization(add, gamma, beta, axis=-1, epsilon=epsilon)
|
|
13
|
+
return (*res, add)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from onnx.reference.ops.op_slice import SliceCommon
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Slice_10(SliceCommon):
|
|
5
|
+
def __init__(self, onnx_node, run_params):
|
|
6
|
+
SliceCommon.__init__(self, onnx_node, run_params)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Slice_1(SliceCommon):
|
|
10
|
+
def __init__(self, onnx_node, run_params):
|
|
11
|
+
print(onnx_node)
|
|
12
|
+
SliceCommon.__init__(self, onnx_node, run_params)
|
|
13
|
+
for f in ["starts", "ends", "steps", "axes"]:
|
|
14
|
+
if not hasattr(self, f):
|
|
15
|
+
continue
|
|
16
|
+
if getattr(self, f) is not None and len(getattr(self, f)) == 0:
|
|
17
|
+
setattr(self, f, None)
|
|
18
|
+
|
|
19
|
+
def _run(self, data, axes=None, ends=None, starts=None):
|
|
20
|
+
return SliceCommon._run(self, data, starts, ends, axes)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from onnx.reference.op_run import OpRun
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Transpose2DCastFP16(OpRun):
|
|
6
|
+
op_domain = "onnx_extended.ortops.optim.cuda"
|
|
7
|
+
|
|
8
|
+
def _run(self, X):
|
|
9
|
+
return (X.T.astype(np.float16),)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Transpose2DCastFP32(OpRun):
|
|
13
|
+
op_domain = "onnx_extended.ortops.optim.cuda"
|
|
14
|
+
|
|
15
|
+
def _run(self, X):
|
|
16
|
+
return (X.T.astype(np.float32),)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from onnx.reference.op_run import OpRun
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class TriMatrix(OpRun):
|
|
6
|
+
op_domain = "onnx_extended.ortops.optim.cuda"
|
|
7
|
+
|
|
8
|
+
def _run(self, shape, csts):
|
|
9
|
+
lower, diag, upper = list(csts)
|
|
10
|
+
dtype = csts.dtype
|
|
11
|
+
mat = np.empty(tuple(shape), dtype=dtype)
|
|
12
|
+
i = np.arange(shape[0], dtype=np.int32).reshape((-1, 1))
|
|
13
|
+
j = np.arange(shape[1], dtype=np.int32).reshape((1, -1))
|
|
14
|
+
mat[i > j] = lower
|
|
15
|
+
mat[i < j] = upper
|
|
16
|
+
mat[i == j] = diag
|
|
17
|
+
return (mat,)
|