emx-onnx-cgen 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (76) hide show
  1. emx_onnx_cgen/__init__.py +6 -0
  2. emx_onnx_cgen/__main__.py +9 -0
  3. emx_onnx_cgen/_build_info.py +3 -0
  4. emx_onnx_cgen/cli.py +328 -0
  5. emx_onnx_cgen/codegen/__init__.py +25 -0
  6. emx_onnx_cgen/codegen/c_emitter.py +9044 -0
  7. emx_onnx_cgen/compiler.py +601 -0
  8. emx_onnx_cgen/dtypes.py +40 -0
  9. emx_onnx_cgen/errors.py +14 -0
  10. emx_onnx_cgen/ir/__init__.py +3 -0
  11. emx_onnx_cgen/ir/model.py +55 -0
  12. emx_onnx_cgen/lowering/__init__.py +3 -0
  13. emx_onnx_cgen/lowering/arg_reduce.py +99 -0
  14. emx_onnx_cgen/lowering/attention.py +421 -0
  15. emx_onnx_cgen/lowering/average_pool.py +229 -0
  16. emx_onnx_cgen/lowering/batch_normalization.py +116 -0
  17. emx_onnx_cgen/lowering/cast.py +70 -0
  18. emx_onnx_cgen/lowering/common.py +72 -0
  19. emx_onnx_cgen/lowering/concat.py +31 -0
  20. emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
  21. emx_onnx_cgen/lowering/conv.py +192 -0
  22. emx_onnx_cgen/lowering/cumsum.py +118 -0
  23. emx_onnx_cgen/lowering/depth_space.py +114 -0
  24. emx_onnx_cgen/lowering/dropout.py +46 -0
  25. emx_onnx_cgen/lowering/elementwise.py +164 -0
  26. emx_onnx_cgen/lowering/expand.py +151 -0
  27. emx_onnx_cgen/lowering/eye_like.py +43 -0
  28. emx_onnx_cgen/lowering/flatten.py +60 -0
  29. emx_onnx_cgen/lowering/gather.py +48 -0
  30. emx_onnx_cgen/lowering/gather_elements.py +60 -0
  31. emx_onnx_cgen/lowering/gemm.py +139 -0
  32. emx_onnx_cgen/lowering/grid_sample.py +149 -0
  33. emx_onnx_cgen/lowering/group_normalization.py +68 -0
  34. emx_onnx_cgen/lowering/identity.py +43 -0
  35. emx_onnx_cgen/lowering/instance_normalization.py +50 -0
  36. emx_onnx_cgen/lowering/layer_normalization.py +110 -0
  37. emx_onnx_cgen/lowering/logsoftmax.py +47 -0
  38. emx_onnx_cgen/lowering/lp_normalization.py +45 -0
  39. emx_onnx_cgen/lowering/lrn.py +104 -0
  40. emx_onnx_cgen/lowering/lstm.py +355 -0
  41. emx_onnx_cgen/lowering/matmul.py +120 -0
  42. emx_onnx_cgen/lowering/maxpool.py +195 -0
  43. emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
  44. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
  45. emx_onnx_cgen/lowering/pad.py +287 -0
  46. emx_onnx_cgen/lowering/range.py +104 -0
  47. emx_onnx_cgen/lowering/reduce.py +544 -0
  48. emx_onnx_cgen/lowering/registry.py +51 -0
  49. emx_onnx_cgen/lowering/reshape.py +188 -0
  50. emx_onnx_cgen/lowering/resize.py +445 -0
  51. emx_onnx_cgen/lowering/rms_normalization.py +67 -0
  52. emx_onnx_cgen/lowering/shape.py +78 -0
  53. emx_onnx_cgen/lowering/size.py +33 -0
  54. emx_onnx_cgen/lowering/slice.py +425 -0
  55. emx_onnx_cgen/lowering/softmax.py +47 -0
  56. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
  57. emx_onnx_cgen/lowering/split.py +150 -0
  58. emx_onnx_cgen/lowering/squeeze.py +161 -0
  59. emx_onnx_cgen/lowering/tile.py +81 -0
  60. emx_onnx_cgen/lowering/transpose.py +46 -0
  61. emx_onnx_cgen/lowering/unsqueeze.py +157 -0
  62. emx_onnx_cgen/lowering/variadic.py +95 -0
  63. emx_onnx_cgen/lowering/where.py +73 -0
  64. emx_onnx_cgen/onnx_import.py +261 -0
  65. emx_onnx_cgen/ops.py +565 -0
  66. emx_onnx_cgen/runtime/__init__.py +1 -0
  67. emx_onnx_cgen/runtime/evaluator.py +2206 -0
  68. emx_onnx_cgen/validation.py +76 -0
  69. emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
  70. emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
  71. emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
  72. emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
  73. emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
  74. shared/__init__.py +2 -0
  75. shared/scalar_functions.py +2405 -0
  76. shared/scalar_types.py +243 -0
