onnx2fx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,682 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Microsoft domain (com.microsoft) attention and transformer operators.
3
+
4
+ This module implements attention-related operators for the com.microsoft domain,
5
+ commonly used by ONNX Runtime optimized models.
6
+ """
7
+
8
+ from typing import TYPE_CHECKING
9
+
10
+ import onnx
11
+ import torch
12
+
13
+ from ..op_registry import register
14
+ from ..utils.attributes import get_attribute
15
+ from ..utils.op_helpers import get_optional_input
16
+
17
+ if TYPE_CHECKING:
18
+ from ..graph_builder import GraphBuilder
19
+
20
+
21
+ # =============================================================================
22
+ # Embedding and LayerNorm variants (com.microsoft domain)
23
+ # =============================================================================
24
+
25
+
26
+ @register("SkipLayerNormalization", domain="com.microsoft")
27
+ def skip_layer_normalization_msft(
28
+ builder: "GraphBuilder", node: onnx.NodeProto
29
+ ) -> torch.fx.Node:
30
+ """Skip connection + LayerNorm (common in transformers).
31
+
32
+ Microsoft domain version (com.microsoft).
33
+ """
34
+ x = builder.get_value(node.input[0])
35
+ skip = builder.get_value(node.input[1])
36
+ gamma = builder.get_value(node.input[2])
37
+ beta = get_optional_input(builder, node, 3)
38
+ bias = get_optional_input(builder, node, 4)
39
+
40
+ epsilon = get_attribute(node, "epsilon", 1e-5)
41
+
42
+ def _skip_layer_norm(
43
+ inp: torch.Tensor,
44
+ sk: torch.Tensor,
45
+ g: torch.Tensor,
46
+ b: torch.Tensor | None,
47
+ bi: torch.Tensor | None,
48
+ eps: float,
49
+ ) -> torch.Tensor:
50
+ hidden = inp + sk
51
+ if bi is not None:
52
+ hidden = hidden + bi
53
+ return torch.nn.functional.layer_norm(
54
+ hidden, hidden.shape[-1:], weight=g, bias=b, eps=eps
55
+ )
56
+
57
+ return builder.call_function(
58
+ _skip_layer_norm, args=(x, skip, gamma, beta, bias, epsilon)
59
+ )
60
+
61
+
62
+ @register("EmbedLayerNormalization", domain="com.microsoft")
63
+ def embed_layer_normalization_msft(
64
+ builder: "GraphBuilder", node: onnx.NodeProto
65
+ ) -> torch.fx.Node:
66
+ """Embedding + LayerNorm (common in BERT-like models).
67
+
68
+ Microsoft domain version (com.microsoft).
69
+ """
70
+ input_ids = builder.get_value(node.input[0])
71
+ segment_ids = get_optional_input(builder, node, 1)
72
+ word_embedding = builder.get_value(node.input[2])
73
+ position_embedding = builder.get_value(node.input[3])
74
+ segment_embedding = get_optional_input(builder, node, 4)
75
+ gamma = get_optional_input(builder, node, 5)
76
+ beta = get_optional_input(builder, node, 6)
77
+
78
+ epsilon = get_attribute(node, "epsilon", 1e-5)
79
+
80
+ def _embed_layer_norm(
81
+ ids: torch.Tensor,
82
+ seg_ids: torch.Tensor | None,
83
+ word_emb: torch.Tensor,
84
+ pos_emb: torch.Tensor,
85
+ seg_emb: torch.Tensor | None,
86
+ g: torch.Tensor | None,
87
+ b: torch.Tensor | None,
88
+ eps: float,
89
+ ) -> torch.Tensor:
90
+ # Word embedding lookup
91
+ word_embed = torch.nn.functional.embedding(ids, word_emb)
92
+
93
+ # Position embedding (assume sequential positions)
94
+ seq_len = ids.shape[1]
95
+ pos_embed = pos_emb[:seq_len].unsqueeze(0).expand(ids.shape[0], -1, -1)
96
+
97
+ hidden = word_embed + pos_embed
98
+
99
+ # Add segment embedding if present
100
+ if seg_emb is not None and seg_ids is not None:
101
+ seg_embed = torch.nn.functional.embedding(seg_ids, seg_emb)
102
+ hidden = hidden + seg_embed
103
+
104
+ # Layer normalization
105
+ if g is not None:
106
+ hidden = torch.nn.functional.layer_norm(
107
+ hidden, hidden.shape[-1:], weight=g, bias=b, eps=eps
108
+ )
109
+
110
+ return hidden
111
+
112
+ return builder.call_function(
113
+ _embed_layer_norm,
114
+ args=(
115
+ input_ids,
116
+ segment_ids,
117
+ word_embedding,
118
+ position_embedding,
119
+ segment_embedding,
120
+ gamma,
121
+ beta,
122
+ epsilon,
123
+ ),
124
+ )
125
+
126
+
127
+ # =============================================================================
128
+ # Microsoft Attention operator (com.microsoft domain)
129
+ # =============================================================================
130
+
131
+
132
+ @register("Attention", domain="com.microsoft")
133
+ def microsoft_attention(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
134
+ """Microsoft Attention operator (com.microsoft domain).
135
+
136
+ Multi-Head Attention that can be either unidirectional (like GPT-2) or
137
+ bidirectional (like BERT). The weights for input projection of Q, K and V
138
+ are merged.
139
+
140
+ Inputs:
141
+ input: Input tensor with shape (batch_size, sequence_length, input_hidden_size)
142
+ weights: Merged Q/K/V weights with shape (input_hidden_size, hidden_size + hidden_size + v_hidden_size)
143
+ bias (optional): Bias tensor with shape (hidden_size + hidden_size + v_hidden_size)
144
+ mask_index (optional): Attention mask
145
+ past (optional): Past state for key and value
146
+ attention_bias (optional): Additional bias to add to QK
147
+ past_sequence_length (optional): Past sequence length
148
+
149
+ Attributes:
150
+ num_heads (required): Number of attention heads
151
+ unidirectional: Whether every token can only attend to previous tokens (default 0)
152
+ scale: Custom scale factor (default 1/sqrt(head_size))
153
+ mask_filter_value: Value to fill in attention mask (default -10000.0)
154
+
155
+ Outputs:
156
+ output: 3D output tensor with shape (batch_size, sequence_length, v_hidden_size)
157
+ present (optional): Past state for key and value
158
+ """
159
+ # Get inputs
160
+ input_tensor = builder.get_value(node.input[0])
161
+ weights = builder.get_value(node.input[1])
162
+ bias = get_optional_input(builder, node, 2)
163
+ mask_index = get_optional_input(builder, node, 3)
164
+ past = get_optional_input(builder, node, 4)
165
+ attention_bias = get_optional_input(builder, node, 5)
166
+
167
+ # Get attributes
168
+ num_heads = get_attribute(node, "num_heads", None)
169
+ if num_heads is None:
170
+ raise ValueError("num_heads attribute is required for Microsoft Attention")
171
+ unidirectional = get_attribute(node, "unidirectional", 0)
172
+ scale = get_attribute(node, "scale", None)
173
+
174
+ def _microsoft_attention(
175
+ inp: torch.Tensor,
176
+ w: torch.Tensor,
177
+ b: torch.Tensor | None,
178
+ mask: torch.Tensor | None,
179
+ past_kv: torch.Tensor | None,
180
+ attn_bias: torch.Tensor | None,
181
+ n_heads: int,
182
+ is_causal: bool,
183
+ attn_scale: float | None,
184
+ ) -> torch.Tensor:
185
+ batch_size, seq_len, hidden_size = inp.shape
186
+
187
+ # Project input to Q, K, V using merged weights
188
+ # weights shape: (input_hidden_size, 3 * hidden_size)
189
+ qkv = torch.matmul(inp, w)
190
+ if b is not None:
191
+ qkv = qkv + b
192
+
193
+ # Split into Q, K, V
194
+ q, k, v = qkv.chunk(3, dim=-1)
195
+
196
+ # Use scaled_dot_product_attention
197
+ output = torch.nn.functional.scaled_dot_product_attention(
198
+ q, k, v, attn_mask=mask, is_causal=is_causal, scale=attn_scale
199
+ )
200
+
201
+ return output
202
+
203
+ is_causal = unidirectional == 1
204
+
205
+ return builder.call_function(
206
+ _microsoft_attention,
207
+ args=(
208
+ input_tensor,
209
+ weights,
210
+ bias,
211
+ mask_index,
212
+ past,
213
+ attention_bias,
214
+ num_heads,
215
+ is_causal,
216
+ scale,
217
+ ),
218
+ )
219
+
220
+
221
+ # =============================================================================
222
+ # Simplified LayerNormalization variants (com.microsoft domain)
223
+ # =============================================================================
224
+
225
+
226
+ @register("SimplifiedLayerNormalization", domain="com.microsoft")
227
+ def simplified_layer_normalization_msft(
228
+ builder: "GraphBuilder", node: onnx.NodeProto
229
+ ) -> torch.fx.Node:
230
+ """Simplified Layer Normalization (RMSNorm).
231
+
232
+ This is LayerNormalization without bias and mean subtraction.
233
+ Formula: output = x / sqrt(mean(x^2) + epsilon) * scale
234
+
235
+ Microsoft domain version (com.microsoft).
236
+ """
237
+ x = builder.get_value(node.input[0])
238
+ scale = builder.get_value(node.input[1])
239
+
240
+ axis = get_attribute(node, "axis", -1)
241
+ epsilon = get_attribute(node, "epsilon", 1e-5)
242
+
243
+ def _simplified_layer_norm(x, scale, axis, epsilon):
244
+ # Simplified LayerNorm (RMSNorm)
245
+ # output = x * rsqrt(mean(x^2) + epsilon) * scale
246
+ if axis < 0:
247
+ axis_pos = x.dim() + axis
248
+ else:
249
+ axis_pos = axis
250
+
251
+ # Keep dims for broadcasting
252
+ dims = list(range(axis_pos, x.dim()))
253
+
254
+ # Compute RMS: sqrt(mean(x^2))
255
+ variance = x.pow(2).mean(dim=dims, keepdim=True)
256
+ x_normalized = x * torch.rsqrt(variance + epsilon)
257
+
258
+ return x_normalized * scale
259
+
260
+ return builder.call_function(_simplified_layer_norm, args=(x, scale, axis, epsilon))
261
+
262
+
263
+ @register("SkipSimplifiedLayerNormalization", domain="com.microsoft")
264
+ def skip_simplified_layer_normalization_msft(
265
+ builder: "GraphBuilder", node: onnx.NodeProto
266
+ ) -> torch.fx.Node:
267
+ """Skip connection + Simplified Layer Normalization (RMSNorm).
268
+
269
+ Microsoft domain version (com.microsoft).
270
+ """
271
+ x = builder.get_value(node.input[0])
272
+ skip = builder.get_value(node.input[1])
273
+ scale = builder.get_value(node.input[2])
274
+ bias = get_optional_input(builder, node, 3)
275
+
276
+ epsilon = get_attribute(node, "epsilon", 1e-5)
277
+
278
+ def _skip_simplified_layer_norm(x, skip, scale, bias, epsilon):
279
+ # Add skip connection
280
+ hidden = x + skip
281
+ if bias is not None:
282
+ hidden = hidden + bias
283
+
284
+ # Simplified LayerNorm (RMSNorm)
285
+ variance = hidden.pow(2).mean(dim=-1, keepdim=True)
286
+ hidden_normalized = hidden * torch.rsqrt(variance + epsilon)
287
+
288
+ return hidden_normalized * scale
289
+
290
+ return builder.call_function(
291
+ _skip_simplified_layer_norm, args=(x, skip, scale, bias, epsilon)
292
+ )
293
+
294
+
295
+ # =============================================================================
296
+ # GroupQueryAttention (com.microsoft domain)
297
+ # =============================================================================
298
+
299
+
300
+ @register("GroupQueryAttention", domain="com.microsoft")
301
+ def group_query_attention_msft(
302
+ builder: "GraphBuilder", node: onnx.NodeProto
303
+ ) -> torch.fx.Node:
304
+ """Group Query Attention (GQA) - used in LLaMA, Mistral, etc.
305
+
306
+ Microsoft domain version (com.microsoft).
307
+
308
+ Inputs:
309
+ - query: [batch, seq_len, num_heads * head_size]
310
+ - key: [batch, kv_seq_len, num_kv_heads * head_size]
311
+ - value: [batch, kv_seq_len, num_kv_heads * head_size]
312
+ - past_key (optional): [batch, num_kv_heads, past_seq_len, head_size]
313
+ - past_value (optional): [batch, num_kv_heads, past_seq_len, head_size]
314
+ - seqlens_k (optional): cumulative sequence lengths for keys
315
+ - total_sequence_length (optional): total sequence length
316
+ - cos_cache (optional): [max_seq_len, head_size / 2] or [max_seq_len, head_size]
317
+ - sin_cache (optional): [max_seq_len, head_size / 2] or [max_seq_len, head_size]
318
+
319
+ Attributes:
320
+ - num_heads: number of attention heads
321
+ - kv_num_heads: number of key-value heads (for GQA)
322
+ - scale: scaling factor (default: 1/sqrt(head_size))
323
+ - local_window_size: for sliding window attention
324
+ - do_rotary: whether to apply rotary position embeddings
325
+ - rotary_interleaved: whether rotary is interleaved (GPT-NeoX style vs LLaMA)
326
+
327
+ Outputs:
328
+ - output: [batch, seq_len, num_heads * head_size]
329
+ - present_key: [batch, num_kv_heads, total_seq_len, head_size]
330
+ - present_value: [batch, num_kv_heads, total_seq_len, head_size]
331
+ """
332
+ # Get required inputs
333
+ query = builder.get_value(node.input[0])
334
+ key = builder.get_value(node.input[1])
335
+ value = builder.get_value(node.input[2])
336
+
337
+ # Get optional inputs
338
+ past_key = get_optional_input(builder, node, 3)
339
+ past_value = get_optional_input(builder, node, 4)
340
+ seqlens_k = get_optional_input(builder, node, 5)
341
+ total_seq_len = get_optional_input(builder, node, 6)
342
+ cos_cache = get_optional_input(builder, node, 7)
343
+ sin_cache = get_optional_input(builder, node, 8)
344
+
345
+ # Get attributes
346
+ num_heads = get_attribute(node, "num_heads", 1)
347
+ kv_num_heads = get_attribute(node, "kv_num_heads", num_heads)
348
+ scale = get_attribute(node, "scale", None)
349
+ local_window_size = get_attribute(node, "local_window_size", -1)
350
+ do_rotary = get_attribute(node, "do_rotary", 0)
351
+ rotary_interleaved = get_attribute(node, "rotary_interleaved", 0)
352
+
353
+ def _group_query_attention(
354
+ q: torch.Tensor,
355
+ k: torch.Tensor,
356
+ v: torch.Tensor,
357
+ past_k: torch.Tensor | None,
358
+ past_v: torch.Tensor | None,
359
+ seqlens_k: torch.Tensor | None,
360
+ total_seq_len: torch.Tensor | None,
361
+ cos_cache: torch.Tensor | None,
362
+ sin_cache: torch.Tensor | None,
363
+ n_heads: int,
364
+ n_kv_heads: int,
365
+ attn_scale: float | None,
366
+ window_size: int,
367
+ do_rotary: int,
368
+ rotary_interleaved: int,
369
+ ):
370
+ batch_size, seq_len, _ = q.shape
371
+ head_size = q.shape[-1] // n_heads
372
+ kv_head_size = k.shape[-1] // n_kv_heads
373
+
374
+ # Reshape Q, K, V to [batch, num_heads, seq_len, head_size]
375
+ q = q.view(batch_size, seq_len, n_heads, head_size).transpose(1, 2)
376
+ k = k.view(batch_size, -1, n_kv_heads, kv_head_size).transpose(1, 2)
377
+ v = v.view(batch_size, -1, n_kv_heads, kv_head_size).transpose(1, 2)
378
+
379
+ # Calculate position offset from past cache
380
+ past_seq_len = 0
381
+ if past_k is not None and past_k.numel() > 0:
382
+ past_seq_len = past_k.shape[2]
383
+
384
+ # Apply rotary position embeddings if enabled
385
+ if do_rotary and cos_cache is not None and sin_cache is not None:
386
+ # Get the position indices
387
+ positions = torch.arange(
388
+ past_seq_len, past_seq_len + seq_len, device=q.device
389
+ )
390
+
391
+ # Get cos/sin values for current positions
392
+ cos = cos_cache[positions] # [seq_len, rotary_dim]
393
+ sin = sin_cache[positions] # [seq_len, rotary_dim]
394
+
395
+ # Expand for batch and heads: [1, 1, seq_len, rotary_dim]
396
+ cos = cos.unsqueeze(0).unsqueeze(0)
397
+ sin = sin.unsqueeze(0).unsqueeze(0)
398
+
399
+ rotary_dim = cos.shape[-1]
400
+
401
+ if rotary_interleaved:
402
+ # GPT-NeoX style: [x0, x1, x2, x3, ...] -> rotate pairs
403
+ q_rot = q[..., :rotary_dim]
404
+ q_pass = q[..., rotary_dim:]
405
+ k_rot = k[..., :rotary_dim]
406
+ k_pass = k[..., rotary_dim:]
407
+
408
+ # Apply rotation
409
+ q1, q2 = q_rot[..., ::2], q_rot[..., 1::2]
410
+ k1, k2 = k_rot[..., ::2], k_rot[..., 1::2]
411
+
412
+ cos_half = cos[..., ::2]
413
+ sin_half = sin[..., ::2]
414
+
415
+ q_rot_new = torch.stack(
416
+ [q1 * cos_half - q2 * sin_half, q1 * sin_half + q2 * cos_half],
417
+ dim=-1,
418
+ ).flatten(-2)
419
+ k_rot_new = torch.stack(
420
+ [k1 * cos_half - k2 * sin_half, k1 * sin_half + k2 * cos_half],
421
+ dim=-1,
422
+ ).flatten(-2)
423
+
424
+ q = torch.cat([q_rot_new, q_pass], dim=-1)
425
+ k = torch.cat([k_rot_new, k_pass], dim=-1)
426
+ else:
427
+ # LLaMA style: cos/sin are [seq, rotary_dim]
428
+ # rotary_dim is half the head_size in this format
429
+ # q/k first rotary_dim*2 elements are rotated:
430
+ # q1 = q[..., :rotary_dim], q2 = q[..., rotary_dim:rotary_dim*2]
431
+ # result = (q1*cos - q2*sin, q1*sin + q2*cos)
432
+
433
+ rotary_full = rotary_dim * 2 # total dims that get rotated
434
+ q_rot = q[..., :rotary_full]
435
+ q_pass = q[..., rotary_full:]
436
+ k_rot = k[..., :rotary_full]
437
+ k_pass = k[..., rotary_full:]
438
+
439
+ # Split into first half and second half
440
+ q1, q2 = q_rot[..., :rotary_dim], q_rot[..., rotary_dim:rotary_full]
441
+ k1, k2 = k_rot[..., :rotary_dim], k_rot[..., rotary_dim:rotary_full]
442
+
443
+ # cos/sin are already in the right shape [1, 1, seq_len, rotary_dim]
444
+ q_rot_new = torch.cat(
445
+ [q1 * cos - q2 * sin, q1 * sin + q2 * cos], dim=-1
446
+ )
447
+ k_rot_new = torch.cat(
448
+ [k1 * cos - k2 * sin, k1 * sin + k2 * cos], dim=-1
449
+ )
450
+
451
+ q = torch.cat([q_rot_new, q_pass], dim=-1)
452
+ k = torch.cat([k_rot_new, k_pass], dim=-1)
453
+
454
+ # Handle past key-value cache
455
+ if past_k is not None and past_k.numel() > 0:
456
+ k = torch.cat([past_k, k], dim=2)
457
+ v = torch.cat([past_v, v], dim=2)
458
+
459
+ # Present key-value for caching
460
+ present_k = k
461
+ present_v = v
462
+
463
+ # Expand K, V for GQA (repeat for each head group)
464
+ if n_kv_heads < n_heads:
465
+ n_rep = n_heads // n_kv_heads
466
+ k = k.repeat_interleave(n_rep, dim=1)
467
+ v = v.repeat_interleave(n_rep, dim=1)
468
+
469
+ # Compute attention scale
470
+ if attn_scale is None:
471
+ attn_scale = 1.0 / (head_size**0.5)
472
+
473
+ # Use scaled_dot_product_attention
474
+ # For autoregressive with past cache, don't use causal mask for new tokens
475
+ # since past_k/v already handled the causality
476
+ is_causal = seq_len > 1 and past_seq_len == 0
477
+ output = torch.nn.functional.scaled_dot_product_attention(
478
+ q, k, v, scale=attn_scale, is_causal=is_causal
479
+ )
480
+
481
+ # Reshape output: [batch, num_heads, seq_len, head_size] -> [batch, seq_len, hidden]
482
+ output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
483
+
484
+ return output, present_k, present_v
485
+
486
+ # Build the call
487
+ result = builder.call_function(
488
+ _group_query_attention,
489
+ args=(
490
+ query,
491
+ key,
492
+ value,
493
+ past_key,
494
+ past_value,
495
+ seqlens_k,
496
+ total_seq_len,
497
+ cos_cache,
498
+ sin_cache,
499
+ num_heads,
500
+ kv_num_heads,
501
+ scale,
502
+ local_window_size,
503
+ do_rotary,
504
+ rotary_interleaved,
505
+ ),
506
+ )
507
+
508
+ # Return tuple output
509
+ return result
510
+
511
+
512
+ # =============================================================================
513
+ # Rotary Embedding (com.microsoft domain)
514
+ # =============================================================================
515
+
516
+
517
+ @register("RotaryEmbedding", domain="com.microsoft")
518
+ def rotary_embedding_msft(
519
+ builder: "GraphBuilder", node: onnx.NodeProto
520
+ ) -> torch.fx.Node:
521
+ """Rotary Position Embedding (RoPE) operator.
522
+
523
+ Applies rotary position embeddings to the input tensor. The positions are
524
+ represented as rotation matrices that are multiplied to query and key
525
+ before the inner product of query and key is taken.
526
+
527
+ Microsoft domain version (com.microsoft).
528
+
529
+ Inputs:
530
+ - input: 3D tensor with shape (batch_size, sequence_length, hidden_size)
531
+ or 4D with shape (batch_size, num_heads, sequence_length, head_size)
532
+ - position_ids: 1D tensor with shape (1) or 2D tensor with shape
533
+ (batch_size, sequence_length)
534
+ - cos_cache: 2D tensor with shape (max_sequence_length, head_size / 2)
535
+ or (max_sequence_length, rotary_embedding_dim / 2)
536
+ - sin_cache: 2D tensor with shape (max_sequence_length, head_size / 2)
537
+ or (max_sequence_length, rotary_embedding_dim / 2)
538
+
539
+ Attributes:
540
+ - interleaved: Indicates whether the input has real and imaginary parts
541
+ interleaved. Default is 0 (False).
542
+ - num_heads: Number of attention heads. Default is 0.
543
+ - rotary_embedding_dim: Rotary embedding dimension. Default is 0.
544
+ - scale: Custom scale. Default is 1.0.
545
+
546
+ Outputs:
547
+ - output: tensor with same shape as input.
548
+ """
549
+ # Get inputs
550
+ input_tensor = builder.get_value(node.input[0])
551
+ position_ids = builder.get_value(node.input[1])
552
+ cos_cache = builder.get_value(node.input[2])
553
+ sin_cache = builder.get_value(node.input[3])
554
+
555
+ # Get attributes
556
+ interleaved = get_attribute(node, "interleaved", 0)
557
+ num_heads = get_attribute(node, "num_heads", 0)
558
+ rotary_embedding_dim = get_attribute(node, "rotary_embedding_dim", 0)
559
+ scale = get_attribute(node, "scale", 1.0)
560
+
561
+ def _rotary_embedding(
562
+ x: torch.Tensor,
563
+ pos_ids: torch.Tensor,
564
+ cos_cache: torch.Tensor,
565
+ sin_cache: torch.Tensor,
566
+ interleaved: int,
567
+ num_heads: int,
568
+ rotary_dim: int,
569
+ scale: float,
570
+ ) -> torch.Tensor:
571
+ """Apply rotary position embeddings."""
572
+ original_shape = x.shape
573
+ is_3d = x.dim() == 3
574
+
575
+ if is_3d:
576
+ # Input is (batch_size, seq_len, hidden_size)
577
+ batch_size, seq_len, hidden_size = x.shape
578
+
579
+ # Determine head_size and num_heads
580
+ if num_heads > 0:
581
+ head_size = hidden_size // num_heads
582
+ actual_num_heads = num_heads
583
+ else:
584
+ # Infer head_size from cos_cache dimension
585
+ # cos_cache has shape (max_seq, rotary_dim/2)
586
+ rotary_half_dim = cos_cache.shape[-1]
587
+ head_size = rotary_half_dim * 2 # rotary_dim == head_size typically
588
+ actual_num_heads = hidden_size // head_size
589
+
590
+ # Reshape to (batch, num_heads, seq, head_size)
591
+ x = x.view(batch_size, seq_len, actual_num_heads, head_size).transpose(1, 2)
592
+ else:
593
+ # Input is (batch_size, num_heads, seq_len, head_size)
594
+ batch_size, actual_num_heads, seq_len, head_size = x.shape
595
+
596
+ # Get cos/sin values for positions
597
+ # position_ids can be (1,) scalar or (batch, seq) or (seq,)
598
+ if pos_ids.dim() == 1:
599
+ if pos_ids.numel() == 1:
600
+ # Single position offset - generate sequence
601
+ start_pos = pos_ids.item()
602
+ positions = torch.arange(
603
+ start_pos, start_pos + seq_len, device=x.device, dtype=torch.long
604
+ )
605
+ else:
606
+ positions = pos_ids
607
+ else:
608
+ # (batch, seq) - use first batch for now (they should be the same)
609
+ positions = pos_ids[0] if pos_ids.shape[0] > 1 else pos_ids.squeeze(0)
610
+
611
+ # Gather cos/sin from cache based on positions
612
+ cos = cos_cache[positions] # (seq_len, rotary_dim/2)
613
+ sin = sin_cache[positions] # (seq_len, rotary_dim/2)
614
+
615
+ # Determine rotary dimension
616
+ if rotary_dim > 0:
617
+ rot_dim = rotary_dim
618
+ else:
619
+ rot_dim = cos.shape[-1] * 2 # cos/sin cache is half the rotary dim
620
+
621
+ # Expand cos/sin for batch and heads: (1, 1, seq_len, rotary_dim/2)
622
+ cos = cos.unsqueeze(0).unsqueeze(0)
623
+ sin = sin.unsqueeze(0).unsqueeze(0)
624
+
625
+ # Apply scale if specified
626
+ if scale != 1.0:
627
+ x = x * scale
628
+
629
+ # Split into rotary and pass-through parts
630
+ x_rot = x[..., :rot_dim]
631
+ x_pass = x[..., rot_dim:] if rot_dim < x.shape[-1] else None
632
+
633
+ if interleaved:
634
+ # Interleaved format: [x0, y0, x1, y1, ...] pairs
635
+ # Rotate pairs: (x, y) -> (x*cos - y*sin, x*sin + y*cos)
636
+ x1 = x_rot[..., ::2] # Even indices
637
+ x2 = x_rot[..., 1::2] # Odd indices
638
+
639
+ # Make sure cos/sin match the half dimension
640
+ cos_half = cos[..., : x1.shape[-1]]
641
+ sin_half = sin[..., : x1.shape[-1]]
642
+
643
+ # Apply rotation
644
+ x_rot_new = torch.stack(
645
+ [x1 * cos_half - x2 * sin_half, x1 * sin_half + x2 * cos_half], dim=-1
646
+ ).flatten(-2)
647
+ else:
648
+ # Non-interleaved format: first half real, second half imaginary
649
+ # x = [x1, x2] where x1 and x2 are halves
650
+ half_dim = rot_dim // 2
651
+ x1 = x_rot[..., :half_dim]
652
+ x2 = x_rot[..., half_dim:rot_dim]
653
+
654
+ # Apply rotation: (x1, x2) -> (x1*cos - x2*sin, x1*sin + x2*cos)
655
+ x_rot_new = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
656
+
657
+ # Concatenate with pass-through part
658
+ if x_pass is not None:
659
+ x_out = torch.cat([x_rot_new, x_pass], dim=-1)
660
+ else:
661
+ x_out = x_rot_new
662
+
663
+ # Reshape back to original shape
664
+ if is_3d:
665
+ # Always reshape back from (batch, num_heads, seq, head_size) to (batch, seq, hidden)
666
+ x_out = x_out.transpose(1, 2).contiguous().view(original_shape)
667
+
668
+ return x_out
669
+
670
+ return builder.call_function(
671
+ _rotary_embedding,
672
+ args=(
673
+ input_tensor,
674
+ position_ids,
675
+ cos_cache,
676
+ sin_cache,
677
+ interleaved,
678
+ num_heads,
679
+ rotary_embedding_dim,
680
+ scale,
681
+ ),
682
+ )