emx-onnx-cgen 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (76) hide show
  1. emx_onnx_cgen/__init__.py +6 -0
  2. emx_onnx_cgen/__main__.py +9 -0
  3. emx_onnx_cgen/_build_info.py +3 -0
  4. emx_onnx_cgen/cli.py +328 -0
  5. emx_onnx_cgen/codegen/__init__.py +25 -0
  6. emx_onnx_cgen/codegen/c_emitter.py +9044 -0
  7. emx_onnx_cgen/compiler.py +601 -0
  8. emx_onnx_cgen/dtypes.py +40 -0
  9. emx_onnx_cgen/errors.py +14 -0
  10. emx_onnx_cgen/ir/__init__.py +3 -0
  11. emx_onnx_cgen/ir/model.py +55 -0
  12. emx_onnx_cgen/lowering/__init__.py +3 -0
  13. emx_onnx_cgen/lowering/arg_reduce.py +99 -0
  14. emx_onnx_cgen/lowering/attention.py +421 -0
  15. emx_onnx_cgen/lowering/average_pool.py +229 -0
  16. emx_onnx_cgen/lowering/batch_normalization.py +116 -0
  17. emx_onnx_cgen/lowering/cast.py +70 -0
  18. emx_onnx_cgen/lowering/common.py +72 -0
  19. emx_onnx_cgen/lowering/concat.py +31 -0
  20. emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
  21. emx_onnx_cgen/lowering/conv.py +192 -0
  22. emx_onnx_cgen/lowering/cumsum.py +118 -0
  23. emx_onnx_cgen/lowering/depth_space.py +114 -0
  24. emx_onnx_cgen/lowering/dropout.py +46 -0
  25. emx_onnx_cgen/lowering/elementwise.py +164 -0
  26. emx_onnx_cgen/lowering/expand.py +151 -0
  27. emx_onnx_cgen/lowering/eye_like.py +43 -0
  28. emx_onnx_cgen/lowering/flatten.py +60 -0
  29. emx_onnx_cgen/lowering/gather.py +48 -0
  30. emx_onnx_cgen/lowering/gather_elements.py +60 -0
  31. emx_onnx_cgen/lowering/gemm.py +139 -0
  32. emx_onnx_cgen/lowering/grid_sample.py +149 -0
  33. emx_onnx_cgen/lowering/group_normalization.py +68 -0
  34. emx_onnx_cgen/lowering/identity.py +43 -0
  35. emx_onnx_cgen/lowering/instance_normalization.py +50 -0
  36. emx_onnx_cgen/lowering/layer_normalization.py +110 -0
  37. emx_onnx_cgen/lowering/logsoftmax.py +47 -0
  38. emx_onnx_cgen/lowering/lp_normalization.py +45 -0
  39. emx_onnx_cgen/lowering/lrn.py +104 -0
  40. emx_onnx_cgen/lowering/lstm.py +355 -0
  41. emx_onnx_cgen/lowering/matmul.py +120 -0
  42. emx_onnx_cgen/lowering/maxpool.py +195 -0
  43. emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
  44. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
  45. emx_onnx_cgen/lowering/pad.py +287 -0
  46. emx_onnx_cgen/lowering/range.py +104 -0
  47. emx_onnx_cgen/lowering/reduce.py +544 -0
  48. emx_onnx_cgen/lowering/registry.py +51 -0
  49. emx_onnx_cgen/lowering/reshape.py +188 -0
  50. emx_onnx_cgen/lowering/resize.py +445 -0
  51. emx_onnx_cgen/lowering/rms_normalization.py +67 -0
  52. emx_onnx_cgen/lowering/shape.py +78 -0
  53. emx_onnx_cgen/lowering/size.py +33 -0
  54. emx_onnx_cgen/lowering/slice.py +425 -0
  55. emx_onnx_cgen/lowering/softmax.py +47 -0
  56. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
  57. emx_onnx_cgen/lowering/split.py +150 -0
  58. emx_onnx_cgen/lowering/squeeze.py +161 -0
  59. emx_onnx_cgen/lowering/tile.py +81 -0
  60. emx_onnx_cgen/lowering/transpose.py +46 -0
  61. emx_onnx_cgen/lowering/unsqueeze.py +157 -0
  62. emx_onnx_cgen/lowering/variadic.py +95 -0
  63. emx_onnx_cgen/lowering/where.py +73 -0
  64. emx_onnx_cgen/onnx_import.py +261 -0
  65. emx_onnx_cgen/ops.py +565 -0
  66. emx_onnx_cgen/runtime/__init__.py +1 -0
  67. emx_onnx_cgen/runtime/evaluator.py +2206 -0
  68. emx_onnx_cgen/validation.py +76 -0
  69. emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
  70. emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
  71. emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
  72. emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
  73. emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
  74. shared/__init__.py +2 -0
  75. shared/scalar_functions.py +2405 -0
  76. shared/scalar_types.py +243 -0
@@ -0,0 +1,192 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+
6
+ from ..codegen.c_emitter import ConvOp
7
+ from ..errors import ShapeInferenceError, UnsupportedOpError
8
+ from ..ir.model import Graph, Node
9
+ from .common import node_dtype as _node_dtype
10
+ from .common import value_shape as _value_shape
11
+ from .registry import register_lowering
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class ConvSpec:
16
+ batch: int
17
+ in_channels: int
18
+ out_channels: int
19
+ spatial_rank: int
20
+ in_spatial: tuple[int, ...]
21
+ out_spatial: tuple[int, ...]
22
+ kernel_shape: tuple[int, ...]
23
+ strides: tuple[int, ...]
24
+ pads: tuple[int, ...]
25
+ dilations: tuple[int, ...]
26
+ group: int
27
+
28
+
29
+ def resolve_conv_spec(graph: Graph, node: Node) -> ConvSpec:
30
+ if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
31
+ raise UnsupportedOpError("Conv must have 2 or 3 inputs and 1 output")
32
+ supported_attrs = {
33
+ "auto_pad",
34
+ "dilations",
35
+ "group",
36
+ "kernel_shape",
37
+ "pads",
38
+ "strides",
39
+ }
40
+ if set(node.attrs) - supported_attrs:
41
+ raise UnsupportedOpError("Conv has unsupported attributes")
42
+ input_shape = _value_shape(graph, node.inputs[0], node)
43
+ weight_shape = _value_shape(graph, node.inputs[1], node)
44
+ if len(input_shape) < 3:
45
+ raise UnsupportedOpError("Conv expects NCHW inputs with spatial dims")
46
+ spatial_rank = len(input_shape) - 2
47
+ if spatial_rank not in {1, 2, 3}:
48
+ raise UnsupportedOpError("Conv supports 1D/2D/3D inputs only")
49
+ if len(weight_shape) != spatial_rank + 2:
50
+ raise UnsupportedOpError("Conv weight rank must match spatial rank")
51
+ batch, in_channels = input_shape[0], input_shape[1]
52
+ in_spatial = input_shape[2:]
53
+ out_channels, weight_in_channels, *kernel_shape = weight_shape
54
+ kernel_shape = node.attrs.get("kernel_shape")
55
+ if kernel_shape is not None:
56
+ kernel_shape = tuple(int(value) for value in kernel_shape)
57
+ if len(kernel_shape) != spatial_rank:
58
+ raise UnsupportedOpError(
59
+ "Conv kernel_shape rank must match input spatial rank"
60
+ )
61
+ if kernel_shape != tuple(weight_shape[2:]):
62
+ raise ShapeInferenceError(
63
+ "Conv kernel_shape must match weights, "
64
+ f"got {kernel_shape} and {tuple(weight_shape[2:])}"
65
+ )
66
+ else:
67
+ kernel_shape = tuple(weight_shape[2:])
68
+ group = int(node.attrs.get("group", 1))
69
+ if group <= 0:
70
+ raise UnsupportedOpError("Conv expects group >= 1")
71
+ if in_channels % group != 0 or out_channels % group != 0:
72
+ raise ShapeInferenceError(
73
+ "Conv expects group to evenly divide in/out channels, "
74
+ f"got group={group}, in_channels={in_channels}, "
75
+ f"out_channels={out_channels}"
76
+ )
77
+ if weight_in_channels != in_channels // group:
78
+ raise ShapeInferenceError(
79
+ "Conv input channels must match weight channels, "
80
+ f"got {in_channels} and {weight_in_channels * group}"
81
+ )
82
+ if len(node.inputs) == 3:
83
+ bias_shape = _value_shape(graph, node.inputs[2], node)
84
+ if bias_shape != (out_channels,):
85
+ raise ShapeInferenceError(
86
+ f"Conv bias shape must be {(out_channels,)}, got {bias_shape}"
87
+ )
88
+ strides = tuple(
89
+ int(value) for value in node.attrs.get("strides", (1,) * spatial_rank)
90
+ )
91
+ if len(strides) != spatial_rank:
92
+ raise UnsupportedOpError("Conv stride rank mismatch")
93
+ dilations = tuple(
94
+ int(value) for value in node.attrs.get("dilations", (1,) * spatial_rank)
95
+ )
96
+ if len(dilations) != spatial_rank:
97
+ raise UnsupportedOpError("Conv dilation rank mismatch")
98
+ pads = tuple(
99
+ int(value)
100
+ for value in node.attrs.get("pads", (0,) * (2 * spatial_rank))
101
+ )
102
+ if len(pads) != 2 * spatial_rank:
103
+ raise UnsupportedOpError("Conv pads rank mismatch")
104
+ auto_pad = node.attrs.get("auto_pad", b"NOTSET")
105
+ if isinstance(auto_pad, bytes):
106
+ auto_pad = auto_pad.decode("utf-8", errors="ignore")
107
+ if auto_pad in ("", "NOTSET"):
108
+ pad_begin = pads[:spatial_rank]
109
+ pad_end = pads[spatial_rank:]
110
+ elif auto_pad == "VALID":
111
+ pad_begin = (0,) * spatial_rank
112
+ pad_end = (0,) * spatial_rank
113
+ elif auto_pad in {"SAME_UPPER", "SAME_LOWER"}:
114
+ pad_begin = []
115
+ pad_end = []
116
+ for dim, stride, dilation, kernel in zip(
117
+ in_spatial, strides, dilations, kernel_shape
118
+ ):
119
+ effective_kernel = dilation * (kernel - 1) + 1
120
+ out_dim = math.ceil(dim / stride)
121
+ pad_needed = max(
122
+ 0, (out_dim - 1) * stride + effective_kernel - dim
123
+ )
124
+ if auto_pad == "SAME_UPPER":
125
+ pad_start = pad_needed // 2
126
+ else:
127
+ pad_start = (pad_needed + 1) // 2
128
+ pad_begin.append(pad_start)
129
+ pad_end.append(pad_needed - pad_start)
130
+ pad_begin = tuple(pad_begin)
131
+ pad_end = tuple(pad_end)
132
+ else:
133
+ raise UnsupportedOpError("Conv has unsupported auto_pad mode")
134
+ out_spatial = []
135
+ for dim, stride, dilation, kernel, pad_start, pad_finish in zip(
136
+ in_spatial, strides, dilations, kernel_shape, pad_begin, pad_end
137
+ ):
138
+ effective_kernel = dilation * (kernel - 1) + 1
139
+ out_dim = (dim + pad_start + pad_finish - effective_kernel) // stride + 1
140
+ if out_dim < 0:
141
+ raise ShapeInferenceError("Conv output shape must be non-negative")
142
+ out_spatial.append(out_dim)
143
+ output_shape = _value_shape(graph, node.outputs[0], node)
144
+ expected_output_shape = (batch, out_channels, *out_spatial)
145
+ if output_shape != expected_output_shape:
146
+ raise ShapeInferenceError(
147
+ "Conv output shape must be "
148
+ f"{expected_output_shape}, got {output_shape}"
149
+ )
150
+ return ConvSpec(
151
+ batch=batch,
152
+ in_channels=in_channels,
153
+ out_channels=out_channels,
154
+ spatial_rank=spatial_rank,
155
+ in_spatial=in_spatial,
156
+ out_spatial=tuple(out_spatial),
157
+ kernel_shape=kernel_shape,
158
+ strides=strides,
159
+ pads=(*pad_begin, *pad_end),
160
+ dilations=dilations,
161
+ group=group,
162
+ )
163
+
164
+
165
+ @register_lowering("Conv")
166
+ def lower_conv(graph: Graph, node: Node) -> ConvOp:
167
+ if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
168
+ raise UnsupportedOpError("Conv must have 2 or 3 inputs and 1 output")
169
+ op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
170
+ if not op_dtype.is_float:
171
+ raise UnsupportedOpError(
172
+ "Conv supports float16, float, and double inputs only"
173
+ )
174
+ spec = resolve_conv_spec(graph, node)
175
+ return ConvOp(
176
+ input0=node.inputs[0],
177
+ weights=node.inputs[1],
178
+ bias=node.inputs[2] if len(node.inputs) == 3 else None,
179
+ output=node.outputs[0],
180
+ batch=spec.batch,
181
+ in_channels=spec.in_channels,
182
+ out_channels=spec.out_channels,
183
+ spatial_rank=spec.spatial_rank,
184
+ in_spatial=spec.in_spatial,
185
+ out_spatial=spec.out_spatial,
186
+ kernel_shape=spec.kernel_shape,
187
+ strides=spec.strides,
188
+ pads=spec.pads,
189
+ dilations=spec.dilations,
190
+ group=spec.group,
191
+ dtype=op_dtype,
192
+ )
@@ -0,0 +1,118 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..codegen.c_emitter import CumSumOp
8
+ from ..errors import ShapeInferenceError, UnsupportedOpError
9
+ from ..ir.model import Graph, Initializer, Node
10
+ from ..lowering.common import value_dtype, value_shape
11
+ from ..validation import ensure_output_shape_matches_input, normalize_axis
12
+ from .registry import register_lowering
13
+
14
+
15
+ _SUPPORTED_CUMSUM_DTYPES = {
16
+ ScalarType.F16,
17
+ ScalarType.F32,
18
+ ScalarType.F64,
19
+ ScalarType.I32,
20
+ ScalarType.I64,
21
+ ScalarType.U32,
22
+ ScalarType.U64,
23
+ }
24
+
25
+
26
+ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
27
+ for initializer in graph.initializers:
28
+ if initializer.name == name:
29
+ return initializer
30
+ return None
31
+
32
+
33
+ def _is_scalar_shape(shape: tuple[int, ...]) -> bool:
34
+ return shape == () or shape == (1,)
35
+
36
+
37
+ def _validate_static_shape(shape: tuple[int, ...], node: Node) -> None:
38
+ for dim in shape:
39
+ if dim < 0:
40
+ raise ShapeInferenceError(
41
+ f"{node.op_type} does not support dynamic dims"
42
+ )
43
+
44
+
45
+ def _read_axis_initializer(
46
+ initializer: Initializer, node: Node
47
+ ) -> int:
48
+ if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
49
+ raise UnsupportedOpError(
50
+ f"{node.op_type} axis input must be int64 or int32"
51
+ )
52
+ axis_data = np.array(initializer.data, dtype=np.int64).reshape(-1)
53
+ if axis_data.size != 1:
54
+ raise UnsupportedOpError(f"{node.op_type} axis input must be scalar")
55
+ return int(axis_data[0])
56
+
57
+
58
+ @register_lowering("CumSum")
59
+ def lower_cumsum(graph: Graph, node: Node) -> CumSumOp:
60
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
61
+ raise UnsupportedOpError("CumSum must have 2 inputs and 1 output")
62
+ input_name = node.inputs[0]
63
+ axis_name = node.inputs[1]
64
+ if not input_name or not axis_name:
65
+ raise UnsupportedOpError("CumSum requires input and axis values")
66
+ input_shape = value_shape(graph, input_name, node)
67
+ output_shape = value_shape(graph, node.outputs[0], node)
68
+ _validate_static_shape(input_shape, node)
69
+ ensure_output_shape_matches_input(node, input_shape, output_shape)
70
+ input_dtype = value_dtype(graph, input_name, node)
71
+ output_dtype = value_dtype(graph, node.outputs[0], node)
72
+ if input_dtype != output_dtype:
73
+ raise UnsupportedOpError(
74
+ "CumSum expects matching input/output dtypes, "
75
+ f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
76
+ )
77
+ if input_dtype not in _SUPPORTED_CUMSUM_DTYPES:
78
+ raise UnsupportedOpError(
79
+ f"CumSum does not support dtype {input_dtype.onnx_name}"
80
+ )
81
+ axis_initializer = _find_initializer(graph, axis_name)
82
+ axis_value = None
83
+ axis_input = None
84
+ axis_input_dtype = None
85
+ if axis_initializer is not None:
86
+ axis_value = normalize_axis(
87
+ _read_axis_initializer(axis_initializer, node),
88
+ input_shape,
89
+ node,
90
+ )
91
+ else:
92
+ axis_shape = value_shape(graph, axis_name, node)
93
+ if not _is_scalar_shape(axis_shape):
94
+ raise UnsupportedOpError("CumSum axis input must be scalar")
95
+ axis_input_dtype = value_dtype(graph, axis_name, node)
96
+ if axis_input_dtype not in {ScalarType.I64, ScalarType.I32}:
97
+ raise UnsupportedOpError(
98
+ "CumSum axis input must be int64 or int32"
99
+ )
100
+ axis_input = axis_name
101
+ exclusive = int(node.attrs.get("exclusive", 0))
102
+ reverse = int(node.attrs.get("reverse", 0))
103
+ if exclusive not in {0, 1}:
104
+ raise UnsupportedOpError("CumSum exclusive must be 0 or 1")
105
+ if reverse not in {0, 1}:
106
+ raise UnsupportedOpError("CumSum reverse must be 0 or 1")
107
+ return CumSumOp(
108
+ input0=input_name,
109
+ axis_input=axis_input,
110
+ axis_input_dtype=axis_input_dtype,
111
+ axis=axis_value,
112
+ output=node.outputs[0],
113
+ input_shape=input_shape,
114
+ dtype=input_dtype,
115
+ input_dtype=input_dtype,
116
+ exclusive=bool(exclusive),
117
+ reverse=bool(reverse),
118
+ )
@@ -0,0 +1,114 @@
1
+ from __future__ import annotations
2
+
3
+ from ..codegen.c_emitter import DepthToSpaceOp, SpaceToDepthOp
4
+ from ..errors import ShapeInferenceError, UnsupportedOpError
5
+ from ..ir.model import Graph, Node
6
+ from ..lowering.common import value_dtype, value_shape
7
+ from .registry import register_lowering
8
+
9
+
10
+ def _blocksize(node: Node) -> int:
11
+ if "blocksize" not in node.attrs:
12
+ raise UnsupportedOpError(f"{node.op_type} requires blocksize attribute")
13
+ blocksize = int(node.attrs["blocksize"])
14
+ if blocksize <= 0:
15
+ raise UnsupportedOpError(
16
+ f"{node.op_type} blocksize must be > 0, got {blocksize}"
17
+ )
18
+ return blocksize
19
+
20
+
21
+ @register_lowering("DepthToSpace")
22
+ def lower_depth_to_space(graph: Graph, node: Node) -> DepthToSpaceOp:
23
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
24
+ raise UnsupportedOpError("DepthToSpace must have 1 input and 1 output")
25
+ input_shape = value_shape(graph, node.inputs[0], node)
26
+ output_shape = value_shape(graph, node.outputs[0], node)
27
+ if len(input_shape) != 4 or len(output_shape) != 4:
28
+ raise UnsupportedOpError("DepthToSpace only supports 4D inputs")
29
+ input_dtype = value_dtype(graph, node.inputs[0], node)
30
+ output_dtype = value_dtype(graph, node.outputs[0], node)
31
+ if input_dtype != output_dtype:
32
+ raise UnsupportedOpError(
33
+ "DepthToSpace expects matching input/output dtypes, "
34
+ f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
35
+ )
36
+ blocksize = _blocksize(node)
37
+ mode_attr = node.attrs.get("mode", "DCR")
38
+ if isinstance(mode_attr, bytes):
39
+ mode = mode_attr.decode()
40
+ else:
41
+ mode = str(mode_attr)
42
+ if mode not in {"DCR", "CRD"}:
43
+ raise UnsupportedOpError(
44
+ "DepthToSpace only supports mode DCR or CRD"
45
+ )
46
+ n, c, h, w = input_shape
47
+ if c % (blocksize * blocksize) != 0:
48
+ raise ShapeInferenceError(
49
+ "DepthToSpace input channels must be divisible by blocksize^2"
50
+ )
51
+ expected_shape = (
52
+ n,
53
+ c // (blocksize * blocksize),
54
+ h * blocksize,
55
+ w * blocksize,
56
+ )
57
+ if output_shape != expected_shape:
58
+ raise ShapeInferenceError(
59
+ "DepthToSpace output shape mismatch: "
60
+ f"expected {expected_shape}, got {output_shape}"
61
+ )
62
+ return DepthToSpaceOp(
63
+ input0=node.inputs[0],
64
+ output=node.outputs[0],
65
+ input_shape=input_shape,
66
+ output_shape=output_shape,
67
+ blocksize=blocksize,
68
+ mode=mode,
69
+ dtype=output_dtype,
70
+ input_dtype=input_dtype,
71
+ )
72
+
73
+
74
+ @register_lowering("SpaceToDepth")
75
+ def lower_space_to_depth(graph: Graph, node: Node) -> SpaceToDepthOp:
76
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
77
+ raise UnsupportedOpError("SpaceToDepth must have 1 input and 1 output")
78
+ input_shape = value_shape(graph, node.inputs[0], node)
79
+ output_shape = value_shape(graph, node.outputs[0], node)
80
+ if len(input_shape) != 4 or len(output_shape) != 4:
81
+ raise UnsupportedOpError("SpaceToDepth only supports 4D inputs")
82
+ input_dtype = value_dtype(graph, node.inputs[0], node)
83
+ output_dtype = value_dtype(graph, node.outputs[0], node)
84
+ if input_dtype != output_dtype:
85
+ raise UnsupportedOpError(
86
+ "SpaceToDepth expects matching input/output dtypes, "
87
+ f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
88
+ )
89
+ blocksize = _blocksize(node)
90
+ n, c, h, w = input_shape
91
+ if h % blocksize != 0 or w % blocksize != 0:
92
+ raise ShapeInferenceError(
93
+ "SpaceToDepth spatial dims must be divisible by blocksize"
94
+ )
95
+ expected_shape = (
96
+ n,
97
+ c * blocksize * blocksize,
98
+ h // blocksize,
99
+ w // blocksize,
100
+ )
101
+ if output_shape != expected_shape:
102
+ raise ShapeInferenceError(
103
+ "SpaceToDepth output shape mismatch: "
104
+ f"expected {expected_shape}, got {output_shape}"
105
+ )
106
+ return SpaceToDepthOp(
107
+ input0=node.inputs[0],
108
+ output=node.outputs[0],
109
+ input_shape=input_shape,
110
+ output_shape=output_shape,
111
+ blocksize=blocksize,
112
+ dtype=output_dtype,
113
+ input_dtype=input_dtype,
114
+ )
@@ -0,0 +1,46 @@
1
+ from __future__ import annotations
2
+
3
+ from ..codegen.c_emitter import ReshapeOp
4
+ from ..errors import ShapeInferenceError, UnsupportedOpError
5
+ from ..ir.model import Graph, Node
6
+ from .common import value_dtype as _value_dtype
7
+ from .common import value_shape as _value_shape
8
+ from .registry import register_lowering
9
+
10
+
11
+ def _is_value_used(graph: Graph, name: str) -> bool:
12
+ if any(value.name == name for value in graph.outputs):
13
+ return True
14
+ return any(name in node.inputs for node in graph.nodes)
15
+
16
+
17
+ @register_lowering("Dropout")
18
+ def lower_dropout(graph: Graph, node: Node) -> ReshapeOp:
19
+ if len(node.outputs) not in {1, 2} or len(node.inputs) != 1:
20
+ raise UnsupportedOpError(
21
+ "Dropout supports only the data input and 1 or 2 outputs"
22
+ )
23
+ if len(node.outputs) == 2 and _is_value_used(graph, node.outputs[1]):
24
+ raise UnsupportedOpError("Dropout mask output is not supported")
25
+ input_shape = _value_shape(graph, node.inputs[0], node)
26
+ output_shape = _value_shape(graph, node.outputs[0], node)
27
+ if input_shape != output_shape:
28
+ raise ShapeInferenceError(
29
+ "Dropout output shape must match input shape, "
30
+ f"got {output_shape} for input {input_shape}"
31
+ )
32
+ input_dtype = _value_dtype(graph, node.inputs[0], node)
33
+ output_dtype = _value_dtype(graph, node.outputs[0], node)
34
+ if input_dtype != output_dtype:
35
+ raise UnsupportedOpError(
36
+ "Dropout expects matching input/output dtypes, "
37
+ f"got {input_dtype} and {output_dtype}"
38
+ )
39
+ return ReshapeOp(
40
+ input0=node.inputs[0],
41
+ output=node.outputs[0],
42
+ input_shape=input_shape,
43
+ output_shape=output_shape,
44
+ dtype=input_dtype,
45
+ input_dtype=input_dtype,
46
+ )
@@ -0,0 +1,164 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_functions import ScalarFunction
4
+ from shared.scalar_types import ScalarType
5
+
6
+ from ..codegen.c_emitter import ClipOp, UnaryOp
7
+ from ..errors import UnsupportedOpError
8
+ from ..ir.model import Graph, Node
9
+ from ..lowering.common import node_dtype, optional_name, value_dtype, value_shape
10
+ from ..lowering.registry import register_lowering
11
+
12
+
13
+ @register_lowering("Clip")
14
+ def lower_clip(graph: Graph, node: Node) -> ClipOp:
15
+ if not node.inputs or len(node.outputs) != 1:
16
+ raise UnsupportedOpError("Clip must have 1 output")
17
+ input_name = node.inputs[0]
18
+ if not input_name:
19
+ raise UnsupportedOpError("Clip input must be provided")
20
+ min_name = optional_name(node.inputs, 1)
21
+ max_name = optional_name(node.inputs, 2)
22
+ input_dtype = value_dtype(graph, input_name, node)
23
+ output_dtype = value_dtype(graph, node.outputs[0], node)
24
+ if input_dtype != output_dtype:
25
+ raise UnsupportedOpError(
26
+ "Clip expects matching input/output dtypes, "
27
+ f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
28
+ )
29
+ if min_name is not None:
30
+ min_dtype = value_dtype(graph, min_name, node)
31
+ if min_dtype != input_dtype:
32
+ raise UnsupportedOpError(
33
+ "Clip min dtype must match input dtype, "
34
+ f"got {min_dtype.onnx_name}"
35
+ )
36
+ if max_name is not None:
37
+ max_dtype = value_dtype(graph, max_name, node)
38
+ if max_dtype != input_dtype:
39
+ raise UnsupportedOpError(
40
+ "Clip max dtype must match input dtype, "
41
+ f"got {max_dtype.onnx_name}"
42
+ )
43
+ input_shape = value_shape(graph, input_name, node)
44
+ output_shape = value_shape(graph, node.outputs[0], node)
45
+ if input_shape != output_shape:
46
+ raise UnsupportedOpError("Clip input and output shapes must match")
47
+ min_shape = value_shape(graph, min_name, node) if min_name else None
48
+ max_shape = value_shape(graph, max_name, node) if max_name else None
49
+ return ClipOp(
50
+ input0=input_name,
51
+ input_min=min_name,
52
+ input_max=max_name,
53
+ output=node.outputs[0],
54
+ input_shape=input_shape,
55
+ min_shape=min_shape,
56
+ max_shape=max_shape,
57
+ output_shape=output_shape,
58
+ dtype=input_dtype,
59
+ )
60
+
61
+
62
+ @register_lowering("Celu")
63
+ def lower_celu(graph: Graph, node: Node) -> UnaryOp:
64
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
65
+ raise UnsupportedOpError("Celu must have 1 input and 1 output")
66
+ dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
67
+ if not dtype.is_float:
68
+ raise UnsupportedOpError("Celu only supports floating-point inputs")
69
+ alpha = float(node.attrs.get("alpha", 1.0))
70
+ output_shape = value_shape(graph, node.outputs[0], node)
71
+ return UnaryOp(
72
+ input0=node.inputs[0],
73
+ output=node.outputs[0],
74
+ function=ScalarFunction.CELU,
75
+ shape=output_shape,
76
+ dtype=dtype,
77
+ input_dtype=dtype,
78
+ params=(alpha,),
79
+ )
80
+
81
+
82
+ @register_lowering("Swish")
83
+ def lower_swish(graph: Graph, node: Node) -> UnaryOp:
84
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
85
+ raise UnsupportedOpError("Swish must have 1 input and 1 output")
86
+ dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
87
+ if not dtype.is_float:
88
+ raise UnsupportedOpError("Swish only supports floating-point inputs")
89
+ alpha = float(node.attrs.get("alpha", 1.0))
90
+ output_shape = value_shape(graph, node.outputs[0], node)
91
+ return UnaryOp(
92
+ input0=node.inputs[0],
93
+ output=node.outputs[0],
94
+ function=ScalarFunction.SWISH,
95
+ shape=output_shape,
96
+ dtype=dtype,
97
+ input_dtype=dtype,
98
+ params=(alpha,),
99
+ )
100
+
101
+
102
+ @register_lowering("Shrink")
103
+ def lower_shrink(graph: Graph, node: Node) -> UnaryOp:
104
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
105
+ raise UnsupportedOpError("Shrink must have 1 input and 1 output")
106
+ dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
107
+ if not dtype.is_float:
108
+ raise UnsupportedOpError("Shrink only supports floating-point inputs")
109
+ bias = float(node.attrs.get("bias", 0.0))
110
+ lambd = float(node.attrs.get("lambd", 0.5))
111
+ output_shape = value_shape(graph, node.outputs[0], node)
112
+ return UnaryOp(
113
+ input0=node.inputs[0],
114
+ output=node.outputs[0],
115
+ function=ScalarFunction.SHRINK,
116
+ shape=output_shape,
117
+ dtype=dtype,
118
+ input_dtype=dtype,
119
+ params=(bias, lambd),
120
+ )
121
+
122
+
123
+ @register_lowering("IsInf")
124
+ def lower_isinf(graph: Graph, node: Node) -> UnaryOp:
125
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
126
+ raise UnsupportedOpError("IsInf must have 1 input and 1 output")
127
+ input_dtype = value_dtype(graph, node.inputs[0], node)
128
+ output_dtype = value_dtype(graph, node.outputs[0], node)
129
+ if not input_dtype.is_float:
130
+ raise UnsupportedOpError("IsInf only supports floating-point inputs")
131
+ if output_dtype != ScalarType.BOOL:
132
+ raise UnsupportedOpError("IsInf output must be bool")
133
+ output_shape = value_shape(graph, node.outputs[0], node)
134
+ return UnaryOp(
135
+ input0=node.inputs[0],
136
+ output=node.outputs[0],
137
+ function=ScalarFunction.ISINF,
138
+ shape=output_shape,
139
+ dtype=output_dtype,
140
+ input_dtype=input_dtype,
141
+ params=(),
142
+ )
143
+
144
+
145
+ @register_lowering("IsNaN")
146
+ def lower_isnan(graph: Graph, node: Node) -> UnaryOp:
147
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
148
+ raise UnsupportedOpError("IsNaN must have 1 input and 1 output")
149
+ input_dtype = value_dtype(graph, node.inputs[0], node)
150
+ output_dtype = value_dtype(graph, node.outputs[0], node)
151
+ if not input_dtype.is_float:
152
+ raise UnsupportedOpError("IsNaN only supports floating-point inputs")
153
+ if output_dtype != ScalarType.BOOL:
154
+ raise UnsupportedOpError("IsNaN output must be bool")
155
+ output_shape = value_shape(graph, node.outputs[0], node)
156
+ return UnaryOp(
157
+ input0=node.inputs[0],
158
+ output=node.outputs[0],
159
+ function=ScalarFunction.ISNAN,
160
+ shape=output_shape,
161
+ dtype=output_dtype,
162
+ input_dtype=input_dtype,
163
+ params=(),
164
+ )