emx-onnx-cgen 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (76) hide show
  1. emx_onnx_cgen/__init__.py +6 -0
  2. emx_onnx_cgen/__main__.py +9 -0
  3. emx_onnx_cgen/_build_info.py +3 -0
  4. emx_onnx_cgen/cli.py +328 -0
  5. emx_onnx_cgen/codegen/__init__.py +25 -0
  6. emx_onnx_cgen/codegen/c_emitter.py +9044 -0
  7. emx_onnx_cgen/compiler.py +601 -0
  8. emx_onnx_cgen/dtypes.py +40 -0
  9. emx_onnx_cgen/errors.py +14 -0
  10. emx_onnx_cgen/ir/__init__.py +3 -0
  11. emx_onnx_cgen/ir/model.py +55 -0
  12. emx_onnx_cgen/lowering/__init__.py +3 -0
  13. emx_onnx_cgen/lowering/arg_reduce.py +99 -0
  14. emx_onnx_cgen/lowering/attention.py +421 -0
  15. emx_onnx_cgen/lowering/average_pool.py +229 -0
  16. emx_onnx_cgen/lowering/batch_normalization.py +116 -0
  17. emx_onnx_cgen/lowering/cast.py +70 -0
  18. emx_onnx_cgen/lowering/common.py +72 -0
  19. emx_onnx_cgen/lowering/concat.py +31 -0
  20. emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
  21. emx_onnx_cgen/lowering/conv.py +192 -0
  22. emx_onnx_cgen/lowering/cumsum.py +118 -0
  23. emx_onnx_cgen/lowering/depth_space.py +114 -0
  24. emx_onnx_cgen/lowering/dropout.py +46 -0
  25. emx_onnx_cgen/lowering/elementwise.py +164 -0
  26. emx_onnx_cgen/lowering/expand.py +151 -0
  27. emx_onnx_cgen/lowering/eye_like.py +43 -0
  28. emx_onnx_cgen/lowering/flatten.py +60 -0
  29. emx_onnx_cgen/lowering/gather.py +48 -0
  30. emx_onnx_cgen/lowering/gather_elements.py +60 -0
  31. emx_onnx_cgen/lowering/gemm.py +139 -0
  32. emx_onnx_cgen/lowering/grid_sample.py +149 -0
  33. emx_onnx_cgen/lowering/group_normalization.py +68 -0
  34. emx_onnx_cgen/lowering/identity.py +43 -0
  35. emx_onnx_cgen/lowering/instance_normalization.py +50 -0
  36. emx_onnx_cgen/lowering/layer_normalization.py +110 -0
  37. emx_onnx_cgen/lowering/logsoftmax.py +47 -0
  38. emx_onnx_cgen/lowering/lp_normalization.py +45 -0
  39. emx_onnx_cgen/lowering/lrn.py +104 -0
  40. emx_onnx_cgen/lowering/lstm.py +355 -0
  41. emx_onnx_cgen/lowering/matmul.py +120 -0
  42. emx_onnx_cgen/lowering/maxpool.py +195 -0
  43. emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
  44. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
  45. emx_onnx_cgen/lowering/pad.py +287 -0
  46. emx_onnx_cgen/lowering/range.py +104 -0
  47. emx_onnx_cgen/lowering/reduce.py +544 -0
  48. emx_onnx_cgen/lowering/registry.py +51 -0
  49. emx_onnx_cgen/lowering/reshape.py +188 -0
  50. emx_onnx_cgen/lowering/resize.py +445 -0
  51. emx_onnx_cgen/lowering/rms_normalization.py +67 -0
  52. emx_onnx_cgen/lowering/shape.py +78 -0
  53. emx_onnx_cgen/lowering/size.py +33 -0
  54. emx_onnx_cgen/lowering/slice.py +425 -0
  55. emx_onnx_cgen/lowering/softmax.py +47 -0
  56. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
  57. emx_onnx_cgen/lowering/split.py +150 -0
  58. emx_onnx_cgen/lowering/squeeze.py +161 -0
  59. emx_onnx_cgen/lowering/tile.py +81 -0
  60. emx_onnx_cgen/lowering/transpose.py +46 -0
  61. emx_onnx_cgen/lowering/unsqueeze.py +157 -0
  62. emx_onnx_cgen/lowering/variadic.py +95 -0
  63. emx_onnx_cgen/lowering/where.py +73 -0
  64. emx_onnx_cgen/onnx_import.py +261 -0
  65. emx_onnx_cgen/ops.py +565 -0
  66. emx_onnx_cgen/runtime/__init__.py +1 -0
  67. emx_onnx_cgen/runtime/evaluator.py +2206 -0
  68. emx_onnx_cgen/validation.py +76 -0
  69. emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
  70. emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
  71. emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
  72. emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
  73. emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
  74. shared/__init__.py +2 -0
  75. shared/scalar_functions.py +2405 -0
  76. shared/scalar_types.py +243 -0
