emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (42) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +34 -0
  3. emx_onnx_cgen/cli.py +340 -59
  4. emx_onnx_cgen/codegen/c_emitter.py +2369 -111
  5. emx_onnx_cgen/compiler.py +188 -5
  6. emx_onnx_cgen/ir/model.py +1 -0
  7. emx_onnx_cgen/lowering/common.py +379 -2
  8. emx_onnx_cgen/lowering/conv_transpose.py +301 -0
  9. emx_onnx_cgen/lowering/einsum.py +153 -0
  10. emx_onnx_cgen/lowering/gather_elements.py +1 -3
  11. emx_onnx_cgen/lowering/gather_nd.py +79 -0
  12. emx_onnx_cgen/lowering/global_max_pool.py +59 -0
  13. emx_onnx_cgen/lowering/hardmax.py +53 -0
  14. emx_onnx_cgen/lowering/identity.py +6 -5
  15. emx_onnx_cgen/lowering/logsoftmax.py +5 -1
  16. emx_onnx_cgen/lowering/lp_pool.py +141 -0
  17. emx_onnx_cgen/lowering/matmul.py +6 -7
  18. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +12 -12
  19. emx_onnx_cgen/lowering/nonzero.py +42 -0
  20. emx_onnx_cgen/lowering/one_hot.py +120 -0
  21. emx_onnx_cgen/lowering/quantize_linear.py +126 -0
  22. emx_onnx_cgen/lowering/reduce.py +5 -6
  23. emx_onnx_cgen/lowering/reshape.py +223 -51
  24. emx_onnx_cgen/lowering/scatter_nd.py +82 -0
  25. emx_onnx_cgen/lowering/softmax.py +5 -1
  26. emx_onnx_cgen/lowering/squeeze.py +5 -5
  27. emx_onnx_cgen/lowering/topk.py +116 -0
  28. emx_onnx_cgen/lowering/trilu.py +89 -0
  29. emx_onnx_cgen/lowering/unsqueeze.py +5 -5
  30. emx_onnx_cgen/onnx_import.py +4 -0
  31. emx_onnx_cgen/onnxruntime_utils.py +11 -0
  32. emx_onnx_cgen/ops.py +4 -0
  33. emx_onnx_cgen/runtime/evaluator.py +460 -42
  34. emx_onnx_cgen/testbench.py +23 -0
  35. emx_onnx_cgen/verification.py +61 -0
  36. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/METADATA +31 -5
  37. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/RECORD +42 -25
  38. shared/scalar_functions.py +49 -17
  39. shared/ulp.py +48 -0
  40. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/WHEEL +0 -0
  41. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/entry_points.txt +0 -0
  42. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,7 @@ from ..codegen.c_emitter import SoftmaxOp
4
4
  from ..errors import UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from .common import node_dtype as _node_dtype
7
+ from .common import onnx_opset_version as _onnx_opset_version
7
8
  from .common import shape_product as _shape_product
8
9
  from .common import value_shape as _value_shape
9
10
  from .registry import register_lowering
@@ -23,8 +24,11 @@ def lower_softmax(graph: Graph, node: Node) -> SoftmaxOp:
23
24
  input_shape = _value_shape(graph, node.inputs[0], node)
24
25
  output_shape = _value_shape(graph, node.outputs[0], node)
25
26
  ensure_output_shape_matches_input(node, input_shape, output_shape)
27
+ opset_version = _onnx_opset_version(graph)
28
+ default_axis = 1 if opset_version is not None and opset_version < 13 else -1
29
+ axis_attr = node.attrs.get("axis", default_axis)
26
30
  axis = _normalize_axis(
27
- int(node.attrs.get("axis", -1)),
31
+ int(axis_attr),
28
32
  input_shape,
29
33
  node,
30
34
  )
