emx-onnx-cgen 0.3.7__py3-none-any.whl → 0.4.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +1025 -162
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +2081 -458
  6. emx_onnx_cgen/compiler.py +157 -75
  7. emx_onnx_cgen/determinism.py +39 -0
  8. emx_onnx_cgen/ir/context.py +25 -15
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +32 -7
  11. emx_onnx_cgen/ir/ops/__init__.py +20 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +138 -22
  13. emx_onnx_cgen/ir/ops/misc.py +95 -0
  14. emx_onnx_cgen/ir/ops/nn.py +361 -38
  15. emx_onnx_cgen/ir/ops/reduce.py +1 -16
  16. emx_onnx_cgen/lowering/__init__.py +9 -0
  17. emx_onnx_cgen/lowering/arg_reduce.py +0 -4
  18. emx_onnx_cgen/lowering/average_pool.py +157 -27
  19. emx_onnx_cgen/lowering/bernoulli.py +73 -0
  20. emx_onnx_cgen/lowering/common.py +48 -0
  21. emx_onnx_cgen/lowering/concat.py +41 -7
  22. emx_onnx_cgen/lowering/conv.py +19 -8
  23. emx_onnx_cgen/lowering/conv_integer.py +103 -0
  24. emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
  25. emx_onnx_cgen/lowering/elementwise.py +140 -43
  26. emx_onnx_cgen/lowering/gather.py +11 -2
  27. emx_onnx_cgen/lowering/gemm.py +7 -124
  28. emx_onnx_cgen/lowering/global_max_pool.py +0 -5
  29. emx_onnx_cgen/lowering/gru.py +323 -0
  30. emx_onnx_cgen/lowering/hamming_window.py +104 -0
  31. emx_onnx_cgen/lowering/hardmax.py +1 -37
  32. emx_onnx_cgen/lowering/identity.py +7 -6
  33. emx_onnx_cgen/lowering/logsoftmax.py +1 -35
  34. emx_onnx_cgen/lowering/lp_pool.py +15 -4
  35. emx_onnx_cgen/lowering/matmul.py +3 -105
  36. emx_onnx_cgen/lowering/optional_has_element.py +28 -0
  37. emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
  38. emx_onnx_cgen/lowering/reduce.py +0 -5
  39. emx_onnx_cgen/lowering/reshape.py +7 -16
  40. emx_onnx_cgen/lowering/shape.py +14 -8
  41. emx_onnx_cgen/lowering/slice.py +14 -4
  42. emx_onnx_cgen/lowering/softmax.py +1 -35
  43. emx_onnx_cgen/lowering/split.py +37 -3
  44. emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
  45. emx_onnx_cgen/lowering/tile.py +38 -1
  46. emx_onnx_cgen/lowering/topk.py +1 -5
  47. emx_onnx_cgen/lowering/transpose.py +9 -3
  48. emx_onnx_cgen/lowering/unsqueeze.py +11 -16
  49. emx_onnx_cgen/lowering/upsample.py +151 -0
  50. emx_onnx_cgen/lowering/variadic.py +1 -1
  51. emx_onnx_cgen/lowering/where.py +0 -5
  52. emx_onnx_cgen/onnx_import.py +578 -14
  53. emx_onnx_cgen/ops.py +3 -0
  54. emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
  55. emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
  56. emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
  57. emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
  58. emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
  59. emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
  60. emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
  61. emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
  62. emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
  63. emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
  64. emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
  65. emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
  66. emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
  67. emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
  68. emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
  69. emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
  70. emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
  71. emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
  72. emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
  73. emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
  74. emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
  75. emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
  76. emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
  77. emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
  78. emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
  79. emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
  80. emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
  81. emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
  82. emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
  83. emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
  84. emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
  85. emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
  86. emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
  87. emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
  88. emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
  89. emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
  90. emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
  91. emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
  92. emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
  93. emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
  94. emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
  95. emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
  96. emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
  97. emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
  98. emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
  99. emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
  100. emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
  101. emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
  102. emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
  103. emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
  104. emx_onnx_cgen/templates/range_op.c.j2 +8 -0
  105. emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
  106. emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
  107. emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
  108. emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
  109. emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
  110. emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
  111. emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
  112. emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
  113. emx_onnx_cgen/templates/size_op.c.j2 +4 -0
  114. emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
  115. emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
  116. emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
  117. emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
  118. emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
  119. emx_onnx_cgen/templates/split_op.c.j2 +18 -0
  120. emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
  121. emx_onnx_cgen/templates/testbench.c.j2 +161 -0
  122. emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
  123. emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
  124. emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
  125. emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
  126. emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
  127. emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
  128. emx_onnx_cgen/templates/where_op.c.j2 +9 -0
  129. emx_onnx_cgen/verification.py +45 -5
  130. {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/METADATA +33 -15
  131. emx_onnx_cgen-0.4.1.dev0.dist-info/RECORD +190 -0
  132. {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/WHEEL +1 -1
  133. emx_onnx_cgen/runtime/__init__.py +0 -1
  134. emx_onnx_cgen/runtime/evaluator.py +0 -2955
  135. emx_onnx_cgen-0.3.7.dist-info/RECORD +0 -107
  136. {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/entry_points.txt +0 -0
  137. {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/top_level.txt +0 -0
@@ -1,119 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
- from dataclasses import dataclass
4
-
5
3
  from ..ir.ops import MatMulOp
6
- from ..errors import ShapeInferenceError, UnsupportedOpError
4
+ from ..errors import UnsupportedOpError
7
5
  from ..ir.model import Graph, Node
8
- from .common import node_dtype as _node_dtype
9
- from .common import value_shape as _value_shape
10
6
  from .registry import register_lowering
11
7
 
12
8
 
13
- @dataclass(frozen=True)
14
- class MatMulSpec:
15
- input0_shape: tuple[int, ...]
16
- input1_shape: tuple[int, ...]
17
- output_shape: tuple[int, ...]
18
- batch_shape: tuple[int, ...]
19
- input0_batch_shape: tuple[int, ...]
20
- input1_batch_shape: tuple[int, ...]
21
- m: int
22
- n: int
23
- k: int
24
- left_vector: bool
25
- right_vector: bool
26
-
27
-
28
- def resolve_matmul_spec(graph: Graph, node: Node) -> MatMulSpec:
29
- if len(node.inputs) != 2 or len(node.outputs) != 1:
30
- raise UnsupportedOpError("MatMul must have 2 inputs and 1 output")
31
- input0_shape = _value_shape(graph, node.inputs[0], node)
32
- input1_shape = _value_shape(graph, node.inputs[1], node)
33
- if len(input0_shape) < 1 or len(input1_shape) < 1:
34
- raise UnsupportedOpError(
35
- "MatMul inputs must be at least 1D, "
36
- f"got {input0_shape} x {input1_shape}"
37
- )
38
- left_vector = len(input0_shape) == 1
39
- right_vector = len(input1_shape) == 1
40
- input0_effective = (1, input0_shape[0]) if left_vector else input0_shape
41
- input1_effective = (input1_shape[0], 1) if right_vector else input1_shape
42
- m, k_left = input0_effective[-2], input0_effective[-1]
43
- k_right, n = input1_effective[-2], input1_effective[-1]
44
- if k_left != k_right:
45
- raise ShapeInferenceError(
46
- f"MatMul inner dimensions must match, got {k_left} and {k_right}"
47
- )
48
- batch_shape, input0_batch_shape, input1_batch_shape = (
49
- _broadcast_batch_shapes(
50
- input0_effective[:-2], input1_effective[:-2], node
51
- )
52
- )
53
- if left_vector and right_vector:
54
- output_shape = batch_shape
55
- elif left_vector:
56
- output_shape = batch_shape + (n,)
57
- elif right_vector:
58
- output_shape = batch_shape + (m,)
59
- else:
60
- output_shape = batch_shape + (m, n)
61
- expected_output_shape = _value_shape(graph, node.outputs[0], node)
62
- if expected_output_shape != output_shape:
63
- raise ShapeInferenceError(
64
- "MatMul output shape must be "
65
- f"{output_shape}, got {expected_output_shape}"
66
- )
67
- return MatMulSpec(
68
- input0_shape=input0_shape,
69
- input1_shape=input1_shape,
70
- output_shape=output_shape,
71
- batch_shape=batch_shape,
72
- input0_batch_shape=input0_batch_shape,
73
- input1_batch_shape=input1_batch_shape,
74
- m=m,
75
- n=n,
76
- k=k_left,
77
- left_vector=left_vector,
78
- right_vector=right_vector,
79
- )
80
-
81
-
82
- def _broadcast_batch_shapes(
83
- left: tuple[int, ...], right: tuple[int, ...], node: Node
84
- ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
85
- max_rank = max(len(left), len(right))
86
- left_padded = (1,) * (max_rank - len(left)) + left
87
- right_padded = (1,) * (max_rank - len(right)) + right
88
- broadcast_shape = []
89
- for left_dim, right_dim in zip(left_padded, right_padded):
90
- if not (left_dim == right_dim or left_dim == 1 or right_dim == 1):
91
- raise ShapeInferenceError(
92
- "MatMul batch dimensions must be broadcastable, "
93
- f"got {left} x {right}"
94
- )
95
- broadcast_shape.append(max(left_dim, right_dim))
96
- return tuple(broadcast_shape), left_padded, right_padded
97
-
98
-
99
9
  @register_lowering("MatMul")
100
10
  def lower_matmul(graph: Graph, node: Node) -> MatMulOp:
101
- op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
102
- spec = resolve_matmul_spec(graph, node)
11
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
12
+ raise UnsupportedOpError("MatMul must have 2 inputs and 1 output")
103
13
  return MatMulOp(
104
14
  input0=node.inputs[0],
105
15
  input1=node.inputs[1],
106
16
  output=node.outputs[0],
107
- input0_shape=spec.input0_shape,
108
- input1_shape=spec.input1_shape,
109
- output_shape=spec.output_shape,
110
- batch_shape=spec.batch_shape,
111
- input0_batch_shape=spec.input0_batch_shape,
112
- input1_batch_shape=spec.input1_batch_shape,
113
- m=spec.m,
114
- n=spec.n,
115
- k=spec.k,
116
- left_vector=spec.left_vector,
117
- right_vector=spec.right_vector,
118
- dtype=op_dtype,
119
17
  )
@@ -0,0 +1,28 @@
1
+ from __future__ import annotations
2
+
3
+ from ..errors import UnsupportedOpError
4
+ from ..ir.context import GraphContext
5
+ from ..ir.model import Node
6
+ from ..ir.ops import OptionalHasElementOp
7
+ from .registry import register_lowering
8
+
9
+
10
+ @register_lowering("OptionalHasElement")
11
+ def lower_optional_has_element(
12
+ ctx: GraphContext, node: Node
13
+ ) -> OptionalHasElementOp:
14
+ if len(node.inputs) != 1 or not node.inputs[0]:
15
+ raise UnsupportedOpError(
16
+ "OptionalHasElement expects exactly one non-empty input."
17
+ )
18
+ if len(node.outputs) != 1 or not node.outputs[0]:
19
+ raise UnsupportedOpError(
20
+ "OptionalHasElement expects exactly one output."
21
+ )
22
+ input_name = node.inputs[0]
23
+ value = ctx.find_value(input_name)
24
+ if not value.type.is_optional:
25
+ raise UnsupportedOpError(
26
+ "OptionalHasElement expects an optional input."
27
+ )
28
+ return OptionalHasElementOp(input0=input_name, output=node.outputs[0])
@@ -0,0 +1,116 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..errors import ShapeInferenceError, UnsupportedOpError
6
+ from ..ir.model import Graph, Node
7
+ from ..ir.op_base import BroadcastingOpBase
8
+ from ..ir.ops import QLinearMulOp
9
+ from .common import value_dtype as _value_dtype
10
+ from .common import value_shape as _value_shape
11
+ from .registry import register_lowering
12
+
13
+
14
+ def _ensure_scalar_input(
15
+ graph: Graph, name: str, node: Node, label: str
16
+ ) -> tuple[int, ...]:
17
+ shape = _value_shape(graph, name, node)
18
+ if shape not in {(), (1,)}:
19
+ raise UnsupportedOpError(
20
+ f"QLinearMul {label} must be scalar, got shape {shape}"
21
+ )
22
+ return shape
23
+
24
+
25
+ def _ensure_scale_dtype(dtype: ScalarType, label: str) -> None:
26
+ if not dtype.is_float:
27
+ raise UnsupportedOpError(
28
+ f"QLinearMul {label} must be float16/float/double"
29
+ )
30
+
31
+
32
+ @register_lowering("QLinearMul")
33
+ def lower_qlinear_mul(graph: Graph, node: Node) -> QLinearMulOp:
34
+ if len(node.inputs) != 8 or len(node.outputs) != 1:
35
+ raise UnsupportedOpError("QLinearMul must have 8 inputs and 1 output")
36
+ input0_shape = _value_shape(graph, node.inputs[0], node)
37
+ input1_shape = _value_shape(graph, node.inputs[3], node)
38
+ output_shape = BroadcastingOpBase.broadcast_shapes(
39
+ input0_shape, input1_shape
40
+ )
41
+ expected_output_shape = _value_shape(graph, node.outputs[0], node)
42
+ if expected_output_shape != output_shape:
43
+ raise ShapeInferenceError(
44
+ "QLinearMul output shape must be "
45
+ f"{output_shape}, got {expected_output_shape}"
46
+ )
47
+ input0_dtype = _value_dtype(graph, node.inputs[0], node)
48
+ input1_dtype = _value_dtype(graph, node.inputs[3], node)
49
+ output_dtype = _value_dtype(graph, node.outputs[0], node)
50
+ if input0_dtype not in {ScalarType.U8, ScalarType.I8}:
51
+ raise UnsupportedOpError("QLinearMul supports uint8/int8 inputs only")
52
+ if input1_dtype not in {ScalarType.U8, ScalarType.I8}:
53
+ raise UnsupportedOpError("QLinearMul supports uint8/int8 inputs only")
54
+ if output_dtype not in {ScalarType.U8, ScalarType.I8}:
55
+ raise UnsupportedOpError(
56
+ "QLinearMul supports uint8/int8 outputs only"
57
+ )
58
+ input0_scale_dtype = _value_dtype(graph, node.inputs[1], node)
59
+ input1_scale_dtype = _value_dtype(graph, node.inputs[4], node)
60
+ output_scale_dtype = _value_dtype(graph, node.inputs[6], node)
61
+ _ensure_scale_dtype(input0_scale_dtype, "a_scale")
62
+ _ensure_scale_dtype(input1_scale_dtype, "b_scale")
63
+ _ensure_scale_dtype(output_scale_dtype, "y_scale")
64
+ input0_zero_dtype = _value_dtype(graph, node.inputs[2], node)
65
+ input1_zero_dtype = _value_dtype(graph, node.inputs[5], node)
66
+ output_zero_dtype = _value_dtype(graph, node.inputs[7], node)
67
+ if input0_zero_dtype != input0_dtype:
68
+ raise UnsupportedOpError("QLinearMul a_zero_point dtype must match a")
69
+ if input1_zero_dtype != input1_dtype:
70
+ raise UnsupportedOpError("QLinearMul b_zero_point dtype must match b")
71
+ if output_zero_dtype != output_dtype:
72
+ raise UnsupportedOpError("QLinearMul y_zero_point dtype must match y")
73
+ input0_scale_shape = _ensure_scalar_input(
74
+ graph, node.inputs[1], node, "a_scale"
75
+ )
76
+ input1_scale_shape = _ensure_scalar_input(
77
+ graph, node.inputs[4], node, "b_scale"
78
+ )
79
+ output_scale_shape = _ensure_scalar_input(
80
+ graph, node.inputs[6], node, "y_scale"
81
+ )
82
+ input0_zero_shape = _ensure_scalar_input(
83
+ graph, node.inputs[2], node, "a_zero_point"
84
+ )
85
+ input1_zero_shape = _ensure_scalar_input(
86
+ graph, node.inputs[5], node, "b_zero_point"
87
+ )
88
+ output_zero_shape = _ensure_scalar_input(
89
+ graph, node.inputs[7], node, "y_zero_point"
90
+ )
91
+ return QLinearMulOp(
92
+ input0=node.inputs[0],
93
+ input0_scale=node.inputs[1],
94
+ input0_zero_point=node.inputs[2],
95
+ input1=node.inputs[3],
96
+ input1_scale=node.inputs[4],
97
+ input1_zero_point=node.inputs[5],
98
+ output_scale=node.inputs[6],
99
+ output_zero_point=node.inputs[7],
100
+ output=node.outputs[0],
101
+ input0_shape=input0_shape,
102
+ input1_shape=input1_shape,
103
+ output_shape=output_shape,
104
+ input0_dtype=input0_dtype,
105
+ input1_dtype=input1_dtype,
106
+ dtype=output_dtype,
107
+ input0_scale_dtype=input0_scale_dtype,
108
+ input1_scale_dtype=input1_scale_dtype,
109
+ output_scale_dtype=output_scale_dtype,
110
+ input0_scale_shape=input0_scale_shape,
111
+ input1_scale_shape=input1_scale_shape,
112
+ output_scale_shape=output_scale_shape,
113
+ input0_zero_shape=input0_zero_shape,
114
+ input1_zero_shape=input1_zero_shape,
115
+ output_zero_shape=output_zero_shape,
116
+ )
@@ -525,17 +525,12 @@ def lower_reduce(graph: Graph, node: Node) -> ReduceOp | ReshapeOp:
525
525
  return ReduceOp(
526
526
  input0=node.inputs[0],
527
527
  output=node.outputs[0],
528
- input_shape=input_shape,
529
- output_shape=spec.output_shape,
530
528
  axes=spec.axes or (),
531
529
  axes_input=spec.axes_input,
532
- axes_input_shape=spec.axes_input_shape,
533
- axes_input_dtype=spec.axes_input_dtype,
534
530
  keepdims=spec.keepdims,
535
531
  noop_with_empty_axes=bool(int(node.attrs.get("noop_with_empty_axes", 0))),
536
532
  reduce_kind=REDUCE_KIND_BY_OP[node.op_type],
537
533
  reduce_count=spec.reduce_count,
538
- dtype=op_dtype,
539
534
  )
540
535
 
541
536
 
@@ -2,31 +2,20 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..ir.ops import ReshapeOp
6
5
  from ..errors import ShapeInferenceError, UnsupportedOpError
6
+ from ..ir.context import GraphContext
7
7
  from ..ir.model import Graph, Initializer, Node
8
- from .common import value_shape as resolved_value_shape
8
+ from ..ir.ops import ReshapeOp
9
+ from .common import value_dtype, value_shape as resolved_value_shape
9
10
  from .registry import register_lowering
10
11
 
11
12
 
12
13
  def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
13
- try:
14
- return graph.find_value(name).type.shape
15
- except KeyError as exc:
16
- raise ShapeInferenceError(
17
- f"Missing shape for value '{name}' in op {node.op_type}. "
18
- "Hint: run ONNX shape inference or export with static shapes."
19
- ) from exc
14
+ return resolved_value_shape(graph, name, node)
20
15
 
21
16
 
22
17
  def _value_dtype(graph: Graph, name: str, node: Node) -> ScalarType:
23
- try:
24
- return graph.find_value(name).type.dtype
25
- except KeyError as exc:
26
- raise ShapeInferenceError(
27
- f"Missing dtype for value '{name}' in op {node.op_type}. "
28
- "Hint: run ONNX shape inference or export with static shapes."
29
- ) from exc
18
+ return value_dtype(graph, name, node)
30
19
 
31
20
 
32
21
  def _shape_product(shape: tuple[int, ...]) -> int:
@@ -350,6 +339,8 @@ def lower_reshape(graph: Graph, node: Node) -> ReshapeOp:
350
339
  for dim in output_shape:
351
340
  if dim < 0:
352
341
  raise ShapeInferenceError("Dynamic dims are not supported")
342
+ if isinstance(graph, GraphContext):
343
+ graph.set_shape(node.outputs[0], output_shape)
353
344
  return ReshapeOp(
354
345
  input0=node.inputs[0],
355
346
  output=node.outputs[0],
@@ -2,10 +2,11 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..ir.ops import ShapeOp
6
5
  from ..errors import ShapeInferenceError, UnsupportedOpError
6
+ from ..ir.context import GraphContext
7
7
  from ..ir.model import Graph, Node
8
- from .common import value_dtype, value_shape
8
+ from ..ir.ops import ShapeOp
9
+ from .common import value_dtype, value_has_dim_params, value_shape
9
10
  from .registry import register_lowering
10
11
 
11
12
 
@@ -29,10 +30,13 @@ def lower_shape(graph: Graph, node: Node) -> ShapeOp:
29
30
  raise UnsupportedOpError("Shape must have 1 input and 1 output")
30
31
  input_shape = value_shape(graph, node.inputs[0], node)
31
32
  output_shape = value_shape(graph, node.outputs[0], node)
32
- if len(output_shape) != 1:
33
- raise ShapeInferenceError("Shape output must be 1D")
34
- if output_shape[0] < 0:
35
- raise ShapeInferenceError("Shape output length must be non-negative")
33
+ if value_has_dim_params(graph, node.outputs[0]) or not output_shape:
34
+ output_shape = ()
35
+ if output_shape:
36
+ if len(output_shape) != 1:
37
+ raise ShapeInferenceError("Shape output must be 1D")
38
+ if output_shape[0] < 0:
39
+ raise ShapeInferenceError("Shape output length must be non-negative")
36
40
  input_dtype = value_dtype(graph, node.inputs[0], node)
37
41
  output_dtype = value_dtype(graph, node.outputs[0], node)
38
42
  if output_dtype != ScalarType.I64:
@@ -43,16 +47,18 @@ def lower_shape(graph: Graph, node: Node) -> ShapeOp:
43
47
  len(input_shape), start=start, end=end
44
48
  )
45
49
  expected_shape = (max(0, end_index - start_index),)
46
- if expected_shape != output_shape:
50
+ if output_shape and expected_shape != output_shape:
47
51
  raise ShapeInferenceError(
48
52
  "Shape output shape must be "
49
53
  f"{expected_shape}, got {output_shape}"
50
54
  )
55
+ if isinstance(graph, GraphContext):
56
+ graph.set_shape(node.outputs[0], expected_shape)
51
57
  return ShapeOp(
52
58
  input0=node.inputs[0],
53
59
  output=node.outputs[0],
54
60
  input_shape=input_shape,
55
- output_shape=output_shape,
61
+ output_shape=expected_shape,
56
62
  values=input_shape[start_index:end_index],
57
63
  dtype=output_dtype,
58
64
  input_dtype=input_dtype,
@@ -6,10 +6,16 @@ import numpy as np
6
6
 
7
7
  from shared.scalar_types import ScalarType
8
8
 
9
- from ..ir.ops import SliceOp
10
9
  from ..errors import ShapeInferenceError, UnsupportedOpError
10
+ from ..ir.context import GraphContext
11
11
  from ..ir.model import Graph, Initializer, Node
12
- from ..lowering.common import value_dtype, value_shape
12
+ from ..ir.ops import SliceOp
13
+ from ..lowering.common import (
14
+ resolve_int_list_from_value,
15
+ value_has_dim_params,
16
+ value_dtype,
17
+ value_shape,
18
+ )
13
19
  from ..validation import normalize_axis
14
20
  from .registry import register_lowering
15
21
 
@@ -70,7 +76,7 @@ def _maybe_read_int_list(
70
76
  ) -> list[int] | None:
71
77
  initializer = _find_initializer(graph, name)
72
78
  if initializer is None:
73
- return None
79
+ return resolve_int_list_from_value(graph, name, node)
74
80
  return _read_int_list(graph, name, node, label=label)
75
81
 
76
82
 
@@ -335,6 +341,8 @@ def resolve_slice_spec(graph: Graph, node: Node) -> SliceSpec:
335
341
  def lower_slice(graph: Graph, node: Node) -> SliceOp:
336
342
  input_shape = value_shape(graph, node.inputs[0], node)
337
343
  output_shape = value_shape(graph, node.outputs[0], node)
344
+ if value_has_dim_params(graph, node.outputs[0]):
345
+ output_shape = ()
338
346
  input_dtype = value_dtype(graph, node.inputs[0], node)
339
347
  output_dtype = value_dtype(graph, node.outputs[0], node)
340
348
  if input_dtype != output_dtype:
@@ -356,6 +364,8 @@ def lower_slice(graph: Graph, node: Node) -> SliceOp:
356
364
  f"{node.op_type} output shape must be "
357
365
  f"{computed_output_shape}, got {output_shape}"
358
366
  )
367
+ if isinstance(graph, GraphContext):
368
+ graph.set_shape(node.outputs[0], computed_output_shape)
359
369
  return SliceOp(
360
370
  input0=node.inputs[0],
361
371
  output=node.outputs[0],
@@ -379,7 +389,7 @@ def lower_slice(graph: Graph, node: Node) -> SliceOp:
379
389
  dtype=input_dtype,
380
390
  input_dtype=input_dtype,
381
391
  )
382
- if len(output_shape) != len(input_shape):
392
+ if output_shape and len(output_shape) != len(input_shape):
383
393
  raise ShapeInferenceError(
384
394
  f"{node.op_type} output rank must match input rank"
385
395
  )
@@ -3,49 +3,15 @@ from __future__ import annotations
3
3
  from ..ir.ops import SoftmaxOp
4
4
  from ..errors import UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
- from .common import node_dtype as _node_dtype
7
- from .common import onnx_opset_version as _onnx_opset_version
8
- from .common import shape_product as _shape_product
9
- from .common import value_shape as _value_shape
10
6
  from .registry import register_lowering
11
- from ..validation import ensure_output_shape_matches_input
12
- from ..validation import normalize_axis as _normalize_axis
13
7
 
14
8
 
15
9
  @register_lowering("Softmax")
16
10
  def lower_softmax(graph: Graph, node: Node) -> SoftmaxOp:
17
11
  if len(node.inputs) != 1 or len(node.outputs) != 1:
18
12
  raise UnsupportedOpError("Softmax must have 1 input and 1 output")
19
- op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
20
- if not op_dtype.is_float:
21
- raise UnsupportedOpError(
22
- "Softmax supports float16, float, and double inputs only"
23
- )
24
- input_shape = _value_shape(graph, node.inputs[0], node)
25
- output_shape = _value_shape(graph, node.outputs[0], node)
26
- ensure_output_shape_matches_input(node, input_shape, output_shape)
27
- opset_version = _onnx_opset_version(graph)
28
- default_axis = 1 if opset_version is not None and opset_version < 13 else -1
29
- axis_attr = node.attrs.get("axis", default_axis)
30
- axis = _normalize_axis(
31
- int(axis_attr),
32
- input_shape,
33
- node,
34
- )
35
- outer = _shape_product(input_shape[:axis]) if axis > 0 else 1
36
- axis_size = input_shape[axis]
37
- inner = (
38
- _shape_product(input_shape[axis + 1 :])
39
- if axis + 1 < len(input_shape)
40
- else 1
41
- )
42
13
  return SoftmaxOp(
43
14
  input0=node.inputs[0],
44
15
  output=node.outputs[0],
45
- outer=outer,
46
- axis_size=axis_size,
47
- inner=inner,
48
- axis=axis,
49
- shape=input_shape,
50
- dtype=op_dtype,
16
+ axis=int(node.attrs["axis"]) if "axis" in node.attrs else None,
51
17
  )
@@ -6,8 +6,14 @@ from shared.scalar_types import ScalarType
6
6
 
7
7
  from ..ir.ops import SplitOp
8
8
  from ..errors import ShapeInferenceError, UnsupportedOpError
9
+ from ..ir.context import GraphContext
9
10
  from ..ir.model import Graph, Initializer, Node
10
- from ..lowering.common import optional_name, value_dtype, value_shape
11
+ from ..lowering.common import (
12
+ optional_name,
13
+ resolve_int_list_from_value,
14
+ value_dtype,
15
+ value_shape,
16
+ )
11
17
  from ..validation import normalize_axis
12
18
  from .registry import register_lowering
13
19
 
@@ -46,6 +52,22 @@ def _validate_static_dims(shape: tuple[int, ...], node: Node) -> None:
46
52
  )
47
53
 
48
54
 
55
+ def _validate_output_ranks(
56
+ output_shapes: list[tuple[int, ...]],
57
+ input_shape: tuple[int, ...],
58
+ node: Node,
59
+ ) -> None:
60
+ expected_rank = len(input_shape)
61
+ for output_shape in output_shapes:
62
+ if not output_shape:
63
+ continue
64
+ if len(output_shape) != expected_rank:
65
+ raise ShapeInferenceError(
66
+ f"{node.op_type} output rank must match input rank "
67
+ f"{expected_rank}, got {len(output_shape)}"
68
+ )
69
+
70
+
49
71
  def _normalize_num_outputs(node: Node, output_count: int) -> int:
50
72
  num_outputs_attr = node.attrs.get("num_outputs")
51
73
  if num_outputs_attr is None:
@@ -75,6 +97,7 @@ def lower_split(graph: Graph, node: Node) -> SplitOp:
75
97
  output_shapes = [
76
98
  value_shape(graph, output, node) for output in node.outputs
77
99
  ]
100
+ _validate_output_ranks(output_shapes, input_shape, node)
78
101
  input_dtype = value_dtype(graph, input_name, node)
79
102
  output_dtypes = {value_dtype(graph, output, node) for output in node.outputs}
80
103
  if output_dtypes != {input_dtype}:
@@ -107,7 +130,15 @@ def lower_split(graph: Graph, node: Node) -> SplitOp:
107
130
  raise ShapeInferenceError(
108
131
  f"Split expects {len(node.outputs)} outputs, got {split_shape[0]}"
109
132
  )
110
- split_sizes = [shape[axis] for shape in output_shapes]
133
+ split_sizes = resolve_int_list_from_value(graph, split_name, node)
134
+ if split_sizes is None:
135
+ if all(output_shape for output_shape in output_shapes):
136
+ split_sizes = [shape[axis] for shape in output_shapes]
137
+ else:
138
+ raise ShapeInferenceError(
139
+ "Split sizes must be constant when output shapes "
140
+ "are unavailable"
141
+ )
111
142
  if len(split_sizes) != len(node.outputs):
112
143
  raise ShapeInferenceError(
113
144
  f"Split expects {len(split_sizes)} outputs, got {len(node.outputs)}"
@@ -133,11 +164,14 @@ def lower_split(graph: Graph, node: Node) -> SplitOp:
133
164
  shape = list(input_shape)
134
165
  shape[axis] = size
135
166
  computed_shape = tuple(shape)
136
- if output_shape != computed_shape:
167
+ if output_shape and output_shape != computed_shape:
137
168
  raise ShapeInferenceError(
138
169
  f"Split output shape must be {computed_shape}, got {output_shape}"
139
170
  )
140
171
  computed_shapes.append(computed_shape)
172
+ if isinstance(graph, GraphContext):
173
+ for output_name, shape in zip(node.outputs, computed_shapes):
174
+ graph.set_shape(output_name, shape)
141
175
  return SplitOp(
142
176
  input0=input_name,
143
177
  outputs=tuple(node.outputs),