emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (42) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +34 -0
  3. emx_onnx_cgen/cli.py +340 -59
  4. emx_onnx_cgen/codegen/c_emitter.py +2369 -111
  5. emx_onnx_cgen/compiler.py +188 -5
  6. emx_onnx_cgen/ir/model.py +1 -0
  7. emx_onnx_cgen/lowering/common.py +379 -2
  8. emx_onnx_cgen/lowering/conv_transpose.py +301 -0
  9. emx_onnx_cgen/lowering/einsum.py +153 -0
  10. emx_onnx_cgen/lowering/gather_elements.py +1 -3
  11. emx_onnx_cgen/lowering/gather_nd.py +79 -0
  12. emx_onnx_cgen/lowering/global_max_pool.py +59 -0
  13. emx_onnx_cgen/lowering/hardmax.py +53 -0
  14. emx_onnx_cgen/lowering/identity.py +6 -5
  15. emx_onnx_cgen/lowering/logsoftmax.py +5 -1
  16. emx_onnx_cgen/lowering/lp_pool.py +141 -0
  17. emx_onnx_cgen/lowering/matmul.py +6 -7
  18. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +12 -12
  19. emx_onnx_cgen/lowering/nonzero.py +42 -0
  20. emx_onnx_cgen/lowering/one_hot.py +120 -0
  21. emx_onnx_cgen/lowering/quantize_linear.py +126 -0
  22. emx_onnx_cgen/lowering/reduce.py +5 -6
  23. emx_onnx_cgen/lowering/reshape.py +223 -51
  24. emx_onnx_cgen/lowering/scatter_nd.py +82 -0
  25. emx_onnx_cgen/lowering/softmax.py +5 -1
  26. emx_onnx_cgen/lowering/squeeze.py +5 -5
  27. emx_onnx_cgen/lowering/topk.py +116 -0
  28. emx_onnx_cgen/lowering/trilu.py +89 -0
  29. emx_onnx_cgen/lowering/unsqueeze.py +5 -5
  30. emx_onnx_cgen/onnx_import.py +4 -0
  31. emx_onnx_cgen/onnxruntime_utils.py +11 -0
  32. emx_onnx_cgen/ops.py +4 -0
  33. emx_onnx_cgen/runtime/evaluator.py +460 -42
  34. emx_onnx_cgen/testbench.py +23 -0
  35. emx_onnx_cgen/verification.py +61 -0
  36. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/METADATA +31 -5
  37. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/RECORD +42 -25
  38. shared/scalar_functions.py +49 -17
  39. shared/ulp.py +48 -0
  40. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/WHEEL +0 -0
  41. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/entry_points.txt +0 -0
  42. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,153 @@
