emx-onnx-cgen 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (94) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +50 -23
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +1844 -1568
  6. emx_onnx_cgen/codegen/emitter.py +5 -0
  7. emx_onnx_cgen/compiler.py +30 -387
  8. emx_onnx_cgen/ir/context.py +87 -0
  9. emx_onnx_cgen/ir/op_base.py +193 -0
  10. emx_onnx_cgen/ir/op_context.py +65 -0
  11. emx_onnx_cgen/ir/ops/__init__.py +130 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +146 -0
  13. emx_onnx_cgen/ir/ops/misc.py +421 -0
  14. emx_onnx_cgen/ir/ops/nn.py +580 -0
  15. emx_onnx_cgen/ir/ops/reduce.py +95 -0
  16. emx_onnx_cgen/lowering/__init__.py +79 -1
  17. emx_onnx_cgen/lowering/adagrad.py +114 -0
  18. emx_onnx_cgen/lowering/arg_reduce.py +1 -1
  19. emx_onnx_cgen/lowering/attention.py +1 -1
  20. emx_onnx_cgen/lowering/average_pool.py +1 -1
  21. emx_onnx_cgen/lowering/batch_normalization.py +1 -1
  22. emx_onnx_cgen/lowering/cast.py +1 -1
  23. emx_onnx_cgen/lowering/common.py +36 -18
  24. emx_onnx_cgen/lowering/concat.py +1 -1
  25. emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
  26. emx_onnx_cgen/lowering/conv.py +1 -1
  27. emx_onnx_cgen/lowering/conv_transpose.py +1 -1
  28. emx_onnx_cgen/lowering/cumsum.py +1 -1
  29. emx_onnx_cgen/lowering/depth_space.py +1 -1
  30. emx_onnx_cgen/lowering/dropout.py +1 -1
  31. emx_onnx_cgen/lowering/einsum.py +1 -1
  32. emx_onnx_cgen/lowering/elementwise.py +152 -4
  33. emx_onnx_cgen/lowering/expand.py +1 -1
  34. emx_onnx_cgen/lowering/eye_like.py +1 -1
  35. emx_onnx_cgen/lowering/flatten.py +1 -1
  36. emx_onnx_cgen/lowering/gather.py +1 -1
  37. emx_onnx_cgen/lowering/gather_elements.py +1 -1
  38. emx_onnx_cgen/lowering/gather_nd.py +1 -1
  39. emx_onnx_cgen/lowering/gemm.py +1 -1
  40. emx_onnx_cgen/lowering/global_max_pool.py +1 -1
  41. emx_onnx_cgen/lowering/grid_sample.py +1 -1
  42. emx_onnx_cgen/lowering/group_normalization.py +1 -1
  43. emx_onnx_cgen/lowering/hardmax.py +1 -1
  44. emx_onnx_cgen/lowering/identity.py +1 -1
  45. emx_onnx_cgen/lowering/instance_normalization.py +1 -1
  46. emx_onnx_cgen/lowering/layer_normalization.py +1 -1
  47. emx_onnx_cgen/lowering/logsoftmax.py +1 -1
  48. emx_onnx_cgen/lowering/lp_normalization.py +1 -1
  49. emx_onnx_cgen/lowering/lp_pool.py +1 -1
  50. emx_onnx_cgen/lowering/lrn.py +1 -1
  51. emx_onnx_cgen/lowering/lstm.py +1 -1
  52. emx_onnx_cgen/lowering/matmul.py +1 -1
  53. emx_onnx_cgen/lowering/maxpool.py +1 -1
  54. emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
  55. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +1 -1
  56. emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
  57. emx_onnx_cgen/lowering/nonzero.py +1 -1
  58. emx_onnx_cgen/lowering/one_hot.py +1 -1
  59. emx_onnx_cgen/lowering/pad.py +1 -1
  60. emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
  61. emx_onnx_cgen/lowering/quantize_linear.py +1 -1
  62. emx_onnx_cgen/lowering/range.py +1 -1
  63. emx_onnx_cgen/lowering/reduce.py +1 -1
  64. emx_onnx_cgen/lowering/registry.py +24 -5
  65. emx_onnx_cgen/lowering/reshape.py +1 -1
  66. emx_onnx_cgen/lowering/resize.py +1 -1
  67. emx_onnx_cgen/lowering/rms_normalization.py +1 -1
  68. emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
  69. emx_onnx_cgen/lowering/scatter_nd.py +1 -1
  70. emx_onnx_cgen/lowering/shape.py +6 -25
  71. emx_onnx_cgen/lowering/size.py +1 -1
  72. emx_onnx_cgen/lowering/slice.py +1 -1
  73. emx_onnx_cgen/lowering/softmax.py +1 -1
  74. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
  75. emx_onnx_cgen/lowering/split.py +1 -1
  76. emx_onnx_cgen/lowering/squeeze.py +1 -1
  77. emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
  78. emx_onnx_cgen/lowering/tile.py +1 -1
  79. emx_onnx_cgen/lowering/topk.py +25 -7
  80. emx_onnx_cgen/lowering/transpose.py +1 -1
  81. emx_onnx_cgen/lowering/trilu.py +1 -1
  82. emx_onnx_cgen/lowering/unsqueeze.py +1 -1
  83. emx_onnx_cgen/lowering/variadic.py +1 -1
  84. emx_onnx_cgen/lowering/where.py +1 -1
  85. emx_onnx_cgen/runtime/evaluator.py +325 -1
  86. emx_onnx_cgen/verification.py +9 -39
  87. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/METADATA +8 -7
  88. emx_onnx_cgen-0.3.1.dist-info/RECORD +107 -0
  89. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/WHEEL +1 -1
  90. shared/scalar_functions.py +11 -0
  91. shared/ulp.py +17 -0
  92. emx_onnx_cgen-0.3.0.dist-info/RECORD +0 -93
  93. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/entry_points.txt +0 -0
  94. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import ReshapeOp