@@ -95,11 +95,11 @@ def _validate_output_shape_for_unknown_axes(
95
95
  for dim in input_shape:
96
96
  if output_index < len(output_shape) and dim == output_shape[output_index]:
97
97
  output_index += 1
98
- continue
99
- if dim != 1:
100
- raise ShapeInferenceError(
101
- "Squeeze output shape must remove only dimensions of size 1"
102
- )
98
+ else:
99
+ if dim != 1:
100
+ raise ShapeInferenceError(
101
+ "Squeeze output shape must remove only dimensions of size 1"
102
+ )
103
103
  if output_index != len(output_shape):
104
104
  raise ShapeInferenceError(
105
105
  "Squeeze output shape must preserve input order while removing size-1 axes"
@@ -0,0 +1,116 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..codegen.c_emitter import TopKOp
8
+ from ..errors import ShapeInferenceError, UnsupportedOpError
9
+ from ..ir.model import Graph, Initializer, Node
10
+ from ..lowering.common import shape_product, value_dtype, value_shape
11
+ from ..validation import normalize_axis
12
+ from .registry import register_lowering
13
+
14
+
15
+ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
16
+ for initializer in graph.initializers:
17
+ if initializer.name == name:
18
+ return initializer
19
+ return None
20
+
21
+
22
+ def _read_k(graph: Graph, name: str, node: Node) -> int:
23
+ initializer = _find_initializer(graph, name)
24
+ if initializer is None:
25
+ raise UnsupportedOpError(
26
+ f"{node.op_type} k input must be a constant initializer"
27
+ )
28
+ if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
29
+ raise UnsupportedOpError(
30
+ f"{node.op_type} k input must be int64 or int32"
31
+ )
32
+ data = np.array(initializer.data, dtype=np.int64).reshape(-1)
33
+ if data.size != 1:
34
+ raise ShapeInferenceError(
35
+ f"{node.op_type} k input must contain a single value"
36
+ )
37
+ k = int(data[0])
38
+ if k <= 0:
39
+ raise ShapeInferenceError(
40
+ f"{node.op_type} k must be a positive value, got {k}"
41
+ )
42
+ return k
43
+
44
+
45
+ def _topk_dtype_supported(dtype: ScalarType) -> bool:
46
+ return not dtype.is_bool
47
+
48
+
49
+ def lower_topk(graph: Graph, node: Node) -> TopKOp:
50
+ if node.op_type != "TopK":
51
+ raise UnsupportedOpError(f"Unsupported op {node.op_type}")
52
+ if len(node.inputs) != 2 or len(node.outputs) != 2:
53
+ raise UnsupportedOpError(
54
+ f"{node.op_type} must have 2 inputs and 2 outputs"
55
+ )
56
+ input_name = node.inputs[0]
57
+ k_name = node.inputs[1]
58
+ output_values = node.outputs[0]
59
+ output_indices = node.outputs[1]
60
+ input_shape = value_shape(graph, input_name, node)
61
+ shape_product(input_shape)
62
+ axis = int(node.attrs.get("axis", -1))
63
+ axis = normalize_axis(axis, input_shape, node)
64
+ k = _read_k(graph, k_name, node)
65
+ axis_dim = input_shape[axis]
66
+ if k > axis_dim:
67
+ raise ShapeInferenceError(
68
+ f"{node.op_type} k {k} exceeds axis dimension {axis_dim}"
69
+ )
70
+ output_shape_expected = list(input_shape)
71
+ output_shape_expected[axis] = k
72
+ output_shape = tuple(output_shape_expected)
73
+ values_shape = value_shape(graph, output_values, node)
74
+ if values_shape != output_shape:
75
+ raise ShapeInferenceError(
76
+ f"{node.op_type} values output shape must be {output_shape}, got {values_shape}"
77
+ )
78
+ indices_shape = value_shape(graph, output_indices, node)
79
+ if indices_shape != output_shape:
80
+ raise ShapeInferenceError(
81
+ f"{node.op_type} indices output shape must be {output_shape}, got {indices_shape}"
82
+ )
83
+ input_dtype = value_dtype(graph, input_name, node)
84
+ if not _topk_dtype_supported(input_dtype):
85
+ raise UnsupportedOpError(
86
+ f"{node.op_type} does not support dtype {input_dtype.onnx_name}"
87
+ )
88
+ values_dtype = value_dtype(graph, output_values, node)
89
+ if values_dtype != input_dtype:
90
+ raise UnsupportedOpError(
91
+ f"{node.op_type} values output dtype must be {input_dtype.onnx_name}"
92
+ )
93
+ indices_dtype = value_dtype(graph, output_indices, node)
94
+ if indices_dtype != ScalarType.I64:
95
+ raise UnsupportedOpError(
96
+ f"{node.op_type} indices output dtype must be int64"
97
+ )
98
+ largest = bool(int(node.attrs.get("largest", 1)))
99
+ sorted_output = bool(int(node.attrs.get("sorted", 1)))
100
+ return TopKOp(
101
+ input0=input_name,
102
+ output_values=output_values,
103
+ output_indices=output_indices,
104
+ input_shape=input_shape,
105
+ output_shape=output_shape,
106
+ axis=axis,
107
+ k=k,
108
+ largest=largest,
109
+ sorted=sorted_output,
110
+ input_dtype=input_dtype,
111
+ output_values_dtype=values_dtype,
112
+ output_indices_dtype=indices_dtype,
113
+ )
114
+
115
+
116
+ register_lowering("TopK")(lower_topk)
@@ -0,0 +1,89 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..codegen.c_emitter import TriluOp
8
+ from ..errors import ShapeInferenceError, UnsupportedOpError
9
+ from ..ir.model import Graph, Initializer, Node
10
+ from ..lowering.common import optional_name, value_dtype, value_shape
11
+ from .registry import register_lowering
12
+
13
+
14
+ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
15
+ for initializer in graph.initializers:
16
+ if initializer.name == name:
17
+ return initializer
18
+ return None
19
+
20
+
21
+ def _is_scalar_shape(shape: tuple[int, ...]) -> bool:
22
+ return shape == () or shape == (1,)
23
+
24
+
25
+ def _read_k_initializer(initializer: Initializer, node: Node) -> int:
26
+ if initializer.type.dtype != ScalarType.I64:
27
+ raise UnsupportedOpError(
28
+ f"{node.op_type} k input must be int64"
29
+ )
30
+ data = np.array(initializer.data, dtype=np.int64).reshape(-1)
31
+ if data.size != 1:
32
+ raise UnsupportedOpError(f"{node.op_type} k input must be scalar")
33
+ return int(data[0])
34
+
35
+
36
+ @register_lowering("Trilu")
37
+ def lower_trilu(graph: Graph, node: Node) -> TriluOp:
38
+ if len(node.inputs) not in {1, 2} or len(node.outputs) != 1:
39
+ raise UnsupportedOpError("Trilu must have 1 or 2 inputs and 1 output")
40
+ input_name = node.inputs[0]
41
+ if not input_name:
42
+ raise UnsupportedOpError("Trilu input must be provided")
43
+ input_shape = value_shape(graph, input_name, node)
44
+ output_shape = value_shape(graph, node.outputs[0], node)
45
+ if input_shape != output_shape:
46
+ raise ShapeInferenceError("Trilu input and output shapes must match")
47
+ if len(output_shape) < 2:
48
+ raise UnsupportedOpError("Trilu expects input rank >= 2")
49
+ input_dtype = value_dtype(graph, input_name, node)
50
+ output_dtype = value_dtype(graph, node.outputs[0], node)
51
+ if input_dtype != output_dtype:
52
+ raise UnsupportedOpError(
53
+ "Trilu expects matching input/output dtypes, "
54
+ f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
55
+ )
56
+ upper_attr = node.attrs.get("upper", 1)
57
+ upper = bool(int(upper_attr))
58
+ k_input = optional_name(node.inputs, 1)
59
+ k_value = 0
60
+ k_input_name = None
61
+ k_input_shape = None
62
+ k_input_dtype = None
63
+ if k_input:
64
+ k_initializer = _find_initializer(graph, k_input)
65
+ if k_initializer is not None:
66
+ k_value = _read_k_initializer(k_initializer, node)
67
+ else:
68
+ k_shape = value_shape(graph, k_input, node)
69
+ if not _is_scalar_shape(k_shape):
70
+ raise UnsupportedOpError("Trilu k input must be scalar")
71
+ k_dtype = value_dtype(graph, k_input, node)
72
+ if k_dtype != ScalarType.I64:
73
+ raise UnsupportedOpError("Trilu k input must be int64")
74
+ k_input_name = k_input
75
+ k_input_shape = k_shape
76
+ k_input_dtype = k_dtype
77
+ return TriluOp(
78
+ input0=input_name,
79
+ output=node.outputs[0],
80
+ input_shape=input_shape,
81
+ output_shape=output_shape,
82
+ upper=upper,
83
+ k_value=k_value,
84
+ k_input=k_input_name,
85
+ k_input_shape=k_input_shape,
86
+ k_input_dtype=k_input_dtype,
87
+ dtype=output_dtype,
88
+ input_dtype=input_dtype,
89
+ )
@@ -131,11 +131,11 @@ def lower_unsqueeze(graph: Graph, node: Node) -> ReshapeOp:
131
131
  for dim in output_shape:
132
132
  if input_index < len(input_shape) and dim == input_shape[input_index]:
133
133
  input_index += 1
134
- continue
135
- if dim != 1:
136
- raise ShapeInferenceError(
137
- "Unsqueeze output shape must insert ones only"
138
- )
134
+ else:
135
+ if dim != 1:
136
+ raise ShapeInferenceError(
137
+ "Unsqueeze output shape must insert ones only"
138
+ )
139
139
  if input_index != len(input_shape):
140
140
  raise ShapeInferenceError(
141
141
  "Unsqueeze output shape must contain input shape in order"
@@ -212,6 +212,9 @@ def import_onnx(model: onnx.ModelProto) -> Graph:
212
212
  dim_param_by_name = _collect_dim_params(
213
213
  tuple(model.graph.input) + tuple(model.graph.output)
214
214
  )
215
+ opset_imports = tuple(
216
+ (opset.domain, opset.version) for opset in model.opset_import
217
+ )
215
218
  try:
216
219
  model = shape_inference.infer_shapes(model, data_prop=True)
217
220
  except Exception as exc: # pragma: no cover - onnx inference errors
@@ -258,4 +261,5 @@ def import_onnx(model: onnx.ModelProto) -> Graph:
258
261
  nodes=nodes,
259
262
  initializers=initializers,
260
263
  values=values,
264
+ opset_imports=opset_imports,
261
265
  )
@@ -0,0 +1,11 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+
6
+ def make_deterministic_session_options(ort: Any) -> Any:
7
+ options = ort.SessionOptions()
8
+ options.intra_op_num_threads = 1
9
+ options.inter_op_num_threads = 1
10
+ options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
11
+ return options
emx_onnx_cgen/ops.py CHANGED
@@ -87,6 +87,7 @@ UNARY_OP_TYPES = {
87
87
  "Identity",
88
88
  "LeakyRelu",
89
89
  "Log",
90
+ "Mish",
90
91
  "Neg",
91
92
  "Not",
92
93
  "Reciprocal",
@@ -177,6 +178,7 @@ UNARY_SYMBOLS_DOUBLE = {
177
178
  ScalarFunction.LEAKY_RELU: "leaky_relu",
178
179
  ScalarFunction.POSITIVE: "identity",
179
180
  ScalarFunction.LOG: "log",
181
+ ScalarFunction.MISH: "mish",
180
182
  ScalarFunction.NEG: "neg",
181
183
  ScalarFunction.RECIPROCAL: "reciprocal",
182
184
  ScalarFunction.RELU: "relu",
@@ -215,6 +217,7 @@ UNARY_SYMBOLS_FLOAT = {
215
217
  ScalarFunction.LEAKY_RELU: "leaky_relu",
216
218
  ScalarFunction.POSITIVE: "identity",
217
219
  ScalarFunction.LOG: "logf",
220
+ ScalarFunction.MISH: "mish",
218
221
  ScalarFunction.NEG: "neg",
219
222
  ScalarFunction.RECIPROCAL: "reciprocal",
220
223
  ScalarFunction.RELU: "relu",
@@ -457,6 +460,7 @@ UNARY_APPLY_FUNCS = {
457
460
  "thresholded_relu": lambda value: np.where(
458
461
  value > 1.0, value, 0.0
459
462
  ),
463
+ "mish": lambda value: value * np.tanh(np.log1p(np.exp(value))),
460
464
  "atanhf": np.arctanh,
461
465
  "atanh": np.arctanh,
462
466
  }