1
+ from __future__ import annotations
2
+
3
+ from ..codegen.c_emitter import EinsumKind, EinsumOp
4
+ from ..errors import ShapeInferenceError, UnsupportedOpError
5
+ from ..ir.model import Graph, Node
6
+ from .common import node_dtype as _node_dtype
7
+ from .common import value_shape as _value_shape
8
+ from .registry import register_lowering
9
+
10
+
11
+ def _normalize_equation(equation: str) -> str:
12
+ return equation.replace(" ", "")
13
+
14
+
15
+ @register_lowering("Einsum")
16
+ def lower_einsum(graph: Graph, node: Node) -> EinsumOp:
17
+ if not node.inputs or len(node.outputs) != 1:
18
+ raise UnsupportedOpError("Einsum must have 1 output and at least 1 input")
19
+ equation_value = node.attrs.get("equation")
20
+ if equation_value is None:
21
+ raise UnsupportedOpError("Einsum equation attribute is required")
22
+ equation = (
23
+ equation_value.decode()
24
+ if isinstance(equation_value, (bytes, bytearray))
25
+ else str(equation_value)
26
+ )
27
+ normalized = _normalize_equation(equation)
28
+ input_shapes = tuple(
29
+ _value_shape(graph, name, node) for name in node.inputs
30
+ )
31
+ output_shape = _value_shape(graph, node.outputs[0], node)
32
+ op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
33
+ if normalized == "->":
34
+ if len(node.inputs) != 1:
35
+ raise UnsupportedOpError("Einsum '->' must have 1 input")
36
+ if output_shape:
37
+ raise ShapeInferenceError(
38
+ "Einsum '->' output must be scalar, "
39
+ f"got shape {output_shape}"
40
+ )
41
+ kind = EinsumKind.REDUCE_ALL
42
+ elif normalized == "ij->i":
43
+ if len(node.inputs) != 1:
44
+ raise UnsupportedOpError("Einsum 'ij->i' must have 1 input")
45
+ input_shape = input_shapes[0]
46
+ if len(input_shape) != 2:
47
+ raise ShapeInferenceError(
48
+ "Einsum 'ij->i' input must be 2D, "
49
+ f"got shape {input_shape}"
50
+ )
51
+ expected = (input_shape[0],)
52
+ if output_shape != expected:
53
+ raise ShapeInferenceError(
54
+ f"Einsum 'ij->i' output must match shape {expected}, "
55
+ f"got {output_shape}"
56
+ )
57
+ kind = EinsumKind.SUM_J
58
+ elif normalized == "ij->ji":
59
+ if len(node.inputs) != 1:
60
+ raise UnsupportedOpError("Einsum 'ij->ji' must have 1 input")
61
+ input_shape = input_shapes[0]
62
+ if len(input_shape) != 2:
63
+ raise ShapeInferenceError(
64
+ "Einsum 'ij->ji' input must be 2D, "
65
+ f"got shape {input_shape}"
66
+ )
67
+ expected = (input_shape[1], input_shape[0])
68
+ if output_shape != expected:
69
+ raise ShapeInferenceError(
70
+ f"Einsum 'ij->ji' output must match shape {expected}, "
71
+ f"got {output_shape}"
72
+ )
73
+ kind = EinsumKind.TRANSPOSE
74
+ elif normalized in {"i,i", "i,i->"}:
75
+ if len(node.inputs) != 2:
76
+ raise UnsupportedOpError("Einsum 'i,i' must have 2 inputs")
77
+ left_shape, right_shape = input_shapes
78
+ if len(left_shape) != 1 or len(right_shape) != 1:
79
+ raise ShapeInferenceError(
80
+ "Einsum 'i,i' inputs must be vectors, "
81
+ f"got shapes {left_shape} and {right_shape}"
82
+ )
83
+ if left_shape[0] != right_shape[0]:
84
+ raise ShapeInferenceError(
85
+ "Einsum 'i,i' inputs must have the same length, "
86
+ f"got shapes {left_shape} and {right_shape}"
87
+ )
88
+ if output_shape:
89
+ raise ShapeInferenceError(
90
+ "Einsum 'i,i' output must be scalar, "
91
+ f"got shape {output_shape}"
92
+ )
93
+ kind = EinsumKind.DOT
94
+ elif normalized == "bij,bjk->bik":
95
+ if len(node.inputs) != 2:
96
+ raise UnsupportedOpError("Einsum 'bij,bjk->bik' must have 2 inputs")
97
+ left_shape, right_shape = input_shapes
98
+ if len(left_shape) != 3 or len(right_shape) != 3:
99
+ raise ShapeInferenceError(
100
+ "Einsum 'bij,bjk->bik' inputs must be 3D, "
101
+ f"got shapes {left_shape} and {right_shape}"
102
+ )
103
+ if left_shape[0] != right_shape[0]:
104
+ raise ShapeInferenceError(
105
+ "Einsum 'bij,bjk->bik' batch dimensions must match, "
106
+ f"got shapes {left_shape} and {right_shape}"
107
+ )
108
+ if left_shape[2] != right_shape[1]:
109
+ raise ShapeInferenceError(
110
+ "Einsum 'bij,bjk->bik' contraction dimensions must match, "
111
+ f"got shapes {left_shape} and {right_shape}"
112
+ )
113
+ expected = (left_shape[0], left_shape[1], right_shape[2])
114
+ if output_shape != expected:
115
+ raise ShapeInferenceError(
116
+ f"Einsum 'bij,bjk->bik' output must match shape {expected}, "
117
+ f"got {output_shape}"
118
+ )
119
+ kind = EinsumKind.BATCH_MATMUL
120
+ elif normalized == "...ii->...i":
121
+ if len(node.inputs) != 1:
122
+ raise UnsupportedOpError("Einsum '...ii->...i' must have 1 input")
123
+ input_shape = input_shapes[0]
124
+ if len(input_shape) < 2:
125
+ raise ShapeInferenceError(
126
+ "Einsum '...ii->...i' input must be at least 2D, "
127
+ f"got shape {input_shape}"
128
+ )
129
+ if input_shape[-1] != input_shape[-2]:
130
+ raise ShapeInferenceError(
131
+ "Einsum '...ii->...i' requires last two dims to match, "
132
+ f"got shape {input_shape}"
133
+ )
134
+ expected = (*input_shape[:-2], input_shape[-1])
135
+ if output_shape != expected:
136
+ raise ShapeInferenceError(
137
+ f"Einsum '...ii->...i' output must match shape {expected}, "
138
+ f"got {output_shape}"
139
+ )
140
+ kind = EinsumKind.BATCH_DIAGONAL
141
+ else:
142
+ raise UnsupportedOpError(
143
+ f"Unsupported Einsum equation '{equation}'"
144
+ )
145
+ return EinsumOp(
146
+ inputs=tuple(node.inputs),
147
+ output=node.outputs[0],
148
+ kind=kind,
149
+ input_shapes=input_shapes,
150
+ output_shape=output_shape,
151
+ dtype=op_dtype,
152
+ input_dtype=op_dtype,
153
+ )
@@ -33,9 +33,7 @@ def lower_gather_elements(graph: Graph, node: Node) -> GatherElementsOp:
33
33
  for dim_index, (data_dim, index_dim) in enumerate(
34
34
  zip(data_shape, indices_shape)
35
35
  ):
