liger-kernel-nightly 0.5.10.dev20250624183504__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (68) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +54 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +4 -0
  7. liger_kernel/chunked_loss/grpo_loss.py +38 -4
  8. liger_kernel/chunked_loss/jsd_loss.py +23 -7
  9. liger_kernel/ops/cross_entropy.py +118 -62
  10. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  11. liger_kernel/ops/fused_linear_cross_entropy.py +113 -21
  12. liger_kernel/ops/geglu.py +1 -1
  13. liger_kernel/ops/layer_norm.py +124 -89
  14. liger_kernel/ops/llama4_rope.py +225 -0
  15. liger_kernel/ops/poly_norm.py +386 -0
  16. liger_kernel/ops/rms_norm.py +2 -2
  17. liger_kernel/ops/rope.py +1 -1
  18. liger_kernel/ops/swiglu.py +1 -1
  19. liger_kernel/ops/tiled_mlp.py +136 -0
  20. liger_kernel/transformers/__init__.py +50 -0
  21. liger_kernel/transformers/cross_entropy.py +8 -3
  22. liger_kernel/transformers/experimental/__init__.py +5 -0
  23. liger_kernel/transformers/functional.py +38 -6
  24. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  25. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -4
  26. liger_kernel/transformers/llama4_rope.py +93 -0
  27. liger_kernel/transformers/model/falcon_h1.py +122 -0
  28. liger_kernel/transformers/model/gemma.py +28 -8
  29. liger_kernel/transformers/model/gemma2.py +31 -8
  30. liger_kernel/transformers/model/gemma3.py +100 -110
  31. liger_kernel/transformers/model/glm4.py +18 -5
  32. liger_kernel/transformers/model/glm4v.py +163 -0
  33. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  34. liger_kernel/transformers/model/internvl.py +157 -0
  35. liger_kernel/transformers/model/llama.py +26 -7
  36. liger_kernel/transformers/model/llama4.py +121 -0
  37. liger_kernel/transformers/model/llava.py +18 -6
  38. liger_kernel/transformers/model/loss_utils.py +34 -3
  39. liger_kernel/transformers/model/mistral.py +17 -10
  40. liger_kernel/transformers/model/mixtral.py +24 -9
  41. liger_kernel/transformers/model/mllama.py +18 -7
  42. liger_kernel/transformers/model/olmo2.py +18 -5
  43. liger_kernel/transformers/model/output_classes.py +147 -0
  44. liger_kernel/transformers/model/paligemma.py +41 -5
  45. liger_kernel/transformers/model/phi3.py +24 -159
  46. liger_kernel/transformers/model/qwen2.py +26 -4
  47. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  48. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  49. liger_kernel/transformers/model/qwen3.py +22 -6
  50. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  51. liger_kernel/transformers/model/qwen3_next.py +146 -0
  52. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  53. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  54. liger_kernel/transformers/model/smollm3.py +199 -0
  55. liger_kernel/transformers/model/smolvlm.py +158 -0
  56. liger_kernel/transformers/monkey_patch.py +1090 -116
  57. liger_kernel/transformers/multi_token_attention.py +1 -1
  58. liger_kernel/transformers/poly_norm.py +42 -0
  59. liger_kernel/transformers/rms_norm.py +7 -0
  60. liger_kernel/transformers/rope.py +43 -0
  61. liger_kernel/transformers/tiled_mlp.py +133 -0
  62. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +26 -24
  63. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  64. liger_kernel_nightly-0.5.10.dev20250624183504.dist-info/RECORD +0 -95
  65. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  66. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  67. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +0 -0
  68. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,3 @@
1
- import math
2
1
  import operator
3
2
 
4
3
  import torch