5
+ from ..ir.ops import ReshapeOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Initializer, Node
8
8
  from .common import value_shape as resolved_value_shape
@@ -4,7 +4,7 @@ from dataclasses import dataclass
4
4
 
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
- from ..codegen.c_emitter import ResizeOp
7
+ from ..ir.ops import ResizeOp
8
8
  from ..errors import ShapeInferenceError, UnsupportedOpError
9
9
  from ..ir.model import Graph, Initializer, Node
10
10
  from .registry import register_lowering
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import RMSNormalizationOp
3
+ from ..ir.ops import RMSNormalizationOp
4
4
  from ..errors import ShapeInferenceError, UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from ..validation import ensure_output_shape_matches_input
@@ -0,0 +1,165 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..ir.ops import RotaryEmbeddingOp
8
+ from ..errors import ShapeInferenceError, UnsupportedOpError
9
+ from ..ir.model import Graph, Node
10
+ from .common import optional_name, value_dtype, value_shape
11
+ from .registry import register_lowering
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class RotaryEmbeddingSpec:
16
+ batch: int
17
+ seq_len: int
18
+ num_heads: int
19
+ head_size: int
20
+ rotary_dim: int
21
+ rotary_dim_half: int
22
+ input_rank: int
23
+
24
+
25
+ def _resolve_rotary_spec(
26
+ graph: Graph, node: Node, dtype: ScalarType
27
+ ) -> RotaryEmbeddingSpec:
28
+ if not dtype.is_float:
29
+ raise UnsupportedOpError("Unsupported op RotaryEmbedding")
30
+ if len(node.inputs) < 3 or len(node.outputs) != 1:
31
+ raise UnsupportedOpError("Unsupported op RotaryEmbedding")
32
+ input_shape = value_shape(graph, node.inputs[0], node)
33
+ input_rank = len(input_shape)
34
+ if input_rank not in {3, 4}:
35
+ raise ShapeInferenceError("RotaryEmbedding expects 3D or 4D input")
36
+ if input_rank == 3:
37
+ num_heads_attr = node.attrs.get("num_heads")
38
+ if num_heads_attr is None:
39
+ raise UnsupportedOpError(
40
+ "RotaryEmbedding num_heads attribute is required for 3D inputs"
41
+ )
42
+ num_heads = int(num_heads_attr)
43
+ if num_heads <= 0:
44
+ raise ShapeInferenceError("RotaryEmbedding num_heads must be > 0")
45
+ batch, seq_len, hidden_size = input_shape
46
+ if hidden_size % num_heads != 0:
47
+ raise ShapeInferenceError(
48
+ "RotaryEmbedding hidden size must be divisible by num_heads"
49
+ )
50
+ head_size = hidden_size // num_heads
51
+ else:
52
+ batch, num_heads, seq_len, head_size = input_shape
53
+ num_heads_attr = node.attrs.get("num_heads")
54
+ if num_heads_attr is not None and int(num_heads_attr) != num_heads:
55
+ raise ShapeInferenceError(
56
+ "RotaryEmbedding num_heads must match input head dimension"
57
+ )
58
+ if head_size % 2 != 0:
59
+ raise ShapeInferenceError("RotaryEmbedding head size must be even")
60
+ rotary_dim = int(node.attrs.get("rotary_embedding_dim", 0))
61
+ if rotary_dim == 0:
62
+ rotary_dim = head_size
63
+ if rotary_dim < 0 or rotary_dim > head_size:
64
+ raise ShapeInferenceError(
65
+ "RotaryEmbedding rotary_embedding_dim must be in [0, head_size]"
66
+ )
67
+ if rotary_dim % 2 != 0:
68
+ raise ShapeInferenceError(
69
+ "RotaryEmbedding rotary_embedding_dim must be even"
70
+ )
71
+ rotary_dim_half = rotary_dim // 2
72
+ return RotaryEmbeddingSpec(
73
+ batch=batch,
74
+ seq_len=seq_len,
75
+ num_heads=num_heads,
76
+ head_size=head_size,
77
+ rotary_dim=rotary_dim,
78
+ rotary_dim_half=rotary_dim_half,
79
+ input_rank=input_rank,
80
+ )
81
+
82
+
83
+ @register_lowering("RotaryEmbedding")
84
+ def lower_rotary_embedding(graph: Graph, node: Node) -> RotaryEmbeddingOp:
85
+ input_name = node.inputs[0]
86
+ cos_name = node.inputs[1]
87
+ sin_name = node.inputs[2]
88
+ position_ids = optional_name(node.inputs, 3)
89
+ dtype = value_dtype(graph, input_name, node)
90
+ cos_dtype = value_dtype(graph, cos_name, node)
91
+ sin_dtype = value_dtype(graph, sin_name, node)
92
+ if cos_dtype != dtype or sin_dtype != dtype:
93
+ raise ShapeInferenceError(
94
+ "RotaryEmbedding inputs must share the same dtype"
95
+ )
96
+ spec = _resolve_rotary_spec(graph, node, dtype)
97
+ input_shape = value_shape(graph, input_name, node)
98
+ output_shape = value_shape(graph, node.outputs[0], node)
99
+ if output_shape != input_shape:
100
+ raise ShapeInferenceError(
101
+ "RotaryEmbedding output shape must match input shape"
102
+ )
103
+ cos_shape = value_shape(graph, cos_name, node)
104
+ sin_shape = value_shape(graph, sin_name, node)
105
+ if cos_shape != sin_shape:
106
+ raise ShapeInferenceError(
107
+ "RotaryEmbedding cos/sin cache shapes must match"
108
+ )
109
+ position_shape = None
110
+ position_dtype = None
111
+ if position_ids is not None:
112
+ position_shape = value_shape(graph, position_ids, node)
113
+ if position_shape != (spec.batch, spec.seq_len):
114
+ raise ShapeInferenceError(
115
+ "RotaryEmbedding position_ids must match [batch, seq_len]"
116
+ )
117
+ position_dtype = value_dtype(graph, position_ids, node)
118
+ if not position_dtype.is_integer:
119
+ raise ShapeInferenceError(
120
+ "RotaryEmbedding position_ids must be an integer tensor"
121
+ )
122
+ if len(cos_shape) != 2:
123
+ raise ShapeInferenceError(
124
+ "RotaryEmbedding expects 2D sin/cos caches with position_ids"
125
+ )
126
+ if cos_shape[1] != spec.rotary_dim_half:
127
+ raise ShapeInferenceError(
128
+ "RotaryEmbedding cos/sin cache last dim must match rotary_dim/2"
129
+ )
130
+ else:
131
+ if len(cos_shape) != 3:
132
+ raise ShapeInferenceError(
133
+ "RotaryEmbedding expects 3D sin/cos caches without position_ids"
134
+ )
135
+ if cos_shape != (
136
+ spec.batch,
137
+ spec.seq_len,
138
+ spec.rotary_dim_half,
139
+ ):
140
+ raise ShapeInferenceError(
141
+ "RotaryEmbedding sin/cos cache shape must be "
142
+ "[batch, seq_len, rotary_dim/2]"
143
+ )
144
+ interleaved = bool(int(node.attrs.get("interleaved", 0)))
145
+ return RotaryEmbeddingOp(
146
+ input0=input_name,
147
+ cos_cache=cos_name,
148
+ sin_cache=sin_name,
149
+ position_ids=position_ids,
150
+ output=node.outputs[0],
151
+ input_shape=input_shape,
152
+ cos_shape=cos_shape,
153
+ sin_shape=sin_shape,
154
+ position_ids_shape=position_shape,
155
+ dtype=dtype,
156
+ position_ids_dtype=position_dtype,
157
+ rotary_dim=spec.rotary_dim,
158
+ rotary_dim_half=spec.rotary_dim_half,
159
+ head_size=spec.head_size,
160
+ num_heads=spec.num_heads,
161
+ seq_len=spec.seq_len,
162
+ batch=spec.batch,
163
+ input_rank=spec.input_rank,
164
+ interleaved=interleaved,
165
+ )
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import ScatterNDOp
5
+ from ..ir.ops import ScatterNDOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .common import value_dtype, value_shape
@@ -2,32 +2,13 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import ShapeOp
5
+ from ..ir.ops import ShapeOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
+ from .common import value_dtype, value_shape
8
9
  from .registry import register_lowering
