emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (99) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +34 -0
  3. emx_onnx_cgen/cli.py +372 -64
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +3932 -1398
  6. emx_onnx_cgen/codegen/emitter.py +5 -0
  7. emx_onnx_cgen/compiler.py +169 -343
  8. emx_onnx_cgen/ir/context.py +87 -0
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +193 -0
  11. emx_onnx_cgen/ir/op_context.py +65 -0
  12. emx_onnx_cgen/ir/ops/__init__.py +130 -0
  13. emx_onnx_cgen/ir/ops/elementwise.py +146 -0
  14. emx_onnx_cgen/ir/ops/misc.py +421 -0
  15. emx_onnx_cgen/ir/ops/nn.py +580 -0
  16. emx_onnx_cgen/ir/ops/reduce.py +95 -0
  17. emx_onnx_cgen/lowering/__init__.py +79 -1
  18. emx_onnx_cgen/lowering/adagrad.py +114 -0
  19. emx_onnx_cgen/lowering/arg_reduce.py +1 -1
  20. emx_onnx_cgen/lowering/attention.py +1 -1
  21. emx_onnx_cgen/lowering/average_pool.py +1 -1
  22. emx_onnx_cgen/lowering/batch_normalization.py +1 -1
  23. emx_onnx_cgen/lowering/cast.py +1 -1
  24. emx_onnx_cgen/lowering/common.py +406 -11
  25. emx_onnx_cgen/lowering/concat.py +1 -1
  26. emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
  27. emx_onnx_cgen/lowering/conv.py +1 -1
  28. emx_onnx_cgen/lowering/conv_transpose.py +301 -0
  29. emx_onnx_cgen/lowering/cumsum.py +1 -1
  30. emx_onnx_cgen/lowering/depth_space.py +1 -1
  31. emx_onnx_cgen/lowering/dropout.py +1 -1
  32. emx_onnx_cgen/lowering/einsum.py +153 -0
  33. emx_onnx_cgen/lowering/elementwise.py +152 -4
  34. emx_onnx_cgen/lowering/expand.py +1 -1
  35. emx_onnx_cgen/lowering/eye_like.py +1 -1
  36. emx_onnx_cgen/lowering/flatten.py +1 -1
  37. emx_onnx_cgen/lowering/gather.py +1 -1
  38. emx_onnx_cgen/lowering/gather_elements.py +2 -4
  39. emx_onnx_cgen/lowering/gather_nd.py +79 -0
  40. emx_onnx_cgen/lowering/gemm.py +1 -1
  41. emx_onnx_cgen/lowering/global_max_pool.py +59 -0
  42. emx_onnx_cgen/lowering/grid_sample.py +1 -1
  43. emx_onnx_cgen/lowering/group_normalization.py +1 -1
  44. emx_onnx_cgen/lowering/hardmax.py +53 -0
  45. emx_onnx_cgen/lowering/identity.py +7 -6
  46. emx_onnx_cgen/lowering/instance_normalization.py +1 -1
  47. emx_onnx_cgen/lowering/layer_normalization.py +1 -1
  48. emx_onnx_cgen/lowering/logsoftmax.py +6 -2
  49. emx_onnx_cgen/lowering/lp_normalization.py +1 -1
  50. emx_onnx_cgen/lowering/lp_pool.py +141 -0
  51. emx_onnx_cgen/lowering/lrn.py +1 -1
  52. emx_onnx_cgen/lowering/lstm.py +1 -1
  53. emx_onnx_cgen/lowering/matmul.py +7 -8
  54. emx_onnx_cgen/lowering/maxpool.py +1 -1
  55. emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
  56. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +13 -13
  57. emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
  58. emx_onnx_cgen/lowering/nonzero.py +42 -0
  59. emx_onnx_cgen/lowering/one_hot.py +120 -0
  60. emx_onnx_cgen/lowering/pad.py +1 -1
  61. emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
  62. emx_onnx_cgen/lowering/quantize_linear.py +126 -0
  63. emx_onnx_cgen/lowering/range.py +1 -1
  64. emx_onnx_cgen/lowering/reduce.py +6 -7
  65. emx_onnx_cgen/lowering/registry.py +24 -5
  66. emx_onnx_cgen/lowering/reshape.py +224 -52
  67. emx_onnx_cgen/lowering/resize.py +1 -1
  68. emx_onnx_cgen/lowering/rms_normalization.py +1 -1
  69. emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
  70. emx_onnx_cgen/lowering/scatter_nd.py +82 -0
  71. emx_onnx_cgen/lowering/shape.py +6 -25
  72. emx_onnx_cgen/lowering/size.py +1 -1
  73. emx_onnx_cgen/lowering/slice.py +1 -1
  74. emx_onnx_cgen/lowering/softmax.py +6 -2
  75. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
  76. emx_onnx_cgen/lowering/split.py +1 -1
  77. emx_onnx_cgen/lowering/squeeze.py +6 -6
  78. emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
  79. emx_onnx_cgen/lowering/tile.py +1 -1
  80. emx_onnx_cgen/lowering/topk.py +134 -0
  81. emx_onnx_cgen/lowering/transpose.py +1 -1
  82. emx_onnx_cgen/lowering/trilu.py +89 -0
  83. emx_onnx_cgen/lowering/unsqueeze.py +6 -6
  84. emx_onnx_cgen/lowering/variadic.py +1 -1
  85. emx_onnx_cgen/lowering/where.py +1 -1
  86. emx_onnx_cgen/onnx_import.py +4 -0
  87. emx_onnx_cgen/onnxruntime_utils.py +11 -0
  88. emx_onnx_cgen/ops.py +4 -0
  89. emx_onnx_cgen/runtime/evaluator.py +785 -43
  90. emx_onnx_cgen/testbench.py +23 -0
  91. emx_onnx_cgen/verification.py +31 -0
  92. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/METADATA +33 -6
  93. emx_onnx_cgen-0.3.1.dist-info/RECORD +107 -0
  94. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/WHEEL +1 -1
  95. shared/scalar_functions.py +60 -17
  96. shared/ulp.py +65 -0
  97. emx_onnx_cgen-0.2.0.dist-info/RECORD +0 -76
  98. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/entry_points.txt +0 -0
  99. {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,81 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+
1
5
  from .registry import get_lowering, register_lowering
2
6
 
3
- __all__ = ["get_lowering", "register_lowering"]
7
+ _LOWERING_MODULES = [
8
+ "adagrad",
9
+ "arg_reduce",
10
+ "attention",
11
+ "average_pool",
12
+ "batch_normalization",
13
+ "cast",
14
+ "concat",
15
+ "constant_of_shape",
16
+ "conv",
17
+ "conv_transpose",
18
+ "cumsum",
19
+ "depth_space",
20
+ "dropout",
21
+ "einsum",
22
+ "elementwise",
23
+ "expand",
24
+ "eye_like",
25
+ "flatten",
26
+ "gather",
27
+ "gather_elements",
28
+ "gather_nd",
29
+ "gemm",
30
+ "global_max_pool",
31
+ "grid_sample",
32
+ "group_normalization",
33
+ "hardmax",
34
+ "identity",
35
+ "instance_normalization",
36
+ "layer_normalization",
37
+ "logsoftmax",
38
+ "lp_normalization",
39
+ "lp_pool",
40
+ "lrn",
41
+ "lstm",
42
+ "matmul",
43
+ "maxpool",
44
+ "mean_variance_normalization",
45
+ "negative_log_likelihood_loss",
46
+ "non_max_suppression",
47
+ "nonzero",
48
+ "one_hot",
49
+ "pad",
50
+ "qlinear_matmul",
51
+ "quantize_linear",
52
+ "range",
53
+ "reduce",
54
+ "reshape",
55
+ "resize",
56
+ "rms_normalization",
57
+ "rotary_embedding",
58
+ "scatter_nd",
59
+ "shape",
60
+ "size",
61
+ "slice",
62
+ "softmax",
63
+ "softmax_cross_entropy_loss",
64
+ "split",
65
+ "squeeze",
66
+ "tensor_scatter",
67
+ "tile",
68
+ "topk",
69
+ "transpose",
70
+ "trilu",
71
+ "unsqueeze",
72
+ "variadic",
73
+ "where",
74
+ ]
75
+
76
+
77
+ def load_lowering_registry() -> None:
78
+ for module_name in _LOWERING_MODULES:
79
+ importlib.import_module(f"{__name__}.{module_name}")
80
+
81
+ __all__ = ["get_lowering", "register_lowering", "load_lowering_registry"]
@@ -0,0 +1,114 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..ir.ops import AdagradOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .common import value_dtype, value_shape
9
+ from .registry import register_lowering
10
+
11
+
12
+ def _is_scalar_shape(shape: tuple[int, ...]) -> bool:
13
+ return shape == () or shape == (1,)
14
+
15
+
16
+ @register_lowering("Adagrad")
17
+ def lower_adagrad(graph: Graph, node: Node) -> AdagradOp:
18
+ if len(node.inputs) < 5:
19
+ raise UnsupportedOpError("Adagrad must have at least 5 inputs")
20
+ if len(node.outputs) < 2:
21
+ raise UnsupportedOpError("Adagrad must have at least 2 outputs")
22
+ if (len(node.inputs) - 2) % 3 != 0:
23
+ raise UnsupportedOpError(
24
+ "Adagrad inputs must be R, T, Xs, Gs, Hs with matching counts"
25
+ )
26
+ tensor_count = (len(node.inputs) - 2) // 3
27
+ if len(node.outputs) != tensor_count * 2:
28
+ raise UnsupportedOpError(
29
+ "Adagrad outputs must be X_news followed by H_news"
30
+ )
31
+ rate_name = node.inputs[0]
32
+ timestep_name = node.inputs[1]
33
+ rate_shape = value_shape(graph, rate_name, node)
34
+ timestep_shape = value_shape(graph, timestep_name, node)
35
+ if not _is_scalar_shape(rate_shape):
36
+ raise UnsupportedOpError("Adagrad R input must be a scalar")
37
+ if not _is_scalar_shape(timestep_shape):
38
+ raise UnsupportedOpError("Adagrad T input must be a scalar")
39
+ rate_dtype = value_dtype(graph, rate_name, node)
40
+ if rate_dtype not in {ScalarType.F32, ScalarType.F64}:
41
+ raise UnsupportedOpError(
42
+ "Adagrad R input must be float or double"
43
+ )
44
+ timestep_dtype = value_dtype(graph, timestep_name, node)
45
+ if timestep_dtype != ScalarType.I64:
46
+ raise UnsupportedOpError("Adagrad T input must be int64")
47
+
48
+ inputs = node.inputs[2 : 2 + tensor_count]
49
+ gradients = node.inputs[2 + tensor_count : 2 + tensor_count * 2]
50
+ accumulators = node.inputs[2 + tensor_count * 2 : 2 + tensor_count * 3]
51
+ outputs = node.outputs[:tensor_count]
52
+ accumulator_outputs = node.outputs[tensor_count:]
53
+ if not inputs or not gradients or not accumulators:
54
+ raise UnsupportedOpError("Adagrad requires X, G, H inputs")
55
+ dtype = value_dtype(graph, inputs[0], node)
56
+ if dtype not in {ScalarType.F32, ScalarType.F64}:
57
+ raise UnsupportedOpError("Adagrad supports float and double tensors only")
58
+ if rate_dtype != dtype:
59
+ raise UnsupportedOpError(
60
+ "Adagrad R input dtype must match tensor dtype"
61
+ )
62
+ input_shapes: list[tuple[int, ...]] = []
63
+ output_shapes: list[tuple[int, ...]] = []
64
+ for index, (x_name, g_name, h_name, out_name, h_out_name) in enumerate(
65
+ zip(inputs, gradients, accumulators, outputs, accumulator_outputs)
66
+ ):
67
+ x_dtype = value_dtype(graph, x_name, node)
68
+ g_dtype = value_dtype(graph, g_name, node)
69
+ h_dtype = value_dtype(graph, h_name, node)
70
+ out_dtype = value_dtype(graph, out_name, node)
71
+ h_out_dtype = value_dtype(graph, h_out_name, node)
72
+ if {x_dtype, g_dtype, h_dtype, out_dtype, h_out_dtype} != {dtype}:
73
+ raise UnsupportedOpError(
74
+ "Adagrad inputs and outputs must share the same dtype"
75
+ )
76
+ x_shape = value_shape(graph, x_name, node)
77
+ g_shape = value_shape(graph, g_name, node)
78
+ h_shape = value_shape(graph, h_name, node)
79
+ out_shape = value_shape(graph, out_name, node)
80
+ h_out_shape = value_shape(graph, h_out_name, node)
81
+ if x_shape != g_shape or x_shape != h_shape:
82
+ raise ShapeInferenceError(
83
+ f"Adagrad inputs X/G/H shapes must match for tensor {index}"
84
+ )
85
+ if out_shape != x_shape or h_out_shape != x_shape:
86
+ raise ShapeInferenceError(
87
+ f"Adagrad outputs must match X shape for tensor {index}"
88
+ )
89
+ input_shapes.append(x_shape)
90
+ output_shapes.append(out_shape)
91
+
92
+ norm_coefficient = float(node.attrs.get("norm_coefficient", 0.0))
93
+ epsilon = float(node.attrs.get("epsilon", 0.0))
94
+ decay_factor = float(node.attrs.get("decay_factor", 0.0))
95
+
96
+ return AdagradOp(
97
+ rate=rate_name,
98
+ timestep=timestep_name,
99
+ inputs=tuple(inputs),
100
+ gradients=tuple(gradients),
101
+ accumulators=tuple(accumulators),
102
+ outputs=tuple(outputs),
103
+ accumulator_outputs=tuple(accumulator_outputs),
104
+ rate_shape=rate_shape,
105
+ timestep_shape=timestep_shape,
106
+ tensor_shapes=tuple(input_shapes),
107
+ output_shapes=tuple(output_shapes),
108
+ dtype=dtype,
109
+ rate_dtype=rate_dtype,
110
+ timestep_dtype=timestep_dtype,
111
+ norm_coefficient=norm_coefficient,
112
+ epsilon=epsilon,
113
+ decay_factor=decay_factor,
114
+ )
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from shared.scalar_types import ScalarType
4
4
 
5
- from ..codegen.c_emitter import ArgReduceOp
5
+ from ..ir.ops import ArgReduceOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .common import shape_product, value_dtype, value_shape
@@ -5,7 +5,7 @@ from dataclasses import dataclass
5
5
 
6
6
  from shared.scalar_types import ScalarType
7
7
 
8
- from ..codegen.c_emitter import AttentionOp
8
+ from ..ir.ops import AttentionOp
9
9
  from ..errors import ShapeInferenceError, UnsupportedOpError
10
10
  from ..ir.model import Graph, Node
11
11
  from .common import node_dtype as _node_dtype
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
4
 
5
- from ..codegen.c_emitter import AveragePoolOp
5
+ from ..ir.ops import AveragePoolOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .registry import register_lowering
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
4
 
5
- from ..codegen.c_emitter import BatchNormOp
5
+ from ..ir.ops import BatchNormOp
6
6
  from ..errors import ShapeInferenceError, UnsupportedOpError
7
7
  from ..ir.model import Graph, Node
8
8
  from .registry import register_lowering
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import onnx
4
4
 
5
- from ..codegen.c_emitter import CastOp
5
+ from ..ir.ops import CastOp
6
6
  from ..dtypes import scalar_type_from_onnx
7
7
  from ..errors import ShapeInferenceError, UnsupportedOpError
8
8
  from ..ir.model import Graph, Node
@@ -5,7 +5,8 @@ from collections.abc import Sequence
5
5
  from shared.scalar_types import ScalarType
6
6
 
7
7
  from ..errors import ShapeInferenceError, UnsupportedOpError
8
- from ..ir.model import Graph, Node
8
+ from ..ir.context import GraphContext
9
+ from ..ir.model import Graph, Initializer, Node
9
10
 
10
11
 
11
12
  def ensure_supported_dtype(dtype: ScalarType) -> ScalarType:
@@ -14,7 +15,24 @@ def ensure_supported_dtype(dtype: ScalarType) -> ScalarType:
14
15
  return dtype
15
16
 
16
17
 
17
- def value_dtype(graph: Graph, name: str, node: Node | None = None) -> ScalarType:
18
+ def onnx_opset_version(graph: Graph | GraphContext, domain: str = "") -> int | None:
19
+ if isinstance(graph, GraphContext):
20
+ return graph.opset_version(domain)
21
+ if domain in {"", "ai.onnx"}:
22
+ domains = {"", "ai.onnx"}
23
+ else:
24
+ domains = {domain}
25
+ for opset_domain, version in graph.opset_imports:
26
+ if opset_domain in domains:
27
+ return int(version)
28
+ return None
29
+
30
+
31
+ def value_dtype(
32
+ graph: Graph | GraphContext, name: str, node: Node | None = None
33
+ ) -> ScalarType:
34
+ if isinstance(graph, GraphContext):
35
+ return graph.dtype(name, node)
18
36
  try:
19
37
  value = graph.find_value(name)
20
38
  except KeyError as exc:
@@ -26,18 +44,395 @@ def value_dtype(graph: Graph, name: str, node: Node | None = None) -> ScalarType
26
44
  return ensure_supported_dtype(value.type.dtype)
27
45
 
28
46
 
29
- def value_shape(graph: Graph, name: str, node: Node | None = None) -> tuple[int, ...]:
47
+ def value_shape(
48
+ graph: Graph | GraphContext, name: str, node: Node | None = None
49
+ ) -> tuple[int, ...]:
50
+ if isinstance(graph, GraphContext):
51
+ shape = graph.shape(name, node)
52
+ value = graph.find_value(name)
53
+ else:
54
+ try:
55
+ value = graph.find_value(name)
56
+ except KeyError as exc:
57
+ op_type = node.op_type if node is not None else "unknown"
58
+ raise ShapeInferenceError(
59
+ f"Missing shape for value '{name}' in op {op_type}. "
60
+ "Hint: run ONNX shape inference or export with static shapes."
61
+ ) from exc
62
+ shape = value.type.shape
63
+ if any(value.type.dim_params):
64
+ resolved = _resolve_value_shape(graph, name, node)
65
+ if resolved is not None:
66
+ return resolved
67
+ return value.type.shape
68
+ return shape
69
+
70
+
71
+ def _find_initializer(graph: Graph | GraphContext, name: str) -> Initializer | None:
72
+ if isinstance(graph, GraphContext):
73
+ return graph.initializer(name)
74
+ for initializer in graph.initializers:
75
+ if initializer.name == name:
76
+ return initializer
77
+ return None
78
+
79
+
80
+ def _find_node_by_output(graph: Graph | GraphContext, name: str) -> Node | None:
81
+ if isinstance(graph, GraphContext):
82
+ return graph.producer(name)
83
+ for node in graph.nodes:
84
+ if name in node.outputs:
85
+ return node
86
+ return None
87
+
88
+
89
+ def _shape_values_from_shape_node(
90
+ graph: Graph | GraphContext, shape_node: Node, node: Node | None
91
+ ) -> list[int]:
92
+ if len(shape_node.inputs) != 1 or len(shape_node.outputs) != 1:
93
+ raise UnsupportedOpError("Shape must have 1 input and 1 output")
94
+ source_shape = value_shape(graph, shape_node.inputs[0], node)
95
+ start = int(shape_node.attrs.get("start", 0))
96
+ end = int(shape_node.attrs.get("end", len(source_shape)))
97
+ if start < 0:
98
+ start += len(source_shape)
99
+ if end < 0:
100
+ end += len(source_shape)
101
+ start = max(start, 0)
102
+ end = min(end, len(source_shape))
103
+ if start > end:
104
+ return []
105
+ return list(source_shape[start:end])
106
+
107
+
108
+ def _shape_values_from_initializer(
109
+ graph: Graph | GraphContext,
110
+ name: str,
111
+ ) -> list[int] | None:
112
+ initializer = _find_initializer(graph, name)
113
+ if initializer is None:
114
+ return None
115
+ if initializer.type.dtype not in {ScalarType.I64, ScalarType.I32}:
116
+ raise UnsupportedOpError(
117
+ "Reshape expects int64 or int32 shape input, "
118
+ f"got {initializer.type.dtype.onnx_name}"
119
+ )
120
+ return [int(value) for value in initializer.data.reshape(-1)]
121
+
122
+
123
+ def _shape_values_from_input(
124
+ graph: Graph | GraphContext,
125
+ name: str,
126
+ node: Node | None,
127
+ *,
128
+ _visited: set[str] | None = None,
129
+ ) -> list[int] | None:
130
+ if _visited is None:
131
+ _visited = set()
132
+ if name in _visited:
133
+ return None
134
+ _visited.add(name)
30
135
  try:
31
- return graph.find_value(name).type.shape
32
- except KeyError as exc:
33
- op_type = node.op_type if node is not None else "unknown"
34
- raise ShapeInferenceError(
35
- f"Missing shape for value '{name}' in op {op_type}. "
36
- "Hint: run ONNX shape inference or export with static shapes."
37
- ) from exc
136
+ shape_values = _shape_values_from_initializer(graph, name)
137
+ if shape_values is not None:
138
+ return shape_values
139
+ source_node = _find_node_by_output(graph, name)
140
+ if source_node is None:
141
+ return None
142
+ if source_node.op_type == "Shape":
143
+ return _shape_values_from_shape_node(graph, source_node, node)
144
+ if source_node.op_type == "Concat":
145
+ axis = int(source_node.attrs.get("axis", 0))
146
+ if axis != 0:
147
+ raise UnsupportedOpError("Reshape shape concat must use axis 0")
148
+ values: list[int] = []
149
+ for input_name in source_node.inputs:
150
+ input_values = _shape_values_from_input(
151
+ graph,
152
+ input_name,
153
+ node,
154
+ _visited=_visited,
155
+ )
156
+ if input_values is None:
157
+ return None
158
+ values.extend(input_values)
159
+ return values
160
+ if source_node.op_type == "Cast":
161
+ if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
162
+ raise UnsupportedOpError("Cast must have 1 input and 1 output")
163
+ return _shape_values_from_input(
164
+ graph,
165
+ source_node.inputs[0],
166
+ node,
167
+ _visited=_visited,
168
+ )
169
+ if source_node.op_type == "Unsqueeze":
170
+ if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
171
+ raise UnsupportedOpError("Unsqueeze must have 1 input and 1 output")
172
+ return _shape_values_from_input(
173
+ graph,
174
+ source_node.inputs[0],
175
+ node,
176
+ _visited=_visited,
177
+ )
178
+ if source_node.op_type == "Identity":
179
+ if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
180
+ raise UnsupportedOpError("Identity must have 1 input and 1 output")
181
+ return _shape_values_from_input(
182
+ graph,
183
+ source_node.inputs[0],
184
+ node,
185
+ _visited=_visited,
186
+ )
187
+ if source_node.op_type in {"Equal", "And", "Or", "Div", "Mod"}:
188
+ if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
189
+ raise UnsupportedOpError(
190
+ f"{source_node.op_type} must have 2 inputs and 1 output"
191
+ )
192
+ left = _shape_values_from_input(
193
+ graph,
194
+ source_node.inputs[0],
195
+ node,
196
+ _visited=_visited,
197
+ )
198
+ right = _shape_values_from_input(
199
+ graph,
200
+ source_node.inputs[1],
201
+ node,
202
+ _visited=_visited,
203
+ )
204
+ if left is None or right is None:
205
+ return None
206
+ if len(left) == 1 and len(right) != 1:
207
+ left = left * len(right)
208
+ if len(right) == 1 and len(left) != 1:
209
+ right = right * len(left)
210
+ if len(left) != len(right):
211
+ return None
212
+ if source_node.op_type == "Equal":
213
+ return [1 if l == r else 0 for l, r in zip(left, right)]
214
+ if source_node.op_type == "And":
215
+ return [1 if (l and r) else 0 for l, r in zip(left, right)]
216
+ if source_node.op_type == "Or":
217
+ return [1 if (l or r) else 0 for l, r in zip(left, right)]
218
+ if source_node.op_type == "Div":
219
+ return [int(l / r) if r != 0 else 0 for l, r in zip(left, right)]
220
+ if source_node.op_type == "Mod":
221
+ return [l % r if r != 0 else 0 for l, r in zip(left, right)]
222
+ if source_node.op_type == "Not":
223
+ if len(source_node.inputs) != 1 or len(source_node.outputs) != 1:
224
+ raise UnsupportedOpError("Not must have 1 input and 1 output")
225
+ values = _shape_values_from_input(
226
+ graph,
227
+ source_node.inputs[0],
228
+ node,
229
+ _visited=_visited,
230
+ )
231
+ if values is None:
232
+ return None
233
+ return [0 if value else 1 for value in values]
234
+ if source_node.op_type == "Where":
235
+ if len(source_node.inputs) != 3 or len(source_node.outputs) != 1:
236
+ raise UnsupportedOpError("Where must have 3 inputs and 1 output")
237
+ condition = _shape_values_from_input(
238
+ graph,
239
+ source_node.inputs[0],
240
+ node,
241
+ _visited=_visited,
242
+ )
243
+ if condition is None:
244
+ return None
245
+ on_true = _shape_values_from_input(
246
+ graph,
247
+ source_node.inputs[1],
248
+ node,
249
+ _visited=_visited,
250
+ )
251
+ on_false = _shape_values_from_input(
252
+ graph,
253
+ source_node.inputs[2],
254
+ node,
255
+ _visited=_visited,
256
+ )
257
+ if on_true is None or on_false is None:
258
+ return None
259
+ if len(condition) == 1:
260
+ condition = condition * max(len(on_true), len(on_false))
261
+ if len(on_true) == 1 and len(condition) != 1:
262
+ on_true = on_true * len(condition)
263
+ if len(on_false) == 1 and len(condition) != 1:
264
+ on_false = on_false * len(condition)
265
+ if not (len(condition) == len(on_true) == len(on_false)):
266
+ return None
267
+ return [
268
+ t if cond else f
269
+ for cond, t, f in zip(condition, on_true, on_false)
270
+ ]
271
+ return None
272
+ finally:
273
+ _visited.remove(name)
274
+
275
+
276
+ def _broadcast_shapes(
277
+ left: tuple[int, ...],
278
+ right: tuple[int, ...],
279
+ ) -> tuple[int, ...] | None:
280
+ result = []
281
+ left_rev = list(reversed(left))
282
+ right_rev = list(reversed(right))
283
+ for index in range(max(len(left_rev), len(right_rev))):
284
+ left_dim = left_rev[index] if index < len(left_rev) else 1
285
+ right_dim = right_rev[index] if index < len(right_rev) else 1
286
+ if left_dim == right_dim:
287
+ result.append(left_dim)
288
+ elif left_dim == 1:
289
+ result.append(right_dim)
290
+ elif right_dim == 1:
291
+ result.append(left_dim)
292
+ else:
293
+ return None
294
+ return tuple(reversed(result))
295
+
296
+
297
+ def _resolve_value_shape(
298
+ graph: Graph | GraphContext,
299
+ name: str,
300
+ node: Node | None,
301
+ *,
302
+ _visited: set[str] | None = None,
303
+ ) -> tuple[int, ...] | None:
304
+ if _visited is None:
305
+ _visited = set()
306
+ if name in _visited:
307
+ return None
308
+ _visited.add(name)
309
+ try:
310
+ value = graph.find_value(name)
311
+ shape = value.type.shape
312
+ if not any(value.type.dim_params):
313
+ return shape
314
+ source_node = _find_node_by_output(graph, name)
315
+ if source_node is None:
316
+ return None
317
+ if source_node.op_type == "Expand":
318
+ if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
319
+ raise UnsupportedOpError("Expand must have 2 inputs and 1 output")
320
+ shape_values = _shape_values_from_input(
321
+ graph, source_node.inputs[1], node
322
+ )
323
+ if shape_values is not None and all(dim >= 0 for dim in shape_values):
324
+ return tuple(shape_values)
325
+ return None
326
+ if source_node.op_type == "Reshape":
327
+ if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
328
+ raise UnsupportedOpError("Reshape must have 2 inputs and 1 output")
329
+ shape_values = _shape_values_from_input(
330
+ graph, source_node.inputs[1], node
331
+ )
332
+ if shape_values is None:
333
+ return None
334
+ allowzero = int(source_node.attrs.get("allowzero", 0))
335
+ input_shape = _resolve_value_shape(
336
+ graph,
337
+ source_node.inputs[0],
338
+ node,
339
+ _visited=_visited,
340
+ )
341
+ if input_shape is None:
342
+ return None
343
+ output_dims: list[int] = []
344
+ unknown_index: int | None = None
345
+ known_product = 1
346
+ contains_zero = False
347
+ for index, dim in enumerate(shape_values):
348
+ if dim == -1:
349
+ if unknown_index is not None:
350
+ return None
351
+ unknown_index = len(output_dims)
352
+ output_dims.append(-1)
353
+ else:
354
+ if dim == 0:
355
+ contains_zero = True
356
+ if allowzero == 0:
357
+ if index >= len(input_shape):
358
+ return None
359
+ dim = input_shape[index]
360
+ if dim < 0:
361
+ return None
362
+ output_dims.append(dim)
363
+ known_product *= dim
364
+ if allowzero == 1 and contains_zero and unknown_index is not None:
365
+ return None
366
+ input_product = shape_product(input_shape)
367
+ if unknown_index is not None:
368
+ if known_product == 0:
369
+ if input_product != 0:
370
+ return None
371
+ output_dims[unknown_index] = 0
372
+ else:
373
+ if input_product % known_product != 0:
374
+ return None
375
+ output_dims[unknown_index] = input_product // known_product
376
+ return tuple(output_dims)
377
+ if source_node.op_type in {
378
+ "Add",
379
+ "Sub",
380
+ "Mul",
381
+ "Div",
382
+ "Pow",
383
+ "Mod",
384
+ "And",
385
+ "Or",
386
+ "Xor",
387
+ "Equal",
388
+ "Greater",
389
+ "Less",
390
+ "GreaterOrEqual",
391
+ "LessOrEqual",
392
+ }:
393
+ if len(source_node.inputs) != 2 or len(source_node.outputs) != 1:
394
+ raise UnsupportedOpError(
395
+ f"{source_node.op_type} must have 2 inputs and 1 output"
396
+ )
397
+ left = _resolve_value_shape(
398
+ graph,
399
+ source_node.inputs[0],
400
+ node,
401
+ _visited=_visited,
402
+ )
403
+ right = _resolve_value_shape(
404
+ graph,
405
+ source_node.inputs[1],
406
+ node,
407
+ _visited=_visited,
408
+ )
409
+ if left is None or right is None:
410
+ return None
411
+ return _broadcast_shapes(left, right)
412
+ if source_node.op_type == "Where":
413
+ if len(source_node.inputs) != 3 or len(source_node.outputs) != 1:
414
+ raise UnsupportedOpError("Where must have 3 inputs and 1 output")
415
+ on_true = _resolve_value_shape(
416
+ graph,
417
+ source_node.inputs[1],
418
+ node,
419
+ _visited=_visited,
420
+ )
421
+ on_false = _resolve_value_shape(
422
+ graph,
423
+ source_node.inputs[2],
424
+ node,
425
+ _visited=_visited,
426
+ )
427
+ if on_true is None or on_false is None:
428
+ return None
429
+ return _broadcast_shapes(on_true, on_false)
430
+ return None
431
+ finally:
432
+ _visited.remove(name)
38
433
 
39
434
 
40
- def node_dtype(graph: Graph, node: Node, *names: str) -> ScalarType:
435
+ def node_dtype(graph: Graph | GraphContext, node: Node, *names: str) -> ScalarType:
41
436
  filtered = [name for name in names if name]
42
437
  if not filtered:
43
438
  raise UnsupportedOpError(
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..codegen.c_emitter import ConcatOp
3
+ from ..ir.ops import ConcatOp
4
4
  from ..errors import UnsupportedOpError
5
5
  from ..ir.model import Graph, Node
6
6
  from .common import node_dtype as _node_dtype