onnx-diagnostic 0.6.0__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +18 -0
  3. onnx_diagnostic/api.py +15 -0
  4. onnx_diagnostic/ext_test_case.py +3 -1
  5. onnx_diagnostic/helpers/args_helper.py +1 -1
  6. onnx_diagnostic/helpers/helper.py +6 -5
  7. onnx_diagnostic/helpers/model_builder_helper.py +24 -8
  8. onnx_diagnostic/helpers/rt_helper.py +5 -1
  9. onnx_diagnostic/helpers/torch_helper.py +2 -0
  10. onnx_diagnostic/reference/__init__.py +1 -0
  11. onnx_diagnostic/reference/torch_evaluator.py +518 -0
  12. onnx_diagnostic/reference/torch_ops/__init__.py +55 -0
  13. onnx_diagnostic/reference/torch_ops/_op_run.py +326 -0
  14. onnx_diagnostic/reference/torch_ops/access_ops.py +84 -0
  15. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  16. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +118 -0
  17. onnx_diagnostic/reference/torch_ops/generator_ops.py +35 -0
  18. onnx_diagnostic/reference/torch_ops/nn_ops.py +176 -0
  19. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  20. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  21. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  22. onnx_diagnostic/reference/torch_ops/shape_ops.py +120 -0
  23. onnx_diagnostic/reference/torch_ops/unary_ops.py +86 -0
  24. onnx_diagnostic/tasks/__init__.py +22 -1
  25. onnx_diagnostic/tasks/image_classification.py +2 -2
  26. onnx_diagnostic/tasks/text_generation.py +3 -3
  27. onnx_diagnostic/torch_export_patches/eval/__init__.py +106 -37
  28. onnx_diagnostic/torch_export_patches/eval/model_cases.py +12 -25
  29. onnx_diagnostic/torch_export_patches/patch_module_helper.py +130 -16
  30. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +88 -0
  31. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +142 -0
  32. onnx_diagnostic/torch_models/test_helper.py +115 -15
  33. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  34. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/METADATA +1 -1
  35. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/RECORD +38 -23
  36. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/WHEEL +1 -1
  37. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/licenses/LICENSE.txt +0 -0
  38. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,326 @@
