onnx2fx 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2fx/__init__.py +96 -0
- onnx2fx/converter.py +62 -0
- onnx2fx/exceptions.py +155 -0
- onnx2fx/graph_builder.py +634 -0
- onnx2fx/op_registry.py +345 -0
- onnx2fx/ops/__init__.py +74 -0
- onnx2fx/ops/activation.py +282 -0
- onnx2fx/ops/arithmetic.py +281 -0
- onnx2fx/ops/attention.py +1055 -0
- onnx2fx/ops/attention_msft.py +682 -0
- onnx2fx/ops/control_flow.py +947 -0
- onnx2fx/ops/convolution.py +406 -0
- onnx2fx/ops/image.py +748 -0
- onnx2fx/ops/linalg.py +33 -0
- onnx2fx/ops/loss.py +56 -0
- onnx2fx/ops/nn.py +96 -0
- onnx2fx/ops/normalization.py +289 -0
- onnx2fx/ops/pooling.py +897 -0
- onnx2fx/ops/quantization.py +524 -0
- onnx2fx/ops/random.py +102 -0
- onnx2fx/ops/recurrent.py +647 -0
- onnx2fx/ops/reduction.py +534 -0
- onnx2fx/ops/sequence.py +304 -0
- onnx2fx/ops/signal.py +444 -0
- onnx2fx/ops/string.py +126 -0
- onnx2fx/ops/tensor.py +1161 -0
- onnx2fx/ops/training.py +402 -0
- onnx2fx/py.typed +0 -0
- onnx2fx/utils/__init__.py +45 -0
- onnx2fx/utils/analyze.py +139 -0
- onnx2fx/utils/attributes.py +150 -0
- onnx2fx/utils/dtype.py +107 -0
- onnx2fx/utils/external_data.py +233 -0
- onnx2fx/utils/names.py +43 -0
- onnx2fx/utils/op_helpers.py +339 -0
- onnx2fx/utils/training.py +54 -0
- onnx2fx-0.0.0.dist-info/METADATA +395 -0
- onnx2fx-0.0.0.dist-info/RECORD +39 -0
- onnx2fx-0.0.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,682 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""Microsoft domain (com.microsoft) attention and transformer operators.
|
|
3
|
+
|
|
4
|
+
This module implements attention-related operators for the com.microsoft domain,
|
|
5
|
+
commonly used by ONNX Runtime optimized models.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
import onnx
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from ..op_registry import register
|
|
14
|
+
from ..utils.attributes import get_attribute
|
|
15
|
+
from ..utils.op_helpers import get_optional_input
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from ..graph_builder import GraphBuilder
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# =============================================================================
|
|
22
|
+
# Embedding and LayerNorm variants (com.microsoft domain)
|
|
23
|
+
# =============================================================================
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@register("SkipLayerNormalization", domain="com.microsoft")
|
|
27
|
+
def skip_layer_normalization_msft(
|
|
28
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
29
|
+
) -> torch.fx.Node:
|
|
30
|
+
"""Skip connection + LayerNorm (common in transformers).
|
|
31
|
+
|
|
32
|
+
Microsoft domain version (com.microsoft).
|
|
33
|
+
"""
|
|
34
|
+
x = builder.get_value(node.input[0])
|
|
35
|
+
skip = builder.get_value(node.input[1])
|
|
36
|
+
gamma = builder.get_value(node.input[2])
|
|
37
|
+
beta = get_optional_input(builder, node, 3)
|
|
38
|
+
bias = get_optional_input(builder, node, 4)
|
|
39
|
+
|
|
40
|
+
epsilon = get_attribute(node, "epsilon", 1e-5)
|
|
41
|
+
|
|
42
|
+
def _skip_layer_norm(
|
|
43
|
+
inp: torch.Tensor,
|
|
44
|
+
sk: torch.Tensor,
|
|
45
|
+
g: torch.Tensor,
|
|
46
|
+
b: torch.Tensor | None,
|
|
47
|
+
bi: torch.Tensor | None,
|
|
48
|
+
eps: float,
|
|
49
|
+
) -> torch.Tensor:
|
|
50
|
+
hidden = inp + sk
|
|
51
|
+
if bi is not None:
|
|
52
|
+
hidden = hidden + bi
|
|
53
|
+
return torch.nn.functional.layer_norm(
|
|
54
|
+
hidden, hidden.shape[-1:], weight=g, bias=b, eps=eps
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
return builder.call_function(
|
|
58
|
+
_skip_layer_norm, args=(x, skip, gamma, beta, bias, epsilon)
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@register("EmbedLayerNormalization", domain="com.microsoft")
|
|
63
|
+
def embed_layer_normalization_msft(
|
|
64
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
65
|
+
) -> torch.fx.Node:
|
|
66
|
+
"""Embedding + LayerNorm (common in BERT-like models).
|
|
67
|
+
|
|
68
|
+
Microsoft domain version (com.microsoft).
|
|
69
|
+
"""
|
|
70
|
+
input_ids = builder.get_value(node.input[0])
|
|
71
|
+
segment_ids = get_optional_input(builder, node, 1)
|
|
72
|
+
word_embedding = builder.get_value(node.input[2])
|
|
73
|
+
position_embedding = builder.get_value(node.input[3])
|
|
74
|
+
segment_embedding = get_optional_input(builder, node, 4)
|
|
75
|
+
gamma = get_optional_input(builder, node, 5)
|
|
76
|
+
beta = get_optional_input(builder, node, 6)
|
|
77
|
+
|
|
78
|
+
epsilon = get_attribute(node, "epsilon", 1e-5)
|
|
79
|
+
|
|
80
|
+
def _embed_layer_norm(
|
|
81
|
+
ids: torch.Tensor,
|
|
82
|
+
seg_ids: torch.Tensor | None,
|
|
83
|
+
word_emb: torch.Tensor,
|
|
84
|
+
pos_emb: torch.Tensor,
|
|
85
|
+
seg_emb: torch.Tensor | None,
|
|
86
|
+
g: torch.Tensor | None,
|
|
87
|
+
b: torch.Tensor | None,
|
|
88
|
+
eps: float,
|
|
89
|
+
) -> torch.Tensor:
|
|
90
|
+
# Word embedding lookup
|
|
91
|
+
word_embed = torch.nn.functional.embedding(ids, word_emb)
|
|
92
|
+
|
|
93
|
+
# Position embedding (assume sequential positions)
|
|
94
|
+
seq_len = ids.shape[1]
|
|
95
|
+
pos_embed = pos_emb[:seq_len].unsqueeze(0).expand(ids.shape[0], -1, -1)
|
|
96
|
+
|
|
97
|
+
hidden = word_embed + pos_embed
|
|
98
|
+
|
|
99
|
+
# Add segment embedding if present
|
|
100
|
+
if seg_emb is not None and seg_ids is not None:
|
|
101
|
+
seg_embed = torch.nn.functional.embedding(seg_ids, seg_emb)
|
|
102
|
+
hidden = hidden + seg_embed
|
|
103
|
+
|
|
104
|
+
# Layer normalization
|
|
105
|
+
if g is not None:
|
|
106
|
+
hidden = torch.nn.functional.layer_norm(
|
|
107
|
+
hidden, hidden.shape[-1:], weight=g, bias=b, eps=eps
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
return hidden
|
|
111
|
+
|
|
112
|
+
return builder.call_function(
|
|
113
|
+
_embed_layer_norm,
|
|
114
|
+
args=(
|
|
115
|
+
input_ids,
|
|
116
|
+
segment_ids,
|
|
117
|
+
word_embedding,
|
|
118
|
+
position_embedding,
|
|
119
|
+
segment_embedding,
|
|
120
|
+
gamma,
|
|
121
|
+
beta,
|
|
122
|
+
epsilon,
|
|
123
|
+
),
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
# =============================================================================
|
|
128
|
+
# Microsoft Attention operator (com.microsoft domain)
|
|
129
|
+
# =============================================================================
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@register("Attention", domain="com.microsoft")
|
|
133
|
+
def microsoft_attention(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
|
|
134
|
+
"""Microsoft Attention operator (com.microsoft domain).
|
|
135
|
+
|
|
136
|
+
Multi-Head Attention that can be either unidirectional (like GPT-2) or
|
|
137
|
+
bidirectional (like BERT). The weights for input projection of Q, K and V
|
|
138
|
+
are merged.
|
|
139
|
+
|
|
140
|
+
Inputs:
|
|
141
|
+
input: Input tensor with shape (batch_size, sequence_length, input_hidden_size)
|
|
142
|
+
weights: Merged Q/K/V weights with shape (input_hidden_size, hidden_size + hidden_size + v_hidden_size)
|
|
143
|
+
bias (optional): Bias tensor with shape (hidden_size + hidden_size + v_hidden_size)
|
|
144
|
+
mask_index (optional): Attention mask
|
|
145
|
+
past (optional): Past state for key and value
|
|
146
|
+
attention_bias (optional): Additional bias to add to QK
|
|
147
|
+
past_sequence_length (optional): Past sequence length
|
|
148
|
+
|
|
149
|
+
Attributes:
|
|
150
|
+
num_heads (required): Number of attention heads
|
|
151
|
+
unidirectional: Whether every token can only attend to previous tokens (default 0)
|
|
152
|
+
scale: Custom scale factor (default 1/sqrt(head_size))
|
|
153
|
+
mask_filter_value: Value to fill in attention mask (default -10000.0)
|
|
154
|
+
|
|
155
|
+
Outputs:
|
|
156
|
+
output: 3D output tensor with shape (batch_size, sequence_length, v_hidden_size)
|
|
157
|
+
present (optional): Past state for key and value
|
|
158
|
+
"""
|
|
159
|
+
# Get inputs
|
|
160
|
+
input_tensor = builder.get_value(node.input[0])
|
|
161
|
+
weights = builder.get_value(node.input[1])
|
|
162
|
+
bias = get_optional_input(builder, node, 2)
|
|
163
|
+
mask_index = get_optional_input(builder, node, 3)
|
|
164
|
+
past = get_optional_input(builder, node, 4)
|
|
165
|
+
attention_bias = get_optional_input(builder, node, 5)
|
|
166
|
+
|
|
167
|
+
# Get attributes
|
|
168
|
+
num_heads = get_attribute(node, "num_heads", None)
|
|
169
|
+
if num_heads is None:
|
|
170
|
+
raise ValueError("num_heads attribute is required for Microsoft Attention")
|
|
171
|
+
unidirectional = get_attribute(node, "unidirectional", 0)
|
|
172
|
+
scale = get_attribute(node, "scale", None)
|
|
173
|
+
|
|
174
|
+
def _microsoft_attention(
|
|
175
|
+
inp: torch.Tensor,
|
|
176
|
+
w: torch.Tensor,
|
|
177
|
+
b: torch.Tensor | None,
|
|
178
|
+
mask: torch.Tensor | None,
|
|
179
|
+
past_kv: torch.Tensor | None,
|
|
180
|
+
attn_bias: torch.Tensor | None,
|
|
181
|
+
n_heads: int,
|
|
182
|
+
is_causal: bool,
|
|
183
|
+
attn_scale: float | None,
|
|
184
|
+
) -> torch.Tensor:
|
|
185
|
+
batch_size, seq_len, hidden_size = inp.shape
|
|
186
|
+
|
|
187
|
+
# Project input to Q, K, V using merged weights
|
|
188
|
+
# weights shape: (input_hidden_size, 3 * hidden_size)
|
|
189
|
+
qkv = torch.matmul(inp, w)
|
|
190
|
+
if b is not None:
|
|
191
|
+
qkv = qkv + b
|
|
192
|
+
|
|
193
|
+
# Split into Q, K, V
|
|
194
|
+
q, k, v = qkv.chunk(3, dim=-1)
|
|
195
|
+
|
|
196
|
+
# Use scaled_dot_product_attention
|
|
197
|
+
output = torch.nn.functional.scaled_dot_product_attention(
|
|
198
|
+
q, k, v, attn_mask=mask, is_causal=is_causal, scale=attn_scale
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
return output
|
|
202
|
+
|
|
203
|
+
is_causal = unidirectional == 1
|
|
204
|
+
|
|
205
|
+
return builder.call_function(
|
|
206
|
+
_microsoft_attention,
|
|
207
|
+
args=(
|
|
208
|
+
input_tensor,
|
|
209
|
+
weights,
|
|
210
|
+
bias,
|
|
211
|
+
mask_index,
|
|
212
|
+
past,
|
|
213
|
+
attention_bias,
|
|
214
|
+
num_heads,
|
|
215
|
+
is_causal,
|
|
216
|
+
scale,
|
|
217
|
+
),
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
# =============================================================================
|
|
222
|
+
# Simplified LayerNormalization variants (com.microsoft domain)
|
|
223
|
+
# =============================================================================
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
@register("SimplifiedLayerNormalization", domain="com.microsoft")
|
|
227
|
+
def simplified_layer_normalization_msft(
|
|
228
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
229
|
+
) -> torch.fx.Node:
|
|
230
|
+
"""Simplified Layer Normalization (RMSNorm).
|
|
231
|
+
|
|
232
|
+
This is LayerNormalization without bias and mean subtraction.
|
|
233
|
+
Formula: output = x / sqrt(mean(x^2) + epsilon) * scale
|
|
234
|
+
|
|
235
|
+
Microsoft domain version (com.microsoft).
|
|
236
|
+
"""
|
|
237
|
+
x = builder.get_value(node.input[0])
|
|
238
|
+
scale = builder.get_value(node.input[1])
|
|
239
|
+
|
|
240
|
+
axis = get_attribute(node, "axis", -1)
|
|
241
|
+
epsilon = get_attribute(node, "epsilon", 1e-5)
|
|
242
|
+
|
|
243
|
+
def _simplified_layer_norm(x, scale, axis, epsilon):
|
|
244
|
+
# Simplified LayerNorm (RMSNorm)
|
|
245
|
+
# output = x * rsqrt(mean(x^2) + epsilon) * scale
|
|
246
|
+
if axis < 0:
|
|
247
|
+
axis_pos = x.dim() + axis
|
|
248
|
+
else:
|
|
249
|
+
axis_pos = axis
|
|
250
|
+
|
|
251
|
+
# Keep dims for broadcasting
|
|
252
|
+
dims = list(range(axis_pos, x.dim()))
|
|
253
|
+
|
|
254
|
+
# Compute RMS: sqrt(mean(x^2))
|
|
255
|
+
variance = x.pow(2).mean(dim=dims, keepdim=True)
|
|
256
|
+
x_normalized = x * torch.rsqrt(variance + epsilon)
|
|
257
|
+
|
|
258
|
+
return x_normalized * scale
|
|
259
|
+
|
|
260
|
+
return builder.call_function(_simplified_layer_norm, args=(x, scale, axis, epsilon))
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
@register("SkipSimplifiedLayerNormalization", domain="com.microsoft")
|
|
264
|
+
def skip_simplified_layer_normalization_msft(
|
|
265
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
266
|
+
) -> torch.fx.Node:
|
|
267
|
+
"""Skip connection + Simplified Layer Normalization (RMSNorm).
|
|
268
|
+
|
|
269
|
+
Microsoft domain version (com.microsoft).
|
|
270
|
+
"""
|
|
271
|
+
x = builder.get_value(node.input[0])
|
|
272
|
+
skip = builder.get_value(node.input[1])
|
|
273
|
+
scale = builder.get_value(node.input[2])
|
|
274
|
+
bias = get_optional_input(builder, node, 3)
|
|
275
|
+
|
|
276
|
+
epsilon = get_attribute(node, "epsilon", 1e-5)
|
|
277
|
+
|
|
278
|
+
def _skip_simplified_layer_norm(x, skip, scale, bias, epsilon):
|
|
279
|
+
# Add skip connection
|
|
280
|
+
hidden = x + skip
|
|
281
|
+
if bias is not None:
|
|
282
|
+
hidden = hidden + bias
|
|
283
|
+
|
|
284
|
+
# Simplified LayerNorm (RMSNorm)
|
|
285
|
+
variance = hidden.pow(2).mean(dim=-1, keepdim=True)
|
|
286
|
+
hidden_normalized = hidden * torch.rsqrt(variance + epsilon)
|
|
287
|
+
|
|
288
|
+
return hidden_normalized * scale
|
|
289
|
+
|
|
290
|
+
return builder.call_function(
|
|
291
|
+
_skip_simplified_layer_norm, args=(x, skip, scale, bias, epsilon)
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
# =============================================================================
|
|
296
|
+
# GroupQueryAttention (com.microsoft domain)
|
|
297
|
+
# =============================================================================
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
@register("GroupQueryAttention", domain="com.microsoft")
|
|
301
|
+
def group_query_attention_msft(
|
|
302
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
303
|
+
) -> torch.fx.Node:
|
|
304
|
+
"""Group Query Attention (GQA) - used in LLaMA, Mistral, etc.
|
|
305
|
+
|
|
306
|
+
Microsoft domain version (com.microsoft).
|
|
307
|
+
|
|
308
|
+
Inputs:
|
|
309
|
+
- query: [batch, seq_len, num_heads * head_size]
|
|
310
|
+
- key: [batch, kv_seq_len, num_kv_heads * head_size]
|
|
311
|
+
- value: [batch, kv_seq_len, num_kv_heads * head_size]
|
|
312
|
+
- past_key (optional): [batch, num_kv_heads, past_seq_len, head_size]
|
|
313
|
+
- past_value (optional): [batch, num_kv_heads, past_seq_len, head_size]
|
|
314
|
+
- seqlens_k (optional): cumulative sequence lengths for keys
|
|
315
|
+
- total_sequence_length (optional): total sequence length
|
|
316
|
+
- cos_cache (optional): [max_seq_len, head_size / 2] or [max_seq_len, head_size]
|
|
317
|
+
- sin_cache (optional): [max_seq_len, head_size / 2] or [max_seq_len, head_size]
|
|
318
|
+
|
|
319
|
+
Attributes:
|
|
320
|
+
- num_heads: number of attention heads
|
|
321
|
+
- kv_num_heads: number of key-value heads (for GQA)
|
|
322
|
+
- scale: scaling factor (default: 1/sqrt(head_size))
|
|
323
|
+
- local_window_size: for sliding window attention
|
|
324
|
+
- do_rotary: whether to apply rotary position embeddings
|
|
325
|
+
- rotary_interleaved: whether rotary is interleaved (GPT-NeoX style vs LLaMA)
|
|
326
|
+
|
|
327
|
+
Outputs:
|
|
328
|
+
- output: [batch, seq_len, num_heads * head_size]
|
|
329
|
+
- present_key: [batch, num_kv_heads, total_seq_len, head_size]
|
|
330
|
+
- present_value: [batch, num_kv_heads, total_seq_len, head_size]
|
|
331
|
+
"""
|
|
332
|
+
# Get required inputs
|
|
333
|
+
query = builder.get_value(node.input[0])
|
|
334
|
+
key = builder.get_value(node.input[1])
|
|
335
|
+
value = builder.get_value(node.input[2])
|
|
336
|
+
|
|
337
|
+
# Get optional inputs
|
|
338
|
+
past_key = get_optional_input(builder, node, 3)
|
|
339
|
+
past_value = get_optional_input(builder, node, 4)
|
|
340
|
+
seqlens_k = get_optional_input(builder, node, 5)
|
|
341
|
+
total_seq_len = get_optional_input(builder, node, 6)
|
|
342
|
+
cos_cache = get_optional_input(builder, node, 7)
|
|
343
|
+
sin_cache = get_optional_input(builder, node, 8)
|
|
344
|
+
|
|
345
|
+
# Get attributes
|
|
346
|
+
num_heads = get_attribute(node, "num_heads", 1)
|
|
347
|
+
kv_num_heads = get_attribute(node, "kv_num_heads", num_heads)
|
|
348
|
+
scale = get_attribute(node, "scale", None)
|
|
349
|
+
local_window_size = get_attribute(node, "local_window_size", -1)
|
|
350
|
+
do_rotary = get_attribute(node, "do_rotary", 0)
|
|
351
|
+
rotary_interleaved = get_attribute(node, "rotary_interleaved", 0)
|
|
352
|
+
|
|
353
|
+
def _group_query_attention(
|
|
354
|
+
q: torch.Tensor,
|
|
355
|
+
k: torch.Tensor,
|
|
356
|
+
v: torch.Tensor,
|
|
357
|
+
past_k: torch.Tensor | None,
|
|
358
|
+
past_v: torch.Tensor | None,
|
|
359
|
+
seqlens_k: torch.Tensor | None,
|
|
360
|
+
total_seq_len: torch.Tensor | None,
|
|
361
|
+
cos_cache: torch.Tensor | None,
|
|
362
|
+
sin_cache: torch.Tensor | None,
|
|
363
|
+
n_heads: int,
|
|
364
|
+
n_kv_heads: int,
|
|
365
|
+
attn_scale: float | None,
|
|
366
|
+
window_size: int,
|
|
367
|
+
do_rotary: int,
|
|
368
|
+
rotary_interleaved: int,
|
|
369
|
+
):
|
|
370
|
+
batch_size, seq_len, _ = q.shape
|
|
371
|
+
head_size = q.shape[-1] // n_heads
|
|
372
|
+
kv_head_size = k.shape[-1] // n_kv_heads
|
|
373
|
+
|
|
374
|
+
# Reshape Q, K, V to [batch, num_heads, seq_len, head_size]
|
|
375
|
+
q = q.view(batch_size, seq_len, n_heads, head_size).transpose(1, 2)
|
|
376
|
+
k = k.view(batch_size, -1, n_kv_heads, kv_head_size).transpose(1, 2)
|
|
377
|
+
v = v.view(batch_size, -1, n_kv_heads, kv_head_size).transpose(1, 2)
|
|
378
|
+
|
|
379
|
+
# Calculate position offset from past cache
|
|
380
|
+
past_seq_len = 0
|
|
381
|
+
if past_k is not None and past_k.numel() > 0:
|
|
382
|
+
past_seq_len = past_k.shape[2]
|
|
383
|
+
|
|
384
|
+
# Apply rotary position embeddings if enabled
|
|
385
|
+
if do_rotary and cos_cache is not None and sin_cache is not None:
|
|
386
|
+
# Get the position indices
|
|
387
|
+
positions = torch.arange(
|
|
388
|
+
past_seq_len, past_seq_len + seq_len, device=q.device
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
# Get cos/sin values for current positions
|
|
392
|
+
cos = cos_cache[positions] # [seq_len, rotary_dim]
|
|
393
|
+
sin = sin_cache[positions] # [seq_len, rotary_dim]
|
|
394
|
+
|
|
395
|
+
# Expand for batch and heads: [1, 1, seq_len, rotary_dim]
|
|
396
|
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
|
397
|
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
|
398
|
+
|
|
399
|
+
rotary_dim = cos.shape[-1]
|
|
400
|
+
|
|
401
|
+
if rotary_interleaved:
|
|
402
|
+
# GPT-NeoX style: [x0, x1, x2, x3, ...] -> rotate pairs
|
|
403
|
+
q_rot = q[..., :rotary_dim]
|
|
404
|
+
q_pass = q[..., rotary_dim:]
|
|
405
|
+
k_rot = k[..., :rotary_dim]
|
|
406
|
+
k_pass = k[..., rotary_dim:]
|
|
407
|
+
|
|
408
|
+
# Apply rotation
|
|
409
|
+
q1, q2 = q_rot[..., ::2], q_rot[..., 1::2]
|
|
410
|
+
k1, k2 = k_rot[..., ::2], k_rot[..., 1::2]
|
|
411
|
+
|
|
412
|
+
cos_half = cos[..., ::2]
|
|
413
|
+
sin_half = sin[..., ::2]
|
|
414
|
+
|
|
415
|
+
q_rot_new = torch.stack(
|
|
416
|
+
[q1 * cos_half - q2 * sin_half, q1 * sin_half + q2 * cos_half],
|
|
417
|
+
dim=-1,
|
|
418
|
+
).flatten(-2)
|
|
419
|
+
k_rot_new = torch.stack(
|
|
420
|
+
[k1 * cos_half - k2 * sin_half, k1 * sin_half + k2 * cos_half],
|
|
421
|
+
dim=-1,
|
|
422
|
+
).flatten(-2)
|
|
423
|
+
|
|
424
|
+
q = torch.cat([q_rot_new, q_pass], dim=-1)
|
|
425
|
+
k = torch.cat([k_rot_new, k_pass], dim=-1)
|
|
426
|
+
else:
|
|
427
|
+
# LLaMA style: cos/sin are [seq, rotary_dim]
|
|
428
|
+
# rotary_dim is half the head_size in this format
|
|
429
|
+
# q/k first rotary_dim*2 elements are rotated:
|
|
430
|
+
# q1 = q[..., :rotary_dim], q2 = q[..., rotary_dim:rotary_dim*2]
|
|
431
|
+
# result = (q1*cos - q2*sin, q1*sin + q2*cos)
|
|
432
|
+
|
|
433
|
+
rotary_full = rotary_dim * 2 # total dims that get rotated
|
|
434
|
+
q_rot = q[..., :rotary_full]
|
|
435
|
+
q_pass = q[..., rotary_full:]
|
|
436
|
+
k_rot = k[..., :rotary_full]
|
|
437
|
+
k_pass = k[..., rotary_full:]
|
|
438
|
+
|
|
439
|
+
# Split into first half and second half
|
|
440
|
+
q1, q2 = q_rot[..., :rotary_dim], q_rot[..., rotary_dim:rotary_full]
|
|
441
|
+
k1, k2 = k_rot[..., :rotary_dim], k_rot[..., rotary_dim:rotary_full]
|
|
442
|
+
|
|
443
|
+
# cos/sin are already in the right shape [1, 1, seq_len, rotary_dim]
|
|
444
|
+
q_rot_new = torch.cat(
|
|
445
|
+
[q1 * cos - q2 * sin, q1 * sin + q2 * cos], dim=-1
|
|
446
|
+
)
|
|
447
|
+
k_rot_new = torch.cat(
|
|
448
|
+
[k1 * cos - k2 * sin, k1 * sin + k2 * cos], dim=-1
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
q = torch.cat([q_rot_new, q_pass], dim=-1)
|
|
452
|
+
k = torch.cat([k_rot_new, k_pass], dim=-1)
|
|
453
|
+
|
|
454
|
+
# Handle past key-value cache
|
|
455
|
+
if past_k is not None and past_k.numel() > 0:
|
|
456
|
+
k = torch.cat([past_k, k], dim=2)
|
|
457
|
+
v = torch.cat([past_v, v], dim=2)
|
|
458
|
+
|
|
459
|
+
# Present key-value for caching
|
|
460
|
+
present_k = k
|
|
461
|
+
present_v = v
|
|
462
|
+
|
|
463
|
+
# Expand K, V for GQA (repeat for each head group)
|
|
464
|
+
if n_kv_heads < n_heads:
|
|
465
|
+
n_rep = n_heads // n_kv_heads
|
|
466
|
+
k = k.repeat_interleave(n_rep, dim=1)
|
|
467
|
+
v = v.repeat_interleave(n_rep, dim=1)
|
|
468
|
+
|
|
469
|
+
# Compute attention scale
|
|
470
|
+
if attn_scale is None:
|
|
471
|
+
attn_scale = 1.0 / (head_size**0.5)
|
|
472
|
+
|
|
473
|
+
# Use scaled_dot_product_attention
|
|
474
|
+
# For autoregressive with past cache, don't use causal mask for new tokens
|
|
475
|
+
# since past_k/v already handled the causality
|
|
476
|
+
is_causal = seq_len > 1 and past_seq_len == 0
|
|
477
|
+
output = torch.nn.functional.scaled_dot_product_attention(
|
|
478
|
+
q, k, v, scale=attn_scale, is_causal=is_causal
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
# Reshape output: [batch, num_heads, seq_len, head_size] -> [batch, seq_len, hidden]
|
|
482
|
+
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
|
|
483
|
+
|
|
484
|
+
return output, present_k, present_v
|
|
485
|
+
|
|
486
|
+
# Build the call
|
|
487
|
+
result = builder.call_function(
|
|
488
|
+
_group_query_attention,
|
|
489
|
+
args=(
|
|
490
|
+
query,
|
|
491
|
+
key,
|
|
492
|
+
value,
|
|
493
|
+
past_key,
|
|
494
|
+
past_value,
|
|
495
|
+
seqlens_k,
|
|
496
|
+
total_seq_len,
|
|
497
|
+
cos_cache,
|
|
498
|
+
sin_cache,
|
|
499
|
+
num_heads,
|
|
500
|
+
kv_num_heads,
|
|
501
|
+
scale,
|
|
502
|
+
local_window_size,
|
|
503
|
+
do_rotary,
|
|
504
|
+
rotary_interleaved,
|
|
505
|
+
),
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
# Return tuple output
|
|
509
|
+
return result
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
# =============================================================================
|
|
513
|
+
# Rotary Embedding (com.microsoft domain)
|
|
514
|
+
# =============================================================================
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
@register("RotaryEmbedding", domain="com.microsoft")
|
|
518
|
+
def rotary_embedding_msft(
|
|
519
|
+
builder: "GraphBuilder", node: onnx.NodeProto
|
|
520
|
+
) -> torch.fx.Node:
|
|
521
|
+
"""Rotary Position Embedding (RoPE) operator.
|
|
522
|
+
|
|
523
|
+
Applies rotary position embeddings to the input tensor. The positions are
|
|
524
|
+
represented as rotation matrices that are multiplied to query and key
|
|
525
|
+
before the inner product of query and key is taken.
|
|
526
|
+
|
|
527
|
+
Microsoft domain version (com.microsoft).
|
|
528
|
+
|
|
529
|
+
Inputs:
|
|
530
|
+
- input: 3D tensor with shape (batch_size, sequence_length, hidden_size)
|
|
531
|
+
or 4D with shape (batch_size, num_heads, sequence_length, head_size)
|
|
532
|
+
- position_ids: 1D tensor with shape (1) or 2D tensor with shape
|
|
533
|
+
(batch_size, sequence_length)
|
|
534
|
+
- cos_cache: 2D tensor with shape (max_sequence_length, head_size / 2)
|
|
535
|
+
or (max_sequence_length, rotary_embedding_dim / 2)
|
|
536
|
+
- sin_cache: 2D tensor with shape (max_sequence_length, head_size / 2)
|
|
537
|
+
or (max_sequence_length, rotary_embedding_dim / 2)
|
|
538
|
+
|
|
539
|
+
Attributes:
|
|
540
|
+
- interleaved: Indicates whether the input has real and imaginary parts
|
|
541
|
+
interleaved. Default is 0 (False).
|
|
542
|
+
- num_heads: Number of attention heads. Default is 0.
|
|
543
|
+
- rotary_embedding_dim: Rotary embedding dimension. Default is 0.
|
|
544
|
+
- scale: Custom scale. Default is 1.0.
|
|
545
|
+
|
|
546
|
+
Outputs:
|
|
547
|
+
- output: tensor with same shape as input.
|
|
548
|
+
"""
|
|
549
|
+
# Get inputs
|
|
550
|
+
input_tensor = builder.get_value(node.input[0])
|
|
551
|
+
position_ids = builder.get_value(node.input[1])
|
|
552
|
+
cos_cache = builder.get_value(node.input[2])
|
|
553
|
+
sin_cache = builder.get_value(node.input[3])
|
|
554
|
+
|
|
555
|
+
# Get attributes
|
|
556
|
+
interleaved = get_attribute(node, "interleaved", 0)
|
|
557
|
+
num_heads = get_attribute(node, "num_heads", 0)
|
|
558
|
+
rotary_embedding_dim = get_attribute(node, "rotary_embedding_dim", 0)
|
|
559
|
+
scale = get_attribute(node, "scale", 1.0)
|
|
560
|
+
|
|
561
|
+
def _rotary_embedding(
|
|
562
|
+
x: torch.Tensor,
|
|
563
|
+
pos_ids: torch.Tensor,
|
|
564
|
+
cos_cache: torch.Tensor,
|
|
565
|
+
sin_cache: torch.Tensor,
|
|
566
|
+
interleaved: int,
|
|
567
|
+
num_heads: int,
|
|
568
|
+
rotary_dim: int,
|
|
569
|
+
scale: float,
|
|
570
|
+
) -> torch.Tensor:
|
|
571
|
+
"""Apply rotary position embeddings."""
|
|
572
|
+
original_shape = x.shape
|
|
573
|
+
is_3d = x.dim() == 3
|
|
574
|
+
|
|
575
|
+
if is_3d:
|
|
576
|
+
# Input is (batch_size, seq_len, hidden_size)
|
|
577
|
+
batch_size, seq_len, hidden_size = x.shape
|
|
578
|
+
|
|
579
|
+
# Determine head_size and num_heads
|
|
580
|
+
if num_heads > 0:
|
|
581
|
+
head_size = hidden_size // num_heads
|
|
582
|
+
actual_num_heads = num_heads
|
|
583
|
+
else:
|
|
584
|
+
# Infer head_size from cos_cache dimension
|
|
585
|
+
# cos_cache has shape (max_seq, rotary_dim/2)
|
|
586
|
+
rotary_half_dim = cos_cache.shape[-1]
|
|
587
|
+
head_size = rotary_half_dim * 2 # rotary_dim == head_size typically
|
|
588
|
+
actual_num_heads = hidden_size // head_size
|
|
589
|
+
|
|
590
|
+
# Reshape to (batch, num_heads, seq, head_size)
|
|
591
|
+
x = x.view(batch_size, seq_len, actual_num_heads, head_size).transpose(1, 2)
|
|
592
|
+
else:
|
|
593
|
+
# Input is (batch_size, num_heads, seq_len, head_size)
|
|
594
|
+
batch_size, actual_num_heads, seq_len, head_size = x.shape
|
|
595
|
+
|
|
596
|
+
# Get cos/sin values for positions
|
|
597
|
+
# position_ids can be (1,) scalar or (batch, seq) or (seq,)
|
|
598
|
+
if pos_ids.dim() == 1:
|
|
599
|
+
if pos_ids.numel() == 1:
|
|
600
|
+
# Single position offset - generate sequence
|
|
601
|
+
start_pos = pos_ids.item()
|
|
602
|
+
positions = torch.arange(
|
|
603
|
+
start_pos, start_pos + seq_len, device=x.device, dtype=torch.long
|
|
604
|
+
)
|
|
605
|
+
else:
|
|
606
|
+
positions = pos_ids
|
|
607
|
+
else:
|
|
608
|
+
# (batch, seq) - use first batch for now (they should be the same)
|
|
609
|
+
positions = pos_ids[0] if pos_ids.shape[0] > 1 else pos_ids.squeeze(0)
|
|
610
|
+
|
|
611
|
+
# Gather cos/sin from cache based on positions
|
|
612
|
+
cos = cos_cache[positions] # (seq_len, rotary_dim/2)
|
|
613
|
+
sin = sin_cache[positions] # (seq_len, rotary_dim/2)
|
|
614
|
+
|
|
615
|
+
# Determine rotary dimension
|
|
616
|
+
if rotary_dim > 0:
|
|
617
|
+
rot_dim = rotary_dim
|
|
618
|
+
else:
|
|
619
|
+
rot_dim = cos.shape[-1] * 2 # cos/sin cache is half the rotary dim
|
|
620
|
+
|
|
621
|
+
# Expand cos/sin for batch and heads: (1, 1, seq_len, rotary_dim/2)
|
|
622
|
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
|
623
|
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
|
624
|
+
|
|
625
|
+
# Apply scale if specified
|
|
626
|
+
if scale != 1.0:
|
|
627
|
+
x = x * scale
|
|
628
|
+
|
|
629
|
+
# Split into rotary and pass-through parts
|
|
630
|
+
x_rot = x[..., :rot_dim]
|
|
631
|
+
x_pass = x[..., rot_dim:] if rot_dim < x.shape[-1] else None
|
|
632
|
+
|
|
633
|
+
if interleaved:
|
|
634
|
+
# Interleaved format: [x0, y0, x1, y1, ...] pairs
|
|
635
|
+
# Rotate pairs: (x, y) -> (x*cos - y*sin, x*sin + y*cos)
|
|
636
|
+
x1 = x_rot[..., ::2] # Even indices
|
|
637
|
+
x2 = x_rot[..., 1::2] # Odd indices
|
|
638
|
+
|
|
639
|
+
# Make sure cos/sin match the half dimension
|
|
640
|
+
cos_half = cos[..., : x1.shape[-1]]
|
|
641
|
+
sin_half = sin[..., : x1.shape[-1]]
|
|
642
|
+
|
|
643
|
+
# Apply rotation
|
|
644
|
+
x_rot_new = torch.stack(
|
|
645
|
+
[x1 * cos_half - x2 * sin_half, x1 * sin_half + x2 * cos_half], dim=-1
|
|
646
|
+
).flatten(-2)
|
|
647
|
+
else:
|
|
648
|
+
# Non-interleaved format: first half real, second half imaginary
|
|
649
|
+
# x = [x1, x2] where x1 and x2 are halves
|
|
650
|
+
half_dim = rot_dim // 2
|
|
651
|
+
x1 = x_rot[..., :half_dim]
|
|
652
|
+
x2 = x_rot[..., half_dim:rot_dim]
|
|
653
|
+
|
|
654
|
+
# Apply rotation: (x1, x2) -> (x1*cos - x2*sin, x1*sin + x2*cos)
|
|
655
|
+
x_rot_new = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
|
|
656
|
+
|
|
657
|
+
# Concatenate with pass-through part
|
|
658
|
+
if x_pass is not None:
|
|
659
|
+
x_out = torch.cat([x_rot_new, x_pass], dim=-1)
|
|
660
|
+
else:
|
|
661
|
+
x_out = x_rot_new
|
|
662
|
+
|
|
663
|
+
# Reshape back to original shape
|
|
664
|
+
if is_3d:
|
|
665
|
+
# Always reshape back from (batch, num_heads, seq, head_size) to (batch, seq, hidden)
|
|
666
|
+
x_out = x_out.transpose(1, 2).contiguous().view(original_shape)
|
|
667
|
+
|
|
668
|
+
return x_out
|
|
669
|
+
|
|
670
|
+
return builder.call_function(
|
|
671
|
+
_rotary_embedding,
|
|
672
|
+
args=(
|
|
673
|
+
input_tensor,
|
|
674
|
+
position_ids,
|
|
675
|
+
cos_cache,
|
|
676
|
+
sin_cache,
|
|
677
|
+
interleaved,
|
|
678
|
+
num_heads,
|
|
679
|
+
rotary_embedding_dim,
|
|
680
|
+
scale,
|
|
681
|
+
),
|
|
682
|
+
)
|