onnx-diagnostic 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +18 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/ext_test_case.py +3 -1
- onnx_diagnostic/helpers/args_helper.py +1 -1
- onnx_diagnostic/helpers/doc_helper.py +143 -0
- onnx_diagnostic/helpers/helper.py +6 -5
- onnx_diagnostic/helpers/model_builder_helper.py +24 -8
- onnx_diagnostic/helpers/rt_helper.py +5 -1
- onnx_diagnostic/helpers/torch_helper.py +2 -0
- onnx_diagnostic/reference/__init__.py +1 -0
- onnx_diagnostic/reference/torch_evaluator.py +648 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +55 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +86 -0
- onnx_diagnostic/tasks/__init__.py +22 -1
- onnx_diagnostic/tasks/image_classification.py +2 -2
- onnx_diagnostic/tasks/text_generation.py +3 -3
- onnx_diagnostic/torch_export_patches/eval/__init__.py +106 -37
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +12 -25
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +130 -16
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +88 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +142 -0
- onnx_diagnostic/torch_models/test_helper.py +133 -16
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/RECORD +39 -23
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/WHEEL +1 -1
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
import onnx
|
|
3
|
+
import torch
|
|
4
|
+
from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
|
|
5
|
+
from . import OpRunKernel, OpRunTensor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AveragePool_11(OpRunKernel):
|
|
9
|
+
"AveragePool"
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
node: onnx.NodeProto,
|
|
14
|
+
version: Optional[int] = None,
|
|
15
|
+
verbose: int = 0,
|
|
16
|
+
):
|
|
17
|
+
super().__init__(node, version, verbose=verbose)
|
|
18
|
+
self.auto_pad = self.get_attribute_string(node, "auto_pad", "NOTSET")
|
|
19
|
+
self.ceil_mode = bool(self.get_attribute_int(node, "ceil_mode", 0))
|
|
20
|
+
self.count_include_pad = bool(self.get_attribute_int(node, "count_include_pad", 0))
|
|
21
|
+
self.dilations = self.get_attribute_ints(node, "dilations", None)
|
|
22
|
+
self.kernel_shape: Tuple[int, ...] = (
|
|
23
|
+
self.get_attribute_ints(node, "kernel_shape") or tuple()
|
|
24
|
+
)
|
|
25
|
+
self.pads = self.get_attribute_ints(node, "pads", None)
|
|
26
|
+
self.strides = self.get_attribute_ints(node, "strides", None)
|
|
27
|
+
|
|
28
|
+
def run(self, x):
|
|
29
|
+
kernel_shape = self.kernel_shape
|
|
30
|
+
dilations = self.dilations or [1 for _ in x.shape[2:]]
|
|
31
|
+
strides = self.strides or [1 for _ in x.shape[2:]]
|
|
32
|
+
pads = self.pads or ([0 for _ in x.shape[2:]] * 2)
|
|
33
|
+
assert (
|
|
34
|
+
self.auto_pad == "NOTSET"
|
|
35
|
+
), f"conv not implemented for auto_pad={self.auto_pad!r}"
|
|
36
|
+
assert len(set(pads)) == 1, f"conv not implemented for pads={pads}"
|
|
37
|
+
assert set(dilations) == {1}, f"conv not implemented for dilations={dilations}"
|
|
38
|
+
avg_pool = getattr(torch.nn.functional, f"avg_pool{len(kernel_shape)}d")
|
|
39
|
+
return OpRunTensor(
|
|
40
|
+
avg_pool(
|
|
41
|
+
x.tensor,
|
|
42
|
+
kernel_size=tuple(kernel_shape),
|
|
43
|
+
stride=tuple(strides),
|
|
44
|
+
padding=pads[0],
|
|
45
|
+
ceil_mode=self.ceil_mode,
|
|
46
|
+
count_include_pad=self.count_include_pad,
|
|
47
|
+
# dilation=tuple(dilations),
|
|
48
|
+
)
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Conv_11(OpRunKernel):
|
|
53
|
+
"Conv"
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
node: onnx.NodeProto,
|
|
58
|
+
version: Optional[int] = None,
|
|
59
|
+
verbose: int = 0,
|
|
60
|
+
):
|
|
61
|
+
super().__init__(node, version, verbose=verbose)
|
|
62
|
+
self.auto_pad = self.get_attribute_string(node, "auto_pad", "NOTSET")
|
|
63
|
+
self.dilations = self.get_attribute_ints(node, "dilations", None)
|
|
64
|
+
self.group = self.get_attribute_int(node, "group", 1)
|
|
65
|
+
self.kernel_shape: Tuple[int, ...] = (
|
|
66
|
+
self.get_attribute_ints(node, "kernel_shape") or tuple()
|
|
67
|
+
)
|
|
68
|
+
self.pads = self.get_attribute_ints(node, "pads", None)
|
|
69
|
+
self.strides = self.get_attribute_ints(node, "strides", None)
|
|
70
|
+
|
|
71
|
+
def run(self, x, w, b=None):
|
|
72
|
+
kernel_shape = self.kernel_shape or w.shape[2:]
|
|
73
|
+
assert (
|
|
74
|
+
tuple(kernel_shape) == w.shape[-len(kernel_shape) :]
|
|
75
|
+
), f"conv not implemented for kernel_shape={kernel_shape} and w.shape={w.shape}"
|
|
76
|
+
dilations = self.dilations or [1 for _ in x.shape[2:]]
|
|
77
|
+
strides = self.strides or [1 for _ in x.shape[2:]]
|
|
78
|
+
|
|
79
|
+
if self.auto_pad in {"SAME_LOWER", "SAME_UPPER"}:
|
|
80
|
+
head = []
|
|
81
|
+
tail = []
|
|
82
|
+
for i in range(len(x.shape) - 2):
|
|
83
|
+
d = x.shape[i + 2]
|
|
84
|
+
target_size = (d + strides[i] - 1) // strides[i]
|
|
85
|
+
pad_needed = (target_size - 1) * strides[i] + kernel_shape[i] - d
|
|
86
|
+
pad_head = (
|
|
87
|
+
(pad_needed + 1) // 2 if self.auto_pad == "SAME_LOWER" else pad_needed // 2
|
|
88
|
+
)
|
|
89
|
+
pad_tail = pad_needed - pad_head
|
|
90
|
+
head.append(pad_head)
|
|
91
|
+
tail.append(pad_tail)
|
|
92
|
+
pads = head + tail
|
|
93
|
+
else:
|
|
94
|
+
pads = self.pads or ([0 for _ in x.shape[2:]] * 2)
|
|
95
|
+
|
|
96
|
+
assert len(set(pads)) == 1, (
|
|
97
|
+
f"conv not implemented for pads={pads}, "
|
|
98
|
+
f"auto_pad={self.auto_pad!r}, strides={strides}, "
|
|
99
|
+
f"x.shape={x.shape}, kernel_shape={kernel_shape}"
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
if b is None:
|
|
103
|
+
bias = None
|
|
104
|
+
else:
|
|
105
|
+
bias = b.tensor.squeeze()
|
|
106
|
+
if not bias.shape:
|
|
107
|
+
bias = bias.unsqueeze(0)
|
|
108
|
+
return OpRunTensor(
|
|
109
|
+
torch.nn.functional.conv2d(
|
|
110
|
+
x.tensor,
|
|
111
|
+
w.tensor,
|
|
112
|
+
bias=bias,
|
|
113
|
+
stride=tuple(strides),
|
|
114
|
+
padding=pads[0],
|
|
115
|
+
dilation=tuple(dilations),
|
|
116
|
+
groups=self.group,
|
|
117
|
+
)
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class LayerNormalization_17(OpRunKernel):
|
|
122
|
+
"LayerNormalization"
|
|
123
|
+
|
|
124
|
+
def __init__(
|
|
125
|
+
self,
|
|
126
|
+
node: onnx.NodeProto,
|
|
127
|
+
version: Optional[int] = None,
|
|
128
|
+
verbose: int = 0,
|
|
129
|
+
):
|
|
130
|
+
super().__init__(node, version, verbose=verbose)
|
|
131
|
+
self.axis = self.get_attribute_int(node, "axis", -1)
|
|
132
|
+
self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
|
|
133
|
+
self.stash_type = onnx_dtype_to_torch_dtype(
|
|
134
|
+
self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT) # type: ignore[arg-type]
|
|
135
|
+
)
|
|
136
|
+
self.compute_std = len(node.output) > 1
|
|
137
|
+
|
|
138
|
+
def run(self, x, scale, bias=None):
|
|
139
|
+
original_dtype = x.dtype
|
|
140
|
+
if self.stash_type == torch.float32 and x.tensor.dtype != torch.float64:
|
|
141
|
+
xt = x.tensor
|
|
142
|
+
res = torch.nn.functional.layer_norm(
|
|
143
|
+
xt,
|
|
144
|
+
xt.shape[self.axis :],
|
|
145
|
+
weight=scale.tensor,
|
|
146
|
+
bias=None if bias is None else bias.tensor,
|
|
147
|
+
eps=self.epsilon,
|
|
148
|
+
)
|
|
149
|
+
else:
|
|
150
|
+
xt = x.tensor.to(self.stash_type)
|
|
151
|
+
res = torch.nn.functional.layer_norm(
|
|
152
|
+
xt,
|
|
153
|
+
xt.shape[self.axis :],
|
|
154
|
+
weight=scale.tensor.to(self.stash_type),
|
|
155
|
+
bias=None if bias is None else bias.tensor.to(self.stash_type),
|
|
156
|
+
eps=self.epsilon,
|
|
157
|
+
)
|
|
158
|
+
if not self.compute_std:
|
|
159
|
+
return OpRunTensor(res.to(original_dtype))
|
|
160
|
+
axes = tuple(range(len(xt.shape)))[self.axis :]
|
|
161
|
+
mean, var = torch.var(xt, dim=axes, keepdim=False)
|
|
162
|
+
x_inv_std_dev = torch.reciprocal(torch.sqrt(var + self.epsilon))
|
|
163
|
+
return (
|
|
164
|
+
OpRunTensor(res.to(original_dtype)),
|
|
165
|
+
OpRunTensor(mean),
|
|
166
|
+
OpRunTensor(x_inv_std_dev),
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class Softmax_13(OpRunKernel):
|
|
171
|
+
"Softmax"
|
|
172
|
+
|
|
173
|
+
def __init__(
|
|
174
|
+
self,
|
|
175
|
+
node: onnx.NodeProto,
|
|
176
|
+
version: Optional[int] = None,
|
|
177
|
+
verbose: int = 0,
|
|
178
|
+
):
|
|
179
|
+
super().__init__(node, version, verbose=verbose)
|
|
180
|
+
self.axis = self.get_attribute_int(node, "axis", -1)
|
|
181
|
+
assert isinstance(self.axis, int), f"Unexpected value for attribute axis={self.axis!r}"
|
|
182
|
+
# this is out of spec
|
|
183
|
+
stash_type = self.get_attribute_int(node, "stash_type", None)
|
|
184
|
+
self.stash_type = None if stash_type is None else onnx_dtype_to_torch_dtype(stash_type)
|
|
185
|
+
|
|
186
|
+
def run(self, data: OpRunTensor) -> OpRunTensor:
|
|
187
|
+
return OpRunTensor(
|
|
188
|
+
torch.nn.functional.softmax(data.tensor, dim=self.axis, dtype=self.stash_type)
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class Tanh_6(OpRunKernel):
|
|
193
|
+
"Tanh"
|
|
194
|
+
|
|
195
|
+
def run(self, data: OpRunTensor) -> OpRunTensor:
|
|
196
|
+
return OpRunTensor(torch.nn.functional.tanh(data.tensor))
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
import onnx
|
|
3
|
+
import torch
|
|
4
|
+
from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
|
|
5
|
+
from . import OpRunKernel, OpRunTensor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Cast_6(OpRunKernel):
|
|
9
|
+
"Cast"
|
|
10
|
+
|
|
11
|
+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
|
|
12
|
+
super().__init__(node, version, verbose=verbose)
|
|
13
|
+
to = self.get_attribute_int(node, "to", 0)
|
|
14
|
+
assert isinstance(to, int), f"Unexpected value for attribute to={to!r}"
|
|
15
|
+
self.to = onnx_dtype_to_torch_dtype(to)
|
|
16
|
+
self.saturate = self.get_attribute_int(node, "saturate", 1)
|
|
17
|
+
assert self.saturate == 1, f"saturate={self.saturate} not implemented for Cast"
|
|
18
|
+
|
|
19
|
+
def run(self, data: OpRunTensor) -> OpRunTensor:
|
|
20
|
+
return OpRunTensor(data.tensor.to(self.to))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class CastLike_15(OpRunKernel):
|
|
24
|
+
"Cast"
|
|
25
|
+
|
|
26
|
+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
|
|
27
|
+
super().__init__(node, version, verbose=verbose)
|
|
28
|
+
self.saturate = self.get_attribute_int(node, "saturate", 1)
|
|
29
|
+
assert self.saturate == 1, f"saturate={self.saturate} not implemented for CastLike"
|
|
30
|
+
|
|
31
|
+
def run(self, data: OpRunTensor, like: OpRunTensor) -> OpRunTensor:
|
|
32
|
+
return OpRunTensor(data.tensor.to(like.tensor.dtype))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Concat_1(OpRunKernel):
|
|
36
|
+
"Concat"
|
|
37
|
+
|
|
38
|
+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
|
|
39
|
+
super().__init__(node, version, verbose=verbose)
|
|
40
|
+
axis = self.get_attribute_int(node, "axis", 0)
|
|
41
|
+
assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}"
|
|
42
|
+
self.axis = axis
|
|
43
|
+
|
|
44
|
+
def run(self, *data: OpRunTensor) -> OpRunTensor:
|
|
45
|
+
assert data, f"No tensor to concatenate in node name {self.name!r}"
|
|
46
|
+
devices = [d.get_device() for d in data]
|
|
47
|
+
if len(set(devices)) == 1:
|
|
48
|
+
return OpRunTensor(torch.cat([t.tensor for t in data], axis=self.axis))
|
|
49
|
+
if (
|
|
50
|
+
data[0].dtype == torch.int64
|
|
51
|
+
and self.axis == 0
|
|
52
|
+
and max(d.tensor.ndim for d in data) == 1
|
|
53
|
+
and max(d.tensor.numel() for d in data) <= 8
|
|
54
|
+
):
|
|
55
|
+
# This is a shape
|
|
56
|
+
return OpRunTensor(torch.cat([t.tensor.cpu() for t in data], axis=self.axis))
|
|
57
|
+
index = devices.index(max(devices))
|
|
58
|
+
device = data[index].tensor.device
|
|
59
|
+
return OpRunTensor(torch.cat([t.tensor.to(device) for t in data], axis=self.axis))
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class NonZero_13(OpRunKernel):
|
|
63
|
+
"NonZero"
|
|
64
|
+
|
|
65
|
+
def run(self, x: OpRunTensor) -> OpRunTensor:
|
|
66
|
+
return OpRunTensor(torch.nonzero(x.tensor).T)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class Tile_6(OpRunKernel):
|
|
70
|
+
"Tile"
|
|
71
|
+
|
|
72
|
+
def run(self, x: OpRunTensor, repeat: OpRunTensor) -> OpRunTensor:
|
|
73
|
+
return OpRunTensor(torch.tile(x.tensor, repeat.as_tuple_int))
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class Transpose_1(OpRunKernel):
|
|
77
|
+
"Transpose"
|
|
78
|
+
|
|
79
|
+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
|
|
80
|
+
super().__init__(node, version, verbose=verbose)
|
|
81
|
+
self.perm = self.get_attribute_ints(node, "perm", None)
|
|
82
|
+
|
|
83
|
+
def run(self, data: OpRunTensor) -> OpRunTensor:
|
|
84
|
+
return OpRunTensor(torch.permute(data.tensor, self.perm))
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class Trilu_14(OpRunKernel):
|
|
88
|
+
"Trilu"
|
|
89
|
+
|
|
90
|
+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
|
|
91
|
+
super().__init__(node, version, verbose=verbose)
|
|
92
|
+
self.upper = self.get_attribute_int(node, "upper", 1)
|
|
93
|
+
|
|
94
|
+
def run(self, data: OpRunTensor, k: Optional[OpRunTensor] = None) -> OpRunTensor:
|
|
95
|
+
diagonal = 0 if k is None else k.tensor.item()
|
|
96
|
+
if self.upper:
|
|
97
|
+
return OpRunTensor(torch.triu(data.tensor, diagonal=diagonal))
|
|
98
|
+
return OpRunTensor(torch.tril(data.tensor, diagonal=diagonal))
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class Where_9(OpRunKernel):
|
|
102
|
+
"Where"
|
|
103
|
+
|
|
104
|
+
def run(self, cond: OpRunTensor, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
|
|
105
|
+
tcond, tx, ty = self.same_device(cond.tensor, x.tensor, y.tensor)
|
|
106
|
+
return OpRunTensor(torch.where(tcond, tx, ty))
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
import onnx
|
|
3
|
+
import torch
|
|
4
|
+
from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
|
|
5
|
+
from . import OpRunKernel, OpRunTensor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ReduceOp(OpRunKernel):
|
|
9
|
+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
|
|
10
|
+
super().__init__(node, version, verbose=verbose)
|
|
11
|
+
self.keepdims = bool(self.get_attribute_int(node, "keepdims", 1))
|
|
12
|
+
self.noop_with_empty_axes = bool(
|
|
13
|
+
self.get_attribute_int(node, "noop_with_empty_axes", 0)
|
|
14
|
+
)
|
|
15
|
+
assert isinstance(
|
|
16
|
+
self.keepdims, bool
|
|
17
|
+
), f"Unexpected value for attribute keepdims={self.keepdims!r}"
|
|
18
|
+
assert isinstance(self.noop_with_empty_axes, bool), (
|
|
19
|
+
f"Unexpected value for attribute "
|
|
20
|
+
f"noop_with_empty_axes={self.noop_with_empty_axes!r}"
|
|
21
|
+
)
|
|
22
|
+
assert (
|
|
23
|
+
not self.noop_with_empty_axes
|
|
24
|
+
), f"Not implemented with noop_with_empty_axes={self.noop_with_empty_axes}"
|
|
25
|
+
# this is out of spec
|
|
26
|
+
stash_type = self.get_attribute_int(node, "stash_type", None)
|
|
27
|
+
self.stash_type = None if stash_type is None else onnx_dtype_to_torch_dtype(stash_type)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ReduceOpAxes(ReduceOp):
|
|
31
|
+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
|
|
32
|
+
super().__init__(node, version, verbose=verbose)
|
|
33
|
+
self.axes: Tuple[int, ...] = self.get_attribute_ints(node, "axes") or tuple()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ReduceMax_18(ReduceOp):
|
|
37
|
+
"""ReduceMax"""
|
|
38
|
+
|
|
39
|
+
def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
|
|
40
|
+
assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
|
|
41
|
+
if axes is None:
|
|
42
|
+
assert (
|
|
43
|
+
not self.keepdims
|
|
44
|
+
), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
|
|
45
|
+
return OpRunTensor(x.tensor.max())
|
|
46
|
+
taxes = axes.as_tuple_int
|
|
47
|
+
if len(taxes) == 1:
|
|
48
|
+
t = x.tensor.max(taxes[0], keepdim=self.keepdims)
|
|
49
|
+
return OpRunTensor(t.values)
|
|
50
|
+
t = x.tensor
|
|
51
|
+
for a in reversed(taxes):
|
|
52
|
+
t = t.max(a, keepdim=self.keepdims).values
|
|
53
|
+
return OpRunTensor(t)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class ReduceMean_18(ReduceOp):
|
|
57
|
+
"""ReduceMean"""
|
|
58
|
+
|
|
59
|
+
def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
|
|
60
|
+
assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
|
|
61
|
+
if axes is None:
|
|
62
|
+
assert (
|
|
63
|
+
not self.keepdims
|
|
64
|
+
), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
|
|
65
|
+
return OpRunTensor(torch.mean(x.tensor))
|
|
66
|
+
taxes = axes.as_tuple_int
|
|
67
|
+
if len(taxes) == 1:
|
|
68
|
+
t = x.tensor.mean(taxes[0], keepdim=self.keepdims)
|
|
69
|
+
return OpRunTensor(t)
|
|
70
|
+
t = x.tensor.mean(taxes, keepdim=self.keepdims)
|
|
71
|
+
return OpRunTensor(t)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class ReduceMin_17(ReduceOpAxes):
|
|
75
|
+
"""ReduceMin"""
|
|
76
|
+
|
|
77
|
+
def run(self, x: OpRunTensor) -> OpRunTensor:
|
|
78
|
+
assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
|
|
79
|
+
axes = self.axes
|
|
80
|
+
if not axes:
|
|
81
|
+
assert (
|
|
82
|
+
not self.keepdims
|
|
83
|
+
), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
|
|
84
|
+
return OpRunTensor(x.tensor.min())
|
|
85
|
+
taxes = tuple(axes)
|
|
86
|
+
if len(taxes) == 1:
|
|
87
|
+
t = x.tensor.min(taxes[0], keepdim=self.keepdims)
|
|
88
|
+
return OpRunTensor(t.values)
|
|
89
|
+
t = x.tensor
|
|
90
|
+
for a in reversed(taxes):
|
|
91
|
+
t = t.min(a, keepdim=self.keepdims).values
|
|
92
|
+
return OpRunTensor(t)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class ReduceMin_18(ReduceOp):
|
|
96
|
+
"""ReduceMin"""
|
|
97
|
+
|
|
98
|
+
def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
|
|
99
|
+
assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
|
|
100
|
+
if axes is None:
|
|
101
|
+
assert (
|
|
102
|
+
not self.keepdims
|
|
103
|
+
), f"axes is empty, keepdims={self.keepdims} for {self.__class__.__name__}"
|
|
104
|
+
return OpRunTensor(torch.min(x.tensor))
|
|
105
|
+
taxes = axes.as_tuple_int
|
|
106
|
+
if len(taxes) == 1:
|
|
107
|
+
t = x.tensor.min(taxes[0], keepdim=self.keepdims)
|
|
108
|
+
return OpRunTensor(t.values)
|
|
109
|
+
t = x.tensor
|
|
110
|
+
for a in reversed(taxes):
|
|
111
|
+
t = t.min(a, keepdim=self.keepdims).values
|
|
112
|
+
return OpRunTensor(t)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class ReduceSum_13(ReduceOp):
|
|
116
|
+
"""ReduceSum"""
|
|
117
|
+
|
|
118
|
+
def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
|
|
119
|
+
assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
|
|
120
|
+
if axes is None:
|
|
121
|
+
assert (
|
|
122
|
+
not self.keepdims
|
|
123
|
+
), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
|
|
124
|
+
return OpRunTensor(torch.sum(x.tensor))
|
|
125
|
+
taxes = axes.as_tuple_int
|
|
126
|
+
if len(taxes) == 1:
|
|
127
|
+
t = x.tensor.sum(taxes[0], keepdim=self.keepdims)
|
|
128
|
+
return OpRunTensor(t)
|
|
129
|
+
t = x.tensor.sum(taxes, keepdim=self.keepdims)
|
|
130
|
+
return OpRunTensor(t)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
import onnx
|
|
3
|
+
import torch
|
|
4
|
+
from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
|
|
5
|
+
from . import OpRunKernel, OpRunSequence, OpRunTensor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class OpRunOpSequence(OpRunKernel):
|
|
9
|
+
"Ancestor for kernel using sequences."
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ConcatFromSequence_11(OpRunOpSequence):
|
|
13
|
+
"ConcatFromSequence"
|
|
14
|
+
|
|
15
|
+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
|
|
16
|
+
super().__init__(node, version, verbose=verbose)
|
|
17
|
+
axis = self.get_attribute_int(node, "axis", None)
|
|
18
|
+
assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}"
|
|
19
|
+
self.axis = axis
|
|
20
|
+
self.new_axis = self.get_attribute_int(node, "new_axis", 0)
|
|
21
|
+
|
|
22
|
+
def run(self, input_sequence: OpRunSequence) -> OpRunTensor:
|
|
23
|
+
assert isinstance(
|
|
24
|
+
input_sequence, OpRunSequence
|
|
25
|
+
), f"Unexpected type {type(input_sequence)} for input_sequence"
|
|
26
|
+
seq = input_sequence.sequence
|
|
27
|
+
if self.new_axis == 1:
|
|
28
|
+
if self.axis == -1:
|
|
29
|
+
seq2 = [s.unsqueeze(len(s.shape)) for s in seq]
|
|
30
|
+
res = torch.cat(seq2, axis=-1)
|
|
31
|
+
else:
|
|
32
|
+
seq2 = [s.expand(self.axis) for s in seq]
|
|
33
|
+
res = torch.cat(seq2, axis=self.axis)
|
|
34
|
+
else:
|
|
35
|
+
res = torch.cat(seq, axis=self.axis)
|
|
36
|
+
return OpRunTensor(res)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class SequenceEmpty_11(OpRunOpSequence):
|
|
40
|
+
"SqeuenceEmpty"
|
|
41
|
+
|
|
42
|
+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
|
|
43
|
+
super().__init__(node, version, verbose=verbose)
|
|
44
|
+
self.dtype = onnx_dtype_to_torch_dtype(
|
|
45
|
+
self.get_attribute_int(node, "dtype", onnx.TensorProto.FLOAT) # type: ignore[arg-type]
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
def run(self) -> OpRunSequence:
|
|
49
|
+
return OpRunSequence(dtype=self.dtype)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class SequenceInsert_11(OpRunOpSequence):
|
|
53
|
+
"SqeuenceInsert"
|
|
54
|
+
|
|
55
|
+
def run(
|
|
56
|
+
self,
|
|
57
|
+
input_sequence: OpRunSequence,
|
|
58
|
+
tensor: OpRunTensor,
|
|
59
|
+
position: Optional[OpRunTensor] = None,
|
|
60
|
+
) -> OpRunSequence:
|
|
61
|
+
assert isinstance(input_sequence, OpRunSequence), (
|
|
62
|
+
f"Unexpected type {type(input_sequence)} for input_sequence: "
|
|
63
|
+
f"{input_sequence.string_type()}"
|
|
64
|
+
)
|
|
65
|
+
return input_sequence.insert_at(tensor, position)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
import onnx
|
|
3
|
+
import torch
|
|
4
|
+
from . import OpRunKernel, OpRunTensor
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ConstantOfShape_9(OpRunKernel):
|
|
8
|
+
"ConstantOfShape"
|
|
9
|
+
|
|
10
|
+
@classmethod
|
|
11
|
+
def device_dependent(cls) -> bool:
|
|
12
|
+
"""
|
|
13
|
+
Returns True if the kernel needs a device to be efficiently initialized.
|
|
14
|
+
"""
|
|
15
|
+
return True
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
node: onnx.NodeProto,
|
|
20
|
+
version: Optional[int] = None,
|
|
21
|
+
device: Optional[torch.device] = None,
|
|
22
|
+
verbose: int = 0,
|
|
23
|
+
):
|
|
24
|
+
super().__init__(node, version, verbose=verbose)
|
|
25
|
+
value = self.get_attribute_tensor(node, "value")
|
|
26
|
+
if value is None:
|
|
27
|
+
value = torch.tensor([0], dtype=torch.float32)
|
|
28
|
+
self.dtype = value.dtype
|
|
29
|
+
self.device = device
|
|
30
|
+
self.value = value[0]
|
|
31
|
+
|
|
32
|
+
def run(self, shape: OpRunTensor) -> OpRunTensor:
|
|
33
|
+
# The device is unknown as shapes usually take place on CPU.
|
|
34
|
+
return OpRunTensor(
|
|
35
|
+
torch.full(
|
|
36
|
+
shape.as_tuple_int, fill_value=self.value, dtype=self.dtype, device=self.device
|
|
37
|
+
)
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class Expand_8(OpRunKernel):
|
|
42
|
+
"Expand"
|
|
43
|
+
|
|
44
|
+
def run(self, data: OpRunTensor, shape: OpRunTensor) -> OpRunTensor:
|
|
45
|
+
ishape = tuple(-1 if i == 1 else i for i in shape.as_tuple_int)
|
|
46
|
+
return OpRunTensor(data.tensor.expand(ishape))
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Reshape_14(OpRunKernel):
|
|
50
|
+
"Reshape"
|
|
51
|
+
|
|
52
|
+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
|
|
53
|
+
super().__init__(node, version, verbose=verbose)
|
|
54
|
+
self.allowzero = self.get_attribute_int(node, "allowzero", 0)
|
|
55
|
+
|
|
56
|
+
def run(self, data: OpRunTensor, shape: OpRunTensor) -> OpRunTensor:
|
|
57
|
+
ishape = shape.as_tuple_int
|
|
58
|
+
assert ishape is not None, f"Unexpected return for shape={shape!r}"
|
|
59
|
+
if not self.allowzero and 0 in ishape:
|
|
60
|
+
xshape = data.tensor.shape
|
|
61
|
+
new_shape = []
|
|
62
|
+
for i, s in enumerate(ishape):
|
|
63
|
+
new_shape.append(xshape[i] if s == 0 else s)
|
|
64
|
+
return OpRunTensor(data.tensor.reshape(new_shape))
|
|
65
|
+
return OpRunTensor(data.tensor.reshape(ishape))
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class Shape_15(OpRunKernel):
|
|
69
|
+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
|
|
70
|
+
super().__init__(node, version, verbose=verbose)
|
|
71
|
+
self.start = self.get_attribute_int(node, "start", 0)
|
|
72
|
+
self.end = self.get_attribute_int(node, "end", None)
|
|
73
|
+
|
|
74
|
+
def run(self, data: OpRunTensor) -> OpRunTensor:
|
|
75
|
+
shape = data.shape
|
|
76
|
+
sh = shape[self.start :] if self.end is None else shape[self.start : self.end]
|
|
77
|
+
return OpRunTensor(torch.tensor(sh, dtype=torch.int64), is_constant=True)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class Split_18(OpRunKernel):
|
|
81
|
+
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
|
|
82
|
+
super().__init__(node, version, verbose=verbose)
|
|
83
|
+
self.axis = self.get_attribute_int(node, "axis", 0)
|
|
84
|
+
self.num_outputs = self.get_attribute_int(node, "num_outputs", None)
|
|
85
|
+
|
|
86
|
+
def run(
|
|
87
|
+
self, data: OpRunTensor, split: Optional[OpRunTensor] = None
|
|
88
|
+
) -> Tuple[OpRunTensor, ...]:
|
|
89
|
+
if split is None:
|
|
90
|
+
assert isinstance(
|
|
91
|
+
self.num_outputs, int
|
|
92
|
+
), f"Incompatibilities: split is None and num_outputs={self.num_outputs}"
|
|
93
|
+
size = data.tensor.shape[self.axis]
|
|
94
|
+
split_size = (
|
|
95
|
+
size // self.num_outputs
|
|
96
|
+
if size % self.num_outputs == 0
|
|
97
|
+
else size // self.num_outputs + 1
|
|
98
|
+
)
|
|
99
|
+
spl = torch.split(data.tensor, split_size, dim=self.axis)
|
|
100
|
+
else:
|
|
101
|
+
spl = torch.split(data.tensor, split.as_tuple_int, dim=self.axis)
|
|
102
|
+
return tuple(OpRunTensor(t) for t in spl)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class Squeeze_13(OpRunKernel):
|
|
106
|
+
"Squeeze"
|
|
107
|
+
|
|
108
|
+
def run(self, data: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
|
|
109
|
+
if axes is None:
|
|
110
|
+
return OpRunTensor(data.tensor.squeeze())
|
|
111
|
+
return OpRunTensor(data.tensor.squeeze(axes.as_tuple_int))
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class Unsqueeze_13(OpRunKernel):
|
|
115
|
+
"Unsqueeze"
|
|
116
|
+
|
|
117
|
+
def run(self, data: OpRunTensor, axes: OpRunTensor) -> OpRunTensor:
|
|
118
|
+
t = data.tensor
|
|
119
|
+
for i in axes.as_tuple_int:
|
|
120
|
+
t = t.unsqueeze(i)
|
|
121
|
+
return OpRunTensor(t)
|