emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (99) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +34 -0
  3. emx_onnx_cgen/cli.py +372 -64
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +3932 -1398
  6. emx_onnx_cgen/codegen/emitter.py +5 -0
  7. emx_onnx_cgen/compiler.py +169 -343
  8. emx_onnx_cgen/ir/context.py +87 -0
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +193 -0
  11. emx_onnx_cgen/ir/op_context.py +65 -0
  12. emx_onnx_cgen/ir/ops/__init__.py +130 -0
  13. emx_onnx_cgen/ir/ops/elementwise.py +146 -0
  14. emx_onnx_cgen/ir/ops/misc.py +421 -0
  15. emx_onnx_cgen/ir/ops/nn.py +580 -0
  16. emx_onnx_cgen/ir/ops/reduce.py +95 -0
  17. emx_onnx_cgen/lowering/__init__.py +79 -1
  18. emx_onnx_cgen/lowering/adagrad.py +114 -0
  19. emx_onnx_cgen/lowering/arg_reduce.py +1 -1
  20. emx_onnx_cgen/lowering/attention.py +1 -1
  21. emx_onnx_cgen/lowering/average_pool.py +1 -1
  22. emx_onnx_cgen/lowering/batch_normalization.py +1 -1
  23. emx_onnx_cgen/lowering/cast.py +1 -1
  24. emx_onnx_cgen/lowering/common.py +406 -11
  25. emx_onnx_cgen/lowering/concat.py +1 -1
  26. emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
  27. emx_onnx_cgen/lowering/conv.py +1 -1
  28. emx_onnx_cgen/lowering/conv_transpose.py +301 -0
  29. emx_onnx_cgen/lowering/cumsum.py +1 -1
  30. emx_onnx_cgen/lowering/depth_space.py +1 -1
  31. emx_onnx_cgen/lowering/dropout.py +1 -1
  32. emx_onnx_cgen/lowering/einsum.py +153 -0
  33. emx_onnx_cgen/lowering/elementwise.py +152 -4
  34. emx_onnx_cgen/lowering/expand.py +1 -1
  35. emx_onnx_cgen/lowering/eye_like.py +1 -1
  36. emx_onnx_cgen/lowering/flatten.py +1 -1
  37. emx_onnx_cgen/lowering/gather.py +1 -1
  38. emx_onnx_cgen/lowering/gather_elements.py +2 -4
  39. emx_onnx_cgen/lowering/gather_nd.py +79 -0
  40. emx_onnx_cgen/lowering/gemm.py +1 -1
  41. emx_onnx_cgen/lowering/global_max_pool.py +59 -0
  42. emx_onnx_cgen/lowering/grid_sample.py +1 -1
  43. emx_onnx_cgen/lowering/group_normalization.py +1 -1
  44. emx_onnx_cgen/lowering/hardmax.py +53 -0
  45. emx_onnx_cgen/lowering/identity.py +7 -6
  46. emx_onnx_cgen/lowering/instance_normalization.py +1 -1
  47. emx_onnx_cgen/lowering/layer_normalization.py +1 -1
  48. emx_onnx_cgen/lowering/logsoftmax.py +6 -2
  49. emx_onnx_cgen/lowering/lp_normalization.py +1 -1
  50. emx_onnx_cgen/lowering/lp_pool.py +141 -0
  51. emx_onnx_cgen/lowering/lrn.py +1 -1
  52. emx_onnx_cgen/lowering/lstm.py +1 -1
  53. emx_onnx_cgen/lowering/matmul.py +7 -8
  54. emx_onnx_cgen/lowering/maxpool.py +1 -1
  55. emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
  56. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +13 -13
  57. emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
  58. emx_onnx_cgen/lowering/nonzero.py +42 -0
  59. emx_onnx_cgen/lowering/one_hot.py +120 -0
  60. emx_onnx_cgen/lowering/pad.py +1 -1
  61. emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
  62. emx_onnx_cgen/lowering/quantize_linear.py +126 -0
  63. emx_onnx_cgen/lowering/range.py +1 -1
  64. emx_onnx_cgen/lowering/reduce.py +6 -7
  65. emx_onnx_cgen/lowering/registry.py +24 -5
  66. emx_onnx_cgen/lowering/reshape.py +224 -52
  67. emx_onnx_cgen/lowering/resize.py +1 -1
  68. emx_onnx_cgen/lowering/rms_normalization.py +1 -1
  69. emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
  70. emx_onnx_cgen/lowering/scatter_nd.py +82 -0
  71. emx_onnx_cgen/lowering/shape.py +6 -25
  72. emx_onnx_cgen/lowering/size.py +1 -1
  73. emx_onnx_cgen/lowering/slice.py +1 -1
  74. emx_onnx_cgen/lowering/softmax.py +6 -2
  75. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
  76. emx_onnx_cgen/lowering/split.py +1 -1
  77. emx_onnx_cgen/lowering/squeeze.py +6 -6
  78. emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
  79. emx_onnx_cgen/lowering/tile.py +1 -1
  80. emx_onnx_cgen/lowering/topk.py +134 -0
  81. emx_onnx_cgen/lowering/transpose.py +1 -1
  82. emx_onnx_cgen/lowering/trilu.py +89 -0
  83. emx_onnx_cgen/lowering/unsqueeze.py +6 -6
  84. emx_onnx_cgen/lowering/variadic.py +1 -1
  85. emx_onnx_cgen/lowering/where.py +1 -1
  86. emx_onnx_cgen/onnx_import.py +4 -0
  87. emx_onnx_cgen/onnxruntime_utils.py +11 -0
  88. emx_onnx_cgen/ops.py +4 -0
  89. emx_onnx_cgen/runtime/evaluator.py +785 -43
  90. emx_onnx_cgen/testbench.py +23 -0
  91. emx_onnx_cgen/verification.py +31 -0
  92. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/METADATA +33 -6
  93. emx_onnx_cgen-0.3.1.dist-info/RECORD +107 -0
  94. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/WHEEL +1 -1
  95. shared/scalar_functions.py +60 -17
  96. shared/ulp.py +65 -0
  97. emx_onnx_cgen-0.2.0.dist-info/RECORD +0 -76
  98. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/entry_points.txt +0 -0
  99. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,5 @@
