emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (99) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +34 -0
  3. emx_onnx_cgen/cli.py +372 -64
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +3932 -1398
  6. emx_onnx_cgen/codegen/emitter.py +5 -0
  7. emx_onnx_cgen/compiler.py +169 -343
  8. emx_onnx_cgen/ir/context.py +87 -0
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +193 -0
  11. emx_onnx_cgen/ir/op_context.py +65 -0
  12. emx_onnx_cgen/ir/ops/__init__.py +130 -0
  13. emx_onnx_cgen/ir/ops/elementwise.py +146 -0
  14. emx_onnx_cgen/ir/ops/misc.py +421 -0
  15. emx_onnx_cgen/ir/ops/nn.py +580 -0
  16. emx_onnx_cgen/ir/ops/reduce.py +95 -0
  17. emx_onnx_cgen/lowering/__init__.py +79 -1
  18. emx_onnx_cgen/lowering/adagrad.py +114 -0
  19. emx_onnx_cgen/lowering/arg_reduce.py +1 -1
  20. emx_onnx_cgen/lowering/attention.py +1 -1
  21. emx_onnx_cgen/lowering/average_pool.py +1 -1
  22. emx_onnx_cgen/lowering/batch_normalization.py +1 -1
  23. emx_onnx_cgen/lowering/cast.py +1 -1
  24. emx_onnx_cgen/lowering/common.py +406 -11
  25. emx_onnx_cgen/lowering/concat.py +1 -1
  26. emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
  27. emx_onnx_cgen/lowering/conv.py +1 -1
  28. emx_onnx_cgen/lowering/conv_transpose.py +301 -0
  29. emx_onnx_cgen/lowering/cumsum.py +1 -1
  30. emx_onnx_cgen/lowering/depth_space.py +1 -1
  31. emx_onnx_cgen/lowering/dropout.py +1 -1
  32. emx_onnx_cgen/lowering/einsum.py +153 -0
  33. emx_onnx_cgen/lowering/elementwise.py +152 -4
  34. emx_onnx_cgen/lowering/expand.py +1 -1
  35. emx_onnx_cgen/lowering/eye_like.py +1 -1
  36. emx_onnx_cgen/lowering/flatten.py +1 -1
  37. emx_onnx_cgen/lowering/gather.py +1 -1
  38. emx_onnx_cgen/lowering/gather_elements.py +2 -4
  39. emx_onnx_cgen/lowering/gather_nd.py +79 -0
  40. emx_onnx_cgen/lowering/gemm.py +1 -1
  41. emx_onnx_cgen/lowering/global_max_pool.py +59 -0
  42. emx_onnx_cgen/lowering/grid_sample.py +1 -1
  43. emx_onnx_cgen/lowering/group_normalization.py +1 -1
  44. emx_onnx_cgen/lowering/hardmax.py +53 -0
  45. emx_onnx_cgen/lowering/identity.py +7 -6
  46. emx_onnx_cgen/lowering/instance_normalization.py +1 -1
  47. emx_onnx_cgen/lowering/layer_normalization.py +1 -1
  48. emx_onnx_cgen/lowering/logsoftmax.py +6 -2
  49. emx_onnx_cgen/lowering/lp_normalization.py +1 -1
  50. emx_onnx_cgen/lowering/lp_pool.py +141 -0
  51. emx_onnx_cgen/lowering/lrn.py +1 -1
  52. emx_onnx_cgen/lowering/lstm.py +1 -1
  53. emx_onnx_cgen/lowering/matmul.py +7 -8
  54. emx_onnx_cgen/lowering/maxpool.py +1 -1
  55. emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
  56. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +13 -13
  57. emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
  58. emx_onnx_cgen/lowering/nonzero.py +42 -0
  59. emx_onnx_cgen/lowering/one_hot.py +120 -0
  60. emx_onnx_cgen/lowering/pad.py +1 -1
  61. emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
  62. emx_onnx_cgen/lowering/quantize_linear.py +126 -0
  63. emx_onnx_cgen/lowering/range.py +1 -1
  64. emx_onnx_cgen/lowering/reduce.py +6 -7
  65. emx_onnx_cgen/lowering/registry.py +24 -5
  66. emx_onnx_cgen/lowering/reshape.py +224 -52
  67. emx_onnx_cgen/lowering/resize.py +1 -1
  68. emx_onnx_cgen/lowering/rms_normalization.py +1 -1
  69. emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
  70. emx_onnx_cgen/lowering/scatter_nd.py +82 -0
  71. emx_onnx_cgen/lowering/shape.py +6 -25
  72. emx_onnx_cgen/lowering/size.py +1 -1
  73. emx_onnx_cgen/lowering/slice.py +1 -1
  74. emx_onnx_cgen/lowering/softmax.py +6 -2
  75. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
  76. emx_onnx_cgen/lowering/split.py +1 -1
  77. emx_onnx_cgen/lowering/squeeze.py +6 -6
  78. emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
  79. emx_onnx_cgen/lowering/tile.py +1 -1
  80. emx_onnx_cgen/lowering/topk.py +134 -0
  81. emx_onnx_cgen/lowering/transpose.py +1 -1
  82. emx_onnx_cgen/lowering/trilu.py +89 -0
  83. emx_onnx_cgen/lowering/unsqueeze.py +6 -6
  84. emx_onnx_cgen/lowering/variadic.py +1 -1
  85. emx_onnx_cgen/lowering/where.py +1 -1
  86. emx_onnx_cgen/onnx_import.py +4 -0
  87. emx_onnx_cgen/onnxruntime_utils.py +11 -0
  88. emx_onnx_cgen/ops.py +4 -0
  89. emx_onnx_cgen/runtime/evaluator.py +785 -43
  90. emx_onnx_cgen/testbench.py +23 -0
  91. emx_onnx_cgen/verification.py +31 -0
  92. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/METADATA +33 -6
  93. emx_onnx_cgen-0.3.1.dist-info/RECORD +107 -0
  94. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/WHEEL +1 -1
  95. shared/scalar_functions.py +60 -17
  96. shared/ulp.py +65 -0
  97. emx_onnx_cgen-0.2.0.dist-info/RECORD +0 -76
  98. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/entry_points.txt +0 -0
  99. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/top_level.txt +0 -0