1
+ from typing import Any, List, Optional, Union, Tuple
2
+ import onnx
3
+ import torch
4
+ from ...api import TensorLike
5
+ from ...helpers import string_type
6
+ from ...helpers.torch_helper import to_tensor
7
+
8
+
9
+ class OpRunValue(TensorLike):
10
+ """Defines a value for the runtime, a tensor or a sequence."""
11
+
12
+ __slots__ = ("cached", "is_constant", "sequence", "tensor")
13
+
14
+ @classmethod
15
+ def is_sequence(cls) -> bool:
16
+ "Tells if it is sequence."
17
+ raise NotImplementedError("is_sequence must be overwritten.")
18
+
19
+
20
+ class OpRunTensor(OpRunValue):
21
+ """
22
+ Wrapper around a tensor.
23
+
24
+ :param tensor: torch.Tensor
25
+ :param is_constant: is it a constant
26
+ :param may_cpu: change the device the tensor is if
27
+ more appropriate
28
+ """
29
+
30
+ def __init__(self, tensor, is_constant: bool = False, may_cpu: bool = False):
31
+ assert isinstance(tensor, torch.Tensor), (
32
+ f"Unexpected type {type(tensor)}, "
33
+ f"__name__={getattr(tensor, '__name__', 'no name')}"
34
+ )
35
+ assert tensor is None or tensor.numel() != 1 or tensor.item() != -666666
36
+ self.tensor = (
37
+ tensor.cpu()
38
+ if may_cpu
39
+ and len(tensor.shape) == 1
40
+ and tensor.numel() < 8
41
+ and tensor.dtype == torch.int64
42
+ and tensor.get_device() >= 0
43
+ else tensor
44
+ )
45
+ self.is_constant = is_constant
46
+ self.cached: Optional[Tuple[int, ...]] = None
47
+
48
+ @classmethod
49
+ def is_sequence(cls) -> bool:
50
+ "Tells if it is sequence."
51
+ return False
52
+
53
+ def to(self, to: Any) -> "OpRunTensor":
54
+ "Changes the device."
55
+ return OpRunTensor(self.tensor.to(to))
56
+
57
+ def string_type(self) -> str:
58
+ "Returns information about the value as a string."
59
+ s = string_type(self.tensor, with_shape=True, with_min_max=True, with_device=True)
60
+ if self.is_constant:
61
+ return f"CST({s})"
62
+ return s
63
+
64
+ def __repr__(self) -> str:
65
+ "usual"
66
+ if self.is_constant:
67
+ return (
68
+ f"{self.__class__.__name__}"
69
+ f"({string_type(self.tensor, with_shape=True)}, is_constant=True)"
70
+ )
71
+ return f"{self.__class__.__name__}({string_type(self.tensor, with_shape=True)})"
72
+
73
+ @property
74
+ def tensor_or_sequence(self) -> Union[torch.Tensor, List[torch.Tensor]]:
75
+ "Returns either a tensor or a sequence."
76
+ return self.tensor
77
+
78
+ @property
79
+ def shape(self):
80
+ "shape (torch shape)"
81
+ return self.tensor.shape
82
+
83
+ @property
84
+ def dtype(self):
85
+ "dtype (torch dtype)"
86
+ return self.tensor.dtype
87
+
88
+ def _tensor_as_tuple_int(self) -> Tuple[int, ...]:
89
+ return tuple(map(int, self.tensor))
90
+
91
+ def numel(self) -> int:
92
+ "Returns the number of elements."
93
+ return 0 if self.tensor is None else self.tensor.numel()
94
+
95
+ def get_device(self) -> int:
96
+ "Returns the device id."
97
+ return -1 if self.tensor is None else self.tensor.get_device()
98
+
99
+ @property
100
+ def device(self):
101
+ "Returns the device."
102
+ return -1 if self.tensor is None else self.tensor.device
103
+
104
+ @property
105
+ def as_tuple_int(self) -> Tuple[int, ...]:
106
+ "value as int"
107
+ if self.is_constant:
108
+ if self.cached is None:
109
+ self.cached = self._tensor_as_tuple_int()
110
+ return self.cached
111
+ return self._tensor_as_tuple_int()
112
+
113
+ def copy(self) -> "OpRunTensor":
114
+ "Shallow copy."
115
+ return self.__class__(self.tensor)
116
+
117
+
118
+ class OpRunSequence(OpRunValue):
119
+ """Defines a sequence."""
120
+
121
+ def __init__(
122
+ self, sequence: Optional[List[torch.Tensor]] = None, dtype: torch.dtype = torch.float32
123
+ ):
124
+ self.tensor = torch.tensor(-666666, dtype=dtype)
125
+ self.is_shape = False
126
+ self.sequence = sequence or []
127
+ self.cached: Optional[Tuple[int, ...]] = None
128
+ assert all(
129
+ isinstance(s, torch.Tensor) for s in self.sequence
130
+ ), f"Unexpected type in sequence {[type(s) for s in self.sequence]}"
131
+
132
+ @property
133
+ def dtype(self):
134
+ "dtype (torch dtype)"
135
+ return self.tensor.dtype
136
+
137
+ @property
138
+ def tensor_or_sequence(self) -> Union[torch.Tensor, List[torch.Tensor]]:
139
+ "Returns either a tensor or a sequence."
140
+ return self.sequence
141
+
142
+ @classmethod
143
+ def is_sequence(cls) -> bool:
144
+ "Tells if it is sequence."
145
+ return True
146
+
147
+ def insert_at(
148
+ self, tensor: torch.Tensor, position: Optional[OpRunTensor] = None
149
+ ) -> "OpRunSequence":
150
+ "Inserts a value at a given position."
151
+ assert isinstance(tensor, OpRunTensor), f"Unexpected type {type(tensor)} for tensor"
152
+ new_seq = OpRunSequence()
153
+ seq = self.sequence.copy()
154
+ new_seq.sequence = seq
155
+ if position is None:
156
+ seq.append(tensor.tensor)
157
+ else:
158
+ seq.insert(int(position.tensor.item()), tensor.tensor)
159
+ return new_seq
160
+
161
+ def copy(self) -> "OpRunSequence":
162
+ "Shallow copy."
163
+ return self.__class__(self.sequence, dtype=self.dtype)
164
+
165
+ def string_type(self) -> str:
166
+ "Returns a string which can be printed."
167
+ return string_type(self.sequence, with_shape=True)
168
+
169
+
170
+ class OpRun:
171
+ """
172
+ Main class. Every kernel should inherit from it.
173
+ It does not copy the proto.
174
+ """
175
+
176
+ @classmethod
177
+ def device_dependent(cls) -> bool:
178
+ """
179
+ Returns True if the kernel needs a device to be efficiently initialized.
180
+ """
181
+ return False
182
+
183
+ @classmethod
184
+ def has_subgraphs(cls) -> bool:
185
+ """Returns True if the kernel has subgraphs."""
186
+ return False
187
+
188
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
189
+ assert isinstance(
190
+ node, onnx.NodeProto
191
+ ), f"node must be a NodeProto but node is {type(node)}"
192
+ self.op_type = node.op_type
193
+ self.domain = node.domain
194
+ self.input = node.input
195
+ self.output = node.output
196
+ if version is None:
197
+ name = self.__class__.__name__.split("_")
198
+ assert (
199
+ len(name) == 2
200
+ ), f"Cannot guess version from name={self.__class__.__name__!r}"
201
+ version = int(name[1])
202
+ self.version = version
203
+ self.name = node.name
204
+
205
+ def __str__(self) -> str:
206
+ "usual"
207
+ if self.domain:
208
+ return (
209
+ f"{self.op_type}[{self.domain}]({', '.join(self.input)}) "
210
+ f"-> {', '.join(self.output)}"
211
+ )
212
+ return f"{self.op_type}({', '.join(self.input)}) -> {', '.join(self.output)}"
213
+
214
+ def run(
215
+ self, *args: Optional[OpRunValue]
216
+ ) -> Union[OpRunValue, Tuple[Optional[OpRunValue], ...]]:
217
+ "Kernel implementation."
218
+ raise NotImplementedError(
219
+ f"Method run is not implemented for kernel {self.__class__.__name__!r}"
220
+ )
221
+
222
+ def _find_attribute(self, node: onnx.NodeProto, name: str):
223
+ for att in node.attribute:
224
+ if att.name == name:
225
+ return att
226
+ return None
227
+
228
+ def get_attribute_float(
229
+ self, node: onnx.NodeProto, name: str, default_value: Optional[float] = None
230
+ ) -> Optional[float]:
231
+ """
232
+ Returns an attribute as an int.
233
+
234
+ :param node: NodeProto
235
+ :param name: name
236
+ :param default_value: default_value
237
+ :return: value
238
+ """
239
+ att = self._find_attribute(node, name)
240
+ return default_value if att is None else float(att.f)
241
+
242
+ def get_attribute_int(
243
+ self, node: onnx.NodeProto, name: str, default_value: Optional[int] = None
244
+ ) -> Optional[int]:
245
+ """
246
+ Returns an attribute as an int.
247
+
248
+ :param node: NodeProto
249
+ :param name: name
250
+ :param default_value: default_value
251
+ :return: value
252
+ """
253
+ att = self._find_attribute(node, name)
254
+ return default_value if att is None else int(att.i)
255
+
256
+ def get_attribute_ints(
257
+ self, node: onnx.NodeProto, name: str, default_value: Optional[Tuple[int, ...]] = None
258
+ ) -> Optional[Tuple[int, ...]]:
259
+ """
260
+ Returns an attribute as a tuple of ints.
261
+
262
+ :param node: NodeProto
263
+ :param name: name
264
+ :param default_value: default_value
265
+ :return: value
266
+ """
267
+ att = self._find_attribute(node, name)
268
+ return default_value if att is None else tuple(map(int, att.ints))
269
+
270
+ def get_attribute_string(
271
+ self, node: onnx.NodeProto, name: str, default_value: Optional[str] = None
272
+ ) -> Optional[str]:
273
+ """
274
+ Returns an attribute as a tuple of ints.
275
+
276
+ :param node: NodeProto
277
+ :param name: name
278
+ :param default_value: default_value
279
+ :return: value
280
+ """
281
+ att = self._find_attribute(node, name)
282
+ return default_value if att is None else att.s.decode("utf-8")
283
+
284
+ def get_attribute_tensor(self, node: onnx.NodeProto, name: str) -> Optional[torch.Tensor]:
285
+ """
286
+ Returns an attribute as a torch tensor.
287
+
288
+ :param node: NodeProto
289
+ :param name: name
290
+ :param default_value: default_value
291
+ :return: value
292
+ """
293
+ att = self._find_attribute(node, name)
294
+ if att is None:
295
+ return None
296
+ return to_tensor(att.t)
297
+
298
+ def same_device(self, *tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
299
+ """Puts all tensors on the same device."""
300
+ devices = [t.get_device() for t in tensors]
301
+ if len(set(devices)) == 1:
302
+ return tuple(tensors)
303
+ index = devices.index(max(devices))
304
+ device = tensors[index].device
305
+ return tuple(t.to(device) for t in tensors)
306
+
307
+
308
+ class OpRunFunction(OpRun):
309
+ """
310
+ Defines a kernel based on a local functions.
311
+ """
312
+
313
+ def __init__(
314
+ self,
315
+ runtime: "onnx_diagnostic.reference.TorchOnnxEvaluator", # noqa: F821
316
+ node: onnx.NodeProto,
317
+ version: Optional[int] = None,
318
+ ):
319
+ super().__init__(node, version)
320
+ self.runtime = runtime
321
+ self.input_names = runtime.input_names
322
+
323
+ def run(
324
+ self, *args: Optional[OpRunValue]
325
+ ) -> Union[OpRunValue, Tuple[Optional[OpRunValue], ...]]:
326
+ return self.runtime.run_with_values(*args)
@@ -0,0 +1,84 @@
1
+ from typing import Optional
2
+ import onnx
3
+ import torch
4
+ from . import OpRun, OpRunTensor
5
+
6
+
7
+ class Gather_1(OpRun):
8
+ "Gather"
9
+
10
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
11
+ super().__init__(node, version)
12
+ axis = self.get_attribute_int(node, "axis", 0)
13
+ assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}"
14
+ self.axis = axis
15
+
16
+ def run(self, x, indices):
17
+ if indices.tensor.numel() == 0:
18
+ return torch.empty((0,), dtype=x.tensor.dtype, device=x.tensor.device)
19
+ ind = [slice(0, s) for s in x.shape]
20
+ ind[self.axis] = indices.tensor
21
+ return OpRunTensor(x.tensor[tuple(ind)])
22
+
23
+
24
+ class ScatterND_16(OpRun):
25
+ "ScatterND"
26
+
27
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
28
+ super().__init__(node, version)
29
+ self.reduction = self.get_attribute_string(node, "reduction", "none")
30
+
31
+ def run(
32
+ self, data: OpRunTensor, indices: OpRunTensor, updates: OpRunTensor
33
+ ) -> OpRunTensor:
34
+ # This implementation is not efficient.
35
+ grids = torch.meshgrid(*[torch.arange(s) for s in indices.shape[:-1]], indexing="ij")
36
+ stacked = torch.stack(grids, dim=-1)
37
+ index = stacked.reshape(-1, len(indices.shape) - 1)
38
+ output = data.tensor.clone()
39
+ for i in index:
40
+ if self.reduction == "add":
41
+ output[indices.tensor[i]] += updates.tensor[i]
42
+ elif self.reduction == "mul":
43
+ output[indices.tensor[i]] *= updates.tensor[i]
44
+ elif self.reduction == "max":
45
+ output[indices.tensor[i]] = torch.maximum(
46
+ output[indices.tensor[i]], updates.tensor[i]
47
+ )
48
+ elif self.reduction == "min":
49
+ output[indices.tensor[i]] = torch.minimum(
50
+ output[indices.tensor[i]], updates.tensor[i]
51
+ )
52
+ else:
53
+ output[indices.tensor[i]] = updates.tensor[i]
54
+ return OpRunTensor(output)
55
+
56
+
57
+ class Slice_13(OpRun):
58
+ "Slice"
59
+
60
+ def run(
61
+ self,
62
+ data: OpRunTensor,
63
+ starts: OpRunTensor,
64
+ ends: OpRunTensor,
65
+ axes: Optional[OpRunTensor] = None,
66
+ steps: Optional[OpRunTensor] = None,
67
+ ) -> OpRunTensor:
68
+ if axes is None:
69
+ if steps is None:
70
+ slices = [slice(s, e) for s, e in zip(starts.tensor, ends.tensor)]
71
+ else:
72
+ slices = [
73
+ slice(s, e, d) for s, e, d in zip(starts.tensor, ends.tensor, steps.tensor)
74
+ ]
75
+ else:
76
+ if steps is None:
77
+ slices = [slice(0, a) for a in data.shape]
78
+ for s, e, a in zip(starts.tensor, ends.tensor, axes.tensor):
79
+ slices[a] = slice(s, e)
80
+ else:
81
+ slices = [slice(0, a) for a in data.shape]
82
+ for s, e, a, d in zip(starts.tensor, ends.tensor, axes.tensor, steps.tensor):
83
+ slices[a] = slice(s, e, d)
84
+ return OpRunTensor(data.tensor[tuple(slices)])
@@ -0,0 +1,108 @@
1
+ import torch
2
+ from . import OpRun, OpRunTensor
3
+
4
+
5
+ class OpRunBinary(OpRun):
6
+ "Binary Op"
7
+
8
+ def run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
9
+ if x.get_device() != y.get_device():
10
+ if x.get_device() >= 0:
11
+ y = y.to(x.device)
12
+ else:
13
+ x = x.to(y.device)
14
+ return self._run(x, y)
15
+
16
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
17
+ raise NotImplementedError(f"Operator {self.__class__.__name__!r} is not complete.")
18
+
19
+
20
+ class And_1(OpRunBinary):
21
+ """And"""
22
+
23
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
24
+ return OpRunTensor(x.tensor & y.tensor)
25
+
26
+
27
+ class Add_1(OpRunBinary):
28
+ """Add"""
29
+
30
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
31
+ return OpRunTensor(x.tensor + y.tensor)
32
+
33
+
34
+ class Div_1(OpRunBinary):
35
+ """Div"""
36
+
37
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
38
+ return OpRunTensor(x.tensor / y.tensor)
39
+
40
+
41
+ class Equal_1(OpRunBinary):
42
+ """Equal"""
43
+
44
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
45
+ return OpRunTensor(x.tensor == y.tensor)
46
+
47
+
48
+ class Greater_1(OpRunBinary):
49
+ """Greater"""
50
+
51
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
52
+ return OpRunTensor(x.tensor > y.tensor)
53
+
54
+
55
+ class GreaterOrEqual_1(OpRunBinary):
56
+ """GreaterOrEqual"""
57
+
58
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
59
+ return OpRunTensor(x.tensor >= y.tensor)
60
+
61
+
62
+ class Less_1(OpRunBinary):
63
+ """Less"""
64
+
65
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
66
+ return OpRunTensor(x.tensor < y.tensor)
67
+
68
+
69
+ class LessOrEqual_1(OpRunBinary):
70
+ """LessOrEqual"""
71
+
72
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
73
+ return OpRunTensor(x.tensor <= y.tensor)
74
+
75
+
76
+ class MatMul_1(OpRunBinary):
77
+ """MatMul"""
78
+
79
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
80
+ return OpRunTensor(x.tensor @ y.tensor)
81
+
82
+
83
+ class Mul_1(OpRunBinary):
84
+ """Mul"""
85
+
86
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
87
+ return OpRunTensor(x.tensor * y.tensor)
88
+
89
+
90
+ class Or_1(OpRunBinary):
91
+ """Or"""
92
+
93
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
94
+ return OpRunTensor(x.tensor | y.tensor)
95
+
96
+
97
+ class Pow_12(OpRunBinary):
98
+ """Pow"""
99
+
100
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
101
+ return OpRunTensor(torch.pow(x.tensor, y.tensor))
102
+
103
+
104
+ class Sub_1(OpRunBinary):
105
+ """Sub"""
106
+
107
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
108
+ return OpRunTensor(x.tensor - y.tensor)
@@ -0,0 +1,118 @@
1
+ from typing import Any, Dict, Optional
2
+ import onnx
3
+ import torch
4
+ from . import OpRun, OpRunTensor
5
+
6
+
7
+ class OpRunControlFlow(OpRun):
8
+ """Common ancestor for control flows."""
9
+
10
+ @classmethod
11
+ def has_subgraphs(cls) -> bool:
12
+ """Returns True if the kernel has subgraphs."""
13
+ return True
14
+
15
+ def __init__(
16
+ self,
17
+ node: onnx.NodeProto,
18
+ version: Optional[int] = None,
19
+ parent: Optional["onnx_diagnostic.reference.TorchOnnxEvaluator"] = None, # noqa: F821
20
+ ):
21
+ super().__init__(node, version)
22
+ assert (
23
+ parent is not None
24
+ ), f"parent must be specified for operator {self.__class__.__name__!r}"
25
+ for att in node.attribute:
26
+ if att.type == onnx.AttributeProto.GRAPH:
27
+ rt = parent.__class__(
28
+ att.g,
29
+ providers=parent.providers,
30
+ opsets=parent.opsets,
31
+ local_functions=parent.functions,
32
+ verbose=parent.verbose,
33
+ )
34
+ setattr(self, att.name, rt)
35
+
36
+
37
+ class If_1(OpRunControlFlow):
38
+ "If"
39
+
40
+ def run(self, cond, context: Optional[Dict[str, Any]] = None):
41
+ rt = self.then_branch if cond.tensor.item() else self.else_branch # type: ignore[attr-defined]
42
+ return rt.run_with_values(context=context)
43
+
44
+
45
+ class Loop_16(OpRunControlFlow):
46
+ "Loop"
47
+
48
+ def __init__(
49
+ self,
50
+ node: onnx.NodeProto,
51
+ version: Optional[int] = None,
52
+ parent: Optional["onnx_diagnostic.reference.TorchOnnxEvaluator"] = None, # noqa: F821
53
+ ):
54
+ super().__init__(node, version, parent)
55
+ self.output_index = {n: i for i, n in enumerate(self.body.output_names)}
56
+ self.N = len(self.body.input_names) - 2
57
+ self.K = len(self.body.output_names) - self.N - 1
58
+
59
+ def run(self, M, cond, *args, context: Optional[Dict[str, Any]] = None):
60
+ if args:
61
+ v_initial = args[0]
62
+ args = args[1:]
63
+ else:
64
+ v_initial = None
65
+ assert M is None or hasattr(
66
+ M, "dtype"
67
+ ), f"M must be empty or an array but its type is {type(M)}."
68
+ body = self.body
69
+ loop_inputs = body.input_names
70
+ inputs = dict.fromkeys(loop_inputs)
71
+ if v_initial is not None:
72
+ inputs[loop_inputs[2]] = v_initial
73
+ cond_name = body.output_names[0]
74
+ if args:
75
+ begin = len(loop_inputs) - len(args)
76
+ all_inputs = loop_inputs[begin:]
77
+ for name, val in zip(all_inputs, args):
78
+ inputs[name] = val
79
+ if context is not None:
80
+ for a in context:
81
+ inputs[a] = context[a]
82
+
83
+ k_carried_away = [[] for i in range(self.K)] # type: ignore
84
+ it = 0
85
+ while (cond is None or cond.tensor is None or cond.tensor.item()) and (
86
+ M is None or M.tensor is None or it < M.tensor.item()
87
+ ):
88
+ if len(body.input_names) > 0 and body.input_names[0] is not None:
89
+ inputs[body.input_names[0]] = OpRunTensor(
90
+ torch.tensor(it, dtype=None if M is None else M.dtype)
91
+ )
92
+ if len(body.input_names) > 1 and body.input_names[1] is not None:
93
+ inputs[body.input_names[1]] = cond
94
+ outputs = list(
95
+ self.body.run_with_values(
96
+ *[inputs[k] for k in self.body.input_names], context=context
97
+ )
98
+ )
99
+ if self.K > 0:
100
+ for k in range(self.K):
101
+ k_carried_away[k].append(outputs[-self.K + k])
102
+ index_cond = self.output_index[cond_name]
103
+ cond = outputs[index_cond]
104
+ assert (
105
+ cond is not None
106
+ ), f"Condition {cond_name!r} returned by the subgraph cannot be None."
107
+ for i, o in zip(body.input_names[2:], body.output_names[1:]):
108
+ inputs[i] = outputs[self.output_index[o]]
109
+ it += 1
110
+
111
+ if it == 0:
112
+ outputs = [inputs[i] for i in body.input_names[2:]]
113
+ else:
114
+ outputs = outputs[1 : 1 + self.N]
115
+ outputs.extend([OpRunTensor(torch.cat(x, axis=0)) for x in k_carried_away])
116
+ while len(outputs) < len(self.body.output_names):
117
+ outputs.append(OpRunTensor(torch.empty(())))
118
+ return tuple(outputs)
@@ -0,0 +1,35 @@
1
+ from typing import Optional
2
+ import onnx
3
+ import torch
4
+ from . import OpRun, OpRunTensor
5
+
6
+
7
+ class Range_11(OpRun):
8
+ """Range"""
9
+
10
+ @classmethod
11
+ def device_dependent(cls) -> bool:
12
+ """
13
+ Returns True if the kernel needs a device to be efficiently initialized.
14
+ """
15
+ return True
16
+
17
+ def __init__(
18
+ self,
19
+ node: onnx.NodeProto,
20
+ version: Optional[int] = None,
21
+ device: Optional[torch.device] = None,
22
+ ):
23
+ super().__init__(node, version)
24
+ self.device = device
25
+
26
+ def run(self, starts: OpRunTensor, limit: OpRunTensor, delta: OpRunTensor) -> OpRunTensor:
27
+ return OpRunTensor(
28
+ torch.arange(
29
+ starts.tensor,
30
+ limit.tensor,
31
+ delta.tensor,
32
+ dtype=starts.dtype,
33
+ device=self.device,
34
+ )
35
+ )