liger-kernel-nightly 0.5.10.dev20250624183504__py3-none-any.whl → 0.6.4.dev20251121224847__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (73) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +54 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
  7. liger_kernel/chunked_loss/grpo_loss.py +46 -9
  8. liger_kernel/chunked_loss/jsd_loss.py +23 -7
  9. liger_kernel/ops/cross_entropy.py +118 -62
  10. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  11. liger_kernel/ops/fused_linear_cross_entropy.py +113 -21
  12. liger_kernel/ops/geglu.py +1 -1
  13. liger_kernel/ops/grpo_loss.py +3 -1
  14. liger_kernel/ops/layer_norm.py +133 -79
  15. liger_kernel/ops/llama4_rope.py +225 -0
  16. liger_kernel/ops/poly_norm.py +386 -0
  17. liger_kernel/ops/rms_norm.py +2 -2
  18. liger_kernel/ops/rope.py +1 -1
  19. liger_kernel/ops/swiglu.py +1 -1
  20. liger_kernel/ops/tiled_mlp.py +136 -0
  21. liger_kernel/transformers/__init__.py +59 -0
  22. liger_kernel/transformers/cross_entropy.py +8 -3
  23. liger_kernel/transformers/experimental/__init__.py +5 -0
  24. liger_kernel/transformers/functional.py +38 -6
  25. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  26. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -4
  27. liger_kernel/transformers/grpo_loss.py +56 -1
  28. liger_kernel/transformers/llama4_rope.py +93 -0
  29. liger_kernel/transformers/model/falcon_h1.py +122 -0
  30. liger_kernel/transformers/model/gemma.py +28 -8
  31. liger_kernel/transformers/model/gemma2.py +31 -8
  32. liger_kernel/transformers/model/gemma3.py +100 -110
  33. liger_kernel/transformers/model/glm4.py +18 -5
  34. liger_kernel/transformers/model/glm4v.py +163 -0
  35. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  36. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  37. liger_kernel/transformers/model/internvl.py +157 -0
  38. liger_kernel/transformers/model/llama.py +26 -7
  39. liger_kernel/transformers/model/llama4.py +121 -0
  40. liger_kernel/transformers/model/llava.py +18 -6
  41. liger_kernel/transformers/model/loss_utils.py +34 -3
  42. liger_kernel/transformers/model/mistral.py +17 -10
  43. liger_kernel/transformers/model/mixtral.py +24 -9
  44. liger_kernel/transformers/model/mllama.py +18 -7
  45. liger_kernel/transformers/model/olmo2.py +18 -5
  46. liger_kernel/transformers/model/olmo3.py +142 -0
  47. liger_kernel/transformers/model/output_classes.py +147 -0
  48. liger_kernel/transformers/model/paligemma.py +41 -5
  49. liger_kernel/transformers/model/phi3.py +24 -159
  50. liger_kernel/transformers/model/qwen2.py +26 -4
  51. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  52. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  53. liger_kernel/transformers/model/qwen3.py +22 -6
  54. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  55. liger_kernel/transformers/model/qwen3_next.py +146 -0
  56. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  57. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  58. liger_kernel/transformers/model/smollm3.py +199 -0
  59. liger_kernel/transformers/model/smolvlm.py +158 -0
  60. liger_kernel/transformers/monkey_patch.py +1278 -116
  61. liger_kernel/transformers/multi_token_attention.py +1 -1
  62. liger_kernel/transformers/poly_norm.py +42 -0
  63. liger_kernel/transformers/rms_norm.py +7 -0
  64. liger_kernel/transformers/rope.py +43 -0
  65. liger_kernel/transformers/swiglu.py +17 -0
  66. liger_kernel/transformers/tiled_mlp.py +133 -0
  67. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/METADATA +29 -24
  68. liger_kernel_nightly-0.6.4.dev20251121224847.dist-info/RECORD +118 -0
  69. liger_kernel_nightly-0.5.10.dev20250624183504.dist-info/RECORD +0 -95
  70. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/LICENSE +0 -0
  71. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/NOTICE +0 -0
  72. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/WHEEL +0 -0
  73. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/top_level.txt +0 -0
