emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (137) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +1025 -162
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +2081 -458
  6. emx_onnx_cgen/compiler.py +157 -75
  7. emx_onnx_cgen/determinism.py +39 -0
  8. emx_onnx_cgen/ir/context.py +25 -15
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +32 -7
  11. emx_onnx_cgen/ir/ops/__init__.py +20 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +138 -22
  13. emx_onnx_cgen/ir/ops/misc.py +95 -0
  14. emx_onnx_cgen/ir/ops/nn.py +361 -38
  15. emx_onnx_cgen/ir/ops/reduce.py +1 -16
  16. emx_onnx_cgen/lowering/__init__.py +9 -0
  17. emx_onnx_cgen/lowering/arg_reduce.py +0 -4
  18. emx_onnx_cgen/lowering/average_pool.py +157 -27
  19. emx_onnx_cgen/lowering/bernoulli.py +73 -0
  20. emx_onnx_cgen/lowering/common.py +48 -0
  21. emx_onnx_cgen/lowering/concat.py +41 -7
  22. emx_onnx_cgen/lowering/conv.py +19 -8
  23. emx_onnx_cgen/lowering/conv_integer.py +103 -0
  24. emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
  25. emx_onnx_cgen/lowering/elementwise.py +140 -43
  26. emx_onnx_cgen/lowering/gather.py +11 -2
  27. emx_onnx_cgen/lowering/gemm.py +7 -124
  28. emx_onnx_cgen/lowering/global_max_pool.py +0 -5
  29. emx_onnx_cgen/lowering/gru.py +323 -0
  30. emx_onnx_cgen/lowering/hamming_window.py +104 -0
  31. emx_onnx_cgen/lowering/hardmax.py +1 -37
  32. emx_onnx_cgen/lowering/identity.py +7 -6
  33. emx_onnx_cgen/lowering/logsoftmax.py +1 -35
  34. emx_onnx_cgen/lowering/lp_pool.py +15 -4
  35. emx_onnx_cgen/lowering/matmul.py +3 -105
  36. emx_onnx_cgen/lowering/optional_has_element.py +28 -0
  37. emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
  38. emx_onnx_cgen/lowering/reduce.py +0 -5
  39. emx_onnx_cgen/lowering/reshape.py +7 -16
  40. emx_onnx_cgen/lowering/shape.py +14 -8
  41. emx_onnx_cgen/lowering/slice.py +14 -4
  42. emx_onnx_cgen/lowering/softmax.py +1 -35
  43. emx_onnx_cgen/lowering/split.py +37 -3
  44. emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
  45. emx_onnx_cgen/lowering/tile.py +38 -1
  46. emx_onnx_cgen/lowering/topk.py +1 -5
  47. emx_onnx_cgen/lowering/transpose.py +9 -3
  48. emx_onnx_cgen/lowering/unsqueeze.py +11 -16
  49. emx_onnx_cgen/lowering/upsample.py +151 -0
  50. emx_onnx_cgen/lowering/variadic.py +1 -1
  51. emx_onnx_cgen/lowering/where.py +0 -5
  52. emx_onnx_cgen/onnx_import.py +578 -14
  53. emx_onnx_cgen/ops.py +3 -0
  54. emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
  55. emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
  56. emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
  57. emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
  58. emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
  59. emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
  60. emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
  61. emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
  62. emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
  63. emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
  64. emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
  65. emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
  66. emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
  67. emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
  68. emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
  69. emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
  70. emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
  71. emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
  72. emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
  73. emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
  74. emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
  75. emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
  76. emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
  77. emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
  78. emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
  79. emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
  80. emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
  81. emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
  82. emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
  83. emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
  84. emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
  85. emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
  86. emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
  87. emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
  88. emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
  89. emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
  90. emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
  91. emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
  92. emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
  93. emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
  94. emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
  95. emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
  96. emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
  97. emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
  98. emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
  99. emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
  100. emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
  101. emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
  102. emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
  103. emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
  104. emx_onnx_cgen/templates/range_op.c.j2 +8 -0
  105. emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
  106. emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
  107. emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
  108. emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
  109. emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
  110. emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
  111. emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
  112. emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
  113. emx_onnx_cgen/templates/size_op.c.j2 +4 -0
  114. emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
  115. emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
  116. emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
  117. emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
  118. emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
  119. emx_onnx_cgen/templates/split_op.c.j2 +18 -0
  120. emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
  121. emx_onnx_cgen/templates/testbench.c.j2 +161 -0
  122. emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
  123. emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
  124. emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
  125. emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
  126. emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
  127. emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
  128. emx_onnx_cgen/templates/where_op.c.j2 +9 -0
  129. emx_onnx_cgen/verification.py +45 -5
  130. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/METADATA +33 -15
  131. emx_onnx_cgen-0.4.2.dev0.dist-info/RECORD +190 -0
  132. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/WHEEL +1 -1
  133. emx_onnx_cgen/runtime/__init__.py +0 -1
  134. emx_onnx_cgen/runtime/evaluator.py +0 -2955
  135. emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
  136. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/entry_points.txt +0 -0
  137. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,199 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..errors import ShapeInferenceError, UnsupportedOpError