@@ -43,30 +42,44 @@ def _layer_norm_forward_kernel(
43
42
  https://arxiv.org/abs/1607.06450
44
43
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
45
44
  """
46
- row_idx = tl.program_id(0)
45
+ row_idx = tl.program_id(0).to(tl.int64)
47
46
  col_offsets = tl.arange(0, BLOCK_SIZE)
48
47
  mask = col_offsets < n_cols
49
48
 
50
- Y_ptr += row_idx * Y_row_stride
51
- X_ptr += row_idx * X_row_stride
52
- Mean_ptr += row_idx * Mean_row_stride
53
- RSTD_ptr += row_idx * RSTD_row_stride
54
-
55
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
56
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
57
- B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
58
-
59
- mean = tl.sum(X_row, axis=0) / n_cols
60
- Xmm = tl.where(mask, X_row - mean, 0)
61
- var = tl.sum(Xmm * Xmm, axis=0) / n_cols
49
+ # Pre-load weights and bias in fp32 to avoid repeated conversions
50
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
51
+ B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0.0)
52
+ W_f32 = W_row.to(tl.float32)
53
+ B_f32 = B_row.to(tl.float32)
54
+
55
+ # Calculate pointers for this row
56
+ row_X_ptr = X_ptr + row_idx * X_row_stride
57
+ row_Y_ptr = Y_ptr + row_idx * Y_row_stride
58
+ row_Mean_ptr = Mean_ptr + row_idx * Mean_row_stride
59
+ row_RSTD_ptr = RSTD_ptr + row_idx * RSTD_row_stride
60
+
61
+ # Load input data and convert to fp32 for numerical stability
62
+ X_row = tl.load(row_X_ptr + col_offsets, mask=mask, other=0.0)
63
+ X_f32 = X_row.to(tl.float32)
64
+
65
+ # Compute statistics in fp32 for numerical stability
66
+ mean = tl.sum(X_f32, axis=0) / n_cols
67
+ X_centered = X_f32 - mean
68
+ # Apply mask to variance calculation to exclude contributions from masked elements
69
+ X_centered_masked = tl.where(mask, X_centered, 0.0)
70
+ var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
62
71
  rstd = rsqrt(var + eps)
63
72
 
64
- tl.store(Mean_ptr, mean)
65
- tl.store(RSTD_ptr, rstd)
73
+ # Store statistics (convert back to original dtype only once)
74
+ tl.store(row_Mean_ptr, mean.to(X_row.dtype))
75
+ tl.store(row_RSTD_ptr, rstd.to(X_row.dtype))
66
76
 
67
- Y_row = Xmm * rstd * W_row + B_row
77
+ # Fused normalization and affine transformation
78
+ # Y = (X - mean) * rstd * W + B = X_centered * rstd * W + B
79
+ Y_f32 = X_centered * rstd * W_f32 + B_f32
68
80
 
69
- tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
81
+ # Store output (single conversion back to original dtype)
82
+ tl.store(row_Y_ptr + col_offsets, Y_f32.to(X_row.dtype), mask=mask)
70
83
 
71
84
 
72
85
  @triton.jit
@@ -81,73 +94,86 @@ def _layer_norm_backward_kernel(
81
94
  DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
82
95
  stride_x, # stride of each row in input
83
96
  stride_dx, # stride of each row in input grad
84
- stride_dw, # stride of each row in weights grad
85
- stride_db, # stride of each row in bias grad
86
97
  stride_dy, # stride of each row in output grad
87
- n_rows,
88
98
  n_cols,
89
- rows_per_program: tl.constexpr,
90
99
  BLOCK_SIZE: tl.constexpr,
91
100
  dtype: tl.constexpr,
101
+ atomic_dtype: tl.constexpr,
92
102
  ):
93
103
  """
94
104
  References:
95
105
  https://arxiv.org/abs/1607.06450
96
106
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
97
- https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
98
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
99
107
  """
100
- row_block_id = tl.program_id(0)
101
- row_start = row_block_id * rows_per_program
102
- row_end = min((row_block_id + 1) * rows_per_program, n_rows)
108
+ row_idx = tl.program_id(0).to(tl.int64)
103
109
  cols = tl.arange(0, BLOCK_SIZE)
104
110
  mask = cols < n_cols
105
111
 
106
- dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
- db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
108
-
109
- X_ptr += row_start * stride_x
110
- Mean_ptr += row_start
111
- RSTD_ptr += row_start
112
- DX_ptr += row_start * stride_dx
113
- DY_ptr += row_start * stride_dy
114
-
115
- for _ in range(row_start, row_end):
116
- x = tl.load(X_ptr + cols, mask=mask, other=0.0)
117
- w = tl.load(W_ptr + cols, mask=mask, other=0.0)
118
- dy = tl.load(DY_ptr + cols, mask=mask, other=0.0)
119
- mean = tl.load(Mean_ptr)
120
- rstd = tl.load(RSTD_ptr)
121
-
122
- x_hat = (x - mean) * rstd
123
- wdy = w * dy
124
- c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
125
- c2 = tl.sum(wdy, axis=0) / n_cols
126
- dx = (wdy - (x_hat * c1 + c2)) * rstd
127
- tl.store(DX_ptr + cols, dx.to(dtype), mask=mask)
128
-
129
- dw_row += dy * x_hat
130
- db_row += dy
131
-
132
- X_ptr += stride_x
133
- Mean_ptr += 1
134
- RSTD_ptr += 1
135
- DX_ptr += stride_dx
136
- DY_ptr += stride_dy
137
-
138
- tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask)
139
- tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask)
112
+ # Pre-load weights once (same optimization as forward pass)
113
+ w = tl.load(W_ptr + cols, mask=mask, other=0.0)
114
+ w_f32 = w.to(tl.float32)
115
+
116
+ # Calculate pointers for this specific row
117
+ row_X_ptr = X_ptr + row_idx * stride_x
118
+ row_DX_ptr = DX_ptr + row_idx * stride_dx
119
+ row_DY_ptr = DY_ptr + row_idx * stride_dy
120
+ row_Mean_ptr = Mean_ptr + row_idx
121
+ row_RSTD_ptr = RSTD_ptr + row_idx
122
+
123
+ # Load data for this row
124
+ x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
125
+ dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
126
+ mean = tl.load(row_Mean_ptr)
127
+ rstd = tl.load(row_RSTD_ptr)
128
+
129
+ # Convert to fp32 for numerical stability
130
+ x_f32 = x.to(tl.float32)
131
+ dy_f32 = dy.to(tl.float32)
132
+ mean_f32 = mean.to(tl.float32)
133
+ rstd_f32 = rstd.to(tl.float32)
134
+
135
+ # Compute backward pass for this row
136
+ x_hat = (x_f32 - mean_f32) * rstd_f32
137
+ wdy = w_f32 * dy_f32
138
+ c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
139
+ c2 = tl.sum(wdy, axis=0) / n_cols
140
+ dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
141
+
142
+ # Store input gradient
143
+ tl.store(row_DX_ptr + cols, dx.to(dtype), mask=mask)
144
+
145
+ # Accumulate weight and bias gradients using atomic operations
146
+ dw = dy_f32 * x_hat
147
+ db = dy_f32
148
+ tl.atomic_add(DW_ptr + cols, dw.to(atomic_dtype), mask=mask)
149
+ tl.atomic_add(DB_ptr + cols, db.to(atomic_dtype), mask=mask)
140
150
 
141
151
 
142
152
  def layer_norm_forward(X, W, B, eps):
153
+ """
154
+ Args:
155
+ X: Input tensor of shape (..., hidden_size)
156
+ W: Weight tensor of shape (hidden_size,)
157
+ B: Bias tensor of shape (hidden_size,)
158
+ eps: Small constant for numerical stability
159
+
160
+ Returns:
161
+ Tuple of (output, input, mean, rstd, block_size, num_warps)
162
+ """
143
163
  shape = X.shape
144
164
  dim = shape[-1]
145
165
  X = X.view(-1, dim)
146
166
  n_rows, n_cols = X.shape
167
+
168
+ # Calculate optimal block size and warp configuration
147
169
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
170
+
171
+ # Allocate output tensors
148
172
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
149
173
  Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
150
174
  RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
175
+
176
+ # Validate input dimensions
151
177
  if X.shape[1] != W.shape[0]:
152
178
  raise ValueError(
153
179
  f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
@@ -159,7 +185,9 @@ def layer_norm_forward(X, W, B, eps):
159
185
  if X.device.type == "xpu":
160
186
  kernel_args["grf_mode"] = "large"
161
187
 
162
- _layer_norm_forward_kernel[(n_rows,)](
188
+ # Launch kernel with one thread block per row for optimal performance
189
+ grid = (n_rows,)
190
+ _layer_norm_forward_kernel[grid](
163
191
  Y,
164
192
  Y.stride(0),
165
193
  X,
@@ -176,35 +204,43 @@ def layer_norm_forward(X, W, B, eps):
176
204
  eps,
177
205
  BLOCK_SIZE=BLOCK_SIZE,
178
206
  num_warps=num_warps,
179
- **kernel_args, # XPU-specific optimization
207
+ **kernel_args,
180
208
  )
209
+
181
210
  return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
182
211
 
183
212
 
184
213
  def layer_norm_backward(dY, X, W, B, Mean, RSTD):
214
+ """
215
+ Args:
216
+ dY: Gradient of output
217
+ X: Input tensor
218
+ W: Weight tensor
219
+ B: Bias tensor
220
+ Mean: Pre-computed mean
221
+ RSTD: Pre-computed reciprocal standard deviation
222
+
223
+ Returns:
224
+ Tuple of (input_grad, weight_grad, bias_grad)
225
+ """
185
226
  shape = dY.shape
186
227
  dim = shape[-1]
187
228
  dY = dY.view(-1, dim)
188
229
  n_rows, n_cols = dY.shape
189
230
 
190
- sm_count = 1
191
- if X.device.type == "cuda":
192
- sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
193
- elif X.device.type == "xpu":
194
- sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
195
-
231
+ # Allocate gradient tensors
196
232
  DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
197
- _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
198
- _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
233
+ # Use float32 for weight/bias gradients if bfloat16 (due to atomic_add limitation)
234
+ grad_dtype = torch.float32 if W.dtype == torch.bfloat16 else W.dtype
235
+ DW = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
236
+ DB = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
199
237
 
238
+ # Calculate optimal block size and warp configuration
200
239
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
201
240
  if n_cols > BLOCK_SIZE:
202
- raise RuntimeError(
203
- f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
204
- )
241
+ raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
205
242
 
206
- rows_per_program = math.ceil(n_rows / sm_count)
207
- grid = (sm_count,)
243
+ # Determine dtype for triton operations
208
244
  triton_dtype = (
209
245
  tl.float32
210
246
  if X.dtype == torch.float32
@@ -212,41 +248,40 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
212
248
  if X.dtype == torch.bfloat16
213
249
  else tl.float16
214
250
  if X.dtype == torch.float16
215
- else tl.float32 # fallback to float32 for other types
251
+ else tl.float32 # fallback
216
252
  )
217
253
 
254
+ # Use float32 for atomic operations if bfloat16 is not supported
255
+ atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
256
+
257
+ kernel_args = {"num_warps": num_warps}
218
258
  # XPU-specific optimization
219
- kernel_args = {}
220
259
  if X.device.type == "xpu":
221
260
  kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
222
261
 
262
+ # Launch kernel with one thread block per row for optimal performance
263
+ grid = (n_rows,)
223
264
  _layer_norm_backward_kernel[grid](
224
265
  X,
225
266
  W,
226
267
  Mean,
227
268
  RSTD,
228
269
  DX,
229
- _DW,
230
- _DB,
270
+ DW,
271
+ DB,
231
272
  dY,
232
273
  X.stride(0),
233
274
  DX.stride(0),
234
- _DW.stride(0),
235
- _DB.stride(0),
236
275
  dY.stride(0),
237
- n_rows,
238
276
  n_cols,
239
- rows_per_program,
240
277
  BLOCK_SIZE=BLOCK_SIZE,
241
278
  dtype=triton_dtype,
242
- **kernel_args, # XPU-specific optimization
279
+ atomic_dtype=atomic_dtype,
280
+ **kernel_args,
243
281
  )
244
282
 
245
- DW = _DW.sum(dim=0).to(W.dtype)
246
- DB = _DB.sum(dim=0).to(W.dtype)
247
-
248
283
  DX = DX.view(*shape)
249
- return DX, DW, DB
284
+ return DX, DW.to(W.dtype), DB.to(W.dtype)
250
285
 
251
286
 
252
287
  class LigerLayerNormFunction(torch.autograd.Function):
@@ -0,0 +1,225 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ def _prepare_freqs(freqs_cis: torch.Tensor, seq_len: int, head_dim_half: int):
7
+ # Split or unpack complex frequencies into real and imag parts
8
+ if freqs_cis.is_complex():
9
+ freqs_real = freqs_cis.real
10
+ freqs_imag = freqs_cis.imag
11
+ else:
12
+ # Already split: last dim should be 2*head_dim_half
13
+ if freqs_cis.shape[-1] == 2 * head_dim_half:
14
+ freqs_real = freqs_cis[..., :head_dim_half]
15
+ freqs_imag = freqs_cis[..., head_dim_half:]
16
+ else:
17
+ raise ValueError(
18
+ f"Unexpected freqs_cis shape for non-complex input: {freqs_cis.shape}, expected last dim = {2 * head_dim_half}"
19
+ )
20
+
21
+ # Canonicalize to shape (seq_len, head_dim_half):
22
+ # 1) Ensure the last dimension is head_dim_half
23
+ if freqs_real.shape[-1] != head_dim_half:
24
+ raise ValueError(f"Unexpected last dim for freqs: {freqs_real.shape[-1]} (expected {head_dim_half})")
25
+ # 2) Flatten all leading dims to a single row dimension
26
+ freqs_real = freqs_real.reshape(-1, head_dim_half)
27
+ freqs_imag = freqs_imag.reshape(-1, head_dim_half)
28
+ # 3) If we have fewer rows than seq_len, allow broadcasting when single row
29
+ if freqs_real.shape[0] < seq_len:
30
+ if freqs_real.shape[0] == 1:
31
+ freqs_real = freqs_real.expand(seq_len, -1)
32
+ freqs_imag = freqs_imag.expand(seq_len, -1)
33
+ else:
34
+ raise ValueError(f"Insufficient rows in freqs: {freqs_real.shape[0]} < seq_len={seq_len}")
35
+ # 4) If we have more rows than seq_len (e.g., batch present), take the first seq_len rows
36
+ elif freqs_real.shape[0] > seq_len:
37
+ freqs_real = freqs_real[:seq_len]
38
+ freqs_imag = freqs_imag[:seq_len]
39
+
40
+ return freqs_real, freqs_imag
41
+
42
+
43
+ def _maybe_to_dtype(t: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
44
+ return t if t.dtype == dtype else t.to(dtype)
45
+
46
+
47
+ def _maybe_contiguous(t: torch.Tensor) -> torch.Tensor:
48
+ return t if t.is_contiguous() else t.contiguous()
49
+
50
+
51
+ def _cast_and_contiguous(q, k, freqs_real, freqs_imag):
52
+ # Choose compute dtype: use fp32 only when inputs are fp32; otherwise keep input dtype for performance
53
+ compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype
54
+
55
+ # Make sure q/k share the same dtype before casting to compute dtype
56
+ if k.dtype != q.dtype:
57
+ k = k.to(q.dtype)
58
+
59
+ q = _maybe_contiguous(_maybe_to_dtype(q, compute_dtype))
60
+ k = _maybe_contiguous(_maybe_to_dtype(k, compute_dtype))
61
+ freqs_real = _maybe_contiguous(_maybe_to_dtype(freqs_real, compute_dtype))
62
+ freqs_imag = _maybe_contiguous(_maybe_to_dtype(freqs_imag, compute_dtype))
63
+ return q, k, freqs_real, freqs_imag
64
+
65
+
66
+ @triton.jit
67
+ def _llama4_rope_kernel(
68
+ q_ptr,
69
+ k_ptr,
70
+ freqs_real_ptr,
71
+ freqs_imag_ptr,
72
+ q_row_stride,
73
+ k_row_stride,
74
+ q_head_stride,
75
+ k_head_stride,
76
+ freqs_row_stride,
77
+ seq_len,
78
+ batch_size,
79
+ imag_sign,
80
+ head_dim_half: tl.constexpr,
81
+ n_q_heads: tl.constexpr,
82
+ n_k_heads: tl.constexpr,
83
+ BLOCK_SIZE: tl.constexpr,
84
+ ):
85
+ """
86
+ H100-optimized RoPE kernel with improved parallelization across heads and dimensions.
87
+ Grid: (batch*seq, head)
88
+ """
89
+ # 2D grid
90
+ pid_bs = tl.program_id(0) # over batch*seq
91
+ pid_h = tl.program_id(1) # over heads
92
+
93
+ batch_idx = pid_bs // seq_len
94
+ seq_idx = pid_bs % seq_len
95
+
96
+ # Bounds check
97
+ if batch_idx >= batch_size or seq_idx >= seq_len:
98
+ return
99
+
100
+ # Base pointers for this (batch, seq) position
101
+ base_offset = batch_idx * seq_len + seq_idx
102
+ q_base = q_ptr + base_offset * q_row_stride
103
+ k_base = k_ptr + base_offset * k_row_stride
104
+
105
+ # Tiling over dim/2
106
+ for d_start in tl.static_range(0, head_dim_half, BLOCK_SIZE):
107
+ d_indices = d_start + tl.arange(0, BLOCK_SIZE)
108
+ mask_d = d_indices < head_dim_half
109
+
110
+ # Load frequencies once per tile (freqs layout: [seq_len, head_dim_half])
111
+ freq_idx = d_indices
112
+ freqs_real = tl.load(freqs_real_ptr + seq_idx * freqs_row_stride + freq_idx, mask=mask_d, other=0.0)
113
+ freqs_imag = tl.load(freqs_imag_ptr + seq_idx * freqs_row_stride + freq_idx, mask=mask_d, other=0.0)
114
+ freqs_imag = freqs_imag * imag_sign
115
+
116
+ # Process one query head per program in pid_h
117
+ if pid_h < n_q_heads:
118
+ q_head_ptr = q_base + pid_h * q_head_stride
119
+ q_real = tl.load(q_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
120
+ q_imag = tl.load(q_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)
121
+
122
+ # Complex multiply with FMAs: (a+ib)*(c+i d) = (a*c - b*d) + i(a*d + b*c)
123
+ new_q_real = tl.math.fma(q_real, freqs_real, -(q_imag * freqs_imag))
124
+ new_q_imag = tl.math.fma(q_real, freqs_imag, q_imag * freqs_real)
125
+
126
+ tl.store(q_head_ptr + d_indices * 2, new_q_real, mask=mask_d)
127
+ tl.store(q_head_ptr + d_indices * 2 + 1, new_q_imag, mask=mask_d)
128
+
129
+ # Process one key head per program in pid_h
130
+ if pid_h < n_k_heads:
131
+ k_head_ptr = k_base + pid_h * k_head_stride
132
+ k_real = tl.load(k_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
133
+ k_imag = tl.load(k_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)
134
+
135
+ new_k_real = tl.math.fma(k_real, freqs_real, -(k_imag * freqs_imag))
136
+ new_k_imag = tl.math.fma(k_real, freqs_imag, k_imag * freqs_real)
137
+
138
+ tl.store(k_head_ptr + d_indices * 2, new_k_real, mask=mask_d)
139
+ tl.store(k_head_ptr + d_indices * 2 + 1, new_k_imag, mask=mask_d)
140
+
141
+
142
+ def _select_kernel_meta(head_dim_half: int):
143
+ # Heuristic tuning for block size and num_warps
144
+ if head_dim_half >= 256:
145
+ return 128, 8
146
+ if head_dim_half >= 96:
147
+ return 128, 4
148
+ if head_dim_half >= 48:
149
+ return 64, 4
150
+ if head_dim_half >= 24:
151
+ return 32, 2
152
+ return 16, 2
153
+
154
+
155
+ def llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE: int = None, imag_sign: float = 1.0):
156
+ # Save original dtype for casting back
157
+ original_dtype = q.dtype
158
+
159
+ batch_size, seq_len, n_q_heads, head_dim = q.shape
160
+ _, _, n_k_heads, _ = k.shape
161
+ head_dim_half = head_dim // 2
162
+
163
+ # Prepare frequencies
164
+ freqs_real, freqs_imag = _prepare_freqs(freqs_cis, seq_len, head_dim_half)
165
+
166
+ # Cast to appropriate dtype and make contiguous only when needed
167
+ q, k, freqs_real, freqs_imag = _cast_and_contiguous(q, k, freqs_real, freqs_imag)
168
+
169
+ # H100-optimized meta-params
170
+ if BLOCK_SIZE is None:
171
+ BLOCK_SIZE, num_warps = _select_kernel_meta(head_dim_half)
172
+ else:
173
+ # Provide a default num_warps if caller pins BLOCK_SIZE
174
+ _, num_warps = _select_kernel_meta(head_dim_half)
175
+
176
+ # 2D grid: one program per (batch, seq, head)
177
+ n_heads_max = max(n_q_heads, n_k_heads)
178
+ grid = (batch_size * seq_len, n_heads_max)
179
+
180
+ # Launch kernel
181
+ _llama4_rope_kernel[grid](
182
+ q,
183
+ k,
184
+ freqs_real,
185
+ freqs_imag,
186
+ q.stride(1),
187
+ k.stride(1),
188
+ q.stride(2),
189
+ k.stride(2),
190
+ freqs_real.stride(0),
191
+ seq_len,
192
+ batch_size,
193
+ imag_sign,
194
+ head_dim_half,
195
+ n_q_heads,
196
+ n_k_heads,
197
+ BLOCK_SIZE,
198
+ num_warps=num_warps,
199
+ num_stages=2,
200
+ )
201
+
202
+ # Cast back to original dtype only if it differs from compute dtype
203
+ if q.dtype != original_dtype:
204
+ q = q.to(original_dtype)
205
+ if k.dtype != original_dtype:
206
+ k = k.to(original_dtype)
207
+
208
+ return q, k
209
+
210
+
211
+ class LigerLlama4RopeFunction(torch.autograd.Function):
212
+ @staticmethod
213
+ def forward(ctx, q, k, freqs_cis, BLOCK_SIZE: int = None):
214
+ q_out, k_out = llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE, imag_sign=1.0)
215
+ ctx.save_for_backward(freqs_cis.detach() if isinstance(freqs_cis, torch.Tensor) else freqs_cis)
216
+ ctx.BLOCK_SIZE = BLOCK_SIZE
217
+ return q_out, k_out
218
+
219
+ @staticmethod
220
+ def backward(ctx, dq, dk):
221
+ (freqs_cis,) = ctx.saved_tensors
222
+ BLOCK_SIZE = getattr(ctx, "BLOCK_SIZE", None)
223
+ # Use imag_sign=-1.0 for conjugate without materializing a new tensor
224
+ dq_out, dk_out = llama4_rope_forward(dq, dk, freqs_cis, BLOCK_SIZE, imag_sign=-1.0)
225
+ return dq_out, dk_out, None