emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (99) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +34 -0
  3. emx_onnx_cgen/cli.py +372 -64
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +3932 -1398
  6. emx_onnx_cgen/codegen/emitter.py +5 -0
  7. emx_onnx_cgen/compiler.py +169 -343
  8. emx_onnx_cgen/ir/context.py +87 -0
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +193 -0
  11. emx_onnx_cgen/ir/op_context.py +65 -0
  12. emx_onnx_cgen/ir/ops/__init__.py +130 -0
  13. emx_onnx_cgen/ir/ops/elementwise.py +146 -0
  14. emx_onnx_cgen/ir/ops/misc.py +421 -0
  15. emx_onnx_cgen/ir/ops/nn.py +580 -0
  16. emx_onnx_cgen/ir/ops/reduce.py +95 -0
  17. emx_onnx_cgen/lowering/__init__.py +79 -1
  18. emx_onnx_cgen/lowering/adagrad.py +114 -0
  19. emx_onnx_cgen/lowering/arg_reduce.py +1 -1
  20. emx_onnx_cgen/lowering/attention.py +1 -1
  21. emx_onnx_cgen/lowering/average_pool.py +1 -1
  22. emx_onnx_cgen/lowering/batch_normalization.py +1 -1
  23. emx_onnx_cgen/lowering/cast.py +1 -1
  24. emx_onnx_cgen/lowering/common.py +406 -11
  25. emx_onnx_cgen/lowering/concat.py +1 -1
  26. emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
  27. emx_onnx_cgen/lowering/conv.py +1 -1
  28. emx_onnx_cgen/lowering/conv_transpose.py +301 -0
  29. emx_onnx_cgen/lowering/cumsum.py +1 -1
  30. emx_onnx_cgen/lowering/depth_space.py +1 -1
  31. emx_onnx_cgen/lowering/dropout.py +1 -1
  32. emx_onnx_cgen/lowering/einsum.py +153 -0
  33. emx_onnx_cgen/lowering/elementwise.py +152 -4
  34. emx_onnx_cgen/lowering/expand.py +1 -1
  35. emx_onnx_cgen/lowering/eye_like.py +1 -1
  36. emx_onnx_cgen/lowering/flatten.py +1 -1
  37. emx_onnx_cgen/lowering/gather.py +1 -1
  38. emx_onnx_cgen/lowering/gather_elements.py +2 -4
  39. emx_onnx_cgen/lowering/gather_nd.py +79 -0
  40. emx_onnx_cgen/lowering/gemm.py +1 -1
  41. emx_onnx_cgen/lowering/global_max_pool.py +59 -0
  42. emx_onnx_cgen/lowering/grid_sample.py +1 -1
  43. emx_onnx_cgen/lowering/group_normalization.py +1 -1
  44. emx_onnx_cgen/lowering/hardmax.py +53 -0
  45. emx_onnx_cgen/lowering/identity.py +7 -6
  46. emx_onnx_cgen/lowering/instance_normalization.py +1 -1
  47. emx_onnx_cgen/lowering/layer_normalization.py +1 -1
  48. emx_onnx_cgen/lowering/logsoftmax.py +6 -2
  49. emx_onnx_cgen/lowering/lp_normalization.py +1 -1
  50. emx_onnx_cgen/lowering/lp_pool.py +141 -0
  51. emx_onnx_cgen/lowering/lrn.py +1 -1
  52. emx_onnx_cgen/lowering/lstm.py +1 -1
  53. emx_onnx_cgen/lowering/matmul.py +7 -8
  54. emx_onnx_cgen/lowering/maxpool.py +1 -1
  55. emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
  56. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +13 -13
  57. emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
  58. emx_onnx_cgen/lowering/nonzero.py +42 -0
  59. emx_onnx_cgen/lowering/one_hot.py +120 -0
  60. emx_onnx_cgen/lowering/pad.py +1 -1
  61. emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
  62. emx_onnx_cgen/lowering/quantize_linear.py +126 -0
  63. emx_onnx_cgen/lowering/range.py +1 -1
  64. emx_onnx_cgen/lowering/reduce.py +6 -7
  65. emx_onnx_cgen/lowering/registry.py +24 -5
  66. emx_onnx_cgen/lowering/reshape.py +224 -52
  67. emx_onnx_cgen/lowering/resize.py +1 -1
  68. emx_onnx_cgen/lowering/rms_normalization.py +1 -1
  69. emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
  70. emx_onnx_cgen/lowering/scatter_nd.py +82 -0
  71. emx_onnx_cgen/lowering/shape.py +6 -25
  72. emx_onnx_cgen/lowering/size.py +1 -1
  73. emx_onnx_cgen/lowering/slice.py +1 -1
  74. emx_onnx_cgen/lowering/softmax.py +6 -2
  75. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
  76. emx_onnx_cgen/lowering/split.py +1 -1
  77. emx_onnx_cgen/lowering/squeeze.py +6 -6
  78. emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
  79. emx_onnx_cgen/lowering/tile.py +1 -1
  80. emx_onnx_cgen/lowering/topk.py +134 -0
  81. emx_onnx_cgen/lowering/transpose.py +1 -1
  82. emx_onnx_cgen/lowering/trilu.py +89 -0
  83. emx_onnx_cgen/lowering/unsqueeze.py +6 -6
  84. emx_onnx_cgen/lowering/variadic.py +1 -1
  85. emx_onnx_cgen/lowering/where.py +1 -1
  86. emx_onnx_cgen/onnx_import.py +4 -0
  87. emx_onnx_cgen/onnxruntime_utils.py +11 -0
  88. emx_onnx_cgen/ops.py +4 -0
  89. emx_onnx_cgen/runtime/evaluator.py +785 -43
  90. emx_onnx_cgen/testbench.py +23 -0
  91. emx_onnx_cgen/verification.py +31 -0
  92. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/METADATA +33 -6
  93. emx_onnx_cgen-0.3.1.dist-info/RECORD +107 -0
  94. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/WHEEL +1 -1
  95. shared/scalar_functions.py +60 -17
  96. shared/ulp.py +65 -0
  97. emx_onnx_cgen-0.2.0.dist-info/RECORD +0 -76
  98. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/entry_points.txt +0 -0
  99. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@ import numpy as np
