emx-onnx-cgen 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (76) hide show
  1. emx_onnx_cgen/__init__.py +6 -0
  2. emx_onnx_cgen/__main__.py +9 -0
  3. emx_onnx_cgen/_build_info.py +3 -0
  4. emx_onnx_cgen/cli.py +328 -0
  5. emx_onnx_cgen/codegen/__init__.py +25 -0
  6. emx_onnx_cgen/codegen/c_emitter.py +9044 -0
  7. emx_onnx_cgen/compiler.py +601 -0
  8. emx_onnx_cgen/dtypes.py +40 -0
  9. emx_onnx_cgen/errors.py +14 -0
  10. emx_onnx_cgen/ir/__init__.py +3 -0
  11. emx_onnx_cgen/ir/model.py +55 -0
  12. emx_onnx_cgen/lowering/__init__.py +3 -0
  13. emx_onnx_cgen/lowering/arg_reduce.py +99 -0
  14. emx_onnx_cgen/lowering/attention.py +421 -0
  15. emx_onnx_cgen/lowering/average_pool.py +229 -0
  16. emx_onnx_cgen/lowering/batch_normalization.py +116 -0
  17. emx_onnx_cgen/lowering/cast.py +70 -0
  18. emx_onnx_cgen/lowering/common.py +72 -0
  19. emx_onnx_cgen/lowering/concat.py +31 -0
  20. emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
  21. emx_onnx_cgen/lowering/conv.py +192 -0
  22. emx_onnx_cgen/lowering/cumsum.py +118 -0
  23. emx_onnx_cgen/lowering/depth_space.py +114 -0
  24. emx_onnx_cgen/lowering/dropout.py +46 -0
  25. emx_onnx_cgen/lowering/elementwise.py +164 -0
  26. emx_onnx_cgen/lowering/expand.py +151 -0
  27. emx_onnx_cgen/lowering/eye_like.py +43 -0
  28. emx_onnx_cgen/lowering/flatten.py +60 -0
  29. emx_onnx_cgen/lowering/gather.py +48 -0
  30. emx_onnx_cgen/lowering/gather_elements.py +60 -0
  31. emx_onnx_cgen/lowering/gemm.py +139 -0
  32. emx_onnx_cgen/lowering/grid_sample.py +149 -0
  33. emx_onnx_cgen/lowering/group_normalization.py +68 -0
  34. emx_onnx_cgen/lowering/identity.py +43 -0
  35. emx_onnx_cgen/lowering/instance_normalization.py +50 -0
  36. emx_onnx_cgen/lowering/layer_normalization.py +110 -0
  37. emx_onnx_cgen/lowering/logsoftmax.py +47 -0
  38. emx_onnx_cgen/lowering/lp_normalization.py +45 -0
  39. emx_onnx_cgen/lowering/lrn.py +104 -0
  40. emx_onnx_cgen/lowering/lstm.py +355 -0
  41. emx_onnx_cgen/lowering/matmul.py +120 -0
  42. emx_onnx_cgen/lowering/maxpool.py +195 -0
  43. emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
  44. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
  45. emx_onnx_cgen/lowering/pad.py +287 -0
  46. emx_onnx_cgen/lowering/range.py +104 -0
  47. emx_onnx_cgen/lowering/reduce.py +544 -0
  48. emx_onnx_cgen/lowering/registry.py +51 -0
  49. emx_onnx_cgen/lowering/reshape.py +188 -0
  50. emx_onnx_cgen/lowering/resize.py +445 -0
  51. emx_onnx_cgen/lowering/rms_normalization.py +67 -0
  52. emx_onnx_cgen/lowering/shape.py +78 -0
  53. emx_onnx_cgen/lowering/size.py +33 -0
  54. emx_onnx_cgen/lowering/slice.py +425 -0
  55. emx_onnx_cgen/lowering/softmax.py +47 -0
  56. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
  57. emx_onnx_cgen/lowering/split.py +150 -0
  58. emx_onnx_cgen/lowering/squeeze.py +161 -0
  59. emx_onnx_cgen/lowering/tile.py +81 -0
  60. emx_onnx_cgen/lowering/transpose.py +46 -0
  61. emx_onnx_cgen/lowering/unsqueeze.py +157 -0
  62. emx_onnx_cgen/lowering/variadic.py +95 -0
  63. emx_onnx_cgen/lowering/where.py +73 -0
  64. emx_onnx_cgen/onnx_import.py +261 -0
  65. emx_onnx_cgen/ops.py +565 -0
  66. emx_onnx_cgen/runtime/__init__.py +1 -0
  67. emx_onnx_cgen/runtime/evaluator.py +2206 -0
  68. emx_onnx_cgen/validation.py +76 -0
  69. emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
  70. emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
  71. emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
  72. emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
  73. emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
  74. shared/__init__.py +2 -0
  75. shared/scalar_functions.py +2405 -0
  76. shared/scalar_types.py +243 -0