1
+ from __future__ import annotations
2
+
3
+ from .c_emitter import CEmitter
4
+
5
+ __all__ = ["CEmitter"]
emx_onnx_cgen/compiler.py CHANGED
@@ -10,141 +10,24 @@ import onnx
10
10
 
11
11
  from shared.scalar_types import ScalarType
12
12
 
13
+ from .onnxruntime_utils import make_deterministic_session_options
13
14
  from .codegen.c_emitter import (
14
- AttentionOp,
15
- AveragePoolOp,
16
- BatchNormOp,
17
- LpNormalizationOp,
18
- InstanceNormalizationOp,
19
- GroupNormalizationOp,
20
- LayerNormalizationOp,
21
- MeanVarianceNormalizationOp,
22
- RMSNormalizationOp,
23
- BinaryOp,
24
- MultiInputBinaryOp,
25
- CastOp,
26
- ClipOp,
27
15
  CEmitter,
28
16
  ConstTensor,
29
- ConvOp,
30
- ConcatOp,
31
- ConstantOfShapeOp,
32
- CumSumOp,
33
- GemmOp,
34
- GatherOp,
35
- GatherElementsOp,
36
- ExpandOp,
37
- RangeOp,
38
- LrnOp,
39
- LstmOp,
40
- LogSoftmaxOp,
41
- NegativeLogLikelihoodLossOp,
42
- NodeInfo,
43
- PadOp,
44
- SplitOp,
45
- SoftmaxCrossEntropyLossOp,
46
17
  LoweredModel,
47
18
  ModelHeader,
48
- MatMulOp,
49
- MaxPoolOp,
50
- ReduceOp,
51
- ArgReduceOp,
52
- ReshapeOp,
53
- ResizeOp,
54
- GridSampleOp,
55
- SoftmaxOp,
56
- ShapeOp,
57
- SliceOp,
58
- TransposeOp,
59
- UnaryOp,
60
- WhereOp,
19
+ NodeInfo,
61
20
  )
62
21
  from .dtypes import dtype_info
63
22
  from .errors import CodegenError, ShapeInferenceError, UnsupportedOpError
