emx-onnx-cgen 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (76) hide show
  1. emx_onnx_cgen/__init__.py +6 -0
  2. emx_onnx_cgen/__main__.py +9 -0
  3. emx_onnx_cgen/_build_info.py +3 -0
  4. emx_onnx_cgen/cli.py +328 -0
  5. emx_onnx_cgen/codegen/__init__.py +25 -0
  6. emx_onnx_cgen/codegen/c_emitter.py +9044 -0
  7. emx_onnx_cgen/compiler.py +601 -0
  8. emx_onnx_cgen/dtypes.py +40 -0
  9. emx_onnx_cgen/errors.py +14 -0
  10. emx_onnx_cgen/ir/__init__.py +3 -0
  11. emx_onnx_cgen/ir/model.py +55 -0
  12. emx_onnx_cgen/lowering/__init__.py +3 -0
  13. emx_onnx_cgen/lowering/arg_reduce.py +99 -0
  14. emx_onnx_cgen/lowering/attention.py +421 -0
  15. emx_onnx_cgen/lowering/average_pool.py +229 -0
  16. emx_onnx_cgen/lowering/batch_normalization.py +116 -0
  17. emx_onnx_cgen/lowering/cast.py +70 -0
  18. emx_onnx_cgen/lowering/common.py +72 -0
  19. emx_onnx_cgen/lowering/concat.py +31 -0
  20. emx_onnx_cgen/lowering/constant_of_shape.py +85 -0
  21. emx_onnx_cgen/lowering/conv.py +192 -0
  22. emx_onnx_cgen/lowering/cumsum.py +118 -0
  23. emx_onnx_cgen/lowering/depth_space.py +114 -0
  24. emx_onnx_cgen/lowering/dropout.py +46 -0
  25. emx_onnx_cgen/lowering/elementwise.py +164 -0
  26. emx_onnx_cgen/lowering/expand.py +151 -0
  27. emx_onnx_cgen/lowering/eye_like.py +43 -0
  28. emx_onnx_cgen/lowering/flatten.py +60 -0
  29. emx_onnx_cgen/lowering/gather.py +48 -0
  30. emx_onnx_cgen/lowering/gather_elements.py +60 -0
  31. emx_onnx_cgen/lowering/gemm.py +139 -0
  32. emx_onnx_cgen/lowering/grid_sample.py +149 -0
  33. emx_onnx_cgen/lowering/group_normalization.py +68 -0
  34. emx_onnx_cgen/lowering/identity.py +43 -0
  35. emx_onnx_cgen/lowering/instance_normalization.py +50 -0
  36. emx_onnx_cgen/lowering/layer_normalization.py +110 -0
  37. emx_onnx_cgen/lowering/logsoftmax.py +47 -0
  38. emx_onnx_cgen/lowering/lp_normalization.py +45 -0
  39. emx_onnx_cgen/lowering/lrn.py +104 -0
  40. emx_onnx_cgen/lowering/lstm.py +355 -0
  41. emx_onnx_cgen/lowering/matmul.py +120 -0
  42. emx_onnx_cgen/lowering/maxpool.py +195 -0
  43. emx_onnx_cgen/lowering/mean_variance_normalization.py +49 -0
  44. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +250 -0
  45. emx_onnx_cgen/lowering/pad.py +287 -0
  46. emx_onnx_cgen/lowering/range.py +104 -0
  47. emx_onnx_cgen/lowering/reduce.py +544 -0
  48. emx_onnx_cgen/lowering/registry.py +51 -0
  49. emx_onnx_cgen/lowering/reshape.py +188 -0
  50. emx_onnx_cgen/lowering/resize.py +445 -0
  51. emx_onnx_cgen/lowering/rms_normalization.py +67 -0
  52. emx_onnx_cgen/lowering/shape.py +78 -0
  53. emx_onnx_cgen/lowering/size.py +33 -0
  54. emx_onnx_cgen/lowering/slice.py +425 -0
  55. emx_onnx_cgen/lowering/softmax.py +47 -0
  56. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +129 -0
  57. emx_onnx_cgen/lowering/split.py +150 -0
  58. emx_onnx_cgen/lowering/squeeze.py +161 -0
  59. emx_onnx_cgen/lowering/tile.py +81 -0
  60. emx_onnx_cgen/lowering/transpose.py +46 -0
  61. emx_onnx_cgen/lowering/unsqueeze.py +157 -0
  62. emx_onnx_cgen/lowering/variadic.py +95 -0
  63. emx_onnx_cgen/lowering/where.py +73 -0
  64. emx_onnx_cgen/onnx_import.py +261 -0
  65. emx_onnx_cgen/ops.py +565 -0
  66. emx_onnx_cgen/runtime/__init__.py +1 -0
  67. emx_onnx_cgen/runtime/evaluator.py +2206 -0
  68. emx_onnx_cgen/validation.py +76 -0
  69. emx_onnx_cgen-0.2.0.dist-info/METADATA +128 -0
  70. emx_onnx_cgen-0.2.0.dist-info/RECORD +76 -0
  71. emx_onnx_cgen-0.2.0.dist-info/WHEEL +5 -0
  72. emx_onnx_cgen-0.2.0.dist-info/entry_points.txt +2 -0
  73. emx_onnx_cgen-0.2.0.dist-info/top_level.txt +2 -0
  74. shared/__init__.py +2 -0
  75. shared/scalar_functions.py +2405 -0
  76. shared/scalar_types.py +243 -0