4
4
 
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
- from ..codegen.c_emitter import ExpandOp
7
+ from ..ir.ops import ExpandOp
8
8
  from ..errors import ShapeInferenceError, UnsupportedOpError
9
9
  from ..ir.model import Graph, Initializer, Node
10
10
  from ..lowering.common import value_dtype, value_shape
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import EyeLikeOp
3
+ from ..ir.ops import EyeLikeOp
4
4
  from ..dtypes import scalar_type_from_onnx
5
5
  from ..errors import ShapeInferenceError, UnsupportedOpError
6
6
  from ..ir.model import Graph, Node
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import ReshapeOp
3
+ from ..ir.ops import ReshapeOp
4
4
  from ..errors import ShapeInferenceError, UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from .common import shape_product, value_dtype, value_shape
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import GatherOp
5
+ from ..ir.ops import GatherOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from ..validation import normalize_axis
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import GatherElementsOp
5
+ from ..ir.ops import GatherElementsOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from ..validation import normalize_axis
@@ -33,9 +33,7 @@ def lower_gather_elements(graph: Graph, node: Node) -> GatherElementsOp:
33
33
  for dim_index, (data_dim, index_dim) in enumerate(
34
34
  zip(data_shape, indices_shape)
35
35
  ):
36
- if dim_index == axis:
37
- continue
38
- if data_dim != index_dim:
36
+ if dim_index != axis and data_dim != index_dim:
39
37
  raise ShapeInferenceError(
40
38
  "GatherElements inputs must match on non-axis dimensions, "
41
39
  f"got {data_shape} and {indices_shape}"
@@ -0,0 +1,79 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..ir.ops import GatherNDOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .common import value_dtype as _value_dtype
9
+ from .common import value_shape as _value_shape
10
+ from .registry import register_lowering
11
+
12
+
13
+ @register_lowering("GatherND")
14
+ def lower_gather_nd(graph: Graph, node: Node) -> GatherNDOp:
15
+ if len(node.inputs) != 2 or len(node.outputs) != 1:
16
+ raise UnsupportedOpError("GatherND must have 2 inputs and 1 output")
17
+ data_name, indices_name = node.inputs
18
+ output_name = node.outputs[0]
19
+ data_shape = _value_shape(graph, data_name, node)
20
+ indices_shape = _value_shape(graph, indices_name, node)
21
+ output_shape = _value_shape(graph, output_name, node)
22
+ if len(indices_shape) < 1:
23
+ raise ShapeInferenceError("GatherND indices must have rank >= 1")
24
+ batch_dims = int(node.attrs.get("batch_dims", 0))
25
+ if batch_dims < 0:
26
+ raise ShapeInferenceError(
27
+ f"GatherND batch_dims must be >= 0, got {batch_dims}"
28
+ )
29
+ if batch_dims > len(indices_shape) - 1:
30
+ raise ShapeInferenceError(
31
+ "GatherND batch_dims must be <= indices rank - 1, "
32
+ f"got {batch_dims} vs {len(indices_shape) - 1}"
33
+ )
34
+ if batch_dims > len(data_shape):
35
+ raise ShapeInferenceError(
36
+ "GatherND batch_dims must be <= data rank, "
37
+ f"got {batch_dims} vs {len(data_shape)}"
38
+ )
39
+ if tuple(data_shape[:batch_dims]) != tuple(indices_shape[:batch_dims]):
40
+ raise ShapeInferenceError(
41
+ "GatherND batch_dims must match on data/indices, "
42
+ f"got {data_shape} vs {indices_shape}"
43
+ )
44
+ index_depth = indices_shape[-1]
45
+ if index_depth <= 0:
46
+ raise ShapeInferenceError(
47
+ "GatherND indices final dimension must be >= 1"
48
+ )
49
+ if index_depth > len(data_shape) - batch_dims:
50
+ raise ShapeInferenceError(
51
+ "GatherND indices final dimension must be <= data rank - "
52
+ f"batch_dims, got {index_depth} vs {len(data_shape) - batch_dims}"
53
+ )
54
+ expected_output_shape = indices_shape[:-1] + data_shape[
55
+ batch_dims + index_depth :
56
+ ]
57
+ if output_shape != expected_output_shape:
58
+ raise ShapeInferenceError(
59
+ "GatherND output shape must be "
60
+ f"{expected_output_shape}, got {output_shape}"
61
+ )
62
+ data_dtype = _value_dtype(graph, data_name, node)
63
+ indices_dtype = _value_dtype(graph, indices_name, node)
64
+ if indices_dtype not in {ScalarType.I64, ScalarType.I32}:
65
+ raise UnsupportedOpError(
66
+ "GatherND indices must be int32 or int64, "
67
+ f"got {indices_dtype.onnx_name}"
68
+ )
69
+ return GatherNDOp(
70
+ data=data_name,
71
+ indices=indices_name,
72
+ output=output_name,
73
+ batch_dims=batch_dims,
74
+ data_shape=data_shape,
75
+ indices_shape=indices_shape,
76
+ output_shape=output_shape,
77
+ dtype=data_dtype,
78
+ indices_dtype=indices_dtype,
79
+ )
@@ -4,7 +4,7 @@ from dataclasses import dataclass
4
4
 
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
- from ..codegen.c_emitter import GemmOp
7
+ from ..ir.ops import GemmOp
8
8
  from ..errors import ShapeInferenceError, UnsupportedOpError
9
9
  from ..ir.model import Graph, Node
10
10
  from .common import node_dtype as _node_dtype
@@ -0,0 +1,59 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..ir.ops import ReduceOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .common import value_dtype as _value_dtype
9
+ from .common import value_shape as _value_shape
10
+ from .registry import register_lowering
11
+
12
+
13
+ @register_lowering("GlobalMaxPool")
14
+ def lower_global_max_pool(graph: Graph, node: Node) -> ReduceOp:
15
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
16
+ raise UnsupportedOpError("GlobalMaxPool must have 1 input and 1 output")
17
+ if node.attrs:
18
+ raise UnsupportedOpError("GlobalMaxPool has unsupported attributes")
19
+ op_dtype = _value_dtype(graph, node.inputs[0], node)
20
+ output_dtype = _value_dtype(graph, node.outputs[0], node)
21
+ if op_dtype != output_dtype:
22
+ raise UnsupportedOpError(
23
+ "GlobalMaxPool expects matching input/output dtypes, "
24
+ f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
25
+ )
26
+ if op_dtype not in {ScalarType.F16, ScalarType.F32, ScalarType.F64}:
27
+ raise UnsupportedOpError(
28
+ "GlobalMaxPool supports float16, float, and double inputs only"
29
+ )
30
+ input_shape = _value_shape(graph, node.inputs[0], node)
31
+ if len(input_shape) < 3:
32
+ raise UnsupportedOpError(
33
+ "GlobalMaxPool expects input rank of at least 3"
34
+ )
35
+ output_shape = _value_shape(graph, node.outputs[0], node)
36
+ expected_output_shape = (input_shape[0], input_shape[1]) + (
37
+ 1,
38
+ ) * (len(input_shape) - 2)
39
+ if output_shape != expected_output_shape:
40
+ raise ShapeInferenceError(
41
+ "GlobalMaxPool output shape must be "
42
+ f"{expected_output_shape}, got {output_shape}"
43
+ )
44
+ axes = tuple(range(2, len(input_shape)))
45
+ return ReduceOp(
46
+ input0=node.inputs[0],
47
+ output=node.outputs[0],
48
+ input_shape=input_shape,
49
+ output_shape=output_shape,
50
+ axes=axes,
51
+ axes_input=None,
52
+ axes_input_shape=None,
53
+ axes_input_dtype=None,
54
+ keepdims=True,
55
+ noop_with_empty_axes=False,
56
+ reduce_kind="max",
57
+ reduce_count=None,
58
+ dtype=op_dtype,
59
+ )
@@ -4,7 +4,7 @@ from dataclasses import dataclass
4
4
 
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
- from ..codegen.c_emitter import GridSampleOp
7
+ from ..ir.ops import GridSampleOp
8
8
  from ..errors import ShapeInferenceError, UnsupportedOpError
9
9
  from ..ir.model import Graph, Node
10
10
  from .common import value_dtype, value_shape
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import GroupNormalizationOp
3
+ from ..ir.ops import GroupNormalizationOp
4
4
  from ..errors import ShapeInferenceError, UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from ..validation import ensure_output_shape_matches_input
@@ -0,0 +1,53 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..ir.ops import HardmaxOp
6
+ from ..errors import UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .common import node_dtype as _node_dtype
9
+ from .common import onnx_opset_version as _onnx_opset_version
10
+ from .common import shape_product as _shape_product
11
+ from .common import value_shape as _value_shape
12
+ from .registry import register_lowering
13
+ from ..validation import ensure_output_shape_matches_input
14
+ from ..validation import normalize_axis as _normalize_axis
15
+
16
+
17
+ @register_lowering("Hardmax")
18
+ def lower_hardmax(graph: Graph, node: Node) -> HardmaxOp:
19
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
20
+ raise UnsupportedOpError("Hardmax must have 1 input and 1 output")
21
+ op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
22
+ if op_dtype not in {ScalarType.F16, ScalarType.F32, ScalarType.F64}:
23
+ raise UnsupportedOpError(
24
+ "Hardmax supports float16, float, and double inputs only"
25
+ )
26
+ input_shape = _value_shape(graph, node.inputs[0], node)
27
+ output_shape = _value_shape(graph, node.outputs[0], node)
28
+ ensure_output_shape_matches_input(node, input_shape, output_shape)
29
+ opset_version = _onnx_opset_version(graph)
30
+ default_axis = 1 if opset_version is not None and opset_version < 13 else -1
31
+ axis_attr = node.attrs.get("axis", default_axis)
32
+ axis = _normalize_axis(
33
+ int(axis_attr),
34
+ input_shape,
35
+ node,
36
+ )
37
+ outer = _shape_product(input_shape[:axis]) if axis > 0 else 1
38
+ axis_size = input_shape[axis]
39
+ inner = (
40
+ _shape_product(input_shape[axis + 1 :])
41
+ if axis + 1 < len(input_shape)
42
+ else 1
43
+ )
44
+ return HardmaxOp(
45
+ input0=node.inputs[0],
46
+ output=node.outputs[0],
47
+ outer=outer,
48
+ axis_size=axis_size,
49
+ inner=inner,
50
+ axis=axis,
51
+ shape=input_shape,
52
+ dtype=op_dtype,
53
+ )
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import IdentityOp
3
+ from ..ir.ops import IdentityOp
4
4
  from ..errors import ShapeInferenceError, UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from .common import value_dtype, value_shape
@@ -22,11 +22,12 @@ def lower_identity(graph: Graph, node: Node) -> IdentityOp:
22
22
  for index, (input_dim, output_dim) in enumerate(
23
23
  zip(input_shape, output_shape)
24
24
  ):
25
- if input_dim == output_dim:
26
- continue
27
- if input_dim_params[index] or output_dim_params[index]:
28
- continue
29
- raise ShapeInferenceError("Identity input and output shapes must match")
25
+ if input_dim != output_dim and not (
26
+ input_dim_params[index] or output_dim_params[index]
27
+ ):
28
+ raise ShapeInferenceError(
29
+ "Identity input and output shapes must match"
30
+ )
30
31
  input_dtype = value_dtype(graph, node.inputs[0], node)
31
32
  output_dtype = value_dtype(graph, node.outputs[0], node)
32
33
  if input_dtype != output_dtype:
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import InstanceNormalizationOp
3
+ from ..ir.ops import InstanceNormalizationOp
4
4
  from ..errors import ShapeInferenceError, UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from ..validation import ensure_output_shape_matches_input
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import LayerNormalizationOp
3
+ from ..ir.ops import LayerNormalizationOp
4
4
  from ..errors import ShapeInferenceError, UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from ..validation import ensure_output_shape_matches_input
@@ -1,9 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import LogSoftmaxOp
3
+ from ..ir.ops import LogSoftmaxOp
4
4
  from ..errors import UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from .common import node_dtype as _node_dtype
7
+ from .common import onnx_opset_version as _onnx_opset_version
7
8
  from .common import shape_product as _shape_product
8
9
  from .common import value_shape as _value_shape
9
10
  from .registry import register_lowering
@@ -23,8 +24,11 @@ def lower_logsoftmax(graph: Graph, node: Node) -> LogSoftmaxOp:
23
24
  input_shape = _value_shape(graph, node.inputs[0], node)
24
25
  output_shape = _value_shape(graph, node.outputs[0], node)
25
26
  ensure_output_shape_matches_input(node, input_shape, output_shape)
27
+ opset_version = _onnx_opset_version(graph)
28
+ default_axis = 1 if opset_version is not None and opset_version < 13 else -1
29
+ axis_attr = node.attrs.get("axis", default_axis)
26
30
  axis = _normalize_axis(
27
- int(node.attrs.get("axis", -1)),
31
+ int(axis_attr),
28
32
  input_shape,
29
33
  node,
30
34
  )
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import LpNormalizationOp
3
+ from ..ir.ops import LpNormalizationOp
4
4
  from ..errors import UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from ..validation import ensure_output_shape_matches_input
@@ -0,0 +1,141 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from ..ir.ops import LpPoolOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .registry import register_lowering
9
+ from .common import value_dtype as _value_dtype, value_shape as _value_shape
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class LpPoolSpec:
14
+ batch: int
15
+ channels: int
16
+ in_h: int
17
+ in_w: int
18
+ out_h: int
19
+ out_w: int
20
+ kernel_h: int
21
+ kernel_w: int
22
+ stride_h: int
23
+ stride_w: int
24
+ pad_top: int
25
+ pad_left: int
26
+ pad_bottom: int
27
+ pad_right: int
28
+ p: int
29
+
30
+
31
+ def _resolve_lp_pool_spec(graph: Graph, node: Node) -> LpPoolSpec:
32
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
33
+ raise UnsupportedOpError("LpPool must have 1 input and 1 output")
34
+ supported_attrs = {
35
+ "auto_pad",
36
+ "ceil_mode",
37
+ "dilations",
38
+ "kernel_shape",
39
+ "pads",
40
+ "p",
41
+ "strides",
42
+ }
43
+ if set(node.attrs) - supported_attrs:
44
+ raise UnsupportedOpError("LpPool has unsupported attributes")
45
+ auto_pad = node.attrs.get("auto_pad", b"NOTSET")
46
+ if isinstance(auto_pad, bytes):
47
+ auto_pad = auto_pad.decode("utf-8", errors="ignore")
48
+ if auto_pad not in ("", "NOTSET"):
49
+ raise UnsupportedOpError("LpPool supports auto_pad=NOTSET only")
50
+ ceil_mode = int(node.attrs.get("ceil_mode", 0))
51
+ if ceil_mode != 0:
52
+ raise UnsupportedOpError("LpPool supports ceil_mode=0 only")
53
+ dilations = tuple(int(value) for value in node.attrs.get("dilations", (1, 1)))
54
+ if any(value != 1 for value in dilations):
55
+ raise UnsupportedOpError("LpPool supports dilations=1 only")
56
+ kernel_shape = node.attrs.get("kernel_shape")
57
+ if kernel_shape is None:
58
+ raise UnsupportedOpError("LpPool requires kernel_shape")
59
+ kernel_shape = tuple(int(value) for value in kernel_shape)
60
+ if len(kernel_shape) != 2:
61
+ raise UnsupportedOpError("LpPool expects 2D kernel_shape")
62
+ kernel_h, kernel_w = kernel_shape
63
+ strides = tuple(int(value) for value in node.attrs.get("strides", (1, 1)))
64
+ if len(strides) != 2:
65
+ raise UnsupportedOpError("LpPool expects 2D strides")
66
+ pads = tuple(int(value) for value in node.attrs.get("pads", (0, 0, 0, 0)))
67
+ if len(pads) != 4:
68
+ raise UnsupportedOpError("LpPool expects 4D pads")
69
+ pad_top, pad_left, pad_bottom, pad_right = pads
70
+ p = int(node.attrs.get("p", 2))
71
+ if p < 1:
72
+ raise UnsupportedOpError("LpPool p must be >= 1")
73
+ input_shape = _value_shape(graph, node.inputs[0], node)
74
+ if len(input_shape) != 4:
75
+ raise UnsupportedOpError("LpPool supports NCHW 2D inputs only")
76
+ batch, channels, in_h, in_w = input_shape
77
+ stride_h, stride_w = strides
78
+ out_h = (in_h + pad_top + pad_bottom - kernel_h) // stride_h + 1
79
+ out_w = (in_w + pad_left + pad_right - kernel_w) // stride_w + 1
80
+ if out_h < 0 or out_w < 0:
81
+ raise ShapeInferenceError("LpPool output shape must be non-negative")
82
+ output_shape = _value_shape(graph, node.outputs[0], node)
83
+ expected_output_shape = (batch, channels, out_h, out_w)
84
+ if output_shape != expected_output_shape:
85
+ raise ShapeInferenceError(
86
+ "LpPool output shape must be "
87
+ f"{expected_output_shape}, got {output_shape}"
88
+ )
89
+ return LpPoolSpec(
90
+ batch=batch,
91
+ channels=channels,
92
+ in_h=in_h,
93
+ in_w=in_w,
94
+ out_h=out_h,
95
+ out_w=out_w,
96
+ kernel_h=kernel_h,
97
+ kernel_w=kernel_w,
98
+ stride_h=stride_h,
99
+ stride_w=stride_w,
100
+ pad_top=pad_top,
101
+ pad_left=pad_left,
102
+ pad_bottom=pad_bottom,
103
+ pad_right=pad_right,
104
+ p=p,
105
+ )
106
+
107
+
108
+ @register_lowering("LpPool")
109
+ def lower_lp_pool(graph: Graph, node: Node) -> LpPoolOp:
110
+ op_dtype = _value_dtype(graph, node.inputs[0], node)
111
+ output_dtype = _value_dtype(graph, node.outputs[0], node)
112
+ if op_dtype != output_dtype:
113
+ raise UnsupportedOpError(
114
+ "LpPool expects matching input/output dtypes, "
115
+ f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
116
+ )
117
+ if not op_dtype.is_float:
118
+ raise UnsupportedOpError(
119
+ "LpPool supports float16, float, and double inputs only"
120
+ )
121
+ spec = _resolve_lp_pool_spec(graph, node)
122
+ return LpPoolOp(
123
+ input0=node.inputs[0],
124
+ output=node.outputs[0],
125
+ batch=spec.batch,
126
+ channels=spec.channels,
127
+ in_h=spec.in_h,
128
+ in_w=spec.in_w,
129
+ out_h=spec.out_h,
130
+ out_w=spec.out_w,
131
+ kernel_h=spec.kernel_h,
132
+ kernel_w=spec.kernel_w,
133
+ stride_h=spec.stride_h,
134
+ stride_w=spec.stride_w,
135
+ pad_top=spec.pad_top,
136
+ pad_left=spec.pad_left,
137
+ pad_bottom=spec.pad_bottom,
138
+ pad_right=spec.pad_right,
139
+ p=spec.p,
140
+ dtype=op_dtype,
141
+ )
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
4
 
5
- from ..codegen.c_emitter import LrnOp
5
+ from ..ir.ops import LrnOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .registry import register_lowering
@@ -323,7 +323,7 @@ def resolve_lstm_spec(graph: Graph, node: Node) -> LstmSpec:
323
323
 
324
324
  @register_lowering("LSTM")
325
325
  def lower_lstm(graph: Graph, node: Node) -> "LstmOp":
326
- from ..codegen.c_emitter import LstmOp
326
+ from ..ir.ops import LstmOp
327
327
 
328
328
  spec = resolve_lstm_spec(graph, node)
329
329
  return LstmOp(
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
4
 
5
- from ..codegen.c_emitter import MatMulOp
5
+ from ..ir.ops import MatMulOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .common import node_dtype as _node_dtype
@@ -87,13 +87,12 @@ def _broadcast_batch_shapes(
87
87
  right_padded = (1,) * (max_rank - len(right)) + right
88
88
  broadcast_shape = []
89
89
  for left_dim, right_dim in zip(left_padded, right_padded):
90
- if left_dim == right_dim or left_dim == 1 or right_dim == 1:
91
- broadcast_shape.append(max(left_dim, right_dim))
92
- continue
93
- raise ShapeInferenceError(
94
- "MatMul batch dimensions must be broadcastable, "
95
- f"got {left} x {right}"
96
- )
90
+ if not (left_dim == right_dim or left_dim == 1 or right_dim == 1):
91
+ raise ShapeInferenceError(
92
+ "MatMul batch dimensions must be broadcastable, "
93
+ f"got {left} x {right}"
94
+ )
95
+ broadcast_shape.append(max(left_dim, right_dim))
97
96
  return tuple(broadcast_shape), left_padded, right_padded
98
97
 
99
98
 
@@ -5,7 +5,7 @@ from dataclasses import dataclass
5
5
 
6
6
  from shared.scalar_types import ScalarType
7
7
 
8
- from ..codegen.c_emitter import MaxPoolOp
8
+ from ..ir.ops import MaxPoolOp
9
9
  from ..errors import ShapeInferenceError, UnsupportedOpError
10
10
  from ..ir.model import Graph, Node
11
11
  from .common import node_dtype as _node_dtype
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import MeanVarianceNormalizationOp
3
+ from ..ir.ops import MeanVarianceNormalizationOp
4
4
  from ..errors import UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from ..validation import ensure_output_shape_matches_input
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import NegativeLogLikelihoodLossOp
5
+ from ..ir.ops import NegativeLogLikelihoodLossOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Initializer, Node
8
8
  from .common import shape_product as _shape_product
@@ -43,18 +43,18 @@ def _resolve_target_shape(
43
43
  raise ShapeInferenceError("Reshape allows only one -1 dimension")
44
44
  unknown_index = index
45
45
  output_dims.append(-1)
46
- continue
47
- if dim == 0:
48
- if allowzero == 0:
49
- if index >= len(input_shape):
50
- raise ShapeInferenceError(
51
- "Reshape zero dim must index into input shape"
52
- )
53
- dim = input_shape[index]
54
- if dim < 0:
55
- raise ShapeInferenceError("Reshape dims must be >= -1")
56
- output_dims.append(dim)
57
- known_product *= dim
46
+ else:
47
+ if dim == 0:
48
+ if allowzero == 0:
49
+ if index >= len(input_shape):
50
+ raise ShapeInferenceError(
51
+ "Reshape zero dim must index into input shape"
52
+ )
53
+ dim = input_shape[index]
54
+ if dim < 0:
55
+ raise ShapeInferenceError("Reshape dims must be >= -1")
56
+ output_dims.append(dim)
57
+ known_product *= dim
58
58
  input_product = _shape_product(input_shape)
59
59
  if unknown_index is not None:
60
60
  if known_product == 0 or input_product % known_product != 0: