emx-onnx-cgen 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (76) hide show
  1. emx_onnx_cgen/__init__.py +6 -0
  2. emx_onnx_cgen/__main__.py +9 -0
  3. emx_onnx_cgen/_build_info.py +3 -0
  4. emx_onnx_cgen/cli.py +328 -0
  5. emx_onnx_cgen/codegen/__init__.py +25 -0
  6. emx_onnx_cgen/codegen/c_emitter.py +9044 -0
  7. emx_onnx_cgen/compiler.py +601 -0
  8. emx_onnx_cgen/dtypes.py +40 -0
  9. emx_onnx_cgen/errors.py +14 -0
  10. emx_onnx_cgen/ir/__init__.py +3 -0
  11. emx_onnx_cgen/ir/model.py +55 -0
  12. emx_onnx_cgen/lowering/__init__.py +3 -0
  13. emx_onnx_cgen/lowering/arg_reduce.py +99 -0
  14. emx_onnx_cgen/lowering/attention.py +421 -0
  15. emx_onnx_cgen/lowering/average_pool.py +229 -0
  16. emx_onnx_cgen/lowering/batch_normalization.py +116 -0
  17. emx_onnx_cgen/lowering/cast.py +70 -0
  18. emx_onnx_cgen/lowering/common.py +72 -0
  19. emx_onnx_cgen/lowering/concat.py +31 -0
  20. emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
  21. emx_onnx_cgen/lowering/conv.py +192 -0
  22. emx_onnx_cgen/lowering/cumsum.py +118 -0
  23. emx_onnx_cgen/lowering/depth_space.py +114 -0
  24. emx_onnx_cgen/lowering/dropout.py +46 -0
  25. emx_onnx_cgen/lowering/elementwise.py +164 -0
  26. emx_onnx_cgen/lowering/expand.py +151 -0
  27. emx_onnx_cgen/lowering/eye_like.py +43 -0
  28. emx_onnx_cgen/lowering/flatten.py +60 -0
  29. emx_onnx_cgen/lowering/gather.py +48 -0
  30. emx_onnx_cgen/lowering/gather_elements.py +60 -0
  31. emx_onnx_cgen/lowering/gemm.py +139 -0
  32. emx_onnx_cgen/lowering/grid_sample.py +149 -0
  33. emx_onnx_cgen/lowering/group_normalization.py +68 -0
  34. emx_onnx_cgen/lowering/identity.py +43 -0
  35. emx_onnx_cgen/lowering/instance_normalization.py +50 -0
  36. emx_onnx_cgen/lowering/layer_normalization.py +110 -0
  37. emx_onnx_cgen/lowering/logsoftmax.py +47 -0
  38. emx_onnx_cgen/lowering/lp_normalization.py +45 -0
  39. emx_onnx_cgen/lowering/lrn.py +104 -0
  40. emx_onnx_cgen/lowering/lstm.py +355 -0
  41. emx_onnx_cgen/lowering/matmul.py +120 -0
  42. emx_onnx_cgen/lowering/maxpool.py +195 -0
  43. emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
  44. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
  45. emx_onnx_cgen/lowering/pad.py +287 -0
  46. emx_onnx_cgen/lowering/range.py +104 -0
  47. emx_onnx_cgen/lowering/reduce.py +544 -0
  48. emx_onnx_cgen/lowering/registry.py +51 -0
  49. emx_onnx_cgen/lowering/reshape.py +188 -0
  50. emx_onnx_cgen/lowering/resize.py +445 -0
  51. emx_onnx_cgen/lowering/rms_normalization.py +67 -0
  52. emx_onnx_cgen/lowering/shape.py +78 -0
  53. emx_onnx_cgen/lowering/size.py +33 -0
  54. emx_onnx_cgen/lowering/slice.py +425 -0
  55. emx_onnx_cgen/lowering/softmax.py +47 -0
  56. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
  57. emx_onnx_cgen/lowering/split.py +150 -0
  58. emx_onnx_cgen/lowering/squeeze.py +161 -0
  59. emx_onnx_cgen/lowering/tile.py +81 -0
  60. emx_onnx_cgen/lowering/transpose.py +46 -0
  61. emx_onnx_cgen/lowering/unsqueeze.py +157 -0
  62. emx_onnx_cgen/lowering/variadic.py +95 -0
  63. emx_onnx_cgen/lowering/where.py +73 -0
  64. emx_onnx_cgen/onnx_import.py +261 -0
  65. emx_onnx_cgen/ops.py +565 -0
  66. emx_onnx_cgen/runtime/__init__.py +1 -0
  67. emx_onnx_cgen/runtime/evaluator.py +2206 -0
  68. emx_onnx_cgen/validation.py +76 -0
  69. emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
  70. emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
  71. emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
  72. emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
  73. emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
  74. shared/__init__.py +2 -0
  75. shared/scalar_functions.py +2405 -0
  76. shared/scalar_types.py +243 -0
@@ -0,0 +1,157 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..codegen.c_emitter import ReshapeOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Initializer, Node
8
+ from .registry import register_lowering
9
+
10
+
11
+ def _value_shape(graph: Graph, name: str, node: Node) -> tuple[int, ...]:
12
+ try:
13
+ return graph.find_value(name).type.shape
14
+ except KeyError as exc:
15
+ raise ShapeInferenceError(
16
+ f"Missing shape for value '{name}' in op {node.op_type}. "
17
+ "Hint: run ONNX shape inference or export with static shapes."
18
+ ) from exc
19
+
20
+
21
+ def _value_dtype(graph: Graph, name: str, node: Node) -> ScalarType:
22
+ try:
23
+ return graph.find_value(name).type.dtype
24
+ except KeyError as exc:
25
+ raise ShapeInferenceError(
26
+ f"Missing dtype for value '{name}' in op {node.op_type}. "
27
+ "Hint: run ONNX shape inference or export with static shapes."
28
+ ) from exc
29
+
30
+
31
+ def _find_initializer(graph: Graph, name: str) -> Initializer | None:
32
+ for initializer in graph.initializers:
33
+ if initializer.name == name:
34
+ return initializer
35
+ return None
36
+
37
+
38
+ def _validate_shape(shape: tuple[int, ...], node: Node, label: str) -> None:
39
+ for dim in shape:
40
+ if dim < 0:
41
+ raise ShapeInferenceError(
42
+ f"{node.op_type} does not support dynamic dims in {label}"
43
+ )
44
+
45
+
46
+ def _normalize_axes(
47
+ axes: list[int], output_rank: int, node: Node
48
+ ) -> tuple[int, ...]:
49
+ normalized: list[int] = []
50
+ for axis in axes:
51
+ if axis < 0:
52
+ axis += output_rank
53
+ if axis < 0 or axis >= output_rank:
54
+ raise ShapeInferenceError(
55
+ f"{node.op_type} axis {axis} is out of range for rank {output_rank}"
56
+ )
57
+ normalized.append(axis)
58
+ if len(set(normalized)) != len(normalized):
59
+ raise ShapeInferenceError(f"{node.op_type} axes must be unique")
60
+ return tuple(sorted(normalized))
61
+
62
+
63
+ def _resolve_axes(graph: Graph, node: Node) -> tuple[int, ...] | None:
64
+ axes_attr = node.attrs.get("axes")
65
+ axes_values: list[int] | None = None
66
+ if len(node.inputs) == 2:
67
+ axes_initializer = _find_initializer(graph, node.inputs[1])
68
+ if axes_initializer is not None:
69
+ if axes_initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
70
+ raise UnsupportedOpError(
71
+ "Unsqueeze axes input must be int64 or int32, "
72
+ f"got {axes_initializer.type.dtype.onnx_name}"
73
+ )
74
+ axes_values = [int(value) for value in axes_initializer.data.reshape(-1)]
75
+ elif axes_attr is not None:
76
+ axes_values = [int(value) for value in axes_attr]
77
+ if axes_values is None and axes_attr is None and len(node.inputs) != 2:
78
+ raise UnsupportedOpError("Unsqueeze requires axes")
79
+ if axes_values is None:
80
+ return None
81
+ if not axes_values:
82
+ raise UnsupportedOpError("Unsqueeze requires non-empty axes")
83
+ return tuple(axes_values)
84
+
85
+
86
+ def _expected_output_shape(
87
+ input_shape: tuple[int, ...], axes: tuple[int, ...], node: Node
88
+ ) -> tuple[int, ...]:
89
+ output_rank = len(input_shape) + len(axes)
90
+ normalized_axes = _normalize_axes(list(axes), output_rank, node)
91
+ output_dims: list[int] = []
92
+ input_index = 0
93
+ for axis in range(output_rank):
94
+ if axis in normalized_axes:
95
+ output_dims.append(1)
96
+ else:
97
+ output_dims.append(input_shape[input_index])
98
+ input_index += 1
99
+ return tuple(output_dims)
100
+
101
+
102
+ @register_lowering("Unsqueeze")
103
+ def lower_unsqueeze(graph: Graph, node: Node) -> ReshapeOp:
104
+ if len(node.outputs) != 1 or len(node.inputs) not in {1, 2}:
105
+ raise UnsupportedOpError("Unsqueeze must have 1 or 2 inputs and 1 output")
106
+ input_shape = _value_shape(graph, node.inputs[0], node)
107
+ output_shape = _value_shape(graph, node.outputs[0], node)
108
+ _validate_shape(input_shape, node, "input")
109
+ _validate_shape(output_shape, node, "output")
110
+ input_dtype = _value_dtype(graph, node.inputs[0], node)
111
+ output_dtype = _value_dtype(graph, node.outputs[0], node)
112
+ if input_dtype != output_dtype:
113
+ raise UnsupportedOpError(
114
+ "Unsqueeze expects matching input/output dtypes, "
115
+ f"got {input_dtype.onnx_name} and {output_dtype.onnx_name}"
116
+ )
117
+ axes = _resolve_axes(graph, node)
118
+ if axes is None:
119
+ if len(node.inputs) == 2:
120
+ axes_dtype = _value_dtype(graph, node.inputs[1], node)
121
+ if axes_dtype not in {ScalarType.I64, ScalarType.I32}:
122
+ raise UnsupportedOpError(
123
+ "Unsqueeze axes input must be int64 or int32, "
124
+ f"got {axes_dtype.onnx_name}"
125
+ )
126
+ if len(output_shape) <= len(input_shape):
127
+ raise ShapeInferenceError(
128
+ "Unsqueeze output rank must exceed input rank"
129
+ )
130
+ input_index = 0
131
+ for dim in output_shape:
132
+ if input_index < len(input_shape) and dim == input_shape[input_index]:
133
+ input_index += 1
134
+ continue
135
+ if dim != 1:
136
+ raise ShapeInferenceError(
137
+ "Unsqueeze output shape must insert ones only"
138
+ )
139
+ if input_index != len(input_shape):
140
+ raise ShapeInferenceError(
141
+ "Unsqueeze output shape must contain input shape in order"
142
+ )
143
+ else:
144
+ expected_shape = _expected_output_shape(input_shape, axes, node)
145
+ if expected_shape != output_shape:
146
+ raise ShapeInferenceError(
147
+ "Unsqueeze output shape must be "
148
+ f"{expected_shape}, got {output_shape}"
149
+ )
150
+ return ReshapeOp(
151
+ input0=node.inputs[0],
152
+ output=node.outputs[0],
153
+ input_shape=input_shape,
154
+ output_shape=output_shape,
155
+ dtype=input_dtype,
156
+ input_dtype=input_dtype,
157
+ )
@@ -0,0 +1,95 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_functions import ScalarFunction
4
+ from shared.scalar_types import ScalarType
5
+
6
+ from ..codegen.c_emitter import MultiInputBinaryOp
7
+ from ..errors import UnsupportedOpError
8
+ from ..ir.model import Graph, Node
9
+ from ..lowering.common import node_dtype, value_dtype, value_shape
10
+ from ..lowering.registry import register_lowering
11
+ from ..ops import binary_op_symbol
12
+
13
+ VARIADIC_OP_FUNCTIONS: dict[str, ScalarFunction] = {
14
+ "Sum": ScalarFunction.ADD,
15
+ "Mean": ScalarFunction.MEAN,
16
+ "Max": ScalarFunction.MAXIMUM,
17
+ "Min": ScalarFunction.MINIMUM,
18
+ "And": ScalarFunction.LOGICAL_AND,
19
+ "Or": ScalarFunction.LOGICAL_OR,
20
+ "Xor": ScalarFunction.LOGICAL_XOR,
21
+ "BitwiseAnd": ScalarFunction.BITWISE_AND,
22
+ "BitwiseOr": ScalarFunction.BITWISE_OR,
23
+ "BitwiseXor": ScalarFunction.BITWISE_XOR,
24
+ }
25
+
26
+ BINARY_ONLY_OPS = {
27
+ "And",
28
+ "Or",
29
+ "Xor",
30
+ "BitwiseAnd",
31
+ "BitwiseOr",
32
+ "BitwiseXor",
33
+ }
34
+
35
+
36
+ def _validate_inputs(
37
+ graph: Graph, node: Node, *, function: ScalarFunction
38
+ ) -> tuple[ScalarType, tuple[int, ...]]:
39
+ if len(node.outputs) != 1:
40
+ raise UnsupportedOpError(f"{node.op_type} must have 1 output")
41
+ if node.op_type in BINARY_ONLY_OPS:
42
+ if len(node.inputs) != 2:
43
+ raise UnsupportedOpError(
44
+ f"{node.op_type} must have exactly 2 inputs"
45
+ )
46
+ elif len(node.inputs) < 2:
47
+ raise UnsupportedOpError(
48
+ f"{node.op_type} must have at least 2 inputs"
49
+ )
50
+ for name in node.inputs:
51
+ if not name:
52
+ raise UnsupportedOpError(f"{node.op_type} input must be provided")
53
+ op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
54
+ output_dtype = value_dtype(graph, node.outputs[0], node)
55
+ if op_dtype != output_dtype:
56
+ raise UnsupportedOpError(
57
+ f"{node.op_type} expects matching input/output dtypes, "
58
+ f"got {op_dtype.onnx_name} and {output_dtype.onnx_name}"
59
+ )
60
+ output_shape = value_shape(graph, node.outputs[0], node)
61
+ for name in node.inputs:
62
+ input_shape = value_shape(graph, name, node)
63
+ if input_shape != output_shape:
64
+ raise UnsupportedOpError(
65
+ f"{node.op_type} expects identical input/output shapes"
66
+ )
67
+ op_spec = binary_op_symbol(function, dtype=op_dtype, validate_attrs=False)
68
+ if op_spec is None:
69
+ raise UnsupportedOpError(
70
+ f"{node.op_type} does not support dtype {op_dtype.onnx_name}"
71
+ )
72
+ return op_dtype, output_shape
73
+
74
+
75
+ def _lower_variadic(graph: Graph, node: Node) -> MultiInputBinaryOp:
76
+ function = VARIADIC_OP_FUNCTIONS[node.op_type]
77
+ op_dtype, output_shape = _validate_inputs(graph, node, function=function)
78
+ op_spec = binary_op_symbol(function, dtype=op_dtype, validate_attrs=False)
79
+ if op_spec is None:
80
+ raise UnsupportedOpError(
81
+ f"{node.op_type} does not support dtype {op_dtype.onnx_name}"
82
+ )
83
+ return MultiInputBinaryOp(
84
+ inputs=tuple(node.inputs),
85
+ output=node.outputs[0],
86
+ function=function,
87
+ operator_kind=op_spec.kind,
88
+ shape=output_shape,
89
+ dtype=op_dtype,
90
+ input_dtype=op_dtype,
91
+ )
92
+
93
+
94
+ for _op_type in VARIADIC_OP_FUNCTIONS:
95
+ register_lowering(_op_type)(_lower_variadic)
@@ -0,0 +1,73 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..codegen.c_emitter import WhereOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .common import value_dtype as _value_dtype
9
+ from .common import value_shape as _value_shape
10
+ from .registry import register_lowering
11
+
12
+
13
+ def _broadcast_shape(shapes: tuple[tuple[int, ...], ...], node: Node) -> tuple[int, ...]:
14
+ if not shapes:
15
+ return ()
16
+ max_rank = max(len(shape) for shape in shapes)
17
+ padded = [
18
+ (1,) * (max_rank - len(shape)) + shape
19
+ for shape in shapes
20
+ ]
21
+ broadcast: list[int] = []
22
+ for dims in zip(*padded):
23
+ dim = max(dims)
24
+ if any(item not in (1, dim) for item in dims):
25
+ raise ShapeInferenceError(
26
+ f"{node.op_type} inputs must be broadcastable, got {shapes}"
27
+ )
28
+ broadcast.append(dim)
29
+ return tuple(broadcast)
30
+
31
+
32
+ @register_lowering("Where")
33
+ def lower_where(graph: Graph, node: Node) -> WhereOp:
34
+ if len(node.inputs) != 3 or len(node.outputs) != 1:
35
+ raise UnsupportedOpError("Where must have 3 inputs and 1 output")
36
+ condition_name, x_name, y_name = node.inputs
37
+ output_name = node.outputs[0]
38
+ condition_dtype = _value_dtype(graph, condition_name, node)
39
+ if condition_dtype != ScalarType.BOOL:
40
+ raise UnsupportedOpError(
41
+ f"Where expects bool condition, got {condition_dtype.onnx_name}"
42
+ )
43
+ x_dtype = _value_dtype(graph, x_name, node)
44
+ y_dtype = _value_dtype(graph, y_name, node)
45
+ output_dtype = _value_dtype(graph, output_name, node)
46
+ if x_dtype != y_dtype or output_dtype != x_dtype:
47
+ raise UnsupportedOpError(
48
+ "Where expects matching input/output dtypes, "
49
+ f"got {x_dtype.onnx_name}, {y_dtype.onnx_name}, {output_dtype.onnx_name}"
50
+ )
51
+ condition_shape = _value_shape(graph, condition_name, node)
52
+ x_shape = _value_shape(graph, x_name, node)
53
+ y_shape = _value_shape(graph, y_name, node)
54
+ output_shape = _value_shape(graph, output_name, node)
55
+ broadcast_shape = _broadcast_shape(
56
+ (condition_shape, x_shape, y_shape),
57
+ node,
58
+ )
59
+ if output_shape != broadcast_shape:
60
+ raise ShapeInferenceError(
61
+ f"Where output shape must be {broadcast_shape}, got {output_shape}"
62
+ )
63
+ return WhereOp(
64
+ condition=condition_name,
65
+ input_x=x_name,
66
+ input_y=y_name,
67
+ output=output_name,
68
+ condition_shape=condition_shape,
69
+ x_shape=x_shape,
70
+ y_shape=y_shape,
71
+ output_shape=output_shape,
72
+ dtype=output_dtype,
73
+ )
@@ -0,0 +1,261 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Iterable, Mapping
4
+
5
+ import onnx
6
+ import numpy as np
7
+ from onnx import helper, numpy_helper, shape_inference
8
+
9
+ from shared.scalar_types import ScalarType
10
+
11
+ from .dtypes import scalar_type_from_onnx
12
+ from .errors import ShapeInferenceError, UnsupportedOpError
13
+ from .ir.model import Graph, Initializer, Node, TensorType, Value
14
+
15
+
16
+ def _normalize_initializer_data(dtype: ScalarType, data: object) -> np.ndarray:
17
+ if isinstance(data, (onnx.TensorProto, onnx.SparseTensorProto)):
18
+ array = numpy_helper.to_array(data)
19
+ elif isinstance(data, np.ndarray):
20
+ array = data
21
+ else:
22
+ array = np.array(data)
23
+ return array.astype(dtype.np_dtype, copy=False)
24
+
25
+
26
+ def _format_elem_type(elem_type: int) -> str:
27
+ try:
28
+ name = onnx.TensorProto.DataType.Name(elem_type)
29
+ except ValueError:
30
+ name = "UNKNOWN"
31
+ return f"{elem_type} ({name})"
32
+
33
+
34
+ def _unsupported_value_type(value_info: onnx.ValueInfoProto) -> UnsupportedOpError:
35
+ value_kind = value_info.type.WhichOneof("value")
36
+ if value_kind is None:
37
+ value_kind = "unknown"
38
+ return UnsupportedOpError(
39
+ f"Unsupported value type '{value_kind}' for '{value_info.name}'. "
40
+ "Hint: export the model with tensor inputs/outputs."
41
+ )
42
+
43
+
44
+ def _tensor_type(
45
+ value_info: onnx.ValueInfoProto,
46
+ *,
47
+ dim_param_override: tuple[str | None, ...] | None = None,
48
+ ) -> TensorType:
49
+ if value_info.type.WhichOneof("value") != "tensor_type":
50
+ raise _unsupported_value_type(value_info)
51
+ tensor_type = value_info.type.tensor_type
52
+ if not tensor_type.HasField("elem_type"):
53
+ raise ShapeInferenceError(f"Missing elem_type for tensor '{value_info.name}'")
54
+ dtype = scalar_type_from_onnx(tensor_type.elem_type)
55
+ if dtype is None:
56
+ raise UnsupportedOpError(
57
+ "Unsupported elem_type "
58
+ f"{_format_elem_type(tensor_type.elem_type)} for tensor '{value_info.name}'."
59
+ )
60
+ shape = []
61
+ dim_params = []
62
+ for dim_index, dim in enumerate(tensor_type.shape.dim):
63
+ dim_param = dim.dim_param if dim.HasField("dim_param") else ""
64
+ if (
65
+ dim_param_override is not None
66
+ and dim_index < len(dim_param_override)
67
+ and dim_param_override[dim_index]
68
+ ):
69
+ dim_param = dim_param_override[dim_index] or ""
70
+ dim_params.append(dim_param or None)
71
+ if not dim.HasField("dim_value"):
72
+ if dim_param:
73
+ shape.append(1)
74
+ continue
75
+ raise ShapeInferenceError(f"Dynamic dim for tensor '{value_info.name}'")
76
+ shape.append(dim.dim_value)
77
+ return TensorType(
78
+ dtype=dtype,
79
+ shape=tuple(shape),
80
+ dim_params=tuple(dim_params),
81
+ )
82
+
83
+
84
+ def _values(
85
+ value_infos: Iterable[onnx.ValueInfoProto],
86
+ *,
87
+ dim_param_by_name: Mapping[str, tuple[str | None, ...]] | None = None,
88
+ ) -> tuple[Value, ...]:
89
+ dim_param_by_name = dim_param_by_name or {}
90
+ return tuple(
91
+ Value(
92
+ name=vi.name,
93
+ type=_tensor_type(
94
+ vi, dim_param_override=dim_param_by_name.get(vi.name)
95
+ ),
96
+ )
97
+ for vi in value_infos
98
+ )
99
+
100
+
101
+ def _collect_dim_params(
102
+ value_infos: Iterable[onnx.ValueInfoProto],
103
+ ) -> dict[str, tuple[str | None, ...]]:
104
+ dim_params: dict[str, tuple[str | None, ...]] = {}
105
+ for value_info in value_infos:
106
+ dims = []
107
+ for dim in value_info.type.tensor_type.shape.dim:
108
+ dim_param = dim.dim_param if dim.HasField("dim_param") else ""
109
+ dims.append(dim_param or None)
110
+ if any(dims):
111
+ dim_params[value_info.name] = tuple(dims)
112
+ return dim_params
113
+
114
+
115
+ def _initializer(value: onnx.TensorProto) -> Initializer:
116
+ dtype = scalar_type_from_onnx(value.data_type)
117
+ if dtype is None:
118
+ raise UnsupportedOpError(
119
+ "Unsupported elem_type "
120
+ f"{_format_elem_type(value.data_type)} for initializer '{value.name}'. "
121
+ "Hint: export the model with float32 initializers."
122
+ )
123
+ data = _normalize_initializer_data(dtype, value)
124
+ return Initializer(
125
+ name=value.name,
126
+ type=TensorType(
127
+ dtype=dtype,
128
+ shape=tuple(data.shape),
129
+ dim_params=(None,) * len(data.shape),
130
+ ),
131
+ data=data,
132
+ )
133
+
134
+
135
+ def _node_attrs(node: onnx.NodeProto) -> dict[str, object]:
136
+ return {attr.name: helper.get_attribute_value(attr) for attr in node.attribute}
137
+
138
+
139
+ def _constant_initializer(node: onnx.NodeProto) -> Initializer:
140
+ if len(node.output) != 1:
141
+ raise UnsupportedOpError("Constant must have exactly one output")
142
+ attrs = _node_attrs(node)
143
+ output_name = node.output[0]
144
+ if "value" in attrs:
145
+ tensor = attrs["value"]
146
+ dtype = scalar_type_from_onnx(tensor.data_type)
147
+ if dtype is None:
148
+ raise UnsupportedOpError(
149
+ "Unsupported elem_type "
150
+ f"{_format_elem_type(tensor.data_type)} for Constant '{output_name}'."
151
+ )
152
+ data = _normalize_initializer_data(dtype, tensor)
153
+ return Initializer(
154
+ name=output_name,
155
+ type=TensorType(
156
+ dtype=dtype,
157
+ shape=tuple(data.shape),
158
+ dim_params=(None,) * len(data.shape),
159
+ ),
160
+ data=data,
161
+ )
162
+ if "sparse_value" in attrs:
163
+ tensor = attrs["sparse_value"]
164
+ dtype = scalar_type_from_onnx(tensor.values.data_type)
165
+ if dtype is None:
166
+ raise UnsupportedOpError(
167
+ "Unsupported elem_type "
168
+ f"{_format_elem_type(tensor.values.data_type)} for Constant '{output_name}'."
169
+ )
170
+ data = _normalize_initializer_data(dtype, tensor)
171
+ return Initializer(
172
+ name=output_name,
173
+ type=TensorType(
174
+ dtype=dtype,
175
+ shape=tuple(data.shape),
176
+ dim_params=(None,) * len(data.shape),
177
+ ),
178
+ data=data,
179
+ )
180
+ if "value_float" in attrs or "value_floats" in attrs:
181
+ values = attrs.get("value_floats", attrs.get("value_float"))
182
+ data = _normalize_initializer_data(ScalarType.F32, values)
183
+ return Initializer(
184
+ name=output_name,
185
+ type=TensorType(
186
+ dtype=ScalarType.F32,
187
+ shape=tuple(data.shape),
188
+ dim_params=(None,) * len(data.shape),
189
+ ),
190
+ data=data,
191
+ )
192
+ if "value_int" in attrs or "value_ints" in attrs:
193
+ values = attrs.get("value_ints", attrs.get("value_int"))
194
+ data = _normalize_initializer_data(ScalarType.I64, values)
195
+ return Initializer(
196
+ name=output_name,
197
+ type=TensorType(
198
+ dtype=ScalarType.I64,
199
+ shape=tuple(data.shape),
200
+ dim_params=(None,) * len(data.shape),
201
+ ),
202
+ data=data,
203
+ )
204
+ if "value_string" in attrs or "value_strings" in attrs:
205
+ raise UnsupportedOpError(
206
+ f"Constant '{output_name}' has unsupported string values"
207
+ )
208
+ raise UnsupportedOpError(f"Constant '{output_name}' requires a value attribute")
209
+
210
+
211
+ def import_onnx(model: onnx.ModelProto) -> Graph:
212
+ dim_param_by_name = _collect_dim_params(
213
+ tuple(model.graph.input) + tuple(model.graph.output)
214
+ )
215
+ try:
216
+ model = shape_inference.infer_shapes(model, data_prop=True)
217
+ except Exception as exc: # pragma: no cover - onnx inference errors
218
+ raise ShapeInferenceError("ONNX shape inference failed") from exc
219
+ graph = model.graph
220
+ base_initializers = [_initializer(value) for value in graph.initializer]
221
+ constant_initializers: list[Initializer] = []
222
+ input_names = {value_info.name for value_info in graph.input}
223
+ output_names = {value_info.name for value_info in graph.output}
224
+ nodes: list[Node] = []
225
+ for node in graph.node:
226
+ if node.op_type == "Constant":
227
+ constant_initializers.append(_constant_initializer(node))
228
+ continue
229
+ nodes.append(
230
+ Node(
231
+ op_type=node.op_type,
232
+ name=node.name or None,
233
+ inputs=tuple(node.input),
234
+ outputs=tuple(node.output),
235
+ attrs=_node_attrs(node),
236
+ )
237
+ )
238
+ initializers = tuple(base_initializers + constant_initializers)
239
+ initializer_names = {initializer.name for initializer in initializers}
240
+ inputs = _values(
241
+ (
242
+ value_info
243
+ for value_info in graph.input
244
+ if value_info.name not in initializer_names
245
+ ),
246
+ dim_param_by_name=dim_param_by_name,
247
+ )
248
+ outputs = _values(graph.output, dim_param_by_name=dim_param_by_name)
249
+ values = _values(
250
+ value_info
251
+ for value_info in graph.value_info
252
+ if value_info.name
253
+ not in initializer_names | input_names | output_names
254
+ )
255
+ return Graph(
256
+ inputs=inputs,
257
+ outputs=outputs,
258
+ nodes=nodes,
259
+ initializers=initializers,
260
+ values=values,
261
+ )