emx-onnx-cgen 0.3.7__py3-none-any.whl → 0.4.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +1025 -162
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +2081 -458
  6. emx_onnx_cgen/compiler.py +157 -75
  7. emx_onnx_cgen/determinism.py +39 -0
  8. emx_onnx_cgen/ir/context.py +25 -15
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +32 -7
  11. emx_onnx_cgen/ir/ops/__init__.py +20 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +138 -22
  13. emx_onnx_cgen/ir/ops/misc.py +95 -0
  14. emx_onnx_cgen/ir/ops/nn.py +361 -38
  15. emx_onnx_cgen/ir/ops/reduce.py +1 -16
  16. emx_onnx_cgen/lowering/__init__.py +9 -0
  17. emx_onnx_cgen/lowering/arg_reduce.py +0 -4
  18. emx_onnx_cgen/lowering/average_pool.py +157 -27
  19. emx_onnx_cgen/lowering/bernoulli.py +73 -0
  20. emx_onnx_cgen/lowering/common.py +48 -0
  21. emx_onnx_cgen/lowering/concat.py +41 -7
  22. emx_onnx_cgen/lowering/conv.py +19 -8
  23. emx_onnx_cgen/lowering/conv_integer.py +103 -0
  24. emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
  25. emx_onnx_cgen/lowering/elementwise.py +140 -43
  26. emx_onnx_cgen/lowering/gather.py +11 -2
  27. emx_onnx_cgen/lowering/gemm.py +7 -124
  28. emx_onnx_cgen/lowering/global_max_pool.py +0 -5
  29. emx_onnx_cgen/lowering/gru.py +323 -0
  30. emx_onnx_cgen/lowering/hamming_window.py +104 -0
  31. emx_onnx_cgen/lowering/hardmax.py +1 -37
  32. emx_onnx_cgen/lowering/identity.py +7 -6
  33. emx_onnx_cgen/lowering/logsoftmax.py +1 -35
  34. emx_onnx_cgen/lowering/lp_pool.py +15 -4
  35. emx_onnx_cgen/lowering/matmul.py +3 -105
  36. emx_onnx_cgen/lowering/optional_has_element.py +28 -0
  37. emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
  38. emx_onnx_cgen/lowering/reduce.py +0 -5
  39. emx_onnx_cgen/lowering/reshape.py +7 -16
  40. emx_onnx_cgen/lowering/shape.py +14 -8
  41. emx_onnx_cgen/lowering/slice.py +14 -4
  42. emx_onnx_cgen/lowering/softmax.py +1 -35
  43. emx_onnx_cgen/lowering/split.py +37 -3
  44. emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
  45. emx_onnx_cgen/lowering/tile.py +38 -1
  46. emx_onnx_cgen/lowering/topk.py +1 -5
  47. emx_onnx_cgen/lowering/transpose.py +9 -3
  48. emx_onnx_cgen/lowering/unsqueeze.py +11 -16
  49. emx_onnx_cgen/lowering/upsample.py +151 -0
  50. emx_onnx_cgen/lowering/variadic.py +1 -1
  51. emx_onnx_cgen/lowering/where.py +0 -5
  52. emx_onnx_cgen/onnx_import.py +578 -14
  53. emx_onnx_cgen/ops.py +3 -0
  54. emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
  55. emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
  56. emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
  57. emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
  58. emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
  59. emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
  60. emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
  61. emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
  62. emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
  63. emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
  64. emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
  65. emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
  66. emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
  67. emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
  68. emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
  69. emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
  70. emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
  71. emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
  72. emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
  73. emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
  74. emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
  75. emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
  76. emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
  77. emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
  78. emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
  79. emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
  80. emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
  81. emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
  82. emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
  83. emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
  84. emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
  85. emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
  86. emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
  87. emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
  88. emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
  89. emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
  90. emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
  91. emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
  92. emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
  93. emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
  94. emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
  95. emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
  96. emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
  97. emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
  98. emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
  99. emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
  100. emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
  101. emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
  102. emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
  103. emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
  104. emx_onnx_cgen/templates/range_op.c.j2 +8 -0
  105. emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
  106. emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
  107. emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
  108. emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
  109. emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
  110. emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
  111. emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
  112. emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
  113. emx_onnx_cgen/templates/size_op.c.j2 +4 -0
  114. emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
  115. emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
  116. emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
  117. emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
  118. emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
  119. emx_onnx_cgen/templates/split_op.c.j2 +18 -0
  120. emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
  121. emx_onnx_cgen/templates/testbench.c.j2 +161 -0
  122. emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
  123. emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
  124. emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
  125. emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
  126. emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
  127. emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
  128. emx_onnx_cgen/templates/where_op.c.j2 +9 -0
  129. emx_onnx_cgen/verification.py +45 -5
  130. {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/METADATA +33 -15
  131. emx_onnx_cgen-0.4.1.dev0.dist-info/RECORD +190 -0
  132. {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/WHEEL +1 -1
  133. emx_onnx_cgen/runtime/__init__.py +0 -1
  134. emx_onnx_cgen/runtime/evaluator.py +0 -2955
  135. emx_onnx_cgen-0.3.7.dist-info/RECORD +0 -107
  136. {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/entry_points.txt +0 -0
  137. {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import math
3
4
  from dataclasses import dataclass
4
5
 
5
6
  from ..ir.ops import AveragePoolOp
@@ -12,16 +13,26 @@ from .registry import register_lowering
12
13
  class _AveragePoolSpec:
13
14
  batch: int
14
15
  channels: int
16
+ spatial_rank: int
17
+ in_d: int
15
18
  in_h: int
16
19
  in_w: int
20
+ out_d: int
17
21
  out_h: int
18
22
  out_w: int
23
+ kernel_d: int
19
24
  kernel_h: int
20
25
  kernel_w: int
26
+ dilation_d: int
27
+ dilation_h: int
28
+ dilation_w: int
29
+ stride_d: int
21
30
  stride_h: int
22
31
  stride_w: int
32
+ pad_front: int
23
33
  pad_top: int
24
34
  pad_left: int
35
+ pad_back: int
25
36
  pad_bottom: int
26
37
  pad_right: int
27
38
  count_include_pad: bool
@@ -54,6 +65,7 @@ def _resolve_average_pool_spec(graph: Graph, node: Node) -> _AveragePoolSpec:
54
65
  "auto_pad",
55
66
  "ceil_mode",
56
67
  "count_include_pad",
68
+ "dilations",
57
69
  "kernel_shape",
58
70
  "pads",
59
71
  "strides",
@@ -63,11 +75,9 @@ def _resolve_average_pool_spec(graph: Graph, node: Node) -> _AveragePoolSpec:
63
75
  auto_pad = node.attrs.get("auto_pad", b"NOTSET")
64
76
  if isinstance(auto_pad, bytes):
65
77
  auto_pad = auto_pad.decode("utf-8", errors="ignore")
66
- if auto_pad not in ("", "NOTSET"):
67
- raise UnsupportedOpError("AveragePool supports auto_pad=NOTSET only")
68
78
  ceil_mode = int(node.attrs.get("ceil_mode", 0))
69
- if ceil_mode != 0:
70
- raise UnsupportedOpError("AveragePool supports ceil_mode=0 only")
79
+ if ceil_mode not in (0, 1):
80
+ raise UnsupportedOpError("AveragePool supports ceil_mode=0 or 1 only")
71
81
  count_include_pad = int(node.attrs.get("count_include_pad", 0))
72
82
  if count_include_pad not in (0, 1):
73
83
  raise UnsupportedOpError("AveragePool supports count_include_pad 0 or 1")
@@ -75,47 +85,128 @@ def _resolve_average_pool_spec(graph: Graph, node: Node) -> _AveragePoolSpec:
75
85
  if kernel_shape is None:
76
86
  raise UnsupportedOpError("AveragePool requires kernel_shape")
77
87
  kernel_shape = tuple(int(value) for value in kernel_shape)
78
- if len(kernel_shape) != 2:
79
- raise UnsupportedOpError("AveragePool expects 2D kernel_shape")
80
- kernel_h, kernel_w = kernel_shape
81
- strides = tuple(int(value) for value in node.attrs.get("strides", (1, 1)))
82
- if len(strides) != 2:
83
- raise UnsupportedOpError("AveragePool expects 2D strides")
84
- pads = tuple(int(value) for value in node.attrs.get("pads", (0, 0, 0, 0)))
85
- if len(pads) != 4:
86
- raise UnsupportedOpError("AveragePool expects 4D pads")
87
- pad_top, pad_left, pad_bottom, pad_right = pads
88
88
  input_shape = _value_shape(graph, node.inputs[0], node)
89
- if len(input_shape) != 4:
90
- raise UnsupportedOpError("AveragePool supports NCHW 2D inputs only")
91
- batch, channels, in_h, in_w = input_shape
92
- stride_h, stride_w = strides
93
- out_h = (in_h + pad_top + pad_bottom - kernel_h) // stride_h + 1
94
- out_w = (in_w + pad_left + pad_right - kernel_w) // stride_w + 1
95
- if out_h < 0 or out_w < 0:
89
+ if len(input_shape) < 3:
90
+ raise UnsupportedOpError("AveragePool expects NCHW inputs with spatial dims")
91
+ spatial_rank = len(input_shape) - 2
92
+ if spatial_rank not in {1, 2, 3}:
93
+ raise UnsupportedOpError("AveragePool supports 1D/2D/3D inputs only")
94
+ if len(kernel_shape) != spatial_rank:
96
95
  raise ShapeInferenceError(
97
- "AveragePool output shape must be non-negative"
96
+ "AveragePool kernel_shape must have "
97
+ f"{spatial_rank} dims, got {kernel_shape}"
98
98
  )
99
+ strides = tuple(
100
+ int(value) for value in node.attrs.get("strides", (1,) * spatial_rank)
101
+ )
102
+ if len(strides) != spatial_rank:
103
+ raise UnsupportedOpError("AveragePool stride rank mismatch")
104
+ dilations = tuple(
105
+ int(value)
106
+ for value in node.attrs.get("dilations", (1,) * spatial_rank)
107
+ )
108
+ if len(dilations) != spatial_rank:
109
+ raise UnsupportedOpError("AveragePool dilation rank mismatch")
110
+ pads = tuple(
111
+ int(value) for value in node.attrs.get("pads", (0,) * (2 * spatial_rank))
112
+ )
113
+ if len(pads) != 2 * spatial_rank:
114
+ raise UnsupportedOpError("AveragePool pads rank mismatch")
115
+ if auto_pad in ("", "NOTSET"):
116
+ pad_begin = pads[:spatial_rank]
117
+ pad_end = pads[spatial_rank:]
118
+ elif auto_pad == "VALID":
119
+ pad_begin = (0,) * spatial_rank
120
+ pad_end = (0,) * spatial_rank
121
+ elif auto_pad in {"SAME_UPPER", "SAME_LOWER"}:
122
+ pad_begin = []
123
+ pad_end = []
124
+ for dim, stride, dilation, kernel in zip(
125
+ input_shape[2:], strides, dilations, kernel_shape
126
+ ):
127
+ effective_kernel = dilation * (kernel - 1) + 1
128
+ out_dim = math.ceil(dim / stride)
129
+ pad_needed = max(0, (out_dim - 1) * stride + effective_kernel - dim)
130
+ if auto_pad == "SAME_UPPER":
131
+ pad_start = pad_needed // 2
132
+ else:
133
+ pad_start = (pad_needed + 1) // 2
134
+ pad_begin.append(pad_start)
135
+ pad_end.append(pad_needed - pad_start)
136
+ pad_begin = tuple(pad_begin)
137
+ pad_end = tuple(pad_end)
138
+ else:
139
+ raise UnsupportedOpError("AveragePool has unsupported auto_pad mode")
140
+ batch, channels = input_shape[:2]
141
+ in_spatial = input_shape[2:]
142
+ out_spatial = []
143
+ for dim, stride, dilation, kernel, pad_start, pad_finish in zip(
144
+ in_spatial, strides, dilations, kernel_shape, pad_begin, pad_end
145
+ ):
146
+ effective_kernel = dilation * (kernel - 1) + 1
147
+ numerator = dim + pad_start + pad_finish - effective_kernel
148
+ if ceil_mode:
149
+ out_dim = (numerator + stride - 1) // stride + 1
150
+ if (out_dim - 1) * stride >= dim + pad_start:
151
+ out_dim -= 1
152
+ else:
153
+ out_dim = numerator // stride + 1
154
+ if out_dim < 0:
155
+ raise ShapeInferenceError(
156
+ "AveragePool output shape must be non-negative"
157
+ )
158
+ out_spatial.append(out_dim)
99
159
  output_shape = _value_shape(graph, node.outputs[0], node)
100
- expected_output_shape = (batch, channels, out_h, out_w)
160
+ expected_output_shape = (batch, channels, *out_spatial)
101
161
  if output_shape != expected_output_shape:
102
162
  raise ShapeInferenceError(
103
163
  "AveragePool output shape must be "
104
164
  f"{expected_output_shape}, got {output_shape}"
105
165
  )
166
+ in_d = in_spatial[0] if spatial_rank == 3 else 1
167
+ in_h = in_spatial[-2] if spatial_rank >= 2 else 1
168
+ in_w = in_spatial[-1]
169
+ out_d = out_spatial[0] if spatial_rank == 3 else 1
170
+ out_h = out_spatial[-2] if spatial_rank >= 2 else 1
171
+ out_w = out_spatial[-1]
172
+ kernel_d = kernel_shape[0] if spatial_rank == 3 else 1
173
+ kernel_h = kernel_shape[-2] if spatial_rank >= 2 else 1
174
+ kernel_w = kernel_shape[-1]
175
+ dilation_d = dilations[0] if spatial_rank == 3 else 1
176
+ dilation_h = dilations[-2] if spatial_rank >= 2 else 1
177
+ dilation_w = dilations[-1]
178
+ stride_d = strides[0] if spatial_rank == 3 else 1
179
+ stride_h = strides[-2] if spatial_rank >= 2 else 1
180
+ stride_w = strides[-1]
181
+ pad_front = pad_begin[0] if spatial_rank == 3 else 0
182
+ pad_top = pad_begin[-2] if spatial_rank >= 2 else 0
183
+ pad_left = pad_begin[-1]
184
+ pad_back = pad_end[0] if spatial_rank == 3 else 0
185
+ pad_bottom = pad_end[-2] if spatial_rank >= 2 else 0
186
+ pad_right = pad_end[-1]
106
187
  return _AveragePoolSpec(
107
188
  batch=batch,
108
189
  channels=channels,
190
+ spatial_rank=spatial_rank,
191
+ in_d=in_d,
109
192
  in_h=in_h,
110
193
  in_w=in_w,
194
+ out_d=out_d,
111
195
  out_h=out_h,
112
196
  out_w=out_w,
197
+ kernel_d=kernel_d,
113
198
  kernel_h=kernel_h,
114
199
  kernel_w=kernel_w,
200
+ dilation_d=dilation_d,
201
+ dilation_h=dilation_h,
202
+ dilation_w=dilation_w,
203
+ stride_d=stride_d,
115
204
  stride_h=stride_h,
116
205
  stride_w=stride_w,
206
+ pad_front=pad_front,
117
207
  pad_top=pad_top,
118
208
  pad_left=pad_left,
209
+ pad_back=pad_back,
119
210
  pad_bottom=pad_bottom,
120
211
  pad_right=pad_right,
121
212
  count_include_pad=bool(count_include_pad),
@@ -128,29 +219,48 @@ def _resolve_global_average_pool_spec(graph: Graph, node: Node) -> _AveragePoolS
128
219
  if node.attrs:
129
220
  raise UnsupportedOpError("GlobalAveragePool has unsupported attributes")
130
221
  input_shape = _value_shape(graph, node.inputs[0], node)
131
- if len(input_shape) != 4:
132
- raise UnsupportedOpError("GlobalAveragePool supports NCHW 2D inputs only")
133
- batch, channels, in_h, in_w = input_shape
222
+ if len(input_shape) < 3:
223
+ raise UnsupportedOpError(
224
+ "GlobalAveragePool expects NCHW inputs with spatial dims"
225
+ )
226
+ spatial_rank = len(input_shape) - 2
227
+ if spatial_rank not in {1, 2, 3}:
228
+ raise UnsupportedOpError("GlobalAveragePool supports 1D/2D/3D inputs only")
229
+ batch, channels = input_shape[:2]
230
+ in_spatial = input_shape[2:]
134
231
  output_shape = _value_shape(graph, node.outputs[0], node)
135
- expected_output_shape = (batch, channels, 1, 1)
232
+ expected_output_shape = (batch, channels, *([1] * spatial_rank))
136
233
  if output_shape != expected_output_shape:
137
234
  raise ShapeInferenceError(
138
235
  "GlobalAveragePool output shape must be "
139
236
  f"{expected_output_shape}, got {output_shape}"
140
237
  )
238
+ in_d = in_spatial[0] if spatial_rank == 3 else 1
239
+ in_h = in_spatial[-2] if spatial_rank >= 2 else 1
240
+ in_w = in_spatial[-1]
141
241
  return _AveragePoolSpec(
142
242
  batch=batch,
143
243
  channels=channels,
244
+ spatial_rank=spatial_rank,
245
+ in_d=in_d,
144
246
  in_h=in_h,
145
247
  in_w=in_w,
248
+ out_d=1,
146
249
  out_h=1,
147
250
  out_w=1,
251
+ kernel_d=in_d,
148
252
  kernel_h=in_h,
149
253
  kernel_w=in_w,
254
+ dilation_d=1,
255
+ dilation_h=1,
256
+ dilation_w=1,
257
+ stride_d=1,
150
258
  stride_h=1,
151
259
  stride_w=1,
260
+ pad_front=0,
152
261
  pad_top=0,
153
262
  pad_left=0,
263
+ pad_back=0,
154
264
  pad_bottom=0,
155
265
  pad_right=0,
156
266
  count_include_pad=False,
@@ -176,16 +286,26 @@ def lower_average_pool(graph: Graph, node: Node) -> AveragePoolOp:
176
286
  output=node.outputs[0],
177
287
  batch=spec.batch,
178
288
  channels=spec.channels,
289
+ spatial_rank=spec.spatial_rank,
290
+ in_d=spec.in_d,
179
291
  in_h=spec.in_h,
180
292
  in_w=spec.in_w,
293
+ out_d=spec.out_d,
181
294
  out_h=spec.out_h,
182
295
  out_w=spec.out_w,
296
+ kernel_d=spec.kernel_d,
183
297
  kernel_h=spec.kernel_h,
184
298
  kernel_w=spec.kernel_w,
299
+ dilation_d=spec.dilation_d,
300
+ dilation_h=spec.dilation_h,
301
+ dilation_w=spec.dilation_w,
302
+ stride_d=spec.stride_d,
185
303
  stride_h=spec.stride_h,
186
304
  stride_w=spec.stride_w,
305
+ pad_front=spec.pad_front,
187
306
  pad_top=spec.pad_top,
188
307
  pad_left=spec.pad_left,
308
+ pad_back=spec.pad_back,
189
309
  pad_bottom=spec.pad_bottom,
190
310
  pad_right=spec.pad_right,
191
311
  count_include_pad=spec.count_include_pad,
@@ -212,16 +332,26 @@ def lower_global_average_pool(graph: Graph, node: Node) -> AveragePoolOp:
212
332
  output=node.outputs[0],
213
333
  batch=spec.batch,
214
334
  channels=spec.channels,
335
+ spatial_rank=spec.spatial_rank,
336
+ in_d=spec.in_d,
215
337
  in_h=spec.in_h,
216
338
  in_w=spec.in_w,
339
+ out_d=spec.out_d,
217
340
  out_h=spec.out_h,
218
341
  out_w=spec.out_w,
342
+ kernel_d=spec.kernel_d,
219
343
  kernel_h=spec.kernel_h,
220
344
  kernel_w=spec.kernel_w,
345
+ dilation_d=spec.dilation_d,
346
+ dilation_h=spec.dilation_h,
347
+ dilation_w=spec.dilation_w,
348
+ stride_d=spec.stride_d,
221
349
  stride_h=spec.stride_h,
222
350
  stride_w=spec.stride_w,
351
+ pad_front=spec.pad_front,
223
352
  pad_top=spec.pad_top,
224
353
  pad_left=spec.pad_left,
354
+ pad_back=spec.pad_back,
225
355
  pad_bottom=spec.pad_bottom,
226
356
  pad_right=spec.pad_right,
227
357
  count_include_pad=spec.count_include_pad,
@@ -0,0 +1,73 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..dtypes import dtype_info
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from ..ir.ops import BernoulliOp
9
+ from .common import value_dtype as _value_dtype
10
+ from .common import value_shape as _value_shape
11
+ from .registry import register_lowering
12
+
13
+
14
+ _SUPPORTED_INPUT_DTYPES = {ScalarType.F16, ScalarType.F32, ScalarType.F64}
15
+ _SUPPORTED_OUTPUT_DTYPES = {
16
+ ScalarType.U8,
17
+ ScalarType.U16,
18
+ ScalarType.U32,
19
+ ScalarType.U64,
20
+ ScalarType.I8,
21
+ ScalarType.I16,
22
+ ScalarType.I32,
23
+ ScalarType.I64,
24
+ ScalarType.F16,
25
+ ScalarType.F32,
26
+ ScalarType.F64,
27
+ ScalarType.BOOL,
28
+ }
29
+
30
+
31
+ @register_lowering("Bernoulli")
32
+ def lower_bernoulli(graph: Graph, node: Node) -> BernoulliOp:
33
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
34
+ raise UnsupportedOpError("Bernoulli must have 1 input and 1 output")
35
+ input_shape = _value_shape(graph, node.inputs[0], node)
36
+ output_shape = _value_shape(graph, node.outputs[0], node)
37
+ if input_shape != output_shape:
38
+ raise ShapeInferenceError(
39
+ "Bernoulli output shape must match input shape, "
40
+ f"got {output_shape} for input {input_shape}"
41
+ )
42
+ input_dtype = _value_dtype(graph, node.inputs[0], node)
43
+ if input_dtype not in _SUPPORTED_INPUT_DTYPES:
44
+ raise UnsupportedOpError(
45
+ "Bernoulli input dtype must be float, "
46
+ f"got {input_dtype.onnx_name}"
47
+ )
48
+ output_dtype = _value_dtype(graph, node.outputs[0], node)
49
+ dtype_attr = node.attrs.get("dtype")
50
+ if dtype_attr is not None:
51
+ attr_dtype = dtype_info(int(dtype_attr))
52
+ if attr_dtype != output_dtype:
53
+ raise UnsupportedOpError(
54
+ "Bernoulli dtype attribute does not match output dtype"
55
+ )
56
+ if output_dtype not in _SUPPORTED_OUTPUT_DTYPES:
57
+ raise UnsupportedOpError(
58
+ "Bernoulli output dtype must be numeric or bool, "
59
+ f"got {output_dtype.onnx_name}"
60
+ )
61
+ seed_value = node.attrs.get("seed")
62
+ seed = None
63
+ if seed_value is not None:
64
+ seed = int(seed_value)
65
+ return BernoulliOp(
66
+ input0=node.inputs[0],
67
+ output=node.outputs[0],
68
+ input_shape=input_shape,
69
+ output_shape=output_shape,
70
+ input_dtype=input_dtype,
71
+ dtype=output_dtype,
72
+ seed=seed,
73
+ )
@@ -50,6 +50,8 @@ def value_shape(
50
50
  if isinstance(graph, GraphContext):
51
51
  shape = graph.shape(name, node)
52
52
  value = graph.find_value(name)
53
+ if graph.has_shape(name):
54
+ return shape
53
55
  else:
54
56
  try:
55
57
  value = graph.find_value(name)
@@ -219,6 +221,37 @@ def _shape_values_from_input(
219
221
  return [int(l / r) if r != 0 else 0 for l, r in zip(left, right)]
220
222
  if source_node.op_type == "Mod":
221
223
  return [l % r if r != 0 else 0 for l, r in zip(left, right)]
224
+ if source_node.op_type in {"Add", "Sub", "Mul"}:
225
+ if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
226
+ raise UnsupportedOpError(
227
+ f"{source_node.op_type} must have 2 inputs and 1 output"
228
+ )
229
+ left = _shape_values_from_input(
230
+ graph,
231
+ source_node.inputs[0],
232
+ node,
233
+ _visited=_visited,
234
+ )
235
+ right = _shape_values_from_input(
236
+ graph,
237
+ source_node.inputs[1],
238
+ node,
239
+ _visited=_visited,
240
+ )
241
+ if left is None or right is None:
242
+ return None
243
+ if len(left) == 1 and len(right) != 1:
244
+ left = left * len(right)
245
+ if len(right) == 1 and len(left) != 1:
246
+ right = right * len(left)
247
+ if len(left) != len(right):
248
+ return None
249
+ if source_node.op_type == "Add":
250
+ return [l + r for l, r in zip(left, right)]
251
+ if source_node.op_type == "Sub":
252
+ return [l - r for l, r in zip(left, right)]
253
+ if source_node.op_type == "Mul":
254
+ return [l * r for l, r in zip(left, right)]
222
255
  if source_node.op_type == "Not":
223
256
  if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
224
257
  raise UnsupportedOpError("Not must have 1 input and 1 output")
@@ -465,3 +498,18 @@ def optional_name(names: Sequence[str], index: int) -> str | None:
465
498
  return None
466
499
  name = names[index]
467
500
  return name or None
501
+
502
+
503
+ def resolve_int_list_from_value(
504
+ graph: Graph | GraphContext,
505
+ name: str,
506
+ node: Node | None = None,
507
+ ) -> list[int] | None:
508
+ return _shape_values_from_input(graph, name, node)
509
+
510
+
511
+ def value_has_dim_params(
512
+ graph: Graph | GraphContext,
513
+ name: str,
514
+ ) -> bool:
515
+ return any(graph.find_value(name).type.dim_params)
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..ir.ops import ConcatOp
4
3
  from ..errors import UnsupportedOpError
4
+ from ..ir.context import GraphContext
5
5
  from ..ir.model import Graph, Node
6
+ from ..ir.ops import ConcatOp
6
7
  from .common import node_dtype as _node_dtype
8
+ from .common import value_has_dim_params as _value_has_dim_params
7
9
  from .common import value_shape as _value_shape
8
10
  from .registry import register_lowering
9
- from ..validation import validate_concat_shapes
11
+ from ..validation import normalize_concat_axis, validate_concat_shapes
10
12
 
11
13
 
12
14
  @register_lowering("Concat")
@@ -15,12 +17,44 @@ def lower_concat(graph: Graph, node: Node) -> ConcatOp:
15
17
  raise UnsupportedOpError("Concat must have at least 1 input and 1 output")
16
18
  op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
17
19
  output_shape = _value_shape(graph, node.outputs[0], node)
20
+ if _value_has_dim_params(graph, node.outputs[0]):
21
+ output_shape = ()
18
22
  input_shapes = tuple(_value_shape(graph, name, node) for name in node.inputs)
19
- axis = validate_concat_shapes(
20
- input_shapes,
21
- output_shape,
22
- int(node.attrs.get("axis", 0)),
23
- )
23
+ axis = int(node.attrs.get("axis", 0))
24
+ if output_shape:
25
+ axis = validate_concat_shapes(
26
+ input_shapes,
27
+ output_shape,
28
+ axis,
29
+ )
30
+ else:
31
+ ranks = {len(shape) for shape in input_shapes}
32
+ if len(ranks) != 1:
33
+ raise UnsupportedOpError(
34
+ f"Concat inputs must have matching ranks, got {input_shapes}"
35
+ )
36
+ rank = ranks.pop()
37
+ axis = normalize_concat_axis(axis, rank)
38
+ base_shape = list(input_shapes[0])
39
+ axis_dim = 0
40
+ for shape in input_shapes:
41
+ if len(shape) != rank:
42
+ raise UnsupportedOpError(
43
+ f"Concat inputs must have matching ranks, got {input_shapes}"
44
+ )
45
+ for dim_index, dim in enumerate(shape):
46
+ if dim_index == axis:
47
+ continue
48
+ if dim != base_shape[dim_index]:
49
+ raise UnsupportedOpError(
50
+ "Concat inputs must match on non-axis dimensions, "
51
+ f"got {input_shapes}"
52
+ )
53
+ axis_dim += shape[axis]
54
+ base_shape[axis] = axis_dim
55
+ output_shape = tuple(base_shape)
56
+ if isinstance(graph, GraphContext):
57
+ graph.set_shape(node.outputs[0], output_shape)
24
58
  return ConcatOp(
25
59
  inputs=node.inputs,
26
60
  output=node.outputs[0],
@@ -26,9 +26,14 @@ class ConvSpec:
26
26
  group: int
27
27
 
28
28
 
29
- def resolve_conv_spec(graph: Graph, node: Node) -> ConvSpec:
30
- if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
31
- raise UnsupportedOpError("Conv must have 2 or 3 inputs and 1 output")
29
+ def resolve_conv_spec(
30
+ graph: Graph,
31
+ node: Node,
32
+ *,
33
+ input_name: str,
34
+ weight_name: str,
35
+ bias_name: str | None,
36
+ ) -> ConvSpec:
32
37
  supported_attrs = {
33
38
  "auto_pad",
34
39
  "dilations",
@@ -39,8 +44,8 @@ def resolve_conv_spec(graph: Graph, node: Node) -> ConvSpec:
39
44
  }
40
45
  if set(node.attrs) - supported_attrs:
41
46
  raise UnsupportedOpError("Conv has unsupported attributes")
42
- input_shape = _value_shape(graph, node.inputs[0], node)
43
- weight_shape = _value_shape(graph, node.inputs[1], node)
47
+ input_shape = _value_shape(graph, input_name, node)
48
+ weight_shape = _value_shape(graph, weight_name, node)
44
49
  if len(input_shape) < 3:
45
50
  raise UnsupportedOpError("Conv expects NCHW inputs with spatial dims")
46
51
  spatial_rank = len(input_shape) - 2
@@ -79,8 +84,8 @@ def resolve_conv_spec(graph: Graph, node: Node) -> ConvSpec:
79
84
  "Conv input channels must match weight channels, "
80
85
  f"got {in_channels} and {weight_in_channels * group}"
81
86
  )
82
- if len(node.inputs) == 3:
83
- bias_shape = _value_shape(graph, node.inputs[2], node)
87
+ if bias_name is not None:
88
+ bias_shape = _value_shape(graph, bias_name, node)
84
89
  if bias_shape != (out_channels,):
85
90
  raise ShapeInferenceError(
86
91
  f"Conv bias shape must be {(out_channels,)}, got {bias_shape}"
@@ -171,7 +176,13 @@ def lower_conv(graph: Graph, node: Node) -> ConvOp:
171
176
  raise UnsupportedOpError(
172
177
  "Conv supports float16, float, and double inputs only"
173
178
  )
174
- spec = resolve_conv_spec(graph, node)
179
+ spec = resolve_conv_spec(
180
+ graph,
181
+ node,
182
+ input_name=node.inputs[0],
183
+ weight_name=node.inputs[1],
184
+ bias_name=node.inputs[2] if len(node.inputs) == 3 else None,
185
+ )
175
186
  return ConvOp(
176
187
  input0=node.inputs[0],
177
188
  weights=node.inputs[1],
@@ -0,0 +1,103 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..errors import UnsupportedOpError
6
+ from ..ir.model import Graph, Node
7
+ from ..ir.ops import ConvIntegerOp
8
+ from .common import optional_name, value_dtype as _value_dtype
9
+ from .common import value_shape as _value_shape
10
+ from .conv import resolve_conv_spec
11
+ from .registry import register_lowering
12
+
13
+
14
+ def _ensure_scalar_shape(shape: tuple[int, ...], label: str) -> None:
15
+ if shape not in {(), (1,)}:
16
+ raise UnsupportedOpError(
17
+ f"ConvInteger {label} must be a scalar, got shape {shape}"
18
+ )
19
+
20
+
21
+ def _resolve_w_zero_point_shape(
22
+ shape: tuple[int, ...], out_channels: int
23
+ ) -> bool:
24
+ if shape in {(), (1,)}:
25
+ return False
26
+ if shape == (out_channels,):
27
+ return True
28
+ raise UnsupportedOpError(
29
+ "ConvInteger w_zero_point must be scalar or 1D per output channel, "
30
+ f"got shape {shape}"
31
+ )
32
+
33
+
34
+ @register_lowering("ConvInteger")
35
+ def lower_conv_integer(graph: Graph, node: Node) -> ConvIntegerOp:
36
+ if len(node.inputs) not in {2, 3, 4} or len(node.outputs) != 1:
37
+ raise UnsupportedOpError(
38
+ "ConvInteger must have 2 to 4 inputs and 1 output"
39
+ )
40
+ input_name = node.inputs[0]
41
+ weight_name = node.inputs[1]
42
+ x_zero_point_name = optional_name(node.inputs, 2)
43
+ w_zero_point_name = optional_name(node.inputs, 3)
44
+ input_dtype = _value_dtype(graph, input_name, node)
45
+ weight_dtype = _value_dtype(graph, weight_name, node)
46
+ output_dtype = _value_dtype(graph, node.outputs[0], node)
47
+ if input_dtype not in {ScalarType.U8, ScalarType.I8}:
48
+ raise UnsupportedOpError("ConvInteger supports uint8/int8 inputs only")
49
+ if weight_dtype not in {ScalarType.U8, ScalarType.I8}:
50
+ raise UnsupportedOpError("ConvInteger supports uint8/int8 weights only")
51
+ if output_dtype != ScalarType.I32:
52
+ raise UnsupportedOpError("ConvInteger expects int32 outputs only")
53
+ x_zero_shape = None
54
+ if x_zero_point_name is not None:
55
+ x_zero_shape = _value_shape(graph, x_zero_point_name, node)
56
+ _ensure_scalar_shape(x_zero_shape, "x_zero_point")
57
+ if _value_dtype(graph, x_zero_point_name, node) != input_dtype:
58
+ raise UnsupportedOpError(
59
+ "ConvInteger x_zero_point dtype must match input dtype"
60
+ )
61
+ w_zero_shape = None
62
+ w_zero_point_per_channel = False
63
+ if w_zero_point_name is not None:
64
+ w_zero_shape = _value_shape(graph, w_zero_point_name, node)
65
+ if _value_dtype(graph, w_zero_point_name, node) != weight_dtype:
66
+ raise UnsupportedOpError(
67
+ "ConvInteger w_zero_point dtype must match weight dtype"
68
+ )
69
+ spec = resolve_conv_spec(
70
+ graph,
71
+ node,
72
+ input_name=input_name,
73
+ weight_name=weight_name,
74
+ bias_name=None,
75
+ )
76
+ if w_zero_shape is not None:
77
+ w_zero_point_per_channel = _resolve_w_zero_point_shape(
78
+ w_zero_shape, spec.out_channels
79
+ )
80
+ return ConvIntegerOp(
81
+ input0=input_name,
82
+ weights=weight_name,
83
+ x_zero_point=x_zero_point_name,
84
+ w_zero_point=w_zero_point_name,
85
+ output=node.outputs[0],
86
+ batch=spec.batch,
87
+ in_channels=spec.in_channels,
88
+ out_channels=spec.out_channels,
89
+ spatial_rank=spec.spatial_rank,
90
+ in_spatial=spec.in_spatial,
91
+ out_spatial=spec.out_spatial,
92
+ kernel_shape=spec.kernel_shape,
93
+ strides=spec.strides,
94
+ pads=spec.pads,
95
+ dilations=spec.dilations,
96
+ group=spec.group,
97
+ input_dtype=input_dtype,
98
+ weight_dtype=weight_dtype,
99
+ dtype=output_dtype,
100
+ x_zero_point_shape=x_zero_shape,
101
+ w_zero_point_shape=w_zero_shape,
102
+ w_zero_point_per_channel=w_zero_point_per_channel,
103
+ )