emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +1025 -162
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +2081 -458
  6. emx_onnx_cgen/compiler.py +157 -75
  7. emx_onnx_cgen/determinism.py +39 -0
  8. emx_onnx_cgen/ir/context.py +25 -15
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +32 -7
  11. emx_onnx_cgen/ir/ops/__init__.py +20 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +138 -22
  13. emx_onnx_cgen/ir/ops/misc.py +95 -0
  14. emx_onnx_cgen/ir/ops/nn.py +361 -38
  15. emx_onnx_cgen/ir/ops/reduce.py +1 -16
  16. emx_onnx_cgen/lowering/__init__.py +9 -0
  17. emx_onnx_cgen/lowering/arg_reduce.py +0 -4
  18. emx_onnx_cgen/lowering/average_pool.py +157 -27
  19. emx_onnx_cgen/lowering/bernoulli.py +73 -0
  20. emx_onnx_cgen/lowering/common.py +48 -0
  21. emx_onnx_cgen/lowering/concat.py +41 -7
  22. emx_onnx_cgen/lowering/conv.py +19 -8
  23. emx_onnx_cgen/lowering/conv_integer.py +103 -0
  24. emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
  25. emx_onnx_cgen/lowering/elementwise.py +140 -43
  26. emx_onnx_cgen/lowering/gather.py +11 -2
  27. emx_onnx_cgen/lowering/gemm.py +7 -124
  28. emx_onnx_cgen/lowering/global_max_pool.py +0 -5
  29. emx_onnx_cgen/lowering/gru.py +323 -0
  30. emx_onnx_cgen/lowering/hamming_window.py +104 -0
  31. emx_onnx_cgen/lowering/hardmax.py +1 -37
  32. emx_onnx_cgen/lowering/identity.py +7 -6
  33. emx_onnx_cgen/lowering/logsoftmax.py +1 -35
  34. emx_onnx_cgen/lowering/lp_pool.py +15 -4
  35. emx_onnx_cgen/lowering/matmul.py +3 -105
  36. emx_onnx_cgen/lowering/optional_has_element.py +28 -0
  37. emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
  38. emx_onnx_cgen/lowering/reduce.py +0 -5
  39. emx_onnx_cgen/lowering/reshape.py +7 -16
  40. emx_onnx_cgen/lowering/shape.py +14 -8
  41. emx_onnx_cgen/lowering/slice.py +14 -4
  42. emx_onnx_cgen/lowering/softmax.py +1 -35
  43. emx_onnx_cgen/lowering/split.py +37 -3
  44. emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
  45. emx_onnx_cgen/lowering/tile.py +38 -1
  46. emx_onnx_cgen/lowering/topk.py +1 -5
  47. emx_onnx_cgen/lowering/transpose.py +9 -3
  48. emx_onnx_cgen/lowering/unsqueeze.py +11 -16
  49. emx_onnx_cgen/lowering/upsample.py +151 -0
  50. emx_onnx_cgen/lowering/variadic.py +1 -1
  51. emx_onnx_cgen/lowering/where.py +0 -5
  52. emx_onnx_cgen/onnx_import.py +578 -14
  53. emx_onnx_cgen/ops.py +3 -0
  54. emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
  55. emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
  56. emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
  57. emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
  58. emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
  59. emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
  60. emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
  61. emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
  62. emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
  63. emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
  64. emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
  65. emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
  66. emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
  67. emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
  68. emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
  69. emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
  70. emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
  71. emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
  72. emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
  73. emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
  74. emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
  75. emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
  76. emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
  77. emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
  78. emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
  79. emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
  80. emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
  81. emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
  82. emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
  83. emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
  84. emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
  85. emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
  86. emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
  87. emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
  88. emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
  89. emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
  90. emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
  91. emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
  92. emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
  93. emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
  94. emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
  95. emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
  96. emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
  97. emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
  98. emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
  99. emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
  100. emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
  101. emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
  102. emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
  103. emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
  104. emx_onnx_cgen/templates/range_op.c.j2 +8 -0
  105. emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
  106. emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
  107. emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
  108. emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
  109. emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
  110. emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
  111. emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
  112. emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
  113. emx_onnx_cgen/templates/size_op.c.j2 +4 -0
  114. emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
  115. emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
  116. emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
  117. emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
  118. emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
  119. emx_onnx_cgen/templates/split_op.c.j2 +18 -0
  120. emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
  121. emx_onnx_cgen/templates/testbench.c.j2 +161 -0
  122. emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
  123. emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
  124. emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
  125. emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
  126. emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
  127. emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
  128. emx_onnx_cgen/templates/where_op.c.j2 +9 -0
  129. emx_onnx_cgen/verification.py +45 -5
  130. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/METADATA +33 -15
  131. emx_onnx_cgen-0.4.1.dev0.dist-info/RECORD +190 -0
  132. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/WHEEL +1 -1
  133. emx_onnx_cgen/runtime/__init__.py +0 -1
  134. emx_onnx_cgen/runtime/evaluator.py +0 -2955
  135. emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
  136. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/entry_points.txt +0 -0
  137. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/top_level.txt +0 -0
