emx-onnx-cgen 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (76) hide show
  1. emx_onnx_cgen/__init__.py +6 -0
  2. emx_onnx_cgen/__main__.py +9 -0
  3. emx_onnx_cgen/_build_info.py +3 -0
  4. emx_onnx_cgen/cli.py +328 -0
  5. emx_onnx_cgen/codegen/__init__.py +25 -0
  6. emx_onnx_cgen/codegen/c_emitter.py +9044 -0
  7. emx_onnx_cgen/compiler.py +601 -0
  8. emx_onnx_cgen/dtypes.py +40 -0
  9. emx_onnx_cgen/errors.py +14 -0
  10. emx_onnx_cgen/ir/__init__.py +3 -0
  11. emx_onnx_cgen/ir/model.py +55 -0
  12. emx_onnx_cgen/lowering/__init__.py +3 -0
  13. emx_onnx_cgen/lowering/arg_reduce.py +99 -0
  14. emx_onnx_cgen/lowering/attention.py +421 -0
  15. emx_onnx_cgen/lowering/average_pool.py +229 -0
  16. emx_onnx_cgen/lowering/batch_normalization.py +116 -0
  17. emx_onnx_cgen/lowering/cast.py +70 -0
  18. emx_onnx_cgen/lowering/common.py +72 -0
  19. emx_onnx_cgen/lowering/concat.py +31 -0
  20. emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
  21. emx_onnx_cgen/lowering/conv.py +192 -0
  22. emx_onnx_cgen/lowering/cumsum.py +118 -0
  23. emx_onnx_cgen/lowering/depth_space.py +114 -0
  24. emx_onnx_cgen/lowering/dropout.py +46 -0
  25. emx_onnx_cgen/lowering/elementwise.py +164 -0
  26. emx_onnx_cgen/lowering/expand.py +151 -0
  27. emx_onnx_cgen/lowering/eye_like.py +43 -0
  28. emx_onnx_cgen/lowering/flatten.py +60 -0
  29. emx_onnx_cgen/lowering/gather.py +48 -0
  30. emx_onnx_cgen/lowering/gather_elements.py +60 -0
  31. emx_onnx_cgen/lowering/gemm.py +139 -0
  32. emx_onnx_cgen/lowering/grid_sample.py +149 -0
  33. emx_onnx_cgen/lowering/group_normalization.py +68 -0
  34. emx_onnx_cgen/lowering/identity.py +43 -0
  35. emx_onnx_cgen/lowering/instance_normalization.py +50 -0
  36. emx_onnx_cgen/lowering/layer_normalization.py +110 -0
  37. emx_onnx_cgen/lowering/logsoftmax.py +47 -0
  38. emx_onnx_cgen/lowering/lp_normalization.py +45 -0
  39. emx_onnx_cgen/lowering/lrn.py +104 -0
  40. emx_onnx_cgen/lowering/lstm.py +355 -0
  41. emx_onnx_cgen/lowering/matmul.py +120 -0
  42. emx_onnx_cgen/lowering/maxpool.py +195 -0
  43. emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
  44. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
  45. emx_onnx_cgen/lowering/pad.py +287 -0
  46. emx_onnx_cgen/lowering/range.py +104 -0
  47. emx_onnx_cgen/lowering/reduce.py +544 -0
  48. emx_onnx_cgen/lowering/registry.py +51 -0
  49. emx_onnx_cgen/lowering/reshape.py +188 -0
  50. emx_onnx_cgen/lowering/resize.py +445 -0
  51. emx_onnx_cgen/lowering/rms_normalization.py +67 -0
  52. emx_onnx_cgen/lowering/shape.py +78 -0
  53. emx_onnx_cgen/lowering/size.py +33 -0
  54. emx_onnx_cgen/lowering/slice.py +425 -0
  55. emx_onnx_cgen/lowering/softmax.py +47 -0
  56. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
  57. emx_onnx_cgen/lowering/split.py +150 -0
  58. emx_onnx_cgen/lowering/squeeze.py +161 -0
  59. emx_onnx_cgen/lowering/tile.py +81 -0
  60. emx_onnx_cgen/lowering/transpose.py +46 -0
  61. emx_onnx_cgen/lowering/unsqueeze.py +157 -0
  62. emx_onnx_cgen/lowering/variadic.py +95 -0
  63. emx_onnx_cgen/lowering/where.py +73 -0
  64. emx_onnx_cgen/onnx_import.py +261 -0
  65. emx_onnx_cgen/ops.py +565 -0
  66. emx_onnx_cgen/runtime/__init__.py +1 -0
  67. emx_onnx_cgen/runtime/evaluator.py +2206 -0
  68. emx_onnx_cgen/validation.py +76 -0
  69. emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
  70. emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
  71. emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
  72. emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
  73. emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
  74. shared/__init__.py +2 -0
  75. shared/scalar_functions.py +2405 -0
  76. shared/scalar_types.py +243 -0