@@ -43,111 +43,157 @@ def _layer_norm_forward_kernel(
43
43
  https://arxiv.org/abs/1607.06450
44
44
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
45
45
  """
46
- row_idx = tl.program_id(0)
46
+ row_idx = tl.program_id(0).to(tl.int64)
47
47
  col_offsets = tl.arange(0, BLOCK_SIZE)
48
48
  mask = col_offsets < n_cols
49
49
 
50
- Y_ptr += row_idx * Y_row_stride
51
- X_ptr += row_idx * X_row_stride
52
- Mean_ptr += row_idx * Mean_row_stride
53
- RSTD_ptr += row_idx * RSTD_row_stride
54
-
55
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
56
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
57
- B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
58
-
59
- mean = tl.sum(X_row, axis=0) / n_cols
60
- Xmm = tl.where(mask, X_row - mean, 0)
61
- var = tl.sum(Xmm * Xmm, axis=0) / n_cols
50
+ # Pre-load weights and bias in fp32 to avoid repeated conversions
51
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
52
+ B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0.0)
53
+ W_f32 = W_row.to(tl.float32)
54
+ B_f32 = B_row.to(tl.float32)
55
+
56
+ # Calculate pointers for this row
57
+ row_X_ptr = X_ptr + row_idx * X_row_stride
58
+ row_Y_ptr = Y_ptr + row_idx * Y_row_stride
59
+ row_Mean_ptr = Mean_ptr + row_idx * Mean_row_stride
60
+ row_RSTD_ptr = RSTD_ptr + row_idx * RSTD_row_stride
61
+
62
+ # Load input data and convert to fp32 for numerical stability
63
+ X_row = tl.load(row_X_ptr + col_offsets, mask=mask, other=0.0)
64
+ X_f32 = X_row.to(tl.float32)
65
+
66
+ # Compute statistics in fp32 for numerical stability
67
+ mean = tl.sum(X_f32, axis=0) / n_cols
68
+ X_centered = X_f32 - mean
69
+ # Apply mask to variance calculation to exclude contributions from masked elements
70
+ X_centered_masked = tl.where(mask, X_centered, 0.0)
71
+ var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
62
72
  rstd = rsqrt(var + eps)
63
73
 
64
- tl.store(Mean_ptr, mean)
65
- tl.store(RSTD_ptr, rstd)
74
+ # Store statistics (convert back to original dtype only once)
75
+ tl.store(row_Mean_ptr, mean.to(X_row.dtype))
76
+ tl.store(row_RSTD_ptr, rstd.to(X_row.dtype))
66
77
 
67
- Y_row = Xmm * rstd * W_row + B_row
78
+ # Fused normalization and affine transformation
79
+ # Y = (X - mean) * rstd * W + B = X_centered * rstd * W + B
80
+ Y_f32 = X_centered * rstd * W_f32 + B_f32
68
81
 
69
- tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
82
+ # Store output (single conversion back to original dtype)
83
+ tl.store(row_Y_ptr + col_offsets, Y_f32.to(X_row.dtype), mask=mask)
70
84
 
71
85
 
72
86
  @triton.jit
73
87
  def _layer_norm_backward_kernel(
74
88
  X_ptr, # pointer to input, shape (n_rows, n_cols)
89
+ stride_x, # stride of each row in input
75
90
  W_ptr, # pointer to weights, shape (n_cols,)
76
91
  Mean_ptr, # pointer to mean, shape (n_rows,)
92
+ stride_mean, # stride of each row in mean
77
93
  RSTD_ptr, # pointer to rstd, shape (n_rows,)
94
+ stride_rstd, # stride of each row in rstd
78
95
  DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
79
- DW_ptr, # pointer to weights grad, shape (n_cols,)
80
- DB_ptr, # pointer to bias grad, shape (n_cols,)
81
- DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
82
- stride_x, # stride of each row in input
83
96
  stride_dx, # stride of each row in input grad
97
+ DW_ptr, # pointer to weights grad, shape (n_cols,)
84
98
  stride_dw, # stride of each row in weights grad
99
+ DB_ptr, # pointer to bias grad, shape (n_cols,)
85
100
  stride_db, # stride of each row in bias grad
101
+ DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
86
102
  stride_dy, # stride of each row in output grad
87
103
  n_rows,
88
104
  n_cols,
89
105
  rows_per_program: tl.constexpr,
90
106
  BLOCK_SIZE: tl.constexpr,
91
- dtype: tl.constexpr,
92
107
  ):
93
108
  """
94
109
  References:
95
110
  https://arxiv.org/abs/1607.06450
96
111
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
97
- https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
98
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
99
112
  """
100
- row_block_id = tl.program_id(0)
113
+ row_block_id = tl.program_id(0).to(tl.int64)
101
114
  row_start = row_block_id * rows_per_program
102
115
  row_end = min((row_block_id + 1) * rows_per_program, n_rows)
103
116
  cols = tl.arange(0, BLOCK_SIZE)
104
117
  mask = cols < n_cols
105
118
 
106
- dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
119
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
120
  db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
108
121
 
109
- X_ptr += row_start * stride_x
110
- Mean_ptr += row_start
111
- RSTD_ptr += row_start
112
- DX_ptr += row_start * stride_dx
113
- DY_ptr += row_start * stride_dy
122
+ # Pre-load weights once (same optimization as forward pass)
123
+ w = tl.load(W_ptr + cols, mask=mask, other=0.0)
124
+ w_f32 = w.to(tl.float32)
125
+
126
+ # Calculate pointers for this specific row
127
+ row_X_ptr = X_ptr + row_start * stride_x
128
+ row_DX_ptr = DX_ptr + row_start * stride_dx
129
+ row_DY_ptr = DY_ptr + row_start * stride_dy
130
+ row_Mean_ptr = Mean_ptr + row_start
131
+ row_RSTD_ptr = RSTD_ptr + row_start
114
132
 
115
133
  for _ in range(row_start, row_end):
116
- x = tl.load(X_ptr + cols, mask=mask, other=0.0)
117
- w = tl.load(W_ptr + cols, mask=mask, other=0.0)
118
- dy = tl.load(DY_ptr + cols, mask=mask, other=0.0)
119
- mean = tl.load(Mean_ptr)
120
- rstd = tl.load(RSTD_ptr)
121
-
122
- x_hat = (x - mean) * rstd
123
- wdy = w * dy
134
+ # Load data for this row
135
+ x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
136
+ dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
137
+ mean = tl.load(row_Mean_ptr)
138
+ rstd = tl.load(row_RSTD_ptr)
139
+
140
+ # Convert to fp32 for numerical stability
141
+ x_f32 = x.to(tl.float32)
142
+ dy_f32 = dy.to(tl.float32)
143
+ mean_f32 = mean.to(tl.float32)
144
+ rstd_f32 = rstd.to(tl.float32)
145
+
146
+ # Compute backward pass for this row
147
+ x_hat = (x_f32 - mean_f32) * rstd_f32
148
+ wdy = w_f32 * dy_f32
124
149
  c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
125
150
  c2 = tl.sum(wdy, axis=0) / n_cols
126
- dx = (wdy - (x_hat * c1 + c2)) * rstd
127
- tl.store(DX_ptr + cols, dx.to(dtype), mask=mask)
151
+ dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
152
+
153
+ # Store input gradient
154
+ tl.store(row_DX_ptr + cols, dx, mask=mask)
128
155
 
129
- dw_row += dy * x_hat
130
- db_row += dy
156
+ # Accumulate weight and bias gradients for this thread block's assigned rows
157
+ dw = dy_f32 * x_hat
158
+ db = dy_f32
159
+ dW_row += dw
160
+ db_row += db
131
161
 
132
- X_ptr += stride_x
133
- Mean_ptr += 1
134
- RSTD_ptr += 1
135
- DX_ptr += stride_dx
136
- DY_ptr += stride_dy
162
+ row_X_ptr += stride_x
163
+ row_DX_ptr += stride_dx
164
+ row_DY_ptr += stride_dy
165
+ row_Mean_ptr += stride_mean
166
+ row_RSTD_ptr += stride_rstd
137
167
 
138
- tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask)
139
- tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask)
168
+ tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
169
+ tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
140
170
 
141
171
 
142
172
  def layer_norm_forward(X, W, B, eps):
173
+ """
174
+ Args:
175
+ X: Input tensor of shape (..., hidden_size)
176
+ W: Weight tensor of shape (hidden_size,)
177
+ B: Bias tensor of shape (hidden_size,)
178
+ eps: Small constant for numerical stability
179
+
180
+ Returns:
181
+ Tuple of (output, input, mean, rstd, block_size, num_warps)
182
+ """
143
183
  shape = X.shape
144
184
  dim = shape[-1]
145
185
  X = X.view(-1, dim)
146
186
  n_rows, n_cols = X.shape
187
+
188
+ # Calculate optimal block size and warp configuration
147
189
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
190
+
191
+ # Allocate output tensors
148
192
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
149
193
  Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
150
194
  RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
195
+
196
+ # Validate input dimensions
151
197
  if X.shape[1] != W.shape[0]:
152
198
  raise ValueError(
153
199
  f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
@@ -159,7 +205,9 @@ def layer_norm_forward(X, W, B, eps):
159
205
  if X.device.type == "xpu":
160
206
  kernel_args["grf_mode"] = "large"
161
207
 
162
- _layer_norm_forward_kernel[(n_rows,)](
208
+ # Launch kernel with one thread block per row for optimal performance
209
+ grid = (n_rows,)
210
+ _layer_norm_forward_kernel[grid](
163
211
  Y,
164
212
  Y.stride(0),
165
213
  X,
@@ -176,12 +224,25 @@ def layer_norm_forward(X, W, B, eps):
176
224
  eps,
177
225
  BLOCK_SIZE=BLOCK_SIZE,
178
226
  num_warps=num_warps,
179
- **kernel_args, # XPU-specific optimization
227
+ **kernel_args,
180
228
  )
229
+
181
230
  return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
182
231
 
183
232
 
184
233
  def layer_norm_backward(dY, X, W, B, Mean, RSTD):
234
+ """
235
+ Args:
236
+ dY: Gradient of output
237
+ X: Input tensor
238
+ W: Weight tensor
239
+ B: Bias tensor
240
+ Mean: Pre-computed mean
241
+ RSTD: Pre-computed reciprocal standard deviation
242
+
243
+ Returns:
244
+ Tuple of (input_grad, weight_grad, bias_grad)
245
+ """
185
246
  shape = dY.shape
186
247
  dim = shape[-1]
187
248
  dY = dY.view(-1, dim)
@@ -193,59 +254,52 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
193
254
  elif X.device.type == "xpu":
194
255
  sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
195
256
 
196
- DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
197
- _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
198
- _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
257
+ # fp32 for numerical stability especially.
258
+ _DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
259
+ _DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
199
260
 
261
+ # Calculate optimal block size and warp configuration
200
262
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
201
263
  if n_cols > BLOCK_SIZE:
202
- raise RuntimeError(
203
- f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
204
- )
205
-
264
+ raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
206
265
  rows_per_program = math.ceil(n_rows / sm_count)
207
266
  grid = (sm_count,)
208
- triton_dtype = (
209
- tl.float32
210
- if X.dtype == torch.float32
211
- else tl.bfloat16
212
- if X.dtype == torch.bfloat16
213
- else tl.float16
214
- if X.dtype == torch.float16
215
- else tl.float32 # fallback to float32 for other types
216
- )
217
267
 
268
+ # Allocate gradient tensors
269
+ DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
270
+
271
+ kernel_args = {"num_warps": num_warps}
218
272
  # XPU-specific optimization
219
- kernel_args = {}
220
273
  if X.device.type == "xpu":
221
274
  kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
222
275
 
276
+ # Launch kernel with one thread block per row for optimal performance
223
277
  _layer_norm_backward_kernel[grid](
224
278
  X,
279
+ X.stride(0),
225
280
  W,
226
281
  Mean,
282
+ Mean.stride(0),
227
283
  RSTD,
284
+ RSTD.stride(0),
228
285
  DX,
229
- _DW,
230
- _DB,
231
- dY,
232
- X.stride(0),
233
286
  DX.stride(0),
287
+ _DW,
234
288
  _DW.stride(0),
289
+ _DB,
235
290
  _DB.stride(0),
291
+ dY,
236
292
  dY.stride(0),
237
293
  n_rows,
238
294
  n_cols,
239
- rows_per_program,
295
+ rows_per_program=rows_per_program,
240
296
  BLOCK_SIZE=BLOCK_SIZE,
241
- dtype=triton_dtype,
242
- **kernel_args, # XPU-specific optimization
297
+ **kernel_args,
243
298
  )
244
299
 
245
- DW = _DW.sum(dim=0).to(W.dtype)
246
- DB = _DB.sum(dim=0).to(W.dtype)
247
-
248
300
  DX = DX.view(*shape)
301
+ DW = _DW.sum(dim=0).to(W.dtype)
302
+ DB = _DB.sum(dim=0).to(B.dtype)
249
303
  return DX, DW, DB
250
304
 
251
305
 
@@ -0,0 +1,225 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ def _prepare_freqs(freqs_cis: torch.Tensor, seq_len: int, head_dim_half: int):
7
+ # Split or unpack complex frequencies into real and imag parts
8
+ if freqs_cis.is_complex():
9
+ freqs_real = freqs_cis.real
10
+ freqs_imag = freqs_cis.imag
11
+ else:
12
+ # Already split: last dim should be 2*head_dim_half
13
+ if freqs_cis.shape[-1] == 2 * head_dim_half:
14
+ freqs_real = freqs_cis[..., :head_dim_half]
15
+ freqs_imag = freqs_cis[..., head_dim_half:]
16
+ else:
17
+ raise ValueError(
18
+ f"Unexpected freqs_cis shape for non-complex input: {freqs_cis.shape}, expected last dim = {2 * head_dim_half}"
19
+ )
20
+
21
+ # Canonicalize to shape (seq_len, head_dim_half):
22
+ # 1) Ensure the last dimension is head_dim_half
23
+ if freqs_real.shape[-1] != head_dim_half:
24
+ raise ValueError(f"Unexpected last dim for freqs: {freqs_real.shape[-1]} (expected {head_dim_half})")
25
+ # 2) Flatten all leading dims to a single row dimension
26
+ freqs_real = freqs_real.reshape(-1, head_dim_half)
27
+ freqs_imag = freqs_imag.reshape(-1, head_dim_half)
28
+ # 3) If we have fewer rows than seq_len, allow broadcasting when single row
29
+ if freqs_real.shape[0] < seq_len:
30
+ if freqs_real.shape[0] == 1:
31
+ freqs_real = freqs_real.expand(seq_len, -1)
32
+ freqs_imag = freqs_imag.expand(seq_len, -1)
33
+ else:
34
+ raise ValueError(f"Insufficient rows in freqs: {freqs_real.shape[0]} < seq_len={seq_len}")
35
+ # 4) If we have more rows than seq_len (e.g., batch present), take the first seq_len rows
36
+ elif freqs_real.shape[0] > seq_len:
37
+ freqs_real = freqs_real[:seq_len]
38
+ freqs_imag = freqs_imag[:seq_len]
39
+
40
+ return freqs_real, freqs_imag
41
+
42
+
43
+ def _maybe_to_dtype(t: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
44
+ return t if t.dtype == dtype else t.to(dtype)
45
+
46
+
47
+ def _maybe_contiguous(t: torch.Tensor) -> torch.Tensor:
48
+ return t if t.is_contiguous() else t.contiguous()
49
+
50
+
51
+ def _cast_and_contiguous(q, k, freqs_real, freqs_imag):
52
+ # Choose compute dtype: use fp32 only when inputs are fp32; otherwise keep input dtype for performance
53
+ compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype
54
+
55
+ # Make sure q/k share the same dtype before casting to compute dtype
56
+ if k.dtype != q.dtype:
57
+ k = k.to(q.dtype)
58
+
59
+ q = _maybe_contiguous(_maybe_to_dtype(q, compute_dtype))
60
+ k = _maybe_contiguous(_maybe_to_dtype(k, compute_dtype))
61
+ freqs_real = _maybe_contiguous(_maybe_to_dtype(freqs_real, compute_dtype))
62
+ freqs_imag = _maybe_contiguous(_maybe_to_dtype(freqs_imag, compute_dtype))
63
+ return q, k, freqs_real, freqs_imag
64
+
65
+
66
+ @triton.jit
67
+ def _llama4_rope_kernel(
68
+ q_ptr,
69
+ k_ptr,
70
+ freqs_real_ptr,
71
+ freqs_imag_ptr,
72
+ q_row_stride,
73
+ k_row_stride,
74
+ q_head_stride,
75
+ k_head_stride,
76
+ freqs_row_stride,
77
+ seq_len,
78
+ batch_size,
79
+ imag_sign,
80
+ head_dim_half: tl.constexpr,
81
+ n_q_heads: tl.constexpr,
82
+ n_k_heads: tl.constexpr,
83
+ BLOCK_SIZE: tl.constexpr,
84
+ ):
85
+ """
86
+ H100-optimized RoPE kernel with improved parallelization across heads and dimensions.
87
+ Grid: (batch*seq, head)
88
+ """
89
+ # 2D grid
90
+ pid_bs = tl.program_id(0) # over batch*seq
91
+ pid_h = tl.program_id(1) # over heads
92
+
93
+ batch_idx = pid_bs // seq_len
94
+ seq_idx = pid_bs % seq_len
95
+
96
+ # Bounds check
97
+ if batch_idx >= batch_size or seq_idx >= seq_len:
98
+ return
99
+
100
+ # Base pointers for this (batch, seq) position
101
+ base_offset = batch_idx * seq_len + seq_idx
102
+ q_base = q_ptr + base_offset * q_row_stride
103
+ k_base = k_ptr + base_offset * k_row_stride
104
+
105
+ # Tiling over dim/2
106
+ for d_start in tl.static_range(0, head_dim_half, BLOCK_SIZE):
107
+ d_indices = d_start + tl.arange(0, BLOCK_SIZE)
108
+ mask_d = d_indices < head_dim_half
109
+
110
+ # Load frequencies once per tile (freqs layout: [seq_len, head_dim_half])
111
+ freq_idx = d_indices
112
+ freqs_real = tl.load(freqs_real_ptr + seq_idx * freqs_row_stride + freq_idx, mask=mask_d, other=0.0)
113
+ freqs_imag = tl.load(freqs_imag_ptr + seq_idx * freqs_row_stride + freq_idx, mask=mask_d, other=0.0)
114
+ freqs_imag = freqs_imag * imag_sign
115
+
116
+ # Process one query head per program in pid_h
117
+ if pid_h < n_q_heads:
118
+ q_head_ptr = q_base + pid_h * q_head_stride
119
+ q_real = tl.load(q_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
120
+ q_imag = tl.load(q_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)
121
+
122
+ # Complex multiply with FMAs: (a+ib)*(c+i d) = (a*c - b*d) + i(a*d + b*c)
123
+ new_q_real = tl.math.fma(q_real, freqs_real, -(q_imag * freqs_imag))
124
+ new_q_imag = tl.math.fma(q_real, freqs_imag, q_imag * freqs_real)
125
+
126
+ tl.store(q_head_ptr + d_indices * 2, new_q_real, mask=mask_d)
127
+ tl.store(q_head_ptr + d_indices * 2 + 1, new_q_imag, mask=mask_d)
128
+
129
+ # Process one key head per program in pid_h
130
+ if pid_h < n_k_heads:
131
+ k_head_ptr = k_base + pid_h * k_head_stride
132
+ k_real = tl.load(k_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
133
+ k_imag = tl.load(k_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)
134
+
135
+ new_k_real = tl.math.fma(k_real, freqs_real, -(k_imag * freqs_imag))
136
+ new_k_imag = tl.math.fma(k_real, freqs_imag, k_imag * freqs_real)
137
+
138
+ tl.store(k_head_ptr + d_indices * 2, new_k_real, mask=mask_d)
139
+ tl.store(k_head_ptr + d_indices * 2 + 1, new_k_imag, mask=mask_d)
140
+
141
+
142
+ def _select_kernel_meta(head_dim_half: int):
143
+ # Heuristic tuning for block size and num_warps
144
+ if head_dim_half >= 256:
145
+ return 128, 8
146
+ if head_dim_half >= 96:
147
+ return 128, 4
148
+ if head_dim_half >= 48:
149
+ return 64, 4
150
+ if head_dim_half >= 24:
151
+ return 32, 2
152
+ return 16, 2
153
+
154
+
155
+ def llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE: int = None, imag_sign: float = 1.0):
156
+ # Save original dtype for casting back
157
+ original_dtype = q.dtype
158
+
159
+ batch_size, seq_len, n_q_heads, head_dim = q.shape
160
+ _, _, n_k_heads, _ = k.shape
161
+ head_dim_half = head_dim // 2
162
+
163
+ # Prepare frequencies
164
+ freqs_real, freqs_imag = _prepare_freqs(freqs_cis, seq_len, head_dim_half)
165
+
166
+ # Cast to appropriate dtype and make contiguous only when needed
167
+ q, k, freqs_real, freqs_imag = _cast_and_contiguous(q, k, freqs_real, freqs_imag)
168
+
169
+ # H100-optimized meta-params
170
+ if BLOCK_SIZE is None:
171
+ BLOCK_SIZE, num_warps = _select_kernel_meta(head_dim_half)
172
+ else:
173
+ # Provide a default num_warps if caller pins BLOCK_SIZE
174
+ _, num_warps = _select_kernel_meta(head_dim_half)
175
+
176
+ # 2D grid: one program per (batch, seq, head)
177
+ n_heads_max = max(n_q_heads, n_k_heads)
178
+ grid = (batch_size * seq_len, n_heads_max)
179
+
180
+ # Launch kernel
181
+ _llama4_rope_kernel[grid](
182
+ q,
183
+ k,
184
+ freqs_real,
185
+ freqs_imag,
186
+ q.stride(1),
187
+ k.stride(1),
188
+ q.stride(2),
189
+ k.stride(2),
190
+ freqs_real.stride(0),
191
+ seq_len,
192
+ batch_size,
193
+ imag_sign,
194
+ head_dim_half,
195
+ n_q_heads,
196
+ n_k_heads,
197
+ BLOCK_SIZE,
198
+ num_warps=num_warps,
199
+ num_stages=2,
200
+ )
201
+
202
+ # Cast back to original dtype only if it differs from compute dtype
203
+ if q.dtype != original_dtype:
204
+ q = q.to(original_dtype)
205
+ if k.dtype != original_dtype:
206
+ k = k.to(original_dtype)
207
+
208
+ return q, k
209
+
210
+
211
+ class LigerLlama4RopeFunction(torch.autograd.Function):
212
+ @staticmethod
213
+ def forward(ctx, q, k, freqs_cis, BLOCK_SIZE: int = None):
214
+ q_out, k_out = llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE, imag_sign=1.0)
215
+ ctx.save_for_backward(freqs_cis.detach() if isinstance(freqs_cis, torch.Tensor) else freqs_cis)
216
+ ctx.BLOCK_SIZE = BLOCK_SIZE
217
+ return q_out, k_out
218
+
219
+ @staticmethod
220
+ def backward(ctx, dq, dk):
221
+ (freqs_cis,) = ctx.saved_tensors
222
+ BLOCK_SIZE = getattr(ctx, "BLOCK_SIZE", None)
223
+ # Use imag_sign=-1.0 for conjugate without materializing a new tensor
224
+ dq_out, dk_out = llama4_rope_forward(dq, dk, freqs_cis, BLOCK_SIZE, imag_sign=-1.0)
225
+ return dq_out, dk_out, None