emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (42) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +34 -0
  3. emx_onnx_cgen/cli.py +340 -59
  4. emx_onnx_cgen/codegen/c_emitter.py +2369 -111
  5. emx_onnx_cgen/compiler.py +188 -5
  6. emx_onnx_cgen/ir/model.py +1 -0
  7. emx_onnx_cgen/lowering/common.py +379 -2
  8. emx_onnx_cgen/lowering/conv_transpose.py +301 -0
  9. emx_onnx_cgen/lowering/einsum.py +153 -0
  10. emx_onnx_cgen/lowering/gather_elements.py +1 -3
  11. emx_onnx_cgen/lowering/gather_nd.py +79 -0
  12. emx_onnx_cgen/lowering/global_max_pool.py +59 -0
  13. emx_onnx_cgen/lowering/hardmax.py +53 -0
  14. emx_onnx_cgen/lowering/identity.py +6 -5
  15. emx_onnx_cgen/lowering/logsoftmax.py +5 -1
  16. emx_onnx_cgen/lowering/lp_pool.py +141 -0
  17. emx_onnx_cgen/lowering/matmul.py +6 -7
  18. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +12 -12
  19. emx_onnx_cgen/lowering/nonzero.py +42 -0
  20. emx_onnx_cgen/lowering/one_hot.py +120 -0
  21. emx_onnx_cgen/lowering/quantize_linear.py +126 -0
  22. emx_onnx_cgen/lowering/reduce.py +5 -6
  23. emx_onnx_cgen/lowering/reshape.py +223 -51
  24. emx_onnx_cgen/lowering/scatter_nd.py +82 -0
  25. emx_onnx_cgen/lowering/softmax.py +5 -1
  26. emx_onnx_cgen/lowering/squeeze.py +5 -5
  27. emx_onnx_cgen/lowering/topk.py +116 -0
  28. emx_onnx_cgen/lowering/trilu.py +89 -0
  29. emx_onnx_cgen/lowering/unsqueeze.py +5 -5
  30. emx_onnx_cgen/onnx_import.py +4 -0
  31. emx_onnx_cgen/onnxruntime_utils.py +11 -0
  32. emx_onnx_cgen/ops.py +4 -0
  33. emx_onnx_cgen/runtime/evaluator.py +460 -42
  34. emx_onnx_cgen/testbench.py +23 -0
  35. emx_onnx_cgen/verification.py +61 -0
  36. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/METADATA +31 -5
  37. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/RECORD +42 -25
  38. shared/scalar_functions.py +49 -17
  39. shared/ulp.py +48 -0
  40. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/WHEEL +0 -0
  41. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/entry_points.txt +0 -0
  42. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/top_level.txt +0 -0
emx_onnx_cgen/compiler.py CHANGED
@@ -10,6 +10,7 @@ import onnx
10
10
 
11
11
  from shared.scalar_types import ScalarType
12
12
 