64
- from .ir.model import Graph, Value
65
- from .lowering.attention import AttentionSpec, resolve_attention_spec
66
- from .lowering.average_pool import (
67
- lower_average_pool,
68
- lower_global_average_pool,
69
- )
70
- from .lowering.batch_normalization import lower_batch_normalization
71
- from .lowering.cast import lower_cast
72
- from .lowering.concat import lower_concat
73
- from .lowering.common import (
74
- ensure_supported_dtype,
75
- node_dtype,
76
- shape_product,
77
- value_dtype,
78
- value_shape,
79
- )
80
- from .lowering.conv import ConvSpec, resolve_conv_spec
81
- from .lowering.constant_of_shape import lower_constant_of_shape
82
- from .lowering.dropout import lower_dropout
83
- from .lowering import cumsum as _cumsum # noqa: F401
84
- from .lowering.flatten import lower_flatten
85
- from .lowering.gather import lower_gather
86
- from .lowering.gather_elements import lower_gather_elements
87
- from .lowering.gemm import resolve_gemm_spec, validate_gemm_bias_shape
88
- from .lowering.lrn import LrnSpec, resolve_lrn_spec
89
- from .lowering.logsoftmax import lower_logsoftmax
90
- from .lowering import group_normalization as _group_normalization # noqa: F401
91
- from .lowering import instance_normalization as _instance_normalization # noqa: F401
92
- from .lowering import layer_normalization as _layer_normalization # noqa: F401
93
- from .lowering import lp_normalization as _lp_normalization # noqa: F401
94
- from .lowering import mean_variance_normalization as _mean_variance_normalization # noqa: F401
95
- from .lowering.negative_log_likelihood_loss import (
96
- lower_negative_log_likelihood_loss,
97
- )
98
- from .lowering.expand import lower_expand
99
- from .lowering.range import lower_range
100
- from .lowering.split import lower_split
101
- from .lowering.softmax_cross_entropy_loss import (
102
- lower_softmax_cross_entropy_loss,
103
- )
104
- from .lowering.matmul import lower_matmul
105
- from .lowering.maxpool import MaxPoolSpec, resolve_maxpool_spec
106
- from .lowering import pad as _pad # noqa: F401
107
- from .lowering.reduce import (
108
- REDUCE_KIND_BY_OP,
109
- REDUCE_OUTPUTS_FLOAT_ONLY,
110
- )
111
- from .lowering import arg_reduce as _arg_reduce # noqa: F401
112
- from .lowering.reshape import lower_reshape
113
- from .lowering.resize import lower_resize
114
- from .lowering.grid_sample import lower_grid_sample
115
- from .lowering.slice import lower_slice
116
- from .lowering.squeeze import lower_squeeze
117
- from .lowering import depth_space as _depth_space # noqa: F401
118
- from .lowering import eye_like as _eye_like # noqa: F401
119
- from .lowering import identity as _identity # noqa: F401
120
- from .lowering import tile as _tile # noqa: F401
121
- from .lowering.shape import lower_shape
122
- from .lowering.size import lower_size
123
- from .lowering.softmax import lower_softmax
124
- from .lowering.transpose import lower_transpose
125
- from .lowering.unsqueeze import lower_unsqueeze
126
- from .lowering.where import lower_where
127
- from .lowering.elementwise import (
128
- lower_celu,
129
- lower_clip,
130
- lower_isinf,
131
- lower_isnan,
132
- lower_shrink,
133
- lower_swish,
134
- )
135
- from .lowering import variadic as _variadic # noqa: F401
136
- from .lowering import rms_normalization as _rms_normalization # noqa: F401
137
- from .lowering.registry import get_lowering_registry, resolve_dispatch
23
+ from .ir.context import GraphContext
24
+ from .ir.model import Graph, TensorType, Value
25
+ from .ir.op_base import OpBase
26
+ from .ir.op_context import OpContext
27
+ from .lowering import load_lowering_registry
28
+ from .lowering.common import ensure_supported_dtype, shape_product, value_dtype
29
+ from .lowering.registry import get_lowering_registry
138
30
  from .onnx_import import import_onnx
139
- from .ops import (
140
- BINARY_OP_TYPES,
141
- COMPARE_FUNCTIONS,
142
- UNARY_OP_TYPES,
143
- binary_op_symbol,
144
- unary_op_symbol,
145
- validate_unary_attrs,
146
- )
147
- from shared.scalar_functions import ScalarFunction, ScalarFunctionError
148
31
  from .runtime.evaluator import Evaluator
149
32
 
150
33
 
@@ -157,6 +40,16 @@ class CompilerOptions:
157
40
  model_checksum: str | None = None
158
41
  restrict_arrays: bool = True
159
42
  testbench_inputs: Mapping[str, np.ndarray] | None = None