@@ -0,0 +1,129 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..codegen.c_emitter import SoftmaxCrossEntropyLossOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .common import shape_product as _shape_product
9
+ from .common import value_dtype as _value_dtype
10
+ from .common import value_shape as _value_shape
11
+ from .registry import register_lowering
12
+
13
+
14
+ @register_lowering("SoftmaxCrossEntropyLoss")
15
+ def lower_softmax_cross_entropy_loss(
16
+ graph: Graph, node: Node
17
+ ) -> SoftmaxCrossEntropyLossOp:
18
+ if len(node.inputs) not in {2, 3} or len(node.outputs) not in {1, 2}:
19
+ raise UnsupportedOpError(
20
+ "SoftmaxCrossEntropyLoss must have 2 or 3 inputs and 1 or 2 outputs"
21
+ )
22
+ input_name = node.inputs[0]
23
+ target_name = node.inputs[1]
24
+ weight_name = node.inputs[2] if len(node.inputs) > 2 else None
25
+ input_dtype = _value_dtype(graph, input_name, node)
26
+ if not input_dtype.is_float:
27
+ raise UnsupportedOpError(
28
+ "SoftmaxCrossEntropyLoss supports float16, float, and double inputs only"
29
+ )
30
+ output_name = node.outputs[0]
31
+ output_dtype = _value_dtype(graph, output_name, node)
32
+ if output_dtype != input_dtype:
33
+ raise UnsupportedOpError(
34
+ "SoftmaxCrossEntropyLoss output dtype must match input dtype"
35
+ )
36
+ log_prob_name = node.outputs[1] if len(node.outputs) > 1 else None
37
+ if log_prob_name is not None:
38
+ log_prob_dtype = _value_dtype(graph, log_prob_name, node)
39
+ if log_prob_dtype != input_dtype:
40
+ raise UnsupportedOpError(
41
+ "SoftmaxCrossEntropyLoss log_prob dtype must match input dtype"
42
+ )
43
+ target_dtype = _value_dtype(graph, target_name, node)
44
+ if target_dtype not in {ScalarType.I32, ScalarType.I64}:
45
+ raise UnsupportedOpError(
46
+ "SoftmaxCrossEntropyLoss target must be int32 or int64"
47
+ )
48
+ weight_dtype = None
49
+ weight_shape: tuple[int, ...] | None = None
50
+ if weight_name is not None:
51
+ weight_dtype = _value_dtype(graph, weight_name, node)
52
+ if weight_dtype != input_dtype:
53
+ raise UnsupportedOpError(
54
+ "SoftmaxCrossEntropyLoss weight dtype must match input dtype"
55
+ )
56
+ input_shape = _value_shape(graph, input_name, node)
57
+ target_shape = _value_shape(graph, target_name, node)
58
+ output_shape = _value_shape(graph, output_name, node)
59
+ if len(input_shape) < 2:
60
+ raise ShapeInferenceError("SoftmaxCrossEntropyLoss input must be at least 2D")
61
+ if len(target_shape) != len(input_shape) - 1:
62
+ raise ShapeInferenceError(
63
+ "SoftmaxCrossEntropyLoss target rank must be input rank - 1"
64
+ )
65
+ if input_shape[0] != target_shape[0]:
66
+ raise ShapeInferenceError(
67
+ "SoftmaxCrossEntropyLoss target batch dimension must match input"
68
+ )
69
+ if input_shape[2:] != target_shape[1:]:
70
+ raise ShapeInferenceError(
71
+ "SoftmaxCrossEntropyLoss target spatial dimensions must match input"
72
+ )
73
+ if weight_name is not None:
74
+ weight_shape = _value_shape(graph, weight_name, node)
75
+ if len(weight_shape) != 1 or weight_shape[0] != input_shape[1]:
76
+ raise ShapeInferenceError(
77
+ "SoftmaxCrossEntropyLoss weight must have shape (C,)"
78
+ )
79
+ if log_prob_name is not None:
80
+ log_prob_shape = _value_shape(graph, log_prob_name, node)
81
+ if log_prob_shape != input_shape:
82
+ raise ShapeInferenceError(
83
+ "SoftmaxCrossEntropyLoss log_prob output must match input shape"
84
+ )
85
+ reduction = node.attrs.get("reduction", "mean")
86
+ if isinstance(reduction, bytes):
87
+ reduction = reduction.decode("utf-8")
88
+ if reduction not in {"none", "mean", "sum"}:
89
+ raise UnsupportedOpError(
90
+ "SoftmaxCrossEntropyLoss reduction must be none, mean, or sum"
91
+ )
92
+ if reduction == "none":
93
+ if output_shape != target_shape:
94
+ raise ShapeInferenceError(
95
+ "SoftmaxCrossEntropyLoss output must match target shape "
96
+ "when reduction is none"
97
+ )
98
+ else:
99
+ if output_shape not in {(), (1,)}:
100
+ raise ShapeInferenceError(
101
+ "SoftmaxCrossEntropyLoss output must be scalar when reduced"
102
+ )
103
+ n = input_shape[0]
104
+ c = input_shape[1]
105
+ d = _shape_product(input_shape[2:]) if len(input_shape) > 2 else 1
106
+ ignore_index = node.attrs.get("ignore_index")
107
+ if ignore_index is not None:
108
+ ignore_index = int(ignore_index)
109
+ return SoftmaxCrossEntropyLossOp(
110
+ input0=input_name,
111
+ target=target_name,
112
+ weight=weight_name,
113
+ output=output_name,
114
+ log_prob=log_prob_name,
115
+ input_shape=input_shape,
116
+ target_shape=target_shape,
117
+ output_shape=output_shape,
118
+ log_prob_shape=input_shape if log_prob_name is not None else None,
119
+ n=n,
120
+ c=c,
121
+ d=d,
122
+ reduction=reduction,
123
+ ignore_index=ignore_index,
124
+ input_dtype=input_dtype,
125
+ weight_dtype=weight_dtype,
126
+ weight_shape=weight_shape,
127
+ dtype=input_dtype,
128
+ target_dtype=target_dtype,
129
+ )
@@ -0,0 +1,150 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..codegen.c_emitter import SplitOp
8
+ from ..errors import ShapeInferenceError, UnsupportedOpError
9
+ from ..ir.model import Graph, Initializer, Node
10
+ from ..lowering.common import optional_name, value_dtype, value_shape
11
+ from ..validation import normalize_axis
12
+ from .registry import register_lowering
13
+
14
+
15
+ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
16
+ for initializer in graph.initializers:
17
+ if initializer.name == name:
18
+ return initializer
19
+ return None
20
+
21
+
22
+ def _read_split_sizes(graph: Graph, name: str, node: Node) -> list[int] | None:
23
+ initializer = _find_initializer(graph, name)
24
+ if initializer is None:
25
+ return None
26
+ if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
27
+ raise UnsupportedOpError(
28
+ f"{node.op_type} split input must be int64 or int32"
29
+ )
30
+ if len(initializer.type.shape) != 1:
31
+ raise UnsupportedOpError(
32
+ f"{node.op_type} split input must be a 1D tensor"
33
+ )
34
+ values = np.array(initializer.data, dtype=np.int64).reshape(-1)
35
+ if values.size == 0:
36
+ raise ShapeInferenceError(
37
+ f"{node.op_type} split input cannot be empty"
38
+ )
39
+ return [int(value) for value in values]
40
+
41
+
42
+ def _validate_static_dims(shape: tuple[int, ...], node: Node) -> None:
43
+ if any(dim < 0 for dim in shape):
44
+ raise ShapeInferenceError(
45
+ f"{node.op_type} does not support dynamic dims"
46
+ )
47
+
48
+
49
+ def _normalize_num_outputs(node: Node, output_count: int) -> int:
50
+ num_outputs_attr = node.attrs.get("num_outputs")
51
+ if num_outputs_attr is None:
52
+ return output_count
53
+ num_outputs = int(num_outputs_attr)
54
+ if num_outputs <= 0:
55
+ raise UnsupportedOpError("Split num_outputs must be positive")
56
+ if output_count != num_outputs:
57
+ raise ShapeInferenceError(
58
+ f"Split expects {num_outputs} outputs, got {output_count}"
59
+ )
60
+ return num_outputs
61
+
62
+
63
+ @register_lowering("Split")
64
+ def lower_split(graph: Graph, node: Node) -> SplitOp:
65
+ if len(node.inputs) < 1 or len(node.outputs) < 1:
66
+ raise UnsupportedOpError("Split must have at least 1 input and 1 output")
67
+ if len(node.inputs) > 2:
68
+ raise UnsupportedOpError("Split supports up to 2 inputs")
69
+ input_name = node.inputs[0]
70
+ if not input_name:
71
+ raise UnsupportedOpError("Split input must be provided")
72
+ input_shape = value_shape(graph, input_name, node)
73
+ _validate_static_dims(input_shape, node)
74
+ axis = normalize_axis(int(node.attrs.get("axis", 0)), input_shape, node)
75
+ output_shapes = [
76
+ value_shape(graph, output, node) for output in node.outputs
77
+ ]
78
+ input_dtype = value_dtype(graph, input_name, node)
79
+ output_dtypes = {value_dtype(graph, output, node) for output in node.outputs}
80
+ if output_dtypes != {input_dtype}:
81
+ dtype_names = ", ".join(
82
+ dtype.onnx_name for dtype in sorted(output_dtypes, key=str)
83
+ )
84
+ raise UnsupportedOpError(
85
+ f"Split expects matching dtypes, got {dtype_names}"
86
+ )
87
+ split_name = optional_name(node.inputs, 1)
88
+ if split_name is not None and "num_outputs" in node.attrs:
89
+ raise UnsupportedOpError(
90
+ "Split cannot specify both split input and num_outputs"
91
+ )
92
+ if split_name is not None:
93
+ split_sizes = _read_split_sizes(graph, split_name, node)
94
+ if split_sizes is None:
95
+ split_shape, split_dtype = value_shape(
96
+ graph, split_name, node
97
+ ), value_dtype(graph, split_name, node)
98
+ if split_dtype not in {ScalarType.I64, ScalarType.I32}:
99
+ raise UnsupportedOpError(
100
+ f"{node.op_type} split input must be int64 or int32"
101
+ )
102
+ if len(split_shape) != 1:
103
+ raise UnsupportedOpError(
104
+ f"{node.op_type} split input must be a 1D tensor"
105
+ )
106
+ if split_shape[0] != len(node.outputs):
107
+ raise ShapeInferenceError(
108
+ f"Split expects {len(node.outputs)} outputs, got {split_shape[0]}"
109
+ )
110
+ split_sizes = [shape[axis] for shape in output_shapes]
111
+ if len(split_sizes) != len(node.outputs):
112
+ raise ShapeInferenceError(
113
+ f"Split expects {len(split_sizes)} outputs, got {len(node.outputs)}"
114
+ )
115
+ if any(size < 0 for size in split_sizes):
116
+ raise ShapeInferenceError("Split sizes must be non-negative")
117
+ if sum(split_sizes) != input_shape[axis]:
118
+ raise ShapeInferenceError(
119
+ "Split sizes must sum to the axis dimension"
120
+ )
121
+ else:
122
+ num_outputs = _normalize_num_outputs(node, len(node.outputs))
123
+ axis_dim = input_shape[axis]
124
+ base = axis_dim // num_outputs
125
+ remainder = axis_dim % num_outputs
126
+ split_sizes = [base + 1] * remainder + [base] * (
127
+ num_outputs - remainder
128
+ )
129
+ computed_shapes: list[tuple[int, ...]] = []
130
+ for size, output_shape in zip(split_sizes, output_shapes):
131
+ if size < 0:
132
+ raise ShapeInferenceError("Split output size must be non-negative")
133
+ shape = list(input_shape)
134
+ shape[axis] = size
135
+ computed_shape = tuple(shape)
136
+ if output_shape != computed_shape:
137
+ raise ShapeInferenceError(
138
+ f"Split output shape must be {computed_shape}, got {output_shape}"
139
+ )
140
+ computed_shapes.append(computed_shape)
141
+ return SplitOp(
142
+ input0=input_name,
143
+ outputs=tuple(node.outputs),
144
+ input_shape=input_shape,
145
+ output_shapes=tuple(computed_shapes),
146
+ axis=axis,
147
+ split_sizes=tuple(split_sizes),
148
+ dtype=input_dtype,
149
+ input_dtype=input_dtype,
150
+ )
@@ -0,0 +1,161 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..codegen.c_emitter import ReshapeOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Initializer, Node
8
+ from .registry import register_lowering
9
+
10
+
11
+ def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
12
+ try:
13
+ return graph.find_value(name).type.shape
14
+ except KeyError as exc:
15
+ raise ShapeInferenceError(
16
+ f"Missing shape for value '{name}' in op {node.op_type}. "
17
+ "Hint: run ONNX shape inference or export with static shapes."
18
+ ) from exc
19
+
20
+
21
+ def _value_dtype(graph: Graph, name: str, node: Node) -> ScalarType:
22
+ try:
23
+ return graph.find_value(name).type.dtype
24
+ except KeyError as exc:
25
+ raise ShapeInferenceError(
26
+ f"Missing dtype for value '{name}' in op {node.op_type}. "
27
+ "Hint: run ONNX shape inference or export with static shapes."
28
+ ) from exc
29
+
30
+
31
+ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
32
+ for initializer in graph.initializers:
33
+ if initializer.name == name:
34
+ return initializer
35
+ return None
36
+
37
+
38
+ def _validate_shape(shape: tuple[int, ...], node: Node, label: str) -> None:
39
+ for dim in shape:
40
+ if dim < 0:
41
+ raise ShapeInferenceError(
42
+ f"{node.op_type} does not support dynamic dims in {label}"
43
+ )
44
+
45
+
46
+ def _normalize_axes(
47
+ axes: list[int], input_rank: int, node: Node
48
+ ) -> tuple[int, ...]:
49
+ normalized: list[int] = []
50
+ for axis in axes:
51
+ if axis < 0:
52
+ axis += input_rank
53
+ if axis < 0 or axis >= input_rank:
54
+ raise ShapeInferenceError(
55
+ f"{node.op_type} axis {axis} is out of range for rank {input_rank}"
56
+ )
57
+ normalized.append(axis)
58
+ if len(set(normalized)) != len(normalized):
59
+ raise ShapeInferenceError(f"{node.op_type} axes must be unique")
60
+ return tuple(sorted(normalized))
61
+
62
+
63
+ def _resolve_axes(graph: Graph, node: Node) -> tuple[int, ...] | None:
64
+ axes_attr = node.attrs.get("axes")
65
+ axes_values: list[int] | None = None
66
+ if len(node.inputs) == 2:
67
+ axes_initializer = _find_initializer(graph, node.inputs[1])
68
+ if axes_initializer is not None:
69
+ if axes_initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
70
+ raise UnsupportedOpError(
71
+ "Squeeze axes input must be int64 or int32, "
72
+ f"got {axes_initializer.type.dtype.onnx_name}"
73
+ )
74
+ axes_values = [int(value) for value in axes_initializer.data.reshape(-1)]
75
+ elif axes_attr is not None:
76
+ axes_values = [int(value) for value in axes_attr]
77
+ if axes_values is None:
78
+ return None
79
+ return tuple(axes_values)
80
+
81
+
82
+ def _expected_output_shape(
83
+ input_shape: tuple[int, ...], axes: tuple[int, ...]
84
+ ) -> tuple[int, ...]:
85
+ axis_set = set(axes)
86
+ return tuple(
87
+ dim for index, dim in enumerate(input_shape) if index not in axis_set
88
+ )
89
+
90
+
91
+ def _validate_output_shape_for_unknown_axes(
92
+ input_shape: tuple[int, ...], output_shape: tuple[int, ...], node: Node
93
+ ) -> None:
94
+ output_index = 0
95
+ for dim in input_shape:
96
+ if output_index < len(output_shape) and dim == output_shape[output_index]:
97
+ output_index += 1
98
+ continue
99
+ if dim != 1:
100
+ raise ShapeInferenceError(
101
+ "Squeeze output shape must remove only dimensions of size 1"
102
+ )
103
+ if output_index != len(output_shape):
104
+ raise ShapeInferenceError(
105
+ "Squeeze output shape must preserve input order while removing size-1 axes"
106
+ )
107
+
108
+
109
+ @register_lowering("Squeeze")
110
+ def lower_squeeze(graph: Graph, node: Node) -> ReshapeOp:
111
+ if len(node.outputs) != 1 or len(node.inputs) not in {1, 2}:
112
+ raise UnsupportedOpError("Squeeze must have 1 or 2 inputs and 1 output")
113
+ input_shape = _value_shape(graph, node.inputs[0], node)
114
+ output_shape = _value_shape(graph, node.outputs[0], node)
115
+ _validate_shape(input_shape, node, "input")
116
+ _validate_shape(output_shape, node, "output")
117
+ input_dtype = _value_dtype(graph, node.inputs[0], node)
118
+ output_dtype = _value_dtype(graph, node.outputs[0], node)
119
+ if input_dtype != output_dtype:
120
+ raise UnsupportedOpError(
121
+ "Squeeze expects matching input/output dtypes, "
122
+ f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
123
+ )
124
+ axes = _resolve_axes(graph, node)
125
+ if axes is None:
126
+ if len(node.inputs) == 2:
127
+ axes_dtype = _value_dtype(graph, node.inputs[1], node)
128
+ if axes_dtype not in {ScalarType.I64, ScalarType.I32}:
129
+ raise UnsupportedOpError(
130
+ "Squeeze axes input must be int64 or int32, "
131
+ f"got {axes_dtype.onnx_name}"
132
+ )
133
+ _validate_output_shape_for_unknown_axes(input_shape, output_shape, node)
134
+ else:
135
+ expected_shape = tuple(dim for dim in input_shape if dim != 1)
136
+ if expected_shape != output_shape:
137
+ raise ShapeInferenceError(
138
+ "Squeeze output shape must be "
139
+ f"{expected_shape}, got {output_shape}"
140
+ )
141
+ else:
142
+ normalized_axes = _normalize_axes(list(axes), len(input_shape), node)
143
+ for axis in normalized_axes:
144
+ if input_shape[axis] != 1:
145
+ raise ShapeInferenceError(
146
+ "Squeeze axes must target dimensions of size 1"
147
+ )
148
+ expected_shape = _expected_output_shape(input_shape, normalized_axes)
149
+ if expected_shape != output_shape:
150
+ raise ShapeInferenceError(
151
+ "Squeeze output shape must be "
152
+ f"{expected_shape}, got {output_shape}"
153
+ )
154
+ return ReshapeOp(
155
+ input0=node.inputs[0],
156
+ output=node.outputs[0],
157
+ input_shape=input_shape,
158
+ output_shape=output_shape,
159
+ dtype=input_dtype,
160
+ input_dtype=input_dtype,
161
+ )
@@ -0,0 +1,81 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..codegen.c_emitter import TileOp
8
+ from ..errors import ShapeInferenceError, UnsupportedOpError
9
+ from ..ir.model import Graph, Initializer, Node
10
+ from ..lowering.common import value_dtype, value_shape
11
+ from .registry import register_lowering
12
+
13
+
14
+ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
15
+ for initializer in graph.initializers:
16
+ if initializer.name == name:
17
+ return initializer
18
+ return None
19
+
20
+
21
+ def _read_repeats(graph: Graph, name: str, node: Node) -> tuple[int, ...] | None:
22
+ initializer = _find_initializer(graph, name)
23
+ if initializer is None:
24
+ return None
25
+ if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
26
+ raise UnsupportedOpError("Tile repeats input must be int64 or int32")
27
+ if len(initializer.type.shape) != 1:
28
+ raise UnsupportedOpError("Tile repeats input must be a 1D tensor")
29
+ values = np.array(initializer.data, dtype=np.int64).reshape(-1)
30
+ return tuple(int(value) for value in values)
31
+
32
+
33
+ def _compute_strides(shape: tuple[int, ...]) -> tuple[int, ...]:
34
+ strides: list[int] = []
35
+ stride = 1
36
+ for dim in reversed(shape):
37
+ strides.append(stride)
38
+ stride *= dim
39
+ return tuple(reversed(strides))
40
+
41
+
42
+ @register_lowering("Tile")
43
+ def lower_tile(graph: Graph, node: Node) -> TileOp:
44
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
45
+ raise UnsupportedOpError("Tile must have 2 inputs and 1 output")
46
+ input_shape = value_shape(graph, node.inputs[0], node)
47
+ output_shape = value_shape(graph, node.outputs[0], node)
48
+ input_dtype = value_dtype(graph, node.inputs[0], node)
49
+ output_dtype = value_dtype(graph, node.outputs[0], node)
50
+ if input_dtype != output_dtype:
51
+ raise UnsupportedOpError(
52
+ "Tile expects matching input/output dtypes, "
53
+ f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
54
+ )
55
+ repeats = _read_repeats(graph, node.inputs[1], node)
56
+ if repeats is None:
57
+ raise UnsupportedOpError("Tile repeats input must be a constant initializer")
58
+ if len(repeats) != len(input_shape):
59
+ raise ShapeInferenceError(
60
+ "Tile repeats must have the same rank as input shape"
61
+ )
62
+ if any(value < 0 for value in repeats):
63
+ raise UnsupportedOpError("Tile repeats must be non-negative")
64
+ expected_shape = tuple(
65
+ int(dim) * int(repeat) for dim, repeat in zip(input_shape, repeats)
66
+ )
67
+ if output_shape and output_shape != expected_shape:
68
+ raise ShapeInferenceError(
69
+ "Tile output shape mismatch: "
70
+ f"expected {expected_shape}, got {output_shape}"
71
+ )
72
+ return TileOp(
73
+ input0=node.inputs[0],
74
+ output=node.outputs[0],
75
+ input_shape=input_shape,
76
+ output_shape=expected_shape,
77
+ repeats=repeats,
78
+ input_strides=_compute_strides(input_shape),
79
+ dtype=output_dtype,
80
+ input_dtype=input_dtype,
81
+ )
@@ -0,0 +1,46 @@
1
+ from __future__ import annotations
2
+
3
+ from ..codegen.c_emitter import TransposeOp
4
+ from ..errors import ShapeInferenceError, UnsupportedOpError
5
+ from ..ir.model import Graph, Node
6
+ from .common import node_dtype as _node_dtype
7
+ from .common import value_shape as _value_shape
8
+ from .registry import register_lowering
9
+
10
+
11
+ @register_lowering("Transpose")
12
+ def lower_transpose(graph: Graph, node: Node) -> TransposeOp:
13
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
14
+ raise UnsupportedOpError("Transpose must have 1 input and 1 output")
15
+ input_shape = _value_shape(graph, node.inputs[0], node)
16
+ output_shape = _value_shape(graph, node.outputs[0], node)
17
+ perm = node.attrs.get("perm")
18
+ if perm is None:
19
+ perm = tuple(reversed(range(len(input_shape))))
20
+ else:
21
+ perm = tuple(int(axis) for axis in perm)
22
+ if len(perm) != len(input_shape):
23
+ raise ShapeInferenceError(
24
+ "Transpose perm must match input rank, "
25
+ f"got perm {perm} for shape {input_shape}"
26
+ )
27
+ if set(perm) != set(range(len(input_shape))):
28
+ raise UnsupportedOpError(
29
+ f"Transpose perm must be a permutation, got {perm}"
30
+ )
31
+ expected_shape = tuple(input_shape[axis] for axis in perm)
32
+ if output_shape != expected_shape:
33
+ raise ShapeInferenceError(
34
+ "Transpose output shape must match permuted input shape, "
35
+ f"expected {expected_shape}, got {output_shape}"
36
+ )
37
+ op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
38
+ return TransposeOp(
39
+ input0=node.inputs[0],
40
+ output=node.outputs[0],
41
+ perm=perm,
42
+ input_shape=input_shape,
43
+ output_shape=output_shape,
44
+ dtype=op_dtype,
45
+ input_dtype=op_dtype,
46
+ )