onnx2fx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1055 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Attention and Transformer related operators (standard ONNX domain).
3
+
4
+ Microsoft domain operators are in attention_msft.py.
5
+ """
6
+
7
+ from typing import TYPE_CHECKING
8
+
9
+ import onnx
10
+ import torch
11
+
12
+ from ..op_registry import register
13
+ from ..utils.attributes import get_attribute
14
+ from ..utils.op_helpers import get_optional_input
15
+
16
+ if TYPE_CHECKING:
17
+ from ..graph_builder import GraphBuilder
18
+
19
+
20
+ # =============================================================================
21
+ # Embedding and LayerNorm variants
22
+ # =============================================================================
23
+
24
+
25
+ @register("SkipLayerNormalization")
26
+ def skip_layer_normalization(
27
+ builder: "GraphBuilder", node: onnx.NodeProto
28
+ ) -> torch.fx.Node:
29
+ """Skip connection + LayerNorm (common in transformers)."""
30
+ x = builder.get_value(node.input[0])
31
+ skip = builder.get_value(node.input[1])
32
+ gamma = builder.get_value(node.input[2])
33
+ beta = get_optional_input(builder, node, 3)
34
+ bias = get_optional_input(builder, node, 4)
35
+
36
+ epsilon = get_attribute(node, "epsilon", 1e-5)
37
+
38
+ def _skip_layer_norm(
39
+ inp: torch.Tensor,
40
+ sk: torch.Tensor,
41
+ g: torch.Tensor,
42
+ b: torch.Tensor | None,
43
+ bi: torch.Tensor | None,
44
+ eps: float,
45
+ ) -> torch.Tensor:
46
+ hidden = inp + sk
47
+ if bi is not None:
48
+ hidden = hidden + bi
49
+ return torch.nn.functional.layer_norm(
50
+ hidden, hidden.shape[-1:], weight=g, bias=b, eps=eps
51
+ )
52
+
53
+ return builder.call_function(
54
+ _skip_layer_norm, args=(x, skip, gamma, beta, bias, epsilon)
55
+ )
56
+
57
+
58
+ @register("EmbedLayerNormalization")
59
+ def embed_layer_normalization(
60
+ builder: "GraphBuilder", node: onnx.NodeProto
61
+ ) -> torch.fx.Node:
62
+ """Embedding + LayerNorm (common in BERT-like models)."""
63
+ input_ids = builder.get_value(node.input[0])
64
+ segment_ids = get_optional_input(builder, node, 1)
65
+ word_embedding = builder.get_value(node.input[2])
66
+ position_embedding = builder.get_value(node.input[3])
67
+ segment_embedding = get_optional_input(builder, node, 4)
68
+ gamma = get_optional_input(builder, node, 5)
69
+ beta = get_optional_input(builder, node, 6)
70
+
71
+ epsilon = get_attribute(node, "epsilon", 1e-5)
72
+
73
+ def _embed_layer_norm(
74
+ ids: torch.Tensor,
75
+ seg_ids: torch.Tensor | None,
76
+ word_emb: torch.Tensor,
77
+ pos_emb: torch.Tensor,
78
+ seg_emb: torch.Tensor | None,
79
+ g: torch.Tensor | None,
80
+ b: torch.Tensor | None,
81
+ eps: float,
82
+ ) -> torch.Tensor:
83
+ # Word embedding lookup
84
+ word_embed = torch.nn.functional.embedding(ids, word_emb)
85
+
86
+ # Position embedding (assume sequential positions)
87
+ seq_len = ids.shape[1]
88
+ pos_embed = pos_emb[:seq_len].unsqueeze(0).expand(ids.shape[0], -1, -1)
89
+
90
+ hidden = word_embed + pos_embed
91
+
92
+ # Add segment embedding if present
93
+ if seg_emb is not None and seg_ids is not None:
94
+ seg_embed = torch.nn.functional.embedding(seg_ids, seg_emb)
95
+ hidden = hidden + seg_embed
96
+
97
+ # Layer normalization
98
+ if g is not None:
99
+ hidden = torch.nn.functional.layer_norm(
100
+ hidden, hidden.shape[-1:], weight=g, bias=b, eps=eps
101
+ )
102
+
103
+ return hidden
104
+
105
+ return builder.call_function(
106
+ _embed_layer_norm,
107
+ args=(
108
+ input_ids,
109
+ segment_ids,
110
+ word_embedding,
111
+ position_embedding,
112
+ segment_embedding,
113
+ gamma,
114
+ beta,
115
+ epsilon,
116
+ ),
117
+ )
118
+
119
+
120
+ # =============================================================================
121
+ # Attention operator (ONNX standard domain, since opset 24)
122
+ # =============================================================================
123
+
124
+
125
+ @register("Attention")
126
+ def attention(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
127
+ """ONNX Attention operator (standard domain, since opset 24).
128
+
129
+ Inputs:
130
+ Q: Query tensor
131
+ - 4D: (batch_size, q_num_heads, q_sequence_length, head_size)
132
+ - 3D: (batch_size, q_sequence_length, q_num_heads * head_size)
133
+ K: Key tensor
134
+ - 4D: (batch_size, kv_num_heads, kv_sequence_length, head_size)
135
+ - 3D: (batch_size, kv_sequence_length, kv_num_heads * head_size)
136
+ V: Value tensor (same format as K)
137
+ attn_mask (optional): Attention mask, broadcastable to
138
+ (batch_size, q_num_heads, q_sequence_length, total_sequence_length)
139
+ past_key (optional): Past key cache
140
+ past_value (optional): Past value cache
141
+ nonpad_kv_seqlen (optional): Non-padding KV sequence lengths
142
+
143
+ Attributes:
144
+ is_causal: If 1, use causal (lower triangular) mask
145
+ scale: Scaling factor for Q*K^T (default: 1/sqrt(head_size))
146
+ softcap: Softcap value for attention weights
147
+ q_num_heads: Number of query heads (required for 3D inputs)
148
+ kv_num_heads: Number of key/value heads (required for 3D inputs)
149
+ qk_matmul_output_mode: Output mode for QK matmul (0, 1, or 2)
150
+
151
+ Outputs:
152
+ Y: Output tensor (same format as Q)
153
+ present_key (optional): Updated key cache
154
+ present_value (optional): Updated value cache
155
+ qk_matmul_output (optional): QK matmul output
156
+ """
157
+ # Get inputs
158
+ query = builder.get_value(node.input[0])
159
+ key = builder.get_value(node.input[1])
160
+ value = builder.get_value(node.input[2])
161
+
162
+ attn_mask = get_optional_input(builder, node, 3)
163
+ past_key = get_optional_input(builder, node, 4)
164
+ past_value = get_optional_input(builder, node, 5)
165
+ nonpad_kv_seqlen = get_optional_input(builder, node, 6)
166
+
167
+ # Get attributes
168
+ is_causal = get_attribute(node, "is_causal", 0)
169
+ scale = get_attribute(node, "scale", None)
170
+ softcap = get_attribute(node, "softcap", 0.0)
171
+ q_num_heads = get_attribute(node, "q_num_heads", None)
172
+ kv_num_heads = get_attribute(node, "kv_num_heads", None)
173
+ qk_matmul_output_mode = get_attribute(node, "qk_matmul_output_mode", 0)
174
+
175
+ # Determine which outputs are needed
176
+ # Output positions: 0=Y, 1=present_key, 2=present_value, 3=qk_matmul_output
177
+ num_outputs = len(node.output)
178
+ has_present_key = num_outputs > 1 and node.output[1]
179
+ has_present_value = num_outputs > 2 and node.output[2]
180
+ has_qk_matmul_output = num_outputs > 3 and node.output[3]
181
+
182
+ # Check if we need multiple outputs (even if some are empty)
183
+ _needs_multiple_outputs = num_outputs > 1
184
+
185
+ # Use manual attention computation when:
186
+ # 1. We need qk_matmul_output
187
+ # 2. We have softcap
188
+ # 3. We have both is_causal and attn_mask
189
+ needs_manual_attention = (
190
+ has_qk_matmul_output or softcap != 0.0 or (is_causal and attn_mask is not None)
191
+ )
192
+
193
+ if needs_manual_attention:
194
+ # Manual attention implementation for advanced features
195
+
196
+ def _attention_manual(
197
+ q: torch.Tensor,
198
+ k: torch.Tensor,
199
+ v: torch.Tensor,
200
+ mask: torch.Tensor | None,
201
+ past_k: torch.Tensor | None,
202
+ past_v: torch.Tensor | None,
203
+ is_causal: int,
204
+ scale: float | None,
205
+ softcap: float,
206
+ q_num_heads: int | None,
207
+ kv_num_heads: int | None,
208
+ num_outputs: int,
209
+ qk_matmul_output_mode: int,
210
+ ) -> (
211
+ torch.Tensor
212
+ | tuple[torch.Tensor, torch.Tensor]
213
+ | tuple[torch.Tensor, torch.Tensor, torch.Tensor]
214
+ | tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
215
+ ):
216
+ is_3d = q.dim() == 3
217
+
218
+ if is_3d:
219
+ batch_size = q.shape[0]
220
+ q_seq_len = q.shape[1]
221
+ kv_seq_len = k.shape[1]
222
+ n_q_heads = q_num_heads if q_num_heads is not None else 1
223
+ n_kv_heads = kv_num_heads if kv_num_heads is not None else n_q_heads
224
+ q_head_size = q.shape[2] // n_q_heads
225
+ kv_head_size = k.shape[2] // n_kv_heads
226
+ v_head_size = v.shape[2] // n_kv_heads
227
+ q_4d = q.view(batch_size, q_seq_len, n_q_heads, q_head_size).transpose(
228
+ 1, 2
229
+ )
230
+ k_4d = k.view(
231
+ batch_size, kv_seq_len, n_kv_heads, kv_head_size
232
+ ).transpose(1, 2)
233
+ v_4d = v.view(
234
+ batch_size, kv_seq_len, n_kv_heads, v_head_size
235
+ ).transpose(1, 2)
236
+ else:
237
+ q_4d = q
238
+ k_4d = k
239
+ v_4d = v
240
+ batch_size = q.shape[0]
241
+ n_q_heads = q.shape[1]
242
+ n_kv_heads = k.shape[1]
243
+ q_seq_len = q.shape[2]
244
+ q_head_size = q.shape[3]
245
+
246
+ # Handle past key/value (KV cache)
247
+ past_seq_len = 0
248
+ if past_k is not None and past_v is not None:
249
+ past_seq_len = past_k.shape[2]
250
+ k_4d = torch.cat([past_k, k_4d], dim=2)
251
+ v_4d = torch.cat([past_v, v_4d], dim=2)
252
+
253
+ # Save present_key and present_value before GQA expansion (if needed for output)
254
+ present_k = k_4d if num_outputs > 1 else None
255
+ present_v = v_4d if num_outputs > 2 else None
256
+
257
+ # Handle GQA: expand KV heads to match Q heads
258
+ if n_kv_heads != n_q_heads:
259
+ n_rep = n_q_heads // n_kv_heads
260
+ k_4d = k_4d.repeat_interleave(n_rep, dim=1)
261
+ v_4d = v_4d.repeat_interleave(n_rep, dim=1)
262
+
263
+ # Compute attention scale
264
+ if scale is None:
265
+ scale_val = 1.0 / (q_head_size**0.5)
266
+ else:
267
+ scale_val = scale
268
+
269
+ # Q @ K^T with scaling
270
+ # (batch, heads, q_seq, head) @ (batch, heads, head, kv_seq)
271
+ # -> (batch, heads, q_seq, kv_seq)
272
+ qk = torch.matmul(q_4d, k_4d.transpose(-2, -1)) * scale_val
273
+
274
+ # Save QK matmul output before applying mask/softmax (if needed for output)
275
+ # Mode 0: raw QK matmul output
276
+ qk_output = None
277
+ if num_outputs > 3 and qk_matmul_output_mode == 0:
278
+ qk_output = qk.clone()
279
+
280
+ # Build combined attention mask (applied BEFORE softcap per ONNX spec)
281
+ # The ONNX approach: create causal mask first, add attn_mask to it,
282
+ # then add combined mask to QK scores
283
+ kv_seq = k_4d.shape[2]
284
+ combined_mask = None
285
+
286
+ # Create causal mask if is_causal=1
287
+ # ONNX uses: Less(q_pos + past_len, k_pos) to determine masked positions
288
+ # This creates a strictly lower triangular mask where q_pos + past_len >= k_pos is allowed
289
+ if is_causal:
290
+ # Create causal mask: (q_pos + past_seq_len) < k_pos means masked
291
+ row = (
292
+ torch.arange(q_seq_len, device=q.device).view(-1, 1) + past_seq_len
293
+ )
294
+ col = torch.arange(kv_seq, device=q.device).view(1, -1)
295
+ causal_bool = row < col # True where masked
296
+ causal_mask = (
297
+ torch.where(causal_bool, float("-inf"), 0.0)
298
+ .to(q.dtype)
299
+ .unsqueeze(0)
300
+ .unsqueeze(0)
301
+ )
302
+ combined_mask = causal_mask
303
+
304
+ # Add attention mask to causal mask (or use just attn_mask if no causal)
305
+ if mask is not None:
306
+ if combined_mask is not None:
307
+ combined_mask = mask + combined_mask # attn_mask + causal_mask
308
+ else:
309
+ combined_mask = mask
310
+
311
+ # Add combined mask to QK scores
312
+ if combined_mask is not None:
313
+ qk = qk + combined_mask
314
+
315
+ # Mode 1: after attention mask addition (including causal)
316
+ if num_outputs > 3 and qk_matmul_output_mode == 1:
317
+ qk_output = qk.clone()
318
+
319
+ # Apply softcap if specified (after mask, before softmax per ONNX spec)
320
+ if softcap != 0.0:
321
+ qk = softcap * torch.tanh(qk / softcap)
322
+
323
+ # Mode 2: after softcap
324
+ if num_outputs > 3 and qk_matmul_output_mode == 2:
325
+ qk_output = qk.clone()
326
+
327
+ # Softmax
328
+ attn_weights = torch.nn.functional.softmax(qk, dim=-1)
329
+
330
+ # Mode 3: after softmax
331
+ if num_outputs > 3 and qk_matmul_output_mode == 3:
332
+ qk_output = attn_weights.clone()
333
+
334
+ # Attention @ V
335
+ output = torch.matmul(attn_weights, v_4d)
336
+
337
+ if is_3d:
338
+ output = (
339
+ output.transpose(1, 2).contiguous().view(batch_size, q_seq_len, -1)
340
+ )
341
+
342
+ # Return based on num_outputs (must match exactly)
343
+ # Output positions: 0=Y, 1=present_key, 2=present_value, 3=qk_matmul_output
344
+ if num_outputs == 1:
345
+ return output
346
+ elif num_outputs == 2:
347
+ return (output, present_k)
348
+ elif num_outputs == 3:
349
+ return (output, present_k, present_v)
350
+ else: # num_outputs == 4
351
+ return (output, present_k, present_v, qk_output)
352
+
353
+ return builder.call_function(
354
+ _attention_manual,
355
+ args=(
356
+ query,
357
+ key,
358
+ value,
359
+ attn_mask,
360
+ past_key,
361
+ past_value,
362
+ is_causal,
363
+ scale,
364
+ softcap,
365
+ q_num_heads,
366
+ kv_num_heads,
367
+ num_outputs,
368
+ qk_matmul_output_mode,
369
+ ),
370
+ )
371
+ elif has_present_key or has_present_value:
372
+ # Use SDPA but also return present_key/present_value
373
+
374
+ def _attention_with_cache(
375
+ q: torch.Tensor,
376
+ k: torch.Tensor,
377
+ v: torch.Tensor,
378
+ mask: torch.Tensor | None,
379
+ past_k: torch.Tensor | None,
380
+ past_v: torch.Tensor | None,
381
+ is_causal: int,
382
+ scale: float | None,
383
+ q_num_heads: int | None,
384
+ kv_num_heads: int | None,
385
+ has_present_key: bool,
386
+ has_present_value: bool,
387
+ ) -> (
388
+ torch.Tensor
389
+ | tuple[torch.Tensor, torch.Tensor]
390
+ | tuple[torch.Tensor, torch.Tensor, torch.Tensor]
391
+ ):
392
+ is_3d = q.dim() == 3
393
+
394
+ if is_3d:
395
+ batch_size = q.shape[0]
396
+ q_seq_len = q.shape[1]
397
+ kv_seq_len = k.shape[1]
398
+ n_q_heads = q_num_heads if q_num_heads is not None else 1
399
+ n_kv_heads = kv_num_heads if kv_num_heads is not None else n_q_heads
400
+ q_head_size = q.shape[2] // n_q_heads
401
+ kv_head_size = k.shape[2] // n_kv_heads
402
+ v_head_size = v.shape[2] // n_kv_heads
403
+ q_4d = q.view(batch_size, q_seq_len, n_q_heads, q_head_size).transpose(
404
+ 1, 2
405
+ )
406
+ k_4d = k.view(
407
+ batch_size, kv_seq_len, n_kv_heads, kv_head_size
408
+ ).transpose(1, 2)
409
+ v_4d = v.view(
410
+ batch_size, kv_seq_len, n_kv_heads, v_head_size
411
+ ).transpose(1, 2)
412
+ else:
413
+ q_4d = q
414
+ k_4d = k
415
+ v_4d = v
416
+ batch_size = q.shape[0]
417
+ n_q_heads = q.shape[1]
418
+ n_kv_heads = k.shape[1]
419
+ q_seq_len = q.shape[2]
420
+
421
+ # Handle past key/value (KV cache)
422
+ if past_k is not None and past_v is not None:
423
+ k_4d = torch.cat([past_k, k_4d], dim=2)
424
+ v_4d = torch.cat([past_v, v_4d], dim=2)
425
+
426
+ # Save present_key and present_value before GQA expansion
427
+ present_k = k_4d if has_present_key else None
428
+ present_v = v_4d if has_present_value else None
429
+
430
+ # Handle GQA: expand KV heads to match Q heads
431
+ if n_kv_heads != n_q_heads:
432
+ n_rep = n_q_heads // n_kv_heads
433
+ k_4d = k_4d.repeat_interleave(n_rep, dim=1)
434
+ v_4d = v_4d.repeat_interleave(n_rep, dim=1)
435
+
436
+ # Call SDPA
437
+ if mask is not None:
438
+ output = torch.nn.functional.scaled_dot_product_attention(
439
+ q_4d, k_4d, v_4d, attn_mask=mask, is_causal=False, scale=scale
440
+ )
441
+ elif is_causal:
442
+ output = torch.nn.functional.scaled_dot_product_attention(
443
+ q_4d, k_4d, v_4d, is_causal=True, scale=scale
444
+ )
445
+ else:
446
+ output = torch.nn.functional.scaled_dot_product_attention(
447
+ q_4d, k_4d, v_4d, scale=scale
448
+ )
449
+
450
+ if is_3d:
451
+ output = (
452
+ output.transpose(1, 2).contiguous().view(batch_size, q_seq_len, -1)
453
+ )
454
+
455
+ # Build return tuple
456
+ results = [output]
457
+ if has_present_key:
458
+ results.append(present_k)
459
+ if has_present_value:
460
+ results.append(present_v)
461
+
462
+ if len(results) == 1:
463
+ return results[0]
464
+ return tuple(results)
465
+
466
+ return builder.call_function(
467
+ _attention_with_cache,
468
+ args=(
469
+ query,
470
+ key,
471
+ value,
472
+ attn_mask,
473
+ past_key,
474
+ past_value,
475
+ is_causal,
476
+ scale,
477
+ q_num_heads,
478
+ kv_num_heads,
479
+ has_present_key,
480
+ has_present_value,
481
+ ),
482
+ )
483
+ else:
484
+ # Simple case: just use SDPA
485
+
486
+ def _attention_standard(
487
+ q: torch.Tensor,
488
+ k: torch.Tensor,
489
+ v: torch.Tensor,
490
+ mask: torch.Tensor | None,
491
+ past_k: torch.Tensor | None,
492
+ past_v: torch.Tensor | None,
493
+ is_causal: int,
494
+ scale: float | None,
495
+ q_num_heads: int | None,
496
+ kv_num_heads: int | None,
497
+ nonpad_kv_seqlen: torch.Tensor | None,
498
+ ) -> torch.Tensor:
499
+ is_3d = q.dim() == 3
500
+
501
+ if is_3d:
502
+ batch_size = q.shape[0]
503
+ q_seq_len = q.shape[1]
504
+ kv_seq_len = k.shape[1]
505
+ n_q_heads = q_num_heads if q_num_heads is not None else 1
506
+ n_kv_heads = kv_num_heads if kv_num_heads is not None else n_q_heads
507
+ q_head_size = q.shape[2] // n_q_heads
508
+ kv_head_size = k.shape[2] // n_kv_heads
509
+ v_head_size = v.shape[2] // n_kv_heads
510
+ q_4d = q.view(batch_size, q_seq_len, n_q_heads, q_head_size).transpose(
511
+ 1, 2
512
+ )
513
+ k_4d = k.view(
514
+ batch_size, kv_seq_len, n_kv_heads, kv_head_size
515
+ ).transpose(1, 2)
516
+ v_4d = v.view(
517
+ batch_size, kv_seq_len, n_kv_heads, v_head_size
518
+ ).transpose(1, 2)
519
+ else:
520
+ q_4d = q
521
+ k_4d = k
522
+ v_4d = v
523
+ batch_size = q.shape[0]
524
+ n_q_heads = q.shape[1]
525
+ n_kv_heads = k.shape[1]
526
+ q_seq_len = q.shape[2]
527
+
528
+ # Handle past key/value (KV cache)
529
+ if past_k is not None and past_v is not None:
530
+ k_4d = torch.cat([past_k, k_4d], dim=2)
531
+ v_4d = torch.cat([past_v, v_4d], dim=2)
532
+
533
+ # Handle GQA: expand KV heads to match Q heads
534
+ if n_kv_heads != n_q_heads:
535
+ n_rep = n_q_heads // n_kv_heads
536
+ k_4d = k_4d.repeat_interleave(n_rep, dim=1)
537
+ v_4d = v_4d.repeat_interleave(n_rep, dim=1)
538
+
539
+ # Handle mask padding if mask is shorter than KV sequence
540
+ # Per ONNX spec: "The last dimension can also be shorter than
541
+ # total_sequence_length and will be padded with negative infinity"
542
+ kv_seq_len_actual = k_4d.shape[2]
543
+
544
+ if mask is not None:
545
+ # Pad mask if shorter than KV sequence (BEFORE adding nonpad mask)
546
+ if mask.shape[-1] < kv_seq_len_actual:
547
+ pad_size = kv_seq_len_actual - mask.shape[-1]
548
+ # For bool masks, pad with False; for float, pad with -inf
549
+ if mask.dtype == torch.bool:
550
+ mask = torch.nn.functional.pad(mask, (0, pad_size), value=False)
551
+ else:
552
+ mask = torch.nn.functional.pad(
553
+ mask, (0, pad_size), value=float("-inf")
554
+ )
555
+
556
+ # Expand mask for GQA if needed (mask might have n_kv_heads dimension)
557
+ if (
558
+ mask.dim() == 4
559
+ and mask.shape[1] == n_kv_heads
560
+ and n_kv_heads != n_q_heads
561
+ ):
562
+ n_rep = n_q_heads // n_kv_heads
563
+ mask = mask.repeat_interleave(n_rep, dim=1)
564
+
565
+ # Handle nonpad_kv_seqlen: create a padding mask for each batch
566
+ # that masks out positions >= nonpad_kv_seqlen[batch]
567
+ if nonpad_kv_seqlen is not None:
568
+ # Create a position index tensor: (1, 1, 1, kv_seq_len)
569
+ positions = torch.arange(kv_seq_len_actual, device=q.device).view(
570
+ 1, 1, 1, -1
571
+ )
572
+ # Create mask: True where position < nonpad_kv_seqlen[batch]
573
+ # nonpad_kv_seqlen: (batch_size,) -> (batch_size, 1, 1, 1)
574
+ valid_mask = positions < nonpad_kv_seqlen.view(-1, 1, 1, 1)
575
+ # Convert to additive mask: 0 for valid, -inf for padding
576
+ pad_mask = torch.where(valid_mask, 0.0, float("-inf")).to(q.dtype)
577
+ # Combine with existing mask
578
+ if mask is not None:
579
+ mask = mask + pad_mask
580
+ else:
581
+ mask = pad_mask
582
+
583
+ # Call SDPA
584
+ if mask is not None:
585
+ output = torch.nn.functional.scaled_dot_product_attention(
586
+ q_4d, k_4d, v_4d, attn_mask=mask, is_causal=False, scale=scale
587
+ )
588
+ elif is_causal:
589
+ output = torch.nn.functional.scaled_dot_product_attention(
590
+ q_4d, k_4d, v_4d, is_causal=True, scale=scale
591
+ )
592
+ else:
593
+ output = torch.nn.functional.scaled_dot_product_attention(
594
+ q_4d, k_4d, v_4d, scale=scale
595
+ )
596
+
597
+ if is_3d:
598
+ output = (
599
+ output.transpose(1, 2).contiguous().view(batch_size, q_seq_len, -1)
600
+ )
601
+
602
+ return output
603
+
604
+ return builder.call_function(
605
+ _attention_standard,
606
+ args=(
607
+ query,
608
+ key,
609
+ value,
610
+ attn_mask,
611
+ past_key,
612
+ past_value,
613
+ is_causal,
614
+ scale,
615
+ q_num_heads,
616
+ kv_num_heads,
617
+ nonpad_kv_seqlen,
618
+ ),
619
+ )
620
+
621
+
622
+ # =============================================================================
623
+ # Simplified LayerNormalization variants (ONNX Runtime contrib ops)
624
+ # =============================================================================
625
+
626
+
627
+ @register("SimplifiedLayerNormalization")
628
+ def simplified_layer_normalization(
629
+ builder: "GraphBuilder", node: onnx.NodeProto
630
+ ) -> torch.fx.Node:
631
+ """Simplified Layer Normalization (RMSNorm).
632
+
633
+ This is LayerNormalization without bias and mean subtraction.
634
+ Formula: output = x / sqrt(mean(x^2) + epsilon) * scale
635
+ """
636
+ x = builder.get_value(node.input[0])
637
+ scale = builder.get_value(node.input[1])
638
+
639
+ axis = get_attribute(node, "axis", -1)
640
+ epsilon = get_attribute(node, "epsilon", 1e-5)
641
+
642
+ def _simplified_layer_norm(x, scale, axis, epsilon):
643
+ # Simplified LayerNorm (RMSNorm)
644
+ # output = x * rsqrt(mean(x^2) + epsilon) * scale
645
+ if axis < 0:
646
+ axis_pos = x.dim() + axis
647
+ else:
648
+ axis_pos = axis
649
+
650
+ # Keep dims for broadcasting
651
+ dims = list(range(axis_pos, x.dim()))
652
+
653
+ # Compute RMS: sqrt(mean(x^2))
654
+ variance = x.pow(2).mean(dim=dims, keepdim=True)
655
+ x_normalized = x * torch.rsqrt(variance + epsilon)
656
+
657
+ return x_normalized * scale
658
+
659
+ return builder.call_function(_simplified_layer_norm, args=(x, scale, axis, epsilon))
660
+
661
+
662
+ @register("SkipSimplifiedLayerNormalization")
663
+ def skip_simplified_layer_normalization(
664
+ builder: "GraphBuilder", node: onnx.NodeProto
665
+ ) -> torch.fx.Node:
666
+ """Skip connection + Simplified Layer Normalization (RMSNorm)."""
667
+ x = builder.get_value(node.input[0])
668
+ skip = builder.get_value(node.input[1])
669
+ scale = builder.get_value(node.input[2])
670
+ bias = get_optional_input(builder, node, 3)
671
+
672
+ epsilon = get_attribute(node, "epsilon", 1e-5)
673
+
674
+ def _skip_simplified_layer_norm(x, skip, scale, bias, epsilon):
675
+ # Add skip connection
676
+ hidden = x + skip
677
+ if bias is not None:
678
+ hidden = hidden + bias
679
+
680
+ # Simplified LayerNorm (RMSNorm)
681
+ variance = hidden.pow(2).mean(dim=-1, keepdim=True)
682
+ hidden_normalized = hidden * torch.rsqrt(variance + epsilon)
683
+
684
+ return hidden_normalized * scale
685
+
686
+ return builder.call_function(
687
+ _skip_simplified_layer_norm, args=(x, skip, scale, bias, epsilon)
688
+ )
689
+
690
+
691
+ @register("GroupQueryAttention")
692
+ def group_query_attention(
693
+ builder: "GraphBuilder", node: onnx.NodeProto
694
+ ) -> torch.fx.Node:
695
+ """Group Query Attention (GQA) - used in LLaMA, Mistral, etc.
696
+
697
+ Inputs:
698
+ - query: [batch, seq_len, num_heads * head_size]
699
+ - key: [batch, kv_seq_len, num_kv_heads * head_size]
700
+ - value: [batch, kv_seq_len, num_kv_heads * head_size]
701
+ - past_key (optional): [batch, num_kv_heads, past_seq_len, head_size]
702
+ - past_value (optional): [batch, num_kv_heads, past_seq_len, head_size]
703
+ - seqlens_k (optional): cumulative sequence lengths for keys
704
+ - total_sequence_length (optional): total sequence length
705
+ - cos_cache (optional): [max_seq_len, head_size / 2] or [max_seq_len, head_size]
706
+ - sin_cache (optional): [max_seq_len, head_size / 2] or [max_seq_len, head_size]
707
+
708
+ Attributes:
709
+ - num_heads: number of attention heads
710
+ - kv_num_heads: number of key-value heads (for GQA)
711
+ - scale: scaling factor (default: 1/sqrt(head_size))
712
+ - local_window_size: for sliding window attention
713
+ - do_rotary: whether to apply rotary position embeddings
714
+ - rotary_interleaved: whether rotary is interleaved (GPT-NeoX style vs LLaMA)
715
+
716
+ Outputs:
717
+ - output: [batch, seq_len, num_heads * head_size]
718
+ - present_key: [batch, num_kv_heads, total_seq_len, head_size]
719
+ - present_value: [batch, num_kv_heads, total_seq_len, head_size]
720
+ """
721
+ # Get required inputs
722
+ query = builder.get_value(node.input[0])
723
+ key = builder.get_value(node.input[1])
724
+ value = builder.get_value(node.input[2])
725
+
726
+ # Get optional inputs
727
+ past_key = get_optional_input(builder, node, 3)
728
+ past_value = get_optional_input(builder, node, 4)
729
+ seqlens_k = get_optional_input(builder, node, 5)
730
+ total_seq_len = get_optional_input(builder, node, 6)
731
+ cos_cache = get_optional_input(builder, node, 7)
732
+ sin_cache = get_optional_input(builder, node, 8)
733
+
734
+ # Get attributes
735
+ num_heads = get_attribute(node, "num_heads", 1)
736
+ kv_num_heads = get_attribute(node, "kv_num_heads", num_heads)
737
+ scale = get_attribute(node, "scale", None)
738
+ local_window_size = get_attribute(node, "local_window_size", -1)
739
+ do_rotary = get_attribute(node, "do_rotary", 0)
740
+ rotary_interleaved = get_attribute(node, "rotary_interleaved", 0)
741
+
742
+ def _group_query_attention(
743
+ q: torch.Tensor,
744
+ k: torch.Tensor,
745
+ v: torch.Tensor,
746
+ past_k: torch.Tensor | None,
747
+ past_v: torch.Tensor | None,
748
+ seqlens_k: torch.Tensor | None,
749
+ total_seq_len: torch.Tensor | None,
750
+ cos_cache: torch.Tensor | None,
751
+ sin_cache: torch.Tensor | None,
752
+ n_heads: int,
753
+ n_kv_heads: int,
754
+ attn_scale: float | None,
755
+ window_size: int,
756
+ do_rotary: int,
757
+ rotary_interleaved: int,
758
+ ):
759
+ batch_size, seq_len, _ = q.shape
760
+ head_size = q.shape[-1] // n_heads
761
+ kv_head_size = k.shape[-1] // n_kv_heads
762
+
763
+ # Reshape Q, K, V to [batch, num_heads, seq_len, head_size]
764
+ q = q.view(batch_size, seq_len, n_heads, head_size).transpose(1, 2)
765
+ k = k.view(batch_size, -1, n_kv_heads, kv_head_size).transpose(1, 2)
766
+ v = v.view(batch_size, -1, n_kv_heads, kv_head_size).transpose(1, 2)
767
+
768
+ # Calculate position offset from past cache
769
+ past_seq_len = 0
770
+ if past_k is not None and past_k.numel() > 0:
771
+ past_seq_len = past_k.shape[2]
772
+
773
+ # Apply rotary position embeddings if enabled
774
+ if do_rotary and cos_cache is not None and sin_cache is not None:
775
+ # Get the position indices
776
+ positions = torch.arange(
777
+ past_seq_len, past_seq_len + seq_len, device=q.device
778
+ )
779
+
780
+ # Get cos/sin values for current positions
781
+ cos = cos_cache[positions] # [seq_len, rotary_dim]
782
+ sin = sin_cache[positions] # [seq_len, rotary_dim]
783
+
784
+ # Expand for batch and heads: [1, 1, seq_len, rotary_dim]
785
+ cos = cos.unsqueeze(0).unsqueeze(0)
786
+ sin = sin.unsqueeze(0).unsqueeze(0)
787
+
788
+ rotary_dim = cos.shape[-1]
789
+
790
+ if rotary_interleaved:
791
+ # GPT-NeoX style: [x0, x1, x2, x3, ...] -> rotate pairs
792
+ q_rot = q[..., :rotary_dim]
793
+ q_pass = q[..., rotary_dim:]
794
+ k_rot = k[..., :rotary_dim]
795
+ k_pass = k[..., rotary_dim:]
796
+
797
+ # Apply rotation
798
+ q1, q2 = q_rot[..., ::2], q_rot[..., 1::2]
799
+ k1, k2 = k_rot[..., ::2], k_rot[..., 1::2]
800
+
801
+ cos_half = cos[..., ::2]
802
+ sin_half = sin[..., ::2]
803
+
804
+ q_rot_new = torch.stack(
805
+ [q1 * cos_half - q2 * sin_half, q1 * sin_half + q2 * cos_half],
806
+ dim=-1,
807
+ ).flatten(-2)
808
+ k_rot_new = torch.stack(
809
+ [k1 * cos_half - k2 * sin_half, k1 * sin_half + k2 * cos_half],
810
+ dim=-1,
811
+ ).flatten(-2)
812
+
813
+ q = torch.cat([q_rot_new, q_pass], dim=-1)
814
+ k = torch.cat([k_rot_new, k_pass], dim=-1)
815
+ else:
816
+ # LLaMA style: cos/sin are [seq, rotary_dim]
817
+ # rotary_dim is half the head_size in this format
818
+ # q/k first rotary_dim*2 elements are rotated:
819
+ # q1 = q[..., :rotary_dim], q2 = q[..., rotary_dim:rotary_dim*2]
820
+ # result = (q1*cos - q2*sin, q1*sin + q2*cos)
821
+
822
+ rotary_full = rotary_dim * 2 # total dims that get rotated
823
+ q_rot = q[..., :rotary_full]
824
+ q_pass = q[..., rotary_full:]
825
+ k_rot = k[..., :rotary_full]
826
+ k_pass = k[..., rotary_full:]
827
+
828
+ # Split into first half and second half
829
+ q1, q2 = q_rot[..., :rotary_dim], q_rot[..., rotary_dim:rotary_full]
830
+ k1, k2 = k_rot[..., :rotary_dim], k_rot[..., rotary_dim:rotary_full]
831
+
832
+ # cos/sin are already in the right shape [1, 1, seq_len, rotary_dim]
833
+ q_rot_new = torch.cat(
834
+ [q1 * cos - q2 * sin, q1 * sin + q2 * cos], dim=-1
835
+ )
836
+ k_rot_new = torch.cat(
837
+ [k1 * cos - k2 * sin, k1 * sin + k2 * cos], dim=-1
838
+ )
839
+
840
+ q = torch.cat([q_rot_new, q_pass], dim=-1)
841
+ k = torch.cat([k_rot_new, k_pass], dim=-1)
842
+
843
+ # Handle past key-value cache
844
+ if past_k is not None and past_k.numel() > 0:
845
+ k = torch.cat([past_k, k], dim=2)
846
+ v = torch.cat([past_v, v], dim=2)
847
+
848
+ # Present key-value for caching
849
+ present_k = k
850
+ present_v = v
851
+
852
+ # Expand K, V for GQA (repeat for each head group)
853
+ if n_kv_heads < n_heads:
854
+ n_rep = n_heads // n_kv_heads
855
+ k = k.repeat_interleave(n_rep, dim=1)
856
+ v = v.repeat_interleave(n_rep, dim=1)
857
+
858
+ # Compute attention scale
859
+ if attn_scale is None:
860
+ attn_scale = 1.0 / (head_size**0.5)
861
+
862
+ # Use scaled_dot_product_attention
863
+ # For autoregressive with past cache, don't use causal mask for new tokens
864
+ # since past_k/v already handled the causality
865
+ is_causal = seq_len > 1 and past_seq_len == 0
866
+ output = torch.nn.functional.scaled_dot_product_attention(
867
+ q, k, v, scale=attn_scale, is_causal=is_causal
868
+ )
869
+
870
+ # Reshape output: [batch, num_heads, seq_len, head_size] -> [batch, seq_len, hidden]
871
+ output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
872
+
873
+ return output, present_k, present_v
874
+
875
+ # Build the call
876
+ result = builder.call_function(
877
+ _group_query_attention,
878
+ args=(
879
+ query,
880
+ key,
881
+ value,
882
+ past_key,
883
+ past_value,
884
+ seqlens_k,
885
+ total_seq_len,
886
+ cos_cache,
887
+ sin_cache,
888
+ num_heads,
889
+ kv_num_heads,
890
+ scale,
891
+ local_window_size,
892
+ do_rotary,
893
+ rotary_interleaved,
894
+ ),
895
+ )
896
+
897
+ # Return tuple output
898
+ return result
899
+
900
+
901
+ # =============================================================================
902
+ # Rotary Embedding (ONNX standard domain, since opset 23)
903
+ # =============================================================================
904
+
905
+
906
+ @register("RotaryEmbedding", since_version=23)
907
+ def rotary_embedding_onnx(
908
+ builder: "GraphBuilder", node: onnx.NodeProto
909
+ ) -> torch.fx.Node:
910
+ """ONNX RotaryEmbedding operator (standard domain, since opset 23).
911
+
912
+ Applies rotary position embeddings (RoPE) to the input tensor based on
913
+ https://arxiv.org/pdf/2104.09864
914
+
915
+ Inputs:
916
+ - X: 4D tensor (batch_size, num_heads, sequence_length, head_size) or
917
+ 3D tensor (batch_size, sequence_length, hidden_size)
918
+ - cos_cache: 2D tensor (max_position_id+1, head_size/2) when position_ids provided,
919
+ or 3D tensor (batch_size, sequence_length, head_size/2) otherwise
920
+ - sin_cache: Same shape as cos_cache
921
+ - position_ids (optional): 2D tensor (batch_size, sequence_length)
922
+
923
+ Attributes:
924
+ - interleaved: Whether to use interleaved pattern. Default is 0 (False).
925
+ - num_heads: Number of attention heads (required for 3D input).
926
+ - rotary_embedding_dim: Partial rotary dimension. Default is 0 (full rotation).
927
+
928
+ Outputs:
929
+ - Y: Tensor with same shape as input.
930
+ """
931
+ # Get inputs
932
+ input_tensor = builder.get_value(node.input[0])
933
+ cos_cache = builder.get_value(node.input[1])
934
+ sin_cache = builder.get_value(node.input[2])
935
+ position_ids = get_optional_input(builder, node, 3)
936
+
937
+ # Get attributes
938
+ interleaved = get_attribute(node, "interleaved", 0)
939
+ num_heads = get_attribute(node, "num_heads", 0)
940
+ rotary_embedding_dim = get_attribute(node, "rotary_embedding_dim", 0)
941
+
942
+ def _rotary_embedding_onnx(
943
+ x: torch.Tensor,
944
+ cos_cache: torch.Tensor,
945
+ sin_cache: torch.Tensor,
946
+ position_ids: torch.Tensor | None,
947
+ interleaved: int,
948
+ num_heads: int,
949
+ rotary_dim: int,
950
+ ) -> torch.Tensor:
951
+ """Apply ONNX-standard rotary position embeddings."""
952
+ original_shape = x.shape
953
+ is_3d = x.dim() == 3
954
+
955
+ # First ensure input has shape [batch_size, seq_len, num_heads, head_size]
956
+ if x.dim() == 4:
957
+ # Input is (batch_size, num_heads, seq_len, head_size)
958
+ # Transpose to (batch_size, seq_len, num_heads, head_size)
959
+ x = x.transpose(1, 2)
960
+ batch_size, seq_len, n_heads, head_size = x.shape
961
+ else:
962
+ # Input is (batch_size, seq_len, hidden_size)
963
+ batch_size, seq_len, hidden_size = x.shape
964
+ assert num_heads != 0, "num_heads must be provided for 3D input"
965
+ head_size = hidden_size // num_heads
966
+ x = x.view(batch_size, seq_len, num_heads, head_size)
967
+ _n_heads = num_heads
968
+
969
+ # Determine rotary_embedding_dim
970
+ if rotary_dim == 0:
971
+ rot_dim = head_size
972
+ else:
973
+ rot_dim = rotary_dim
974
+
975
+ rotary_dim_half = rot_dim // 2
976
+
977
+ # Split into rotary and pass-through parts
978
+ x_rotate = x[..., :rot_dim]
979
+ x_not_rotate = x[..., rot_dim:] if rot_dim < head_size else None
980
+
981
+ # Retrieve sin and cos caches using position ids
982
+ if position_ids is not None:
983
+ # cos_cache/sin_cache shape: (max_pos+1, rotary_dim/2)
984
+ # position_ids shape: (batch_size, seq_len)
985
+ # Result shape: (batch_size, seq_len, rotary_dim/2)
986
+ cos = cos_cache[position_ids]
987
+ sin = sin_cache[position_ids]
988
+ else:
989
+ # cos_cache/sin_cache already have shape (batch_size, seq_len, rotary_dim/2)
990
+ cos = cos_cache
991
+ sin = sin_cache
992
+
993
+ # Validate cache dimensions
994
+ if cos.shape[-1] != rotary_dim_half:
995
+ raise ValueError(
996
+ f"Last dimension of cos cache ({cos.shape[-1]}) does not match "
997
+ f"rotary_embedding_dim/2 ({rotary_dim_half})."
998
+ )
999
+
1000
+ # Add head dimension: (batch_size, seq_len, 1, rotary_dim/2)
1001
+ cos = cos.unsqueeze(2)
1002
+ sin = sin.unsqueeze(2)
1003
+
1004
+ # Apply rotation based on interleaved pattern
1005
+ if interleaved:
1006
+ # Interleaved: x_rotate[..., 0::2] and x_rotate[..., 1::2]
1007
+ x1 = x_rotate[..., 0::2]
1008
+ x2 = x_rotate[..., 1::2]
1009
+
1010
+ # Calculate real and imaginary values
1011
+ real = (cos * x1) - (sin * x2)
1012
+ imag = (sin * x1) + (cos * x2)
1013
+
1014
+ # Interleave back
1015
+ real = real.unsqueeze(-1)
1016
+ imag = imag.unsqueeze(-1)
1017
+ x_rotate_result = torch.cat((real, imag), dim=-1).flatten(-2)
1018
+ else:
1019
+ # Non-interleaved: split in halves
1020
+ x1 = x_rotate[..., :rotary_dim_half]
1021
+ x2 = x_rotate[..., rotary_dim_half:rot_dim]
1022
+
1023
+ # Calculate real and imaginary values
1024
+ real = (cos * x1) - (sin * x2)
1025
+ imag = (sin * x1) + (cos * x2)
1026
+
1027
+ x_rotate_result = torch.cat((real, imag), dim=-1)
1028
+
1029
+ # Concatenate with non-rotated part
1030
+ if x_not_rotate is not None:
1031
+ output = torch.cat((x_rotate_result, x_not_rotate), dim=-1)
1032
+ else:
1033
+ output = x_rotate_result
1034
+
1035
+ # Reshape back to original shape
1036
+ if is_3d:
1037
+ output = output.view(original_shape)
1038
+ else:
1039
+ # Transpose back to (batch_size, num_heads, seq_len, head_size)
1040
+ output = output.transpose(1, 2)
1041
+
1042
+ return output
1043
+
1044
+ return builder.call_function(
1045
+ _rotary_embedding_onnx,
1046
+ args=(
1047
+ input_tensor,
1048
+ cos_cache,
1049
+ sin_cache,
1050
+ position_ids,
1051
+ interleaved,
1052
+ num_heads,
1053
+ rotary_embedding_dim,
1054
+ ),
1055
+ )