@@ -0,0 +1,68 @@
1
+ from __future__ import annotations
2
+
3
+ from ..codegen.c_emitter import GroupNormalizationOp
4
+ from ..errors import ShapeInferenceError, UnsupportedOpError
5
+ from ..ir.model import Graph, Node
6
+ from ..validation import ensure_output_shape_matches_input
7
+ from .common import node_dtype, shape_product, value_shape
8
+ from .registry import register_lowering
9
+
10
+
11
+ @register_lowering("GroupNormalization")
12
+ def lower_group_normalization(
13
+ graph: Graph, node: Node
14
+ ) -> GroupNormalizationOp:
15
+ if len(node.inputs) != 3 or len(node.outputs) != 1:
16
+ raise UnsupportedOpError(
17
+ "GroupNormalization must have 3 inputs and 1 output"
18
+ )
19
+ op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
20
+ if not op_dtype.is_float:
21
+ raise UnsupportedOpError(
22
+ "GroupNormalization supports float16, float, and double inputs only"
23
+ )
24
+ input_shape = value_shape(graph, node.inputs[0], node)
25
+ output_shape = value_shape(graph, node.outputs[0], node)
26
+ ensure_output_shape_matches_input(node, input_shape, output_shape)
27
+ if len(input_shape) < 3:
28
+ raise ShapeInferenceError(
29
+ "GroupNormalization expects input rank of at least 3"
30
+ )
31
+ channels = input_shape[1]
32
+ num_groups_attr = node.attrs.get("num_groups")
33
+ if num_groups_attr is None:
34
+ raise UnsupportedOpError("GroupNormalization requires num_groups")
35
+ num_groups = int(num_groups_attr)
36
+ if num_groups <= 0:
37
+ raise ShapeInferenceError("GroupNormalization num_groups must be > 0")
38
+ if channels % num_groups != 0:
39
+ raise ShapeInferenceError(
40
+ "GroupNormalization num_groups must divide the channel dimension"
41
+ )
42
+ scale_shape = value_shape(graph, node.inputs[1], node)
43
+ bias_shape = value_shape(graph, node.inputs[2], node)
44
+ if scale_shape != (channels,) or bias_shape != (channels,):
45
+ raise ShapeInferenceError(
46
+ "GroupNormalization scale and bias must be 1D with length C"
47
+ )
48
+ spatial_size = shape_product(input_shape[2:])
49
+ group_size = channels // num_groups
50
+ epsilon = float(node.attrs.get("epsilon", 1e-5))
51
+ stash_type = int(node.attrs.get("stash_type", 1))
52
+ if stash_type != 1:
53
+ raise UnsupportedOpError(
54
+ "GroupNormalization supports stash_type=1 only"
55
+ )
56
+ return GroupNormalizationOp(
57
+ input0=node.inputs[0],
58
+ scale=node.inputs[1],
59
+ bias=node.inputs[2],
60
+ output=node.outputs[0],
61
+ shape=input_shape,
62
+ channels=channels,
63
+ num_groups=num_groups,
64
+ group_size=group_size,
65
+ spatial_size=spatial_size,
66
+ epsilon=epsilon,
67
+ dtype=op_dtype,
68
+ )
@@ -0,0 +1,43 @@
1
+ from __future__ import annotations
2
+
3
+ from ..codegen.c_emitter import IdentityOp
4
+ from ..errors import ShapeInferenceError, UnsupportedOpError
5
+ from ..ir.model import Graph, Node
6
+ from .common import value_dtype, value_shape
7
+ from .registry import register_lowering
8
+
9
+
10
+ @register_lowering("Identity")
11
+ def lower_identity(graph: Graph, node: Node) -> IdentityOp:
12
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
13
+ raise UnsupportedOpError("Identity must have 1 input and 1 output")
14
+ input_shape = value_shape(graph, node.inputs[0], node)
15
+ output_shape = value_shape(graph, node.outputs[0], node)
16
+ input_dim_params = graph.find_value(node.inputs[0]).type.dim_params
17
+ output_dim_params = graph.find_value(node.outputs[0]).type.dim_params
18
+ resolved_shape = output_shape or input_shape
19
+ if input_shape and output_shape:
20
+ if len(input_shape) != len(output_shape):
21
+ raise ShapeInferenceError("Identity input and output shapes must match")
22
+ for index, (input_dim, output_dim) in enumerate(
23
+ zip(input_shape, output_shape)
24
+ ):
25
+ if input_dim == output_dim:
26
+ continue
27
+ if input_dim_params[index] or output_dim_params[index]:
28
+ continue
29
+ raise ShapeInferenceError("Identity input and output shapes must match")
30
+ input_dtype = value_dtype(graph, node.inputs[0], node)
31
+ output_dtype = value_dtype(graph, node.outputs[0], node)
32
+ if input_dtype != output_dtype:
33
+ raise UnsupportedOpError(
34
+ "Identity expects matching input/output dtypes, "
35
+ f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
36
+ )
37
+ return IdentityOp(
38
+ input0=node.inputs[0],
39
+ output=node.outputs[0],
40
+ shape=resolved_shape,
41
+ dtype=output_dtype,
42
+ input_dtype=input_dtype,
43
+ )
@@ -0,0 +1,50 @@
1
+ from __future__ import annotations
2
+
3
+ from ..codegen.c_emitter import InstanceNormalizationOp
4
+ from ..errors import ShapeInferenceError, UnsupportedOpError
5
+ from ..ir.model import Graph, Node
6
+ from ..validation import ensure_output_shape_matches_input
7
+ from .common import node_dtype, shape_product, value_shape
8
+ from .registry import register_lowering
9
+
10
+
11
+ @register_lowering("InstanceNormalization")
12
+ def lower_instance_normalization(
13
+ graph: Graph, node: Node
14
+ ) -> InstanceNormalizationOp:
15
+ if len(node.inputs) != 3 or len(node.outputs) != 1:
16
+ raise UnsupportedOpError(
17
+ "InstanceNormalization must have 3 inputs and 1 output"
18
+ )
19
+ op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
20
+ if not op_dtype.is_float:
21
+ raise UnsupportedOpError(
22
+ "InstanceNormalization supports float16, float, and double inputs only"
23
+ )
24
+ input_shape = value_shape(graph, node.inputs[0], node)
25
+ output_shape = value_shape(graph, node.outputs[0], node)
26
+ ensure_output_shape_matches_input(node, input_shape, output_shape)
27
+ if len(input_shape) < 3:
28
+ raise ShapeInferenceError(
29
+ "InstanceNormalization expects input rank of at least 3"
30
+ )
31
+ channels = input_shape[1]
32
+ scale_shape = value_shape(graph, node.inputs[1], node)
33
+ bias_shape = value_shape(graph, node.inputs[2], node)
34
+ if scale_shape != (channels,) or bias_shape != (channels,):
35
+ raise ShapeInferenceError(
36
+ "InstanceNormalization scale and bias must be 1D with length C"
37
+ )
38
+ spatial_size = shape_product(input_shape[2:])
39
+ epsilon = float(node.attrs.get("epsilon", 1e-5))
40
+ return InstanceNormalizationOp(
41
+ input0=node.inputs[0],
42
+ scale=node.inputs[1],
43
+ bias=node.inputs[2],
44
+ output=node.outputs[0],
45
+ shape=input_shape,
46
+ channels=channels,
47
+ spatial_size=spatial_size,
48
+ epsilon=epsilon,
49
+ dtype=op_dtype,
50
+ )
@@ -0,0 +1,110 @@
1
+ from __future__ import annotations
2
+
3
+ from ..codegen.c_emitter import LayerNormalizationOp
4
+ from ..errors import ShapeInferenceError, UnsupportedOpError
5
+ from ..ir.model import Graph, Node
6
+ from ..validation import ensure_output_shape_matches_input
7
+ from ..validation import normalize_axis
8
+ from .common import node_dtype, shape_product, value_dtype, value_shape
9
+ from .registry import register_lowering
10
+
11
+
12
+ def _ensure_broadcastable(
13
+ name: str,
14
+ shape: tuple[int, ...],
15
+ normalized_shape: tuple[int, ...],
16
+ ) -> None:
17
+ if len(shape) != len(normalized_shape):
18
+ raise ShapeInferenceError(
19
+ f"LayerNormalization {name} rank must match normalized rank"
20
+ )
21
+ for dim, expected in zip(shape, normalized_shape):
22
+ if dim not in {1, expected}:
23
+ raise ShapeInferenceError(
24
+ f"LayerNormalization {name} shape {shape} must be broadcastable "
25
+ f"to {normalized_shape}"
26
+ )
27
+
28
+
29
+ @register_lowering("LayerNormalization")
30
+ def lower_layer_normalization(
31
+ graph: Graph, node: Node
32
+ ) -> LayerNormalizationOp:
33
+ if len(node.inputs) < 2 or len(node.inputs) > 3:
34
+ raise UnsupportedOpError(
35
+ "LayerNormalization must have 2 or 3 inputs"
36
+ )
37
+ if len(node.outputs) < 1 or len(node.outputs) > 3:
38
+ raise UnsupportedOpError(
39
+ "LayerNormalization must have 1 to 3 outputs"
40
+ )
41
+ op_dtype = node_dtype(graph, node, *node.inputs, node.outputs[0])
42
+ if not op_dtype.is_float:
43
+ raise UnsupportedOpError(
44
+ "LayerNormalization supports float16, float, and double inputs only"
45
+ )
46
+ input_shape = value_shape(graph, node.inputs[0], node)
47
+ output_shape = value_shape(graph, node.outputs[0], node)
48
+ ensure_output_shape_matches_input(node, input_shape, output_shape)
49
+ axis = normalize_axis(int(node.attrs.get("axis", -1)), input_shape, node)
50
+ normalized_shape = input_shape[axis:]
51
+ scale_shape = value_shape(graph, node.inputs[1], node)
52
+ _ensure_broadcastable("scale", scale_shape, normalized_shape)
53
+ bias_input = node.inputs[2] if len(node.inputs) > 2 and node.inputs[2] else None
54
+ bias_shape = None
55
+ if bias_input is not None:
56
+ bias_shape = value_shape(graph, bias_input, node)
57
+ _ensure_broadcastable("bias", bias_shape, normalized_shape)
58
+ epsilon = float(node.attrs.get("epsilon", 1e-5))
59
+ stash_type = int(node.attrs.get("stash_type", 1))
60
+ if stash_type != 1:
61
+ raise UnsupportedOpError(
62
+ "LayerNormalization supports stash_type=1 only"
63
+ )
64
+ mean_output = node.outputs[1] if len(node.outputs) > 1 else None
65
+ invstd_output = node.outputs[2] if len(node.outputs) > 2 else None
66
+ if mean_output is not None:
67
+ mean_dtype = value_dtype(graph, mean_output, node)
68
+ if mean_dtype != op_dtype:
69
+ raise UnsupportedOpError(
70
+ "LayerNormalization expects mean output dtype to match input"
71
+ )
72
+ expected_mean_shape = input_shape[:axis] + (1,) * len(normalized_shape)
73
+ mean_shape = value_shape(graph, mean_output, node)
74
+ if mean_shape != expected_mean_shape:
75
+ raise ShapeInferenceError(
76
+ "LayerNormalization mean output shape must be "
77
+ f"{expected_mean_shape}, got {mean_shape}"
78
+ )
79
+ if invstd_output is not None:
80
+ invstd_dtype = value_dtype(graph, invstd_output, node)
81
+ if invstd_dtype != op_dtype:
82
+ raise UnsupportedOpError(
83
+ "LayerNormalization expects invstd output dtype to match input"
84
+ )
85
+ expected_invstd_shape = input_shape[:axis] + (1,) * len(normalized_shape)
86
+ invstd_shape = value_shape(graph, invstd_output, node)
87
+ if invstd_shape != expected_invstd_shape:
88
+ raise ShapeInferenceError(
89
+ "LayerNormalization invstd output shape must be "
90
+ f"{expected_invstd_shape}, got {invstd_shape}"
91
+ )
92
+ outer = shape_product(input_shape[:axis]) if axis > 0 else 1
93
+ inner = shape_product(normalized_shape)
94
+ return LayerNormalizationOp(
95
+ input0=node.inputs[0],
96
+ scale=node.inputs[1],
97
+ bias=bias_input,
98
+ output=node.outputs[0],
99
+ mean_output=mean_output,
100
+ invstd_output=invstd_output,
101
+ shape=input_shape,
102
+ normalized_shape=normalized_shape,
103
+ scale_shape=scale_shape,
104
+ bias_shape=bias_shape,
105
+ outer=outer,
106
+ inner=inner,
107
+ axis=axis,
108
+ epsilon=epsilon,
109
+ dtype=op_dtype,
110
+ )
@@ -0,0 +1,47 @@
1
+ from __future__ import annotations
2
+
3
+ from ..codegen.c_emitter import LogSoftmaxOp
4
+ from ..errors import UnsupportedOpError
5
+ from ..ir.model import Graph, Node
6
+ from .common import node_dtype as _node_dtype
7
+ from .common import shape_product as _shape_product
8
+ from .common import value_shape as _value_shape
9
+ from .registry import register_lowering
10
+ from ..validation import ensure_output_shape_matches_input
11
+ from ..validation import normalize_axis as _normalize_axis
12
+
13
+
14
+ @register_lowering("LogSoftmax")
15
+ def lower_logsoftmax(graph: Graph, node: Node) -> LogSoftmaxOp:
16
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
17
+ raise UnsupportedOpError("LogSoftmax must have 1 input and 1 output")
18
+ op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
19
+ if not op_dtype.is_float:
20
+ raise UnsupportedOpError(
21
+ "LogSoftmax supports float16, float, and double inputs only"
22
+ )
23
+ input_shape = _value_shape(graph, node.inputs[0], node)
24
+ output_shape = _value_shape(graph, node.outputs[0], node)
25
+ ensure_output_shape_matches_input(node, input_shape, output_shape)
26
+ axis = _normalize_axis(
27
+ int(node.attrs.get("axis", -1)),
28
+ input_shape,
29
+ node,
30
+ )
31
+ outer = _shape_product(input_shape[:axis]) if axis > 0 else 1
32
+ axis_size = input_shape[axis]
33
+ inner = (
34
+ _shape_product(input_shape[axis + 1 :])
35
+ if axis + 1 < len(input_shape)
36
+ else 1
37
+ )
38
+ return LogSoftmaxOp(
39
+ input0=node.inputs[0],
40
+ output=node.outputs[0],
41
+ outer=outer,
42
+ axis_size=axis_size,
43
+ inner=inner,
44
+ axis=axis,
45
+ shape=input_shape,
46
+ dtype=op_dtype,
47
+ )
@@ -0,0 +1,45 @@
1
+ from __future__ import annotations
2
+
3
+ from ..codegen.c_emitter import LpNormalizationOp
4
+ from ..errors import UnsupportedOpError
5
+ from ..ir.model import Graph, Node
6
+ from ..validation import ensure_output_shape_matches_input
7
+ from ..validation import normalize_axis
8
+ from .common import node_dtype, shape_product, value_shape
9
+ from .registry import register_lowering
10
+
11
+
12
+ @register_lowering("LpNormalization")
13
+ def lower_lp_normalization(graph: Graph, node: Node) -> LpNormalizationOp:
14
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
15
+ raise UnsupportedOpError("LpNormalization must have 1 input and 1 output")
16
+ op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
17
+ if not op_dtype.is_float:
18
+ raise UnsupportedOpError(
19
+ "LpNormalization supports float16, float, and double inputs only"
20
+ )
21
+ input_shape = value_shape(graph, node.inputs[0], node)
22
+ output_shape = value_shape(graph, node.outputs[0], node)
23
+ ensure_output_shape_matches_input(node, input_shape, output_shape)
24
+ axis = normalize_axis(int(node.attrs.get("axis", -1)), input_shape, node)
25
+ p = int(node.attrs.get("p", 2))
26
+ if p not in {1, 2}:
27
+ raise UnsupportedOpError("LpNormalization only supports p=1 or p=2")
28
+ outer = shape_product(input_shape[:axis]) if axis > 0 else 1
29
+ axis_size = input_shape[axis]
30
+ inner = (
31
+ shape_product(input_shape[axis + 1 :])
32
+ if axis + 1 < len(input_shape)
33
+ else 1
34
+ )
35
+ return LpNormalizationOp(
36
+ input0=node.inputs[0],
37
+ output=node.outputs[0],
38
+ shape=input_shape,
39
+ axis=axis,
40
+ p=p,
41
+ outer=outer,
42
+ axis_size=axis_size,
43
+ inner=inner,
44
+ dtype=op_dtype,
45
+ )
@@ -0,0 +1,104 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from ..codegen.c_emitter import LrnOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .registry import register_lowering
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class LrnSpec:
13
+ shape: tuple[int, ...]
14
+ channels: int
15
+ size: int
16
+ half: int
17
+ alpha: float
18
+ beta: float
19
+ bias: float
20
+
21
+
22
+ def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
23
+ try:
24
+ return graph.find_value(name).type.shape
25
+ except KeyError as exc:
26
+ raise ShapeInferenceError(
27
+ f"Missing shape for value '{name}' in op {node.op_type}. "
28
+ "Hint: run ONNX shape inference or export with static shapes."
29
+ ) from exc
30
+
31
+
32
+ def _value_dtype(graph: Graph, name: str, node: Node) -> str:
33
+ try:
34
+ return graph.find_value(name).type.dtype
35
+ except KeyError as exc:
36
+ raise ShapeInferenceError(
37
+ f"Missing dtype for value '{name}' in op {node.op_type}. "
38
+ "Hint: run ONNX shape inference or export with static shapes."
39
+ ) from exc
40
+
41
+
42
+ def _node_dtype(graph: Graph, node: Node, *names: str) -> str:
43
+ dtypes = {_value_dtype(graph, name, node) for name in names}
44
+ if len(dtypes) != 1:
45
+ raise UnsupportedOpError(
46
+ f"{node.op_type} expects matching dtypes, got {', '.join(sorted(dtypes))}"
47
+ )
48
+ return next(iter(dtypes))
49
+
50
+
51
+ def resolve_lrn_spec(graph: Graph, node: Node) -> LrnSpec:
52
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
53
+ raise UnsupportedOpError("LRN must have 1 input and 1 output")
54
+ supported_attrs = {"alpha", "beta", "bias", "size"}
55
+ if set(node.attrs) - supported_attrs:
56
+ raise UnsupportedOpError("LRN has unsupported attributes")
57
+ size = int(node.attrs.get("size", 0))
58
+ if size <= 0:
59
+ raise UnsupportedOpError("LRN size must be a positive integer")
60
+ if size % 2 == 0:
61
+ raise UnsupportedOpError("LRN size must be odd")
62
+ alpha = float(node.attrs.get("alpha", 0.0001))
63
+ beta = float(node.attrs.get("beta", 0.75))
64
+ bias = float(node.attrs.get("bias", 1.0))
65
+ input_shape = _value_shape(graph, node.inputs[0], node)
66
+ if len(input_shape) < 2:
67
+ raise UnsupportedOpError("LRN expects input rank of at least 2")
68
+ output_shape = _value_shape(graph, node.outputs[0], node)
69
+ if output_shape != input_shape:
70
+ raise ShapeInferenceError(
71
+ "LRN output shape must match input shape, "
72
+ f"got {output_shape} for input {input_shape}"
73
+ )
74
+ return LrnSpec(
75
+ shape=input_shape,
76
+ channels=input_shape[1],
77
+ size=size,
78
+ half=size // 2,
79
+ alpha=alpha,
80
+ beta=beta,
81
+ bias=bias,
82
+ )
83
+
84
+
85
+ @register_lowering("LRN")
86
+ def lower_lrn(graph: Graph, node: Node) -> LrnOp:
87
+ op_dtype = _node_dtype(graph, node, *node.inputs, *node.outputs)
88
+ if not op_dtype.is_float:
89
+ raise UnsupportedOpError(
90
+ "LRN supports float16, float, and double inputs only"
91
+ )
92
+ spec = resolve_lrn_spec(graph, node)
93
+ return LrnOp(
94
+ input0=node.inputs[0],
95
+ output=node.outputs[0],
96
+ shape=spec.shape,
97
+ channels=spec.channels,
98
+ size=spec.size,
99
+ half=spec.half,
100
+ alpha=spec.alpha,
101
+ beta=spec.beta,
102
+ bias=spec.bias,
103
+ dtype=op_dtype,
104
+ )