6
+ from ..ir.model import Graph, Node
7
+ from ..ir.ops import TfIdfVectorizerOp
8
+ from ..lowering.common import value_dtype, value_shape
9
+ from .registry import register_lowering
10
+
11
+ _SUPPORTED_INPUT_DTYPES = {ScalarType.I32, ScalarType.I64}
12
+ _SUPPORTED_OUTPUT_DTYPES = {ScalarType.F32}
13
+ _SUPPORTED_MODES = {"TF", "IDF", "TFIDF"}
14
+
15
+
16
+ def _decode_mode(value: object) -> str:
17
+ if isinstance(value, bytes):
18
+ return value.decode()
19
+ return str(value)
20
+
21
+
22
+ def _ensure_int_list(
23
+ values: object | None, *, name: str, node: Node
24
+ ) -> tuple[int, ...]:
25
+ if values is None:
26
+ raise UnsupportedOpError(f"{node.op_type} requires {name} attribute")
27
+ try:
28
+ return tuple(int(value) for value in values) # type: ignore[arg-type]
29
+ except TypeError as exc:
30
+ raise UnsupportedOpError(
31
+ f"{node.op_type} {name} attribute must be a list of integers"
32
+ ) from exc
33
+
34
+
35
+ def _ensure_float_list(
36
+ values: object | None, *, name: str, node: Node
37
+ ) -> tuple[float, ...] | None:
38
+ if values is None:
39
+ return None
40
+ try:
41
+ return tuple(float(value) for value in values) # type: ignore[arg-type]
42
+ except TypeError as exc:
43
+ raise UnsupportedOpError(
44
+ f"{node.op_type} {name} attribute must be a list of floats"
45
+ ) from exc
46
+
47
+
48
+ def _validate_output_shape(
49
+ node: Node,
50
+ input_shape: tuple[int, ...],
51
+ output_shape: tuple[int, ...],
52
+ output_dim: int,
53
+ ) -> None:
54
+ if len(input_shape) == 1:
55
+ expected = (output_dim,)
56
+ else:
57
+ expected = (input_shape[0], output_dim)
58
+ if output_shape != expected:
59
+ raise ShapeInferenceError(
60
+ f"{node.op_type} output shape must be {expected}, got {output_shape}"
61
+ )
62
+
63
+
64
+ @register_lowering("TfIdfVectorizer")
65
+ def lower_tfidf_vectorizer(graph: Graph, node: Node) -> TfIdfVectorizerOp:
66
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
67
+ raise UnsupportedOpError(
68
+ f"{node.op_type} expects 1 input and 1 output"
69
+ )
70
+ input_name = node.inputs[0]
71
+ output_name = node.outputs[0]
72
+ input_shape = value_shape(graph, input_name, node)
73
+ output_shape = value_shape(graph, output_name, node)
74
+ input_dtype = value_dtype(graph, input_name, node)
75
+ output_dtype = value_dtype(graph, output_name, node)
76
+ if input_dtype not in _SUPPORTED_INPUT_DTYPES:
77
+ raise UnsupportedOpError(
78
+ f"{node.op_type} input dtype must be int32 or int64, "
79
+ f"got {input_dtype.onnx_name}"
80
+ )
81
+ if output_dtype not in _SUPPORTED_OUTPUT_DTYPES:
82
+ raise UnsupportedOpError(
83
+ f"{node.op_type} output dtype must be float, "
84
+ f"got {output_dtype.onnx_name}"
85
+ )
86
+ if len(input_shape) not in {1, 2}:
87
+ raise UnsupportedOpError(
88
+ f"{node.op_type} input rank must be 1 or 2, got {len(input_shape)}"
89
+ )
90
+ mode_value = node.attrs.get("mode")
91
+ if mode_value is None:
92
+ raise UnsupportedOpError(
93
+ f"{node.op_type} requires mode attribute"
94
+ )
95
+ mode = _decode_mode(mode_value)
96
+ if mode not in _SUPPORTED_MODES:
97
+ raise UnsupportedOpError(
98
+ f"{node.op_type} mode must be one of {sorted(_SUPPORTED_MODES)}, "
99
+ f"got {mode}"
100
+ )
101
+ min_gram_length = int(node.attrs.get("min_gram_length", 0))
102
+ max_gram_length = int(node.attrs.get("max_gram_length", 0))
103
+ max_skip_count = int(node.attrs.get("max_skip_count", 0))
104
+ if min_gram_length <= 0 or max_gram_length <= 0:
105
+ raise UnsupportedOpError(
106
+ f"{node.op_type} requires positive min/max gram lengths"
107
+ )
108
+ if min_gram_length > max_gram_length:
109
+ raise UnsupportedOpError(
110
+ f"{node.op_type} min_gram_length {min_gram_length} exceeds "
111
+ f"max_gram_length {max_gram_length}"
112
+ )
113
+ if max_skip_count < 0:
114
+ raise UnsupportedOpError(
115
+ f"{node.op_type} max_skip_count must be non-negative"
116
+ )
117
+ ngram_counts = _ensure_int_list(
118
+ node.attrs.get("ngram_counts"), name="ngram_counts", node=node
119
+ )
120
+ ngram_indexes = _ensure_int_list(
121
+ node.attrs.get("ngram_indexes"), name="ngram_indexes", node=node
122
+ )
123
+ if "pool_strings" in node.attrs:
124
+ raise UnsupportedOpError(
125
+ f"{node.op_type} string pools are not supported"
126
+ )
127
+ pool_int64s = _ensure_int_list(
128
+ node.attrs.get("pool_int64s"), name="pool_int64s", node=node
129
+ )
130
+ weights = _ensure_float_list(
131
+ node.attrs.get("weights"), name="weights", node=node
132
+ )
133
+ if len(ngram_counts) < max_gram_length:
134
+ raise UnsupportedOpError(
135
+ f"{node.op_type} ngram_counts length must be >= max_gram_length"
136
+ )
137
+ if ngram_counts and ngram_counts[0] != 0:
138
+ raise UnsupportedOpError(
139
+ f"{node.op_type} ngram_counts must start with 0"
140
+ )
141
+ if any(value < 0 for value in ngram_counts):
142
+ raise UnsupportedOpError(
143
+ f"{node.op_type} ngram_counts must be non-negative"
144
+ )
145
+ if any(
146
+ later < earlier
147
+ for earlier, later in zip(ngram_counts, ngram_counts[1:])
148
+ ):
149
+ raise UnsupportedOpError(
150
+ f"{node.op_type} ngram_counts must be non-decreasing"
151
+ )
152
+ pool_size = len(pool_int64s)
153
+ if ngram_counts and ngram_counts[-1] > pool_size:
154
+ raise UnsupportedOpError(
155
+ f"{node.op_type} ngram_counts exceeds pool_int64s length"
156
+ )
157
+ total_ngrams = 0
158
+ for gram_length in range(1, max_gram_length + 1):
159
+ start = ngram_counts[gram_length - 1]
160
+ end = (
161
+ ngram_counts[gram_length]
162
+ if gram_length < len(ngram_counts)
163
+ else pool_size
164
+ )
165
+ count = end - start
166
+ if count < 0 or count % gram_length != 0:
167
+ raise UnsupportedOpError(
168
+ f"{node.op_type} pool size for {gram_length}-grams "
169
+ "must be divisible by gram length"
170
+ )
171
+ total_ngrams += count // gram_length
172
+ if total_ngrams != len(ngram_indexes):
173
+ raise UnsupportedOpError(
174
+ f"{node.op_type} ngram_indexes length {len(ngram_indexes)} "
175
+ f"does not match pool ngram count {total_ngrams}"
176
+ )
177
+ if weights is not None and len(weights) != len(ngram_indexes):
178
+ raise UnsupportedOpError(
179
+ f"{node.op_type} weights length {len(weights)} does not match "
180
+ f"ngram_indexes length {len(ngram_indexes)}"
181
+ )
182
+ output_dim = max(ngram_indexes, default=-1) + 1
183
+ _validate_output_shape(node, input_shape, output_shape, output_dim)
184
+ return TfIdfVectorizerOp(
185
+ input0=input_name,
186
+ output=output_name,
187
+ input_shape=input_shape,
188
+ output_shape=output_shape,
189
+ input_dtype=input_dtype,
190
+ output_dtype=output_dtype,
191
+ min_gram_length=min_gram_length,
192
+ max_gram_length=max_gram_length,
193
+ max_skip_count=max_skip_count,
194
+ mode=mode,
195
+ ngram_counts=ngram_counts,
196
+ ngram_indexes=ngram_indexes,
197
+ pool_int64s=pool_int64s,
198
+ weights=weights,
199
+ )
@@ -30,6 +30,37 @@ def _read_repeats(graph: Graph, name: str, node: Node) -> tuple[int, ...] | None
30
30
  return tuple(int(value) for value in values)
31
31
 
32
32
 
33
+ def _infer_repeats_from_shapes(
34
+ input_shape: tuple[int, ...],
35
+ output_shape: tuple[int, ...],
36
+ ) -> tuple[int, ...]:
37
+ if len(input_shape) != len(output_shape):
38
+ raise ShapeInferenceError(
39
+ "Tile repeats must have the same rank as input shape"
40
+ )
41
+ repeats: list[int] = []
42
+ for input_dim, output_dim in zip(input_shape, output_shape):
43
+ if input_dim < 0 or output_dim < 0:
44
+ raise ShapeInferenceError(
45
+ "Tile repeats input must be constant when shapes are dynamic"
46
+ )
47
+ if input_dim == 0:
48
+ if output_dim != 0:
49
+ raise ShapeInferenceError(
50
+ "Tile output shape mismatch: "
51
+ f"expected 0 for dimension, got {output_dim}"
52
+ )
53
+ repeats.append(0)
54
+ continue
55
+ if output_dim % input_dim != 0:
56
+ raise ShapeInferenceError(
57
+ "Tile output shape mismatch: "
58
+ f"expected multiple of {input_dim}, got {output_dim}"
59
+ )
60
+ repeats.append(int(output_dim // input_dim))
61
+ return tuple(repeats)
62
+
63
+
33
64
  def _compute_strides(shape: tuple[int, ...]) -> tuple[int, ...]:
34
65
  strides: list[int] = []
35
66
  stride = 1
@@ -54,7 +85,13 @@ def lower_tile(graph: Graph, node: Node) -> TileOp:
54
85
  )
55
86
  repeats = _read_repeats(graph, node.inputs[1], node)
56
87
  if repeats is None:
57
- raise UnsupportedOpError("Tile repeats input must be a constant initializer")
88
+ repeats_shape = value_shape(graph, node.inputs[1], node)
89
+ repeats_dtype = value_dtype(graph, node.inputs[1], node)
90
+ if repeats_dtype not in {ScalarType.I64, ScalarType.I32}:
91
+ raise UnsupportedOpError("Tile repeats input must be int64 or int32")
92
+ if len(repeats_shape) != 1:
93
+ raise UnsupportedOpError("Tile repeats input must be a 1D tensor")
94
+ repeats = _infer_repeats_from_shapes(input_shape, output_shape)
58
95
  if len(repeats) != len(input_shape):
59
96
  raise ShapeInferenceError(
60
97
  "Tile repeats must have the same rank as input shape"
@@ -117,17 +117,13 @@ def lower_topk(graph: Graph, node: Node) -> TopKOp:
117
117
  sorted_output = bool(int(node.attrs.get("sorted", 1)))
118
118
  return TopKOp(
119
119
  input0=input_name,
120
+ k_input=k_name,
120
121
  output_values=output_values,
121
122
  output_indices=output_indices,
122
- input_shape=input_shape,
123
- output_shape=output_shape,
124
123
  axis=axis,
125
124
  k=k,
126
125
  largest=largest,
127
126
  sorted=sorted_output,
128
- input_dtype=input_dtype,
129
- output_values_dtype=values_dtype,
130
- output_indices_dtype=indices_dtype,
131
127
  )
132
128
 
133
129
 
@@ -1,9 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..ir.ops import TransposeOp
4
3
  from ..errors import ShapeInferenceError, UnsupportedOpError
4
+ from ..ir.context import GraphContext
5
5
  from ..ir.model import Graph, Node
6
+ from ..ir.ops import TransposeOp
6
7
  from .common import node_dtype as _node_dtype
8
+ from .common import value_has_dim_params as _value_has_dim_params
7
9
  from .common import value_shape as _value_shape
8
10
  from .registry import register_lowering
9
11
 
@@ -14,6 +16,8 @@ def lower_transpose(graph: Graph, node: Node) -> TransposeOp:
14
16
  raise UnsupportedOpError("Transpose must have 1 input and 1 output")
15
17
  input_shape = _value_shape(graph, node.inputs[0], node)
16
18
  output_shape = _value_shape(graph, node.outputs[0], node)
19
+ if _value_has_dim_params(graph, node.outputs[0]) or not output_shape:
20
+ output_shape = ()
17
21
  perm = node.attrs.get("perm")
18
22
  if perm is None:
19
23
  perm = tuple(reversed(range(len(input_shape))))
@@ -29,18 +33,20 @@ def lower_transpose(graph: Graph, node: Node) -> TransposeOp:
29
33
  f"Transpose perm must be a permutation, got {perm}"
30
34
  )
31
35
  expected_shape = tuple(input_shape[axis] for axis in perm)
32
- if output_shape != expected_shape:
36
+ if output_shape and output_shape != expected_shape:
33
37
  raise ShapeInferenceError(
34
38
  "Transpose output shape must match permuted input shape, "
35
39
  f"expected {expected_shape}, got {output_shape}"
36
40
  )
41
+ if isinstance(graph, GraphContext):
42
+ graph.set_shape(node.outputs[0], expected_shape)
37
43
  op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
38
44
  return TransposeOp(
39
45
  input0=node.inputs[0],
40
46
  output=node.outputs[0],
41
47
  perm=perm,
42
48
  input_shape=input_shape,
43
- output_shape=output_shape,
49
+ output_shape=expected_shape,
44
50
  dtype=op_dtype,
45
51
  input_dtype=op_dtype,
46
52
  )
@@ -2,30 +2,20 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..ir.ops import ReshapeOp
6
5
  from ..errors import ShapeInferenceError, UnsupportedOpError
6
+ from ..ir.context import GraphContext
7
7
  from ..ir.model import Graph, Initializer, Node
8
+ from ..ir.ops import ReshapeOp
9
+ from ..lowering.common import value_dtype, value_has_dim_params, value_shape
8
10
  from .registry import register_lowering
9
11
 
10
12
 
11
13
  def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
12
- try:
13
- return graph.find_value(name).type.shape
14
- except KeyError as exc:
15
- raise ShapeInferenceError(
16
- f"Missing shape for value '{name}' in op {node.op_type}. "
17
- "Hint: run ONNX shape inference or export with static shapes."
18
- ) from exc
14
+ return value_shape(graph, name, node)
19
15
 
20
16
 
21
17
  def _value_dtype(graph: Graph, name: str, node: Node) -> ScalarType:
22
- try:
23
- return graph.find_value(name).type.dtype
24
- except KeyError as exc:
25
- raise ShapeInferenceError(
26
- f"Missing dtype for value '{name}' in op {node.op_type}. "
27
- "Hint: run ONNX shape inference or export with static shapes."
28
- ) from exc
18
+ return value_dtype(graph, name, node)
29
19
 
30
20
 
31
21
  def _find_initializer(graph: Graph, name: str) -> Initializer | None:
@@ -105,6 +95,8 @@ def lower_unsqueeze(graph: Graph, node: Node) -> ReshapeOp:
105
95
  raise UnsupportedOpError("Unsqueeze must have 1 or 2 inputs and 1 output")
106
96
  input_shape = _value_shape(graph, node.inputs[0], node)
107
97
  output_shape = _value_shape(graph, node.outputs[0], node)
98
+ if value_has_dim_params(graph, node.outputs[0]):
99
+ output_shape = ()
108
100
  _validate_shape(input_shape, node, "input")
109
101
  _validate_shape(output_shape, node, "output")
110
102
  input_dtype = _value_dtype(graph, node.inputs[0], node)
@@ -142,11 +134,14 @@ def lower_unsqueeze(graph: Graph, node: Node) -> ReshapeOp:
142
134
  )
