onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,121 @@
1
+ from typing import Any, Dict, Optional
2
+ import onnx
3
+ import torch
4
+ from . import OpRunKernel, OpRunTensor
5
+
6
+
7
+ class OpRunControlFlow(OpRunKernel):
8
+ """Common ancestor for control flows."""
9
+
10
+ @classmethod
11
+ def has_subgraphs(cls) -> bool:
12
+ """Returns True if the kernel has subgraphs."""
13
+ return True
14
+
15
+ def __init__(
16
+ self,
17
+ node: onnx.NodeProto,
18
+ version: Optional[int] = None,
19
+ parent: Optional["onnx_diagnostic.reference.TorchOnnxEvaluator"] = None, # noqa: F821
20
+ verbose: int = 0,
21
+ ):
22
+ super().__init__(node, version, verbose=verbose)
23
+ assert (
24
+ parent is not None
25
+ ), f"parent must be specified for operator {self.__class__.__name__!r}"
26
+ for att in node.attribute:
27
+ if att.type == onnx.AttributeProto.GRAPH:
28
+ rt = parent.__class__(
29
+ att.g,
30
+ providers=parent.providers,
31
+ opsets=parent.opsets,
32
+ local_functions=parent.functions,
33
+ verbose=parent.verbose,
34
+ custom_kernels=parent.custom_kernels,
35
+ )
36
+ setattr(self, att.name, rt)
37
+
38
+
39
+ class If_1(OpRunControlFlow):
40
+ "If"
41
+
42
+ def run(self, cond, context: Optional[Dict[str, Any]] = None):
43
+ rt = self.then_branch if cond.tensor.item() else self.else_branch # type: ignore[attr-defined]
44
+ return rt.run_with_values(context=context)
45
+
46
+
47
+ class Loop_16(OpRunControlFlow):
48
+ "Loop"
49
+
50
+ def __init__(
51
+ self,
52
+ node: onnx.NodeProto,
53
+ version: Optional[int] = None,
54
+ parent: Optional["onnx_diagnostic.reference.TorchOnnxEvaluator"] = None, # noqa: F821
55
+ verbose: int = 0,
56
+ ):
57
+ super().__init__(node, version, parent, verbose=verbose)
58
+ self.output_index = {n: i for i, n in enumerate(self.body.output_names)}
59
+ self.N = len(self.body.input_names) - 2
60
+ self.K = len(self.body.output_names) - self.N - 1
61
+
62
+ def run(self, M, cond, *args, context: Optional[Dict[str, Any]] = None):
63
+ if args:
64
+ v_initial = args[0]
65
+ args = args[1:]
66
+ else:
67
+ v_initial = None
68
+ assert M is None or hasattr(
69
+ M, "dtype"
70
+ ), f"M must be empty or an array but its type is {type(M)}."
71
+ body = self.body
72
+ loop_inputs = body.input_names
73
+ inputs = dict.fromkeys(loop_inputs)
74
+ if v_initial is not None:
75
+ inputs[loop_inputs[2]] = v_initial
76
+ cond_name = body.output_names[0]
77
+ if args:
78
+ begin = len(loop_inputs) - len(args)
79
+ all_inputs = loop_inputs[begin:]
80
+ for name, val in zip(all_inputs, args):
81
+ inputs[name] = val
82
+ if context is not None:
83
+ for a in context:
84
+ inputs[a] = context[a]
85
+
86
+ k_carried_away = [[] for i in range(self.K)] # type: ignore
87
+ it = 0
88
+ while (cond is None or cond.tensor is None or cond.tensor.item()) and (
89
+ M is None or M.tensor is None or it < M.tensor.item()
90
+ ):
91
+ if len(body.input_names) > 0 and body.input_names[0] is not None:
92
+ inputs[body.input_names[0]] = OpRunTensor(
93
+ torch.tensor(it, dtype=None if M is None else M.dtype)
94
+ )
95
+ if len(body.input_names) > 1 and body.input_names[1] is not None:
96
+ inputs[body.input_names[1]] = cond
97
+ outputs = list(
98
+ self.body.run_with_values(
99
+ *[inputs[k] for k in self.body.input_names], context=context
100
+ )
101
+ )
102
+ if self.K > 0:
103
+ for k in range(self.K):
104
+ k_carried_away[k].append(outputs[-self.K + k])
105
+ index_cond = self.output_index[cond_name]
106
+ cond = outputs[index_cond]
107
+ assert (
108
+ cond is not None
109
+ ), f"Condition {cond_name!r} returned by the subgraph cannot be None."
110
+ for i, o in zip(body.input_names[2:], body.output_names[1:]):
111
+ inputs[i] = outputs[self.output_index[o]]
112
+ it += 1
113
+
114
+ if it == 0:
115
+ outputs = [inputs[i] for i in body.input_names[2:]]
116
+ else:
117
+ outputs = outputs[1 : 1 + self.N]
118
+ outputs.extend([OpRunTensor(torch.cat(x, axis=0)) for x in k_carried_away])
119
+ while len(outputs) < len(self.body.output_names):
120
+ outputs.append(OpRunTensor(torch.empty(())))
121
+ return tuple(outputs)
@@ -0,0 +1,36 @@
1
+ from typing import Optional
2
+ import onnx
3
+ import torch
4
+ from . import OpRunKernel, OpRunTensor
5
+
6
+
7
+ class Range_11(OpRunKernel):
8
+ """Range"""
9
+
10
+ @classmethod
11
+ def device_dependent(cls) -> bool:
12
+ """
13
+ Returns True if the kernel needs a device to be efficiently initialized.
14
+ """
15
+ return True
16
+
17
+ def __init__(
18
+ self,
19
+ node: onnx.NodeProto,
20
+ version: Optional[int] = None,
21
+ device: Optional[torch.device] = None,
22
+ verbose: int = 0,
23
+ ):
24
+ super().__init__(node, version, verbose=verbose)
25
+ self.device = device
26
+
27
+ def run(self, starts: OpRunTensor, limit: OpRunTensor, delta: OpRunTensor) -> OpRunTensor:
28
+ return OpRunTensor(
29
+ torch.arange(
30
+ starts.tensor,
31
+ limit.tensor,
32
+ delta.tensor,
33
+ dtype=starts.dtype,
34
+ device=self.device,
35
+ )
36
+ )
@@ -0,0 +1,196 @@
1
+ from typing import Optional, Tuple
2
+ import onnx
3
+ import torch
4
+ from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
5
+ from . import OpRunKernel, OpRunTensor
6
+
7
+
8
+ class AveragePool_11(OpRunKernel):
9
+ "AveragePool"
10
+
11
+ def __init__(
12
+ self,
13
+ node: onnx.NodeProto,
14
+ version: Optional[int] = None,
15
+ verbose: int = 0,
16
+ ):
17
+ super().__init__(node, version, verbose=verbose)
18
+ self.auto_pad = self.get_attribute_string(node, "auto_pad", "NOTSET")
19
+ self.ceil_mode = bool(self.get_attribute_int(node, "ceil_mode", 0))
20
+ self.count_include_pad = bool(self.get_attribute_int(node, "count_include_pad", 0))
21
+ self.dilations = self.get_attribute_ints(node, "dilations", None)
22
+ self.kernel_shape: Tuple[int, ...] = (
23
+ self.get_attribute_ints(node, "kernel_shape") or tuple()
24
+ )
25
+ self.pads = self.get_attribute_ints(node, "pads", None)
26
+ self.strides = self.get_attribute_ints(node, "strides", None)
27
+
28
+ def run(self, x):
29
+ kernel_shape = self.kernel_shape
30
+ dilations = self.dilations or [1 for _ in x.shape[2:]]
31
+ strides = self.strides or [1 for _ in x.shape[2:]]
32
+ pads = self.pads or ([0 for _ in x.shape[2:]] * 2)
33
+ assert (
34
+ self.auto_pad == "NOTSET"
35
+ ), f"conv not implemented for auto_pad={self.auto_pad!r}"
36
+ assert len(set(pads)) == 1, f"conv not implemented for pads={pads}"
37
+ assert set(dilations) == {1}, f"conv not implemented for dilations={dilations}"
38
+ avg_pool = getattr(torch.nn.functional, f"avg_pool{len(kernel_shape)}d")
39
+ return OpRunTensor(
40
+ avg_pool(
41
+ x.tensor,
42
+ kernel_size=tuple(kernel_shape),
43
+ stride=tuple(strides),
44
+ padding=pads[0],
45
+ ceil_mode=self.ceil_mode,
46
+ count_include_pad=self.count_include_pad,
47
+ # dilation=tuple(dilations),
48
+ )
49
+ )
50
+
51
+
52
+ class Conv_11(OpRunKernel):
53
+ "Conv"
54
+
55
+ def __init__(
56
+ self,
57
+ node: onnx.NodeProto,
58
+ version: Optional[int] = None,
59
+ verbose: int = 0,
60
+ ):
61
+ super().__init__(node, version, verbose=verbose)
62
+ self.auto_pad = self.get_attribute_string(node, "auto_pad", "NOTSET")
63
+ self.dilations = self.get_attribute_ints(node, "dilations", None)
64
+ self.group = self.get_attribute_int(node, "group", 1)
65
+ self.kernel_shape: Tuple[int, ...] = (
66
+ self.get_attribute_ints(node, "kernel_shape") or tuple()
67
+ )
68
+ self.pads = self.get_attribute_ints(node, "pads", None)
69
+ self.strides = self.get_attribute_ints(node, "strides", None)
70
+
71
+ def run(self, x, w, b=None):
72
+ kernel_shape = self.kernel_shape or w.shape[2:]
73
+ assert (
74
+ tuple(kernel_shape) == w.shape[-len(kernel_shape) :]
75
+ ), f"conv not implemented for kernel_shape={kernel_shape} and w.shape={w.shape}"
76
+ dilations = self.dilations or [1 for _ in x.shape[2:]]
77
+ strides = self.strides or [1 for _ in x.shape[2:]]
78
+
79
+ if self.auto_pad in {"SAME_LOWER", "SAME_UPPER"}:
80
+ head = []
81
+ tail = []
82
+ for i in range(len(x.shape) - 2):
83
+ d = x.shape[i + 2]
84
+ target_size = (d + strides[i] - 1) // strides[i]
85
+ pad_needed = (target_size - 1) * strides[i] + kernel_shape[i] - d
86
+ pad_head = (
87
+ (pad_needed + 1) // 2 if self.auto_pad == "SAME_LOWER" else pad_needed // 2
88
+ )
89
+ pad_tail = pad_needed - pad_head
90
+ head.append(pad_head)
91
+ tail.append(pad_tail)
92
+ pads = head + tail
93
+ else:
94
+ pads = self.pads or ([0 for _ in x.shape[2:]] * 2)
95
+
96
+ assert len(set(pads)) == 1, (
97
+ f"conv not implemented for pads={pads}, "
98
+ f"auto_pad={self.auto_pad!r}, strides={strides}, "
99
+ f"x.shape={x.shape}, kernel_shape={kernel_shape}"
100
+ )
101
+
102
+ if b is None:
103
+ bias = None
104
+ else:
105
+ bias = b.tensor.squeeze()
106
+ if not bias.shape:
107
+ bias = bias.unsqueeze(0)
108
+ return OpRunTensor(
109
+ torch.nn.functional.conv2d(
110
+ x.tensor,
111
+ w.tensor,
112
+ bias=bias,
113
+ stride=tuple(strides),
114
+ padding=pads[0],
115
+ dilation=tuple(dilations),
116
+ groups=self.group,
117
+ )
118
+ )
119
+
120
+
121
+ class LayerNormalization_17(OpRunKernel):
122
+ "LayerNormalization"
123
+
124
+ def __init__(
125
+ self,
126
+ node: onnx.NodeProto,
127
+ version: Optional[int] = None,
128
+ verbose: int = 0,
129
+ ):
130
+ super().__init__(node, version, verbose=verbose)
131
+ self.axis = self.get_attribute_int(node, "axis", -1)
132
+ self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
133
+ self.stash_type = onnx_dtype_to_torch_dtype(
134
+ self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT) # type: ignore[arg-type]
135
+ )
136
+ self.compute_std = len(node.output) > 1
137
+
138
+ def run(self, x, scale, bias=None):
139
+ original_dtype = x.dtype
140
+ if self.stash_type == torch.float32 and x.tensor.dtype != torch.float64:
141
+ xt = x.tensor
142
+ res = torch.nn.functional.layer_norm(
143
+ xt,
144
+ xt.shape[self.axis :],
145
+ weight=scale.tensor,
146
+ bias=None if bias is None else bias.tensor,
147
+ eps=self.epsilon,
148
+ )
149
+ else:
150
+ xt = x.tensor.to(self.stash_type)
151
+ res = torch.nn.functional.layer_norm(
152
+ xt,
153
+ xt.shape[self.axis :],
154
+ weight=scale.tensor.to(self.stash_type),
155
+ bias=None if bias is None else bias.tensor.to(self.stash_type),
156
+ eps=self.epsilon,
157
+ )
158
+ if not self.compute_std:
159
+ return OpRunTensor(res.to(original_dtype))
160
+ axes = tuple(range(len(xt.shape)))[self.axis :]
161
+ mean, var = torch.var(xt, dim=axes, keepdim=False)
162
+ x_inv_std_dev = torch.reciprocal(torch.sqrt(var + self.epsilon))
163
+ return (
164
+ OpRunTensor(res.to(original_dtype)),
165
+ OpRunTensor(mean),
166
+ OpRunTensor(x_inv_std_dev),
167
+ )
168
+
169
+
170
+ class Softmax_13(OpRunKernel):
171
+ "Softmax"
172
+
173
+ def __init__(
174
+ self,
175
+ node: onnx.NodeProto,
176
+ version: Optional[int] = None,
177
+ verbose: int = 0,
178
+ ):
179
+ super().__init__(node, version, verbose=verbose)
180
+ self.axis = self.get_attribute_int(node, "axis", -1)
181
+ assert isinstance(self.axis, int), f"Unexpected value for attribute axis={self.axis!r}"
182
+ # this is out of spec
183
+ stash_type = self.get_attribute_int(node, "stash_type", None)
184
+ self.stash_type = None if stash_type is None else onnx_dtype_to_torch_dtype(stash_type)
185
+
186
+ def run(self, data: OpRunTensor) -> OpRunTensor:
187
+ return OpRunTensor(
188
+ torch.nn.functional.softmax(data.tensor, dim=self.axis, dtype=self.stash_type)
189
+ )
190
+
191
+
192
+ class Tanh_6(OpRunKernel):
193
+ "Tanh"
194
+
195
+ def run(self, data: OpRunTensor) -> OpRunTensor:
196
+ return OpRunTensor(torch.nn.functional.tanh(data.tensor))
@@ -0,0 +1,106 @@
1
+ from typing import Optional
2
+ import onnx
3
+ import torch
4
+ from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
5
+ from . import OpRunKernel, OpRunTensor
6
+
7
+
8
+ class Cast_6(OpRunKernel):
9
+ "Cast"
10
+
11
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
12
+ super().__init__(node, version, verbose=verbose)
13
+ to = self.get_attribute_int(node, "to", 0)
14
+ assert isinstance(to, int), f"Unexpected value for attribute to={to!r}"
15
+ self.to = onnx_dtype_to_torch_dtype(to)
16
+ self.saturate = self.get_attribute_int(node, "saturate", 1)
17
+ assert self.saturate == 1, f"saturate={self.saturate} not implemented for Cast"
18
+
19
+ def run(self, data: OpRunTensor) -> OpRunTensor:
20
+ return OpRunTensor(data.tensor.to(self.to))
21
+
22
+
23
+ class CastLike_15(OpRunKernel):
24
+ "Cast"
25
+
26
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
27
+ super().__init__(node, version, verbose=verbose)
28
+ self.saturate = self.get_attribute_int(node, "saturate", 1)
29
+ assert self.saturate == 1, f"saturate={self.saturate} not implemented for CastLike"
30
+
31
+ def run(self, data: OpRunTensor, like: OpRunTensor) -> OpRunTensor:
32
+ return OpRunTensor(data.tensor.to(like.tensor.dtype))
33
+
34
+
35
+ class Concat_1(OpRunKernel):
36
+ "Concat"
37
+
38
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
39
+ super().__init__(node, version, verbose=verbose)
40
+ axis = self.get_attribute_int(node, "axis", 0)
41
+ assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}"
42
+ self.axis = axis
43
+
44
+ def run(self, *data: OpRunTensor) -> OpRunTensor:
45
+ assert data, f"No tensor to concatenate in node name {self.name!r}"
46
+ devices = [d.get_device() for d in data]
47
+ if len(set(devices)) == 1:
48
+ return OpRunTensor(torch.cat([t.tensor for t in data], axis=self.axis))
49
+ if (
50
+ data[0].dtype == torch.int64
51
+ and self.axis == 0
52
+ and max(d.tensor.ndim for d in data) == 1
53
+ and max(d.tensor.numel() for d in data) <= 8
54
+ ):
55
+ # This is a shape
56
+ return OpRunTensor(torch.cat([t.tensor.cpu() for t in data], axis=self.axis))
57
+ index = devices.index(max(devices))
58
+ device = data[index].tensor.device
59
+ return OpRunTensor(torch.cat([t.tensor.to(device) for t in data], axis=self.axis))
60
+
61
+
62
+ class NonZero_13(OpRunKernel):
63
+ "NonZero"
64
+
65
+ def run(self, x: OpRunTensor) -> OpRunTensor:
66
+ return OpRunTensor(torch.nonzero(x.tensor).T)
67
+
68
+
69
+ class Tile_6(OpRunKernel):
70
+ "Tile"
71
+
72
+ def run(self, x: OpRunTensor, repeat: OpRunTensor) -> OpRunTensor:
73
+ return OpRunTensor(torch.tile(x.tensor, repeat.as_tuple_int))
74
+
75
+
76
+ class Transpose_1(OpRunKernel):
77
+ "Transpose"
78
+
79
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
80
+ super().__init__(node, version, verbose=verbose)
81
+ self.perm = self.get_attribute_ints(node, "perm", None)
82
+
83
+ def run(self, data: OpRunTensor) -> OpRunTensor:
84
+ return OpRunTensor(torch.permute(data.tensor, self.perm))
85
+
86
+
87
+ class Trilu_14(OpRunKernel):
88
+ "Trilu"
89
+
90
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
91
+ super().__init__(node, version, verbose=verbose)
92
+ self.upper = self.get_attribute_int(node, "upper", 1)
93
+
94
+ def run(self, data: OpRunTensor, k: Optional[OpRunTensor] = None) -> OpRunTensor:
95
+ diagonal = 0 if k is None else k.tensor.item()
96
+ if self.upper:
97
+ return OpRunTensor(torch.triu(data.tensor, diagonal=diagonal))
98
+ return OpRunTensor(torch.tril(data.tensor, diagonal=diagonal))
99
+
100
+
101
+ class Where_9(OpRunKernel):
102
+ "Where"
103
+
104
+ def run(self, cond: OpRunTensor, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
105
+ tcond, tx, ty = self.same_device(cond.tensor, x.tensor, y.tensor)
106
+ return OpRunTensor(torch.where(tcond, tx, ty))
@@ -0,0 +1,130 @@
1
+ from typing import Optional, Tuple
2
+ import onnx
3
+ import torch
4
+ from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
5
+ from . import OpRunKernel, OpRunTensor
6
+
7
+
8
+ class ReduceOp(OpRunKernel):
9
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
10
+ super().__init__(node, version, verbose=verbose)
11
+ self.keepdims = bool(self.get_attribute_int(node, "keepdims", 1))
12
+ self.noop_with_empty_axes = bool(
13
+ self.get_attribute_int(node, "noop_with_empty_axes", 0)
14
+ )
15
+ assert isinstance(
16
+ self.keepdims, bool
17
+ ), f"Unexpected value for attribute keepdims={self.keepdims!r}"
18
+ assert isinstance(self.noop_with_empty_axes, bool), (
19
+ f"Unexpected value for attribute "
20
+ f"noop_with_empty_axes={self.noop_with_empty_axes!r}"
21
+ )
22
+ assert (
23
+ not self.noop_with_empty_axes
24
+ ), f"Not implemented with noop_with_empty_axes={self.noop_with_empty_axes}"
25
+ # this is out of spec
26
+ stash_type = self.get_attribute_int(node, "stash_type", None)
27
+ self.stash_type = None if stash_type is None else onnx_dtype_to_torch_dtype(stash_type)
28
+
29
+
30
+ class ReduceOpAxes(ReduceOp):
31
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
32
+ super().__init__(node, version, verbose=verbose)
33
+ self.axes: Tuple[int, ...] = self.get_attribute_ints(node, "axes") or tuple()
34
+
35
+
36
+ class ReduceMax_18(ReduceOp):
37
+ """ReduceMax"""
38
+
39
+ def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
40
+ assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
41
+ if axes is None:
42
+ assert (
43
+ not self.keepdims
44
+ ), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
45
+ return OpRunTensor(x.tensor.max())
46
+ taxes = axes.as_tuple_int
47
+ if len(taxes) == 1:
48
+ t = x.tensor.max(taxes[0], keepdim=self.keepdims)
49
+ return OpRunTensor(t.values)
50
+ t = x.tensor
51
+ for a in reversed(taxes):
52
+ t = t.max(a, keepdim=self.keepdims).values
53
+ return OpRunTensor(t)
54
+
55
+
56
+ class ReduceMean_18(ReduceOp):
57
+ """ReduceMean"""
58
+
59
+ def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
60
+ assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
61
+ if axes is None:
62
+ assert (
63
+ not self.keepdims
64
+ ), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
65
+ return OpRunTensor(torch.mean(x.tensor))
66
+ taxes = axes.as_tuple_int
67
+ if len(taxes) == 1:
68
+ t = x.tensor.mean(taxes[0], keepdim=self.keepdims)
69
+ return OpRunTensor(t)
70
+ t = x.tensor.mean(taxes, keepdim=self.keepdims)
71
+ return OpRunTensor(t)
72
+
73
+
74
+ class ReduceMin_17(ReduceOpAxes):
75
+ """ReduceMin"""
76
+
77
+ def run(self, x: OpRunTensor) -> OpRunTensor:
78
+ assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
79
+ axes = self.axes
80
+ if not axes:
81
+ assert (
82
+ not self.keepdims
83
+ ), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
84
+ return OpRunTensor(x.tensor.min())
85
+ taxes = tuple(axes)
86
+ if len(taxes) == 1:
87
+ t = x.tensor.min(taxes[0], keepdim=self.keepdims)
88
+ return OpRunTensor(t.values)
89
+ t = x.tensor
90
+ for a in reversed(taxes):
91
+ t = t.min(a, keepdim=self.keepdims).values
92
+ return OpRunTensor(t)
93
+
94
+
95
+ class ReduceMin_18(ReduceOp):
96
+ """ReduceMin"""
97
+
98
+ def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
99
+ assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
100
+ if axes is None:
101
+ assert (
102
+ not self.keepdims
103
+ ), f"axes is empty, keepdims={self.keepdims} for {self.__class__.__name__}"
104
+ return OpRunTensor(torch.min(x.tensor))
105
+ taxes = axes.as_tuple_int
106
+ if len(taxes) == 1:
107
+ t = x.tensor.min(taxes[0], keepdim=self.keepdims)
108
+ return OpRunTensor(t.values)
109
+ t = x.tensor
110
+ for a in reversed(taxes):
111
+ t = t.min(a, keepdim=self.keepdims).values
112
+ return OpRunTensor(t)
113
+
114
+
115
+ class ReduceSum_13(ReduceOp):
116
+ """ReduceSum"""
117
+
118
+ def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
119
+ assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
120
+ if axes is None:
121
+ assert (
122
+ not self.keepdims
123
+ ), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
124
+ return OpRunTensor(torch.sum(x.tensor))
125
+ taxes = axes.as_tuple_int
126
+ if len(taxes) == 1:
127
+ t = x.tensor.sum(taxes[0], keepdim=self.keepdims)
128
+ return OpRunTensor(t)
129
+ t = x.tensor.sum(taxes, keepdim=self.keepdims)
130
+ return OpRunTensor(t)
@@ -0,0 +1,65 @@
1
+ from typing import Optional
2
+ import onnx
3
+ import torch
4
+ from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
5
+ from . import OpRunKernel, OpRunSequence, OpRunTensor
6
+
7
+
8
+ class OpRunOpSequence(OpRunKernel):
9
+ "Ancestor for kernel using sequences."
10
+
11
+
12
+ class ConcatFromSequence_11(OpRunOpSequence):
13
+ "ConcatFromSequence"
14
+
15
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
16
+ super().__init__(node, version, verbose=verbose)
17
+ axis = self.get_attribute_int(node, "axis", None)
18
+ assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}"
19
+ self.axis = axis
20
+ self.new_axis = self.get_attribute_int(node, "new_axis", 0)
21
+
22
+ def run(self, input_sequence: OpRunSequence) -> OpRunTensor:
23
+ assert isinstance(
24
+ input_sequence, OpRunSequence
25
+ ), f"Unexpected type {type(input_sequence)} for input_sequence"
26
+ seq = input_sequence.sequence
27
+ if self.new_axis == 1:
28
+ if self.axis == -1:
29
+ seq2 = [s.unsqueeze(len(s.shape)) for s in seq]
30
+ res = torch.cat(seq2, axis=-1)
31
+ else:
32
+ seq2 = [s.expand(self.axis) for s in seq]
33
+ res = torch.cat(seq2, axis=self.axis)
34
+ else:
35
+ res = torch.cat(seq, axis=self.axis)
36
+ return OpRunTensor(res)
37
+
38
+
39
+ class SequenceEmpty_11(OpRunOpSequence):
40
+ "SqeuenceEmpty"
41
+
42
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
43
+ super().__init__(node, version, verbose=verbose)
44
+ self.dtype = onnx_dtype_to_torch_dtype(
45
+ self.get_attribute_int(node, "dtype", onnx.TensorProto.FLOAT) # type: ignore[arg-type]
46
+ )
47
+
48
+ def run(self) -> OpRunSequence:
49
+ return OpRunSequence(dtype=self.dtype)
50
+
51
+
52
+ class SequenceInsert_11(OpRunOpSequence):
53
+ "SqeuenceInsert"
54
+
55
+ def run(
56
+ self,
57
+ input_sequence: OpRunSequence,
58
+ tensor: OpRunTensor,
59
+ position: Optional[OpRunTensor] = None,
60
+ ) -> OpRunSequence:
61
+ assert isinstance(input_sequence, OpRunSequence), (
62
+ f"Unexpected type {type(input_sequence)} for input_sequence: "
63
+ f"{input_sequence.string_type()}"
64
+ )
65
+ return input_sequence.insert_at(tensor, position)