@@ -2,32 +2,13 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import ShapeOp
5
+ from ..ir.ops import ShapeOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
+ from .common import value_dtype, value_shape
8
9
  from .registry import register_lowering
9
10
 
10
11
 
11
- def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
12
- try:
13
- return graph.find_value(name).type.shape
14
- except KeyError as exc:
15
- raise ShapeInferenceError(
16
- f"Missing shape for value '{name}' in op {node.op_type}. "
17
- "Hint: run ONNX shape inference or export with static shapes."
18
- ) from exc
19
-
20
-
21
- def _value_dtype(graph: Graph, name: str, node: Node) -> ScalarType:
22
- try:
23
- return graph.find_value(name).type.dtype
24
- except KeyError as exc:
25
- raise ShapeInferenceError(
26
- f"Missing dtype for value '{name}' in op {node.op_type}. "
27
- "Hint: run ONNX shape inference or export with static shapes."
28
- ) from exc
29
-
30
-
31
12
  def _normalize_slice_bounds(
32
13
  rank: int, *, start: int | None, end: int | None
33
14
  ) -> tuple[int, int]:
@@ -46,14 +27,14 @@ def _normalize_slice_bounds(
46
27
  def lower_shape(graph: Graph, node: Node) -> ShapeOp:
47
28
  if len(node.inputs) != 1 or len(node.outputs) != 1:
48
29
  raise UnsupportedOpError("Shape must have 1 input and 1 output")
49
- input_shape = _value_shape(graph, node.inputs[0], node)
50
- output_shape = _value_shape(graph, node.outputs[0], node)
30
+ input_shape = value_shape(graph, node.inputs[0], node)
31
+ output_shape = value_shape(graph, node.outputs[0], node)
51
32
  if len(output_shape) != 1:
52
33
  raise ShapeInferenceError("Shape output must be 1D")
53
34
  if output_shape[0] < 0:
54
35
  raise ShapeInferenceError("Shape output length must be non-negative")
55
- input_dtype = _value_dtype(graph, node.inputs[0], node)
56
- output_dtype = _value_dtype(graph, node.outputs[0], node)
36
+ input_dtype = value_dtype(graph, node.inputs[0], node)
37
+ output_dtype = value_dtype(graph, node.outputs[0], node)
57
38
  if output_dtype != ScalarType.I64:
58
39
  raise UnsupportedOpError("Shape output dtype must be int64")
59
40
  start = node.attrs.get("start")
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import SizeOp
5
+ from ..ir.ops import SizeOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .common import shape_product, value_dtype, value_shape
@@ -6,7 +6,7 @@ import numpy as np
6
6
 
7
7
  from shared.scalar_types import ScalarType
8
8
 
9
- from ..codegen.c_emitter import SliceOp
9
+ from ..ir.ops import SliceOp
10
10
  from ..errors import ShapeInferenceError, UnsupportedOpError
11
11
  from ..ir.model import Graph, Initializer, Node
12
12
  from ..lowering.common import value_dtype, value_shape
@@ -1,9 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import SoftmaxOp
3
+ from ..ir.ops import SoftmaxOp
4
4
  from ..errors import UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from .common import node_dtype as _node_dtype
7
+ from .common import onnx_opset_version as _onnx_opset_version
7
8
  from .common import shape_product as _shape_product
8
9
  from .common import value_shape as _value_shape
9
10
  from .registry import register_lowering
@@ -23,8 +24,11 @@ def lower_softmax(graph: Graph, node: Node) -> SoftmaxOp:
23
24
  input_shape = _value_shape(graph, node.inputs[0], node)
24
25
  output_shape = _value_shape(graph, node.outputs[0], node)
25
26
  ensure_output_shape_matches_input(node, input_shape, output_shape)
27
+ opset_version = _onnx_opset_version(graph)
28
+ default_axis = 1 if opset_version is not None and opset_version < 13 else -1
29
+ axis_attr = node.attrs.get("axis", default_axis)
26
30
  axis = _normalize_axis(
27
- int(node.attrs.get("axis", -1)),
31
+ int(axis_attr),
28
32
  input_shape,
29
33
  node,
30
34
  )
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import SoftmaxCrossEntropyLossOp
5
+ from ..ir.ops import SoftmaxCrossEntropyLossOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .common import shape_product as _shape_product
@@ -4,7 +4,7 @@ import numpy as np
4
4
 
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
- from ..codegen.c_emitter import SplitOp
7
+ from ..ir.ops import SplitOp
8
8
  from ..errors import ShapeInferenceError, UnsupportedOpError
9
9
  from ..ir.model import Graph, Initializer, Node
10
10
  from ..lowering.common import optional_name, value_dtype, value_shape
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import ReshapeOp
5
+ from ..ir.ops import ReshapeOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Initializer, Node
8
8
  from .registry import register_lowering
@@ -95,11 +95,11 @@ def _validate_output_shape_for_unknown_axes(
95
95
  for dim in input_shape:
96
96
  if output_index < len(output_shape) and dim == output_shape[output_index]:
97
97
  output_index += 1
98
- continue
99
- if dim != 1:
100
- raise ShapeInferenceError(
101
- "Squeeze output shape must remove only dimensions of size 1"
102
- )
98
+ else:
99
+ if dim != 1:
100
+ raise ShapeInferenceError(
101
+ "Squeeze output shape must remove only dimensions of size 1"
102
+ )
103
103
  if output_index != len(output_shape):
104
104
  raise ShapeInferenceError(
105
105
  "Squeeze output shape must preserve input order while removing size-1 axes"
@@ -0,0 +1,110 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..ir.ops import TensorScatterOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from ..validation import normalize_axis
9
+ from .common import optional_name, value_dtype, value_shape
10
+ from .registry import register_lowering
11
+
12
+ _ALLOWED_MODES = {"linear", "circular"}
13
+
14
+
15
+ @register_lowering("TensorScatter")
16
+ def lower_tensor_scatter(graph: Graph, node: Node) -> TensorScatterOp:
17
+ if len(node.inputs) not in {2, 3} or len(node.outputs) != 1:
18
+ raise UnsupportedOpError(
19
+ "TensorScatter must have 2 or 3 inputs and 1 output"
20
+ )
21
+ past_cache_name = node.inputs[0]
22
+ update_name = node.inputs[1]
23
+ write_indices_name = optional_name(node.inputs, 2)
24
+ output_name = node.outputs[0]
25
+ past_cache_shape = value_shape(graph, past_cache_name, node)
26
+ update_shape = value_shape(graph, update_name, node)
27
+ output_shape = value_shape(graph, output_name, node)
28
+ if output_shape != past_cache_shape:
29
+ raise ShapeInferenceError(
30
+ "TensorScatter output shape must match past_cache shape, "
31
+ f"got {output_shape} vs {past_cache_shape}"
32
+ )
33
+ if len(update_shape) != len(past_cache_shape):
34
+ raise ShapeInferenceError(
35
+ "TensorScatter update shape rank must match past_cache rank, "
36
+ f"got {len(update_shape)} vs {len(past_cache_shape)}"
37
+ )
38
+ axis = normalize_axis(int(node.attrs.get("axis", -2)), past_cache_shape, node)
39
+ if axis == 0:
40
+ raise UnsupportedOpError(
41
+ "TensorScatter axis cannot be 0 (batch dimension)"
42
+ )
43
+ for dim_index, (past_dim, update_dim) in enumerate(
44
+ zip(past_cache_shape, update_shape)
45
+ ):
46
+ if dim_index == axis:
47
+ if update_dim > past_dim:
48
+ raise ShapeInferenceError(
49
+ "TensorScatter update sequence length must be <= "
50
+ "past_cache sequence length, "
51
+ f"got {update_dim} vs {past_dim}"
52
+ )
53
+ elif update_dim != past_dim:
54
+ raise ShapeInferenceError(
55
+ "TensorScatter update shape must match past_cache shape "
56
+ f"outside axis {axis}, got {update_shape} vs {past_cache_shape}"
57
+ )
58
+ mode = node.attrs.get("mode", "linear")
59
+ if isinstance(mode, bytes):
60
+ mode = mode.decode("utf-8")
61
+ if mode not in _ALLOWED_MODES:
62
+ raise UnsupportedOpError(
63
+ "TensorScatter mode must be one of "
64
+ f"{sorted(_ALLOWED_MODES)}, got {mode}"
65
+ )
66
+ dtype = value_dtype(graph, past_cache_name, node)
67
+ update_dtype = value_dtype(graph, update_name, node)
68
+ output_dtype = value_dtype(graph, output_name, node)
69
+ if update_dtype != dtype or output_dtype != dtype:
70
+ raise UnsupportedOpError(
71
+ "TensorScatter expects past_cache, update, and output "
72
+ "to share the same dtype, "
73
+ f"got {dtype.onnx_name}, {update_dtype.onnx_name}, "
74
+ f"{output_dtype.onnx_name}"
75
+ )
76
+ write_indices_shape = None
77
+ write_indices_dtype = None
78
+ if write_indices_name is not None:
79
+ write_indices_shape = value_shape(graph, write_indices_name, node)
80
+ if len(write_indices_shape) != 1:
81
+ raise ShapeInferenceError(
82
+ "TensorScatter write_indices must be a 1D tensor"
83
+ )
84
+ if write_indices_shape[0] != past_cache_shape[0]:
85
+ raise ShapeInferenceError(
86
+ "TensorScatter write_indices length must match batch size, "
87
+ f"got {write_indices_shape[0]} vs {past_cache_shape[0]}"
88
+ )
89
+ write_indices_dtype = value_dtype(
90
+ graph, write_indices_name, node
91
+ )
92
+ if write_indices_dtype not in {ScalarType.I64, ScalarType.I32}:
93
+ raise UnsupportedOpError(
94
+ "TensorScatter write_indices must be int32 or int64, "
95
+ f"got {write_indices_dtype.onnx_name}"
96
+ )
97
+ return TensorScatterOp(
98
+ past_cache=past_cache_name,
99
+ update=update_name,
100
+ write_indices=write_indices_name,
101
+ output=output_name,
102
+ past_cache_shape=past_cache_shape,
103
+ update_shape=update_shape,
104
+ output_shape=output_shape,
105
+ write_indices_shape=write_indices_shape,
106
+ axis=axis,
107
+ mode=mode,
108
+ dtype=dtype,
109
+ write_indices_dtype=write_indices_dtype,
110
+ )
@@ -4,7 +4,7 @@ import numpy as np
4
4
 
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
- from ..codegen.c_emitter import TileOp
7
+ from ..ir.ops import TileOp
8
8
  from ..errors import ShapeInferenceError, UnsupportedOpError
9
9
  from ..ir.model import Graph, Initializer, Node
10
10
  from ..lowering.common import value_dtype, value_shape
@@ -0,0 +1,134 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..ir.ops import TopKOp
8
+ from ..errors import ShapeInferenceError, UnsupportedOpError
9
+ from ..ir.model import Graph, Initializer, Node
10
+ from ..lowering.common import shape_product, value_dtype, value_shape
11
+ from ..validation import normalize_axis
12
+ from .registry import register_lowering
13
+
14
+
15
+ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
16
+ for initializer in graph.initializers:
17
+ if initializer.name == name:
18
+ return initializer
19
+ return None
20
+
21
+
22
+ def _read_k(graph: Graph, name: str, node: Node) -> int | None:
23
+ initializer = _find_initializer(graph, name)
24
+ if initializer is None:
25
+ return None
26
+ if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
27
+ raise UnsupportedOpError(
28
+ f"{node.op_type} k input must be int64 or int32"
29
+ )
30
+ data = np.array(initializer.data, dtype=np.int64).reshape(-1)
31
+ if data.size != 1:
32
+ raise ShapeInferenceError(
33
+ f"{node.op_type} k input must contain a single value"
34
+ )
35
+ k = int(data[0])
36
+ if k <= 0:
37
+ raise ShapeInferenceError(
38
+ f"{node.op_type} k must be a positive value, got {k}"
39
+ )
40
+ return k
41
+
42
+
43
+ def _topk_dtype_supported(dtype: ScalarType) -> bool:
44
+ return not dtype.is_bool
45
+
46
+
47
+ def lower_topk(graph: Graph, node: Node) -> TopKOp:
48
+ if node.op_type != "TopK":
49
+ raise UnsupportedOpError(f"Unsupported op {node.op_type}")
50
+ if len(node.inputs) != 2 or len(node.outputs) != 2:
51
+ raise UnsupportedOpError(
52
+ f"{node.op_type} must have 2 inputs and 2 outputs"
53
+ )
54
+ input_name = node.inputs[0]
55
+ k_name = node.inputs[1]
56
+ output_values = node.outputs[0]
57
+ output_indices = node.outputs[1]
58
+ input_shape = value_shape(graph, input_name, node)
59
+ shape_product(input_shape)
60
+ axis = int(node.attrs.get("axis", -1))
61
+ axis = normalize_axis(axis, input_shape, node)
62
+ k = _read_k(graph, k_name, node)
63
+ axis_dim = input_shape[axis]
64
+ values_shape = value_shape(graph, output_values, node)
65
+ indices_shape = value_shape(graph, output_indices, node)
66
+ if values_shape != indices_shape:
67
+ raise ShapeInferenceError(
68
+ f"{node.op_type} values and indices output shapes must match, "
69
+ f"got {values_shape} and {indices_shape}"
70
+ )
71
+ if k is None:
72
+ k_shape = value_shape(graph, k_name, node)
73
+ if len(k_shape) != 1 or k_shape[0] != 1:
74
+ raise ShapeInferenceError(
75
+ f"{node.op_type} k input must be a 1-element tensor"
76
+ )
77
+ if axis >= len(values_shape):
78
+ raise ShapeInferenceError(
79
+ f"{node.op_type} axis {axis} exceeds output rank {len(values_shape)}"
80
+ )
81
+ k = values_shape[axis]
82
+ if k <= 0:
83
+ raise ShapeInferenceError(
84
+ f"{node.op_type} k must be a positive value, got {k}"
85
+ )
86
+ if k > axis_dim:
87
+ raise ShapeInferenceError(
88
+ f"{node.op_type} k {k} exceeds axis dimension {axis_dim}"
89
+ )
90
+ output_shape_expected = list(input_shape)
91
+ output_shape_expected[axis] = k
92
+ output_shape = tuple(output_shape_expected)
93
+ if values_shape != output_shape:
94
+ raise ShapeInferenceError(
95
+ f"{node.op_type} values output shape must be {output_shape}, got {values_shape}"
96
+ )
97
+ if indices_shape != output_shape:
98
+ raise ShapeInferenceError(
99
+ f"{node.op_type} indices output shape must be {output_shape}, got {indices_shape}"
100
+ )
101
+ input_dtype = value_dtype(graph, input_name, node)
102
+ if not _topk_dtype_supported(input_dtype):
103
+ raise UnsupportedOpError(
104
+ f"{node.op_type} does not support dtype {input_dtype.onnx_name}"
105
+ )
106
+ values_dtype = value_dtype(graph, output_values, node)
107
+ if values_dtype != input_dtype:
108
+ raise UnsupportedOpError(
109
+ f"{node.op_type} values output dtype must be {input_dtype.onnx_name}"
110
+ )
111
+ indices_dtype = value_dtype(graph, output_indices, node)
112
+ if indices_dtype != ScalarType.I64:
113
+ raise UnsupportedOpError(
114
+ f"{node.op_type} indices output dtype must be int64"
115
+ )
116
+ largest = bool(int(node.attrs.get("largest", 1)))
117
+ sorted_output = bool(int(node.attrs.get("sorted", 1)))
118
+ return TopKOp(
119
+ input0=input_name,
120
+ output_values=output_values,
121
+ output_indices=output_indices,
122
+ input_shape=input_shape,
123
+ output_shape=output_shape,
124
+ axis=axis,
125
+ k=k,
126
+ largest=largest,
127
+ sorted=sorted_output,
128
+ input_dtype=input_dtype,
129
+ output_values_dtype=values_dtype,
130
+ output_indices_dtype=indices_dtype,
131
+ )
132
+
133
+
134
+ register_lowering("TopK")(lower_topk)
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import TransposeOp
3
+ from ..ir.ops import TransposeOp
4
4
  from ..errors import ShapeInferenceError, UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from .common import node_dtype as _node_dtype
@@ -0,0 +1,89 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..ir.ops import TriluOp
8
+ from ..errors import ShapeInferenceError, UnsupportedOpError
9
+ from ..ir.model import Graph, Initializer, Node
10
+ from ..lowering.common import optional_name, value_dtype, value_shape
11
+ from .registry import register_lowering
12
+
13
+
14
+ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
15
+ for initializer in graph.initializers:
16
+ if initializer.name == name:
17
+ return initializer
18
+ return None
19
+
20
+
21
+ def _is_scalar_shape(shape: tuple[int, ...]) -> bool:
22
+ return shape == () or shape == (1,)
23
+
24
+
25
+ def _read_k_initializer(initializer: Initializer, node: Node) -> int:
26
+ if initializer.type.dtype != ScalarType.I64:
27
+ raise UnsupportedOpError(
28
+ f"{node.op_type} k input must be int64"
29
+ )
30
+ data = np.array(initializer.data, dtype=np.int64).reshape(-1)
31
+ if data.size != 1:
32
+ raise UnsupportedOpError(f"{node.op_type} k input must be scalar")
33
+ return int(data[0])
34
+
35
+
36
+ @register_lowering("Trilu")
37
+ def lower_trilu(graph: Graph, node: Node) -> TriluOp:
38
+ if len(node.inputs) not in {1, 2} or len(node.outputs) != 1:
39
+ raise UnsupportedOpError("Trilu must have 1 or 2 inputs and 1 output")
40
+ input_name = node.inputs[0]
41
+ if not input_name:
42
+ raise UnsupportedOpError("Trilu input must be provided")
43
+ input_shape = value_shape(graph, input_name, node)
44
+ output_shape = value_shape(graph, node.outputs[0], node)
45
+ if input_shape != output_shape:
46
+ raise ShapeInferenceError("Trilu input and output shapes must match")
47
+ if len(output_shape) < 2:
48
+ raise UnsupportedOpError("Trilu expects input rank >= 2")
49
+ input_dtype = value_dtype(graph, input_name, node)
50
+ output_dtype = value_dtype(graph, node.outputs[0], node)
51
+ if input_dtype != output_dtype:
52
+ raise UnsupportedOpError(
53
+ "Trilu expects matching input/output dtypes, "
54
+ f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
55
+ )
56
+ upper_attr = node.attrs.get("upper", 1)
57
+ upper = bool(int(upper_attr))
58
+ k_input = optional_name(node.inputs, 1)
59
+ k_value = 0
60
+ k_input_name = None
61
+ k_input_shape = None
62
+ k_input_dtype = None
63
+ if k_input:
64
+ k_initializer = _find_initializer(graph, k_input)
65
+ if k_initializer is not None:
66
+ k_value = _read_k_initializer(k_initializer, node)
67
+ else:
68
+ k_shape = value_shape(graph, k_input, node)
69
+ if not _is_scalar_shape(k_shape):
70
+ raise UnsupportedOpError("Trilu k input must be scalar")
71
+ k_dtype = value_dtype(graph, k_input, node)
72
+ if k_dtype != ScalarType.I64:
73
+ raise UnsupportedOpError("Trilu k input must be int64")
74
+ k_input_name = k_input
75
+ k_input_shape = k_shape
76
+ k_input_dtype = k_dtype
77
+ return TriluOp(
78
+ input0=input_name,
79
+ output=node.outputs[0],
80
+ input_shape=input_shape,
81
+ output_shape=output_shape,
82
+ upper=upper,
83
+ k_value=k_value,
84
+ k_input=k_input_name,
85
+ k_input_shape=k_input_shape,
86
+ k_input_dtype=k_input_dtype,
87
+ dtype=output_dtype,
88
+ input_dtype=input_dtype,
89
+ )
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import ReshapeOp
5
+ from ..ir.ops import ReshapeOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Initializer, Node
8
8
  from .registry import register_lowering
@@ -131,11 +131,11 @@ def lower_unsqueeze(graph: Graph, node: Node) -> ReshapeOp:
131
131
  for dim in output_shape:
132
132
  if input_index < len(input_shape) and dim == input_shape[input_index]:
133
133
  input_index += 1
134
- continue
135
- if dim != 1:
136
- raise ShapeInferenceError(
137
- "Unsqueeze output shape must insert ones only"
138
- )
134
+ else:
135
+ if dim != 1:
136
+ raise ShapeInferenceError(
137
+ "Unsqueeze output shape must insert ones only"
138
+ )
139
139
  if input_index != len(input_shape):
140
140
  raise ShapeInferenceError(
141
141
  "Unsqueeze output shape must contain input shape in order"
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  from shared.scalar_functions import ScalarFunction
4
4
  from shared.scalar_types import ScalarType
5
5
 
6
- from ..codegen.c_emitter import MultiInputBinaryOp
6
+ from ..ir.ops import MultiInputBinaryOp
7
7
  from ..errors import UnsupportedOpError
8
8
  from ..ir.model import Graph, Node
9
9
  from ..lowering.common import node_dtype, value_dtype, value_shape
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import WhereOp
5
+ from ..ir.ops import WhereOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .common import value_dtype as _value_dtype
@@ -212,6 +212,9 @@ def import_onnx(model: onnx.ModelProto) -> Graph:
212
212
  dim_param_by_name = _collect_dim_params(
213
213
  tuple(model.graph.input) + tuple(model.graph.output)
214
214
  )
215
+ opset_imports = tuple(
216
+ (opset.domain, opset.version) for opset in model.opset_import
217
+ )
215
218
  try:
216
219
  model = shape_inference.infer_shapes(model, data_prop=True)
217
220
  except Exception as exc: # pragma: no cover - onnx inference errors
@@ -258,4 +261,5 @@ def import_onnx(model: onnx.ModelProto) -> Graph:
258
261
  nodes=nodes,
259
262
  initializers=initializers,
260
263
  values=values,
264
+ opset_imports=opset_imports,
261
265
  )
@@ -0,0 +1,11 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+
6
+ def make_deterministic_session_options(ort: Any) -> Any:
7
+ options = ort.SessionOptions()
8
+ options.intra_op_num_threads = 1
9
+ options.inter_op_num_threads = 1
10
+ options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
11
+ return options
emx_onnx_cgen/ops.py CHANGED
@@ -87,6 +87,7 @@ UNARY_OP_TYPES = {
87
87
  "Identity",
88
88
  "LeakyRelu",
89
89
  "Log",
90
+ "Mish",
90
91
  "Neg",
91
92
  "Not",
92
93
  "Reciprocal",
@@ -177,6 +178,7 @@ UNARY_SYMBOLS_DOUBLE = {
177
178
  ScalarFunction.LEAKY_RELU: "leaky_relu",
178
179
  ScalarFunction.POSITIVE: "identity",
179
180
  ScalarFunction.LOG: "log",
181
+ ScalarFunction.MISH: "mish",
180
182
  ScalarFunction.NEG: "neg",
181
183
  ScalarFunction.RECIPROCAL: "reciprocal",
182
184
  ScalarFunction.RELU: "relu",
@@ -215,6 +217,7 @@ UNARY_SYMBOLS_FLOAT = {
215
217
  ScalarFunction.LEAKY_RELU: "leaky_relu",
216
218
  ScalarFunction.POSITIVE: "identity",
217
219
  ScalarFunction.LOG: "logf",
220
+ ScalarFunction.MISH: "mish",
218
221
  ScalarFunction.NEG: "neg",
219
222
  ScalarFunction.RECIPROCAL: "reciprocal",
220
223
  ScalarFunction.RELU: "relu",
@@ -457,6 +460,7 @@ UNARY_APPLY_FUNCS = {
457
460
  "thresholded_relu": lambda value: np.where(
458
461
  value > 1.0, value, 0.0
459
462
  ),
463
+ "mish": lambda value: value * np.tanh(np.log1p(np.exp(value))),
460
464
  "atanhf": np.arctanh,
461
465
  "atanh": np.arctanh,
462
466
  }