13
+ from .onnxruntime_utils import make_deterministic_session_options
13
14
  from .codegen.c_emitter import (
14
15
  AttentionOp,
15
16
  AveragePoolOp,
@@ -27,18 +28,26 @@ from .codegen.c_emitter import (
27
28
  CEmitter,
28
29
  ConstTensor,
29
30
  ConvOp,
31
+ ConvTransposeOp,
30
32
  ConcatOp,
31
33
  ConstantOfShapeOp,
32
34
  CumSumOp,
33
35
  GemmOp,
34
36
  GatherOp,
35
37
  GatherElementsOp,
38
+ GatherNDOp,
39
+ ScatterNDOp,
36
40
  ExpandOp,
37
41
  RangeOp,
42
+ OneHotOp,
43
+ LpPoolOp,
44
+ QuantizeLinearOp,
38
45
  LrnOp,
39
46
  LstmOp,
40
47
  LogSoftmaxOp,
48
+ HardmaxOp,
41
49
  NegativeLogLikelihoodLossOp,
50
+ NonZeroOp,
42
51
  NodeInfo,
43
52
  PadOp,
44
53
  SplitOp,
@@ -52,6 +61,7 @@ from .codegen.c_emitter import (
52
61
  ReshapeOp,
53
62
  ResizeOp,
54
63
  GridSampleOp,
64
+ HardmaxOp,
55
65
  SoftmaxOp,
56
66
  ShapeOp,
57
67
  SliceOp,
@@ -61,12 +71,13 @@ from .codegen.c_emitter import (
61
71
  )
62
72
  from .dtypes import dtype_info
63
73
  from .errors import CodegenError, ShapeInferenceError, UnsupportedOpError
64
- from .ir.model import Graph, Value
74
+ from .ir.model import Graph, TensorType, Value
65
75
  from .lowering.attention import AttentionSpec, resolve_attention_spec
66
76
  from .lowering.average_pool import (
67
77
  lower_average_pool,
68
78
  lower_global_average_pool,
69
79
  )
80
+ from .lowering import global_max_pool as _global_max_pool # noqa: F401
70
81
  from .lowering.batch_normalization import lower_batch_normalization
71
82
  from .lowering.cast import lower_cast
72
83
  from .lowering.concat import lower_concat
@@ -78,25 +89,33 @@ from .lowering.common import (
78
89
  value_shape,
79
90
  )
80
91
  from .lowering.conv import ConvSpec, resolve_conv_spec
92
+ from .lowering import conv_transpose as _conv_transpose # noqa: F401
81
93
  from .lowering.constant_of_shape import lower_constant_of_shape
82
94
  from .lowering.dropout import lower_dropout
83
95
  from .lowering import cumsum as _cumsum # noqa: F401
96
+ from .lowering import einsum as _einsum # noqa: F401
84
97
  from .lowering.flatten import lower_flatten
85
98
  from .lowering.gather import lower_gather
86
99
  from .lowering.gather_elements import lower_gather_elements
100
+ from .lowering.gather_nd import lower_gather_nd
101
+ from .lowering import scatter_nd as _scatter_nd # noqa: F401
87
102
  from .lowering.gemm import resolve_gemm_spec, validate_gemm_bias_shape
88
103
  from .lowering.lrn import LrnSpec, resolve_lrn_spec
89
104
  from .lowering.logsoftmax import lower_logsoftmax
105
+ from .lowering import hardmax as _hardmax # noqa: F401
90
106
  from .lowering import group_normalization as _group_normalization # noqa: F401
91
107
  from .lowering import instance_normalization as _instance_normalization # noqa: F401
92
108
  from .lowering import layer_normalization as _layer_normalization # noqa: F401
93
109
  from .lowering import lp_normalization as _lp_normalization # noqa: F401
110
+ from .lowering import lp_pool as _lp_pool # noqa: F401
94
111
  from .lowering import mean_variance_normalization as _mean_variance_normalization # noqa: F401
95
112
  from .lowering.negative_log_likelihood_loss import (
96
113
  lower_negative_log_likelihood_loss,
97
114
  )
115
+ from .lowering import nonzero as _nonzero # noqa: F401
98
116
  from .lowering.expand import lower_expand
99
117
  from .lowering.range import lower_range
118
+ from .lowering import one_hot as _one_hot # noqa: F401
100
119
  from .lowering.split import lower_split
101
120
  from .lowering.softmax_cross_entropy_loss import (
102
121
  lower_softmax_cross_entropy_loss,
@@ -109,15 +128,18 @@ from .lowering.reduce import (
109
128
  REDUCE_OUTPUTS_FLOAT_ONLY,
110
129
  )
111
130
  from .lowering import arg_reduce as _arg_reduce # noqa: F401
131
+ from .lowering import topk as _topk # noqa: F401
112
132
  from .lowering.reshape import lower_reshape
113
133
  from .lowering.resize import lower_resize
114
134
  from .lowering.grid_sample import lower_grid_sample
135
+ from .lowering import quantize_linear as _quantize_linear # noqa: F401
115
136
  from .lowering.slice import lower_slice
116
137
  from .lowering.squeeze import lower_squeeze
117
138
  from .lowering import depth_space as _depth_space # noqa: F401
118
139
  from .lowering import eye_like as _eye_like # noqa: F401
119
140
  from .lowering import identity as _identity # noqa: F401
120
141
  from .lowering import tile as _tile # noqa: F401
142
+ from .lowering import trilu as _trilu # noqa: F401
121
143
  from .lowering.shape import lower_shape
122
144
  from .lowering.size import lower_size
123
145
  from .lowering.softmax import lower_softmax
@@ -157,6 +179,16 @@ class CompilerOptions:
157
179
  model_checksum: str | None = None
158
180
  restrict_arrays: bool = True
159
181
  testbench_inputs: Mapping[str, np.ndarray] | None = None
182
+ truncate_weights_after: int | None = None
183
+ large_temp_threshold_bytes: int = 1024
184
+ large_weight_threshold: int = 1024
185
+
186
+
187
+ def _onnx_elem_type(dtype: np.dtype) -> int:
188
+ for elem_type, info in onnx._mapping.TENSOR_TYPE_MAP.items():
189
+ if info.np_dtype == dtype:
190
+ return elem_type
191
+ raise UnsupportedOpError(f"Unsupported dtype {dtype} for ONNX output")
160
192
 
161
193
 
162
194
  class Compiler:
@@ -165,11 +197,16 @@ class Compiler:
165
197
  options = CompilerOptions(template_dir=Path("templates"))
166
198
  self._options = options
167
199
  self._emitter = CEmitter(
168
- options.template_dir, restrict_arrays=options.restrict_arrays
200
+ options.template_dir,
201
+ restrict_arrays=options.restrict_arrays,
202
+ truncate_weights_after=options.truncate_weights_after,
203
+ large_temp_threshold_bytes=options.large_temp_threshold_bytes,
204
+ large_weight_threshold=options.large_weight_threshold,
169
205
  )
170
206
 
171
207
  def compile(self, model: onnx.ModelProto) -> str:
172
208
  graph = import_onnx(model)
209
+ graph = self._concretize_graph_shapes(model, graph)
173
210
  testbench_inputs = self._resolve_testbench_inputs(graph)
174
211
  variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
175
212
  graph
@@ -185,6 +222,7 @@ class Compiler:
185
222
 
186
223
  def compile_with_data_file(self, model: onnx.ModelProto) -> tuple[str, str]:
187
224
  graph = import_onnx(model)
225
+ graph = self._concretize_graph_shapes(model, graph)
188
226
  testbench_inputs = self._resolve_testbench_inputs(graph)
189
227
  variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
190
228
  graph
@@ -198,6 +236,46 @@ class Compiler:
198
236
  variable_dim_outputs=variable_dim_outputs,
199
237
  )
200
238
 
239
+ def compile_with_weight_data(
240
+ self, model: onnx.ModelProto
241
+ ) -> tuple[str, bytes | None]:
242
+ graph = import_onnx(model)
243
+ graph = self._concretize_graph_shapes(model, graph)
244
+ testbench_inputs = self._resolve_testbench_inputs(graph)
245
+ variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
246
+ graph
247
+ )
248
+ lowered = self._lower_model(model, graph)
249
+ generated = self._emitter.emit_model(
250
+ lowered,
251
+ emit_testbench=self._options.emit_testbench,
252
+ testbench_inputs=testbench_inputs,
253
+ variable_dim_inputs=variable_dim_inputs,
254
+ variable_dim_outputs=variable_dim_outputs,
255
+ )
256
+ weight_data = self._emitter.collect_weight_data(lowered.constants)
257
+ return generated, weight_data
258
+
259
+ def compile_with_data_file_and_weight_data(
260
+ self, model: onnx.ModelProto
261
+ ) -> tuple[str, str, bytes | None]:
262
+ graph = import_onnx(model)
263
+ graph = self._concretize_graph_shapes(model, graph)
264
+ testbench_inputs = self._resolve_testbench_inputs(graph)
265
+ variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
266
+ graph
267
+ )
268
+ lowered = self._lower_model(model, graph)
269
+ generated, data_source = self._emitter.emit_model_with_data_file(
270
+ lowered,
271
+ emit_testbench=self._options.emit_testbench,
272
+ testbench_inputs=testbench_inputs,
273
+ variable_dim_inputs=variable_dim_inputs,
274
+ variable_dim_outputs=variable_dim_outputs,
275
+ )
276
+ weight_data = self._emitter.collect_weight_data(lowered.constants)
277
+ return generated, data_source, weight_data
278
+
201
279
  @staticmethod
202
280
  def _collect_variable_dims(
203
281
  graph: Graph,
@@ -282,15 +360,93 @@ class Compiler:
282
360
  resolved[name] = tuple(array.ravel().tolist())
283
361
  return resolved
284
362
 
363
+ def _concretize_graph_shapes(
364
+ self, model: onnx.ModelProto, graph: Graph
365
+ ) -> Graph:
366
+ if not self._options.testbench_inputs:
367
+ return graph
368
+ if not any(value.type.dim_params for value in graph.values):
369
+ if not any(value.type.dim_params for value in graph.inputs):
370
+ if not any(value.type.dim_params for value in graph.outputs):
371
+ return graph
372
+ try:
373
+ import onnxruntime as ort
374
+ except Exception:
375
+ return graph
376
+ try:
377
+ model_with_outputs = onnx.ModelProto()
378
+ model_with_outputs.CopyFrom(model)
379
+ existing_outputs = {
380
+ output.name for output in model_with_outputs.graph.output
381
+ }
382
+ value_info_by_name = {
383
+ value_info.name: value_info
384
+ for value_info in model_with_outputs.graph.value_info
385
+ }
386
+ for value in graph.values:
387
+ if value.name in existing_outputs:
388
+ continue
389
+ value_info = value_info_by_name.get(value.name)
390
+ if value_info is None:
391
+ dims: list[int | str | None] = []
392
+ for index, dim in enumerate(value.type.shape):
393
+ dim_param = None
394
+ if index < len(value.type.dim_params):
395
+ dim_param = value.type.dim_params[index]
396
+ dims.append(dim_param if dim_param else None)
397
+ elem_type = _onnx_elem_type(value.type.dtype.np_dtype)
398
+ value_info = onnx.helper.make_tensor_value_info(
399
+ value.name, elem_type, dims
400
+ )
401
+ model_with_outputs.graph.output.append(value_info)
402
+ existing_outputs.add(value.name)
403
+ output_names = [output.name for output in model_with_outputs.graph.output]
404
+ sess_options = make_deterministic_session_options(ort)
405
+ sess = ort.InferenceSession(
406
+ model_with_outputs.SerializeToString(),
407
+ sess_options=sess_options,
408
+ providers=["CPUExecutionProvider"],
409
+ )
410
+ output_arrays = sess.run(None, self._options.testbench_inputs)
411
+ except Exception:
412
+ return graph
413
+
414
+ shapes_by_name: dict[str, tuple[int, ...]] = {
415
+ name: tuple(int(dim) for dim in array.shape)
416
+ for name, array in zip(output_names, output_arrays)
417
+ }
418
+ for name, array in self._options.testbench_inputs.items():
419
+ shapes_by_name[name] = tuple(int(dim) for dim in array.shape)
420
+
421
+ def concretize_value(value: Value) -> Value:
422
+ shape = shapes_by_name.get(value.name)
423
+ if shape is None:
424
+ return value
425
+ return Value(
426
+ name=value.name,
427
+ type=TensorType(
428
+ dtype=value.type.dtype,
429
+ shape=shape,
430
+ dim_params=(None,) * len(shape),
431
+ ),
432
+ )
433
+
434
+ return Graph(
435
+ inputs=tuple(concretize_value(value) for value in graph.inputs),
436
+ outputs=tuple(concretize_value(value) for value in graph.outputs),
437
+ nodes=graph.nodes,
438
+ initializers=graph.initializers,
439
+ values=tuple(concretize_value(value) for value in graph.values),
440
+ opset_imports=graph.opset_imports,
441
+ )
442
+
285
443
  def _validate_graph(self, graph: Graph) -> None:
286
444
  if not graph.outputs:
287
445
  raise UnsupportedOpError("Graph must have at least one output")
288
446
  if not graph.nodes:
289
447
  raise UnsupportedOpError("Graph must contain at least one node")
290
448
  for value in graph.outputs:
291
- element_count = shape_product(value.type.shape)
292
- if element_count <= 0:
293
- raise ShapeInferenceError("Output shape must be fully defined")
449
+ shape_product(value.type.shape)
294
450
 
295
451
  def _collect_io_specs(
296
452
  self, graph: Graph
@@ -330,11 +486,14 @@ class Compiler:
330
486
  | UnaryOp
331
487
  | ClipOp
332
488
  | CastOp
489
+ | QuantizeLinearOp
333
490
  | MatMulOp
334
491
  | GemmOp
335
492
  | AttentionOp
336
493
  | ConvOp
494
+ | ConvTransposeOp
337
495
  | AveragePoolOp
496
+ | LpPoolOp
338
497
  | BatchNormOp
339
498
  | LpNormalizationOp
340
499
  | InstanceNormalizationOp
@@ -346,12 +505,15 @@ class Compiler:
346
505
  | LstmOp
347
506
  | SoftmaxOp
348
507
  | LogSoftmaxOp
508
+ | HardmaxOp
349
509
  | NegativeLogLikelihoodLossOp
350
510
  | SoftmaxCrossEntropyLossOp
351
511
  | MaxPoolOp
352
512
  | ConcatOp
353
513
  | GatherElementsOp
354
514
  | GatherOp
515
+ | GatherNDOp
516
+ | ScatterNDOp
355
517
  | TransposeOp
356
518
  | ConstantOfShapeOp
357
519
  | ReshapeOp
@@ -362,9 +524,11 @@ class Compiler:
362
524
  | ArgReduceOp
363
525
  | ShapeOp
364
526
  | PadOp
527
+ | NonZeroOp
365
528
  | ExpandOp
366
529
  | CumSumOp
367
530
  | RangeOp
531
+ | OneHotOp
368
532
  | SplitOp
369
533
  ],
370
534
  list[NodeInfo],
@@ -375,11 +539,14 @@ class Compiler:
375
539
  | UnaryOp
376
540
  | ClipOp
377
541
  | CastOp
542
+ | QuantizeLinearOp
378
543
  | MatMulOp
379
544
  | GemmOp
380
545
  | AttentionOp
381
546
  | ConvOp
547
+ | ConvTransposeOp
382
548
  | AveragePoolOp
549
+ | LpPoolOp
383
550
  | BatchNormOp
384
551
  | LpNormalizationOp
385
552
  | InstanceNormalizationOp
@@ -391,12 +558,14 @@ class Compiler:
391
558
  | LstmOp
392
559
  | SoftmaxOp
393
560
  | LogSoftmaxOp
561
+ | HardmaxOp
394
562
  | NegativeLogLikelihoodLossOp
395
563
  | SoftmaxCrossEntropyLossOp
396
564
  | MaxPoolOp
397
565
  | ConcatOp
398
566
  | GatherElementsOp
399
567
  | GatherOp
568
+ | GatherNDOp
400
569
  | TransposeOp
401
570
  | ConstantOfShapeOp
402
571
  | ReshapeOp
@@ -406,9 +575,11 @@ class Compiler:
406
575
  | ArgReduceOp
407
576
  | ShapeOp
408
577
  | PadOp
578
+ | NonZeroOp
409
579
  | ExpandOp
410
580
  | CumSumOp
411
581
  | RangeOp
582
+ | OneHotOp
412
583
  | SplitOp
413
584
  | WhereOp
414
585
  ] = []
@@ -515,6 +686,8 @@ def _lower_binary_unary(graph: Graph, node: Node) -> BinaryOp | UnaryOp:
515
686
  op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
516
687
  if op_spec is None:
517
688
  raise UnsupportedOpError("Unsupported op BitShift")
689
+ input0_shape = value_shape(graph, node.inputs[0], node)
690
+ input1_shape = value_shape(graph, node.inputs[1], node)
518
691
  output_shape = value_shape(graph, node.outputs[0], node)
519
692
  return BinaryOp(
520
693
  input0=node.inputs[0],
@@ -522,6 +695,8 @@ def _lower_binary_unary(graph: Graph, node: Node) -> BinaryOp | UnaryOp:
522
695
  output=node.outputs[0],
523
696
  function=function,
524
697
  operator_kind=op_spec.kind,
698
+ input0_shape=input0_shape,
699
+ input1_shape=input1_shape,
525
700
  shape=output_shape,
526
701
  dtype=op_dtype,
527
702
  input_dtype=op_dtype,
@@ -555,6 +730,8 @@ def _lower_binary_unary(graph: Graph, node: Node) -> BinaryOp | UnaryOp:
555
730
  raise UnsupportedOpError(
556
731
  f"{node.op_type} expects bool output, got {output_dtype.onnx_name}"
557
732
  )
733
+ input0_shape = value_shape(graph, node.inputs[0], node)
734
+ input1_shape = value_shape(graph, node.inputs[1], node)
558
735
  output_shape = value_shape(graph, node.outputs[0], node)
559
736
  return BinaryOp(
560
737
  input0=node.inputs[0],
@@ -562,6 +739,8 @@ def _lower_binary_unary(graph: Graph, node: Node) -> BinaryOp | UnaryOp:
562
739
  output=node.outputs[0],
563
740
  function=function,
564
741
  operator_kind=op_spec.kind,
742
+ input0_shape=input0_shape,
743
+ input1_shape=input1_shape,
565
744
  shape=output_shape,
566
745
  dtype=output_dtype,
567
746
  input_dtype=input_dtype,
@@ -576,6 +755,8 @@ def _lower_binary_unary(graph: Graph, node: Node) -> BinaryOp | UnaryOp:
576
755
  raise UnsupportedOpError(
577
756
  f"{node.op_type} must have 2 inputs and 1 output"
578
757
  )
758
+ input0_shape = value_shape(graph, node.inputs[0], node)
759
+ input1_shape = value_shape(graph, node.inputs[1], node)
579
760
  output_shape = value_shape(graph, node.outputs[0], node)
580
761
  return BinaryOp(
581
762
  input0=node.inputs[0],
@@ -583,6 +764,8 @@ def _lower_binary_unary(graph: Graph, node: Node) -> BinaryOp | UnaryOp:
583
764
  output=node.outputs[0],
584
765
  function=function,
585
766
  operator_kind=op_spec.kind,
767
+ input0_shape=input0_shape,
768
+ input1_shape=input1_shape,
586
769
  shape=output_shape,
587
770
  dtype=op_dtype,
588
771
  input_dtype=op_dtype,
emx_onnx_cgen/ir/model.py CHANGED
@@ -44,6 +44,7 @@ class Graph:
44
44
  nodes: tuple[Node, ...]
45
45
  initializers: tuple[Initializer, ...]
46
46
  values: tuple[Value, ...] = ()
47
+ opset_imports: tuple[tuple[str, int], ...] = ()
47
48
 
48
49
  def find_value(self, name: str) -> Value:
49
50
  for value in self.inputs + self.outputs + self.values: