onnx-diagnostic 0.5.0__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. onnx_diagnostic/__init__.py +2 -2
  2. onnx_diagnostic/_command_lines_parser.py +39 -1
  3. onnx_diagnostic/api.py +15 -0
  4. onnx_diagnostic/export/dynamic_shapes.py +14 -5
  5. onnx_diagnostic/ext_test_case.py +15 -1
  6. onnx_diagnostic/helpers/args_helper.py +1 -1
  7. onnx_diagnostic/helpers/graph_helper.py +386 -0
  8. onnx_diagnostic/helpers/helper.py +30 -5
  9. onnx_diagnostic/helpers/model_builder_helper.py +349 -0
  10. onnx_diagnostic/helpers/rt_helper.py +69 -1
  11. onnx_diagnostic/helpers/torch_helper.py +2 -0
  12. onnx_diagnostic/reference/__init__.py +1 -0
  13. onnx_diagnostic/reference/torch_evaluator.py +518 -0
  14. onnx_diagnostic/reference/torch_ops/__init__.py +55 -0
  15. onnx_diagnostic/reference/torch_ops/_op_run.py +326 -0
  16. onnx_diagnostic/reference/torch_ops/access_ops.py +84 -0
  17. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  18. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +118 -0
  19. onnx_diagnostic/reference/torch_ops/generator_ops.py +35 -0
  20. onnx_diagnostic/reference/torch_ops/nn_ops.py +176 -0
  21. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  22. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  23. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  24. onnx_diagnostic/reference/torch_ops/shape_ops.py +120 -0
  25. onnx_diagnostic/reference/torch_ops/unary_ops.py +86 -0
  26. onnx_diagnostic/tasks/__init__.py +22 -1
  27. onnx_diagnostic/tasks/image_classification.py +2 -2
  28. onnx_diagnostic/tasks/text_generation.py +3 -3
  29. onnx_diagnostic/torch_export_patches/eval/__init__.py +690 -0
  30. onnx_diagnostic/torch_export_patches/eval/model_cases.py +883 -0
  31. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +34 -1
  32. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +6 -1
  33. onnx_diagnostic/torch_export_patches/patch_module_helper.py +148 -28
  34. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +91 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +117 -1
  36. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +142 -0
  37. onnx_diagnostic/torch_models/test_helper.py +225 -22
  38. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  39. {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/METADATA +1 -1
  40. {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/RECORD +43 -24
  41. {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/WHEEL +1 -1
  42. {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/licenses/LICENSE.txt +0 -0
  43. {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,176 @@
1
+ from typing import Optional, Tuple
2
+ import onnx
3
+ import torch
4
+ from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
5
+ from . import OpRun, OpRunTensor
6
+
7
+
8
+ class AveragePool_11(OpRun):
9
+ "AveragePool"
10
+
11
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
12
+ super().__init__(node, version)
13
+ self.auto_pad = self.get_attribute_string(node, "auto_pad", "NOTSET")
14
+ self.ceil_mode = bool(self.get_attribute_int(node, "ceil_mode", 0))
15
+ self.count_include_pad = bool(self.get_attribute_int(node, "count_include_pad", 0))
16
+ self.dilations = self.get_attribute_ints(node, "dilations", None)
17
+ self.kernel_shape: Tuple[int, ...] = (
18
+ self.get_attribute_ints(node, "kernel_shape") or tuple()
19
+ )
20
+ self.pads = self.get_attribute_ints(node, "pads", None)
21
+ self.strides = self.get_attribute_ints(node, "strides", None)
22
+
23
+ def run(self, x):
24
+ kernel_shape = self.kernel_shape
25
+ dilations = self.dilations or [1 for _ in x.shape[2:]]
26
+ strides = self.strides or [1 for _ in x.shape[2:]]
27
+ pads = self.pads or ([0 for _ in x.shape[2:]] * 2)
28
+ assert (
29
+ self.auto_pad == "NOTSET"
30
+ ), f"conv not implemented for auto_pad={self.auto_pad!r}"
31
+ assert len(set(pads)) == 1, f"conv not implemented for pads={pads}"
32
+ assert set(dilations) == {1}, f"conv not implemented for dilations={dilations}"
33
+ avg_pool = getattr(torch.nn.functional, f"avg_pool{len(kernel_shape)}d")
34
+ return OpRunTensor(
35
+ avg_pool(
36
+ x.tensor,
37
+ kernel_size=tuple(kernel_shape),
38
+ stride=tuple(strides),
39
+ padding=pads[0],
40
+ ceil_mode=self.ceil_mode,
41
+ count_include_pad=self.count_include_pad,
42
+ # dilation=tuple(dilations),
43
+ )
44
+ )
45
+
46
+
47
+ class Conv_11(OpRun):
48
+ "Conv"
49
+
50
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
51
+ super().__init__(node, version)
52
+ self.auto_pad = self.get_attribute_string(node, "auto_pad", "NOTSET")
53
+ self.dilations = self.get_attribute_ints(node, "dilations", None)
54
+ self.group = self.get_attribute_int(node, "group", 1)
55
+ self.kernel_shape: Tuple[int, ...] = (
56
+ self.get_attribute_ints(node, "kernel_shape") or tuple()
57
+ )
58
+ self.pads = self.get_attribute_ints(node, "pads", None)
59
+ self.strides = self.get_attribute_ints(node, "strides", None)
60
+
61
+ def run(self, x, w, b=None):
62
+ kernel_shape = self.kernel_shape or w.shape[2:]
63
+ assert (
64
+ tuple(kernel_shape) == w.shape[-len(kernel_shape) :]
65
+ ), f"conv not implemented for kernel_shape={kernel_shape} and w.shape={w.shape}"
66
+ dilations = self.dilations or [1 for _ in x.shape[2:]]
67
+ strides = self.strides or [1 for _ in x.shape[2:]]
68
+
69
+ if self.auto_pad in {"SAME_LOWER", "SAME_UPPER"}:
70
+ head = []
71
+ tail = []
72
+ for i in range(len(x.shape) - 2):
73
+ d = x.shape[i + 2]
74
+ target_size = (d + strides[i] - 1) // strides[i]
75
+ pad_needed = (target_size - 1) * strides[i] + kernel_shape[i] - d
76
+ pad_head = (
77
+ (pad_needed + 1) // 2 if self.auto_pad == "SAME_LOWER" else pad_needed // 2
78
+ )
79
+ pad_tail = pad_needed - pad_head
80
+ head.append(pad_head)
81
+ tail.append(pad_tail)
82
+ pads = head + tail
83
+ else:
84
+ pads = self.pads or ([0 for _ in x.shape[2:]] * 2)
85
+
86
+ assert len(set(pads)) == 1, (
87
+ f"conv not implemented for pads={pads}, "
88
+ f"auto_pad={self.auto_pad!r}, strides={strides}, "
89
+ f"x.shape={x.shape}, kernel_shape={kernel_shape}"
90
+ )
91
+
92
+ if b is None:
93
+ bias = None
94
+ else:
95
+ bias = b.tensor.squeeze()
96
+ if not bias.shape:
97
+ bias = bias.unsqueeze(0)
98
+ return OpRunTensor(
99
+ torch.nn.functional.conv2d(
100
+ x.tensor,
101
+ w.tensor,
102
+ bias=bias,
103
+ stride=tuple(strides),
104
+ padding=pads[0],
105
+ dilation=tuple(dilations),
106
+ groups=self.group,
107
+ )
108
+ )
109
+
110
+
111
+ class LayerNormalization_17(OpRun):
112
+ "LayerNormalization"
113
+
114
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
115
+ super().__init__(node, version)
116
+ self.axis = self.get_attribute_int(node, "axis", -1)
117
+ self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
118
+ self.stash_type = onnx_dtype_to_torch_dtype(
119
+ self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT) # type: ignore[arg-type]
120
+ )
121
+ self.compute_std = len(node.output) > 1
122
+
123
+ def run(self, x, scale, bias=None):
124
+ original_dtype = x.dtype
125
+ if self.stash_type == torch.float32 and x.tensor.dtype != torch.float64:
126
+ xt = x.tensor
127
+ res = torch.nn.functional.layer_norm(
128
+ xt,
129
+ xt.shape[self.axis :],
130
+ weight=scale.tensor,
131
+ bias=None if bias is None else bias.tensor,
132
+ eps=self.epsilon,
133
+ )
134
+ else:
135
+ xt = x.tensor.to(self.stash_type)
136
+ res = torch.nn.functional.layer_norm(
137
+ xt,
138
+ xt.shape[self.axis :],
139
+ weight=scale.tensor.to(self.stash_type),
140
+ bias=None if bias is None else bias.tensor.to(self.stash_type),
141
+ eps=self.epsilon,
142
+ )
143
+ if not self.compute_std:
144
+ return OpRunTensor(res.to(original_dtype))
145
+ axes = tuple(range(len(xt.shape)))[self.axis :]
146
+ mean, var = torch.var(xt, dim=axes, keepdim=False)
147
+ x_inv_std_dev = torch.reciprocal(torch.sqrt(var + self.epsilon))
148
+ return (
149
+ OpRunTensor(res.to(original_dtype)),
150
+ OpRunTensor(mean),
151
+ OpRunTensor(x_inv_std_dev),
152
+ )
153
+
154
+
155
+ class Softmax_13(OpRun):
156
+ "Softmax"
157
+
158
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
159
+ super().__init__(node, version)
160
+ self.axis = self.get_attribute_int(node, "axis", -1)
161
+ assert isinstance(self.axis, int), f"Unexpected value for attribute axis={self.axis!r}"
162
+ # this is out of spec
163
+ stash_type = self.get_attribute_int(node, "stash_type", None)
164
+ self.stash_type = None if stash_type is None else onnx_dtype_to_torch_dtype(stash_type)
165
+
166
+ def run(self, data: OpRunTensor) -> OpRunTensor:
167
+ return OpRunTensor(
168
+ torch.nn.functional.softmax(data.tensor, dim=self.axis, dtype=self.stash_type)
169
+ )
170
+
171
+
172
+ class Tanh_6(OpRun):
173
+ "Tanh"
174
+
175
+ def run(self, data: OpRunTensor) -> OpRunTensor:
176
+ return OpRunTensor(torch.nn.functional.tanh(data.tensor))
@@ -0,0 +1,106 @@
1
+ from typing import Optional
2
+ import onnx
3
+ import torch
4
+ from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
5
+ from . import OpRun, OpRunTensor
6
+
7
+
8
+ class Cast_6(OpRun):
9
+ "Cast"
10
+
11
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
12
+ super().__init__(node, version)
13
+ to = self.get_attribute_int(node, "to", 0)
14
+ assert isinstance(to, int), f"Unexpected value for attribute to={to!r}"
15
+ self.to = onnx_dtype_to_torch_dtype(to)
16
+ self.saturate = self.get_attribute_int(node, "saturate", 1)
17
+ assert self.saturate == 1, f"saturate={self.saturate} not implemented for Cast"
18
+
19
+ def run(self, data: OpRunTensor) -> OpRunTensor:
20
+ return OpRunTensor(data.tensor.to(self.to))
21
+
22
+
23
+ class CastLike_15(OpRun):
24
+ "Cast"
25
+
26
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
27
+ super().__init__(node, version)
28
+ self.saturate = self.get_attribute_int(node, "saturate", 1)
29
+ assert self.saturate == 1, f"saturate={self.saturate} not implemented for CastLike"
30
+
31
+ def run(self, data: OpRunTensor, like: OpRunTensor) -> OpRunTensor:
32
+ return OpRunTensor(data.tensor.to(like.tensor.dtype))
33
+
34
+
35
+ class Concat_1(OpRun):
36
+ "Concat"
37
+
38
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
39
+ super().__init__(node, version)
40
+ axis = self.get_attribute_int(node, "axis", 0)
41
+ assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}"
42
+ self.axis = axis
43
+
44
+ def run(self, *data: OpRunTensor) -> OpRunTensor:
45
+ assert data, f"No tensor to concatenate in node name {self.name!r}"
46
+ devices = [d.get_device() for d in data]
47
+ if len(set(devices)) == 1:
48
+ return OpRunTensor(torch.cat([t.tensor for t in data], axis=self.axis))
49
+ if (
50
+ data[0].dtype == torch.int64
51
+ and self.axis == 0
52
+ and max(d.tensor.ndim for d in data) == 1
53
+ and max(d.tensor.numel() for d in data) <= 8
54
+ ):
55
+ # This is a shape
56
+ return OpRunTensor(torch.cat([t.tensor.cpu() for t in data], axis=self.axis))
57
+ index = devices.index(max(devices))
58
+ device = data[index].tensor.device
59
+ return OpRunTensor(torch.cat([t.tensor.to(device) for t in data], axis=self.axis))
60
+
61
+
62
+ class NonZero_13(OpRun):
63
+ "NonZero"
64
+
65
+ def run(self, x: OpRunTensor) -> OpRunTensor:
66
+ return OpRunTensor(torch.nonzero(x.tensor).T)
67
+
68
+
69
+ class Tile_6(OpRun):
70
+ "Tile"
71
+
72
+ def run(self, x: OpRunTensor, repeat: OpRunTensor) -> OpRunTensor:
73
+ return OpRunTensor(torch.tile(x.tensor, repeat.as_tuple_int))
74
+
75
+
76
+ class Transpose_1(OpRun):
77
+ "Transpose"
78
+
79
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
80
+ super().__init__(node, version)
81
+ self.perm = self.get_attribute_ints(node, "perm", None)
82
+
83
+ def run(self, data: OpRunTensor) -> OpRunTensor:
84
+ return OpRunTensor(torch.permute(data.tensor, self.perm))
85
+
86
+
87
+ class Trilu_14(OpRun):
88
+ "Trilu"
89
+
90
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
91
+ super().__init__(node, version)
92
+ self.upper = self.get_attribute_int(node, "upper", 1)
93
+
94
+ def run(self, data: OpRunTensor, k: Optional[OpRunTensor] = None) -> OpRunTensor:
95
+ diagonal = 0 if k is None else k.tensor.item()
96
+ if self.upper:
97
+ return OpRunTensor(torch.triu(data.tensor, diagonal=diagonal))
98
+ return OpRunTensor(torch.tril(data.tensor, diagonal=diagonal))
99
+
100
+
101
+ class Where_9(OpRun):
102
+ "Where"
103
+
104
+ def run(self, cond: OpRunTensor, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
105
+ tcond, tx, ty = self.same_device(cond.tensor, x.tensor, y.tensor)
106
+ return OpRunTensor(torch.where(tcond, tx, ty))
@@ -0,0 +1,130 @@
1
+ from typing import Optional, Tuple
2
+ import onnx
3
+ import torch
4
+ from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
5
+ from . import OpRun, OpRunTensor
6
+
7
+
8
+ class ReduceOp(OpRun):
9
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
10
+ super().__init__(node, version)
11
+ self.keepdims = bool(self.get_attribute_int(node, "keepdims", 1))
12
+ self.noop_with_empty_axes = bool(
13
+ self.get_attribute_int(node, "noop_with_empty_axes", 0)
14
+ )
15
+ assert isinstance(
16
+ self.keepdims, bool
17
+ ), f"Unexpected value for attribute keepdims={self.keepdims!r}"
18
+ assert isinstance(self.noop_with_empty_axes, bool), (
19
+ f"Unexpected value for attribute "
20
+ f"noop_with_empty_axes={self.noop_with_empty_axes!r}"
21
+ )
22
+ assert (
23
+ not self.noop_with_empty_axes
24
+ ), f"Not implemented with noop_with_empty_axes={self.noop_with_empty_axes}"
25
+ # this is out of spec
26
+ stash_type = self.get_attribute_int(node, "stash_type", None)
27
+ self.stash_type = None if stash_type is None else onnx_dtype_to_torch_dtype(stash_type)
28
+
29
+
30
+ class ReduceOpAxes(ReduceOp):
31
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
32
+ super().__init__(node, version)
33
+ self.axes: Tuple[int, ...] = self.get_attribute_ints(node, "axes") or tuple()
34
+
35
+
36
+ class ReduceMax_18(ReduceOp):
37
+ """ReduceMax"""
38
+
39
+ def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
40
+ assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
41
+ if axes is None:
42
+ assert (
43
+ not self.keepdims
44
+ ), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
45
+ return OpRunTensor(x.tensor.max())
46
+ taxes = axes.as_tuple_int
47
+ if len(taxes) == 1:
48
+ t = x.tensor.max(taxes[0], keepdim=self.keepdims)
49
+ return OpRunTensor(t.values)
50
+ t = x.tensor
51
+ for a in reversed(taxes):
52
+ t = t.max(a, keepdim=self.keepdims).values
53
+ return OpRunTensor(t)
54
+
55
+
56
+ class ReduceMean_18(ReduceOp):
57
+ """ReduceMean"""
58
+
59
+ def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
60
+ assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
61
+ if axes is None:
62
+ assert (
63
+ not self.keepdims
64
+ ), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
65
+ return OpRunTensor(torch.mean(x.tensor))
66
+ taxes = axes.as_tuple_int
67
+ if len(taxes) == 1:
68
+ t = x.tensor.mean(taxes[0], keepdim=self.keepdims)
69
+ return OpRunTensor(t)
70
+ t = x.tensor.mean(taxes, keepdim=self.keepdims)
71
+ return OpRunTensor(t)
72
+
73
+
74
+ class ReduceMin_17(ReduceOpAxes):
75
+ """ReduceMin"""
76
+
77
+ def run(self, x: OpRunTensor) -> OpRunTensor:
78
+ assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
79
+ axes = self.axes
80
+ if not axes:
81
+ assert (
82
+ not self.keepdims
83
+ ), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
84
+ return OpRunTensor(x.tensor.min())
85
+ taxes = tuple(axes)
86
+ if len(taxes) == 1:
87
+ t = x.tensor.min(taxes[0], keepdim=self.keepdims)
88
+ return OpRunTensor(t.values)
89
+ t = x.tensor
90
+ for a in reversed(taxes):
91
+ t = t.min(a, keepdim=self.keepdims).values
92
+ return OpRunTensor(t)
93
+
94
+
95
+ class ReduceMin_18(ReduceOp):
96
+ """ReduceMin"""
97
+
98
+ def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
99
+ assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
100
+ if axes is None:
101
+ assert (
102
+ not self.keepdims
103
+ ), f"axes is empty, keepdims={self.keepdims} for {self.__class__.__name__}"
104
+ return OpRunTensor(torch.min(x.tensor))
105
+ taxes = axes.as_tuple_int
106
+ if len(taxes) == 1:
107
+ t = x.tensor.min(taxes[0], keepdim=self.keepdims)
108
+ return OpRunTensor(t.values)
109
+ t = x.tensor
110
+ for a in reversed(taxes):
111
+ t = t.min(a, keepdim=self.keepdims).values
112
+ return OpRunTensor(t)
113
+
114
+
115
+ class ReduceSum_13(ReduceOp):
116
+ """ReduceSum"""
117
+
118
+ def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
119
+ assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
120
+ if axes is None:
121
+ assert (
122
+ not self.keepdims
123
+ ), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
124
+ return OpRunTensor(torch.sum(x.tensor))
125
+ taxes = axes.as_tuple_int
126
+ if len(taxes) == 1:
127
+ t = x.tensor.sum(taxes[0], keepdim=self.keepdims)
128
+ return OpRunTensor(t)
129
+ t = x.tensor.sum(taxes, keepdim=self.keepdims)
130
+ return OpRunTensor(t)
@@ -0,0 +1,65 @@
1
+ from typing import Optional
2
+ import onnx
3
+ import torch
4
+ from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
5
+ from . import OpRun, OpRunSequence, OpRunTensor
6
+
7
+
8
+ class OpRunOpSequence(OpRun):
9
+ "Ancestor for kernel using sequences."
10
+
11
+
12
+ class ConcatFromSequence_11(OpRunOpSequence):
13
+ "ConcatFromSequence"
14
+
15
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
16
+ super().__init__(node, version)
17
+ axis = self.get_attribute_int(node, "axis", None)
18
+ assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}"
19
+ self.axis = axis
20
+ self.new_axis = self.get_attribute_int(node, "new_axis", 0)
21
+
22
+ def run(self, input_sequence: OpRunSequence) -> OpRunTensor:
23
+ assert isinstance(
24
+ input_sequence, OpRunSequence
25
+ ), f"Unexpected type {type(input_sequence)} for input_sequence"
26
+ seq = input_sequence.sequence
27
+ if self.new_axis == 1:
28
+ if self.axis == -1:
29
+ seq2 = [s.unsqueeze(len(s.shape)) for s in seq]
30
+ res = torch.cat(seq2, axis=-1)
31
+ else:
32
+ seq2 = [s.expand(self.axis) for s in seq]
33
+ res = torch.cat(seq2, axis=self.axis)
34
+ else:
35
+ res = torch.cat(seq, axis=self.axis)
36
+ return OpRunTensor(res)
37
+
38
+
39
+ class SequenceEmpty_11(OpRunOpSequence):
40
+ "SqeuenceEmpty"
41
+
42
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
43
+ super().__init__(node, version)
44
+ self.dtype = onnx_dtype_to_torch_dtype(
45
+ self.get_attribute_int(node, "dtype", onnx.TensorProto.FLOAT) # type: ignore[arg-type]
46
+ )
47
+
48
+ def run(self) -> OpRunSequence:
49
+ return OpRunSequence(dtype=self.dtype)
50
+
51
+
52
+ class SequenceInsert_11(OpRunOpSequence):
53
+ "SqeuenceInsert"
54
+
55
+ def run(
56
+ self,
57
+ input_sequence: OpRunSequence,
58
+ tensor: OpRunTensor,
59
+ position: Optional[OpRunTensor] = None,
60
+ ) -> OpRunSequence:
61
+ assert isinstance(input_sequence, OpRunSequence), (
62
+ f"Unexpected type {type(input_sequence)} for input_sequence: "
63
+ f"{input_sequence.string_type()}"
64
+ )
65
+ return input_sequence.insert_at(tensor, position)
@@ -0,0 +1,120 @@
1
+ from typing import Optional, Tuple
2
+ import onnx
3
+ import torch
4
+ from . import OpRun, OpRunTensor
5
+
6
+
7
+ class ConstantOfShape_9(OpRun):
8
+ "ConstantOfShape"
9
+
10
+ @classmethod
11
+ def device_dependent(cls) -> bool:
12
+ """
13
+ Returns True if the kernel needs a device to be efficiently initialized.
14
+ """
15
+ return True
16
+
17
+ def __init__(
18
+ self,
19
+ node: onnx.NodeProto,
20
+ version: Optional[int] = None,
21
+ device: Optional[torch.device] = None,
22
+ ):
23
+ super().__init__(node, version)
24
+ value = self.get_attribute_tensor(node, "value")
25
+ if value is None:
26
+ value = torch.tensor([0], dtype=torch.float32)
27
+ self.dtype = value.dtype
28
+ self.device = device
29
+ self.value = value[0]
30
+
31
+ def run(self, shape: OpRunTensor) -> OpRunTensor:
32
+ # The device is unknown as shapes usually take place on CPU.
33
+ return OpRunTensor(
34
+ torch.full(
35
+ shape.as_tuple_int, fill_value=self.value, dtype=self.dtype, device=self.device
36
+ )
37
+ )
38
+
39
+
40
+ class Expand_8(OpRun):
41
+ "Expand"
42
+
43
+ def run(self, data: OpRunTensor, shape: OpRunTensor) -> OpRunTensor:
44
+ ishape = tuple(-1 if i == 1 else i for i in shape.as_tuple_int)
45
+ return OpRunTensor(data.tensor.expand(ishape))
46
+
47
+
48
+ class Reshape_14(OpRun):
49
+ "Reshape"
50
+
51
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
52
+ super().__init__(node, version)
53
+ self.allowzero = self.get_attribute_int(node, "allowzero", 0)
54
+
55
+ def run(self, data: OpRunTensor, shape: OpRunTensor) -> OpRunTensor:
56
+ ishape = shape.as_tuple_int
57
+ assert ishape is not None, f"Unexpected return for shape={shape!r}"
58
+ if not self.allowzero and 0 in ishape:
59
+ xshape = data.tensor.shape
60
+ new_shape = []
61
+ for i, s in enumerate(ishape):
62
+ new_shape.append(xshape[i] if s == 0 else s)
63
+ return OpRunTensor(data.tensor.reshape(new_shape))
64
+ return OpRunTensor(data.tensor.reshape(ishape))
65
+
66
+
67
+ class Shape_15(OpRun):
68
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
69
+ super().__init__(node, version)
70
+ self.start = self.get_attribute_int(node, "start", 0)
71
+ self.end = self.get_attribute_int(node, "end", None)
72
+
73
+ def run(self, data: OpRunTensor) -> OpRunTensor:
74
+ shape = data.shape
75
+ sh = shape[self.start :] if self.end is None else shape[self.start : self.end]
76
+ return OpRunTensor(torch.tensor(sh, dtype=torch.int64), is_constant=True)
77
+
78
+
79
+ class Split_18(OpRun):
80
+ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
81
+ super().__init__(node, version)
82
+ self.axis = self.get_attribute_int(node, "axis", 0)
83
+ self.num_outputs = self.get_attribute_int(node, "num_outputs", None)
84
+
85
+ def run(
86
+ self, data: OpRunTensor, split: Optional[OpRunTensor] = None
87
+ ) -> Tuple[OpRunTensor, ...]:
88
+ if split is None:
89
+ assert isinstance(
90
+ self.num_outputs, int
91
+ ), f"Incompatibilities: split is None and num_outputs={self.num_outputs}"
92
+ size = data.tensor.shape[self.axis]
93
+ split_size = (
94
+ size // self.num_outputs
95
+ if size % self.num_outputs == 0
96
+ else size // self.num_outputs + 1
97
+ )
98
+ spl = torch.split(data.tensor, split_size, dim=self.axis)
99
+ else:
100
+ spl = torch.split(data.tensor, split.as_tuple_int, dim=self.axis)
101
+ return tuple(OpRunTensor(t) for t in spl)
102
+
103
+
104
+ class Squeeze_13(OpRun):
105
+ "Squeeze"
106
+
107
+ def run(self, data: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
108
+ if axes is None:
109
+ return OpRunTensor(data.tensor.squeeze())
110
+ return OpRunTensor(data.tensor.squeeze(axes.as_tuple_int))
111
+
112
+
113
+ class Unsqueeze_13(OpRun):
114
+ "Unsqueeze"
115
+
116
+ def run(self, data: OpRunTensor, axes: OpRunTensor) -> OpRunTensor:
117
+ t = data.tensor
118
+ for i in axes.as_tuple_int:
119
+ t = t.unsqueeze(i)
120
+ return OpRunTensor(t)
@@ -0,0 +1,86 @@
1
+ import torch
2
+ from . import OpRun, OpRunTensor
3
+
4
+
5
+ class Abs_1(OpRun):
6
+ """Abs"""
7
+
8
+ def run(self, x: OpRunTensor) -> OpRunTensor:
9
+ return OpRunTensor(torch.abs(x.tensor))
10
+
11
+
12
+ class Cos_1(OpRun):
13
+ """Cos"""
14
+
15
+ def run(self, x: OpRunTensor) -> OpRunTensor:
16
+ return OpRunTensor(x.tensor.cos())
17
+
18
+
19
+ class Erf_9(OpRun):
20
+ """Erf"""
21
+
22
+ def run(self, x: OpRunTensor) -> OpRunTensor:
23
+ return OpRunTensor(x.tensor.erf())
24
+
25
+
26
+ class Exp_1(OpRun):
27
+ """Exp"""
28
+
29
+ def run(self, x: OpRunTensor) -> OpRunTensor:
30
+ return OpRunTensor(x.tensor.exp())
31
+
32
+
33
+ class Identity_1(OpRun):
34
+ "Identity"
35
+
36
+ def run(self, x: OpRunTensor) -> OpRunTensor:
37
+ return OpRunTensor(x.tensor)
38
+
39
+
40
+ class Log_1(OpRun):
41
+ """Log"""
42
+
43
+ def run(self, x: OpRunTensor) -> OpRunTensor:
44
+ return OpRunTensor(x.tensor.log())
45
+
46
+
47
+ class Neg_1(OpRun):
48
+ """Neg"""
49
+
50
+ def run(self, x: OpRunTensor) -> OpRunTensor:
51
+ return OpRunTensor(-x.tensor)
52
+
53
+
54
+ class Not_1(OpRun):
55
+ """Not"""
56
+
57
+ def run(self, x: OpRunTensor) -> OpRunTensor:
58
+ return OpRunTensor(~x.tensor)
59
+
60
+
61
+ class Reciprocal_1(OpRun):
62
+ """REciprocal"""
63
+
64
+ def run(self, x: OpRunTensor) -> OpRunTensor:
65
+ return OpRunTensor(1 / x.tensor)
66
+
67
+
68
+ class Sigmoid_6(OpRun):
69
+ """Sqrt"""
70
+
71
+ def run(self, x: OpRunTensor) -> OpRunTensor:
72
+ return OpRunTensor(torch.sigmoid(x.tensor))
73
+
74
+
75
+ class Sin_1(OpRun):
76
+ """Sin"""
77
+
78
+ def run(self, x: OpRunTensor) -> OpRunTensor:
79
+ return OpRunTensor(x.tensor.sin())
80
+
81
+
82
+ class Sqrt_1(OpRun):
83
+ """Sqrt"""
84
+
85
+ def run(self, x: OpRunTensor) -> OpRunTensor:
86
+ return OpRunTensor(x.tensor.sqrt())