36
- if dim_index == axis:
37
- continue
38
- if data_dim != index_dim:
36
+ if dim_index != axis and data_dim != index_dim:
39
37
  raise ShapeInferenceError(
40
38
  "GatherElements inputs must match on non-axis dimensions, "
41
39
  f"got {data_shape} and {indices_shape}"
@@ -0,0 +1,79 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..codegen.c_emitter import GatherNDOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .common import value_dtype as _value_dtype
9
+ from .common import value_shape as _value_shape
10
+ from .registry import register_lowering
11
+
12
+
13
+ @register_lowering("GatherND")
14
+ def lower_gather_nd(graph: Graph, node: Node) -> GatherNDOp:
15
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
16
+ raise UnsupportedOpError("GatherND must have 2 inputs and 1 output")
17
+ data_name, indices_name = node.inputs
18
+ output_name = node.outputs[0]
19
+ data_shape = _value_shape(graph, data_name, node)
20
+ indices_shape = _value_shape(graph, indices_name, node)
21
+ output_shape = _value_shape(graph, output_name, node)
22
+ if len(indices_shape) < 1:
23
+ raise ShapeInferenceError("GatherND indices must have rank >= 1")
24
+ batch_dims = int(node.attrs.get("batch_dims", 0))
25
+ if batch_dims < 0:
26
+ raise ShapeInferenceError(
27
+ f"GatherND batch_dims must be >= 0, got {batch_dims}"
28
+ )
29
+ if batch_dims > len(indices_shape) - 1:
30
+ raise ShapeInferenceError(
31
+ "GatherND batch_dims must be <= indices rank - 1, "
32
+ f"got {batch_dims} vs {len(indices_shape) - 1}"
33
+ )
34
+ if batch_dims > len(data_shape):
35
+ raise ShapeInferenceError(
36
+ "GatherND batch_dims must be <= data rank, "
37
+ f"got {batch_dims} vs {len(data_shape)}"
38
+ )
39
+ if tuple(data_shape[:batch_dims]) != tuple(indices_shape[:batch_dims]):
40
+ raise ShapeInferenceError(
41
+ "GatherND batch_dims must match on data/indices, "
42
+ f"got {data_shape} vs {indices_shape}"
43
+ )
44
+ index_depth = indices_shape[-1]
45
+ if index_depth <= 0:
46
+ raise ShapeInferenceError(
47
+ "GatherND indices final dimension must be >= 1"
48
+ )
49
+ if index_depth > len(data_shape) - batch_dims:
50
+ raise ShapeInferenceError(
51
+ "GatherND indices final dimension must be <= data rank - "
52
+ f"batch_dims, got {index_depth} vs {len(data_shape) - batch_dims}"
53
+ )
54
+ expected_output_shape = indices_shape[:-1] + data_shape[
55
+ batch_dims + index_depth :
56
+ ]
57
+ if output_shape != expected_output_shape:
58
+ raise ShapeInferenceError(
59
+ "GatherND output shape must be "
60
+ f"{expected_output_shape}, got {output_shape}"
61
+ )
62
+ data_dtype = _value_dtype(graph, data_name, node)
63
+ indices_dtype = _value_dtype(graph, indices_name, node)
64
+ if indices_dtype not in {ScalarType.I64, ScalarType.I32}:
65
+ raise UnsupportedOpError(
66
+ "GatherND indices must be int32 or int64, "
67
+ f"got {indices_dtype.onnx_name}"
68
+ )
69
+ return GatherNDOp(
70
+ data=data_name,
71
+ indices=indices_name,
72
+ output=output_name,
73
+ batch_dims=batch_dims,
74
+ data_shape=data_shape,
75
+ indices_shape=indices_shape,
76
+ output_shape=output_shape,
77
+ dtype=data_dtype,
78
+ indices_dtype=indices_dtype,
79
+ )
@@ -0,0 +1,59 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..codegen.c_emitter import ReduceOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .common import value_dtype as _value_dtype
9
+ from .common import value_shape as _value_shape
10
+ from .registry import register_lowering
11
+
12
+
13
+ @register_lowering("GlobalMaxPool")
14
+ def lower_global_max_pool(graph: Graph, node: Node) -> ReduceOp:
15
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
16
+ raise UnsupportedOpError("GlobalMaxPool must have 1 input and 1 output")
17
+ if node.attrs:
18
+ raise UnsupportedOpError("GlobalMaxPool has unsupported attributes")
19
+ op_dtype = _value_dtype(graph, node.inputs[0], node)
20
+ output_dtype = _value_dtype(graph, node.outputs[0], node)
21
+ if op_dtype != output_dtype:
22
+ raise UnsupportedOpError(
23
+ "GlobalMaxPool expects matching input/output dtypes, "
24
+ f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
25
+ )
26
+ if op_dtype not in {ScalarType.F16, ScalarType.F32, ScalarType.F64}:
27
+ raise UnsupportedOpError(
28
+ "GlobalMaxPool supports float16, float, and double inputs only"
29
+ )
30
+ input_shape = _value_shape(graph, node.inputs[0], node)
31
+ if len(input_shape) < 3:
32
+ raise UnsupportedOpError(
33
+ "GlobalMaxPool expects input rank of at least 3"
34
+ )
35
+ output_shape = _value_shape(graph, node.outputs[0], node)
36
+ expected_output_shape = (input_shape[0], input_shape[1]) + (
37
+ 1,
38
+ ) * (len(input_shape) - 2)
39
+ if output_shape != expected_output_shape:
40
+ raise ShapeInferenceError(
41
+ "GlobalMaxPool output shape must be "
42
+ f"{expected_output_shape}, got {output_shape}"
43
+ )
44
+ axes = tuple(range(2, len(input_shape)))
45
+ return ReduceOp(
46
+ input0=node.inputs[0],
47
+ output=node.outputs[0],
48
+ input_shape=input_shape,
49
+ output_shape=output_shape,
50
+ axes=axes,
51
+ axes_input=None,
52
+ axes_input_shape=None,
53
+ axes_input_dtype=None,
54
+ keepdims=True,
55
+ noop_with_empty_axes=False,
56
+ reduce_kind="max",
57
+ reduce_count=None,
58
+ dtype=op_dtype,
59
+ )
@@ -0,0 +1,53 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..codegen.c_emitter import HardmaxOp
6
+ from ..errors import UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .common import node_dtype as _node_dtype
9
+ from .common import onnx_opset_version as _onnx_opset_version
10
+ from .common import shape_product as _shape_product
11
+ from .common import value_shape as _value_shape
12
+ from .registry import register_lowering
13
+ from ..validation import ensure_output_shape_matches_input
14
+ from ..validation import normalize_axis as _normalize_axis
15
+
16
+
17
+ @register_lowering("Hardmax")
18
+ def lower_hardmax(graph: Graph, node: Node) -> HardmaxOp:
19
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
20
+ raise UnsupportedOpError("Hardmax must have 1 input and 1 output")
21
+ op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
22
+ if op_dtype not in {ScalarType.F16, ScalarType.F32, ScalarType.F64}:
23
+ raise UnsupportedOpError(
24
+ "Hardmax supports float16, float, and double inputs only"
25
+ )
26
+ input_shape = _value_shape(graph, node.inputs[0], node)
27
+ output_shape = _value_shape(graph, node.outputs[0], node)
28
+ ensure_output_shape_matches_input(node, input_shape, output_shape)
29
+ opset_version = _onnx_opset_version(graph)
30
+ default_axis = 1 if opset_version is not None and opset_version < 13 else -1
31
+ axis_attr = node.attrs.get("axis", default_axis)
32
+ axis = _normalize_axis(
33
+ int(axis_attr),
34
+ input_shape,
35
+ node,
36
+ )
37
+ outer = _shape_product(input_shape[:axis]) if axis > 0 else 1
38
+ axis_size = input_shape[axis]
39
+ inner = (
40
+ _shape_product(input_shape[axis + 1 :])
41
+ if axis + 1 < len(input_shape)
42
+ else 1
43
+ )
44
+ return HardmaxOp(
45
+ input0=node.inputs[0],
46
+ output=node.outputs[0],
47
+ outer=outer,
48
+ axis_size=axis_size,
49
+ inner=inner,
50
+ axis=axis,
51
+ shape=input_shape,
52
+ dtype=op_dtype,
53
+ )
@@ -22,11 +22,12 @@ def lower_identity(graph: Graph, node: Node) -> IdentityOp:
22
22
  for index, (input_dim, output_dim) in enumerate(
23
23
  zip(input_shape, output_shape)
24
24
  ):
25
- if input_dim == output_dim:
26
- continue
27
- if input_dim_params[index] or output_dim_params[index]:
28
- continue
29
- raise ShapeInferenceError("Identity input and output shapes must match")
25
+ if input_dim != output_dim and not (
26
+ input_dim_params[index] or output_dim_params[index]
27
+ ):
28
+ raise ShapeInferenceError(
29
+ "Identity input and output shapes must match"
30
+ )
30
31
  input_dtype = value_dtype(graph, node.inputs[0], node)
31
32
  output_dtype = value_dtype(graph, node.outputs[0], node)
32
33
  if input_dtype != output_dtype:
@@ -4,6 +4,7 @@ from ..codegen.c_emitter import LogSoftmaxOp
4
4
  from ..errors import UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from .common import node_dtype as _node_dtype
7
+ from .common import onnx_opset_version as _onnx_opset_version
7
8
  from .common import shape_product as _shape_product
8
9
  from .common import value_shape as _value_shape
9
10
  from .registry import register_lowering
@@ -23,8 +24,11 @@ def lower_logsoftmax(graph: Graph, node: Node) -> LogSoftmaxOp:
23
24
  input_shape = _value_shape(graph, node.inputs[0], node)
24
25
  output_shape = _value_shape(graph, node.outputs[0], node)
25
26
  ensure_output_shape_matches_input(node, input_shape, output_shape)
27
+ opset_version = _onnx_opset_version(graph)
28
+ default_axis = 1 if opset_version is not None and opset_version < 13 else -1
29
+ axis_attr = node.attrs.get("axis", default_axis)
26
30
  axis = _normalize_axis(
27
- int(node.attrs.get("axis", -1)),
31
+ int(axis_attr),
28
32
  input_shape,
29
33
  node,
30
34
  )
@@ -0,0 +1,141 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from ..codegen.c_emitter import LpPoolOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .registry import register_lowering
9
+ from .common import value_dtype as _value_dtype, value_shape as _value_shape
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class LpPoolSpec:
14
+ batch: int
15
+ channels: int
16
+ in_h: int
17
+ in_w: int
18
+ out_h: int
19
+ out_w: int
20
+ kernel_h: int
21
+ kernel_w: int
22
+ stride_h: int
23
+ stride_w: int
24
+ pad_top: int
25
+ pad_left: int
26
+ pad_bottom: int
27
+ pad_right: int
28
+ p: int
29
+
30
+
31
+ def _resolve_lp_pool_spec(graph: Graph, node: Node) -> LpPoolSpec:
32
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
33
+ raise UnsupportedOpError("LpPool must have 1 input and 1 output")
34
+ supported_attrs = {
35
+ "auto_pad",
36
+ "ceil_mode",
37
+ "dilations",
38
+ "kernel_shape",
39
+ "pads",
40
+ "p",
41
+ "strides",
42
+ }
43
+ if set(node.attrs) - supported_attrs:
44
+ raise UnsupportedOpError("LpPool has unsupported attributes")
45
+ auto_pad = node.attrs.get("auto_pad", b"NOTSET")
46
+ if isinstance(auto_pad, bytes):
47
+ auto_pad = auto_pad.decode("utf-8", errors="ignore")
48
+ if auto_pad not in ("", "NOTSET"):
49
+ raise UnsupportedOpError("LpPool supports auto_pad=NOTSET only")
50
+ ceil_mode = int(node.attrs.get("ceil_mode", 0))
51
+ if ceil_mode != 0:
52
+ raise UnsupportedOpError("LpPool supports ceil_mode=0 only")
53
+ dilations = tuple(int(value) for value in node.attrs.get("dilations", (1, 1)))
54
+ if any(value != 1 for value in dilations):
55
+ raise UnsupportedOpError("LpPool supports dilations=1 only")
56
+ kernel_shape = node.attrs.get("kernel_shape")
57
+ if kernel_shape is None:
58
+ raise UnsupportedOpError("LpPool requires kernel_shape")
59
+ kernel_shape = tuple(int(value) for value in kernel_shape)
60
+ if len(kernel_shape) != 2:
61
+ raise UnsupportedOpError("LpPool expects 2D kernel_shape")
62
+ kernel_h, kernel_w = kernel_shape
63
+ strides = tuple(int(value) for value in node.attrs.get("strides", (1, 1)))
64
+ if len(strides) != 2:
65
+ raise UnsupportedOpError("LpPool expects 2D strides")
66
+ pads = tuple(int(value) for value in node.attrs.get("pads", (0, 0, 0, 0)))
67
+ if len(pads) != 4:
68
+ raise UnsupportedOpError("LpPool expects 4D pads")
69
+ pad_top, pad_left, pad_bottom, pad_right = pads
70
+ p = int(node.attrs.get("p", 2))
71
+ if p < 1:
72
+ raise UnsupportedOpError("LpPool p must be >= 1")
73
+ input_shape = _value_shape(graph, node.inputs[0], node)
74
+ if len(input_shape) != 4:
75
+ raise UnsupportedOpError("LpPool supports NCHW 2D inputs only")
76
+ batch, channels, in_h, in_w = input_shape
77
+ stride_h, stride_w = strides
78
+ out_h = (in_h + pad_top + pad_bottom - kernel_h) // stride_h + 1
79
+ out_w = (in_w + pad_left + pad_right - kernel_w) // stride_w + 1
80
+ if out_h < 0 or out_w < 0:
81
+ raise ShapeInferenceError("LpPool output shape must be non-negative")
82
+ output_shape = _value_shape(graph, node.outputs[0], node)
83
+ expected_output_shape = (batch, channels, out_h, out_w)
84
+ if output_shape != expected_output_shape:
85
+ raise ShapeInferenceError(
86
+ "LpPool output shape must be "
87
+ f"{expected_output_shape}, got {output_shape}"
88
+ )
89
+ return LpPoolSpec(
90
+ batch=batch,
91
+ channels=channels,
92
+ in_h=in_h,
93
+ in_w=in_w,
94
+ out_h=out_h,
95
+ out_w=out_w,
96
+ kernel_h=kernel_h,
97
+ kernel_w=kernel_w,
98
+ stride_h=stride_h,
99
+ stride_w=stride_w,
100
+ pad_top=pad_top,
101
+ pad_left=pad_left,
102
+ pad_bottom=pad_bottom,
103
+ pad_right=pad_right,
104
+ p=p,
105
+ )
106
+
107
+
108
+ @register_lowering("LpPool")
109
+ def lower_lp_pool(graph: Graph, node: Node) -> LpPoolOp:
110
+ op_dtype = _value_dtype(graph, node.inputs[0], node)
111
+ output_dtype = _value_dtype(graph, node.outputs[0], node)
112
+ if op_dtype != output_dtype:
113
+ raise UnsupportedOpError(
114
+ "LpPool expects matching input/output dtypes, "
115
+ f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
116
+ )
117
+ if not op_dtype.is_float:
118
+ raise UnsupportedOpError(
119
+ "LpPool supports float16, float, and double inputs only"
120
+ )
121
+ spec = _resolve_lp_pool_spec(graph, node)
122
+ return LpPoolOp(
123
+ input0=node.inputs[0],
124
+ output=node.outputs[0],
125
+ batch=spec.batch,
126
+ channels=spec.channels,
127
+ in_h=spec.in_h,
128
+ in_w=spec.in_w,
129
+ out_h=spec.out_h,
130
+ out_w=spec.out_w,
131
+ kernel_h=spec.kernel_h,
132
+ kernel_w=spec.kernel_w,
133
+ stride_h=spec.stride_h,
134
+ stride_w=spec.stride_w,
135
+ pad_top=spec.pad_top,
136
+ pad_left=spec.pad_left,
137
+ pad_bottom=spec.pad_bottom,
138
+ pad_right=spec.pad_right,
139
+ p=spec.p,
140
+ dtype=op_dtype,
141
+ )
@@ -87,13 +87,12 @@ def _broadcast_batch_shapes(
87
87
  right_padded = (1,) * (max_rank - len(right)) + right
88
88
  broadcast_shape = []
89
89
  for left_dim, right_dim in zip(left_padded, right_padded):
90
- if left_dim == right_dim or left_dim == 1 or right_dim == 1:
91
- broadcast_shape.append(max(left_dim, right_dim))
92
- continue
93
- raise ShapeInferenceError(
94
- "MatMul batch dimensions must be broadcastable, "
95
- f"got {left} x {right}"
96
- )
90
+ if not (left_dim == right_dim or left_dim == 1 or right_dim == 1):
91
+ raise ShapeInferenceError(
92
+ "MatMul batch dimensions must be broadcastable, "
93
+ f"got {left} x {right}"
94
+ )
95
+ broadcast_shape.append(max(left_dim, right_dim))
97
96
  return tuple(broadcast_shape), left_padded, right_padded
98
97
 
99
98
 
@@ -43,18 +43,18 @@ def _resolve_target_shape(
43
43
  raise ShapeInferenceError("Reshape allows only one -1 dimension")
44
44
  unknown_index = index
45
45
  output_dims.append(-1)
46
- continue
47
- if dim == 0:
48
- if allowzero == 0:
49
- if index >= len(input_shape):
50
- raise ShapeInferenceError(
51
- "Reshape zero dim must index into input shape"
52
- )
53
- dim = input_shape[index]
54
- if dim < 0:
55
- raise ShapeInferenceError("Reshape dims must be >= -1")
56
- output_dims.append(dim)
57
- known_product *= dim
46
+ else:
47
+ if dim == 0:
48
+ if allowzero == 0:
49
+ if index >= len(input_shape):
50
+ raise ShapeInferenceError(
51
+ "Reshape zero dim must index into input shape"
52
+ )
53
+ dim = input_shape[index]
54
+ if dim < 0:
55
+ raise ShapeInferenceError("Reshape dims must be >= -1")
56
+ output_dims.append(dim)
57
+ known_product *= dim
58
58
  input_product = _shape_product(input_shape)
59
59
  if unknown_index is not None:
60
60
  if known_product == 0 or input_product % known_product != 0:
@@ -0,0 +1,42 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..codegen.c_emitter import NonZeroOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .common import value_dtype, value_shape
9
+ from .registry import register_lowering
10
+
11
+
12
+ @register_lowering("NonZero")
13
+ def lower_nonzero(graph: Graph, node: Node) -> NonZeroOp:
14
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
15
+ raise UnsupportedOpError("NonZero must have 1 input and 1 output")
16
+ input_shape = value_shape(graph, node.inputs[0], node)
17
+ if len(input_shape) == 0:
18
+ raise UnsupportedOpError("NonZero does not support scalar inputs")
19
+ output_shape = value_shape(graph, node.outputs[0], node)
20
+ if len(output_shape) != 2:
21
+ raise ShapeInferenceError("NonZero output must be 2D")
22
+ if output_shape[0] != len(input_shape):
23
+ raise ShapeInferenceError(
24
+ "NonZero output shape must be "
25
+ f"({len(input_shape)}, N), got {output_shape}"
26
+ )
27
+ if output_shape[0] < 0 or output_shape[1] < 0:
28
+ raise ShapeInferenceError(
29
+ "NonZero output shape must be non-negative"
30
+ )
31
+ output_dtype = value_dtype(graph, node.outputs[0], node)
32
+ if output_dtype != ScalarType.I64:
33
+ raise UnsupportedOpError("NonZero output dtype must be int64")
34
+ input_dtype = value_dtype(graph, node.inputs[0], node)
35
+ return NonZeroOp(
36
+ input0=node.inputs[0],
37
+ output=node.outputs[0],
38
+ input_shape=input_shape,
39
+ output_shape=output_shape,
40
+ dtype=output_dtype,
41
+ input_dtype=input_dtype,
42
+ )