@@ -0,0 +1,151 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..codegen.c_emitter import ExpandOp
8
+ from ..errors import ShapeInferenceError, UnsupportedOpError
9
+ from ..ir.model import Graph, Initializer, Node
10
+ from ..lowering.common import value_dtype, value_shape
11
+ from .registry import register_lowering
12
+
13
+
14
+ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
15
+ for initializer in graph.initializers:
16
+ if initializer.name == name:
17
+ return initializer
18
+ return None
19
+
20
+
21
+ def _read_shape_values(graph: Graph, name: str, node: Node) -> list[int] | None:
22
+ initializer = _find_initializer(graph, name)
23
+ if initializer is None:
24
+ return None
25
+ if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
26
+ raise UnsupportedOpError(
27
+ f"{node.op_type} shape input must be int64 or int32"
28
+ )
29
+ if len(initializer.type.shape) != 1:
30
+ raise UnsupportedOpError(
31
+ f"{node.op_type} shape input must be a 1D tensor"
32
+ )
33
+ values = np.array(initializer.data, dtype=np.int64).reshape(-1)
34
+ if values.size == 0:
35
+ raise ShapeInferenceError(
36
+ f"{node.op_type} shape input cannot be empty"
37
+ )
38
+ return [int(value) for value in values]
39
+
40
+
41
+ def _validate_shape_input(graph: Graph, name: str, node: Node) -> None:
42
+ dtype = value_dtype(graph, name, node)
43
+ if dtype not in {ScalarType.I64, ScalarType.I32}:
44
+ raise UnsupportedOpError(
45
+ f"{node.op_type} shape input must be int64 or int32"
46
+ )
47
+ shape = value_shape(graph, name, node)
48
+ if len(shape) != 1:
49
+ raise UnsupportedOpError(
50
+ f"{node.op_type} shape input must be a 1D tensor"
51
+ )
52
+ if shape[0] <= 0:
53
+ raise ShapeInferenceError(
54
+ f"{node.op_type} shape input cannot be empty"
55
+ )
56
+
57
+
58
+ def _validate_static_dims(shape: tuple[int, ...], node: Node) -> None:
59
+ if any(dim < 0 for dim in shape):
60
+ raise ShapeInferenceError(
61
+ f"{node.op_type} does not support dynamic dims"
62
+ )
63
+
64
+
65
+ def _broadcast_shape(
66
+ input_shape: tuple[int, ...], shape_values: list[int], node: Node
67
+ ) -> tuple[int, ...]:
68
+ _validate_static_dims(input_shape, node)
69
+ for dim in shape_values:
70
+ if dim < 0:
71
+ raise ShapeInferenceError(
72
+ f"{node.op_type} does not support dynamic dims"
73
+ )
74
+ output_rank = max(len(input_shape), len(shape_values))
75
+ input_padded = (1,) * (output_rank - len(input_shape)) + input_shape
76
+ shape_padded = (1,) * (output_rank - len(shape_values)) + tuple(shape_values)
77
+ result: list[int] = []
78
+ for input_dim, shape_dim in zip(input_padded, shape_padded):
79
+ if input_dim == 1:
80
+ result.append(shape_dim)
81
+ elif shape_dim == 1:
82
+ result.append(input_dim)
83
+ elif input_dim == shape_dim:
84
+ result.append(input_dim)
85
+ else:
86
+ raise ShapeInferenceError(
87
+ f"{node.op_type} input shape {input_shape} is not "
88
+ f"broadcastable to {shape_values}"
89
+ )
90
+ return tuple(result)
91
+
92
+
93
+ def _compute_strides(shape: tuple[int, ...]) -> tuple[int, ...]:
94
+ strides: list[int] = []
95
+ stride = 1
96
+ for dim in reversed(shape):
97
+ strides.append(stride)
98
+ stride *= dim
99
+ return tuple(reversed(strides))
100
+
101
+
102
+ @register_lowering("Expand")
103
+ def lower_expand(graph: Graph, node: Node) -> ExpandOp:
104
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
105
+ raise UnsupportedOpError("Expand must have 2 inputs and 1 output")
106
+ input_shape = value_shape(graph, node.inputs[0], node)
107
+ output_shape = value_shape(graph, node.outputs[0], node)
108
+ input_dtype = value_dtype(graph, node.inputs[0], node)
109
+ output_dtype = value_dtype(graph, node.outputs[0], node)
110
+ if input_dtype != output_dtype:
111
+ raise UnsupportedOpError(
112
+ f"{node.op_type} expects matching input/output dtypes, "
113
+ f"got {input_dtype} and {output_dtype}"
114
+ )
115
+ shape_values = _read_shape_values(graph, node.inputs[1], node)
116
+ if shape_values is not None:
117
+ expected_output_shape = _broadcast_shape(input_shape, shape_values, node)
118
+ _validate_static_dims(expected_output_shape, node)
119
+ if output_shape and output_shape != expected_output_shape:
120
+ raise ShapeInferenceError(
121
+ f"{node.op_type} output shape must be {expected_output_shape}, "
122
+ f"got {output_shape}"
123
+ )
124
+ else:
125
+ _validate_shape_input(graph, node.inputs[1], node)
126
+ if not output_shape:
127
+ raise ShapeInferenceError(
128
+ f"{node.op_type} output shape must be specified"
129
+ )
130
+ expected_output_shape = _broadcast_shape(
131
+ input_shape, list(output_shape), node
132
+ )
133
+ if expected_output_shape != output_shape:
134
+ raise ShapeInferenceError(
135
+ f"{node.op_type} output shape must be {expected_output_shape}, "
136
+ f"got {output_shape}"
137
+ )
138
+ input_shape_padded = (
139
+ (1,) * (len(expected_output_shape) - len(input_shape)) + input_shape
140
+ )
141
+ input_strides = _compute_strides(input_shape_padded)
142
+ return ExpandOp(
143
+ input0=node.inputs[0],
144
+ output=node.outputs[0],
145
+ input_shape=input_shape,
146
+ output_shape=expected_output_shape,
147
+ input_shape_padded=input_shape_padded,
148
+ input_strides=input_strides,
149
+ dtype=input_dtype,
150
+ input_dtype=input_dtype,
151
+ )
@@ -0,0 +1,43 @@
1
+ from __future__ import annotations
2
+
3
+ from ..codegen.c_emitter import EyeLikeOp
4
+ from ..dtypes import scalar_type_from_onnx
5
+ from ..errors import ShapeInferenceError, UnsupportedOpError
6
+ from ..ir.model import Graph, Node
7
+ from .common import value_dtype, value_shape
8
+ from .registry import register_lowering
9
+
10
+
11
+ @register_lowering("EyeLike")
12
+ def lower_eye_like(graph: Graph, node: Node) -> EyeLikeOp:
13
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
14
+ raise UnsupportedOpError("EyeLike must have 1 input and 1 output")
15
+ input_shape = value_shape(graph, node.inputs[0], node)
16
+ output_shape = value_shape(graph, node.outputs[0], node)
17
+ if input_shape != output_shape:
18
+ raise ShapeInferenceError("EyeLike input and output shapes must match")
19
+ if len(output_shape) < 2:
20
+ raise UnsupportedOpError("EyeLike expects input rank >= 2")
21
+ input_dtype = value_dtype(graph, node.inputs[0], node)
22
+ output_dtype = value_dtype(graph, node.outputs[0], node)
23
+ dtype_attr = node.attrs.get("dtype")
24
+ if dtype_attr is not None:
25
+ target_dtype = scalar_type_from_onnx(int(dtype_attr))
26
+ if target_dtype is None:
27
+ raise UnsupportedOpError(
28
+ f"EyeLike dtype {dtype_attr} is not supported"
29
+ )
30
+ if output_dtype != target_dtype:
31
+ raise UnsupportedOpError(
32
+ "EyeLike output dtype must match dtype attribute, "
33
+ f"got {output_dtype.onnx_name} and {target_dtype.onnx_name}"
34
+ )
35
+ k = int(node.attrs.get("k", 0))
36
+ return EyeLikeOp(
37
+ input0=node.inputs[0],
38
+ output=node.outputs[0],
39
+ output_shape=output_shape,
40
+ k=k,
41
+ dtype=output_dtype,
42
+ input_dtype=input_dtype,
43
+ )
@@ -0,0 +1,60 @@
1
+ from __future__ import annotations
2
+
3
+ from ..codegen.c_emitter import ReshapeOp
4
+ from ..errors import ShapeInferenceError, UnsupportedOpError
5
+ from ..ir.model import Graph, Node
6
+ from .common import shape_product, value_dtype, value_shape
7
+ from .registry import register_lowering
8
+
9
+
10
+ def _normalize_axis(axis: int, rank: int) -> int:
11
+ if axis < 0:
12
+ axis += rank
13
+ if axis < 0 or axis > rank:
14
+ raise UnsupportedOpError("Flatten axis must be within input rank")
15
+ return axis
16
+
17
+
18
+ def _flatten_output_shape(
19
+ input_shape: tuple[int, ...], axis: int
20
+ ) -> tuple[int, int]:
21
+ rank = len(input_shape)
22
+ axis = _normalize_axis(axis, rank)
23
+ if rank == 0:
24
+ return (1, 1)
25
+ for dim in input_shape:
26
+ if dim < 0:
27
+ raise ShapeInferenceError("Dynamic dims are not supported")
28
+ first = shape_product(input_shape[:axis]) if axis else 1
29
+ second = shape_product(input_shape[axis:]) if axis < rank else 1
30
+ return (first, second)
31
+
32
+
33
+ @register_lowering("Flatten")
34
+ def lower_flatten(graph: Graph, node: Node) -> ReshapeOp:
35
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
36
+ raise UnsupportedOpError("Flatten must have 1 input and 1 output")
37
+ input_shape = value_shape(graph, node.inputs[0], node)
38
+ input_dtype = value_dtype(graph, node.inputs[0], node)
39
+ output_dtype = value_dtype(graph, node.outputs[0], node)
40
+ if input_dtype != output_dtype:
41
+ raise UnsupportedOpError(
42
+ "Flatten expects matching input/output dtypes, "
43
+ f"got {input_dtype} and {output_dtype}"
44
+ )
45
+ axis = int(node.attrs.get("axis", 1))
46
+ output_shape = _flatten_output_shape(input_shape, axis)
47
+ expected_shape = value_shape(graph, node.outputs[0], node)
48
+ if expected_shape and output_shape != expected_shape:
49
+ raise ShapeInferenceError(
50
+ "Flatten output shape must be "
51
+ f"{output_shape}, got {expected_shape}"
52
+ )
53
+ return ReshapeOp(
54
+ input0=node.inputs[0],
55
+ output=node.outputs[0],
56
+ input_shape=input_shape,
57
+ output_shape=output_shape,
58
+ dtype=input_dtype,
59
+ input_dtype=input_dtype,
60
+ )
@@ -0,0 +1,48 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..codegen.c_emitter import GatherOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from ..validation import normalize_axis
9
+ from .common import value_dtype as _value_dtype
10
+ from .common import value_shape as _value_shape
11
+ from .registry import register_lowering
12
+
13
+
14
+ @register_lowering("Gather")
15
+ def lower_gather(graph: Graph, node: Node) -> GatherOp:
16
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
17
+ raise UnsupportedOpError("Gather must have 2 inputs and 1 output")
18
+ data_name, indices_name = node.inputs
19
+ data_shape = _value_shape(graph, data_name, node)
20
+ indices_shape = _value_shape(graph, indices_name, node)
21
+ output_shape = _value_shape(graph, node.outputs[0], node)
22
+ axis = normalize_axis(int(node.attrs.get("axis", 0)), data_shape, node)
23
+ expected_output_shape = (
24
+ data_shape[:axis] + indices_shape + data_shape[axis + 1 :]
25
+ )
26
+ if output_shape != expected_output_shape:
27
+ raise ShapeInferenceError(
28
+ "Gather output shape must be "
29
+ f"{expected_output_shape}, got {output_shape}"
30
+ )
31
+ op_dtype = _value_dtype(graph, data_name, node)
32
+ indices_dtype = _value_dtype(graph, indices_name, node)
33
+ if indices_dtype not in {ScalarType.I64, ScalarType.I32}:
34
+ raise UnsupportedOpError(
35
+ "Gather indices must be int32 or int64, "
36
+ f"got {indices_dtype.onnx_name}"
37
+ )
38
+ return GatherOp(
39
+ data=data_name,
40
+ indices=indices_name,
41
+ output=node.outputs[0],
42
+ axis=axis,
43
+ data_shape=data_shape,
44
+ indices_shape=indices_shape,
45
+ output_shape=output_shape,
46
+ dtype=op_dtype,
47
+ indices_dtype=indices_dtype,
48
+ )
@@ -0,0 +1,60 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..codegen.c_emitter import GatherElementsOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from ..validation import normalize_axis
9
+ from .common import value_dtype as _value_dtype
10
+ from .common import value_shape as _value_shape
11
+ from .registry import register_lowering
12
+
13
+
14
+ @register_lowering("GatherElements")
15
+ def lower_gather_elements(graph: Graph, node: Node) -> GatherElementsOp:
16
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
17
+ raise UnsupportedOpError("GatherElements must have 2 inputs and 1 output")
18
+ data_name, indices_name = node.inputs
19
+ data_shape = _value_shape(graph, data_name, node)
20
+ indices_shape = _value_shape(graph, indices_name, node)
21
+ output_shape = _value_shape(graph, node.outputs[0], node)
22
+ if len(data_shape) != len(indices_shape):
23
+ raise ShapeInferenceError(
24
+ "GatherElements inputs must have matching ranks, "
25
+ f"got {data_shape} and {indices_shape}"
26
+ )
27
+ if output_shape != indices_shape:
28
+ raise ShapeInferenceError(
29
+ "GatherElements output shape must match indices shape, "
30
+ f"got {output_shape} and {indices_shape}"
31
+ )
32
+ axis = normalize_axis(int(node.attrs.get("axis", 0)), data_shape, node)
33
+ for dim_index, (data_dim, index_dim) in enumerate(
34
+ zip(data_shape, indices_shape)
35
+ ):
36
+ if dim_index == axis:
37
+ continue
38
+ if data_dim != index_dim:
39
+ raise ShapeInferenceError(
40
+ "GatherElements inputs must match on non-axis dimensions, "
41
+ f"got {data_shape} and {indices_shape}"
42
+ )
43
+ op_dtype = _value_dtype(graph, data_name, node)
44
+ indices_dtype = _value_dtype(graph, indices_name, node)
45
+ if indices_dtype not in {ScalarType.I64, ScalarType.I32}:
46
+ raise UnsupportedOpError(
47
+ "GatherElements indices must be int32 or int64, "
48
+ f"got {indices_dtype.onnx_name}"
49
+ )
50
+ return GatherElementsOp(
51
+ data=data_name,
52
+ indices=indices_name,
53
+ output=node.outputs[0],
54
+ axis=axis,
55
+ data_shape=data_shape,
56
+ indices_shape=indices_shape,
57
+ output_shape=output_shape,
58
+ dtype=op_dtype,
59
+ indices_dtype=indices_dtype,
60
+ )
@@ -0,0 +1,139 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..codegen.c_emitter import GemmOp
8
+ from ..errors import ShapeInferenceError, UnsupportedOpError
9
+ from ..ir.model import Graph, Node
10
+ from .common import node_dtype as _node_dtype
11
+ from .common import value_shape as _value_shape
12
+ from .registry import register_lowering
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class GemmSpec:
17
+ m: int
18
+ n: int
19
+ k: int
20
+ alpha: float | int
21
+ beta: float | int
22
+ trans_a: bool
23
+ trans_b: bool
24
+ c_shape: tuple[int, ...] | None
25
+
26
+
27
+ def resolve_gemm_spec(graph: Graph, node: Node, dtype: ScalarType) -> GemmSpec:
28
+ if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
29
+ raise UnsupportedOpError("Gemm must have 2 or 3 inputs and 1 output")
30
+ alpha, beta, trans_a, trans_b = _resolve_gemm_attrs(node, dtype)
31
+ input0_shape = _value_shape(graph, node.inputs[0], node)
32
+ input1_shape = _value_shape(graph, node.inputs[1], node)
33
+ if len(input0_shape) != 2 or len(input1_shape) != 2:
34
+ raise UnsupportedOpError(
35
+ "Gemm supports 2D inputs only, "
36
+ f"got {input0_shape} x {input1_shape}"
37
+ )
38
+ if trans_a:
39
+ m, k_left = input0_shape[1], input0_shape[0]
40
+ else:
41
+ m, k_left = input0_shape
42
+ if trans_b:
43
+ n, k_right = input1_shape[0], input1_shape[1]
44
+ else:
45
+ k_right, n = input1_shape
46
+ if k_left != k_right:
47
+ raise ShapeInferenceError(
48
+ f"Gemm inner dimensions must match, got {k_left} and {k_right}"
49
+ )
50
+ output_shape = _value_shape(graph, node.outputs[0], node)
51
+ if output_shape != (m, n):
52
+ raise ShapeInferenceError(
53
+ f"Gemm output shape must be {(m, n)}, got {output_shape}"
54
+ )
55
+ c_shape = None
56
+ if len(node.inputs) == 3:
57
+ bias_shape = _value_shape(graph, node.inputs[2], node)
58
+ c_shape = validate_gemm_bias_shape((m, n), bias_shape, node)
59
+ return GemmSpec(
60
+ m=m,
61
+ n=n,
62
+ k=k_left,
63
+ alpha=alpha,
64
+ beta=beta,
65
+ trans_a=trans_a,
66
+ trans_b=trans_b,
67
+ c_shape=c_shape,
68
+ )
69
+
70
+
71
+ def _resolve_gemm_attrs(
72
+ node: Node, dtype: ScalarType
73
+ ) -> tuple[float | int, float | int, bool, bool]:
74
+ alpha = float(node.attrs.get("alpha", 1.0))
75
+ beta = float(node.attrs.get("beta", 1.0))
76
+ trans_a = int(node.attrs.get("transA", 0))
77
+ trans_b = int(node.attrs.get("transB", 0))
78
+ if trans_a not in {0, 1} or trans_b not in {0, 1}:
79
+ raise UnsupportedOpError(
80
+ "Gemm only supports transA/transB values of 0 or 1"
81
+ )
82
+ if dtype == ScalarType.BOOL:
83
+ raise UnsupportedOpError("Gemm supports numeric inputs only")
84
+ if not dtype.is_float:
85
+ alpha_int = int(alpha)
86
+ beta_int = int(beta)
87
+ if alpha != alpha_int or beta != beta_int:
88
+ raise UnsupportedOpError(
89
+ "Gemm alpha and beta must be integers for non-float inputs"
90
+ )
91
+ alpha = alpha_int
92
+ beta = beta_int
93
+ return alpha, beta, bool(trans_a), bool(trans_b)
94
+
95
+
96
+ def validate_gemm_bias_shape(
97
+ output_shape: tuple[int, int], bias_shape: tuple[int, ...], node: Node
98
+ ) -> tuple[int, ...]:
99
+ if len(bias_shape) == 0:
100
+ return bias_shape
101
+ if len(bias_shape) == 1:
102
+ if bias_shape[0] not in {1, output_shape[1]}:
103
+ raise ShapeInferenceError(
104
+ "Gemm bias input must be broadcastable to output shape, "
105
+ f"got {bias_shape} vs {output_shape}"
106
+ )
107
+ return bias_shape
108
+ if len(bias_shape) == 2:
109
+ m, n = output_shape
110
+ if bias_shape[0] not in {1, m} or bias_shape[1] not in {1, n}:
111
+ raise ShapeInferenceError(
112
+ "Gemm bias input must be broadcastable to output shape, "
113
+ f"got {bias_shape} vs {output_shape}"
114
+ )
115
+ return bias_shape
116
+ raise ShapeInferenceError(
117
+ f"Gemm bias input must be rank 1 or 2, got {bias_shape}"
118
+ )
119
+
120
+
121
+ @register_lowering("Gemm")
122
+ def lower_gemm(graph: Graph, node: Node) -> GemmOp:
123
+ op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
124
+ spec = resolve_gemm_spec(graph, node, op_dtype)
125
+ return GemmOp(
126
+ input_a=node.inputs[0],
127
+ input_b=node.inputs[1],
128
+ input_c=node.inputs[2] if len(node.inputs) == 3 else None,
129
+ output=node.outputs[0],
130
+ m=spec.m,
131
+ n=spec.n,
132
+ k=spec.k,
133
+ trans_a=spec.trans_a,
134
+ trans_b=spec.trans_b,
135
+ alpha=spec.alpha,
136
+ beta=spec.beta,
137
+ c_shape=spec.c_shape,
138
+ dtype=op_dtype,
139
+ )
@@ -0,0 +1,149 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..codegen.c_emitter import GridSampleOp
8
+ from ..errors import ShapeInferenceError, UnsupportedOpError
9
+ from ..ir.model import Graph, Node
10
+ from .common import value_dtype, value_shape
11
+ from .registry import register_lowering
12
+
13
+ _SUPPORTED_MODES = {"linear", "nearest", "cubic"}
14
+ _SUPPORTED_PADDING_MODES = {"zeros", "border", "reflection"}
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class _GridSampleShapes:
19
+ input_shape: tuple[int, ...]
20
+ grid_shape: tuple[int, ...]
21
+ output_shape: tuple[int, ...]
22
+ spatial_rank: int
23
+
24
+
25
+ def _decode_attr(value: object, default: str) -> str:
26
+ if value is None:
27
+ return default
28
+ if isinstance(value, bytes):
29
+ return value.decode("utf-8", errors="ignore")
30
+ if isinstance(value, str):
31
+ return value
32
+ return str(value)
33
+
34
+
35
+ def _resolve_shapes(graph: Graph, node: Node) -> _GridSampleShapes:
36
+ input_shape = value_shape(graph, node.inputs[0], node)
37
+ grid_shape = value_shape(graph, node.inputs[1], node)
38
+ output_shape = value_shape(graph, node.outputs[0], node)
39
+ if len(input_shape) < 3:
40
+ raise ShapeInferenceError(
41
+ "GridSample expects input rank of at least 3"
42
+ )
43
+ spatial_rank = len(input_shape) - 2
44
+ if any(dim < 0 for dim in (*input_shape, *grid_shape, *output_shape)):
45
+ raise ShapeInferenceError(
46
+ "GridSample requires static, non-negative shapes"
47
+ )
48
+ return _GridSampleShapes(
49
+ input_shape=input_shape,
50
+ grid_shape=grid_shape,
51
+ output_shape=output_shape,
52
+ spatial_rank=spatial_rank,
53
+ )
54
+
55
+
56
+ def _validate_shapes(shapes: _GridSampleShapes) -> None:
57
+ input_shape = shapes.input_shape
58
+ grid_shape = shapes.grid_shape
59
+ output_shape = shapes.output_shape
60
+ spatial_rank = shapes.spatial_rank
61
+ if len(grid_shape) != spatial_rank + 2:
62
+ raise ShapeInferenceError(
63
+ "GridSample expects grid rank to match input spatial rank"
64
+ )
65
+ if len(output_shape) != spatial_rank + 2:
66
+ raise ShapeInferenceError(
67
+ "GridSample expects output rank to match input spatial rank"
68
+ )
69
+ if grid_shape[0] != input_shape[0]:
70
+ raise ShapeInferenceError("GridSample expects matching batch dimension")
71
+ if grid_shape[-1] != spatial_rank:
72
+ raise ShapeInferenceError(
73
+ "GridSample expects grid last dimension to match spatial rank"
74
+ )
75
+ expected_output = (
76
+ input_shape[0],
77
+ input_shape[1],
78
+ *grid_shape[1:-1],
79
+ )
80
+ if output_shape != expected_output:
81
+ raise ShapeInferenceError(
82
+ "GridSample output shape must be "
83
+ f"{expected_output}, got {output_shape}"
84
+ )
85
+
86
+
87
+ def _validate_dtypes(
88
+ graph: Graph, node: Node
89
+ ) -> tuple[ScalarType, ScalarType]:
90
+ input_dtype = value_dtype(graph, node.inputs[0], node)
91
+ grid_dtype = value_dtype(graph, node.inputs[1], node)
92
+ output_dtype = value_dtype(graph, node.outputs[0], node)
93
+ if input_dtype != output_dtype:
94
+ raise UnsupportedOpError(
95
+ "GridSample expects matching input/output dtypes, got "
96
+ f"{input_dtype.onnx_name} and {output_dtype.onnx_name}"
97
+ )
98
+ if not input_dtype.is_float:
99
+ raise UnsupportedOpError(
100
+ "GridSample currently supports floating-point inputs only"
101
+ )
102
+ if not grid_dtype.is_float:
103
+ raise UnsupportedOpError("GridSample expects floating-point grid")
104
+ return input_dtype, grid_dtype
105
+
106
+
107
+ @register_lowering("GridSample")
108
+ def lower_grid_sample(graph: Graph, node: Node) -> GridSampleOp:
109
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
110
+ raise UnsupportedOpError(
111
+ "GridSample expects 2 inputs (X, grid) and 1 output"
112
+ )
113
+ shapes = _resolve_shapes(graph, node)
114
+ _validate_shapes(shapes)
115
+ mode = _decode_attr(node.attrs.get("mode"), "linear")
116
+ padding_mode = _decode_attr(node.attrs.get("padding_mode"), "zeros")
117
+ align_corners = int(node.attrs.get("align_corners", 0))
118
+ if mode not in _SUPPORTED_MODES:
119
+ raise UnsupportedOpError(
120
+ f"GridSample mode {mode!r} is not supported"
121
+ )
122
+ if padding_mode not in _SUPPORTED_PADDING_MODES:
123
+ raise UnsupportedOpError(
124
+ "GridSample padding_mode "
125
+ f"{padding_mode!r} is not supported"
126
+ )
127
+ if align_corners not in {0, 1}:
128
+ raise UnsupportedOpError("GridSample align_corners must be 0 or 1")
129
+ input_dtype, grid_dtype = _validate_dtypes(graph, node)
130
+ if shapes.spatial_rank > 3:
131
+ raise UnsupportedOpError(
132
+ "GridSample supports up to 3 spatial dimensions"
133
+ )
134
+ return GridSampleOp(
135
+ input0=node.inputs[0],
136
+ grid=node.inputs[1],
137
+ output=node.outputs[0],
138
+ input_shape=shapes.input_shape,
139
+ grid_shape=shapes.grid_shape,
140
+ output_shape=shapes.output_shape,
141
+ spatial_rank=shapes.spatial_rank,
142
+ input_spatial=shapes.input_shape[2:],
143
+ output_spatial=shapes.output_shape[2:],
144
+ mode=mode,
145
+ padding_mode=padding_mode,
146
+ align_corners=bool(align_corners),
147
+ dtype=input_dtype,
148
+ grid_dtype=grid_dtype,
149
+ )