onnx-diagnostic 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +18 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/ext_test_case.py +3 -1
- onnx_diagnostic/helpers/args_helper.py +1 -1
- onnx_diagnostic/helpers/doc_helper.py +143 -0
- onnx_diagnostic/helpers/helper.py +6 -5
- onnx_diagnostic/helpers/model_builder_helper.py +24 -8
- onnx_diagnostic/helpers/rt_helper.py +5 -1
- onnx_diagnostic/helpers/torch_helper.py +2 -0
- onnx_diagnostic/reference/__init__.py +1 -0
- onnx_diagnostic/reference/torch_evaluator.py +648 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +55 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +86 -0
- onnx_diagnostic/tasks/__init__.py +22 -1
- onnx_diagnostic/tasks/image_classification.py +2 -2
- onnx_diagnostic/tasks/text_generation.py +3 -3
- onnx_diagnostic/torch_export_patches/eval/__init__.py +106 -37
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +12 -25
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +130 -16
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +88 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +142 -0
- onnx_diagnostic/torch_models/test_helper.py +133 -16
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/RECORD +39 -23
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/WHEEL +1 -1
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional, Union, Tuple
|
|
2
|
+
import onnx
|
|
3
|
+
import torch
|
|
4
|
+
from ...api import TensorLike
|
|
5
|
+
from ...helpers import string_type
|
|
6
|
+
from ...helpers.torch_helper import to_tensor
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class OpRunValue(TensorLike):
|
|
10
|
+
"""Defines a value for the runtime, a tensor or a sequence."""
|
|
11
|
+
|
|
12
|
+
__slots__ = ("cached", "is_constant", "sequence", "tensor")
|
|
13
|
+
|
|
14
|
+
@classmethod
|
|
15
|
+
def is_sequence(cls) -> bool:
|
|
16
|
+
"Tells if it is sequence."
|
|
17
|
+
raise NotImplementedError("is_sequence must be overwritten.")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class OpRunTensor(OpRunValue):
|
|
21
|
+
"""
|
|
22
|
+
Wrapper around a tensor.
|
|
23
|
+
|
|
24
|
+
:param tensor: torch.Tensor
|
|
25
|
+
:param is_constant: is it a constant
|
|
26
|
+
:param may_cpu: change the device the tensor is if
|
|
27
|
+
more appropriate
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, tensor, is_constant: bool = False, may_cpu: bool = False):
|
|
31
|
+
assert isinstance(tensor, torch.Tensor), (
|
|
32
|
+
f"Unexpected type {type(tensor)}, "
|
|
33
|
+
f"__name__={getattr(tensor, '__name__', 'no name')}"
|
|
34
|
+
)
|
|
35
|
+
assert tensor is None or tensor.numel() != 1 or tensor.item() != -666666
|
|
36
|
+
self.tensor = (
|
|
37
|
+
tensor.cpu()
|
|
38
|
+
if may_cpu
|
|
39
|
+
and len(tensor.shape) == 1
|
|
40
|
+
and tensor.numel() < 8
|
|
41
|
+
and tensor.dtype == torch.int64
|
|
42
|
+
and tensor.get_device() >= 0
|
|
43
|
+
else tensor
|
|
44
|
+
)
|
|
45
|
+
self.is_constant = is_constant
|
|
46
|
+
self.cached: Optional[Tuple[int, ...]] = None
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
def is_sequence(cls) -> bool:
|
|
50
|
+
"Tells if it is sequence."
|
|
51
|
+
return False
|
|
52
|
+
|
|
53
|
+
def to(self, to: Any) -> "OpRunTensor":
|
|
54
|
+
"Changes the device."
|
|
55
|
+
return OpRunTensor(self.tensor.to(to))
|
|
56
|
+
|
|
57
|
+
def string_type(self) -> str:
|
|
58
|
+
"Returns information about the value as a string."
|
|
59
|
+
s = string_type(self.tensor, with_shape=True, with_min_max=True, with_device=True)
|
|
60
|
+
if self.is_constant:
|
|
61
|
+
return f"CST({s})"
|
|
62
|
+
return s
|
|
63
|
+
|
|
64
|
+
def __repr__(self) -> str:
|
|
65
|
+
"usual"
|
|
66
|
+
if self.is_constant:
|
|
67
|
+
return (
|
|
68
|
+
f"{self.__class__.__name__}"
|
|
69
|
+
f"({string_type(self.tensor, with_shape=True)}, is_constant=True)"
|
|
70
|
+
)
|
|
71
|
+
return f"{self.__class__.__name__}({string_type(self.tensor, with_shape=True)})"
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def tensor_or_sequence(self) -> Union[torch.Tensor, List[torch.Tensor]]:
|
|
75
|
+
"Returns either a tensor or a sequence."
|
|
76
|
+
return self.tensor
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def shape(self):
|
|
80
|
+
"shape (torch shape)"
|
|
81
|
+
return self.tensor.shape
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def dtype(self):
|
|
85
|
+
"dtype (torch dtype)"
|
|
86
|
+
return self.tensor.dtype
|
|
87
|
+
|
|
88
|
+
def _tensor_as_tuple_int(self) -> Tuple[int, ...]:
|
|
89
|
+
return tuple(map(int, self.tensor))
|
|
90
|
+
|
|
91
|
+
def numel(self) -> int:
|
|
92
|
+
"Returns the number of elements."
|
|
93
|
+
return 0 if self.tensor is None else self.tensor.numel()
|
|
94
|
+
|
|
95
|
+
def get_device(self) -> int:
|
|
96
|
+
"Returns the device id."
|
|
97
|
+
return -1 if self.tensor is None else self.tensor.get_device()
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def device(self):
|
|
101
|
+
"Returns the device."
|
|
102
|
+
return -1 if self.tensor is None else self.tensor.device
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def as_tuple_int(self) -> Tuple[int, ...]:
|
|
106
|
+
"value as int"
|
|
107
|
+
if self.is_constant:
|
|
108
|
+
if self.cached is None:
|
|
109
|
+
self.cached = self._tensor_as_tuple_int()
|
|
110
|
+
return self.cached
|
|
111
|
+
return self._tensor_as_tuple_int()
|
|
112
|
+
|
|
113
|
+
def copy(self) -> "OpRunTensor":
|
|
114
|
+
"Shallow copy."
|
|
115
|
+
return self.__class__(self.tensor)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class OpRunSequence(OpRunValue):
|
|
119
|
+
"""Defines a sequence."""
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self, sequence: Optional[List[torch.Tensor]] = None, dtype: torch.dtype = torch.float32
|
|
123
|
+
):
|
|
124
|
+
self.tensor = torch.tensor(-666666, dtype=dtype)
|
|
125
|
+
self.is_shape = False
|
|
126
|
+
self.sequence = sequence or []
|
|
127
|
+
self.cached: Optional[Tuple[int, ...]] = None
|
|
128
|
+
assert all(
|
|
129
|
+
isinstance(s, torch.Tensor) for s in self.sequence
|
|
130
|
+
), f"Unexpected type in sequence {[type(s) for s in self.sequence]}"
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def dtype(self):
|
|
134
|
+
"dtype (torch dtype)"
|
|
135
|
+
return self.tensor.dtype
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def tensor_or_sequence(self) -> Union[torch.Tensor, List[torch.Tensor]]:
|
|
139
|
+
"Returns either a tensor or a sequence."
|
|
140
|
+
return self.sequence
|
|
141
|
+
|
|
142
|
+
@classmethod
|
|
143
|
+
def is_sequence(cls) -> bool:
|
|
144
|
+
"Tells if it is sequence."
|
|
145
|
+
return True
|
|
146
|
+
|
|
147
|
+
def insert_at(
|
|
148
|
+
self, tensor: torch.Tensor, position: Optional[OpRunTensor] = None
|
|
149
|
+
) -> "OpRunSequence":
|
|
150
|
+
"Inserts a value at a given position."
|
|
151
|
+
assert isinstance(tensor, OpRunTensor), f"Unexpected type {type(tensor)} for tensor"
|
|
152
|
+
new_seq = OpRunSequence()
|
|
153
|
+
seq = self.sequence.copy()
|
|
154
|
+
new_seq.sequence = seq
|
|
155
|
+
if position is None:
|
|
156
|
+
seq.append(tensor.tensor)
|
|
157
|
+
else:
|
|
158
|
+
seq.insert(int(position.tensor.item()), tensor.tensor)
|
|
159
|
+
return new_seq
|
|
160
|
+
|
|
161
|
+
def copy(self) -> "OpRunSequence":
|
|
162
|
+
"Shallow copy."
|
|
163
|
+
return self.__class__(self.sequence, dtype=self.dtype)
|
|
164
|
+
|
|
165
|
+
def string_type(self) -> str:
|
|
166
|
+
"Returns a string which can be printed."
|
|
167
|
+
return string_type(self.sequence, with_shape=True)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class OpRunKernel:
|
|
171
|
+
"""
|
|
172
|
+
Main class. Every kernel should inherit from it.
|
|
173
|
+
It does not copy the proto.
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
@classmethod
|
|
177
|
+
def device_dependent(cls) -> bool:
|
|
178
|
+
"""
|
|
179
|
+
Returns True if the kernel needs a device to be efficiently initialized.
|
|
180
|
+
"""
|
|
181
|
+
return False
|
|
182
|
+
|
|
183
|
+
@classmethod
|
|
184
|
+
def has_subgraphs(cls) -> bool:
|
|
185
|
+
"""Returns True if the kernel has subgraphs."""
|
|
186
|
+
return False
|
|
187
|
+
|
|
188
|
+
def __init__(
|
|
189
|
+
self,
|
|
190
|
+
node: onnx.NodeProto,
|
|
191
|
+
version: Optional[int] = None,
|
|
192
|
+
verbose: int = 0,
|
|
193
|
+
custom_kernels: Optional[Dict[Tuple[str, str], type]] = None,
|
|
194
|
+
):
|
|
195
|
+
assert isinstance(
|
|
196
|
+
node, onnx.NodeProto
|
|
197
|
+
), f"node must be a NodeProto but node is {type(node)}"
|
|
198
|
+
self.op_type = node.op_type
|
|
199
|
+
self.domain = node.domain
|
|
200
|
+
self.input = node.input
|
|
201
|
+
self.output = node.output
|
|
202
|
+
self.verbose = verbose
|
|
203
|
+
self.custom_kernels = custom_kernels
|
|
204
|
+
if version is None:
|
|
205
|
+
name = self.__class__.__name__.split("_")
|
|
206
|
+
assert (
|
|
207
|
+
len(name) == 2
|
|
208
|
+
), f"Cannot guess version from name={self.__class__.__name__!r}"
|
|
209
|
+
version = int(name[1])
|
|
210
|
+
self.version = version
|
|
211
|
+
self.name = node.name
|
|
212
|
+
|
|
213
|
+
def __str__(self) -> str:
|
|
214
|
+
"usual"
|
|
215
|
+
if self.domain:
|
|
216
|
+
return (
|
|
217
|
+
f"{self.op_type}[{self.domain}]({', '.join(self.input)}) "
|
|
218
|
+
f"-> {', '.join(self.output)}"
|
|
219
|
+
)
|
|
220
|
+
return f"{self.op_type}({', '.join(self.input)}) -> {', '.join(self.output)}"
|
|
221
|
+
|
|
222
|
+
def run(
|
|
223
|
+
self, *args: Optional[OpRunValue]
|
|
224
|
+
) -> Union[OpRunValue, Tuple[Optional[OpRunValue], ...]]:
|
|
225
|
+
"Kernel implementation."
|
|
226
|
+
raise NotImplementedError(
|
|
227
|
+
f"Method run is not implemented for kernel {self.__class__.__name__!r}"
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
def _find_attribute(self, node: onnx.NodeProto, name: str):
|
|
231
|
+
for att in node.attribute:
|
|
232
|
+
if att.name == name:
|
|
233
|
+
return att
|
|
234
|
+
return None
|
|
235
|
+
|
|
236
|
+
def get_attribute_float(
|
|
237
|
+
self, node: onnx.NodeProto, name: str, default_value: Optional[float] = None
|
|
238
|
+
) -> Optional[float]:
|
|
239
|
+
"""
|
|
240
|
+
Returns an attribute as an int.
|
|
241
|
+
|
|
242
|
+
:param node: NodeProto
|
|
243
|
+
:param name: name
|
|
244
|
+
:param default_value: default_value
|
|
245
|
+
:return: value
|
|
246
|
+
"""
|
|
247
|
+
att = self._find_attribute(node, name)
|
|
248
|
+
return default_value if att is None else float(att.f)
|
|
249
|
+
|
|
250
|
+
def get_attribute_int(
|
|
251
|
+
self, node: onnx.NodeProto, name: str, default_value: Optional[int] = None
|
|
252
|
+
) -> Optional[int]:
|
|
253
|
+
"""
|
|
254
|
+
Returns an attribute as an int.
|
|
255
|
+
|
|
256
|
+
:param node: NodeProto
|
|
257
|
+
:param name: name
|
|
258
|
+
:param default_value: default_value
|
|
259
|
+
:return: value
|
|
260
|
+
"""
|
|
261
|
+
att = self._find_attribute(node, name)
|
|
262
|
+
return default_value if att is None else int(att.i)
|
|
263
|
+
|
|
264
|
+
def get_attribute_ints(
|
|
265
|
+
self, node: onnx.NodeProto, name: str, default_value: Optional[Tuple[int, ...]] = None
|
|
266
|
+
) -> Optional[Tuple[int, ...]]:
|
|
267
|
+
"""
|
|
268
|
+
Returns an attribute as a tuple of ints.
|
|
269
|
+
|
|
270
|
+
:param node: NodeProto
|
|
271
|
+
:param name: name
|
|
272
|
+
:param default_value: default_value
|
|
273
|
+
:return: value
|
|
274
|
+
"""
|
|
275
|
+
att = self._find_attribute(node, name)
|
|
276
|
+
return default_value if att is None else tuple(map(int, att.ints))
|
|
277
|
+
|
|
278
|
+
def get_attribute_string(
|
|
279
|
+
self, node: onnx.NodeProto, name: str, default_value: Optional[str] = None
|
|
280
|
+
) -> Optional[str]:
|
|
281
|
+
"""
|
|
282
|
+
Returns an attribute as a tuple of ints.
|
|
283
|
+
|
|
284
|
+
:param node: NodeProto
|
|
285
|
+
:param name: name
|
|
286
|
+
:param default_value: default_value
|
|
287
|
+
:return: value
|
|
288
|
+
"""
|
|
289
|
+
att = self._find_attribute(node, name)
|
|
290
|
+
return default_value if att is None else att.s.decode("utf-8")
|
|
291
|
+
|
|
292
|
+
def get_attribute_tensor(self, node: onnx.NodeProto, name: str) -> Optional[torch.Tensor]:
|
|
293
|
+
"""
|
|
294
|
+
Returns an attribute as a torch tensor.
|
|
295
|
+
|
|
296
|
+
:param node: NodeProto
|
|
297
|
+
:param name: name
|
|
298
|
+
:param default_value: default_value
|
|
299
|
+
:return: value
|
|
300
|
+
"""
|
|
301
|
+
att = self._find_attribute(node, name)
|
|
302
|
+
if att is None:
|
|
303
|
+
return None
|
|
304
|
+
return to_tensor(att.t)
|
|
305
|
+
|
|
306
|
+
def same_device(self, *tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
|
307
|
+
"""Puts all tensors on the same device."""
|
|
308
|
+
devices = [t.get_device() for t in tensors]
|
|
309
|
+
if len(set(devices)) == 1:
|
|
310
|
+
return tuple(tensors)
|
|
311
|
+
index = devices.index(max(devices))
|
|
312
|
+
device = tensors[index].device
|
|
313
|
+
return tuple(t.to(device) for t in tensors)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class OpRunFunction(OpRunKernel):
|
|
317
|
+
"""
|
|
318
|
+
Defines a kernel based on a local functions.
|
|
319
|
+
"""
|
|
320
|
+
|
|
321
|
+
def __init__(
|
|
322
|
+
self,
|
|
323
|
+
runtime: "onnx_diagnostic.reference.TorchOnnxEvaluator", # noqa: F821
|
|
324
|
+
node: onnx.NodeProto,
|
|
325
|
+
version: Optional[int] = None,
|
|
326
|
+
verbose: int = 0,
|
|
327
|
+
):
|
|
328
|
+
super().__init__(node, version, verbose=verbose)
|
|
329
|
+
self.runtime = runtime
|
|
330
|
+
self.input_names = runtime.input_names
|
|
331
|
+
|
|
332
|
+
def run(
|
|
333
|
+
self, *args: Optional[OpRunValue]
|
|
334
|
+
) -> Union[OpRunValue, Tuple[Optional[OpRunValue], ...]]:
|
|
335
|
+
return self.runtime.run_with_values(*args)
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
import onnx
|
|
3
|
+
import torch
|
|
4
|
+
from . import OpRunKernel, OpRunTensor
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Gather_1(OpRunKernel):
|
|
8
|
+
"Gather"
|
|
9
|
+
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
node: onnx.NodeProto,
|
|
13
|
+
version: Optional[int] = None,
|
|
14
|
+
verbose: int = 0,
|
|
15
|
+
):
|
|
16
|
+
super().__init__(node, version, verbose=verbose)
|
|
17
|
+
axis = self.get_attribute_int(node, "axis", 0)
|
|
18
|
+
assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}"
|
|
19
|
+
self.axis = axis
|
|
20
|
+
|
|
21
|
+
def run(self, x, indices):
|
|
22
|
+
if indices.tensor.numel() == 0:
|
|
23
|
+
return torch.empty((0,), dtype=x.tensor.dtype, device=x.tensor.device)
|
|
24
|
+
ind = [slice(0, s) for s in x.shape]
|
|
25
|
+
ind[self.axis] = indices.tensor
|
|
26
|
+
return OpRunTensor(x.tensor[tuple(ind)])
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ScatterND_16(OpRunKernel):
|
|
30
|
+
"ScatterND"
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
node: onnx.NodeProto,
|
|
35
|
+
version: Optional[int] = None,
|
|
36
|
+
verbose: int = 0,
|
|
37
|
+
):
|
|
38
|
+
super().__init__(node, version, verbose=verbose)
|
|
39
|
+
self.reduction = self.get_attribute_string(node, "reduction", "none")
|
|
40
|
+
|
|
41
|
+
def run(
|
|
42
|
+
self, data: OpRunTensor, indices: OpRunTensor, updates: OpRunTensor
|
|
43
|
+
) -> OpRunTensor:
|
|
44
|
+
# This implementation is not efficient.
|
|
45
|
+
grids = torch.meshgrid(*[torch.arange(s) for s in indices.shape[:-1]], indexing="ij")
|
|
46
|
+
stacked = torch.stack(grids, dim=-1)
|
|
47
|
+
index = stacked.reshape(-1, len(indices.shape) - 1)
|
|
48
|
+
output = data.tensor.clone()
|
|
49
|
+
for i in index:
|
|
50
|
+
if self.reduction == "add":
|
|
51
|
+
output[indices.tensor[i]] += updates.tensor[i]
|
|
52
|
+
elif self.reduction == "mul":
|
|
53
|
+
output[indices.tensor[i]] *= updates.tensor[i]
|
|
54
|
+
elif self.reduction == "max":
|
|
55
|
+
output[indices.tensor[i]] = torch.maximum(
|
|
56
|
+
output[indices.tensor[i]], updates.tensor[i]
|
|
57
|
+
)
|
|
58
|
+
elif self.reduction == "min":
|
|
59
|
+
output[indices.tensor[i]] = torch.minimum(
|
|
60
|
+
output[indices.tensor[i]], updates.tensor[i]
|
|
61
|
+
)
|
|
62
|
+
else:
|
|
63
|
+
output[indices.tensor[i]] = updates.tensor[i]
|
|
64
|
+
return OpRunTensor(output)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class Slice_13(OpRunKernel):
|
|
68
|
+
"Slice"
|
|
69
|
+
|
|
70
|
+
def run(
|
|
71
|
+
self,
|
|
72
|
+
data: OpRunTensor,
|
|
73
|
+
starts: OpRunTensor,
|
|
74
|
+
ends: OpRunTensor,
|
|
75
|
+
axes: Optional[OpRunTensor] = None,
|
|
76
|
+
steps: Optional[OpRunTensor] = None,
|
|
77
|
+
) -> OpRunTensor:
|
|
78
|
+
if axes is None:
|
|
79
|
+
if steps is None:
|
|
80
|
+
slices = [slice(s, e) for s, e in zip(starts.tensor, ends.tensor)]
|
|
81
|
+
else:
|
|
82
|
+
slices = [
|
|
83
|
+
slice(s, e, d) for s, e, d in zip(starts.tensor, ends.tensor, steps.tensor)
|
|
84
|
+
]
|
|
85
|
+
else:
|
|
86
|
+
if steps is None:
|
|
87
|
+
slices = [slice(0, a) for a in data.shape]
|
|
88
|
+
for s, e, a in zip(starts.tensor, ends.tensor, axes.tensor):
|
|
89
|
+
slices[a] = slice(s, e)
|
|
90
|
+
else:
|
|
91
|
+
slices = [slice(0, a) for a in data.shape]
|
|
92
|
+
for s, e, a, d in zip(starts.tensor, ends.tensor, axes.tensor, steps.tensor):
|
|
93
|
+
slices[a] = slice(s, e, d)
|
|
94
|
+
return OpRunTensor(data.tensor[tuple(slices)])
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from . import OpRunKernel, OpRunTensor
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class OpRunBinary(OpRunKernel):
|
|
6
|
+
"Binary Op"
|
|
7
|
+
|
|
8
|
+
def run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
|
|
9
|
+
if x.get_device() != y.get_device():
|
|
10
|
+
if x.get_device() >= 0:
|
|
11
|
+
y = y.to(x.device)
|
|
12
|
+
else:
|
|
13
|
+
x = x.to(y.device)
|
|
14
|
+
return self._run(x, y)
|
|
15
|
+
|
|
16
|
+
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
|
|
17
|
+
raise NotImplementedError(f"Operator {self.__class__.__name__!r} is not complete.")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class And_1(OpRunBinary):
|
|
21
|
+
"""And"""
|
|
22
|
+
|
|
23
|
+
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
|
|
24
|
+
return OpRunTensor(x.tensor & y.tensor)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Add_1(OpRunBinary):
|
|
28
|
+
"""Add"""
|
|
29
|
+
|
|
30
|
+
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
|
|
31
|
+
return OpRunTensor(x.tensor + y.tensor)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Div_1(OpRunBinary):
|
|
35
|
+
"""Div"""
|
|
36
|
+
|
|
37
|
+
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
|
|
38
|
+
return OpRunTensor(x.tensor / y.tensor)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class Equal_1(OpRunBinary):
|
|
42
|
+
"""Equal"""
|
|
43
|
+
|
|
44
|
+
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
|
|
45
|
+
return OpRunTensor(x.tensor == y.tensor)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class Greater_1(OpRunBinary):
|
|
49
|
+
"""Greater"""
|
|
50
|
+
|
|
51
|
+
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
|
|
52
|
+
return OpRunTensor(x.tensor > y.tensor)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class GreaterOrEqual_1(OpRunBinary):
|
|
56
|
+
"""GreaterOrEqual"""
|
|
57
|
+
|
|
58
|
+
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
|
|
59
|
+
return OpRunTensor(x.tensor >= y.tensor)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class Less_1(OpRunBinary):
|
|
63
|
+
"""Less"""
|
|
64
|
+
|
|
65
|
+
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
|
|
66
|
+
return OpRunTensor(x.tensor < y.tensor)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class LessOrEqual_1(OpRunBinary):
|
|
70
|
+
"""LessOrEqual"""
|
|
71
|
+
|
|
72
|
+
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
|
|
73
|
+
return OpRunTensor(x.tensor <= y.tensor)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class MatMul_1(OpRunBinary):
|
|
77
|
+
"""MatMul"""
|
|
78
|
+
|
|
79
|
+
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
|
|
80
|
+
return OpRunTensor(x.tensor @ y.tensor)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class Mul_1(OpRunBinary):
|
|
84
|
+
"""Mul"""
|
|
85
|
+
|
|
86
|
+
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
|
|
87
|
+
return OpRunTensor(x.tensor * y.tensor)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class Or_1(OpRunBinary):
|
|
91
|
+
"""Or"""
|
|
92
|
+
|
|
93
|
+
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
|
|
94
|
+
return OpRunTensor(x.tensor | y.tensor)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class Pow_12(OpRunBinary):
|
|
98
|
+
"""Pow"""
|
|
99
|
+
|
|
100
|
+
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
|
|
101
|
+
return OpRunTensor(torch.pow(x.tensor, y.tensor))
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class Sub_1(OpRunBinary):
|
|
105
|
+
"""Sub"""
|
|
106
|
+
|
|
107
|
+
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
|
|
108
|
+
return OpRunTensor(x.tensor - y.tensor)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
import onnx
|
|
3
|
+
import torch
|
|
4
|
+
from . import OpRunKernel, OpRunTensor
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class OpRunControlFlow(OpRunKernel):
|
|
8
|
+
"""Common ancestor for control flows."""
|
|
9
|
+
|
|
10
|
+
@classmethod
|
|
11
|
+
def has_subgraphs(cls) -> bool:
|
|
12
|
+
"""Returns True if the kernel has subgraphs."""
|
|
13
|
+
return True
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
node: onnx.NodeProto,
|
|
18
|
+
version: Optional[int] = None,
|
|
19
|
+
parent: Optional["onnx_diagnostic.reference.TorchOnnxEvaluator"] = None, # noqa: F821
|
|
20
|
+
verbose: int = 0,
|
|
21
|
+
):
|
|
22
|
+
super().__init__(node, version, verbose=verbose)
|
|
23
|
+
assert (
|
|
24
|
+
parent is not None
|
|
25
|
+
), f"parent must be specified for operator {self.__class__.__name__!r}"
|
|
26
|
+
for att in node.attribute:
|
|
27
|
+
if att.type == onnx.AttributeProto.GRAPH:
|
|
28
|
+
rt = parent.__class__(
|
|
29
|
+
att.g,
|
|
30
|
+
providers=parent.providers,
|
|
31
|
+
opsets=parent.opsets,
|
|
32
|
+
local_functions=parent.functions,
|
|
33
|
+
verbose=parent.verbose,
|
|
34
|
+
custom_kernels=parent.custom_kernels,
|
|
35
|
+
)
|
|
36
|
+
setattr(self, att.name, rt)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class If_1(OpRunControlFlow):
|
|
40
|
+
"If"
|
|
41
|
+
|
|
42
|
+
def run(self, cond, context: Optional[Dict[str, Any]] = None):
|
|
43
|
+
rt = self.then_branch if cond.tensor.item() else self.else_branch # type: ignore[attr-defined]
|
|
44
|
+
return rt.run_with_values(context=context)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class Loop_16(OpRunControlFlow):
|
|
48
|
+
"Loop"
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
node: onnx.NodeProto,
|
|
53
|
+
version: Optional[int] = None,
|
|
54
|
+
parent: Optional["onnx_diagnostic.reference.TorchOnnxEvaluator"] = None, # noqa: F821
|
|
55
|
+
verbose: int = 0,
|
|
56
|
+
):
|
|
57
|
+
super().__init__(node, version, parent, verbose=verbose)
|
|
58
|
+
self.output_index = {n: i for i, n in enumerate(self.body.output_names)}
|
|
59
|
+
self.N = len(self.body.input_names) - 2
|
|
60
|
+
self.K = len(self.body.output_names) - self.N - 1
|
|
61
|
+
|
|
62
|
+
def run(self, M, cond, *args, context: Optional[Dict[str, Any]] = None):
|
|
63
|
+
if args:
|
|
64
|
+
v_initial = args[0]
|
|
65
|
+
args = args[1:]
|
|
66
|
+
else:
|
|
67
|
+
v_initial = None
|
|
68
|
+
assert M is None or hasattr(
|
|
69
|
+
M, "dtype"
|
|
70
|
+
), f"M must be empty or an array but its type is {type(M)}."
|
|
71
|
+
body = self.body
|
|
72
|
+
loop_inputs = body.input_names
|
|
73
|
+
inputs = dict.fromkeys(loop_inputs)
|
|
74
|
+
if v_initial is not None:
|
|
75
|
+
inputs[loop_inputs[2]] = v_initial
|
|
76
|
+
cond_name = body.output_names[0]
|
|
77
|
+
if args:
|
|
78
|
+
begin = len(loop_inputs) - len(args)
|
|
79
|
+
all_inputs = loop_inputs[begin:]
|
|
80
|
+
for name, val in zip(all_inputs, args):
|
|
81
|
+
inputs[name] = val
|
|
82
|
+
if context is not None:
|
|
83
|
+
for a in context:
|
|
84
|
+
inputs[a] = context[a]
|
|
85
|
+
|
|
86
|
+
k_carried_away = [[] for i in range(self.K)] # type: ignore
|
|
87
|
+
it = 0
|
|
88
|
+
while (cond is None or cond.tensor is None or cond.tensor.item()) and (
|
|
89
|
+
M is None or M.tensor is None or it < M.tensor.item()
|
|
90
|
+
):
|
|
91
|
+
if len(body.input_names) > 0 and body.input_names[0] is not None:
|
|
92
|
+
inputs[body.input_names[0]] = OpRunTensor(
|
|
93
|
+
torch.tensor(it, dtype=None if M is None else M.dtype)
|
|
94
|
+
)
|
|
95
|
+
if len(body.input_names) > 1 and body.input_names[1] is not None:
|
|
96
|
+
inputs[body.input_names[1]] = cond
|
|
97
|
+
outputs = list(
|
|
98
|
+
self.body.run_with_values(
|
|
99
|
+
*[inputs[k] for k in self.body.input_names], context=context
|
|
100
|
+
)
|
|
101
|
+
)
|
|
102
|
+
if self.K > 0:
|
|
103
|
+
for k in range(self.K):
|
|
104
|
+
k_carried_away[k].append(outputs[-self.K + k])
|
|
105
|
+
index_cond = self.output_index[cond_name]
|
|
106
|
+
cond = outputs[index_cond]
|
|
107
|
+
assert (
|
|
108
|
+
cond is not None
|
|
109
|
+
), f"Condition {cond_name!r} returned by the subgraph cannot be None."
|
|
110
|
+
for i, o in zip(body.input_names[2:], body.output_names[1:]):
|
|
111
|
+
inputs[i] = outputs[self.output_index[o]]
|
|
112
|
+
it += 1
|
|
113
|
+
|
|
114
|
+
if it == 0:
|
|
115
|
+
outputs = [inputs[i] for i in body.input_names[2:]]
|
|
116
|
+
else:
|
|
117
|
+
outputs = outputs[1 : 1 + self.N]
|
|
118
|
+
outputs.extend([OpRunTensor(torch.cat(x, axis=0)) for x in k_carried_away])
|
|
119
|
+
while len(outputs) < len(self.body.output_names):
|
|
120
|
+
outputs.append(OpRunTensor(torch.empty(())))
|
|
121
|
+
return tuple(outputs)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
import onnx
|
|
3
|
+
import torch
|
|
4
|
+
from . import OpRunKernel, OpRunTensor
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Range_11(OpRunKernel):
|
|
8
|
+
"""Range"""
|
|
9
|
+
|
|
10
|
+
@classmethod
|
|
11
|
+
def device_dependent(cls) -> bool:
|
|
12
|
+
"""
|
|
13
|
+
Returns True if the kernel needs a device to be efficiently initialized.
|
|
14
|
+
"""
|
|
15
|
+
return True
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
node: onnx.NodeProto,
|
|
20
|
+
version: Optional[int] = None,
|
|
21
|
+
device: Optional[torch.device] = None,
|
|
22
|
+
verbose: int = 0,
|
|
23
|
+
):
|
|
24
|
+
super().__init__(node, version, verbose=verbose)
|
|
25
|
+
self.device = device
|
|
26
|
+
|
|
27
|
+
def run(self, starts: OpRunTensor, limit: OpRunTensor, delta: OpRunTensor) -> OpRunTensor:
|
|
28
|
+
return OpRunTensor(
|
|
29
|
+
torch.arange(
|
|
30
|
+
starts.tensor,
|
|
31
|
+
limit.tensor,
|
|
32
|
+
delta.tensor,
|
|
33
|
+
dtype=starts.dtype,
|
|
34
|
+
device=self.device,
|
|
35
|
+
)
|
|
36
|
+
)
|