onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,56 @@
1
+ from ._op_run import OpRunKernel, OpRunFunction, OpRunSequence, OpRunTensor, OpRunValue
2
+ from .access_ops import Gather_1, ScatterND_16, Slice_13
3
+ from .binary_ops import (
4
+ And_1,
5
+ Add_1,
6
+ Div_1,
7
+ Equal_1,
8
+ Greater_1,
9
+ GreaterOrEqual_1,
10
+ Less_1,
11
+ LessOrEqual_1,
12
+ MatMul_1,
13
+ Mul_1,
14
+ Or_1,
15
+ Pow_12,
16
+ Sub_1,
17
+ )
18
+ from .controlflow_ops import If_1, Loop_16
19
+ from .generator_ops import Range_11
20
+ from .nn_ops import AveragePool_11, Conv_11, LayerNormalization_17, Softmax_13, Tanh_6
21
+ from .other_ops import (
22
+ Cast_6,
23
+ CastLike_15,
24
+ NonZero_13,
25
+ Concat_1,
26
+ Tile_6,
27
+ Transpose_1,
28
+ Trilu_14,
29
+ Where_9,
30
+ )
31
+ from .reduce_ops import ReduceMax_18, ReduceMean_18, ReduceMin_17, ReduceMin_18, ReduceSum_13
32
+ from .sequence_ops import ConcatFromSequence_11, SequenceEmpty_11, SequenceInsert_11
33
+ from .shape_ops import (
34
+ ConstantOfShape_9,
35
+ Expand_8,
36
+ Reshape_14,
37
+ Shape_15,
38
+ Squeeze_13,
39
+ Split_18,
40
+ Unsqueeze_13,
41
+ )
42
+ from .unary_ops import (
43
+ Abs_1,
44
+ Cos_1,
45
+ Erf_9,
46
+ Exp_1,
47
+ Identity_1,
48
+ IsNaN_9,
49
+ Log_1,
50
+ Neg_1,
51
+ Not_1,
52
+ Reciprocal_1,
53
+ Sigmoid_6,
54
+ Sin_1,
55
+ Sqrt_1,
56
+ )
@@ -0,0 +1,335 @@
1
+ from typing import Any, Dict, List, Optional, Union, Tuple
2
+ import onnx
3
+ import torch
4
+ from ...api import TensorLike
5
+ from ...helpers import string_type
6
+ from ...helpers.torch_helper import to_tensor
7
+
8
+
9
+ class OpRunValue(TensorLike):
10
+ """Defines a value for the runtime, a tensor or a sequence."""
11
+
12
+ __slots__ = ("cached", "is_constant", "sequence", "tensor")
13
+
14
+ @classmethod
15
+ def is_sequence(cls) -> bool:
16
+ "Tells if it is sequence."
17
+ raise NotImplementedError("is_sequence must be overwritten.")
18
+
19
+
20
+ class OpRunTensor(OpRunValue):
21
+ """
22
+ Wrapper around a tensor.
23
+
24
+ :param tensor: torch.Tensor
25
+ :param is_constant: is it a constant
26
+ :param may_cpu: change the device the tensor is if
27
+ more appropriate
28
+ """
29
+
30
+ def __init__(self, tensor, is_constant: bool = False, may_cpu: bool = False):
31
+ assert isinstance(tensor, torch.Tensor), (
32
+ f"Unexpected type {type(tensor)}, "
33
+ f"__name__={getattr(tensor, '__name__', 'no name')}"
34
+ )
35
+ assert tensor is None or tensor.numel() != 1 or tensor.item() != -666666
36
+ self.tensor = (
37
+ tensor.cpu()
38
+ if may_cpu
39
+ and len(tensor.shape) == 1
40
+ and tensor.numel() < 8
41
+ and tensor.dtype == torch.int64
42
+ and tensor.get_device() >= 0
43
+ else tensor
44
+ )
45
+ self.is_constant = is_constant
46
+ self.cached: Optional[Tuple[int, ...]] = None
47
+
48
+ @classmethod
49
+ def is_sequence(cls) -> bool:
50
+ "Tells if it is sequence."
51
+ return False
52
+
53
+ def to(self, to: Any) -> "OpRunTensor":
54
+ "Changes the device."
55
+ return OpRunTensor(self.tensor.to(to))
56
+
57
+ def string_type(self) -> str:
58
+ "Returns information about the value as a string."
59
+ s = string_type(self.tensor, with_shape=True, with_min_max=True, with_device=True)
60
+ if self.is_constant:
61
+ return f"CST({s})"
62
+ return s
63
+
64
+ def __repr__(self) -> str:
65
+ "usual"
66
+ if self.is_constant:
67
+ return (
68
+ f"{self.__class__.__name__}"
69
+ f"({string_type(self.tensor, with_shape=True)}, is_constant=True)"
70
+ )
71
+ return f"{self.__class__.__name__}({string_type(self.tensor, with_shape=True)})"
72
+
73
+ @property
74
+ def tensor_or_sequence(self) -> Union[torch.Tensor, List[torch.Tensor]]:
75
+ "Returns either a tensor or a sequence."
76
+ return self.tensor
77
+
78
+ @property
79
+ def shape(self):
80
+ "shape (torch shape)"
81
+ return self.tensor.shape
82
+
83
+ @property
84
+ def dtype(self):
85
+ "dtype (torch dtype)"
86
+ return self.tensor.dtype
87
+
88
+ def _tensor_as_tuple_int(self) -> Tuple[int, ...]:
89
+ return tuple(map(int, self.tensor))
90
+
91
+ def numel(self) -> int:
92
+ "Returns the number of elements."
93
+ return 0 if self.tensor is None else self.tensor.numel()
94
+
95
+ def get_device(self) -> int:
96
+ "Returns the device id."
97
+ return -1 if self.tensor is None else self.tensor.get_device()
98
+
99
+ @property
100
+ def device(self):
101
+ "Returns the device."
102
+ return -1 if self.tensor is None else self.tensor.device
103
+
104
+ @property
105
+ def as_tuple_int(self) -> Tuple[int, ...]:
106
+ "value as int"
107
+ if self.is_constant:
108
+ if self.cached is None:
109
+ self.cached = self._tensor_as_tuple_int()
110
+ return self.cached
111
+ return self._tensor_as_tuple_int()
112
+
113
+ def copy(self) -> "OpRunTensor":
114
+ "Shallow copy."
115
+ return self.__class__(self.tensor)
116
+
117
+
118
+ class OpRunSequence(OpRunValue):
119
+ """Defines a sequence."""
120
+
121
+ def __init__(
122
+ self, sequence: Optional[List[torch.Tensor]] = None, dtype: torch.dtype = torch.float32
123
+ ):
124
+ self.tensor = torch.tensor(-666666, dtype=dtype)
125
+ self.is_shape = False
126
+ self.sequence = sequence or []
127
+ self.cached: Optional[Tuple[int, ...]] = None
128
+ assert all(
129
+ isinstance(s, torch.Tensor) for s in self.sequence
130
+ ), f"Unexpected type in sequence {[type(s) for s in self.sequence]}"
131
+
132
+ @property
133
+ def dtype(self):
134
+ "dtype (torch dtype)"
135
+ return self.tensor.dtype
136
+
137
+ @property
138
+ def tensor_or_sequence(self) -> Union[torch.Tensor, List[torch.Tensor]]:
139
+ "Returns either a tensor or a sequence."
140
+ return self.sequence
141
+
142
+ @classmethod
143
+ def is_sequence(cls) -> bool:
144
+ "Tells if it is sequence."
145
+ return True
146
+
147
+ def insert_at(
148
+ self, tensor: torch.Tensor, position: Optional[OpRunTensor] = None
149
+ ) -> "OpRunSequence":
150
+ "Inserts a value at a given position."
151
+ assert isinstance(tensor, OpRunTensor), f"Unexpected type {type(tensor)} for tensor"
152
+ new_seq = OpRunSequence()
153
+ seq = self.sequence.copy()
154
+ new_seq.sequence = seq
155
+ if position is None:
156
+ seq.append(tensor.tensor)
157
+ else:
158
+ seq.insert(int(position.tensor.item()), tensor.tensor)
159
+ return new_seq
160
+
161
+ def copy(self) -> "OpRunSequence":
162
+ "Shallow copy."
163
+ return self.__class__(self.sequence, dtype=self.dtype)
164
+
165
+ def string_type(self) -> str:
166
+ "Returns a string which can be printed."
167
+ return string_type(self.sequence, with_shape=True)
168
+
169
+
170
+ class OpRunKernel:
171
+ """
172
+ Main class. Every kernel should inherit from it.
173
+ It does not copy the proto.
174
+ """
175
+
176
+ @classmethod
177
+ def device_dependent(cls) -> bool:
178
+ """
179
+ Returns True if the kernel needs a device to be efficiently initialized.
180
+ """
181
+ return False
182
+
183
+ @classmethod
184
+ def has_subgraphs(cls) -> bool:
185
+ """Returns True if the kernel has subgraphs."""
186
+ return False
187
+
188
+ def __init__(
189
+ self,
190
+ node: onnx.NodeProto,
191
+ version: Optional[int] = None,
192
+ verbose: int = 0,
193
+ custom_kernels: Optional[Dict[Tuple[str, str], type]] = None,
194
+ ):
195
+ assert isinstance(
196
+ node, onnx.NodeProto
197
+ ), f"node must be a NodeProto but node is {type(node)}"
198
+ self.op_type = node.op_type
199
+ self.domain = node.domain
200
+ self.input = node.input
201
+ self.output = node.output
202
+ self.verbose = verbose
203
+ self.custom_kernels = custom_kernels
204
+ if version is None:
205
+ name = self.__class__.__name__.split("_")
206
+ assert (
207
+ len(name) == 2
208
+ ), f"Cannot guess version from name={self.__class__.__name__!r}"
209
+ version = int(name[1])
210
+ self.version = version
211
+ self.name = node.name
212
+
213
+ def __str__(self) -> str:
214
+ "usual"
215
+ if self.domain:
216
+ return (
217
+ f"{self.op_type}[{self.domain}]({', '.join(self.input)}) "
218
+ f"-> {', '.join(self.output)}"
219
+ )
220
+ return f"{self.op_type}({', '.join(self.input)}) -> {', '.join(self.output)}"
221
+
222
+ def run(
223
+ self, *args: Optional[OpRunValue]
224
+ ) -> Union[OpRunValue, Tuple[Optional[OpRunValue], ...]]:
225
+ "Kernel implementation."
226
+ raise NotImplementedError(
227
+ f"Method run is not implemented for kernel {self.__class__.__name__!r}"
228
+ )
229
+
230
+ def _find_attribute(self, node: onnx.NodeProto, name: str):
231
+ for att in node.attribute:
232
+ if att.name == name:
233
+ return att
234
+ return None
235
+
236
+ def get_attribute_float(
237
+ self, node: onnx.NodeProto, name: str, default_value: Optional[float] = None
238
+ ) -> Optional[float]:
239
+ """
240
+ Returns an attribute as an int.
241
+
242
+ :param node: NodeProto
243
+ :param name: name
244
+ :param default_value: default_value
245
+ :return: value
246
+ """
247
+ att = self._find_attribute(node, name)
248
+ return default_value if att is None else float(att.f)
249
+
250
+ def get_attribute_int(
251
+ self, node: onnx.NodeProto, name: str, default_value: Optional[int] = None
252
+ ) -> Optional[int]:
253
+ """
254
+ Returns an attribute as an int.
255
+
256
+ :param node: NodeProto
257
+ :param name: name
258
+ :param default_value: default_value
259
+ :return: value
260
+ """
261
+ att = self._find_attribute(node, name)
262
+ return default_value if att is None else int(att.i)
263
+
264
+ def get_attribute_ints(
265
+ self, node: onnx.NodeProto, name: str, default_value: Optional[Tuple[int, ...]] = None
266
+ ) -> Optional[Tuple[int, ...]]:
267
+ """
268
+ Returns an attribute as a tuple of ints.
269
+
270
+ :param node: NodeProto
271
+ :param name: name
272
+ :param default_value: default_value
273
+ :return: value
274
+ """
275
+ att = self._find_attribute(node, name)
276
+ return default_value if att is None else tuple(map(int, att.ints))
277
+
278
+ def get_attribute_string(
279
+ self, node: onnx.NodeProto, name: str, default_value: Optional[str] = None
280
+ ) -> Optional[str]:
281
+ """
282
+ Returns an attribute as a tuple of ints.
283
+
284
+ :param node: NodeProto
285
+ :param name: name
286
+ :param default_value: default_value
287
+ :return: value
288
+ """
289
+ att = self._find_attribute(node, name)
290
+ return default_value if att is None else att.s.decode("utf-8")
291
+
292
+ def get_attribute_tensor(self, node: onnx.NodeProto, name: str) -> Optional[torch.Tensor]:
293
+ """
294
+ Returns an attribute as a torch tensor.
295
+
296
+ :param node: NodeProto
297
+ :param name: name
298
+ :param default_value: default_value
299
+ :return: value
300
+ """
301
+ att = self._find_attribute(node, name)
302
+ if att is None:
303
+ return None
304
+ return to_tensor(att.t)
305
+
306
+ def same_device(self, *tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
307
+ """Puts all tensors on the same device."""
308
+ devices = [t.get_device() for t in tensors]
309
+ if len(set(devices)) == 1:
310
+ return tuple(tensors)
311
+ index = devices.index(max(devices))
312
+ device = tensors[index].device
313
+ return tuple(t.to(device) for t in tensors)
314
+
315
+
316
+ class OpRunFunction(OpRunKernel):
317
+ """
318
+ Defines a kernel based on a local functions.
319
+ """
320
+
321
+ def __init__(
322
+ self,
323
+ runtime: "onnx_diagnostic.reference.TorchOnnxEvaluator", # noqa: F821
324
+ node: onnx.NodeProto,
325
+ version: Optional[int] = None,
326
+ verbose: int = 0,
327
+ ):
328
+ super().__init__(node, version, verbose=verbose)
329
+ self.runtime = runtime
330
+ self.input_names = runtime.input_names
331
+
332
+ def run(
333
+ self, *args: Optional[OpRunValue]
334
+ ) -> Union[OpRunValue, Tuple[Optional[OpRunValue], ...]]:
335
+ return self.runtime.run_with_values(*args)
@@ -0,0 +1,94 @@
1
+ from typing import Optional
2
+ import onnx
3
+ import torch
4
+ from . import OpRunKernel, OpRunTensor
5
+
6
+
7
+ class Gather_1(OpRunKernel):
8
+ "Gather"
9
+
10
+ def __init__(
11
+ self,
12
+ node: onnx.NodeProto,
13
+ version: Optional[int] = None,
14
+ verbose: int = 0,
15
+ ):
16
+ super().__init__(node, version, verbose=verbose)
17
+ axis = self.get_attribute_int(node, "axis", 0)
18
+ assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}"
19
+ self.axis = axis
20
+
21
+ def run(self, x, indices):
22
+ if indices.tensor.numel() == 0:
23
+ return torch.empty((0,), dtype=x.tensor.dtype, device=x.tensor.device)
24
+ ind = [slice(0, s) for s in x.shape]
25
+ ind[self.axis] = indices.tensor
26
+ return OpRunTensor(x.tensor[tuple(ind)])
27
+
28
+
29
+ class ScatterND_16(OpRunKernel):
30
+ "ScatterND"
31
+
32
+ def __init__(
33
+ self,
34
+ node: onnx.NodeProto,
35
+ version: Optional[int] = None,
36
+ verbose: int = 0,
37
+ ):
38
+ super().__init__(node, version, verbose=verbose)
39
+ self.reduction = self.get_attribute_string(node, "reduction", "none")
40
+
41
+ def run(
42
+ self, data: OpRunTensor, indices: OpRunTensor, updates: OpRunTensor
43
+ ) -> OpRunTensor:
44
+ # This implementation is not efficient.
45
+ grids = torch.meshgrid(*[torch.arange(s) for s in indices.shape[:-1]], indexing="ij")
46
+ stacked = torch.stack(grids, dim=-1)
47
+ index = stacked.reshape(-1, len(indices.shape) - 1)
48
+ output = data.tensor.clone()
49
+ for i in index:
50
+ if self.reduction == "add":
51
+ output[indices.tensor[i]] += updates.tensor[i]
52
+ elif self.reduction == "mul":
53
+ output[indices.tensor[i]] *= updates.tensor[i]
54
+ elif self.reduction == "max":
55
+ output[indices.tensor[i]] = torch.maximum(
56
+ output[indices.tensor[i]], updates.tensor[i]
57
+ )
58
+ elif self.reduction == "min":
59
+ output[indices.tensor[i]] = torch.minimum(
60
+ output[indices.tensor[i]], updates.tensor[i]
61
+ )
62
+ else:
63
+ output[indices.tensor[i]] = updates.tensor[i]
64
+ return OpRunTensor(output)
65
+
66
+
67
+ class Slice_13(OpRunKernel):
68
+ "Slice"
69
+
70
+ def run(
71
+ self,
72
+ data: OpRunTensor,
73
+ starts: OpRunTensor,
74
+ ends: OpRunTensor,
75
+ axes: Optional[OpRunTensor] = None,
76
+ steps: Optional[OpRunTensor] = None,
77
+ ) -> OpRunTensor:
78
+ if axes is None:
79
+ if steps is None:
80
+ slices = [slice(s, e) for s, e in zip(starts.tensor, ends.tensor)]
81
+ else:
82
+ slices = [
83
+ slice(s, e, d) for s, e, d in zip(starts.tensor, ends.tensor, steps.tensor)
84
+ ]
85
+ else:
86
+ if steps is None:
87
+ slices = [slice(0, a) for a in data.shape]
88
+ for s, e, a in zip(starts.tensor, ends.tensor, axes.tensor):
89
+ slices[a] = slice(s, e)
90
+ else:
91
+ slices = [slice(0, a) for a in data.shape]
92
+ for s, e, a, d in zip(starts.tensor, ends.tensor, axes.tensor, steps.tensor):
93
+ slices[a] = slice(s, e, d)
94
+ return OpRunTensor(data.tensor[tuple(slices)])
@@ -0,0 +1,108 @@
1
+ import torch
2
+ from . import OpRunKernel, OpRunTensor
3
+
4
+
5
+ class OpRunBinary(OpRunKernel):
6
+ "Binary Op"
7
+
8
+ def run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
9
+ if x.get_device() != y.get_device():
10
+ if x.get_device() >= 0:
11
+ y = y.to(x.device)
12
+ else:
13
+ x = x.to(y.device)
14
+ return self._run(x, y)
15
+
16
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
17
+ raise NotImplementedError(f"Operator {self.__class__.__name__!r} is not complete.")
18
+
19
+
20
+ class And_1(OpRunBinary):
21
+ """And"""
22
+
23
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
24
+ return OpRunTensor(x.tensor & y.tensor)
25
+
26
+
27
+ class Add_1(OpRunBinary):
28
+ """Add"""
29
+
30
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
31
+ return OpRunTensor(x.tensor + y.tensor)
32
+
33
+
34
+ class Div_1(OpRunBinary):
35
+ """Div"""
36
+
37
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
38
+ return OpRunTensor(x.tensor / y.tensor)
39
+
40
+
41
+ class Equal_1(OpRunBinary):
42
+ """Equal"""
43
+
44
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
45
+ return OpRunTensor(x.tensor == y.tensor)
46
+
47
+
48
+ class Greater_1(OpRunBinary):
49
+ """Greater"""
50
+
51
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
52
+ return OpRunTensor(x.tensor > y.tensor)
53
+
54
+
55
+ class GreaterOrEqual_1(OpRunBinary):
56
+ """GreaterOrEqual"""
57
+
58
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
59
+ return OpRunTensor(x.tensor >= y.tensor)
60
+
61
+
62
+ class Less_1(OpRunBinary):
63
+ """Less"""
64
+
65
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
66
+ return OpRunTensor(x.tensor < y.tensor)
67
+
68
+
69
+ class LessOrEqual_1(OpRunBinary):
70
+ """LessOrEqual"""
71
+
72
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
73
+ return OpRunTensor(x.tensor <= y.tensor)
74
+
75
+
76
+ class MatMul_1(OpRunBinary):
77
+ """MatMul"""
78
+
79
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
80
+ return OpRunTensor(x.tensor @ y.tensor)
81
+
82
+
83
+ class Mul_1(OpRunBinary):
84
+ """Mul"""
85
+
86
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
87
+ return OpRunTensor(x.tensor * y.tensor)
88
+
89
+
90
+ class Or_1(OpRunBinary):
91
+ """Or"""
92
+
93
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
94
+ return OpRunTensor(x.tensor | y.tensor)
95
+
96
+
97
+ class Pow_12(OpRunBinary):
98
+ """Pow"""
99
+
100
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
101
+ return OpRunTensor(torch.pow(x.tensor, y.tensor))
102
+
103
+
104
+ class Sub_1(OpRunBinary):
105
+ """Sub"""
106
+
107
+ def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
108
+ return OpRunTensor(x.tensor - y.tensor)