9
10
 
10
11
 
11
- def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
12
- try:
13
- return graph.find_value(name).type.shape
14
- except KeyError as exc:
15
- raise ShapeInferenceError(
16
- f"Missing shape for value '{name}' in op {node.op_type}. "
17
- "Hint: run ONNX shape inference or export with static shapes."
18
- ) from exc
19
-
20
-
21
- def _value_dtype(graph: Graph, name: str, node: Node) -> ScalarType:
22
- try:
23
- return graph.find_value(name).type.dtype
24
- except KeyError as exc:
25
- raise ShapeInferenceError(
26
- f"Missing dtype for value '{name}' in op {node.op_type}. "
27
- "Hint: run ONNX shape inference or export with static shapes."
28
- ) from exc
29
-
30
-
31
12
  def _normalize_slice_bounds(
32
13
  rank: int, *, start: int | None, end: int | None
33
14
  ) -> tuple[int, int]:
@@ -46,14 +27,14 @@ def _normalize_slice_bounds(
46
27
  def lower_shape(graph: Graph, node: Node) -> ShapeOp:
47
28
  if len(node.inputs) != 1 or len(node.outputs) != 1:
48
29
  raise UnsupportedOpError("Shape must have 1 input and 1 output")
49
- input_shape = _value_shape(graph, node.inputs[0], node)
50
- output_shape = _value_shape(graph, node.outputs[0], node)
30
+ input_shape = value_shape(graph, node.inputs[0], node)
31
+ output_shape = value_shape(graph, node.outputs[0], node)
51
32
  if len(output_shape) != 1:
52
33
  raise ShapeInferenceError("Shape output must be 1D")
53
34
  if output_shape[0] < 0:
54
35
  raise ShapeInferenceError("Shape output length must be non-negative")
55
- input_dtype = _value_dtype(graph, node.inputs[0], node)
56
- output_dtype = _value_dtype(graph, node.outputs[0], node)
36
+ input_dtype = value_dtype(graph, node.inputs[0], node)
37
+ output_dtype = value_dtype(graph, node.outputs[0], node)
57
38
  if output_dtype != ScalarType.I64:
58
39
  raise UnsupportedOpError("Shape output dtype must be int64")
59
40
  start = node.attrs.get("start")
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import SizeOp
5
+ from ..ir.ops import SizeOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .common import shape_product, value_dtype, value_shape
@@ -6,7 +6,7 @@ import numpy as np
6
6
 
7
7
  from shared.scalar_types import ScalarType
8
8
 
9
- from ..codegen.c_emitter import SliceOp
9
+ from ..ir.ops import SliceOp
10
10
  from ..errors import ShapeInferenceError, UnsupportedOpError
11
11
  from ..ir.model import Graph, Initializer, Node
12
12
  from ..lowering.common import value_dtype, value_shape
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import SoftmaxOp
3
+ from ..ir.ops import SoftmaxOp
4
4
  from ..errors import UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from .common import node_dtype as _node_dtype
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import SoftmaxCrossEntropyLossOp
5
+ from ..ir.ops import SoftmaxCrossEntropyLossOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .common import shape_product as _shape_product
@@ -4,7 +4,7 @@ import numpy as np
4
4
 
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
- from ..codegen.c_emitter import SplitOp
7
+ from ..ir.ops import SplitOp
8
8
  from ..errors import ShapeInferenceError, UnsupportedOpError
9
9
  from ..ir.model import Graph, Initializer, Node
10
10
  from ..lowering.common import optional_name, value_dtype, value_shape
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import ReshapeOp
5
+ from ..ir.ops import ReshapeOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Initializer, Node
8
8
  from .registry import register_lowering
@@ -0,0 +1,110 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..ir.ops import TensorScatterOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from ..validation import normalize_axis
9
+ from .common import optional_name, value_dtype, value_shape
10
+ from .registry import register_lowering
11
+
12
+ _ALLOWED_MODES = {"linear", "circular"}
13
+
14
+
15
+ @register_lowering("TensorScatter")
16
+ def lower_tensor_scatter(graph: Graph, node: Node) -> TensorScatterOp:
17
+ if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
18
+ raise UnsupportedOpError(
19
+ "TensorScatter must have 2 or 3 inputs and 1 output"
20
+ )
21
+ past_cache_name = node.inputs[0]
22
+ update_name = node.inputs[1]
23
+ write_indices_name = optional_name(node.inputs, 2)
24
+ output_name = node.outputs[0]
25
+ past_cache_shape = value_shape(graph, past_cache_name, node)
26
+ update_shape = value_shape(graph, update_name, node)
27
+ output_shape = value_shape(graph, output_name, node)
28
+ if output_shape != past_cache_shape:
29
+ raise ShapeInferenceError(
30
+ "TensorScatter output shape must match past_cache shape, "
31
+ f"got {output_shape} vs {past_cache_shape}"
32
+ )
33
+ if len(update_shape) != len(past_cache_shape):
34
+ raise ShapeInferenceError(
35
+ "TensorScatter update shape rank must match past_cache rank, "
36
+ f"got {len(update_shape)} vs {len(past_cache_shape)}"
37
+ )
38
+ axis = normalize_axis(int(node.attrs.get("axis", -2)), past_cache_shape, node)
39
+ if axis == 0:
40
+ raise UnsupportedOpError(
41
+ "TensorScatter axis cannot be 0 (batch dimension)"
42
+ )
43
+ for dim_index, (past_dim, update_dim) in enumerate(
44
+ zip(past_cache_shape, update_shape)
45
+ ):
46
+ if dim_index == axis:
47
+ if update_dim > past_dim:
48
+ raise ShapeInferenceError(
49
+ "TensorScatter update sequence length must be <= "
50
+ "past_cache sequence length, "
51
+ f"got {update_dim} vs {past_dim}"
52
+ )
53
+ elif update_dim != past_dim:
54
+ raise ShapeInferenceError(
55
+ "TensorScatter update shape must match past_cache shape "
56
+ f"outside axis {axis}, got {update_shape} vs {past_cache_shape}"
57
+ )
58
+ mode = node.attrs.get("mode", "linear")
59
+ if isinstance(mode, bytes):
60
+ mode = mode.decode("utf-8")
61
+ if mode not in _ALLOWED_MODES:
62
+ raise UnsupportedOpError(
63
+ "TensorScatter mode must be one of "
64
+ f"{sorted(_ALLOWED_MODES)}, got {mode}"
65
+ )
66
+ dtype = value_dtype(graph, past_cache_name, node)
67
+ update_dtype = value_dtype(graph, update_name, node)
68
+ output_dtype = value_dtype(graph, output_name, node)
69
+ if update_dtype != dtype or output_dtype != dtype:
70
+ raise UnsupportedOpError(
71
+ "TensorScatter expects past_cache, update, and output "
72
+ "to share the same dtype, "
73
+ f"got {dtype.onnx_name}, {update_dtype.onnx_name}, "
74
+ f"{output_dtype.onnx_name}"
75
+ )
76
+ write_indices_shape = None
77
+ write_indices_dtype = None
78
+ if write_indices_name is not None:
79
+ write_indices_shape = value_shape(graph, write_indices_name, node)
80
+ if len(write_indices_shape) != 1:
81
+ raise ShapeInferenceError(
82
+ "TensorScatter write_indices must be a 1D tensor"
83
+ )
84
+ if write_indices_shape[0] != past_cache_shape[0]:
85
+ raise ShapeInferenceError(
86
+ "TensorScatter write_indices length must match batch size, "
87
+ f"got {write_indices_shape[0]} vs {past_cache_shape[0]}"
88
+ )
89
+ write_indices_dtype = value_dtype(
90
+ graph, write_indices_name, node
91
+ )
92
+ if write_indices_dtype not in {ScalarType.I64, ScalarType.I32}:
93
+ raise UnsupportedOpError(
94
+ "TensorScatter write_indices must be int32 or int64, "
95
+ f"got {write_indices_dtype.onnx_name}"
96
+ )
97
+ return TensorScatterOp(
98
+ past_cache=past_cache_name,
99
+ update=update_name,
100
+ write_indices=write_indices_name,
101
+ output=output_name,
102
+ past_cache_shape=past_cache_shape,
103
+ update_shape=update_shape,
104
+ output_shape=output_shape,
105
+ write_indices_shape=write_indices_shape,
106
+ axis=axis,
107
+ mode=mode,
108
+ dtype=dtype,
109
+ write_indices_dtype=write_indices_dtype,
110
+ )
@@ -4,7 +4,7 @@ import numpy as np
4
4
 
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
- from ..codegen.c_emitter import TileOp
7
+ from ..ir.ops import TileOp
8
8
  from ..errors import ShapeInferenceError, UnsupportedOpError
9
9
  from ..ir.model import Graph, Initializer, Node
10
10
  from ..lowering.common import value_dtype, value_shape
@@ -4,7 +4,7 @@ import numpy as np
4
4
 
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
- from ..codegen.c_emitter import TopKOp
7
+ from ..ir.ops import TopKOp
8
8
  from ..errors import ShapeInferenceError, UnsupportedOpError
9
9
  from ..ir.model import Graph, Initializer, Node
10
10
  from ..lowering.common import shape_product, value_dtype, value_shape
@@ -19,12 +19,10 @@ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
19
19
  return None
20
20
 
21
21
 
22
- def _read_k(graph: Graph, name: str, node: Node) -> int:
22
+ def _read_k(graph: Graph, name: str, node: Node) -> int | None:
23
23
  initializer = _find_initializer(graph, name)
24
24
  if initializer is None:
25
- raise UnsupportedOpError(
26
- f"{node.op_type} k input must be a constant initializer"
27
- )
25
+ return None
28
26
  if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
29
27
  raise UnsupportedOpError(
30
28
  f"{node.op_type} k input must be int64 or int32"
@@ -63,6 +61,28 @@ def lower_topk(graph: Graph, node: Node) -> TopKOp:
63
61
  axis = normalize_axis(axis, input_shape, node)
64
62
  k = _read_k(graph, k_name, node)
65
63
  axis_dim = input_shape[axis]
64
+ values_shape = value_shape(graph, output_values, node)
65
+ indices_shape = value_shape(graph, output_indices, node)
66
+ if values_shape != indices_shape:
67
+ raise ShapeInferenceError(
68
+ f"{node.op_type} values and indices output shapes must match, "
69
+ f"got {values_shape} and {indices_shape}"
70
+ )
71
+ if k is None:
72
+ k_shape = value_shape(graph, k_name, node)
73
+ if len(k_shape) != 1 or k_shape[0] != 1:
74
+ raise ShapeInferenceError(
75
+ f"{node.op_type} k input must be a 1-element tensor"
76
+ )
77
+ if axis >= len(values_shape):
78
+ raise ShapeInferenceError(
79
+ f"{node.op_type} axis {axis} exceeds output rank {len(values_shape)}"
80
+ )
81
+ k = values_shape[axis]
82
+ if k <= 0:
83
+ raise ShapeInferenceError(
84
+ f"{node.op_type} k must be a positive value, got {k}"
85
+ )
66
86
  if k > axis_dim:
67
87
  raise ShapeInferenceError(
68
88
  f"{node.op_type} k {k} exceeds axis dimension {axis_dim}"
@@ -70,12 +90,10 @@ def lower_topk(graph: Graph, node: Node) -> TopKOp:
70
90
  output_shape_expected = list(input_shape)
71
91
  output_shape_expected[axis] = k
72
92
  output_shape = tuple(output_shape_expected)
73
- values_shape = value_shape(graph, output_values, node)
74
93
  if values_shape != output_shape:
75
94
  raise ShapeInferenceError(
76
95
  f"{node.op_type} values output shape must be {output_shape}, got {values_shape}"
77
96
  )
78
- indices_shape = value_shape(graph, output_indices, node)
79
97
  if indices_shape != output_shape:
80
98
  raise ShapeInferenceError(
81
99
  f"{node.op_type} indices output shape must be {output_shape}, got {indices_shape}"
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import TransposeOp
3
+ from ..ir.ops import TransposeOp
4
4
  from ..errors import ShapeInferenceError, UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from .common import node_dtype as _node_dtype
@@ -4,7 +4,7 @@ import numpy as np
4
4
 
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
- from ..codegen.c_emitter import TriluOp
7
+ from ..ir.ops import TriluOp
8
8
  from ..errors import ShapeInferenceError, UnsupportedOpError
9
9
  from ..ir.model import Graph, Initializer, Node
10
10
  from ..lowering.common import optional_name, value_dtype, value_shape
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import ReshapeOp
5
+ from ..ir.ops import ReshapeOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Initializer, Node
8
8
  from .registry import register_lowering
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  from shared.scalar_functions import ScalarFunction
4
4
  from shared.scalar_types import ScalarType
5
5
 
6
- from ..codegen.c_emitter import MultiInputBinaryOp
6
+ from ..ir.ops import MultiInputBinaryOp
7
7
  from ..errors import UnsupportedOpError
8
8
  from ..ir.model import Graph, Node
9
9
  from ..lowering.common import node_dtype, value_dtype, value_shape
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import WhereOp
5
+ from ..ir.ops import WhereOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .common import value_dtype as _value_dtype