143
135
  else:
144
136
  expected_shape = _expected_output_shape(input_shape, axes, node)
145
- if expected_shape != output_shape:
137
+ if output_shape and expected_shape != output_shape:
146
138
  raise ShapeInferenceError(
147
139
  "Unsqueeze output shape must be "
148
140
  f"{expected_shape}, got {output_shape}"
149
141
  )
142
+ output_shape = expected_shape
143
+ if isinstance(graph, GraphContext):
144
+ graph.set_shape(node.outputs[0], output_shape)
150
145
  return ReshapeOp(
151
146
  input0=node.inputs[0],
152
147
  output=node.outputs[0],
@@ -0,0 +1,151 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..errors import ShapeInferenceError, UnsupportedOpError
6
+ from ..ir.model import Graph, Initializer, Node
7
+ from ..ir.ops import ResizeOp
8
+ from .common import value_dtype, value_shape
9
+ from .registry import register_lowering
10
+
11
+ _SUPPORTED_MODES = {"nearest", "linear"}
12
+
13
+
14
+ def _decode_attr(value: object, default: str) -> str:
15
+ if value is None:
16
+ return default
17
+ if isinstance(value, bytes):
18
+ return value.decode("utf-8", errors="ignore")
19
+ if isinstance(value, str):
20
+ return value
21
+ return str(value)
22
+
23
+
24
+ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
25
+ for initializer in graph.initializers:
26
+ if initializer.name == name:
27
+ return initializer
28
+ return None
29
+
30
+
31
+ def _load_initializer_values(
32
+ graph: Graph, name: str, node: Node
33
+ ) -> tuple[float | int, ...] | None:
34
+ initializer = _find_initializer(graph, name)
35
+ if initializer is None:
36
+ return None
37
+ if initializer.type.dtype not in {
38
+ ScalarType.F16,
39
+ ScalarType.F32,
40
+ ScalarType.F64,
41
+ }:
42
+ raise UnsupportedOpError(
43
+ "Upsample scales initializer must be float16/float32/float64"
44
+ )
45
+ data = initializer.data.reshape(-1)
46
+ return tuple(data.tolist())
47
+
48
+
49
+ def _validate_output_shape(
50
+ expected: tuple[int, ...],
51
+ actual: tuple[int, ...],
52
+ ) -> None:
53
+ if expected != actual:
54
+ raise ShapeInferenceError(
55
+ f"Upsample output shape must be {expected}, got {actual}"
56
+ )
57
+ if any(dim < 0 for dim in actual):
58
+ raise ShapeInferenceError("Upsample output shape must be non-negative")
59
+
60
+
61
+ @register_lowering("Upsample")
62
+ def lower_upsample(graph: Graph, node: Node) -> ResizeOp:
63
+ if len(node.outputs) != 1:
64
+ raise UnsupportedOpError("Upsample expects one output")
65
+ if len(node.inputs) not in {1, 2}:
66
+ raise UnsupportedOpError("Upsample expects 1 or 2 inputs")
67
+ mode = _decode_attr(node.attrs.get("mode"), "nearest")
68
+ if mode not in _SUPPORTED_MODES:
69
+ raise UnsupportedOpError(f"Upsample mode {mode!r} is not supported")
70
+ input_name = node.inputs[0]
71
+ output_name = node.outputs[0]
72
+ input_shape = value_shape(graph, input_name, node)
73
+ output_shape = value_shape(graph, output_name, node)
74
+ input_dtype = value_dtype(graph, input_name, node)
75
+ output_dtype = value_dtype(graph, output_name, node)
76
+ if input_dtype != output_dtype:
77
+ raise UnsupportedOpError(
78
+ "Upsample expects matching input/output dtypes, "
79
+ f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
80
+ )
81
+ rank = len(input_shape)
82
+ axes = tuple(range(rank))
83
+ scales_input = None
84
+ scales_shape = None
85
+ scales_dtype = None
86
+ scales_axes = None
87
+ scales: tuple[float, ...]
88
+ if len(node.inputs) == 2 and node.inputs[1]:
89
+ scales_input = node.inputs[1]
90
+ scales_shape = value_shape(graph, scales_input, node)
91
+ if len(scales_shape) != 1:
92
+ raise UnsupportedOpError("Upsample expects scales to be 1D")
93
+ if scales_shape[0] != rank:
94
+ raise UnsupportedOpError("Upsample scales length mismatch")
95
+ scales_dtype = value_dtype(graph, scales_input, node)
96
+ if scales_dtype not in {ScalarType.F16, ScalarType.F32, ScalarType.F64}:
97
+ raise UnsupportedOpError(
98
+ "Upsample expects scales input to be float16/float32/float64"
99
+ )
100
+ values = _load_initializer_values(graph, scales_input, node)
101
+ if values is None:
102
+ scales = tuple(
103
+ output_shape[axis] / input_shape[axis]
104
+ for axis in range(rank)
105
+ )
106
+ else:
107
+ scales = tuple(float(value) for value in values)
108
+ expected = tuple(
109
+ int(input_shape[axis] * scales[axis]) for axis in range(rank)
110
+ )
111
+ _validate_output_shape(expected, output_shape)
112
+ else:
113
+ scales_attr = node.attrs.get("scales")
114
+ if scales_attr is None:
115
+ raise UnsupportedOpError("Upsample requires scales attribute or input")
116
+ scales = tuple(float(value) for value in scales_attr)
117
+ if len(scales) != rank:
118
+ raise UnsupportedOpError("Upsample scales length mismatch")
119
+ expected = tuple(
120
+ int(input_shape[axis] * scales[axis]) for axis in range(rank)
121
+ )
122
+ _validate_output_shape(expected, output_shape)
123
+ return ResizeOp(
124
+ input0=input_name,
125
+ output=output_name,
126
+ input_shape=input_shape,
127
+ output_shape=output_shape,
128
+ scales=scales,
129
+ scales_input=scales_input,
130
+ sizes_input=None,
131
+ roi_input=None,
132
+ axes=axes,
133
+ scales_shape=scales_shape,
134
+ sizes_shape=None,
135
+ roi_shape=None,
136
+ scales_dtype=scales_dtype,
137
+ sizes_dtype=None,
138
+ roi_dtype=None,
139
+ scales_axes=scales_axes,
140
+ sizes_axes=None,
141
+ roi_axes=None,
142
+ mode=mode,
143
+ coordinate_transformation_mode="asymmetric",
144
+ nearest_mode="floor",
145
+ cubic_coeff_a=-0.75,
146
+ exclude_outside=False,
147
+ extrapolation_value=0.0,
148
+ antialias=False,
149
+ keep_aspect_ratio_policy="stretch",
150
+ dtype=input_dtype,
151
+ )
@@ -53,7 +53,7 @@ def _lower_variadic(graph: Graph, node: Node) -> MultiInputBinaryOp:
53
53
  output=node.outputs[0],
54
54
  function=VARIADIC_OP_FUNCTIONS[node.op_type],
55
55
  operator_kind=VARIADIC_OP_OPERATOR_KINDS[node.op_type],
56
- min_inputs=2,
56
+ min_inputs=1 if node.op_type not in BINARY_ONLY_OPS else 2,
57
57
  max_inputs=2 if node.op_type in BINARY_ONLY_OPS else None,
58
58
  )
59
59
 
@@ -65,9 +65,4 @@ def lower_where(graph: Graph, node: Node) -> WhereOp:
65
65
  input_x=x_name,
66
66
  input_y=y_name,
67
67
  output=output_name,
68
- condition_shape=condition_shape,
69
- x_shape=x_shape,
70
- y_shape=y_shape,
71
- output_shape=output_shape,
72
- dtype=output_dtype,
73
68
  )