@@ -41,21 +41,19 @@ def _unsupported_value_type(value_info: onnx.ValueInfoProto) -> UnsupportedOpErr
41
41
  )
42
42
 
43
43
 
44
- def _tensor_type(
45
- value_info: onnx.ValueInfoProto,
44
+ def _tensor_type_from_proto(
45
+ tensor_type: onnx.TypeProto.Tensor,
46
+ name: str,
46
47
  *,
47
48
  dim_param_override: tuple[str | None, ...] | None = None,
48
49
  ) -> TensorType:
49
- if value_info.type.WhichOneof("value") != "tensor_type":
50
- raise _unsupported_value_type(value_info)
51
- tensor_type = value_info.type.tensor_type
52
50
  if not tensor_type.HasField("elem_type"):
53
- raise ShapeInferenceError(f"Missing elem_type for tensor '{value_info.name}'")
51
+ raise ShapeInferenceError(f"Missing elem_type for tensor '{name}'")
54
52
  dtype = scalar_type_from_onnx(tensor_type.elem_type)
55
53
  if dtype is None:
56
54
  raise UnsupportedOpError(
57
55
  "Unsupported elem_type "
58
- f"{_format_elem_type(tensor_type.elem_type)} for tensor '{value_info.name}'."
56
+ f"{_format_elem_type(tensor_type.elem_type)} for tensor '{name}'."
59
57
  )
60
58
  shape = []
61
59
  dim_params = []
@@ -72,7 +70,7 @@ def _tensor_type(
72
70
  if dim_param:
73
71
  shape.append(1)
74
72
  continue
75
- raise ShapeInferenceError(f"Dynamic dim for tensor '{value_info.name}'")
73
+ raise ShapeInferenceError(f"Dynamic dim for tensor '{name}'")
76
74
  shape.append(dim.dim_value)
77
75
  return TensorType(
78
76
  dtype=dtype,
@@ -81,6 +79,40 @@ def _tensor_type(
81
79
  )
82
80
 
83
81
 
82
+ def _value_type(
83
+ value_info: onnx.ValueInfoProto,
84
+ *,
85
+ dim_param_override: tuple[str | None, ...] | None = None,
86
+ ) -> TensorType:
87
+ value_kind = value_info.type.WhichOneof("value")
88
+ if value_kind == "tensor_type":
89
+ return _tensor_type_from_proto(
90
+ value_info.type.tensor_type,
91
+ value_info.name,
92
+ dim_param_override=dim_param_override,
93
+ )
94
+ if value_kind == "optional_type":
95
+ elem_type = value_info.type.optional_type.elem_type
96
+ elem_kind = elem_type.WhichOneof("value")
97
+ if elem_kind != "tensor_type":
98
+ raise UnsupportedOpError(
99
+ f"Unsupported optional element type '{elem_kind}' for '{value_info.name}'. "
100
+ "Hint: export the model with optional tensor inputs/outputs."
101
+ )
102
+ tensor_type = _tensor_type_from_proto(
103
+ elem_type.tensor_type,
104
+ value_info.name,
105
+ dim_param_override=dim_param_override,
106
+ )
107
+ return TensorType(
108
+ dtype=tensor_type.dtype,
109
+ shape=tensor_type.shape,
110
+ dim_params=tensor_type.dim_params,
111
+ is_optional=True,
112
+ )
113
+ raise _unsupported_value_type(value_info)
114
+
115
+
84
116
  def _values(
85
117
  value_infos: Iterable[onnx.ValueInfoProto],
86
118
  *,
@@ -90,7 +122,7 @@ def _values(
90
122
  return tuple(
91
123
  Value(
92
124
  name=vi.name,
93
- type=_tensor_type(
125
+ type=_value_type(
94
126
  vi, dim_param_override=dim_param_by_name.get(vi.name)
95
127
  ),
96
128
  )
@@ -103,8 +135,18 @@ def _collect_dim_params(
103
135
  ) -> dict[str, tuple[str | None, ...]]:
104
136
  dim_params: dict[str, tuple[str | None, ...]] = {}
105
137
  for value_info in value_infos:
138
+ value_kind = value_info.type.WhichOneof("value")
139
+ if value_kind == "tensor_type":
140
+ tensor_type = value_info.type.tensor_type
141
+ elif value_kind == "optional_type":
142
+ elem_type = value_info.type.optional_type.elem_type
143
+ if elem_type.WhichOneof("value") != "tensor_type":
144
+ continue
145
+ tensor_type = elem_type.tensor_type
146
+ else:
147
+ continue
106
148
  dims = []
107
- for dim in value_info.type.tensor_type.shape.dim:
149
+ for dim in tensor_type.shape.dim:
108
150
  dim_param = dim.dim_param if dim.HasField("dim_param") else ""
109
151
  dims.append(dim_param or None)
110
152
  if any(dims):
@@ -112,6 +154,61 @@ def _collect_dim_params(
112
154
  return dim_params
113
155
 
114
156
 
157
+ def _value_info_complete(value_info: onnx.ValueInfoProto) -> bool:
158
+ value_kind = value_info.type.WhichOneof("value")
159
+ if value_kind == "tensor_type":
160
+ tensor_type = value_info.type.tensor_type
161
+ elif value_kind == "optional_type":
162
+ elem_type = value_info.type.optional_type.elem_type
163
+ if elem_type.WhichOneof("value") != "tensor_type":
164
+ return False
165
+ tensor_type = elem_type.tensor_type
166
+ else:
167
+ return False
168
+ if not tensor_type.HasField("elem_type"):
169
+ return False
170
+ if not tensor_type.HasField("shape"):
171
+ return False
172
+ for dim in tensor_type.shape.dim:
173
+ if dim.HasField("dim_value"):
174
+ continue
175
+ if dim.HasField("dim_param"):
176
+ continue
177
+ return False
178
+ return True
179
+
180
+
181
+ def _needs_shape_inference(model: onnx.ModelProto) -> bool:
182
+ graph = model.graph
183
+ value_info_by_name = {
184
+ value_info.name: value_info for value_info in graph.value_info
185
+ }
186
+ output_names = {value_info.name for value_info in graph.output}
187
+ initializer_names = {initializer.name for initializer in graph.initializer}
188
+ initializer_names.update(
189
+ sparse_init.name for sparse_init in graph.sparse_initializer
190
+ )
191
+ for node in graph.node:
192
+ for output in node.output:
193
+ if not output:
194
+ continue
195
+ if output in output_names or output in value_info_by_name:
196
+ continue
197
+ return True
198
+ for value_info in graph.value_info:
199
+ if not _value_info_complete(value_info):
200
+ return True
201
+ for value_info in graph.output:
202
+ if not _value_info_complete(value_info):
203
+ return True
204
+ for value_info in graph.input:
205
+ if value_info.name in initializer_names:
206
+ continue
207
+ if not _value_info_complete(value_info):
208
+ return True
209
+ return False
210
+
211
+
115
212
  def _initializer(value: onnx.TensorProto) -> Initializer:
116
213
  dtype = scalar_type_from_onnx(value.data_type)
117
214
  if dtype is None:
@@ -136,6 +233,471 @@ def _node_attrs(node: onnx.NodeProto) -> dict[str, object]:
136
233
  return {attr.name: helper.get_attribute_value(attr) for attr in node.attribute}
137
234
 
138
235
 
236
+ def _find_value_info(
237
+ graph: onnx.GraphProto, name: str
238
+ ) -> onnx.ValueInfoProto | None:
239
+ for value_info in graph.input:
240
+ if value_info.name == name:
241
+ return value_info
242
+ for value_info in graph.value_info:
243
+ if value_info.name == name:
244
+ return value_info
245
+ for value_info in graph.output:
246
+ if value_info.name == name:
247
+ return value_info
248
+ return None
249
+
250
+
251
+ def _tensor_shape_from_value_info(
252
+ graph: onnx.GraphProto, name: str
253
+ ) -> tuple[int, ...]:
254
+ value_info = _find_value_info(graph, name)
255
+ if value_info is None:
256
+ for initializer in graph.initializer:
257
+ if initializer.name == name:
258
+ return tuple(int(dim) for dim in initializer.dims)
259
+ raise ShapeInferenceError(
260
+ f"Missing shape for '{name}' in Scan expansion. "
261
+ "Hint: run ONNX shape inference or export with static shapes."
262
+ )
263
+ tensor_type = value_info.type.tensor_type
264
+ if not tensor_type.HasField("shape"):
265
+ raise ShapeInferenceError(
266
+ f"Missing shape for '{name}' in Scan expansion. "
267
+ "Hint: run ONNX shape inference or export with static shapes."
268
+ )
269
+ dims: list[int] = []
270
+ for dim in tensor_type.shape.dim:
271
+ if not dim.HasField("dim_value"):
272
+ raise ShapeInferenceError(
273
+ f"Dynamic dim for '{name}' in Scan expansion. "
274
+ "Hint: export with static shapes."
275
+ )
276
+ dims.append(int(dim.dim_value))
277
+ return tuple(dims)
278
+
279
+
280
+ def _scan_attr_ints(
281
+ attrs: dict[str, object],
282
+ key: str,
283
+ *,
284
+ default: tuple[int, ...],
285
+ ) -> tuple[int, ...]:
286
+ value = attrs.get(key)
287
+ if value is None:
288
+ return default
289
+ return tuple(int(item) for item in value)
290
+
291
+
292
+ def _onnx_opset_version(model: onnx.ModelProto) -> int | None:
293
+ for opset in model.opset_import:
294
+ if opset.domain in {"", "ai.onnx"}:
295
+ return int(opset.version)
296
+ return None
297
+
298
+
299
+ def _scan_expected_axis(is_opset8: bool) -> int:
300
+ return 1 if is_opset8 else 0
301
+
302
+
303
+ def _scan_axes_and_directions(
304
+ attrs: dict[str, object],
305
+ *,
306
+ num_scan_inputs: int,
307
+ scan_output_count: int,
308
+ is_opset8: bool,
309
+ ) -> None:
310
+ default_axis = _scan_expected_axis(is_opset8)
311
+ scan_input_axes = _scan_attr_ints(
312
+ attrs,
313
+ "scan_input_axes",
314
+ default=(default_axis,) * num_scan_inputs,
315
+ )
316
+ scan_output_axes = _scan_attr_ints(
317
+ attrs,
318
+ "scan_output_axes",
319
+ default=(default_axis,) * scan_output_count,
320
+ )
321
+ scan_input_directions = _scan_attr_ints(
322
+ attrs,
323
+ "scan_input_directions",
324
+ default=(0,) * num_scan_inputs,
325
+ )
326
+ scan_output_directions = _scan_attr_ints(
327
+ attrs,
328
+ "scan_output_directions",
329
+ default=(0,) * scan_output_count,
330
+ )
331
+ if any(axis != default_axis for axis in scan_input_axes):
332
+ raise UnsupportedOpError(
333
+ f"Scan only supports scan_input_axes={default_axis}"
334
+ )
335
+ if any(axis != default_axis for axis in scan_output_axes):
336
+ raise UnsupportedOpError(
337
+ f"Scan only supports scan_output_axes={default_axis}"
338
+ )
339
+ if any(direction != 0 for direction in scan_input_directions):
340
+ raise UnsupportedOpError(
341
+ "Scan only supports scan_input_directions=0"
342
+ )
343
+ if any(direction != 0 for direction in scan_output_directions):
344
+ raise UnsupportedOpError(
345
+ "Scan only supports scan_output_directions=0"
346
+ )
347
+
348
+
349
+ def _scan_sequence_length(
350
+ graph: onnx.GraphProto,
351
+ scan_input_names: list[str],
352
+ *,
353
+ is_opset8: bool,
354
+ ) -> tuple[int, int | None]:
355
+ scan_input_shapes = [
356
+ _tensor_shape_from_value_info(graph, name)
357
+ for name in scan_input_names
358
+ ]
359
+ if not scan_input_shapes:
360
+ raise UnsupportedOpError("Scan requires scan inputs")
361
+ if is_opset8:
362
+ if any(len(shape) < 2 for shape in scan_input_shapes):
363
+ raise UnsupportedOpError(
364
+ "Scan opset 8 inputs must include batch and sequence dims"
365
+ )
366
+ batch_size = scan_input_shapes[0][0]
367
+ sequence_len = scan_input_shapes[0][1]
368
+ if batch_size != 1:
369
+ raise UnsupportedOpError(
370
+ "Scan opset 8 currently supports batch size 1 only"
371
+ )
372
+ if sequence_len <= 0:
373
+ raise UnsupportedOpError("Scan requires positive sequence length")
374
+ if any(
375
+ shape[0] != batch_size or shape[1] != sequence_len
376
+ for shape in scan_input_shapes
377
+ ):
378
+ raise UnsupportedOpError(
379
+ "Scan inputs must share the same batch and sequence length"
380
+ )
381
+ return sequence_len, batch_size
382
+ sequence_len = scan_input_shapes[0][0]
383
+ if sequence_len <= 0:
384
+ raise UnsupportedOpError("Scan requires positive sequence length")
385
+ if any(shape[0] != sequence_len for shape in scan_input_shapes):
386
+ raise UnsupportedOpError(
387
+ "Scan inputs must share the same sequence length"
388
+ )
389
+ return sequence_len, None
390
+
391
+
392
+ def _scan_body_initializers(
393
+ body: onnx.GraphProto,
394
+ *,
395
+ prefix: str,
396
+ new_initializers: list[onnx.TensorProto],
397
+ ) -> dict[str, str]:
398
+ initializer_map: dict[str, str] = {}
399
+ for initializer in body.initializer:
400
+ new_name = f"{prefix}_init_{initializer.name}"
401
+ initializer_map[initializer.name] = new_name
402
+ array = numpy_helper.to_array(initializer)
403
+ new_initializers.append(numpy_helper.from_array(array, name=new_name))
404
+ return initializer_map
405
+
406
+
407
+ def _scan_state_inputs(
408
+ graph: onnx.GraphProto,
409
+ *,
410
+ prefix: str,
411
+ state_input_names: list[str],
412
+ new_nodes: list[onnx.NodeProto],
413
+ is_opset8: bool,
414
+ batch_size: int | None,
415
+ ) -> list[str]:
416
+ state_names = list(state_input_names)
417
+ if is_opset8 and state_input_names:
418
+ for state_index, state_name in enumerate(state_input_names):
419
+ state_shape = _tensor_shape_from_value_info(graph, state_name)
420
+ if not state_shape:
421
+ raise UnsupportedOpError(
422
+ "Scan opset 8 state inputs must be tensors"
423
+ )
424
+ if batch_size is not None and state_shape[0] != batch_size:
425
+ raise UnsupportedOpError(
426
+ "Scan opset 8 state inputs must match batch size"
427
+ )
428
+ squeezed_name = f"{prefix}_state{state_index}_squeezed"
429
+ new_nodes.append(
430
+ helper.make_node(
431
+ "Squeeze",
432
+ inputs=[state_name],
433
+ outputs=[squeezed_name],
434
+ name=f"{squeezed_name}_node",
435
+ axes=[0],
436
+ )
437
+ )
438
+ state_names[state_index] = squeezed_name
439
+ return state_names
440
+
441
+
442
+ def _scan_iteration_inputs(
443
+ *,
444
+ prefix: str,
445
+ iter_index: int,
446
+ scan_input_names: list[str],
447
+ new_nodes: list[onnx.NodeProto],
448
+ is_opset8: bool,
449
+ ) -> list[str]:
450
+ scan_iter_inputs: list[str] = []
451
+ slice_axis = _scan_expected_axis(is_opset8)
452
+ squeeze_axes = [0, 1] if is_opset8 else [0]
453
+ for scan_index, scan_name in enumerate(scan_input_names):
454
+ slice_out = f"{prefix}_iter{iter_index}_scan{scan_index}_slice"
455
+ squeeze_out = f"{prefix}_iter{iter_index}_scan{scan_index}_value"
456
+ new_nodes.append(
457
+ helper.make_node(
458
+ "Slice",
459
+ inputs=[scan_name],
460
+ outputs=[slice_out],
461
+ name=f"{slice_out}_node",
462
+ starts=[iter_index],
463
+ ends=[iter_index + 1],
464
+ axes=[slice_axis],
465
+ )
466
+ )
467
+ new_nodes.append(
468
+ helper.make_node(
469
+ "Squeeze",
470
+ inputs=[slice_out],
471
+ outputs=[squeeze_out],
472
+ name=f"{squeeze_out}_node",
473
+ axes=squeeze_axes,
474
+ )
475
+ )
476
+ scan_iter_inputs.append(squeeze_out)
477
+ return scan_iter_inputs
478
+
479
+
480
+ def _expand_scan_nodes(model: onnx.ModelProto) -> tuple[onnx.ModelProto, bool]:
481
+ graph = model.graph
482
+ opset_version = _onnx_opset_version(model)
483
+ if opset_version is None:
484
+ return model, False
485
+
486
+ new_nodes: list[onnx.NodeProto] = []
487
+ new_initializers: list[onnx.TensorProto] = []
488
+ scan_index = 0
489
+ expanded = False
490
+ is_opset8 = opset_version <= 8
491
+
492
+ for node in graph.node:
493
+ if node.op_type != "Scan":
494
+ new_nodes.append(node)
495
+ continue
496
+
497
+ expanded = True
498
+ scan_index += 1
499
+ attrs = _node_attrs(node)
500
+ body = attrs.get("body")
501
+ if not isinstance(body, onnx.GraphProto):
502
+ raise UnsupportedOpError("Scan requires a body graph")
503
+ num_scan_inputs = int(attrs.get("num_scan_inputs", 0))
504
+ if num_scan_inputs <= 0:
505
+ raise UnsupportedOpError("Scan requires num_scan_inputs")
506
+ input_names = list(node.input)
507
+ if is_opset8:
508
+ if not input_names:
509
+ raise UnsupportedOpError("Scan in opset 8 requires inputs")
510
+ sequence_lens = input_names.pop(0)
511
+ if sequence_lens:
512
+ raise UnsupportedOpError(
513
+ "Scan sequence_lens input is not supported"
514
+ )
515
+ num_state_inputs = len(input_names) - num_scan_inputs
516
+ if num_state_inputs < 0:
517
+ raise UnsupportedOpError("Scan input count is invalid")
518
+ if len(body.input) != num_state_inputs + num_scan_inputs:
519
+ raise UnsupportedOpError(
520
+ "Scan body input count must match state and scan inputs"
521
+ )
522
+ if len(body.output) != len(node.output):
523
+ raise UnsupportedOpError(
524
+ "Scan body output count must match Scan outputs"
525
+ )
526
+ scan_output_count = len(node.output) - num_state_inputs
527
+ _scan_axes_and_directions(
528
+ attrs,
529
+ num_scan_inputs=num_scan_inputs,
530
+ scan_output_count=scan_output_count,
531
+ is_opset8=is_opset8,
532
+ )
533
+
534
+ state_input_names = input_names[:num_state_inputs]
535
+ scan_input_names = input_names[num_state_inputs:]
536
+ sequence_len, batch_size = _scan_sequence_length(
537
+ graph,
538
+ scan_input_names,
539
+ is_opset8=is_opset8,
540
+ )
541
+
542
+ prefix = node.name or f"scan_{scan_index}"
543
+ initializer_map = _scan_body_initializers(
544
+ body,
545
+ prefix=prefix,
546
+ new_initializers=new_initializers,
547
+ )
548
+
549
+ state_names = _scan_state_inputs(
550
+ graph,
551
+ prefix=prefix,
552
+ state_input_names=state_input_names,
553
+ new_nodes=new_nodes,
554
+ is_opset8=is_opset8,
555
+ batch_size=batch_size,
556
+ )
557
+ scan_output_buffers: list[list[str]] = [
558
+ [] for _ in range(scan_output_count)
559
+ ]
560
+
561
+ for iter_index in range(sequence_len):
562
+ scan_iter_inputs = _scan_iteration_inputs(
563
+ prefix=prefix,
564
+ iter_index=iter_index,
565
+ scan_input_names=scan_input_names,
566
+ new_nodes=new_nodes,
567
+ is_opset8=is_opset8,
568
+ )
569
+ name_map: dict[str, str] = {}
570
+ for index, value in enumerate(body.input[:num_state_inputs]):
571
+ name_map[value.name] = state_names[index]
572
+ for index, value in enumerate(
573
+ body.input[num_state_inputs : num_state_inputs + num_scan_inputs]
574
+ ):
575
+ name_map[value.name] = scan_iter_inputs[index]
576
+ for original, mapped in initializer_map.items():
577
+ name_map[original] = mapped
578
+
579
+ for body_node in body.node:
580
+ body_attrs = _node_attrs(body_node)
581
+ mapped_inputs = [
582
+ name_map.get(input_name, input_name)
583
+ for input_name in body_node.input
584
+ ]
585
+ mapped_outputs: list[str] = []
586
+ for output_name in body_node.output:
587
+ if not output_name:
588
+ mapped_outputs.append("")
589
+ continue
590
+ mapped_name = (
591
+ f"{prefix}_iter{iter_index}_{output_name}"
592
+ )
593
+ name_map[output_name] = mapped_name
594
+ mapped_outputs.append(mapped_name)
595
+ new_nodes.append(
596
+ helper.make_node(
597
+ body_node.op_type,
598
+ inputs=mapped_inputs,
599
+ outputs=mapped_outputs,
600
+ name=(
601
+ f"{prefix}_iter{iter_index}_{body_node.name}"
602
+ if body_node.name
603
+ else ""
604
+ ),
605
+ domain=body_node.domain,
606
+ **body_attrs,
607
+ )
608
+ )
609
+
610
+ for index, output in enumerate(body.output[:num_state_inputs]):
611
+ mapped_output = name_map.get(output.name)
612
+ if mapped_output is None:
613
+ raise UnsupportedOpError(
614
+ "Scan body did not produce a required state output"
615
+ )
616
+ state_names[index] = mapped_output
617
+
618
+ for output_index, output in enumerate(
619
+ body.output[
620
+ num_state_inputs : num_state_inputs + scan_output_count
621
+ ]
622
+ ):
623
+ mapped_output = name_map.get(output.name)
624
+ if mapped_output is None:
625
+ raise UnsupportedOpError(
626
+ "Scan body did not produce a required scan output"
627
+ )
628
+ unsqueeze_out = (
629
+ f"{prefix}_iter{iter_index}_scanout{output_index}"
630
+ )
631
+ unsqueeze_axes = [0, 1] if is_opset8 else [0]
632
+ new_nodes.append(
633
+ helper.make_node(
634
+ "Unsqueeze",
635
+ inputs=[mapped_output],
636
+ outputs=[unsqueeze_out],
637
+ name=f"{unsqueeze_out}_node",
638
+ axes=unsqueeze_axes,
639
+ )
640
+ )
641
+ scan_output_buffers[output_index].append(unsqueeze_out)
642
+
643
+ for index, output_name in enumerate(node.output[:num_state_inputs]):
644
+ state_value = state_names[index]
645
+ if is_opset8:
646
+ expanded_state = f"{prefix}_state_output_{index}_expanded"
647
+ new_nodes.append(
648
+ helper.make_node(
649
+ "Unsqueeze",
650
+ inputs=[state_value],
651
+ outputs=[expanded_state],
652
+ name=f"{expanded_state}_node",
653
+ axes=[0],
654
+ )
655
+ )
656
+ state_value = expanded_state
657
+ if state_value == output_name:
658
+ continue
659
+ new_nodes.append(
660
+ helper.make_node(
661
+ "Identity",
662
+ inputs=[state_value],
663
+ outputs=[output_name],
664
+ name=f"{prefix}_state_output_{index}",
665
+ )
666
+ )
667
+
668
+ for output_index, output_name in enumerate(
669
+ node.output[num_state_inputs : num_state_inputs + scan_output_count]
670
+ ):
671
+ buffer = scan_output_buffers[output_index]
672
+ concat_axis = _scan_expected_axis(is_opset8)
673
+ if len(buffer) == 1:
674
+ new_nodes.append(
675
+ helper.make_node(
676
+ "Identity",
677
+ inputs=buffer,
678
+ outputs=[output_name],
679
+ name=f"{prefix}_scan_output_{output_index}",
680
+ )
681
+ )
682
+ else:
683
+ new_nodes.append(
684
+ helper.make_node(
685
+ "Concat",
686
+ inputs=buffer,
687
+ outputs=[output_name],
688
+ name=f"{prefix}_scan_output_{output_index}",
689
+ axis=concat_axis,
690
+ )
691
+ )
692
+
693
+ if expanded:
694
+ del graph.node[:]
695
+ graph.node.extend(new_nodes)
696
+ if new_initializers:
697
+ graph.initializer.extend(new_initializers)
698
+ return model, expanded
699
+
700
+
139
701
  def _constant_initializer(node: onnx.NodeProto) -> Initializer:
140
702
  if len(node.output) != 1:
141
703
  raise UnsupportedOpError("Constant must have exactly one output")
@@ -209,16 +771,18 @@ def _constant_initializer(node: onnx.NodeProto) -> Initializer:
209
771
 
210
772
 
211
773
  def import_onnx(model: onnx.ModelProto) -> Graph:
774
+ model, _ = _expand_scan_nodes(model)
212
775
  dim_param_by_name = _collect_dim_params(
213
776
  tuple(model.graph.input) + tuple(model.graph.output)
214
777
  )
215
778
  opset_imports = tuple(
216
779
  (opset.domain, opset.version) for opset in model.opset_import
217
780
  )
218
- try:
219
- model = shape_inference.infer_shapes(model, data_prop=True)
220
- except Exception as exc: # pragma: no cover - onnx inference errors
221
- raise ShapeInferenceError("ONNX shape inference failed") from exc
781
+ if _needs_shape_inference(model):
782
+ try:
783
+ model = shape_inference.infer_shapes(model, data_prop=True)
784
+ except Exception as exc: # pragma: no cover - onnx inference errors
785
+ raise ShapeInferenceError("ONNX shape inference failed") from exc
222
786
  graph = model.graph
223
787
  base_initializers = [_initializer(value) for value in graph.initializer]
224
788
  constant_initializers: list[Initializer] = []
emx_onnx_cgen/ops.py CHANGED
@@ -554,6 +554,9 @@ def unary_op_symbol(function: ScalarFunction, *, dtype: ScalarType) -> str | Non
554
554
  def apply_binary_op(
555
555
  op_spec: BinaryOpSpec, left: np.ndarray, right: np.ndarray
556
556
  ) -> np.ndarray:
557
+ if op_spec.apply is np.power:
558
+ with np.errstate(invalid="ignore", divide="ignore", over="ignore"):
559
+ return op_spec.apply(left, right)
557
560
  return op_spec.apply(left, right)
558
561
 
559
562
 
@@ -0,0 +1,16 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ const {{ c_type }} r = {{ rate }}[0] / ({{ one_literal }} + ({{ c_type }}){{ timestep }}[0] * {{ decay_factor_literal }});
3
+ {% for tensor in tensors %}
4
+ {% for dim in tensor.shape %}
5
+ for (idx_t {{ tensor.loop_vars[loop.index0] }} = 0; {{ tensor.loop_vars[loop.index0] }} < {{ dim }}; ++{{ tensor.loop_vars[loop.index0] }}) {
6
+ {% endfor %}
7
+ {{ c_type }} g_regularized = {{ norm_coefficient_literal }} * {{ tensor.input_expr }} + {{ tensor.grad_expr }};
8
+ {{ c_type }} h_new = {{ tensor.acc_expr }} + g_regularized * g_regularized;
9
+ {{ tensor.acc_output_expr }} = h_new;
10
+ {{ c_type }} h_adaptive = {{ sqrt_fn }}(h_new) + {{ epsilon_literal }};
11
+ {{ tensor.output_expr }} = {{ tensor.input_expr }} - r * g_regularized / h_adaptive;
12
+ {% for _ in tensor.shape %}
13
+ }
14
+ {% endfor %}
15
+ {% endfor %}
16
+ }