@@ -0,0 +1,601 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ import hashlib
5
+ from pathlib import Path
6
+ from typing import Mapping
7
+
8
+ import numpy as np
9
+ import onnx
10
+
11
+ from shared.scalar_types import ScalarType
12
+
13
+ from .codegen.c_emitter import (
14
+ AttentionOp,
15
+ AveragePoolOp,
16
+ BatchNormOp,
17
+ LpNormalizationOp,
18
+ InstanceNormalizationOp,
19
+ GroupNormalizationOp,
20
+ LayerNormalizationOp,
21
+ MeanVarianceNormalizationOp,
22
+ RMSNormalizationOp,
23
+ BinaryOp,
24
+ MultiInputBinaryOp,
25
+ CastOp,
26
+ ClipOp,
27
+ CEmitter,
28
+ ConstTensor,
29
+ ConvOp,
30
+ ConcatOp,
31
+ ConstantOfShapeOp,
32
+ CumSumOp,
33
+ GemmOp,
34
+ GatherOp,
35
+ GatherElementsOp,
36
+ ExpandOp,
37
+ RangeOp,
38
+ LrnOp,
39
+ LstmOp,
40
+ LogSoftmaxOp,
41
+ NegativeLogLikelihoodLossOp,
42
+ NodeInfo,
43
+ PadOp,
44
+ SplitOp,
45
+ SoftmaxCrossEntropyLossOp,
46
+ LoweredModel,
47
+ ModelHeader,
48
+ MatMulOp,
49
+ MaxPoolOp,
50
+ ReduceOp,
51
+ ArgReduceOp,
52
+ ReshapeOp,
53
+ ResizeOp,
54
+ GridSampleOp,
55
+ SoftmaxOp,
56
+ ShapeOp,
57
+ SliceOp,
58
+ TransposeOp,
59
+ UnaryOp,
60
+ WhereOp,
61
+ )
62
+ from .dtypes import dtype_info
63
+ from .errors import CodegenError, ShapeInferenceError, UnsupportedOpError
64
+ from .ir.model import Graph, Value
65
+ from .lowering.attention import AttentionSpec, resolve_attention_spec
66
+ from .lowering.average_pool import (
67
+ lower_average_pool,
68
+ lower_global_average_pool,
69
+ )
70
+ from .lowering.batch_normalization import lower_batch_normalization
71
+ from .lowering.cast import lower_cast
72
+ from .lowering.concat import lower_concat
73
+ from .lowering.common import (
74
+ ensure_supported_dtype,
75
+ node_dtype,
76
+ shape_product,
77
+ value_dtype,
78
+ value_shape,
79
+ )
80
+ from .lowering.conv import ConvSpec, resolve_conv_spec
81
+ from .lowering.constant_of_shape import lower_constant_of_shape
82
+ from .lowering.dropout import lower_dropout
83
+ from .lowering import cumsum as _cumsum # noqa: F401
84
+ from .lowering.flatten import lower_flatten
85
+ from .lowering.gather import lower_gather
86
+ from .lowering.gather_elements import lower_gather_elements
87
+ from .lowering.gemm import resolve_gemm_spec, validate_gemm_bias_shape
88
+ from .lowering.lrn import LrnSpec, resolve_lrn_spec
89
+ from .lowering.logsoftmax import lower_logsoftmax
90
+ from .lowering import group_normalization as _group_normalization # noqa: F401
91
+ from .lowering import instance_normalization as _instance_normalization # noqa: F401
92
+ from .lowering import layer_normalization as _layer_normalization # noqa: F401
93
+ from .lowering import lp_normalization as _lp_normalization # noqa: F401
94
+ from .lowering import mean_variance_normalization as _mean_variance_normalization # noqa: F401
95
+ from .lowering.negative_log_likelihood_loss import (
96
+ lower_negative_log_likelihood_loss,
97
+ )
98
+ from .lowering.expand import lower_expand
99
+ from .lowering.range import lower_range
100
+ from .lowering.split import lower_split
101
+ from .lowering.softmax_cross_entropy_loss import (
102
+ lower_softmax_cross_entropy_loss,
103
+ )
104
+ from .lowering.matmul import lower_matmul
105
+ from .lowering.maxpool import MaxPoolSpec, resolve_maxpool_spec
106
+ from .lowering import pad as _pad # noqa: F401
107
+ from .lowering.reduce import (
108
+ REDUCE_KIND_BY_OP,
109
+ REDUCE_OUTPUTS_FLOAT_ONLY,
110
+ )
111
+ from .lowering import arg_reduce as _arg_reduce # noqa: F401
112
+ from .lowering.reshape import lower_reshape
113
+ from .lowering.resize import lower_resize
114
+ from .lowering.grid_sample import lower_grid_sample
115
+ from .lowering.slice import lower_slice
116
+ from .lowering.squeeze import lower_squeeze
117
+ from .lowering import depth_space as _depth_space # noqa: F401
118
+ from .lowering import eye_like as _eye_like # noqa: F401
119
+ from .lowering import identity as _identity # noqa: F401
120
+ from .lowering import tile as _tile # noqa: F401
121
+ from .lowering.shape import lower_shape
122
+ from .lowering.size import lower_size
123
+ from .lowering.softmax import lower_softmax
124
+ from .lowering.transpose import lower_transpose
125
+ from .lowering.unsqueeze import lower_unsqueeze
126
+ from .lowering.where import lower_where
127
+ from .lowering.elementwise import (
128
+ lower_celu,
129
+ lower_clip,
130
+ lower_isinf,
131
+ lower_isnan,
132
+ lower_shrink,
133
+ lower_swish,
134
+ )
135
+ from .lowering import variadic as _variadic # noqa: F401
136
+ from .lowering import rms_normalization as _rms_normalization # noqa: F401
137
+ from .lowering.registry import get_lowering_registry, resolve_dispatch
138
+ from .onnx_import import import_onnx
139
+ from .ops import (
140
+ BINARY_OP_TYPES,
141
+ COMPARE_FUNCTIONS,
142
+ UNARY_OP_TYPES,
143
+ binary_op_symbol,
144
+ unary_op_symbol,
145
+ validate_unary_attrs,
146
+ )
147
+ from shared.scalar_functions import ScalarFunction, ScalarFunctionError
148
+ from .runtime.evaluator import Evaluator
149
+
150
+
151
+ @dataclass(frozen=True)
152
+ class CompilerOptions:
153
+ template_dir: Path
154
+ model_name: str = "model"
155
+ emit_testbench: bool = False
156
+ command_line: str | None = None
157
+ model_checksum: str | None = None
158
+ restrict_arrays: bool = True
159
+ testbench_inputs: Mapping[str, np.ndarray] | None = None
160
+
161
+
162
+ class Compiler:
163
+ def __init__(self, options: CompilerOptions | None = None) -> None:
164
+ if options is None:
165
+ options = CompilerOptions(template_dir=Path("templates"))
166
+ self._options = options
167
+ self._emitter = CEmitter(
168
+ options.template_dir, restrict_arrays=options.restrict_arrays
169
+ )
170
+
171
+ def compile(self, model: onnx.ModelProto) -> str:
172
+ graph = import_onnx(model)
173
+ testbench_inputs = self._resolve_testbench_inputs(graph)
174
+ variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
175
+ graph
176
+ )
177
+ lowered = self._lower_model(model, graph)
178
+ return self._emitter.emit_model(
179
+ lowered,
180
+ emit_testbench=self._options.emit_testbench,
181
+ testbench_inputs=testbench_inputs,
182
+ variable_dim_inputs=variable_dim_inputs,
183
+ variable_dim_outputs=variable_dim_outputs,
184
+ )
185
+
186
+ def compile_with_data_file(self, model: onnx.ModelProto) -> tuple[str, str]:
187
+ graph = import_onnx(model)
188
+ testbench_inputs = self._resolve_testbench_inputs(graph)
189
+ variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
190
+ graph
191
+ )
192
+ lowered = self._lower_model(model, graph)
193
+ return self._emitter.emit_model_with_data_file(
194
+ lowered,
195
+ emit_testbench=self._options.emit_testbench,
196
+ testbench_inputs=testbench_inputs,
197
+ variable_dim_inputs=variable_dim_inputs,
198
+ variable_dim_outputs=variable_dim_outputs,
199
+ )
200
+
201
+ @staticmethod
202
+ def _collect_variable_dims(
203
+ graph: Graph,
204
+ ) -> tuple[dict[int, dict[int, str]], dict[int, dict[int, str]]]:
205
+ def collect(values: tuple[Value, ...]) -> dict[int, dict[int, str]]:
206
+ dim_map: dict[int, dict[int, str]] = {}
207
+ for index, value in enumerate(values):
208
+ dims = {
209
+ dim_index: dim_param
210
+ for dim_index, dim_param in enumerate(
211
+ value.type.dim_params
212
+ )
213
+ if dim_param
214
+ }
215
+ if dims:
216
+ dim_map[index] = dims
217
+ return dim_map
218
+
219
+ return collect(graph.inputs), collect(graph.outputs)
220
+
221
+ def _lower_model(self, model: onnx.ModelProto, graph: Graph) -> LoweredModel:
222
+ constants = _lowered_constants(graph)
223
+ self._validate_graph(graph)
224
+ (
225
+ input_names,
226
+ input_shapes,
227
+ input_dtypes,
228
+ output_names,
229
+ output_shapes,
230
+ output_dtypes,
231
+ ) = self._collect_io_specs(graph)
232
+ ops, node_infos = self._lower_nodes(graph)
233
+ header = self._build_header(model, graph)
234
+ return LoweredModel(
235
+ name=self._options.model_name,
236
+ input_names=input_names,
237
+ input_shapes=input_shapes,
238
+ input_dtypes=input_dtypes,
239
+ output_names=output_names,
240
+ output_shapes=output_shapes,
241
+ output_dtypes=output_dtypes,
242
+ constants=constants,
243
+ ops=tuple(ops),
244
+ node_infos=tuple(node_infos),
245
+ header=header,
246
+ )
247
+
248
+ def _resolve_testbench_inputs(
249
+ self, graph: Graph
250
+ ) -> Mapping[str, tuple[float | int | bool, ...]] | None:
251
+ if not self._options.testbench_inputs:
252
+ return None
253
+ input_specs = {value.name: value for value in graph.inputs}
254
+ unknown_inputs = sorted(
255
+ name
256
+ for name in self._options.testbench_inputs
257
+ if name not in input_specs
258
+ )
259
+ if unknown_inputs:
260
+ raise CodegenError(
261
+ "Testbench inputs include unknown inputs: "
262
+ + ", ".join(unknown_inputs)
263
+ )
264
+ resolved: dict[str, tuple[float | int | bool, ...]] = {}
265
+ for name, values in self._options.testbench_inputs.items():
266
+ if not isinstance(values, np.ndarray):
267
+ raise CodegenError(
268
+ f"Testbench input {name} must be a numpy array"
269
+ )
270
+ input_value = input_specs[name]
271
+ dtype = value_dtype(graph, name)
272
+ info = dtype_info(dtype)
273
+ expected_shape = input_value.type.shape
274
+ expected_count = shape_product(expected_shape)
275
+ array = values.astype(info.np_dtype, copy=False)
276
+ if array.size != expected_count:
277
+ raise CodegenError(
278
+ "Testbench input "
279
+ f"{name} has {array.size} elements, expected {expected_count}"
280
+ )
281
+ array = array.reshape(expected_shape)
282
+ resolved[name] = tuple(array.ravel().tolist())
283
+ return resolved
284
+
285
+ def _validate_graph(self, graph: Graph) -> None:
286
+ if not graph.outputs:
287
+ raise UnsupportedOpError("Graph must have at least one output")
288
+ if not graph.nodes:
289
+ raise UnsupportedOpError("Graph must contain at least one node")
290
+ for value in graph.outputs:
291
+ element_count = shape_product(value.type.shape)
292
+ if element_count <= 0:
293
+ raise ShapeInferenceError("Output shape must be fully defined")
294
+
295
+ def _collect_io_specs(
296
+ self, graph: Graph
297
+ ) -> tuple[
298
+ tuple[str, ...],
299
+ tuple[tuple[int, ...], ...],
300
+ tuple[ScalarType, ...],
301
+ tuple[str, ...],
302
+ tuple[tuple[int, ...], ...],
303
+ tuple[ScalarType, ...],
304
+ ]:
305
+ input_names = tuple(value.name for value in graph.inputs)
306
+ input_shapes = tuple(value.type.shape for value in graph.inputs)
307
+ input_dtypes = tuple(
308
+ value_dtype(graph, value.name) for value in graph.inputs
309
+ )
310
+ output_names = tuple(value.name for value in graph.outputs)
311
+ output_shapes = tuple(value.type.shape for value in graph.outputs)
312
+ output_dtypes = tuple(
313
+ value_dtype(graph, value.name) for value in graph.outputs
314
+ )
315
+ return (
316
+ input_names,
317
+ input_shapes,
318
+ input_dtypes,
319
+ output_names,
320
+ output_shapes,
321
+ output_dtypes,
322
+ )
323
+
324
+ def _lower_nodes(
325
+ self, graph: Graph
326
+ ) -> tuple[
327
+ list[
328
+ BinaryOp
329
+ | MultiInputBinaryOp
330
+ | UnaryOp
331
+ | ClipOp
332
+ | CastOp
333
+ | MatMulOp
334
+ | GemmOp
335
+ | AttentionOp
336
+ | ConvOp
337
+ | AveragePoolOp
338
+ | BatchNormOp
339
+ | LpNormalizationOp
340
+ | InstanceNormalizationOp
341
+ | GroupNormalizationOp
342
+ | LayerNormalizationOp
343
+ | MeanVarianceNormalizationOp
344
+ | RMSNormalizationOp
345
+ | LrnOp
346
+ | LstmOp
347
+ | SoftmaxOp
348
+ | LogSoftmaxOp
349
+ | NegativeLogLikelihoodLossOp
350
+ | SoftmaxCrossEntropyLossOp
351
+ | MaxPoolOp
352
+ | ConcatOp
353
+ | GatherElementsOp
354
+ | GatherOp
355
+ | TransposeOp
356
+ | ConstantOfShapeOp
357
+ | ReshapeOp
358
+ | SliceOp
359
+ | ResizeOp
360
+ | GridSampleOp
361
+ | ReduceOp
362
+ | ArgReduceOp
363
+ | ShapeOp
364
+ | PadOp
365
+ | ExpandOp
366
+ | CumSumOp
367
+ | RangeOp
368
+ | SplitOp
369
+ ],
370
+ list[NodeInfo],
371
+ ]:
372
+ ops: list[
373
+ BinaryOp
374
+ | MultiInputBinaryOp
375
+ | UnaryOp
376
+ | ClipOp
377
+ | CastOp
378
+ | MatMulOp
379
+ | GemmOp
380
+ | AttentionOp
381
+ | ConvOp
382
+ | AveragePoolOp
383
+ | BatchNormOp
384
+ | LpNormalizationOp
385
+ | InstanceNormalizationOp
386
+ | GroupNormalizationOp
387
+ | LayerNormalizationOp
388
+ | MeanVarianceNormalizationOp
389
+ | RMSNormalizationOp
390
+ | LrnOp
391
+ | LstmOp
392
+ | SoftmaxOp
393
+ | LogSoftmaxOp
394
+ | NegativeLogLikelihoodLossOp
395
+ | SoftmaxCrossEntropyLossOp
396
+ | MaxPoolOp
397
+ | ConcatOp
398
+ | GatherElementsOp
399
+ | GatherOp
400
+ | TransposeOp
401
+ | ConstantOfShapeOp
402
+ | ReshapeOp
403
+ | SliceOp
404
+ | ResizeOp
405
+ | ReduceOp
406
+ | ArgReduceOp
407
+ | ShapeOp
408
+ | PadOp
409
+ | ExpandOp
410
+ | CumSumOp
411
+ | RangeOp
412
+ | SplitOp
413
+ | WhereOp
414
+ ] = []
415
+ node_infos: list[NodeInfo] = []
416
+ for node in graph.nodes:
417
+ lowering = resolve_dispatch(
418
+ node.op_type,
419
+ get_lowering_registry(),
420
+ binary_types=BINARY_OP_TYPES,
421
+ unary_types=UNARY_OP_TYPES,
422
+ binary_fallback=lambda: _lower_binary_unary,
423
+ unary_fallback=lambda: _lower_binary_unary,
424
+ )
425
+ ops.append(lowering(graph, node))
426
+ node_infos.append(
427
+ NodeInfo(
428
+ op_type=node.op_type,
429
+ name=node.name,
430
+ inputs=tuple(node.inputs),
431
+ outputs=tuple(node.outputs),
432
+ attrs=dict(node.attrs),
433
+ )
434
+ )
435
+ return ops, node_infos
436
+
437
+ def _build_header(self, model: onnx.ModelProto, graph: Graph) -> ModelHeader:
438
+ metadata_props = tuple(
439
+ (prop.key, prop.value) for prop in model.metadata_props
440
+ )
441
+ opset_imports = tuple(
442
+ (opset.domain, opset.version) for opset in model.opset_import
443
+ )
444
+ checksum = self._options.model_checksum
445
+ if checksum is None:
446
+ checksum = hashlib.sha256(model.SerializeToString()).hexdigest()
447
+ return ModelHeader(
448
+ generator="Generated by emmtrix ONNX-to-C Code Generator (emx-onnx-cgen)",
449
+ command_line=self._options.command_line,
450
+ model_checksum=checksum,
451
+ model_name=self._options.model_name,
452
+ graph_name=model.graph.name or None,
453
+ description=model.doc_string or None,
454
+ graph_description=model.graph.doc_string or None,
455
+ producer_name=model.producer_name or None,
456
+ producer_version=model.producer_version or None,
457
+ domain=model.domain or None,
458
+ model_version=model.model_version or None,
459
+ ir_version=model.ir_version or None,
460
+ opset_imports=opset_imports,
461
+ metadata_props=metadata_props,
462
+ input_count=len(graph.inputs),
463
+ output_count=len(graph.outputs),
464
+ node_count=len(graph.nodes),
465
+ initializer_count=len(graph.initializers),
466
+ )
467
+
468
+ def run(
469
+ self, model: onnx.ModelProto, feeds: Mapping[str, np.ndarray]
470
+ ) -> dict[str, np.ndarray]:
471
+ graph = import_onnx(model)
472
+ evaluator = Evaluator(graph)
473
+ return evaluator.run(feeds)
474
+
475
+
476
+ def _lowered_constants(graph: Graph) -> tuple[ConstTensor, ...]:
477
+ constants: list[ConstTensor] = []
478
+ for initializer in graph.initializers:
479
+ dtype = ensure_supported_dtype(initializer.type.dtype)
480
+ constants.append(
481
+ ConstTensor(
482
+ name=initializer.name,
483
+ shape=initializer.type.shape,
484
+ data=tuple(
485
+ dtype.np_dtype.type(value)
486
+ for value in initializer.data.ravel()
487
+ ),
488
+ dtype=dtype,
489
+ )
490
+ )
491
+ return tuple(constants)
492
+
493
+
494
+ def _lower_binary_unary(graph: Graph, node: Node) -> BinaryOp | UnaryOp:
495
+ if node.op_type == "BitShift":
496
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
497
+ raise UnsupportedOpError("BitShift must have 2 inputs and 1 output")
498
+ direction_attr = node.attrs.get("direction", "LEFT")
499
+ if isinstance(direction_attr, bytes):
500
+ direction = direction_attr.decode()
501
+ else:
502
+ direction = str(direction_attr)
503
+ if direction not in {"LEFT", "RIGHT"}:
504
+ raise UnsupportedOpError(
505
+ "BitShift direction must be LEFT or RIGHT"
506
+ )
507
+ op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
508
+ if not op_dtype.is_integer:
509
+ raise UnsupportedOpError("BitShift expects integer inputs")
510
+ function = (
511
+ ScalarFunction.BITWISE_LEFT_SHIFT
512
+ if direction == "LEFT"
513
+ else ScalarFunction.BITWISE_RIGHT_SHIFT
514
+ )
515
+ op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
516
+ if op_spec is None:
517
+ raise UnsupportedOpError("Unsupported op BitShift")
518
+ output_shape = value_shape(graph, node.outputs[0], node)
519
+ return BinaryOp(
520
+ input0=node.inputs[0],
521
+ input1=node.inputs[1],
522
+ output=node.outputs[0],
523
+ function=function,
524
+ operator_kind=op_spec.kind,
525
+ shape=output_shape,
526
+ dtype=op_dtype,
527
+ input_dtype=op_dtype,
528
+ )
529
+ if node.op_type == "Mod":
530
+ fmod = int(node.attrs.get("fmod", 0))
531
+ if fmod not in {0, 1}:
532
+ raise UnsupportedOpError("Mod only supports fmod=0 or fmod=1")
533
+ function = (
534
+ ScalarFunction.FMOD if fmod == 1 else ScalarFunction.REMAINDER
535
+ )
536
+ else:
537
+ try:
538
+ function = ScalarFunction.from_onnx_op(node.op_type)
539
+ except ScalarFunctionError as exc:
540
+ raise UnsupportedOpError(
541
+ f"Unsupported op {node.op_type}"
542
+ ) from exc
543
+ validate_unary_attrs(node.op_type, node.attrs)
544
+ if function in COMPARE_FUNCTIONS:
545
+ input_dtype = node_dtype(graph, node, *node.inputs)
546
+ output_dtype = value_dtype(graph, node.outputs[0], node)
547
+ op_spec = binary_op_symbol(function, node.attrs, dtype=input_dtype)
548
+ if op_spec is None:
549
+ raise UnsupportedOpError(f"Unsupported op {node.op_type}")
550
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
551
+ raise UnsupportedOpError(
552
+ f"{node.op_type} must have 2 inputs and 1 output"
553
+ )
554
+ if output_dtype != ScalarType.BOOL:
555
+ raise UnsupportedOpError(
556
+ f"{node.op_type} expects bool output, got {output_dtype.onnx_name}"
557
+ )
558
+ output_shape = value_shape(graph, node.outputs[0], node)
559
+ return BinaryOp(
560
+ input0=node.inputs[0],
561
+ input1=node.inputs[1],
562
+ output=node.outputs[0],
563
+ function=function,
564
+ operator_kind=op_spec.kind,
565
+ shape=output_shape,
566
+ dtype=output_dtype,
567
+ input_dtype=input_dtype,
568
+ )
569
+ op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
570
+ op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
571
+ unary_symbol = unary_op_symbol(function, dtype=op_dtype)
572
+ if op_spec is None and unary_symbol is None:
573
+ raise UnsupportedOpError(f"Unsupported op {node.op_type}")
574
+ if op_spec is not None:
575
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
576
+ raise UnsupportedOpError(
577
+ f"{node.op_type} must have 2 inputs and 1 output"
578
+ )
579
+ output_shape = value_shape(graph, node.outputs[0], node)
580
+ return BinaryOp(
581
+ input0=node.inputs[0],
582
+ input1=node.inputs[1],
583
+ output=node.outputs[0],
584
+ function=function,
585
+ operator_kind=op_spec.kind,
586
+ shape=output_shape,
587
+ dtype=op_dtype,
588
+ input_dtype=op_dtype,
589
+ )
590
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
591
+ raise UnsupportedOpError(f"{node.op_type} must have 1 input and 1 output")
592
+ output_shape = value_shape(graph, node.outputs[0], node)
593
+ return UnaryOp(
594
+ input0=node.inputs[0],
595
+ output=node.outputs[0],
596
+ function=function,
597
+ shape=output_shape,
598
+ dtype=op_dtype,
599
+ input_dtype=op_dtype,
600
+ params=(),
601
+ )
@@ -0,0 +1,40 @@
1
+ from __future__ import annotations
2
+
3
+ import onnx
4
+
5
+ from shared.scalar_types import ScalarFunctionError, ScalarType
6
+
7
+ from .errors import UnsupportedOpError
8
+
9
+ ONNX_TO_SCALAR_TYPE: dict[int, ScalarType] = {
10
+ onnx.TensorProto.FLOAT16: ScalarType.F16,
11
+ onnx.TensorProto.FLOAT: ScalarType.F32,
12
+ onnx.TensorProto.DOUBLE: ScalarType.F64,
13
+ onnx.TensorProto.BOOL: ScalarType.BOOL,
14
+ onnx.TensorProto.UINT8: ScalarType.U8,
15
+ onnx.TensorProto.UINT16: ScalarType.U16,
16
+ onnx.TensorProto.UINT32: ScalarType.U32,
17
+ onnx.TensorProto.UINT64: ScalarType.U64,
18
+ onnx.TensorProto.INT8: ScalarType.I8,
19
+ onnx.TensorProto.INT16: ScalarType.I16,
20
+ onnx.TensorProto.INT32: ScalarType.I32,
21
+ onnx.TensorProto.INT64: ScalarType.I64,
22
+ }
23
+
24
+
25
+ def scalar_type_from_onnx(elem_type: int) -> ScalarType | None:
26
+ return ONNX_TO_SCALAR_TYPE.get(elem_type)
27
+
28
+
29
+ def dtype_info(dtype: ScalarType | int | str) -> ScalarType:
30
+ if isinstance(dtype, ScalarType):
31
+ return dtype
32
+ if isinstance(dtype, int):
33
+ scalar = scalar_type_from_onnx(dtype)
34
+ if scalar is None:
35
+ raise UnsupportedOpError(f"Unsupported ONNX dtype enum: {dtype}")
36
+ return scalar
37
+ try:
38
+ return ScalarType.from_onnx_name(dtype)
39
+ except ScalarFunctionError:
40
+ raise UnsupportedOpError(f"Unsupported dtype: {dtype}") from None
@@ -0,0 +1,14 @@
1
+ class CompilerError(RuntimeError):
2
+ """Base error for compiler failures."""
3
+
4
+
5
+ class UnsupportedOpError(CompilerError):
6
+ """Raised when an ONNX operator is not supported."""
7
+
8
+
9
+ class ShapeInferenceError(CompilerError):
10
+ """Raised when tensor shapes cannot be resolved."""
11
+
12
+
13
+ class CodegenError(CompilerError):
14
+ """Raised when C code generation fails."""
@@ -0,0 +1,3 @@
1
+ from .model import Graph, Node, TensorType, Value
2
+
3
+ __all__ = ["Graph", "Node", "TensorType", "Value"]