onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
import onnx
|
|
3
|
+
import torch
|
|
4
|
+
from . import OpRunKernel, OpRunTensor
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class OpRunControlFlow(OpRunKernel):
|
|
8
|
+
"""Common ancestor for control flows."""
|
|
9
|
+
|
|
10
|
+
@classmethod
|
|
11
|
+
def has_subgraphs(cls) -> bool:
|
|
12
|
+
"""Returns True if the kernel has subgraphs."""
|
|
13
|
+
return True
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
node: onnx.NodeProto,
|
|
18
|
+
version: Optional[int] = None,
|
|
19
|
+
parent: Optional["onnx_diagnostic.reference.TorchOnnxEvaluator"] = None, # noqa: F821
|
|
20
|
+
verbose: int = 0,
|
|
21
|
+
):
|
|
22
|
+
super().__init__(node, version, verbose=verbose)
|
|
23
|
+
assert (
|
|
24
|
+
parent is not None
|
|
25
|
+
), f"parent must be specified for operator {self.__class__.__name__!r}"
|
|
26
|
+
for att in node.attribute:
|
|
27
|
+
if att.type == onnx.AttributeProto.GRAPH:
|
|
28
|
+
rt = parent.__class__(
|
|
29
|
+
att.g,
|
|
30
|
+
providers=parent.providers,
|
|
31
|
+
opsets=parent.opsets,
|
|
32
|
+
local_functions=parent.functions,
|
|
33
|
+
verbose=parent.verbose,
|
|
34
|
+
custom_kernels=parent.custom_kernels,
|
|
35
|
+
)
|
|
36
|
+
setattr(self, att.name, rt)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class If_1(OpRunControlFlow):
|
|
40
|
+
"If"
|
|
41
|
+
|
|
42
|
+
def run(self, cond, context: Optional[Dict[str, Any]] = None):
|
|
43
|
+
rt = self.then_branch if cond.tensor.item() else self.else_branch # type: ignore[attr-defined]
|
|
44
|
+
return rt.run_with_values(context=context)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class Loop_16(OpRunControlFlow):
|
|
48
|
+
"Loop"
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
node: onnx.NodeProto,
|
|
53
|
+
version: Optional[int] = None,
|
|
54
|
+
parent: Optional["onnx_diagnostic.reference.TorchOnnxEvaluator"] = None, # noqa: F821
|
|
55
|
+
verbose: int = 0,
|
|
56
|
+
):
|
|
57
|
+
super().__init__(node, version, parent, verbose=verbose)
|
|
58
|
+
self.output_index = {n: i for i, n in enumerate(self.body.output_names)}
|
|
59
|
+
self.N = len(self.body.input_names) - 2
|
|
60
|
+
self.K = len(self.body.output_names) - self.N - 1
|
|
61
|
+
|
|
62
|
+
def run(self, M, cond, *args, context: Optional[Dict[str, Any]] = None):
|
|
63
|
+
if args:
|
|
64
|
+
v_initial = args[0]
|
|
65
|
+
args = args[1:]
|
|
66
|
+
else:
|
|
67
|
+
v_initial = None
|
|
68
|
+
assert M is None or hasattr(
|
|
69
|
+
M, "dtype"
|
|
70
|
+
), f"M must be empty or an array but its type is {type(M)}."
|
|
71
|
+
body = self.body
|
|
72
|
+
loop_inputs = body.input_names
|
|
73
|
+
inputs = dict.fromkeys(loop_inputs)
|
|
74
|
+
if v_initial is not None:
|
|
75
|
+
inputs[loop_inputs[2]] = v_initial
|
|
76
|
+
cond_name = body.output_names[0]
|
|
77
|
+
if args:
|
|
78
|
+
begin = len(loop_inputs) - len(args)
|
|
79
|
+
all_inputs = loop_inputs[begin:]
|
|
80
|
+
for name, val in zip(all_inputs, args):
|
|
81
|
+
inputs[name] = val
|
|
82
|
+
if context is not None:
|
|
83
|
+
for a in context:
|
|
84
|
+
inputs[a] = context[a]
|
|
85
|
+
|
|
86
|
+
k_carried_away = [[] for i in range(self.K)] # type: ignore
|
|
87
|
+
it = 0
|
|
88
|
+
while (cond is None or cond.tensor is None or cond.tensor.item()) and (
|
|
89
|
+
M is None or M.tensor is None or it < M.tensor.item()
|
|
90
|
+
):
|
|
91
|
+
if len(body.input_names) > 0 and body.input_names[0] is not None:
|
|
92
|
+
inputs[body.input_names[0]] = OpRunTensor(
|
|
93
|
+
torch.tensor(it, dtype=None if M is None else M.dtype)
|
|
94
|
+
)
|
|
95
|
+
if len(body.input_names) > 1 and body.input_names[1] is not None:
|
|
96
|
+
inputs[body.input_names[1]] = cond
|
|
97
|
+
outputs = list(
|
|
98
|
+
self.body.run_with_values(
|
|
99
|
+
*[inputs[k] for k in self.body.input_names], context=context
|
|
100
|
+
)
|
|
101
|
+
)
|
|
102
|
+
if self.K > 0:
|
|
103
|
+
for k in range(self.K):
|
|
104
|
+
k_carried_away[k].append(outputs[-self.K + k])
|
|
105
|
+
index_cond = self.output_index[cond_name]
|
|
106
|
+
cond = outputs[index_cond]
|
|
107
|
+
assert (
|
|
108
|
+
cond is not None
|
|
109
|
+
), f"Condition {cond_name!r} returned by the subgraph cannot be None."
|
|
110
|
+
for i, o in zip(body.input_names[2:], body.output_names[1:]):
|
|
111
|
+
inputs[i] = outputs[self.output_index[o]]
|
|
112
|
+
it += 1
|
|
113
|
+
|
|
114
|
+
if it == 0:
|
|
115
|
+
outputs = [inputs[i] for i in body.input_names[2:]]
|
|
116
|
+
else:
|
|
117
|
+
outputs = outputs[1 : 1 + self.N]
|
|
118
|
+
outputs.extend([OpRunTensor(torch.cat(x, axis=0)) for x in k_carried_away])
|
|
119
|
+
while len(outputs) < len(self.body.output_names):
|
|
120
|
+
outputs.append(OpRunTensor(torch.empty(())))
|
|
121
|
+
return tuple(outputs)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
import onnx
|
|
3
|
+
import torch
|
|
4
|
+
from . import OpRunKernel, OpRunTensor
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Range_11(OpRunKernel):
|
|
8
|
+
"""Range"""
|
|
9
|
+
|
|
10
|
+
@classmethod
|
|
11
|
+
def device_dependent(cls) -> bool:
|
|
12
|
+
"""
|
|
13
|
+
Returns True if the kernel needs a device to be efficiently initialized.
|
|
14
|
+
"""
|
|
15
|
+
return True
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
node: onnx.NodeProto,
|
|
20
|
+
version: Optional[int] = None,
|
|
21
|
+
device: Optional[torch.device] = None,
|
|
22
|
+
verbose: int = 0,
|
|
23
|
+
):
|
|
24
|
+
super().__init__(node, version, verbose=verbose)
|
|
25
|
+
self.device = device
|
|
26
|
+
|
|
27
|
+
def run(self, starts: OpRunTensor, limit: OpRunTensor, delta: OpRunTensor) -> OpRunTensor:
|
|
28
|
+
return OpRunTensor(
|
|
29
|
+
torch.arange(
|
|
30
|
+
starts.tensor,
|
|
31
|
+
limit.tensor,
|
|
32
|
+
delta.tensor,
|
|
33
|
+
dtype=starts.dtype,
|
|
34
|
+
device=self.device,
|
|
35
|
+
)
|
|
36
|
+
)
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
import onnx
|
|
3
|
+
import torch
|
|
4
|
+
from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
|
|
5
|
+
from . import OpRunKernel, OpRunTensor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AveragePool_11(OpRunKernel):
|
|
9
|
+
"AveragePool"
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
node: onnx.NodeProto,
|
|
14
|
+
version: Optional[int] = None,
|
|
15
|
+
verbose: int = 0,
|
|
16
|
+
):
|
|
17
|
+
super().__init__(node, version, verbose=verbose)
|
|
18
|
+
self.auto_pad = self.get_attribute_string(node, "auto_pad", "NOTSET")
|
|
19
|
+
self.ceil_mode = bool(self.get_attribute_int(node, "ceil_mode", 0))
|
|
20
|
+
self.count_include_pad = bool(self.get_attribute_int(node, "count_include_pad", 0))
|
|
21
|
+
self.dilations = self.get_attribute_ints(node, "dilations", None)
|
|
22
|
+
self.kernel_shape: Tuple[int, ...] = (
|
|
23
|
+
self.get_attribute_ints(node, "kernel_shape") or tuple()
|
|
24
|
+
)
|
|
25
|
+
self.pads = self.get_attribute_ints(node, "pads", None)
|
|
26
|
+
self.strides = self.get_attribute_ints(node, "strides", None)
|
|
27
|
+
|
|
28
|
+
def run(self, x):
|
|
29
|
+
kernel_shape = self.kernel_shape
|
|
30
|
+
dilations = self.dilations or [1 for _ in x.shape[2:]]
|
|
31
|
+
strides = self.strides or [1 for _ in x.shape[2:]]
|
|
32
|
+
pads = self.pads or ([0 for _ in x.shape[2:]] * 2)
|
|
33
|
+
assert (
|
|
34
|
+
self.auto_pad == "NOTSET"
|
|
35
|
+
), f"conv not implemented for auto_pad={self.auto_pad!r}"
|
|
36
|
+
assert len(set(pads)) == 1, f"conv not implemented for pads={pads}"
|
|
37
|
+
assert set(dilations) == {1}, f"conv not implemented for dilations={dilations}"
|
|
38
|
+
avg_pool = getattr(torch.nn.functional, f"avg_pool{len(kernel_shape)}d")
|
|
39
|
+
return OpRunTensor(
|
|
40
|
+
avg_pool(
|
|
41
|
+
x.tensor,
|
|
42
|
+
kernel_size=tuple(kernel_shape),
|
|
43
|
+
stride=tuple(strides),
|
|
44
|
+
padding=pads[0],
|
|
45
|
+
ceil_mode=self.ceil_mode,
|
|
46
|
+
count_include_pad=self.count_include_pad,
|
|
47
|
+
# dilation=tuple(dilations),
|
|
48
|
+
)
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Conv_11(OpRunKernel):
|
|
53
|
+
"Conv"
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
node: onnx.NodeProto,
|
|
58
|
+
version: Optional[int] = None,
|
|
59
|
+
verbose: int = 0,
|
|
60
|
+
):
|
|
61
|
+
super().__init__(node, version, verbose=verbose)
|
|
62
|
+
self.auto_pad = self.get_attribute_string(node, "auto_pad", "NOTSET")
|
|
63
|
+
self.dilations = self.get_attribute_ints(node, "dilations", None)
|
|
64
|
+
self.group = self.get_attribute_int(node, "group", 1)
|
|
65
|
+
self.kernel_shape: Tuple[int, ...] = (
|
|
66
|
+
self.get_attribute_ints(node, "kernel_shape") or tuple()
|
|
67
|
+
)
|
|
68
|
+
self.pads = self.get_attribute_ints(node, "pads", None)
|
|
69
|
+
self.strides = self.get_attribute_ints(node, "strides", None)
|
|
70
|
+
|
|
71
|
+
def run(self, x, w, b=None):
|
|
72
|
+
kernel_shape = self.kernel_shape or w.shape[2:]
|
|
73
|
+
assert (
|
|
74
|
+
tuple(kernel_shape) == w.shape[-len(kernel_shape) :]
|
|
75
|
+
), f"conv not implemented for kernel_shape={kernel_shape} and w.shape={w.shape}"
|
|
76
|
+
dilations = self.dilations or [1 for _ in x.shape[2:]]
|
|
77
|
+
strides = self.strides or [1 for _ in x.shape[2:]]
|
|
78
|
+
|
|
79
|
+
if self.auto_pad in {"SAME_LOWER", "SAME_UPPER"}:
|
|
80
|
+
head = []
|
|
81
|
+
tail = []
|
|
82
|
+
for i in range(len(x.shape) - 2):
|
|
83
|
+
d = x.shape[i + 2]
|
|
84
|
+
target_size = (d + strides[i] - 1) // strides[i]
|
|
85
|
+
pad_needed = (target_size - 1) * strides[i] + kernel_shape[i] - d
|
|
86
|
+
pad_head = (
|
|
87
|
+
(pad_needed + 1) // 2 if self.auto_pad == "SAME_LOWER" else pad_needed // 2
|
|
88
|
+
)
|
|
89
|
+
pad_tail = pad_needed - pad_head
|
|
90
|
+
head.append(pad_head)
|
|
91
|
+
tail.append(pad_tail)
|
|
92
|
+
pads = head + tail
|
|
93
|
+
else:
|
|
94
|
+
pads = self.pads or ([0 for _ in x.shape[2:]] * 2)
|
|
95
|
+
|
|
96
|
+
assert len(set(pads)) == 1, (
|
|
97
|
+
f"conv not implemented for pads={pads}, "
|
|
98
|
+
f"auto_pad={self.auto_pad!r}, strides={strides}, "
|
|
99
|
+
f"x.shape={x.shape}, kernel_shape={kernel_shape}"
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
if b is None:
|
|
103
|
+
bias = None
|
|
104
|
+
else:
|
|
105
|
+
bias = b.tensor.squeeze()
|
|
106
|
+
if not bias.shape:
|
|
107
|
+
bias = bias.unsqueeze(0)
|
|
108
|
+
return OpRunTensor(
|
|
109
|
+
torch.nn.functional.conv2d(
|
|
110
|
+
x.tensor,
|
|
111
|
+
w.tensor,
|
|
112
|
+
bias=bias,
|
|
113
|
+
stride=tuple(strides),
|
|
114
|
+
padding=pads[0],
|
|
115
|
+
dilation=tuple(dilations),
|
|
116
|
+
groups=self.group,
|
|
117
|
+
)
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class LayerNormalization_17(OpRunKernel):
|
|
122
|
+
"LayerNormalization"
|
|
123
|
+
|
|
124
|
+
def __init__(
|
|
125
|
+
self,
|
|
126
|
+
node: onnx.NodeProto,
|
|
127
|
+
version: Optional[int] = None,
|
|
128
|
+
verbose: int = 0,
|
|
129
|
+
):
|
|
130
|
+
super().__init__(node, version, verbose=verbose)
|
|
131
|
+
self.axis = self.get_attribute_int(node, "axis", -1)
|
|
132
|
+
self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
|
|
133
|
+
self.stash_type = onnx_dtype_to_torch_dtype(
|
|
134
|
+
self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT) # type: ignore[arg-type]
|
|
135
|
+
)
|
|
136
|
+
self.compute_std = len(node.output) > 1
|
|
137
|
+
|
|
138
|
+
def run(self, x, scale, bias=None):
|
|
139
|
+
original_dtype = x.dtype
|
|
140
|
+
if self.stash_type == torch.float32 and x.tensor.dtype != torch.float64:
|
|
141
|
+
xt = x.tensor
|
|
142
|
+
res = torch.nn.functional.layer_norm(
|
|
143
|
+
xt,
|
|
144
|
+
xt.shape[self.axis :],
|
|
145
|
+
weight=scale.tensor,
|
|
146
|
+
bias=None if bias is None else bias.tensor,
|
|
147
|
+
eps=self.epsilon,
|
|
148
|
+
)
|
|
149
|
+
else:
|
|
150
|
+
xt = x.tensor.to(self.stash_type)
|
|
151
|
+
res = torch.nn.functional.layer_norm(
|
|
152
|
+
xt,
|
|
153
|
+
xt.shape[self.axis :],
|
|
154
|
+
weight=scale.tensor.to(self.stash_type),
|
|
155
|
+
bias=None if bias is None else bias.tensor.to(self.stash_type),
|
|
156
|
+
eps=self.epsilon,
|
|
157
|
+
)
|
|
158
|
+
if not self.compute_std:
|
|
159
|
+
return OpRunTensor(res.to(original_dtype))
|
|
160
|
+
axes = tuple(range(len(xt.shape)))[self.axis :]
|
|
161
|
+
mean, var = torch.var(xt, dim=axes, keepdim=False)
|
|
162
|
+
x_inv_std_dev = torch.reciprocal(torch.sqrt(var + self.epsilon))
|
|
163
|
+
return (
|
|
164
|
+
OpRunTensor(res.to(original_dtype)),
|
|
165
|
+
OpRunTensor(mean),
|
|
166
|
+
OpRunTensor(x_inv_std_dev),
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class Softmax_13(OpRunKernel):
|
|
171
|
+
"Softmax"
|
|
172
|
+
|
|
173
|
+
def __init__(
|
|
174
|
+
self,
|
|
175
|
+
node: onnx.NodeProto,
|
|
176
|
+
version: Optional[int] = None,
|
|
177
|
+
verbose: int = 0,
|
|
178
|
+
):
|
|
179
|
+
super().__init__(node, version, verbose=verbose)
|
|
180
|
+
self.axis = self.get_attribute_int(node, "axis", -1)
|
|
181
|
+
assert isinstance(self.axis, int), f"Unexpected value for attribute axis={self.axis!r}"
|
|
182
|
+
# this is out of spec
|
|
183
|
+
stash_type = self.get_attribute_int(node, "stash_type", None)
|
|
184
|
+
self.stash_type = None if stash_type is None else onnx_dtype_to_torch_dtype(stash_type)
|
|
185
|
+
|
|
186
|
+
def run(self, data: OpRunTensor) -> OpRunTensor:
|
|
187
|
+
return OpRunTensor(
|
|
188
|
+
torch.nn.functional.softmax(data.tensor, dim=self.axis, dtype=self.stash_type)
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class Tanh_6(OpRunKernel):
|
|
193
|
+
"Tanh"
|
|
194
|
+
|
|
195
|
+
def run(self, data: OpRunTensor) -> OpRunTensor:
|
|
196
|
+
return OpRunTensor(torch.nn.functional.tanh(data.tensor))
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
import onnx
|
|
3
|
+
import torch
|
|
4
|
+
from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
|
|
5
|
+
from . import OpRunKernel, OpRunTensor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Cast_6(OpRunKernel):
|
|
9
|
+
"Cast"
|
|
10
|
+
|
|
11
|
+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
|
|
12
|
+
super().__init__(node, version, verbose=verbose)
|
|
13
|
+
to = self.get_attribute_int(node, "to", 0)
|
|
14
|
+
assert isinstance(to, int), f"Unexpected value for attribute to={to!r}"
|
|
15
|
+
self.to = onnx_dtype_to_torch_dtype(to)
|
|
16
|
+
self.saturate = self.get_attribute_int(node, "saturate", 1)
|
|
17
|
+
assert self.saturate == 1, f"saturate={self.saturate} not implemented for Cast"
|
|
18
|
+
|
|
19
|
+
def run(self, data: OpRunTensor) -> OpRunTensor:
|
|
20
|
+
return OpRunTensor(data.tensor.to(self.to))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class CastLike_15(OpRunKernel):
|
|
24
|
+
"Cast"
|
|
25
|
+
|
|
26
|
+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
|
|
27
|
+
super().__init__(node, version, verbose=verbose)
|
|
28
|
+
self.saturate = self.get_attribute_int(node, "saturate", 1)
|
|
29
|
+
assert self.saturate == 1, f"saturate={self.saturate} not implemented for CastLike"
|
|
30
|
+
|
|
31
|
+
def run(self, data: OpRunTensor, like: OpRunTensor) -> OpRunTensor:
|
|
32
|
+
return OpRunTensor(data.tensor.to(like.tensor.dtype))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Concat_1(OpRunKernel):
|
|
36
|
+
"Concat"
|
|
37
|
+
|
|
38
|
+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
|
|
39
|
+
super().__init__(node, version, verbose=verbose)
|
|
40
|
+
axis = self.get_attribute_int(node, "axis", 0)
|
|
41
|
+
assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}"
|
|
42
|
+
self.axis = axis
|
|
43
|
+
|
|
44
|
+
def run(self, *data: OpRunTensor) -> OpRunTensor:
|
|
45
|
+
assert data, f"No tensor to concatenate in node name {self.name!r}"
|
|
46
|
+
devices = [d.get_device() for d in data]
|
|
47
|
+
if len(set(devices)) == 1:
|
|
48
|
+
return OpRunTensor(torch.cat([t.tensor for t in data], axis=self.axis))
|
|
49
|
+
if (
|
|
50
|
+
data[0].dtype == torch.int64
|
|
51
|
+
and self.axis == 0
|
|
52
|
+
and max(d.tensor.ndim for d in data) == 1
|
|
53
|
+
and max(d.tensor.numel() for d in data) <= 8
|
|
54
|
+
):
|
|
55
|
+
# This is a shape
|
|
56
|
+
return OpRunTensor(torch.cat([t.tensor.cpu() for t in data], axis=self.axis))
|
|
57
|
+
index = devices.index(max(devices))
|
|
58
|
+
device = data[index].tensor.device
|
|
59
|
+
return OpRunTensor(torch.cat([t.tensor.to(device) for t in data], axis=self.axis))
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class NonZero_13(OpRunKernel):
|
|
63
|
+
"NonZero"
|
|
64
|
+
|
|
65
|
+
def run(self, x: OpRunTensor) -> OpRunTensor:
|
|
66
|
+
return OpRunTensor(torch.nonzero(x.tensor).T)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class Tile_6(OpRunKernel):
|
|
70
|
+
"Tile"
|
|
71
|
+
|
|
72
|
+
def run(self, x: OpRunTensor, repeat: OpRunTensor) -> OpRunTensor:
|
|
73
|
+
return OpRunTensor(torch.tile(x.tensor, repeat.as_tuple_int))
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class Transpose_1(OpRunKernel):
|
|
77
|
+
"Transpose"
|
|
78
|
+
|
|
79
|
+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
|
|
80
|
+
super().__init__(node, version, verbose=verbose)
|
|
81
|
+
self.perm = self.get_attribute_ints(node, "perm", None)
|
|
82
|
+
|
|
83
|
+
def run(self, data: OpRunTensor) -> OpRunTensor:
|
|
84
|
+
return OpRunTensor(torch.permute(data.tensor, self.perm))
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class Trilu_14(OpRunKernel):
|
|
88
|
+
"Trilu"
|
|
89
|
+
|
|
90
|
+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
|
|
91
|
+
super().__init__(node, version, verbose=verbose)
|
|
92
|
+
self.upper = self.get_attribute_int(node, "upper", 1)
|
|
93
|
+
|
|
94
|
+
def run(self, data: OpRunTensor, k: Optional[OpRunTensor] = None) -> OpRunTensor:
|
|
95
|
+
diagonal = 0 if k is None else k.tensor.item()
|
|
96
|
+
if self.upper:
|
|
97
|
+
return OpRunTensor(torch.triu(data.tensor, diagonal=diagonal))
|
|
98
|
+
return OpRunTensor(torch.tril(data.tensor, diagonal=diagonal))
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class Where_9(OpRunKernel):
|
|
102
|
+
"Where"
|
|
103
|
+
|
|
104
|
+
def run(self, cond: OpRunTensor, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
|
|
105
|
+
tcond, tx, ty = self.same_device(cond.tensor, x.tensor, y.tensor)
|
|
106
|
+
return OpRunTensor(torch.where(tcond, tx, ty))
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
import onnx
|
|
3
|
+
import torch
|
|
4
|
+
from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
|
|
5
|
+
from . import OpRunKernel, OpRunTensor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ReduceOp(OpRunKernel):
|
|
9
|
+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
|
|
10
|
+
super().__init__(node, version, verbose=verbose)
|
|
11
|
+
self.keepdims = bool(self.get_attribute_int(node, "keepdims", 1))
|
|
12
|
+
self.noop_with_empty_axes = bool(
|
|
13
|
+
self.get_attribute_int(node, "noop_with_empty_axes", 0)
|
|
14
|
+
)
|
|
15
|
+
assert isinstance(
|
|
16
|
+
self.keepdims, bool
|
|
17
|
+
), f"Unexpected value for attribute keepdims={self.keepdims!r}"
|
|
18
|
+
assert isinstance(self.noop_with_empty_axes, bool), (
|
|
19
|
+
f"Unexpected value for attribute "
|
|
20
|
+
f"noop_with_empty_axes={self.noop_with_empty_axes!r}"
|
|
21
|
+
)
|
|
22
|
+
assert (
|
|
23
|
+
not self.noop_with_empty_axes
|
|
24
|
+
), f"Not implemented with noop_with_empty_axes={self.noop_with_empty_axes}"
|
|
25
|
+
# this is out of spec
|
|
26
|
+
stash_type = self.get_attribute_int(node, "stash_type", None)
|
|
27
|
+
self.stash_type = None if stash_type is None else onnx_dtype_to_torch_dtype(stash_type)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ReduceOpAxes(ReduceOp):
|
|
31
|
+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
|
|
32
|
+
super().__init__(node, version, verbose=verbose)
|
|
33
|
+
self.axes: Tuple[int, ...] = self.get_attribute_ints(node, "axes") or tuple()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ReduceMax_18(ReduceOp):
|
|
37
|
+
"""ReduceMax"""
|
|
38
|
+
|
|
39
|
+
def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
|
|
40
|
+
assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
|
|
41
|
+
if axes is None:
|
|
42
|
+
assert (
|
|
43
|
+
not self.keepdims
|
|
44
|
+
), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
|
|
45
|
+
return OpRunTensor(x.tensor.max())
|
|
46
|
+
taxes = axes.as_tuple_int
|
|
47
|
+
if len(taxes) == 1:
|
|
48
|
+
t = x.tensor.max(taxes[0], keepdim=self.keepdims)
|
|
49
|
+
return OpRunTensor(t.values)
|
|
50
|
+
t = x.tensor
|
|
51
|
+
for a in reversed(taxes):
|
|
52
|
+
t = t.max(a, keepdim=self.keepdims).values
|
|
53
|
+
return OpRunTensor(t)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class ReduceMean_18(ReduceOp):
|
|
57
|
+
"""ReduceMean"""
|
|
58
|
+
|
|
59
|
+
def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
|
|
60
|
+
assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
|
|
61
|
+
if axes is None:
|
|
62
|
+
assert (
|
|
63
|
+
not self.keepdims
|
|
64
|
+
), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
|
|
65
|
+
return OpRunTensor(torch.mean(x.tensor))
|
|
66
|
+
taxes = axes.as_tuple_int
|
|
67
|
+
if len(taxes) == 1:
|
|
68
|
+
t = x.tensor.mean(taxes[0], keepdim=self.keepdims)
|
|
69
|
+
return OpRunTensor(t)
|
|
70
|
+
t = x.tensor.mean(taxes, keepdim=self.keepdims)
|
|
71
|
+
return OpRunTensor(t)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class ReduceMin_17(ReduceOpAxes):
|
|
75
|
+
"""ReduceMin"""
|
|
76
|
+
|
|
77
|
+
def run(self, x: OpRunTensor) -> OpRunTensor:
|
|
78
|
+
assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
|
|
79
|
+
axes = self.axes
|
|
80
|
+
if not axes:
|
|
81
|
+
assert (
|
|
82
|
+
not self.keepdims
|
|
83
|
+
), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
|
|
84
|
+
return OpRunTensor(x.tensor.min())
|
|
85
|
+
taxes = tuple(axes)
|
|
86
|
+
if len(taxes) == 1:
|
|
87
|
+
t = x.tensor.min(taxes[0], keepdim=self.keepdims)
|
|
88
|
+
return OpRunTensor(t.values)
|
|
89
|
+
t = x.tensor
|
|
90
|
+
for a in reversed(taxes):
|
|
91
|
+
t = t.min(a, keepdim=self.keepdims).values
|
|
92
|
+
return OpRunTensor(t)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class ReduceMin_18(ReduceOp):
|
|
96
|
+
"""ReduceMin"""
|
|
97
|
+
|
|
98
|
+
def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
|
|
99
|
+
assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
|
|
100
|
+
if axes is None:
|
|
101
|
+
assert (
|
|
102
|
+
not self.keepdims
|
|
103
|
+
), f"axes is empty, keepdims={self.keepdims} for {self.__class__.__name__}"
|
|
104
|
+
return OpRunTensor(torch.min(x.tensor))
|
|
105
|
+
taxes = axes.as_tuple_int
|
|
106
|
+
if len(taxes) == 1:
|
|
107
|
+
t = x.tensor.min(taxes[0], keepdim=self.keepdims)
|
|
108
|
+
return OpRunTensor(t.values)
|
|
109
|
+
t = x.tensor
|
|
110
|
+
for a in reversed(taxes):
|
|
111
|
+
t = t.min(a, keepdim=self.keepdims).values
|
|
112
|
+
return OpRunTensor(t)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class ReduceSum_13(ReduceOp):
|
|
116
|
+
"""ReduceSum"""
|
|
117
|
+
|
|
118
|
+
def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
|
|
119
|
+
assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
|
|
120
|
+
if axes is None:
|
|
121
|
+
assert (
|
|
122
|
+
not self.keepdims
|
|
123
|
+
), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
|
|
124
|
+
return OpRunTensor(torch.sum(x.tensor))
|
|
125
|
+
taxes = axes.as_tuple_int
|
|
126
|
+
if len(taxes) == 1:
|
|
127
|
+
t = x.tensor.sum(taxes[0], keepdim=self.keepdims)
|
|
128
|
+
return OpRunTensor(t)
|
|
129
|
+
t = x.tensor.sum(taxes, keepdim=self.keepdims)
|
|
130
|
+
return OpRunTensor(t)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
import onnx
|
|
3
|
+
import torch
|
|
4
|
+
from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
|
|
5
|
+
from . import OpRunKernel, OpRunSequence, OpRunTensor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class OpRunOpSequence(OpRunKernel):
|
|
9
|
+
"Ancestor for kernel using sequences."
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ConcatFromSequence_11(OpRunOpSequence):
|
|
13
|
+
"ConcatFromSequence"
|
|
14
|
+
|
|
15
|
+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
|
|
16
|
+
super().__init__(node, version, verbose=verbose)
|
|
17
|
+
axis = self.get_attribute_int(node, "axis", None)
|
|
18
|
+
assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}"
|
|
19
|
+
self.axis = axis
|
|
20
|
+
self.new_axis = self.get_attribute_int(node, "new_axis", 0)
|
|
21
|
+
|
|
22
|
+
def run(self, input_sequence: OpRunSequence) -> OpRunTensor:
|
|
23
|
+
assert isinstance(
|
|
24
|
+
input_sequence, OpRunSequence
|
|
25
|
+
), f"Unexpected type {type(input_sequence)} for input_sequence"
|
|
26
|
+
seq = input_sequence.sequence
|
|
27
|
+
if self.new_axis == 1:
|
|
28
|
+
if self.axis == -1:
|
|
29
|
+
seq2 = [s.unsqueeze(len(s.shape)) for s in seq]
|
|
30
|
+
res = torch.cat(seq2, axis=-1)
|
|
31
|
+
else:
|
|
32
|
+
seq2 = [s.expand(self.axis) for s in seq]
|
|
33
|
+
res = torch.cat(seq2, axis=self.axis)
|
|
34
|
+
else:
|
|
35
|
+
res = torch.cat(seq, axis=self.axis)
|
|
36
|
+
return OpRunTensor(res)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class SequenceEmpty_11(OpRunOpSequence):
|
|
40
|
+
"SqeuenceEmpty"
|
|
41
|
+
|
|
42
|
+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
|
|
43
|
+
super().__init__(node, version, verbose=verbose)
|
|
44
|
+
self.dtype = onnx_dtype_to_torch_dtype(
|
|
45
|
+
self.get_attribute_int(node, "dtype", onnx.TensorProto.FLOAT) # type: ignore[arg-type]
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
def run(self) -> OpRunSequence:
|
|
49
|
+
return OpRunSequence(dtype=self.dtype)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class SequenceInsert_11(OpRunOpSequence):
|
|
53
|
+
"SqeuenceInsert"
|
|
54
|
+
|
|
55
|
+
def run(
|
|
56
|
+
self,
|
|
57
|
+
input_sequence: OpRunSequence,
|
|
58
|
+
tensor: OpRunTensor,
|
|
59
|
+
position: Optional[OpRunTensor] = None,
|
|
60
|
+
) -> OpRunSequence:
|
|
61
|
+
assert isinstance(input_sequence, OpRunSequence), (
|
|
62
|
+
f"Unexpected type {type(input_sequence)} for input_sequence: "
|
|
63
|
+
f"{input_sequence.string_type()}"
|
|
64
|
+
)
|
|
65
|
+
return input_sequence.insert_at(tensor, position)
|