onnx2fx 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2fx/__init__.py +96 -0
- onnx2fx/converter.py +62 -0
- onnx2fx/exceptions.py +155 -0
- onnx2fx/graph_builder.py +634 -0
- onnx2fx/op_registry.py +345 -0
- onnx2fx/ops/__init__.py +74 -0
- onnx2fx/ops/activation.py +282 -0
- onnx2fx/ops/arithmetic.py +281 -0
- onnx2fx/ops/attention.py +1055 -0
- onnx2fx/ops/attention_msft.py +682 -0
- onnx2fx/ops/control_flow.py +947 -0
- onnx2fx/ops/convolution.py +406 -0
- onnx2fx/ops/image.py +748 -0
- onnx2fx/ops/linalg.py +33 -0
- onnx2fx/ops/loss.py +56 -0
- onnx2fx/ops/nn.py +96 -0
- onnx2fx/ops/normalization.py +289 -0
- onnx2fx/ops/pooling.py +897 -0
- onnx2fx/ops/quantization.py +524 -0
- onnx2fx/ops/random.py +102 -0
- onnx2fx/ops/recurrent.py +647 -0
- onnx2fx/ops/reduction.py +534 -0
- onnx2fx/ops/sequence.py +304 -0
- onnx2fx/ops/signal.py +444 -0
- onnx2fx/ops/string.py +126 -0
- onnx2fx/ops/tensor.py +1161 -0
- onnx2fx/ops/training.py +402 -0
- onnx2fx/py.typed +0 -0
- onnx2fx/utils/__init__.py +45 -0
- onnx2fx/utils/analyze.py +139 -0
- onnx2fx/utils/attributes.py +150 -0
- onnx2fx/utils/dtype.py +107 -0
- onnx2fx/utils/external_data.py +233 -0
- onnx2fx/utils/names.py +43 -0
- onnx2fx/utils/op_helpers.py +339 -0
- onnx2fx/utils/training.py +54 -0
- onnx2fx-0.0.0.dist-info/METADATA +395 -0
- onnx2fx-0.0.0.dist-info/RECORD +39 -0
- onnx2fx-0.0.0.dist-info/WHEEL +4 -0
onnx2fx/ops/attention.py
ADDED
|
@@ -0,0 +1,1055 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Attention and Transformer related operators (standard ONNX domain).
|
|
3
|
+
|
|
4
|
+
Microsoft domain operators are in attention_msft.py.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
import onnx
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from ..op_registry import register
|
|
13
|
+
from ..utils.attributes import get_attribute
|
|
14
|
+
from ..utils.op_helpers import get_optional_input
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from ..graph_builder import GraphBuilder
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# =============================================================================
|
|
21
|
+
# Embedding and LayerNorm variants
|
|
22
|
+
# =============================================================================
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@register("SkipLayerNormalization")
|
|
26
|
+
def skip_layer_normalization(
|
|
27
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
28
|
+
) -> torch.fx.Node:
|
|
29
|
+
"""Skip connection + LayerNorm (common in transformers)."""
|
|
30
|
+
x = builder.get_value(node.input[0])
|
|
31
|
+
skip = builder.get_value(node.input[1])
|
|
32
|
+
gamma = builder.get_value(node.input[2])
|
|
33
|
+
beta = get_optional_input(builder, node, 3)
|
|
34
|
+
bias = get_optional_input(builder, node, 4)
|
|
35
|
+
|
|
36
|
+
epsilon = get_attribute(node, "epsilon", 1e-5)
|
|
37
|
+
|
|
38
|
+
def _skip_layer_norm(
|
|
39
|
+
inp: torch.Tensor,
|
|
40
|
+
sk: torch.Tensor,
|
|
41
|
+
g: torch.Tensor,
|
|
42
|
+
b: torch.Tensor | None,
|
|
43
|
+
bi: torch.Tensor | None,
|
|
44
|
+
eps: float,
|
|
45
|
+
) -> torch.Tensor:
|
|
46
|
+
hidden = inp + sk
|
|
47
|
+
if bi is not None:
|
|
48
|
+
hidden = hidden + bi
|
|
49
|
+
return torch.nn.functional.layer_norm(
|
|
50
|
+
hidden, hidden.shape[-1:], weight=g, bias=b, eps=eps
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
return builder.call_function(
|
|
54
|
+
_skip_layer_norm, args=(x, skip, gamma, beta, bias, epsilon)
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@register("EmbedLayerNormalization")
|
|
59
|
+
def embed_layer_normalization(
|
|
60
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
61
|
+
) -> torch.fx.Node:
|
|
62
|
+
"""Embedding + LayerNorm (common in BERT-like models)."""
|
|
63
|
+
input_ids = builder.get_value(node.input[0])
|
|
64
|
+
segment_ids = get_optional_input(builder, node, 1)
|
|
65
|
+
word_embedding = builder.get_value(node.input[2])
|
|
66
|
+
position_embedding = builder.get_value(node.input[3])
|
|
67
|
+
segment_embedding = get_optional_input(builder, node, 4)
|
|
68
|
+
gamma = get_optional_input(builder, node, 5)
|
|
69
|
+
beta = get_optional_input(builder, node, 6)
|
|
70
|
+
|
|
71
|
+
epsilon = get_attribute(node, "epsilon", 1e-5)
|
|
72
|
+
|
|
73
|
+
def _embed_layer_norm(
|
|
74
|
+
ids: torch.Tensor,
|
|
75
|
+
seg_ids: torch.Tensor | None,
|
|
76
|
+
word_emb: torch.Tensor,
|
|
77
|
+
pos_emb: torch.Tensor,
|
|
78
|
+
seg_emb: torch.Tensor | None,
|
|
79
|
+
g: torch.Tensor | None,
|
|
80
|
+
b: torch.Tensor | None,
|
|
81
|
+
eps: float,
|
|
82
|
+
) -> torch.Tensor:
|
|
83
|
+
# Word embedding lookup
|
|
84
|
+
word_embed = torch.nn.functional.embedding(ids, word_emb)
|
|
85
|
+
|
|
86
|
+
# Position embedding (assume sequential positions)
|
|
87
|
+
seq_len = ids.shape[1]
|
|
88
|
+
pos_embed = pos_emb[:seq_len].unsqueeze(0).expand(ids.shape[0], -1, -1)
|
|
89
|
+
|
|
90
|
+
hidden = word_embed + pos_embed
|
|
91
|
+
|
|
92
|
+
# Add segment embedding if present
|
|
93
|
+
if seg_emb is not None and seg_ids is not None:
|
|
94
|
+
seg_embed = torch.nn.functional.embedding(seg_ids, seg_emb)
|
|
95
|
+
hidden = hidden + seg_embed
|
|
96
|
+
|
|
97
|
+
# Layer normalization
|
|
98
|
+
if g is not None:
|
|
99
|
+
hidden = torch.nn.functional.layer_norm(
|
|
100
|
+
hidden, hidden.shape[-1:], weight=g, bias=b, eps=eps
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
return hidden
|
|
104
|
+
|
|
105
|
+
return builder.call_function(
|
|
106
|
+
_embed_layer_norm,
|
|
107
|
+
args=(
|
|
108
|
+
input_ids,
|
|
109
|
+
segment_ids,
|
|
110
|
+
word_embedding,
|
|
111
|
+
position_embedding,
|
|
112
|
+
segment_embedding,
|
|
113
|
+
gamma,
|
|
114
|
+
beta,
|
|
115
|
+
epsilon,
|
|
116
|
+
),
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
# =============================================================================
|
|
121
|
+
# Attention operator (ONNX standard domain, since opset 24)
|
|
122
|
+
# =============================================================================
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@register("Attention")
|
|
126
|
+
def attention(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
127
|
+
"""ONNX Attention operator (standard domain, since opset 24).
|
|
128
|
+
|
|
129
|
+
Inputs:
|
|
130
|
+
Q: Query tensor
|
|
131
|
+
- 4D: (batch_size, q_num_heads, q_sequence_length, head_size)
|
|
132
|
+
- 3D: (batch_size, q_sequence_length, q_num_heads * head_size)
|
|
133
|
+
K: Key tensor
|
|
134
|
+
- 4D: (batch_size, kv_num_heads, kv_sequence_length, head_size)
|
|
135
|
+
- 3D: (batch_size, kv_sequence_length, kv_num_heads * head_size)
|
|
136
|
+
V: Value tensor (same format as K)
|
|
137
|
+
attn_mask (optional): Attention mask, broadcastable to
|
|
138
|
+
(batch_size, q_num_heads, q_sequence_length, total_sequence_length)
|
|
139
|
+
past_key (optional): Past key cache
|
|
140
|
+
past_value (optional): Past value cache
|
|
141
|
+
nonpad_kv_seqlen (optional): Non-padding KV sequence lengths
|
|
142
|
+
|
|
143
|
+
Attributes:
|
|
144
|
+
is_causal: If 1, use causal (lower triangular) mask
|
|
145
|
+
scale: Scaling factor for Q*K^T (default: 1/sqrt(head_size))
|
|
146
|
+
softcap: Softcap value for attention weights
|
|
147
|
+
q_num_heads: Number of query heads (required for 3D inputs)
|
|
148
|
+
kv_num_heads: Number of key/value heads (required for 3D inputs)
|
|
149
|
+
qk_matmul_output_mode: Output mode for QK matmul (0, 1, or 2)
|
|
150
|
+
|
|
151
|
+
Outputs:
|
|
152
|
+
Y: Output tensor (same format as Q)
|
|
153
|
+
present_key (optional): Updated key cache
|
|
154
|
+
present_value (optional): Updated value cache
|
|
155
|
+
qk_matmul_output (optional): QK matmul output
|
|
156
|
+
"""
|
|
157
|
+
# Get inputs
|
|
158
|
+
query = builder.get_value(node.input[0])
|
|
159
|
+
key = builder.get_value(node.input[1])
|
|
160
|
+
value = builder.get_value(node.input[2])
|
|
161
|
+
|
|
162
|
+
attn_mask = get_optional_input(builder, node, 3)
|
|
163
|
+
past_key = get_optional_input(builder, node, 4)
|
|
164
|
+
past_value = get_optional_input(builder, node, 5)
|
|
165
|
+
nonpad_kv_seqlen = get_optional_input(builder, node, 6)
|
|
166
|
+
|
|
167
|
+
# Get attributes
|
|
168
|
+
is_causal = get_attribute(node, "is_causal", 0)
|
|
169
|
+
scale = get_attribute(node, "scale", None)
|
|
170
|
+
softcap = get_attribute(node, "softcap", 0.0)
|
|
171
|
+
q_num_heads = get_attribute(node, "q_num_heads", None)
|
|
172
|
+
kv_num_heads = get_attribute(node, "kv_num_heads", None)
|
|
173
|
+
qk_matmul_output_mode = get_attribute(node, "qk_matmul_output_mode", 0)
|
|
174
|
+
|
|
175
|
+
# Determine which outputs are needed
|
|
176
|
+
# Output positions: 0=Y, 1=present_key, 2=present_value, 3=qk_matmul_output
|
|
177
|
+
num_outputs = len(node.output)
|
|
178
|
+
has_present_key = num_outputs > 1 and node.output[1]
|
|
179
|
+
has_present_value = num_outputs > 2 and node.output[2]
|
|
180
|
+
has_qk_matmul_output = num_outputs > 3 and node.output[3]
|
|
181
|
+
|
|
182
|
+
# Check if we need multiple outputs (even if some are empty)
|
|
183
|
+
_needs_multiple_outputs = num_outputs > 1
|
|
184
|
+
|
|
185
|
+
# Use manual attention computation when:
|
|
186
|
+
# 1. We need qk_matmul_output
|
|
187
|
+
# 2. We have softcap
|
|
188
|
+
# 3. We have both is_causal and attn_mask
|
|
189
|
+
needs_manual_attention = (
|
|
190
|
+
has_qk_matmul_output or softcap != 0.0 or (is_causal and attn_mask is not None)
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
if needs_manual_attention:
|
|
194
|
+
# Manual attention implementation for advanced features
|
|
195
|
+
|
|
196
|
+
def _attention_manual(
|
|
197
|
+
q: torch.Tensor,
|
|
198
|
+
k: torch.Tensor,
|
|
199
|
+
v: torch.Tensor,
|
|
200
|
+
mask: torch.Tensor | None,
|
|
201
|
+
past_k: torch.Tensor | None,
|
|
202
|
+
past_v: torch.Tensor | None,
|
|
203
|
+
is_causal: int,
|
|
204
|
+
scale: float | None,
|
|
205
|
+
softcap: float,
|
|
206
|
+
q_num_heads: int | None,
|
|
207
|
+
kv_num_heads: int | None,
|
|
208
|
+
num_outputs: int,
|
|
209
|
+
qk_matmul_output_mode: int,
|
|
210
|
+
) -> (
|
|
211
|
+
torch.Tensor
|
|
212
|
+
| tuple[torch.Tensor, torch.Tensor]
|
|
213
|
+
| tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
|
214
|
+
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
|
|
215
|
+
):
|
|
216
|
+
is_3d = q.dim() == 3
|
|
217
|
+
|
|
218
|
+
if is_3d:
|
|
219
|
+
batch_size = q.shape[0]
|
|
220
|
+
q_seq_len = q.shape[1]
|
|
221
|
+
kv_seq_len = k.shape[1]
|
|
222
|
+
n_q_heads = q_num_heads if q_num_heads is not None else 1
|
|
223
|
+
n_kv_heads = kv_num_heads if kv_num_heads is not None else n_q_heads
|
|
224
|
+
q_head_size = q.shape[2] // n_q_heads
|
|
225
|
+
kv_head_size = k.shape[2] // n_kv_heads
|
|
226
|
+
v_head_size = v.shape[2] // n_kv_heads
|
|
227
|
+
q_4d = q.view(batch_size, q_seq_len, n_q_heads, q_head_size).transpose(
|
|
228
|
+
1, 2
|
|
229
|
+
)
|
|
230
|
+
k_4d = k.view(
|
|
231
|
+
batch_size, kv_seq_len, n_kv_heads, kv_head_size
|
|
232
|
+
).transpose(1, 2)
|
|
233
|
+
v_4d = v.view(
|
|
234
|
+
batch_size, kv_seq_len, n_kv_heads, v_head_size
|
|
235
|
+
).transpose(1, 2)
|
|
236
|
+
else:
|
|
237
|
+
q_4d = q
|
|
238
|
+
k_4d = k
|
|
239
|
+
v_4d = v
|
|
240
|
+
batch_size = q.shape[0]
|
|
241
|
+
n_q_heads = q.shape[1]
|
|
242
|
+
n_kv_heads = k.shape[1]
|
|
243
|
+
q_seq_len = q.shape[2]
|
|
244
|
+
q_head_size = q.shape[3]
|
|
245
|
+
|
|
246
|
+
# Handle past key/value (KV cache)
|
|
247
|
+
past_seq_len = 0
|
|
248
|
+
if past_k is not None and past_v is not None:
|
|
249
|
+
past_seq_len = past_k.shape[2]
|
|
250
|
+
k_4d = torch.cat([past_k, k_4d], dim=2)
|
|
251
|
+
v_4d = torch.cat([past_v, v_4d], dim=2)
|
|
252
|
+
|
|
253
|
+
# Save present_key and present_value before GQA expansion (if needed for output)
|
|
254
|
+
present_k = k_4d if num_outputs > 1 else None
|
|
255
|
+
present_v = v_4d if num_outputs > 2 else None
|
|
256
|
+
|
|
257
|
+
# Handle GQA: expand KV heads to match Q heads
|
|
258
|
+
if n_kv_heads != n_q_heads:
|
|
259
|
+
n_rep = n_q_heads // n_kv_heads
|
|
260
|
+
k_4d = k_4d.repeat_interleave(n_rep, dim=1)
|
|
261
|
+
v_4d = v_4d.repeat_interleave(n_rep, dim=1)
|
|
262
|
+
|
|
263
|
+
# Compute attention scale
|
|
264
|
+
if scale is None:
|
|
265
|
+
scale_val = 1.0 / (q_head_size**0.5)
|
|
266
|
+
else:
|
|
267
|
+
scale_val = scale
|
|
268
|
+
|
|
269
|
+
# Q @ K^T with scaling
|
|
270
|
+
# (batch, heads, q_seq, head) @ (batch, heads, head, kv_seq)
|
|
271
|
+
# -> (batch, heads, q_seq, kv_seq)
|
|
272
|
+
qk = torch.matmul(q_4d, k_4d.transpose(-2, -1)) * scale_val
|
|
273
|
+
|
|
274
|
+
# Save QK matmul output before applying mask/softmax (if needed for output)
|
|
275
|
+
# Mode 0: raw QK matmul output
|
|
276
|
+
qk_output = None
|
|
277
|
+
if num_outputs > 3 and qk_matmul_output_mode == 0:
|
|
278
|
+
qk_output = qk.clone()
|
|
279
|
+
|
|
280
|
+
# Build combined attention mask (applied BEFORE softcap per ONNX spec)
|
|
281
|
+
# The ONNX approach: create causal mask first, add attn_mask to it,
|
|
282
|
+
# then add combined mask to QK scores
|
|
283
|
+
kv_seq = k_4d.shape[2]
|
|
284
|
+
combined_mask = None
|
|
285
|
+
|
|
286
|
+
# Create causal mask if is_causal=1
|
|
287
|
+
# ONNX uses: Less(q_pos + past_len, k_pos) to determine masked positions
|
|
288
|
+
# This creates a strictly lower triangular mask where q_pos + past_len >= k_pos is allowed
|
|
289
|
+
if is_causal:
|
|
290
|
+
# Create causal mask: (q_pos + past_seq_len) < k_pos means masked
|
|
291
|
+
row = (
|
|
292
|
+
torch.arange(q_seq_len, device=q.device).view(-1, 1) + past_seq_len
|
|
293
|
+
)
|
|
294
|
+
col = torch.arange(kv_seq, device=q.device).view(1, -1)
|
|
295
|
+
causal_bool = row < col # True where masked
|
|
296
|
+
causal_mask = (
|
|
297
|
+
torch.where(causal_bool, float("-inf"), 0.0)
|
|
298
|
+
.to(q.dtype)
|
|
299
|
+
.unsqueeze(0)
|
|
300
|
+
.unsqueeze(0)
|
|
301
|
+
)
|
|
302
|
+
combined_mask = causal_mask
|
|
303
|
+
|
|
304
|
+
# Add attention mask to causal mask (or use just attn_mask if no causal)
|
|
305
|
+
if mask is not None:
|
|
306
|
+
if combined_mask is not None:
|
|
307
|
+
combined_mask = mask + combined_mask # attn_mask + causal_mask
|
|
308
|
+
else:
|
|
309
|
+
combined_mask = mask
|
|
310
|
+
|
|
311
|
+
# Add combined mask to QK scores
|
|
312
|
+
if combined_mask is not None:
|
|
313
|
+
qk = qk + combined_mask
|
|
314
|
+
|
|
315
|
+
# Mode 1: after attention mask addition (including causal)
|
|
316
|
+
if num_outputs > 3 and qk_matmul_output_mode == 1:
|
|
317
|
+
qk_output = qk.clone()
|
|
318
|
+
|
|
319
|
+
# Apply softcap if specified (after mask, before softmax per ONNX spec)
|
|
320
|
+
if softcap != 0.0:
|
|
321
|
+
qk = softcap * torch.tanh(qk / softcap)
|
|
322
|
+
|
|
323
|
+
# Mode 2: after softcap
|
|
324
|
+
if num_outputs > 3 and qk_matmul_output_mode == 2:
|
|
325
|
+
qk_output = qk.clone()
|
|
326
|
+
|
|
327
|
+
# Softmax
|
|
328
|
+
attn_weights = torch.nn.functional.softmax(qk, dim=-1)
|
|
329
|
+
|
|
330
|
+
# Mode 3: after softmax
|
|
331
|
+
if num_outputs > 3 and qk_matmul_output_mode == 3:
|
|
332
|
+
qk_output = attn_weights.clone()
|
|
333
|
+
|
|
334
|
+
# Attention @ V
|
|
335
|
+
output = torch.matmul(attn_weights, v_4d)
|
|
336
|
+
|
|
337
|
+
if is_3d:
|
|
338
|
+
output = (
|
|
339
|
+
output.transpose(1, 2).contiguous().view(batch_size, q_seq_len, -1)
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
# Return based on num_outputs (must match exactly)
|
|
343
|
+
# Output positions: 0=Y, 1=present_key, 2=present_value, 3=qk_matmul_output
|
|
344
|
+
if num_outputs == 1:
|
|
345
|
+
return output
|
|
346
|
+
elif num_outputs == 2:
|
|
347
|
+
return (output, present_k)
|
|
348
|
+
elif num_outputs == 3:
|
|
349
|
+
return (output, present_k, present_v)
|
|
350
|
+
else: # num_outputs == 4
|
|
351
|
+
return (output, present_k, present_v, qk_output)
|
|
352
|
+
|
|
353
|
+
return builder.call_function(
|
|
354
|
+
_attention_manual,
|
|
355
|
+
args=(
|
|
356
|
+
query,
|
|
357
|
+
key,
|
|
358
|
+
value,
|
|
359
|
+
attn_mask,
|
|
360
|
+
past_key,
|
|
361
|
+
past_value,
|
|
362
|
+
is_causal,
|
|
363
|
+
scale,
|
|
364
|
+
softcap,
|
|
365
|
+
q_num_heads,
|
|
366
|
+
kv_num_heads,
|
|
367
|
+
num_outputs,
|
|
368
|
+
qk_matmul_output_mode,
|
|
369
|
+
),
|
|
370
|
+
)
|
|
371
|
+
elif has_present_key or has_present_value:
|
|
372
|
+
# Use SDPA but also return present_key/present_value
|
|
373
|
+
|
|
374
|
+
def _attention_with_cache(
|
|
375
|
+
q: torch.Tensor,
|
|
376
|
+
k: torch.Tensor,
|
|
377
|
+
v: torch.Tensor,
|
|
378
|
+
mask: torch.Tensor | None,
|
|
379
|
+
past_k: torch.Tensor | None,
|
|
380
|
+
past_v: torch.Tensor | None,
|
|
381
|
+
is_causal: int,
|
|
382
|
+
scale: float | None,
|
|
383
|
+
q_num_heads: int | None,
|
|
384
|
+
kv_num_heads: int | None,
|
|
385
|
+
has_present_key: bool,
|
|
386
|
+
has_present_value: bool,
|
|
387
|
+
) -> (
|
|
388
|
+
torch.Tensor
|
|
389
|
+
| tuple[torch.Tensor, torch.Tensor]
|
|
390
|
+
| tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
|
391
|
+
):
|
|
392
|
+
is_3d = q.dim() == 3
|
|
393
|
+
|
|
394
|
+
if is_3d:
|
|
395
|
+
batch_size = q.shape[0]
|
|
396
|
+
q_seq_len = q.shape[1]
|
|
397
|
+
kv_seq_len = k.shape[1]
|
|
398
|
+
n_q_heads = q_num_heads if q_num_heads is not None else 1
|
|
399
|
+
n_kv_heads = kv_num_heads if kv_num_heads is not None else n_q_heads
|
|
400
|
+
q_head_size = q.shape[2] // n_q_heads
|
|
401
|
+
kv_head_size = k.shape[2] // n_kv_heads
|
|
402
|
+
v_head_size = v.shape[2] // n_kv_heads
|
|
403
|
+
q_4d = q.view(batch_size, q_seq_len, n_q_heads, q_head_size).transpose(
|
|
404
|
+
1, 2
|
|
405
|
+
)
|
|
406
|
+
k_4d = k.view(
|
|
407
|
+
batch_size, kv_seq_len, n_kv_heads, kv_head_size
|
|
408
|
+
).transpose(1, 2)
|
|
409
|
+
v_4d = v.view(
|
|
410
|
+
batch_size, kv_seq_len, n_kv_heads, v_head_size
|
|
411
|
+
).transpose(1, 2)
|
|
412
|
+
else:
|
|
413
|
+
q_4d = q
|
|
414
|
+
k_4d = k
|
|
415
|
+
v_4d = v
|
|
416
|
+
batch_size = q.shape[0]
|
|
417
|
+
n_q_heads = q.shape[1]
|
|
418
|
+
n_kv_heads = k.shape[1]
|
|
419
|
+
q_seq_len = q.shape[2]
|
|
420
|
+
|
|
421
|
+
# Handle past key/value (KV cache)
|
|
422
|
+
if past_k is not None and past_v is not None:
|
|
423
|
+
k_4d = torch.cat([past_k, k_4d], dim=2)
|
|
424
|
+
v_4d = torch.cat([past_v, v_4d], dim=2)
|
|
425
|
+
|
|
426
|
+
# Save present_key and present_value before GQA expansion
|
|
427
|
+
present_k = k_4d if has_present_key else None
|
|
428
|
+
present_v = v_4d if has_present_value else None
|
|
429
|
+
|
|
430
|
+
# Handle GQA: expand KV heads to match Q heads
|
|
431
|
+
if n_kv_heads != n_q_heads:
|
|
432
|
+
n_rep = n_q_heads // n_kv_heads
|
|
433
|
+
k_4d = k_4d.repeat_interleave(n_rep, dim=1)
|
|
434
|
+
v_4d = v_4d.repeat_interleave(n_rep, dim=1)
|
|
435
|
+
|
|
436
|
+
# Call SDPA
|
|
437
|
+
if mask is not None:
|
|
438
|
+
output = torch.nn.functional.scaled_dot_product_attention(
|
|
439
|
+
q_4d, k_4d, v_4d, attn_mask=mask, is_causal=False, scale=scale
|
|
440
|
+
)
|
|
441
|
+
elif is_causal:
|
|
442
|
+
output = torch.nn.functional.scaled_dot_product_attention(
|
|
443
|
+
q_4d, k_4d, v_4d, is_causal=True, scale=scale
|
|
444
|
+
)
|
|
445
|
+
else:
|
|
446
|
+
output = torch.nn.functional.scaled_dot_product_attention(
|
|
447
|
+
q_4d, k_4d, v_4d, scale=scale
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
if is_3d:
|
|
451
|
+
output = (
|
|
452
|
+
output.transpose(1, 2).contiguous().view(batch_size, q_seq_len, -1)
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
# Build return tuple
|
|
456
|
+
results = [output]
|
|
457
|
+
if has_present_key:
|
|
458
|
+
results.append(present_k)
|
|
459
|
+
if has_present_value:
|
|
460
|
+
results.append(present_v)
|
|
461
|
+
|
|
462
|
+
if len(results) == 1:
|
|
463
|
+
return results[0]
|
|
464
|
+
return tuple(results)
|
|
465
|
+
|
|
466
|
+
return builder.call_function(
|
|
467
|
+
_attention_with_cache,
|
|
468
|
+
args=(
|
|
469
|
+
query,
|
|
470
|
+
key,
|
|
471
|
+
value,
|
|
472
|
+
attn_mask,
|
|
473
|
+
past_key,
|
|
474
|
+
past_value,
|
|
475
|
+
is_causal,
|
|
476
|
+
scale,
|
|
477
|
+
q_num_heads,
|
|
478
|
+
kv_num_heads,
|
|
479
|
+
has_present_key,
|
|
480
|
+
has_present_value,
|
|
481
|
+
),
|
|
482
|
+
)
|
|
483
|
+
else:
|
|
484
|
+
# Simple case: just use SDPA
|
|
485
|
+
|
|
486
|
+
def _attention_standard(
|
|
487
|
+
q: torch.Tensor,
|
|
488
|
+
k: torch.Tensor,
|
|
489
|
+
v: torch.Tensor,
|
|
490
|
+
mask: torch.Tensor | None,
|
|
491
|
+
past_k: torch.Tensor | None,
|
|
492
|
+
past_v: torch.Tensor | None,
|
|
493
|
+
is_causal: int,
|
|
494
|
+
scale: float | None,
|
|
495
|
+
q_num_heads: int | None,
|
|
496
|
+
kv_num_heads: int | None,
|
|
497
|
+
nonpad_kv_seqlen: torch.Tensor | None,
|
|
498
|
+
) -> torch.Tensor:
|
|
499
|
+
is_3d = q.dim() == 3
|
|
500
|
+
|
|
501
|
+
if is_3d:
|
|
502
|
+
batch_size = q.shape[0]
|
|
503
|
+
q_seq_len = q.shape[1]
|
|
504
|
+
kv_seq_len = k.shape[1]
|
|
505
|
+
n_q_heads = q_num_heads if q_num_heads is not None else 1
|
|
506
|
+
n_kv_heads = kv_num_heads if kv_num_heads is not None else n_q_heads
|
|
507
|
+
q_head_size = q.shape[2] // n_q_heads
|
|
508
|
+
kv_head_size = k.shape[2] // n_kv_heads
|
|
509
|
+
v_head_size = v.shape[2] // n_kv_heads
|
|
510
|
+
q_4d = q.view(batch_size, q_seq_len, n_q_heads, q_head_size).transpose(
|
|
511
|
+
1, 2
|
|
512
|
+
)
|
|
513
|
+
k_4d = k.view(
|
|
514
|
+
batch_size, kv_seq_len, n_kv_heads, kv_head_size
|
|
515
|
+
).transpose(1, 2)
|
|
516
|
+
v_4d = v.view(
|
|
517
|
+
batch_size, kv_seq_len, n_kv_heads, v_head_size
|
|
518
|
+
).transpose(1, 2)
|
|
519
|
+
else:
|
|
520
|
+
q_4d = q
|
|
521
|
+
k_4d = k
|
|
522
|
+
v_4d = v
|
|
523
|
+
batch_size = q.shape[0]
|
|
524
|
+
n_q_heads = q.shape[1]
|
|
525
|
+
n_kv_heads = k.shape[1]
|
|
526
|
+
q_seq_len = q.shape[2]
|
|
527
|
+
|
|
528
|
+
# Handle past key/value (KV cache)
|
|
529
|
+
if past_k is not None and past_v is not None:
|
|
530
|
+
k_4d = torch.cat([past_k, k_4d], dim=2)
|
|
531
|
+
v_4d = torch.cat([past_v, v_4d], dim=2)
|
|
532
|
+
|
|
533
|
+
# Handle GQA: expand KV heads to match Q heads
|
|
534
|
+
if n_kv_heads != n_q_heads:
|
|
535
|
+
n_rep = n_q_heads // n_kv_heads
|
|
536
|
+
k_4d = k_4d.repeat_interleave(n_rep, dim=1)
|
|
537
|
+
v_4d = v_4d.repeat_interleave(n_rep, dim=1)
|
|
538
|
+
|
|
539
|
+
# Handle mask padding if mask is shorter than KV sequence
|
|
540
|
+
# Per ONNX spec: "The last dimension can also be shorter than
|
|
541
|
+
# total_sequence_length and will be padded with negative infinity"
|
|
542
|
+
kv_seq_len_actual = k_4d.shape[2]
|
|
543
|
+
|
|
544
|
+
if mask is not None:
|
|
545
|
+
# Pad mask if shorter than KV sequence (BEFORE adding nonpad mask)
|
|
546
|
+
if mask.shape[-1] < kv_seq_len_actual:
|
|
547
|
+
pad_size = kv_seq_len_actual - mask.shape[-1]
|
|
548
|
+
# For bool masks, pad with False; for float, pad with -inf
|
|
549
|
+
if mask.dtype == torch.bool:
|
|
550
|
+
mask = torch.nn.functional.pad(mask, (0, pad_size), value=False)
|
|
551
|
+
else:
|
|
552
|
+
mask = torch.nn.functional.pad(
|
|
553
|
+
mask, (0, pad_size), value=float("-inf")
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
# Expand mask for GQA if needed (mask might have n_kv_heads dimension)
|
|
557
|
+
if (
|
|
558
|
+
mask.dim() == 4
|
|
559
|
+
and mask.shape[1] == n_kv_heads
|
|
560
|
+
and n_kv_heads != n_q_heads
|
|
561
|
+
):
|
|
562
|
+
n_rep = n_q_heads // n_kv_heads
|
|
563
|
+
mask = mask.repeat_interleave(n_rep, dim=1)
|
|
564
|
+
|
|
565
|
+
# Handle nonpad_kv_seqlen: create a padding mask for each batch
|
|
566
|
+
# that masks out positions >= nonpad_kv_seqlen[batch]
|
|
567
|
+
if nonpad_kv_seqlen is not None:
|
|
568
|
+
# Create a position index tensor: (1, 1, 1, kv_seq_len)
|
|
569
|
+
positions = torch.arange(kv_seq_len_actual, device=q.device).view(
|
|
570
|
+
1, 1, 1, -1
|
|
571
|
+
)
|
|
572
|
+
# Create mask: True where position < nonpad_kv_seqlen[batch]
|
|
573
|
+
# nonpad_kv_seqlen: (batch_size,) -> (batch_size, 1, 1, 1)
|
|
574
|
+
valid_mask = positions < nonpad_kv_seqlen.view(-1, 1, 1, 1)
|
|
575
|
+
# Convert to additive mask: 0 for valid, -inf for padding
|
|
576
|
+
pad_mask = torch.where(valid_mask, 0.0, float("-inf")).to(q.dtype)
|
|
577
|
+
# Combine with existing mask
|
|
578
|
+
if mask is not None:
|
|
579
|
+
mask = mask + pad_mask
|
|
580
|
+
else:
|
|
581
|
+
mask = pad_mask
|
|
582
|
+
|
|
583
|
+
# Call SDPA
|
|
584
|
+
if mask is not None:
|
|
585
|
+
output = torch.nn.functional.scaled_dot_product_attention(
|
|
586
|
+
q_4d, k_4d, v_4d, attn_mask=mask, is_causal=False, scale=scale
|
|
587
|
+
)
|
|
588
|
+
elif is_causal:
|
|
589
|
+
output = torch.nn.functional.scaled_dot_product_attention(
|
|
590
|
+
q_4d, k_4d, v_4d, is_causal=True, scale=scale
|
|
591
|
+
)
|
|
592
|
+
else:
|
|
593
|
+
output = torch.nn.functional.scaled_dot_product_attention(
|
|
594
|
+
q_4d, k_4d, v_4d, scale=scale
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
if is_3d:
|
|
598
|
+
output = (
|
|
599
|
+
output.transpose(1, 2).contiguous().view(batch_size, q_seq_len, -1)
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
return output
|
|
603
|
+
|
|
604
|
+
return builder.call_function(
|
|
605
|
+
_attention_standard,
|
|
606
|
+
args=(
|
|
607
|
+
query,
|
|
608
|
+
key,
|
|
609
|
+
value,
|
|
610
|
+
attn_mask,
|
|
611
|
+
past_key,
|
|
612
|
+
past_value,
|
|
613
|
+
is_causal,
|
|
614
|
+
scale,
|
|
615
|
+
q_num_heads,
|
|
616
|
+
kv_num_heads,
|
|
617
|
+
nonpad_kv_seqlen,
|
|
618
|
+
),
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
# =============================================================================
|
|
623
|
+
# Simplified LayerNormalization variants (ONNX Runtime contrib ops)
|
|
624
|
+
# =============================================================================
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
@register("SimplifiedLayerNormalization")
|
|
628
|
+
def simplified_layer_normalization(
|
|
629
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
630
|
+
) -> torch.fx.Node:
|
|
631
|
+
"""Simplified Layer Normalization (RMSNorm).
|
|
632
|
+
|
|
633
|
+
This is LayerNormalization without bias and mean subtraction.
|
|
634
|
+
Formula: output = x / sqrt(mean(x^2) + epsilon) * scale
|
|
635
|
+
"""
|
|
636
|
+
x = builder.get_value(node.input[0])
|
|
637
|
+
scale = builder.get_value(node.input[1])
|
|
638
|
+
|
|
639
|
+
axis = get_attribute(node, "axis", -1)
|
|
640
|
+
epsilon = get_attribute(node, "epsilon", 1e-5)
|
|
641
|
+
|
|
642
|
+
def _simplified_layer_norm(x, scale, axis, epsilon):
|
|
643
|
+
# Simplified LayerNorm (RMSNorm)
|
|
644
|
+
# output = x * rsqrt(mean(x^2) + epsilon) * scale
|
|
645
|
+
if axis < 0:
|
|
646
|
+
axis_pos = x.dim() + axis
|
|
647
|
+
else:
|
|
648
|
+
axis_pos = axis
|
|
649
|
+
|
|
650
|
+
# Keep dims for broadcasting
|
|
651
|
+
dims = list(range(axis_pos, x.dim()))
|
|
652
|
+
|
|
653
|
+
# Compute RMS: sqrt(mean(x^2))
|
|
654
|
+
variance = x.pow(2).mean(dim=dims, keepdim=True)
|
|
655
|
+
x_normalized = x * torch.rsqrt(variance + epsilon)
|
|
656
|
+
|
|
657
|
+
return x_normalized * scale
|
|
658
|
+
|
|
659
|
+
return builder.call_function(_simplified_layer_norm, args=(x, scale, axis, epsilon))
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
@register("SkipSimplifiedLayerNormalization")
|
|
663
|
+
def skip_simplified_layer_normalization(
|
|
664
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
665
|
+
) -> torch.fx.Node:
|
|
666
|
+
"""Skip connection + Simplified Layer Normalization (RMSNorm)."""
|
|
667
|
+
x = builder.get_value(node.input[0])
|
|
668
|
+
skip = builder.get_value(node.input[1])
|
|
669
|
+
scale = builder.get_value(node.input[2])
|
|
670
|
+
bias = get_optional_input(builder, node, 3)
|
|
671
|
+
|
|
672
|
+
epsilon = get_attribute(node, "epsilon", 1e-5)
|
|
673
|
+
|
|
674
|
+
def _skip_simplified_layer_norm(x, skip, scale, bias, epsilon):
|
|
675
|
+
# Add skip connection
|
|
676
|
+
hidden = x + skip
|
|
677
|
+
if bias is not None:
|
|
678
|
+
hidden = hidden + bias
|
|
679
|
+
|
|
680
|
+
# Simplified LayerNorm (RMSNorm)
|
|
681
|
+
variance = hidden.pow(2).mean(dim=-1, keepdim=True)
|
|
682
|
+
hidden_normalized = hidden * torch.rsqrt(variance + epsilon)
|
|
683
|
+
|
|
684
|
+
return hidden_normalized * scale
|
|
685
|
+
|
|
686
|
+
return builder.call_function(
|
|
687
|
+
_skip_simplified_layer_norm, args=(x, skip, scale, bias, epsilon)
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
|
|
691
|
+
@register("GroupQueryAttention")
|
|
692
|
+
def group_query_attention(
|
|
693
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
694
|
+
) -> torch.fx.Node:
|
|
695
|
+
"""Group Query Attention (GQA) - used in LLaMA, Mistral, etc.
|
|
696
|
+
|
|
697
|
+
Inputs:
|
|
698
|
+
- query: [batch, seq_len, num_heads * head_size]
|
|
699
|
+
- key: [batch, kv_seq_len, num_kv_heads * head_size]
|
|
700
|
+
- value: [batch, kv_seq_len, num_kv_heads * head_size]
|
|
701
|
+
- past_key (optional): [batch, num_kv_heads, past_seq_len, head_size]
|
|
702
|
+
- past_value (optional): [batch, num_kv_heads, past_seq_len, head_size]
|
|
703
|
+
- seqlens_k (optional): cumulative sequence lengths for keys
|
|
704
|
+
- total_sequence_length (optional): total sequence length
|
|
705
|
+
- cos_cache (optional): [max_seq_len, head_size / 2] or [max_seq_len, head_size]
|
|
706
|
+
- sin_cache (optional): [max_seq_len, head_size / 2] or [max_seq_len, head_size]
|
|
707
|
+
|
|
708
|
+
Attributes:
|
|
709
|
+
- num_heads: number of attention heads
|
|
710
|
+
- kv_num_heads: number of key-value heads (for GQA)
|
|
711
|
+
- scale: scaling factor (default: 1/sqrt(head_size))
|
|
712
|
+
- local_window_size: for sliding window attention
|
|
713
|
+
- do_rotary: whether to apply rotary position embeddings
|
|
714
|
+
- rotary_interleaved: whether rotary is interleaved (GPT-NeoX style vs LLaMA)
|
|
715
|
+
|
|
716
|
+
Outputs:
|
|
717
|
+
- output: [batch, seq_len, num_heads * head_size]
|
|
718
|
+
- present_key: [batch, num_kv_heads, total_seq_len, head_size]
|
|
719
|
+
- present_value: [batch, num_kv_heads, total_seq_len, head_size]
|
|
720
|
+
"""
|
|
721
|
+
# Get required inputs
|
|
722
|
+
query = builder.get_value(node.input[0])
|
|
723
|
+
key = builder.get_value(node.input[1])
|
|
724
|
+
value = builder.get_value(node.input[2])
|
|
725
|
+
|
|
726
|
+
# Get optional inputs
|
|
727
|
+
past_key = get_optional_input(builder, node, 3)
|
|
728
|
+
past_value = get_optional_input(builder, node, 4)
|
|
729
|
+
seqlens_k = get_optional_input(builder, node, 5)
|
|
730
|
+
total_seq_len = get_optional_input(builder, node, 6)
|
|
731
|
+
cos_cache = get_optional_input(builder, node, 7)
|
|
732
|
+
sin_cache = get_optional_input(builder, node, 8)
|
|
733
|
+
|
|
734
|
+
# Get attributes
|
|
735
|
+
num_heads = get_attribute(node, "num_heads", 1)
|
|
736
|
+
kv_num_heads = get_attribute(node, "kv_num_heads", num_heads)
|
|
737
|
+
scale = get_attribute(node, "scale", None)
|
|
738
|
+
local_window_size = get_attribute(node, "local_window_size", -1)
|
|
739
|
+
do_rotary = get_attribute(node, "do_rotary", 0)
|
|
740
|
+
rotary_interleaved = get_attribute(node, "rotary_interleaved", 0)
|
|
741
|
+
|
|
742
|
+
def _group_query_attention(
|
|
743
|
+
q: torch.Tensor,
|
|
744
|
+
k: torch.Tensor,
|
|
745
|
+
v: torch.Tensor,
|
|
746
|
+
past_k: torch.Tensor | None,
|
|
747
|
+
past_v: torch.Tensor | None,
|
|
748
|
+
seqlens_k: torch.Tensor | None,
|
|
749
|
+
total_seq_len: torch.Tensor | None,
|
|
750
|
+
cos_cache: torch.Tensor | None,
|
|
751
|
+
sin_cache: torch.Tensor | None,
|
|
752
|
+
n_heads: int,
|
|
753
|
+
n_kv_heads: int,
|
|
754
|
+
attn_scale: float | None,
|
|
755
|
+
window_size: int,
|
|
756
|
+
do_rotary: int,
|
|
757
|
+
rotary_interleaved: int,
|
|
758
|
+
):
|
|
759
|
+
batch_size, seq_len, _ = q.shape
|
|
760
|
+
head_size = q.shape[-1] // n_heads
|
|
761
|
+
kv_head_size = k.shape[-1] // n_kv_heads
|
|
762
|
+
|
|
763
|
+
# Reshape Q, K, V to [batch, num_heads, seq_len, head_size]
|
|
764
|
+
q = q.view(batch_size, seq_len, n_heads, head_size).transpose(1, 2)
|
|
765
|
+
k = k.view(batch_size, -1, n_kv_heads, kv_head_size).transpose(1, 2)
|
|
766
|
+
v = v.view(batch_size, -1, n_kv_heads, kv_head_size).transpose(1, 2)
|
|
767
|
+
|
|
768
|
+
# Calculate position offset from past cache
|
|
769
|
+
past_seq_len = 0
|
|
770
|
+
if past_k is not None and past_k.numel() > 0:
|
|
771
|
+
past_seq_len = past_k.shape[2]
|
|
772
|
+
|
|
773
|
+
# Apply rotary position embeddings if enabled
|
|
774
|
+
if do_rotary and cos_cache is not None and sin_cache is not None:
|
|
775
|
+
# Get the position indices
|
|
776
|
+
positions = torch.arange(
|
|
777
|
+
past_seq_len, past_seq_len + seq_len, device=q.device
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
# Get cos/sin values for current positions
|
|
781
|
+
cos = cos_cache[positions] # [seq_len, rotary_dim]
|
|
782
|
+
sin = sin_cache[positions] # [seq_len, rotary_dim]
|
|
783
|
+
|
|
784
|
+
# Expand for batch and heads: [1, 1, seq_len, rotary_dim]
|
|
785
|
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
|
786
|
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
|
787
|
+
|
|
788
|
+
rotary_dim = cos.shape[-1]
|
|
789
|
+
|
|
790
|
+
if rotary_interleaved:
|
|
791
|
+
# GPT-NeoX style: [x0, x1, x2, x3, ...] -> rotate pairs
|
|
792
|
+
q_rot = q[..., :rotary_dim]
|
|
793
|
+
q_pass = q[..., rotary_dim:]
|
|
794
|
+
k_rot = k[..., :rotary_dim]
|
|
795
|
+
k_pass = k[..., rotary_dim:]
|
|
796
|
+
|
|
797
|
+
# Apply rotation
|
|
798
|
+
q1, q2 = q_rot[..., ::2], q_rot[..., 1::2]
|
|
799
|
+
k1, k2 = k_rot[..., ::2], k_rot[..., 1::2]
|
|
800
|
+
|
|
801
|
+
cos_half = cos[..., ::2]
|
|
802
|
+
sin_half = sin[..., ::2]
|
|
803
|
+
|
|
804
|
+
q_rot_new = torch.stack(
|
|
805
|
+
[q1 * cos_half - q2 * sin_half, q1 * sin_half + q2 * cos_half],
|
|
806
|
+
dim=-1,
|
|
807
|
+
).flatten(-2)
|
|
808
|
+
k_rot_new = torch.stack(
|
|
809
|
+
[k1 * cos_half - k2 * sin_half, k1 * sin_half + k2 * cos_half],
|
|
810
|
+
dim=-1,
|
|
811
|
+
).flatten(-2)
|
|
812
|
+
|
|
813
|
+
q = torch.cat([q_rot_new, q_pass], dim=-1)
|
|
814
|
+
k = torch.cat([k_rot_new, k_pass], dim=-1)
|
|
815
|
+
else:
|
|
816
|
+
# LLaMA style: cos/sin are [seq, rotary_dim]
|
|
817
|
+
# rotary_dim is half the head_size in this format
|
|
818
|
+
# q/k first rotary_dim*2 elements are rotated:
|
|
819
|
+
# q1 = q[..., :rotary_dim], q2 = q[..., rotary_dim:rotary_dim*2]
|
|
820
|
+
# result = (q1*cos - q2*sin, q1*sin + q2*cos)
|
|
821
|
+
|
|
822
|
+
rotary_full = rotary_dim * 2 # total dims that get rotated
|
|
823
|
+
q_rot = q[..., :rotary_full]
|
|
824
|
+
q_pass = q[..., rotary_full:]
|
|
825
|
+
k_rot = k[..., :rotary_full]
|
|
826
|
+
k_pass = k[..., rotary_full:]
|
|
827
|
+
|
|
828
|
+
# Split into first half and second half
|
|
829
|
+
q1, q2 = q_rot[..., :rotary_dim], q_rot[..., rotary_dim:rotary_full]
|
|
830
|
+
k1, k2 = k_rot[..., :rotary_dim], k_rot[..., rotary_dim:rotary_full]
|
|
831
|
+
|
|
832
|
+
# cos/sin are already in the right shape [1, 1, seq_len, rotary_dim]
|
|
833
|
+
q_rot_new = torch.cat(
|
|
834
|
+
[q1 * cos - q2 * sin, q1 * sin + q2 * cos], dim=-1
|
|
835
|
+
)
|
|
836
|
+
k_rot_new = torch.cat(
|
|
837
|
+
[k1 * cos - k2 * sin, k1 * sin + k2 * cos], dim=-1
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
q = torch.cat([q_rot_new, q_pass], dim=-1)
|
|
841
|
+
k = torch.cat([k_rot_new, k_pass], dim=-1)
|
|
842
|
+
|
|
843
|
+
# Handle past key-value cache
|
|
844
|
+
if past_k is not None and past_k.numel() > 0:
|
|
845
|
+
k = torch.cat([past_k, k], dim=2)
|
|
846
|
+
v = torch.cat([past_v, v], dim=2)
|
|
847
|
+
|
|
848
|
+
# Present key-value for caching
|
|
849
|
+
present_k = k
|
|
850
|
+
present_v = v
|
|
851
|
+
|
|
852
|
+
# Expand K, V for GQA (repeat for each head group)
|
|
853
|
+
if n_kv_heads < n_heads:
|
|
854
|
+
n_rep = n_heads // n_kv_heads
|
|
855
|
+
k = k.repeat_interleave(n_rep, dim=1)
|
|
856
|
+
v = v.repeat_interleave(n_rep, dim=1)
|
|
857
|
+
|
|
858
|
+
# Compute attention scale
|
|
859
|
+
if attn_scale is None:
|
|
860
|
+
attn_scale = 1.0 / (head_size**0.5)
|
|
861
|
+
|
|
862
|
+
# Use scaled_dot_product_attention
|
|
863
|
+
# For autoregressive with past cache, don't use causal mask for new tokens
|
|
864
|
+
# since past_k/v already handled the causality
|
|
865
|
+
is_causal = seq_len > 1 and past_seq_len == 0
|
|
866
|
+
output = torch.nn.functional.scaled_dot_product_attention(
|
|
867
|
+
q, k, v, scale=attn_scale, is_causal=is_causal
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
# Reshape output: [batch, num_heads, seq_len, head_size] -> [batch, seq_len, hidden]
|
|
871
|
+
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
|
|
872
|
+
|
|
873
|
+
return output, present_k, present_v
|
|
874
|
+
|
|
875
|
+
# Build the call
|
|
876
|
+
result = builder.call_function(
|
|
877
|
+
_group_query_attention,
|
|
878
|
+
args=(
|
|
879
|
+
query,
|
|
880
|
+
key,
|
|
881
|
+
value,
|
|
882
|
+
past_key,
|
|
883
|
+
past_value,
|
|
884
|
+
seqlens_k,
|
|
885
|
+
total_seq_len,
|
|
886
|
+
cos_cache,
|
|
887
|
+
sin_cache,
|
|
888
|
+
num_heads,
|
|
889
|
+
kv_num_heads,
|
|
890
|
+
scale,
|
|
891
|
+
local_window_size,
|
|
892
|
+
do_rotary,
|
|
893
|
+
rotary_interleaved,
|
|
894
|
+
),
|
|
895
|
+
)
|
|
896
|
+
|
|
897
|
+
# Return tuple output
|
|
898
|
+
return result
|
|
899
|
+
|
|
900
|
+
|
|
901
|
+
# =============================================================================
|
|
902
|
+
# Rotary Embedding (ONNX standard domain, since opset 23)
|
|
903
|
+
# =============================================================================
|
|
904
|
+
|
|
905
|
+
|
|
906
|
+
@register("RotaryEmbedding", since_version=23)
|
|
907
|
+
def rotary_embedding_onnx(
|
|
908
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
909
|
+
) -> torch.fx.Node:
|
|
910
|
+
"""ONNX RotaryEmbedding operator (standard domain, since opset 23).
|
|
911
|
+
|
|
912
|
+
Applies rotary position embeddings (RoPE) to the input tensor based on
|
|
913
|
+
https://arxiv.org/pdf/2104.09864
|
|
914
|
+
|
|
915
|
+
Inputs:
|
|
916
|
+
- X: 4D tensor (batch_size, num_heads, sequence_length, head_size) or
|
|
917
|
+
3D tensor (batch_size, sequence_length, hidden_size)
|
|
918
|
+
- cos_cache: 2D tensor (max_position_id+1, head_size/2) when position_ids provided,
|
|
919
|
+
or 3D tensor (batch_size, sequence_length, head_size/2) otherwise
|
|
920
|
+
- sin_cache: Same shape as cos_cache
|
|
921
|
+
- position_ids (optional): 2D tensor (batch_size, sequence_length)
|
|
922
|
+
|
|
923
|
+
Attributes:
|
|
924
|
+
- interleaved: Whether to use interleaved pattern. Default is 0 (False).
|
|
925
|
+
- num_heads: Number of attention heads (required for 3D input).
|
|
926
|
+
- rotary_embedding_dim: Partial rotary dimension. Default is 0 (full rotation).
|
|
927
|
+
|
|
928
|
+
Outputs:
|
|
929
|
+
- Y: Tensor with same shape as input.
|
|
930
|
+
"""
|
|
931
|
+
# Get inputs
|
|
932
|
+
input_tensor = builder.get_value(node.input[0])
|
|
933
|
+
cos_cache = builder.get_value(node.input[1])
|
|
934
|
+
sin_cache = builder.get_value(node.input[2])
|
|
935
|
+
position_ids = get_optional_input(builder, node, 3)
|
|
936
|
+
|
|
937
|
+
# Get attributes
|
|
938
|
+
interleaved = get_attribute(node, "interleaved", 0)
|
|
939
|
+
num_heads = get_attribute(node, "num_heads", 0)
|
|
940
|
+
rotary_embedding_dim = get_attribute(node, "rotary_embedding_dim", 0)
|
|
941
|
+
|
|
942
|
+
def _rotary_embedding_onnx(
|
|
943
|
+
x: torch.Tensor,
|
|
944
|
+
cos_cache: torch.Tensor,
|
|
945
|
+
sin_cache: torch.Tensor,
|
|
946
|
+
position_ids: torch.Tensor | None,
|
|
947
|
+
interleaved: int,
|
|
948
|
+
num_heads: int,
|
|
949
|
+
rotary_dim: int,
|
|
950
|
+
) -> torch.Tensor:
|
|
951
|
+
"""Apply ONNX-standard rotary position embeddings."""
|
|
952
|
+
original_shape = x.shape
|
|
953
|
+
is_3d = x.dim() == 3
|
|
954
|
+
|
|
955
|
+
# First ensure input has shape [batch_size, seq_len, num_heads, head_size]
|
|
956
|
+
if x.dim() == 4:
|
|
957
|
+
# Input is (batch_size, num_heads, seq_len, head_size)
|
|
958
|
+
# Transpose to (batch_size, seq_len, num_heads, head_size)
|
|
959
|
+
x = x.transpose(1, 2)
|
|
960
|
+
batch_size, seq_len, n_heads, head_size = x.shape
|
|
961
|
+
else:
|
|
962
|
+
# Input is (batch_size, seq_len, hidden_size)
|
|
963
|
+
batch_size, seq_len, hidden_size = x.shape
|
|
964
|
+
assert num_heads != 0, "num_heads must be provided for 3D input"
|
|
965
|
+
head_size = hidden_size // num_heads
|
|
966
|
+
x = x.view(batch_size, seq_len, num_heads, head_size)
|
|
967
|
+
_n_heads = num_heads
|
|
968
|
+
|
|
969
|
+
# Determine rotary_embedding_dim
|
|
970
|
+
if rotary_dim == 0:
|
|
971
|
+
rot_dim = head_size
|
|
972
|
+
else:
|
|
973
|
+
rot_dim = rotary_dim
|
|
974
|
+
|
|
975
|
+
rotary_dim_half = rot_dim // 2
|
|
976
|
+
|
|
977
|
+
# Split into rotary and pass-through parts
|
|
978
|
+
x_rotate = x[..., :rot_dim]
|
|
979
|
+
x_not_rotate = x[..., rot_dim:] if rot_dim < head_size else None
|
|
980
|
+
|
|
981
|
+
# Retrieve sin and cos caches using position ids
|
|
982
|
+
if position_ids is not None:
|
|
983
|
+
# cos_cache/sin_cache shape: (max_pos+1, rotary_dim/2)
|
|
984
|
+
# position_ids shape: (batch_size, seq_len)
|
|
985
|
+
# Result shape: (batch_size, seq_len, rotary_dim/2)
|
|
986
|
+
cos = cos_cache[position_ids]
|
|
987
|
+
sin = sin_cache[position_ids]
|
|
988
|
+
else:
|
|
989
|
+
# cos_cache/sin_cache already have shape (batch_size, seq_len, rotary_dim/2)
|
|
990
|
+
cos = cos_cache
|
|
991
|
+
sin = sin_cache
|
|
992
|
+
|
|
993
|
+
# Validate cache dimensions
|
|
994
|
+
if cos.shape[-1] != rotary_dim_half:
|
|
995
|
+
raise ValueError(
|
|
996
|
+
f"Last dimension of cos cache ({cos.shape[-1]}) does not match "
|
|
997
|
+
f"rotary_embedding_dim/2 ({rotary_dim_half})."
|
|
998
|
+
)
|
|
999
|
+
|
|
1000
|
+
# Add head dimension: (batch_size, seq_len, 1, rotary_dim/2)
|
|
1001
|
+
cos = cos.unsqueeze(2)
|
|
1002
|
+
sin = sin.unsqueeze(2)
|
|
1003
|
+
|
|
1004
|
+
# Apply rotation based on interleaved pattern
|
|
1005
|
+
if interleaved:
|
|
1006
|
+
# Interleaved: x_rotate[..., 0::2] and x_rotate[..., 1::2]
|
|
1007
|
+
x1 = x_rotate[..., 0::2]
|
|
1008
|
+
x2 = x_rotate[..., 1::2]
|
|
1009
|
+
|
|
1010
|
+
# Calculate real and imaginary values
|
|
1011
|
+
real = (cos * x1) - (sin * x2)
|
|
1012
|
+
imag = (sin * x1) + (cos * x2)
|
|
1013
|
+
|
|
1014
|
+
# Interleave back
|
|
1015
|
+
real = real.unsqueeze(-1)
|
|
1016
|
+
imag = imag.unsqueeze(-1)
|
|
1017
|
+
x_rotate_result = torch.cat((real, imag), dim=-1).flatten(-2)
|
|
1018
|
+
else:
|
|
1019
|
+
# Non-interleaved: split in halves
|
|
1020
|
+
x1 = x_rotate[..., :rotary_dim_half]
|
|
1021
|
+
x2 = x_rotate[..., rotary_dim_half:rot_dim]
|
|
1022
|
+
|
|
1023
|
+
# Calculate real and imaginary values
|
|
1024
|
+
real = (cos * x1) - (sin * x2)
|
|
1025
|
+
imag = (sin * x1) + (cos * x2)
|
|
1026
|
+
|
|
1027
|
+
x_rotate_result = torch.cat((real, imag), dim=-1)
|
|
1028
|
+
|
|
1029
|
+
# Concatenate with non-rotated part
|
|
1030
|
+
if x_not_rotate is not None:
|
|
1031
|
+
output = torch.cat((x_rotate_result, x_not_rotate), dim=-1)
|
|
1032
|
+
else:
|
|
1033
|
+
output = x_rotate_result
|
|
1034
|
+
|
|
1035
|
+
# Reshape back to original shape
|
|
1036
|
+
if is_3d:
|
|
1037
|
+
output = output.view(original_shape)
|
|
1038
|
+
else:
|
|
1039
|
+
# Transpose back to (batch_size, num_heads, seq_len, head_size)
|
|
1040
|
+
output = output.transpose(1, 2)
|
|
1041
|
+
|
|
1042
|
+
return output
|
|
1043
|
+
|
|
1044
|
+
return builder.call_function(
|
|
1045
|
+
_rotary_embedding_onnx,
|
|
1046
|
+
args=(
|
|
1047
|
+
input_tensor,
|
|
1048
|
+
cos_cache,
|
|
1049
|
+
sin_cache,
|
|
1050
|
+
position_ids,
|
|
1051
|
+
interleaved,
|
|
1052
|
+
num_heads,
|
|
1053
|
+
rotary_embedding_dim,
|
|
1054
|
+
),
|
|
1055
|
+
)
|