43
+ truncate_weights_after: int | None = None
44
+ large_temp_threshold_bytes: int = 1024
45
+ large_weight_threshold: int = 1024 * 1024
46
+
47
+
48
+ def _onnx_elem_type(dtype: np.dtype) -> int:
49
+ for elem_type, info in onnx._mapping.TENSOR_TYPE_MAP.items():
50
+ if info.np_dtype == dtype:
51
+ return elem_type
52
+ raise UnsupportedOpError(f"Unsupported dtype {dtype} for ONNX output")
160
53
 
161
54
 
162
55
  class Compiler:
@@ -165,11 +58,17 @@ class Compiler:
165
58
  options = CompilerOptions(template_dir=Path("templates"))
166
59
  self._options = options
167
60
  self._emitter = CEmitter(
168
- options.template_dir, restrict_arrays=options.restrict_arrays
61
+ options.template_dir,
62
+ restrict_arrays=options.restrict_arrays,
63
+ truncate_weights_after=options.truncate_weights_after,
64
+ large_temp_threshold_bytes=options.large_temp_threshold_bytes,
65
+ large_weight_threshold=options.large_weight_threshold,
169
66
  )
67
+ load_lowering_registry()
170
68
 
171
69
  def compile(self, model: onnx.ModelProto) -> str:
172
70
  graph = import_onnx(model)
71
+ graph = self._concretize_graph_shapes(model, graph)
173
72
  testbench_inputs = self._resolve_testbench_inputs(graph)
174
73
  variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
175
74
  graph
@@ -185,6 +84,7 @@ class Compiler:
185
84
 
186
85
  def compile_with_data_file(self, model: onnx.ModelProto) -> tuple[str, str]:
187
86
  graph = import_onnx(model)
87
+ graph = self._concretize_graph_shapes(model, graph)
188
88
  testbench_inputs = self._resolve_testbench_inputs(graph)
189
89
  variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
190
90
  graph
@@ -198,6 +98,46 @@ class Compiler:
198
98
  variable_dim_outputs=variable_dim_outputs,
199
99
  )
200
100
 
101
+ def compile_with_weight_data(
102
+ self, model: onnx.ModelProto
103
+ ) -> tuple[str, bytes | None]:
104
+ graph = import_onnx(model)
105
+ graph = self._concretize_graph_shapes(model, graph)
106
+ testbench_inputs = self._resolve_testbench_inputs(graph)
107
+ variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
108
+ graph
109
+ )
110
+ lowered = self._lower_model(model, graph)
111
+ generated = self._emitter.emit_model(
112
+ lowered,
113
+ emit_testbench=self._options.emit_testbench,
114
+ testbench_inputs=testbench_inputs,
115
+ variable_dim_inputs=variable_dim_inputs,
116
+ variable_dim_outputs=variable_dim_outputs,
117
+ )
118
+ weight_data = self._emitter.collect_weight_data(lowered.constants)
119
+ return generated, weight_data
120
+
121
+ def compile_with_data_file_and_weight_data(
122
+ self, model: onnx.ModelProto
123
+ ) -> tuple[str, str, bytes | None]:
124
+ graph = import_onnx(model)
125
+ graph = self._concretize_graph_shapes(model, graph)
126
+ testbench_inputs = self._resolve_testbench_inputs(graph)
127
+ variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
128
+ graph
129
+ )
130
+ lowered = self._lower_model(model, graph)
131
+ generated, data_source = self._emitter.emit_model_with_data_file(
132
+ lowered,
133
+ emit_testbench=self._options.emit_testbench,
134
+ testbench_inputs=testbench_inputs,
135
+ variable_dim_inputs=variable_dim_inputs,
136
+ variable_dim_outputs=variable_dim_outputs,
137
+ )
138
+ weight_data = self._emitter.collect_weight_data(lowered.constants)
139
+ return generated, data_source, weight_data
140
+
201
141
  @staticmethod
202
142
  def _collect_variable_dims(
203
143
  graph: Graph,
@@ -219,7 +159,8 @@ class Compiler:
219
159
  return collect(graph.inputs), collect(graph.outputs)
220
160
 
221
161
  def _lower_model(self, model: onnx.ModelProto, graph: Graph) -> LoweredModel:
222
- constants = _lowered_constants(graph)
162
+ ctx = GraphContext(graph)
163
+ constants = _lowered_constants(ctx)
223
164
  self._validate_graph(graph)
224
165
  (
225
166
  input_names,
@@ -229,7 +170,14 @@ class Compiler:
229
170
  output_shapes,
230
171
  output_dtypes,
231
172
  ) = self._collect_io_specs(graph)
232
- ops, node_infos = self._lower_nodes(graph)
173
+ ops, node_infos = self._lower_nodes(ctx)
174
+ op_ctx = OpContext(ctx)
175
+ for op in ops:
176
+ op.validate(op_ctx)
177
+ for op in ops:
178
+ op.infer_types(op_ctx)
179
+ for op in ops:
180
+ op.infer_shapes(op_ctx)
233
181
  header = self._build_header(model, graph)
234
182
  return LoweredModel(
235
183
  name=self._options.model_name,
@@ -243,6 +191,7 @@ class Compiler:
243
191
  ops=tuple(ops),
244
192
  node_infos=tuple(node_infos),
245
193
  header=header,
194
+ op_context=op_ctx,
246
195
  )
247
196
 
248
197
  def _resolve_testbench_inputs(
@@ -282,15 +231,93 @@ class Compiler:
282
231
  resolved[name] = tuple(array.ravel().tolist())
283
232
  return resolved
284
233
 
234
+ def _concretize_graph_shapes(
235
+ self, model: onnx.ModelProto, graph: Graph
236
+ ) -> Graph:
237
+ if not self._options.testbench_inputs:
238
+ return graph
239
+ if not any(value.type.dim_params for value in graph.values):
240
+ if not any(value.type.dim_params for value in graph.inputs):
241
+ if not any(value.type.dim_params for value in graph.outputs):
242
+ return graph
243
+ try:
244
+ import onnxruntime as ort
245
+ except Exception:
246
+ return graph
247
+ try:
248
+ model_with_outputs = onnx.ModelProto()
249
+ model_with_outputs.CopyFrom(model)
250
+ existing_outputs = {
251
+ output.name for output in model_with_outputs.graph.output
252
+ }
253
+ value_info_by_name = {
254
+ value_info.name: value_info
255
+ for value_info in model_with_outputs.graph.value_info
256
+ }
257
+ for value in graph.values:
258
+ if value.name in existing_outputs:
259
+ continue
260
+ value_info = value_info_by_name.get(value.name)
261
+ if value_info is None:
262
+ dims: list[int | str | None] = []
263
+ for index, dim in enumerate(value.type.shape):
264
+ dim_param = None
265
+ if index < len(value.type.dim_params):
266
+ dim_param = value.type.dim_params[index]
267
+ dims.append(dim_param if dim_param else None)
268
+ elem_type = _onnx_elem_type(value.type.dtype.np_dtype)
269
+ value_info = onnx.helper.make_tensor_value_info(
270
+ value.name, elem_type, dims
271
+ )
272
+ model_with_outputs.graph.output.append(value_info)
273
+ existing_outputs.add(value.name)
274
+ output_names = [output.name for output in model_with_outputs.graph.output]
275
+ sess_options = make_deterministic_session_options(ort)
276
+ sess = ort.InferenceSession(
277
+ model_with_outputs.SerializeToString(),
278
+ sess_options=sess_options,
279
+ providers=["CPUExecutionProvider"],
280
+ )
281
+ output_arrays = sess.run(None, self._options.testbench_inputs)
282
+ except Exception:
283
+ return graph
284
+
285
+ shapes_by_name: dict[str, tuple[int, ...]] = {
286
+ name: tuple(int(dim) for dim in array.shape)
287
+ for name, array in zip(output_names, output_arrays)
288
+ }
289
+ for name, array in self._options.testbench_inputs.items():
290
+ shapes_by_name[name] = tuple(int(dim) for dim in array.shape)
291
+
292
+ def concretize_value(value: Value) -> Value:
293
+ shape = shapes_by_name.get(value.name)
294
+ if shape is None:
295
+ return value
296
+ return Value(
297
+ name=value.name,
298
+ type=TensorType(
299
+ dtype=value.type.dtype,
300
+ shape=shape,
301
+ dim_params=(None,) * len(shape),
302
+ ),
303
+ )
304
+
305
+ return Graph(
306
+ inputs=tuple(concretize_value(value) for value in graph.inputs),
307
+ outputs=tuple(concretize_value(value) for value in graph.outputs),
308
+ nodes=graph.nodes,
309
+ initializers=graph.initializers,
310
+ values=tuple(concretize_value(value) for value in graph.values),
311
+ opset_imports=graph.opset_imports,
312
+ )
313
+
285
314
  def _validate_graph(self, graph: Graph) -> None:
286
315
  if not graph.outputs:
287
316
  raise UnsupportedOpError("Graph must have at least one output")
288
317
  if not graph.nodes:
289
318
  raise UnsupportedOpError("Graph must contain at least one node")
290
319
  for value in graph.outputs:
291
- element_count = shape_product(value.type.shape)
292
- if element_count <= 0:
293
- raise ShapeInferenceError("Output shape must be fully defined")
320
+ shape_product(value.type.shape)
294
321
 
295
322
  def _collect_io_specs(
296
323
  self, graph: Graph
@@ -322,107 +349,16 @@ class Compiler:
322
349
  )
323
350
 
324
351
  def _lower_nodes(
325
- self, graph: Graph
326
- ) -> tuple[
327
- list[
328
- BinaryOp
329
- | MultiInputBinaryOp
330
- | UnaryOp
331
- | ClipOp
332
- | CastOp
333
- | MatMulOp
334
- | GemmOp
335
- | AttentionOp
336
- | ConvOp
337
- | AveragePoolOp
338
- | BatchNormOp
339
- | LpNormalizationOp
340
- | InstanceNormalizationOp
341
- | GroupNormalizationOp
342
- | LayerNormalizationOp
343
- | MeanVarianceNormalizationOp
344
- | RMSNormalizationOp
345
- | LrnOp
346
- | LstmOp
347
- | SoftmaxOp
348
- | LogSoftmaxOp
349
- | NegativeLogLikelihoodLossOp
350
- | SoftmaxCrossEntropyLossOp
351
- | MaxPoolOp
352
- | ConcatOp
353
- | GatherElementsOp
354
- | GatherOp
355
- | TransposeOp
356
- | ConstantOfShapeOp
357
- | ReshapeOp
358
- | SliceOp
359
- | ResizeOp
360
- | GridSampleOp
361
- | ReduceOp
362
- | ArgReduceOp
363
- | ShapeOp
364
- | PadOp
365
- | ExpandOp
366
- | CumSumOp
367
- | RangeOp
368
- | SplitOp
369
- ],
370
- list[NodeInfo],
371
- ]:
372
- ops: list[
373
- BinaryOp
374
- | MultiInputBinaryOp
375
- | UnaryOp
376
- | ClipOp
377
- | CastOp
378
- | MatMulOp
379
- | GemmOp
380
- | AttentionOp
381
- | ConvOp
382
- | AveragePoolOp
383
- | BatchNormOp
384
- | LpNormalizationOp
385
- | InstanceNormalizationOp
386
- | GroupNormalizationOp
387
- | LayerNormalizationOp
388
- | MeanVarianceNormalizationOp
389
- | RMSNormalizationOp
390
- | LrnOp
391
- | LstmOp
392
- | SoftmaxOp
393
- | LogSoftmaxOp
394
- | NegativeLogLikelihoodLossOp
395
- | SoftmaxCrossEntropyLossOp
396
- | MaxPoolOp
397
- | ConcatOp
398
- | GatherElementsOp
399
- | GatherOp
400
- | TransposeOp
401
- | ConstantOfShapeOp
402
- | ReshapeOp
403
- | SliceOp
404
- | ResizeOp
405
- | ReduceOp
406
- | ArgReduceOp
407
- | ShapeOp
408
- | PadOp
409
- | ExpandOp
410
- | CumSumOp
411
- | RangeOp
412
- | SplitOp
413
- | WhereOp
414
- ] = []
352
+ self, ctx: GraphContext
353
+ ) -> tuple[list[OpBase], list[NodeInfo]]:
354
+ ops: list[OpBase] = []
415
355
  node_infos: list[NodeInfo] = []
416
- for node in graph.nodes:
417
- lowering = resolve_dispatch(
418
- node.op_type,
419
- get_lowering_registry(),
420
- binary_types=BINARY_OP_TYPES,
421
- unary_types=UNARY_OP_TYPES,
422
- binary_fallback=lambda: _lower_binary_unary,
423
- unary_fallback=lambda: _lower_binary_unary,
424
- )
425
- ops.append(lowering(graph, node))
356
+ registry = get_lowering_registry()
357
+ for node in ctx.nodes:
358
+ lowering = registry.get(node.op_type)
359
+ if lowering is None:
360
+ raise UnsupportedOpError(f"Unsupported op {node.op_type}")
361
+ ops.append(lowering(ctx, node))
426
362
  node_infos.append(
427
363
  NodeInfo(
428
364
  op_type=node.op_type,
@@ -473,7 +409,7 @@ class Compiler:
473
409
  return evaluator.run(feeds)
474
410
 
475
411
 
476
- def _lowered_constants(graph: Graph) -> tuple[ConstTensor, ...]:
412
+ def _lowered_constants(graph: Graph | GraphContext) -> tuple[ConstTensor, ...]:
477
413
  constants: list[ConstTensor] = []
478
414
  for initializer in graph.initializers:
479
415
  dtype = ensure_supported_dtype(initializer.type.dtype)
@@ -489,113 +425,3 @@ def _lowered_constants(graph: Graph) -> tuple[ConstTensor, ...]:
489
425
  )
490
426
  )
491
427
  return tuple(constants)
492
-
493
-
494
- def _lower_binary_unary(graph: Graph, node: Node) -> BinaryOp | UnaryOp:
495
- if node.op_type == "BitShift":
496
- if len(node.inputs) != 2 or len(node.outputs) != 1:
497
- raise UnsupportedOpError("BitShift must have 2 inputs and 1 output")
498
- direction_attr = node.attrs.get("direction", "LEFT")
499
- if isinstance(direction_attr, bytes):
500
- direction = direction_attr.decode()
501
- else:
502
- direction = str(direction_attr)
503
- if direction not in {"LEFT", "RIGHT"}:
504
- raise UnsupportedOpError(
505
- "BitShift direction must be LEFT or RIGHT"
506
- )
507
- op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
508
- if not op_dtype.is_integer:
509
- raise UnsupportedOpError("BitShift expects integer inputs")
510
- function = (
511
- ScalarFunction.BITWISE_LEFT_SHIFT
512
- if direction == "LEFT"
513
- else ScalarFunction.BITWISE_RIGHT_SHIFT
514
- )
515
- op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
516
- if op_spec is None:
517
- raise UnsupportedOpError("Unsupported op BitShift")
518
- output_shape = value_shape(graph, node.outputs[0], node)
519
- return BinaryOp(
520
- input0=node.inputs[0],
521
- input1=node.inputs[1],
522
- output=node.outputs[0],
523
- function=function,
524
- operator_kind=op_spec.kind,
525
- shape=output_shape,
526
- dtype=op_dtype,
527
- input_dtype=op_dtype,
528
- )
529
- if node.op_type == "Mod":
530
- fmod = int(node.attrs.get("fmod", 0))
531
- if fmod not in {0, 1}:
532
- raise UnsupportedOpError("Mod only supports fmod=0 or fmod=1")
533
- function = (
534
- ScalarFunction.FMOD if fmod == 1 else ScalarFunction.REMAINDER
535
- )
536
- else:
537
- try:
538
- function = ScalarFunction.from_onnx_op(node.op_type)
539
- except ScalarFunctionError as exc:
540
- raise UnsupportedOpError(
541
- f"Unsupported op {node.op_type}"
542
- ) from exc
543
- validate_unary_attrs(node.op_type, node.attrs)
544
- if function in COMPARE_FUNCTIONS:
545
- input_dtype = node_dtype(graph, node, *node.inputs)
546
- output_dtype = value_dtype(graph, node.outputs[0], node)
547
- op_spec = binary_op_symbol(function, node.attrs, dtype=input_dtype)
548
- if op_spec is None:
549
- raise UnsupportedOpError(f"Unsupported op {node.op_type}")
550
- if len(node.inputs) != 2 or len(node.outputs) != 1:
551
- raise UnsupportedOpError(
552
- f"{node.op_type} must have 2 inputs and 1 output"
553
- )
554
- if output_dtype != ScalarType.BOOL:
555
- raise UnsupportedOpError(
556
- f"{node.op_type} expects bool output, got {output_dtype.onnx_name}"
557
- )
558
- output_shape = value_shape(graph, node.outputs[0], node)
559
- return BinaryOp(
560
- input0=node.inputs[0],
561
- input1=node.inputs[1],
562
- output=node.outputs[0],
563
- function=function,
564
- operator_kind=op_spec.kind,
565
- shape=output_shape,
566
- dtype=output_dtype,
567
- input_dtype=input_dtype,
568
- )
569
- op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
570
- op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
571
- unary_symbol = unary_op_symbol(function, dtype=op_dtype)
572
- if op_spec is None and unary_symbol is None:
573
- raise UnsupportedOpError(f"Unsupported op {node.op_type}")
574
- if op_spec is not None:
575
- if len(node.inputs) != 2 or len(node.outputs) != 1:
576
- raise UnsupportedOpError(
577
- f"{node.op_type} must have 2 inputs and 1 output"
578
- )
579
- output_shape = value_shape(graph, node.outputs[0], node)
580
- return BinaryOp(
581
- input0=node.inputs[0],
582
- input1=node.inputs[1],
583
- output=node.outputs[0],
584
- function=function,
585
- operator_kind=op_spec.kind,
586
- shape=output_shape,
587
- dtype=op_dtype,
588
- input_dtype=op_dtype,
589
- )
590
- if len(node.inputs) != 1 or len(node.outputs) != 1:
591
- raise UnsupportedOpError(f"{node.op_type} must have 1 input and 1 output")
592
- output_shape = value_shape(graph, node.outputs[0], node)
593
- return UnaryOp(
594
- input0=node.inputs[0],
595
- output=node.outputs[0],
596
- function=function,
597
- shape=output_shape,
598
- dtype=op_dtype,
599
- input_dtype=op_dtype,
600
- params=(),
601
- )
@@ -0,0 +1,87 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+
5
+ from ..errors import ShapeInferenceError, UnsupportedOpError
6
+ from .model import Graph, Initializer, Node, Value
7
+ from shared.scalar_types import ScalarType
8
+
9
+
10
+ @dataclass
11
+ class GraphContext:
12
+ graph: Graph
13
+ _dtype_cache: dict[str, ScalarType] = field(default_factory=dict)
14
+ _shape_cache: dict[str, tuple[int, ...]] = field(default_factory=dict)
15
+ _initializer_cache: dict[str, Initializer] = field(default_factory=dict)
16
+ _producer_cache: dict[str, Node] = field(default_factory=dict)
17
+
18
+ def find_value(self, name: str) -> Value:
19
+ return self.graph.find_value(name)
20
+
21
+ def dtype(self, name: str, node: Node | None = None) -> ScalarType:
22
+ if name in self._dtype_cache:
23
+ return self._dtype_cache[name]
24
+ try:
25
+ value = self.graph.find_value(name)
26
+ except KeyError as exc:
27
+ op_type = node.op_type if node is not None else "unknown"
28
+ raise ShapeInferenceError(
29
+ f"Missing dtype for value '{name}' in op {op_type}. "
30
+ "Hint: run ONNX shape inference or export with static shapes."
31
+ ) from exc
32
+ dtype = value.type.dtype
33
+ if not isinstance(dtype, ScalarType):
34
+ raise UnsupportedOpError(f"Unsupported dtype {dtype}")
35
+ self._dtype_cache[name] = dtype
36
+ return dtype
37
+
38
+ def set_dtype(self, name: str, dtype: ScalarType) -> None:
39
+ self._dtype_cache[name] = dtype
40
+
41
+ def shape(self, name: str, node: Node | None = None) -> tuple[int, ...]:
42
+ if name in self._shape_cache:
43
+ return self._shape_cache[name]
44
+ try:
45
+ value = self.graph.find_value(name)
46
+ except KeyError as exc:
47
+ op_type = node.op_type if node is not None else "unknown"
48
+ raise ShapeInferenceError(
49
+ f"Missing shape for value '{name}' in op {op_type}. "
50
+ "Hint: run ONNX shape inference or export with static shapes."
51
+ ) from exc
52
+ self._shape_cache[name] = value.type.shape
53
+ return value.type.shape
54
+
55
+ def set_shape(self, name: str, shape: tuple[int, ...]) -> None:
56
+ self._shape_cache[name] = shape
57
+
58
+ def initializer(self, name: str) -> Initializer | None:
59
+ if name in self._initializer_cache:
60
+ return self._initializer_cache[name]
61
+ for initializer in self.graph.initializers:
62
+ if initializer.name == name:
63
+ self._initializer_cache[name] = initializer
64
+ return initializer
65
+ return None
66
+
67
+ def producer(self, output_name: str) -> Node | None:
68
+ if output_name in self._producer_cache:
69
+ return self._producer_cache[output_name]
70
+ for node in self.graph.nodes:
71
+ if output_name in node.outputs:
72
+ self._producer_cache[output_name] = node
73
+ return node
74
+ return None
75
+
76
+ def opset_version(self, domain: str = "") -> int | None:
77
+ if domain in {"", "ai.onnx"}:
78
+ domains = {"", "ai.onnx"}
79
+ else:
80
+ domains = {domain}
81
+ for opset_domain, version in self.graph.opset_imports:
82
+ if opset_domain in domains:
83
+ return int(version)
84
+ return None
85
+
86
+ def __getattr__(self, name: str):
87
+ return getattr(self.graph, name)