@@ -0,0 +1,55 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Mapping, Sequence
5
+
6
+ import numpy as np
7
+
8
+ from shared.scalar_types import ScalarType
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class TensorType:
13
+ dtype: ScalarType
14
+ shape: tuple[int, ...]
15
+ dim_params: tuple[str | None, ...]
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class Value:
20
+ name: str
21
+ type: TensorType
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class Node:
26
+ op_type: str
27
+ name: str | None
28
+ inputs: tuple[str, ...]
29
+ outputs: tuple[str, ...]
30
+ attrs: Mapping[str, object]
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class Initializer:
35
+ name: str
36
+ type: TensorType
37
+ data: np.ndarray
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class Graph:
42
+ inputs: tuple[Value, ...]
43
+ outputs: tuple[Value, ...]
44
+ nodes: tuple[Node, ...]
45
+ initializers: tuple[Initializer, ...]
46
+ values: tuple[Value, ...] = ()
47
+
48
+ def find_value(self, name: str) -> Value:
49
+ for value in self.inputs + self.outputs + self.values:
50
+ if value.name == name:
51
+ return value
52
+ for initializer in self.initializers:
53
+ if initializer.name == name:
54
+ return Value(name=initializer.name, type=initializer.type)
55
+ raise KeyError(name)
@@ -0,0 +1,3 @@
1
+ from .registry import get_lowering, register_lowering
2
+
3
+ __all__ = ["get_lowering", "register_lowering"]
@@ -0,0 +1,99 @@
1
+ from __future__ import annotations
2
+
3
+ from shared.scalar_types import ScalarType
4
+
5
+ from ..codegen.c_emitter import ArgReduceOp
6
+ from ..errors import ShapeInferenceError, UnsupportedOpError
7
+ from ..ir.model import Graph, Node
8
+ from .common import shape_product, value_dtype, value_shape
9
+ from .registry import register_lowering
10
+
11
+ ARG_REDUCE_KIND_BY_OP = {"ArgMax": "max", "ArgMin": "min"}
12
+
13
+
14
+ def _normalize_axis(axis: int, rank: int, node: Node) -> int:
15
+ if rank <= 0:
16
+ raise ShapeInferenceError(
17
+ f"{node.op_type} requires input rank >= 1, got {rank}"
18
+ )
19
+ if axis < 0:
20
+ axis += rank
21
+ if axis < 0 or axis >= rank:
22
+ raise ShapeInferenceError(
23
+ f"{node.op_type} axis {axis} is out of range for rank {rank}"
24
+ )
25
+ return axis
26
+
27
+
28
+ def _output_shape(
29
+ input_shape: tuple[int, ...], axis: int, keepdims: bool
30
+ ) -> tuple[int, ...]:
31
+ if keepdims:
32
+ return tuple(1 if idx == axis else dim for idx, dim in enumerate(input_shape))
33
+ return tuple(dim for idx, dim in enumerate(input_shape) if idx != axis)
34
+
35
+
36
+ def _arg_reduce_dtype_supported(dtype: ScalarType) -> bool:
37
+ return dtype in {
38
+ ScalarType.F16,
39
+ ScalarType.F32,
40
+ ScalarType.F64,
41
+ ScalarType.I64,
42
+ ScalarType.I32,
43
+ ScalarType.I16,
44
+ ScalarType.I8,
45
+ ScalarType.U64,
46
+ ScalarType.U32,
47
+ ScalarType.U16,
48
+ ScalarType.U8,
49
+ }
50
+
51
+
52
+ def lower_arg_reduce(graph: Graph, node: Node) -> ArgReduceOp:
53
+ if node.op_type not in ARG_REDUCE_KIND_BY_OP:
54
+ raise UnsupportedOpError(f"Unsupported op {node.op_type}")
55
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
56
+ raise UnsupportedOpError(
57
+ f"{node.op_type} must have 1 input and 1 output"
58
+ )
59
+ input_name = node.inputs[0]
60
+ output_name = node.outputs[0]
61
+ input_shape = value_shape(graph, input_name, node)
62
+ shape_product(input_shape)
63
+ rank = len(input_shape)
64
+ axis = int(node.attrs.get("axis", 0))
65
+ axis = _normalize_axis(axis, rank, node)
66
+ keepdims = bool(int(node.attrs.get("keepdims", 1)))
67
+ select_last_index = bool(int(node.attrs.get("select_last_index", 0)))
68
+ expected_output_shape = _output_shape(input_shape, axis, keepdims)
69
+ output_shape = value_shape(graph, output_name, node)
70
+ if output_shape != expected_output_shape:
71
+ raise ShapeInferenceError(
72
+ f"{node.op_type} output shape must be {expected_output_shape}, got {output_shape}"
73
+ )
74
+ input_dtype = value_dtype(graph, input_name, node)
75
+ if not _arg_reduce_dtype_supported(input_dtype):
76
+ raise UnsupportedOpError(
77
+ f"{node.op_type} does not support dtype {input_dtype.onnx_name}"
78
+ )
79
+ output_dtype = value_dtype(graph, output_name, node)
80
+ if output_dtype != ScalarType.I64:
81
+ raise UnsupportedOpError(
82
+ f"{node.op_type} expects output dtype int64, got {output_dtype.onnx_name}"
83
+ )
84
+ return ArgReduceOp(
85
+ input0=input_name,
86
+ output=output_name,
87
+ input_shape=input_shape,
88
+ output_shape=output_shape,
89
+ axis=axis,
90
+ keepdims=keepdims,
91
+ select_last_index=select_last_index,
92
+ reduce_kind=ARG_REDUCE_KIND_BY_OP[node.op_type],
93
+ input_dtype=input_dtype,
94
+ output_dtype=output_dtype,
95
+ )
96
+
97
+
98
+ register_lowering("ArgMax")(lower_arg_reduce)
99
+ register_lowering("ArgMin")(lower_arg_reduce)
@@ -0,0 +1,421 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+
6
+ from shared.scalar_types import ScalarType
7
+
8
+ from ..codegen.c_emitter import AttentionOp
9
+ from ..errors import ShapeInferenceError, UnsupportedOpError
10
+ from ..ir.model import Graph, Node
11
+ from .common import node_dtype as _node_dtype
12
+ from .common import optional_name as _optional_name
13
+ from .common import value_dtype as _value_dtype
14
+ from .common import value_shape as _value_shape
15
+ from .registry import register_lowering
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class AttentionSpec:
20
+ batch: int
21
+ q_heads: int
22
+ kv_heads: int
23
+ q_seq: int
24
+ kv_seq: int
25
+ total_seq: int
26
+ past_seq: int
27
+ qk_head_size: int
28
+ v_head_size: int
29
+ q_hidden_size: int | None
30
+ k_hidden_size: int | None
31
+ v_hidden_size: int | None
32
+ scale: float
33
+ is_causal: bool
34
+ softcap: float
35
+ qk_matmul_output_mode: int
36
+ q_rank: int
37
+ k_rank: int
38
+ v_rank: int
39
+ output_rank: int
40
+ mask_shape: tuple[int, ...] | None
41
+ mask_is_bool: bool
42
+ mask_rank: int | None
43
+ mask_broadcast_batch: bool
44
+ mask_broadcast_heads: bool
45
+ mask_broadcast_q_seq: bool
46
+ mask_q_seq: int | None
47
+ mask_kv_seq: int | None
48
+ head_group_size: int
49
+ has_attn_mask: bool
50
+ has_past: bool
51
+ has_present: bool
52
+ has_nonpad: bool
53
+
54
+
55
+ def resolve_attention_spec(
56
+ graph: Graph, node: Node, dtype: ScalarType
57
+ ) -> AttentionSpec:
58
+ if not dtype.is_float:
59
+ raise UnsupportedOpError("Unsupported op Attention")
60
+ if len(node.inputs) < 3 or len(node.outputs) < 1:
61
+ raise UnsupportedOpError("Unsupported op Attention")
62
+ supported_attrs = {
63
+ "scale",
64
+ "is_causal",
65
+ "q_num_heads",
66
+ "kv_num_heads",
67
+ "softmax_precision",
68
+ "softcap",
69
+ "qk_matmul_output_mode",
70
+ }
71
+ if set(node.attrs) - supported_attrs:
72
+ raise UnsupportedOpError("Unsupported op Attention")
73
+ q_shape = _value_shape(graph, node.inputs[0], node)
74
+ k_shape = _value_shape(graph, node.inputs[1], node)
75
+ v_shape = _value_shape(graph, node.inputs[2], node)
76
+ q_rank = len(q_shape)
77
+ k_rank = len(k_shape)
78
+ v_rank = len(v_shape)
79
+ if q_rank not in {3, 4} or k_rank not in {3, 4} or v_rank not in {3, 4}:
80
+ raise UnsupportedOpError("Unsupported op Attention")
81
+ if q_rank != k_rank or q_rank != v_rank:
82
+ raise UnsupportedOpError("Unsupported op Attention")
83
+ batch = q_shape[0]
84
+ if batch != k_shape[0] or batch != v_shape[0]:
85
+ raise ShapeInferenceError("Attention batch sizes must match")
86
+ q_hidden_size = None
87
+ k_hidden_size = None
88
+ v_hidden_size = None
89
+ if q_rank == 3:
90
+ q_heads = node.attrs.get("q_num_heads")
91
+ kv_heads = node.attrs.get("kv_num_heads")
92
+ if q_heads is None or kv_heads is None:
93
+ raise UnsupportedOpError("Unsupported op Attention")
94
+ q_heads = int(q_heads)
95
+ kv_heads = int(kv_heads)
96
+ q_seq = q_shape[1]
97
+ kv_seq = k_shape[1]
98
+ if kv_seq != v_shape[1]:
99
+ raise ShapeInferenceError(
100
+ "Attention key/value sequence lengths must match"
101
+ )
102
+ q_hidden_size = q_shape[2]
103
+ k_hidden_size = k_shape[2]
104
+ v_hidden_size = v_shape[2]
105
+ if q_hidden_size % q_heads != 0:
106
+ raise ShapeInferenceError(
107
+ "Attention query hidden size must be divisible by q_num_heads"
108
+ )
109
+ if k_hidden_size % kv_heads != 0:
110
+ raise ShapeInferenceError(
111
+ "Attention key hidden size must be divisible by kv_num_heads"
112
+ )
113
+ if v_hidden_size % kv_heads != 0:
114
+ raise ShapeInferenceError(
115
+ "Attention value hidden size must be divisible by kv_num_heads"
116
+ )
117
+ qk_head_size = q_hidden_size // q_heads
118
+ k_head_size = k_hidden_size // kv_heads
119
+ v_head_size = v_hidden_size // kv_heads
120
+ if qk_head_size != k_head_size:
121
+ raise ShapeInferenceError("Attention Q/K head sizes must match")
122
+ else:
123
+ q_heads = q_shape[1]
124
+ kv_heads = k_shape[1]
125
+ if kv_heads != v_shape[1]:
126
+ raise ShapeInferenceError("Attention key/value head counts must match")
127
+ q_seq = q_shape[2]
128
+ kv_seq = k_shape[2]
129
+ if kv_seq != v_shape[2]:
130
+ raise ShapeInferenceError(
131
+ "Attention key/value sequence lengths must match"
132
+ )
133
+ qk_head_size = q_shape[3]
134
+ k_head_size = k_shape[3]
135
+ v_head_size = v_shape[3]
136
+ if qk_head_size != k_head_size:
137
+ raise ShapeInferenceError("Attention Q/K head sizes must match")
138
+ attr_q_heads = node.attrs.get("q_num_heads")
139
+ attr_kv_heads = node.attrs.get("kv_num_heads")
140
+ if attr_q_heads is not None and int(attr_q_heads) != q_heads:
141
+ raise ShapeInferenceError(
142
+ "Attention q_num_heads must match query head dimension"
143
+ )
144
+ if attr_kv_heads is not None and int(attr_kv_heads) != kv_heads:
145
+ raise ShapeInferenceError(
146
+ "Attention kv_num_heads must match key/value head dimension"
147
+ )
148
+ if q_heads < kv_heads or q_heads % kv_heads != 0:
149
+ raise ShapeInferenceError(
150
+ "Attention requires q_num_heads to be a multiple of kv_num_heads"
151
+ )
152
+ head_group_size = q_heads // kv_heads
153
+ past_key_name = _optional_name(node.inputs, 4)
154
+ past_value_name = _optional_name(node.inputs, 5)
155
+ has_past = past_key_name is not None or past_value_name is not None
156
+ if has_past and (past_key_name is None or past_value_name is None):
157
+ raise UnsupportedOpError(
158
+ "Attention expects both past_key and past_value if either is provided"
159
+ )
160
+ past_seq = 0
161
+ if has_past:
162
+ past_key_shape = _value_shape(graph, past_key_name, node)
163
+ past_value_shape = _value_shape(graph, past_value_name, node)
164
+ if len(past_key_shape) != 4 or len(past_value_shape) != 4:
165
+ raise ShapeInferenceError("Attention past key/value must be 4D")
166
+ if (
167
+ past_key_shape[0] != batch
168
+ or past_value_shape[0] != batch
169
+ or past_key_shape[1] != kv_heads
170
+ or past_value_shape[1] != kv_heads
171
+ ):
172
+ raise ShapeInferenceError(
173
+ "Attention past key/value batch/head sizes must match"
174
+ )
175
+ if past_key_shape[3] != qk_head_size:
176
+ raise ShapeInferenceError(
177
+ "Attention past key head size must match key head size"
178
+ )
179
+ if past_value_shape[3] != v_head_size:
180
+ raise ShapeInferenceError(
181
+ "Attention past value head size must match value head size"
182
+ )
183
+ past_seq = past_key_shape[2]
184
+ total_seq = kv_seq + past_seq
185
+ output_shape = _value_shape(graph, node.outputs[0], node)
186
+ output_rank = len(output_shape)
187
+ if q_rank == 3:
188
+ expected_output_shape = (
189
+ batch,
190
+ q_seq,
191
+ q_heads * v_head_size,
192
+ )
193
+ else:
194
+ expected_output_shape = (batch, q_heads, q_seq, v_head_size)
195
+ if output_shape != expected_output_shape:
196
+ raise ShapeInferenceError(
197
+ "Attention output shape must be "
198
+ f"{expected_output_shape}, got {output_shape}"
199
+ )
200
+ present_key_name = _optional_name(node.outputs, 1)
201
+ present_value_name = _optional_name(node.outputs, 2)
202
+ has_present = present_key_name is not None or present_value_name is not None
203
+ if has_present and (present_key_name is None or present_value_name is None):
204
+ raise UnsupportedOpError(
205
+ "Attention expects both present_key and present_value if either is provided"
206
+ )
207
+ if has_present and not has_past:
208
+ raise UnsupportedOpError(
209
+ "Attention present outputs require past key/value inputs"
210
+ )
211
+ if has_present:
212
+ present_key_shape = _value_shape(graph, present_key_name, node)
213
+ present_value_shape = _value_shape(graph, present_value_name, node)
214
+ expected_present_key = (batch, kv_heads, total_seq, qk_head_size)
215
+ expected_present_value = (batch, kv_heads, total_seq, v_head_size)
216
+ if present_key_shape != expected_present_key:
217
+ raise ShapeInferenceError(
218
+ "Attention present key shape must be "
219
+ f"{expected_present_key}, got {present_key_shape}"
220
+ )
221
+ if present_value_shape != expected_present_value:
222
+ raise ShapeInferenceError(
223
+ "Attention present value shape must be "
224
+ f"{expected_present_value}, got {present_value_shape}"
225
+ )
226
+ qk_matmul_output_name = _optional_name(node.outputs, 3)
227
+ if qk_matmul_output_name is not None:
228
+ qk_shape = _value_shape(graph, qk_matmul_output_name, node)
229
+ expected_qk_shape = (batch, q_heads, q_seq, total_seq)
230
+ if qk_shape != expected_qk_shape:
231
+ raise ShapeInferenceError(
232
+ "Attention qk_matmul_output shape must be "
233
+ f"{expected_qk_shape}, got {qk_shape}"
234
+ )
235
+ attn_mask_name = _optional_name(node.inputs, 3)
236
+ mask_shape = None
237
+ mask_rank = None
238
+ mask_q_seq = None
239
+ mask_kv_seq = None
240
+ mask_is_bool = False
241
+ mask_broadcast_batch = False
242
+ mask_broadcast_heads = True
243
+ mask_broadcast_q_seq = False
244
+ has_attn_mask = attn_mask_name is not None
245
+ if has_attn_mask:
246
+ mask_shape = _value_shape(graph, attn_mask_name, node)
247
+ mask_rank = len(mask_shape)
248
+ if mask_rank not in {2, 3, 4}:
249
+ raise ShapeInferenceError("Attention mask must be 2D/3D/4D")
250
+ mask_dtype = _value_dtype(graph, attn_mask_name, node)
251
+ if mask_dtype == ScalarType.BOOL:
252
+ mask_is_bool = True
253
+ elif mask_dtype != dtype:
254
+ raise UnsupportedOpError(
255
+ "Attention mask must be bool or match attention dtype"
256
+ )
257
+ if mask_rank == 2:
258
+ mask_q_seq, mask_kv_seq = mask_shape
259
+ mask_broadcast_batch = True
260
+ mask_broadcast_heads = True
261
+ mask_broadcast_q_seq = mask_q_seq == 1
262
+ if mask_q_seq not in {1, q_seq}:
263
+ raise ShapeInferenceError(
264
+ "Attention mask sequence length must match query length"
265
+ )
266
+ elif mask_rank == 3:
267
+ mask_batch, mask_q_seq, mask_kv_seq = mask_shape
268
+ mask_broadcast_batch = mask_batch == 1
269
+ mask_broadcast_heads = True
270
+ mask_broadcast_q_seq = mask_q_seq == 1
271
+ if mask_batch not in {1, batch}:
272
+ raise ShapeInferenceError(
273
+ "Attention mask batch dimension must match batch size"
274
+ )
275
+ if mask_q_seq not in {1, q_seq}:
276
+ raise ShapeInferenceError(
277
+ "Attention mask sequence length must match query length"
278
+ )
279
+ else:
280
+ mask_batch, mask_heads, mask_q_seq, mask_kv_seq = mask_shape
281
+ mask_broadcast_batch = mask_batch == 1
282
+ mask_broadcast_heads = mask_heads == 1
283
+ mask_broadcast_q_seq = mask_q_seq == 1
284
+ if mask_batch not in {1, batch}:
285
+ raise ShapeInferenceError(
286
+ "Attention mask batch dimension must match batch size"
287
+ )
288
+ if mask_heads not in {1, q_heads}:
289
+ raise ShapeInferenceError(
290
+ "Attention mask head dimension must match q_num_heads"
291
+ )
292
+ if mask_q_seq not in {1, q_seq}:
293
+ raise ShapeInferenceError(
294
+ "Attention mask sequence length must match query length"
295
+ )
296
+ if mask_kv_seq is None:
297
+ raise ShapeInferenceError("Attention mask must include kv sequence")
298
+ if mask_kv_seq > total_seq:
299
+ raise ShapeInferenceError(
300
+ "Attention mask kv sequence length exceeds total sequence length"
301
+ )
302
+ nonpad_name = _optional_name(node.inputs, 6)
303
+ has_nonpad = nonpad_name is not None
304
+ if has_nonpad:
305
+ if has_past or has_present:
306
+ raise UnsupportedOpError(
307
+ "Attention nonpad_kv_seqlen is not supported with KV cache"
308
+ )
309
+ nonpad_shape = _value_shape(graph, nonpad_name, node)
310
+ if nonpad_shape != (batch,):
311
+ raise ShapeInferenceError(
312
+ "Attention nonpad_kv_seqlen must have shape (batch,)"
313
+ )
314
+ nonpad_dtype = _value_dtype(graph, nonpad_name, node)
315
+ if nonpad_dtype != ScalarType.I64:
316
+ raise UnsupportedOpError(
317
+ "Attention nonpad_kv_seqlen must be int64"
318
+ )
319
+ scale = float(node.attrs.get("scale", 1.0 / math.sqrt(qk_head_size)))
320
+ softcap = float(node.attrs.get("softcap", 0.0))
321
+ is_causal = int(node.attrs.get("is_causal", 0))
322
+ if is_causal not in (0, 1):
323
+ raise UnsupportedOpError("Unsupported op Attention")
324
+ qk_matmul_output_mode = int(node.attrs.get("qk_matmul_output_mode", 0))
325
+ if qk_matmul_output_mode not in {0, 1, 2, 3}:
326
+ raise UnsupportedOpError("Unsupported op Attention")
327
+ return AttentionSpec(
328
+ batch=batch,
329
+ q_heads=q_heads,
330
+ kv_heads=kv_heads,
331
+ q_seq=q_seq,
332
+ kv_seq=kv_seq,
333
+ total_seq=total_seq,
334
+ past_seq=past_seq,
335
+ qk_head_size=qk_head_size,
336
+ v_head_size=v_head_size,
337
+ q_hidden_size=q_hidden_size,
338
+ k_hidden_size=k_hidden_size,
339
+ v_hidden_size=v_hidden_size,
340
+ scale=scale,
341
+ is_causal=bool(is_causal),
342
+ softcap=softcap,
343
+ qk_matmul_output_mode=qk_matmul_output_mode,
344
+ q_rank=q_rank,
345
+ k_rank=k_rank,
346
+ v_rank=v_rank,
347
+ output_rank=output_rank,
348
+ mask_shape=mask_shape,
349
+ mask_is_bool=mask_is_bool,
350
+ mask_rank=mask_rank,
351
+ mask_broadcast_batch=mask_broadcast_batch,
352
+ mask_broadcast_heads=mask_broadcast_heads,
353
+ mask_broadcast_q_seq=mask_broadcast_q_seq,
354
+ mask_q_seq=mask_q_seq,
355
+ mask_kv_seq=mask_kv_seq,
356
+ head_group_size=head_group_size,
357
+ has_attn_mask=has_attn_mask,
358
+ has_past=has_past,
359
+ has_present=has_present,
360
+ has_nonpad=has_nonpad,
361
+ )
362
+
363
+
364
+ @register_lowering("Attention")
365
+ def lower_attention(graph: Graph, node: Node) -> AttentionOp:
366
+ input_q = node.inputs[0]
367
+ input_k = node.inputs[1]
368
+ input_v = node.inputs[2]
369
+ output_y = node.outputs[0]
370
+ op_dtype = _node_dtype(graph, node, input_q, input_k, input_v, output_y)
371
+ spec = resolve_attention_spec(graph, node, op_dtype)
372
+ input_mask = _optional_name(node.inputs, 3)
373
+ input_past_key = _optional_name(node.inputs, 4)
374
+ input_past_value = _optional_name(node.inputs, 5)
375
+ input_nonpad = _optional_name(node.inputs, 6)
376
+ output_present_key = _optional_name(node.outputs, 1)
377
+ output_present_value = _optional_name(node.outputs, 2)
378
+ output_qk_matmul = _optional_name(node.outputs, 3)
379
+ return AttentionOp(
380
+ input_q=input_q,
381
+ input_k=input_k,
382
+ input_v=input_v,
383
+ input_attn_mask=input_mask,
384
+ input_past_key=input_past_key,
385
+ input_past_value=input_past_value,
386
+ input_nonpad_kv_seqlen=input_nonpad,
387
+ output=output_y,
388
+ output_present_key=output_present_key,
389
+ output_present_value=output_present_value,
390
+ output_qk_matmul=output_qk_matmul,
391
+ batch=spec.batch,
392
+ q_heads=spec.q_heads,
393
+ kv_heads=spec.kv_heads,
394
+ q_seq=spec.q_seq,
395
+ kv_seq=spec.kv_seq,
396
+ total_seq=spec.total_seq,
397
+ past_seq=spec.past_seq,
398
+ qk_head_size=spec.qk_head_size,
399
+ v_head_size=spec.v_head_size,
400
+ q_hidden_size=spec.q_hidden_size,
401
+ k_hidden_size=spec.k_hidden_size,
402
+ v_hidden_size=spec.v_hidden_size,
403
+ scale=spec.scale,
404
+ is_causal=spec.is_causal,
405
+ softcap=spec.softcap,
406
+ qk_matmul_output_mode=spec.qk_matmul_output_mode,
407
+ q_rank=spec.q_rank,
408
+ k_rank=spec.k_rank,
409
+ v_rank=spec.v_rank,
410
+ output_rank=spec.output_rank,
411
+ mask_shape=spec.mask_shape,
412
+ mask_is_bool=spec.mask_is_bool,
413
+ mask_rank=spec.mask_rank,
414
+ mask_broadcast_batch=spec.mask_broadcast_batch,
415
+ mask_broadcast_heads=spec.mask_broadcast_heads,
416
+ mask_broadcast_q_seq=spec.mask_broadcast_q_seq,
417
+ mask_q_seq=spec.mask_q_seq,
418
+ mask_kv_seq=spec.mask_kv_seq,
419
+ head_group_size=spec.head_group_size,
420
+ dtype=op_dtype,
421
+ )