liger-kernel-nightly 0.5.5.dev20250402185702__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (115) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +36 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +46 -15
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  15. liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
  16. liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
  17. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  18. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  19. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  20. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  21. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  22. liger_kernel/ops/backends/registry.py +61 -0
  23. liger_kernel/ops/cross_entropy.py +134 -65
  24. liger_kernel/ops/dyt.py +115 -180
  25. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  26. liger_kernel/ops/fused_linear_cross_entropy.py +117 -23
  27. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  28. liger_kernel/ops/geglu.py +6 -4
  29. liger_kernel/ops/group_norm.py +7 -7
  30. liger_kernel/ops/grpo_loss.py +312 -0
  31. liger_kernel/ops/jsd.py +2 -1
  32. liger_kernel/ops/kl_div.py +9 -5
  33. liger_kernel/ops/layer_norm.py +146 -78
  34. liger_kernel/ops/llama4_rope.py +225 -0
  35. liger_kernel/ops/multi_token_attention.py +207 -0
  36. liger_kernel/ops/poly_norm.py +390 -0
  37. liger_kernel/ops/rms_norm.py +398 -99
  38. liger_kernel/ops/rope.py +1 -1
  39. liger_kernel/ops/softmax.py +201 -0
  40. liger_kernel/ops/sparsemax.py +179 -0
  41. liger_kernel/ops/swiglu.py +1 -1
  42. liger_kernel/ops/tiled_mlp.py +136 -0
  43. liger_kernel/ops/utils.py +14 -0
  44. liger_kernel/transformers/__init__.py +208 -17
  45. liger_kernel/transformers/auto_model.py +21 -0
  46. liger_kernel/transformers/cross_entropy.py +9 -4
  47. liger_kernel/transformers/dyt.py +6 -4
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -1
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +122 -20
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -1
  57. liger_kernel/transformers/group_norm.py +1 -1
  58. liger_kernel/transformers/grpo_loss.py +153 -0
  59. liger_kernel/transformers/jsd.py +1 -1
  60. liger_kernel/transformers/kl_div.py +1 -1
  61. liger_kernel/transformers/layer_norm.py +1 -1
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/exaone4.py +136 -0
  64. liger_kernel/transformers/model/falcon_h1.py +122 -0
  65. liger_kernel/transformers/model/gemma.py +57 -27
  66. liger_kernel/transformers/model/gemma2.py +65 -28
  67. liger_kernel/transformers/model/gemma3.py +331 -0
  68. liger_kernel/transformers/model/glm4.py +141 -0
  69. liger_kernel/transformers/model/glm4v.py +163 -0
  70. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  71. liger_kernel/transformers/model/gpt_oss.py +211 -0
  72. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  73. liger_kernel/transformers/model/internvl.py +157 -0
  74. liger_kernel/transformers/model/llama.py +109 -27
  75. liger_kernel/transformers/model/llama4.py +121 -0
  76. liger_kernel/transformers/model/llava.py +111 -136
  77. liger_kernel/transformers/model/loss_utils.py +50 -12
  78. liger_kernel/transformers/model/mistral.py +51 -34
  79. liger_kernel/transformers/model/mixtral.py +50 -29
  80. liger_kernel/transformers/model/mllama.py +46 -24
  81. liger_kernel/transformers/model/olmo2.py +47 -22
  82. liger_kernel/transformers/model/olmo3.py +142 -0
  83. liger_kernel/transformers/model/output_classes.py +147 -0
  84. liger_kernel/transformers/model/paligemma.py +50 -14
  85. liger_kernel/transformers/model/phi3.py +47 -172
  86. liger_kernel/transformers/model/qwen2.py +55 -23
  87. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  88. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  89. liger_kernel/transformers/model/qwen3.py +136 -0
  90. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  91. liger_kernel/transformers/model/qwen3_next.py +146 -0
  92. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  93. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  94. liger_kernel/transformers/model/smollm3.py +199 -0
  95. liger_kernel/transformers/model/smolvlm.py +158 -0
  96. liger_kernel/transformers/monkey_patch.py +2018 -244
  97. liger_kernel/transformers/multi_token_attention.py +64 -0
  98. liger_kernel/transformers/poly_norm.py +42 -0
  99. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  100. liger_kernel/transformers/rms_norm.py +54 -6
  101. liger_kernel/transformers/rope.py +45 -1
  102. liger_kernel/transformers/softmax.py +12 -0
  103. liger_kernel/transformers/sparsemax.py +16 -0
  104. liger_kernel/transformers/swiglu.py +39 -1
  105. liger_kernel/transformers/tiled_mlp.py +125 -0
  106. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  107. liger_kernel/transformers/tvd.py +1 -1
  108. liger_kernel/utils.py +63 -0
  109. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +73 -39
  110. liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
  111. liger_kernel_nightly-0.5.5.dev20250402185702.dist-info/RECORD +0 -80
  112. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
  115. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
@@ -8,8 +8,9 @@ import triton.language as tl
8
8
  from liger_kernel.ops.utils import calculate_settings
9
9
  from liger_kernel.ops.utils import compare_version
10
10
  from liger_kernel.ops.utils import ensure_contiguous
11
+ from liger_kernel.utils import is_npu_available
11
12
 
12
- if compare_version("triton", operator.ge, "3.0.0"):
13
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
13
14
  try:
14
15
  # typical import path with dispatch available
15
16
  from triton.language.extra.libdevice import rsqrt
@@ -43,118 +44,171 @@ def _layer_norm_forward_kernel(
43
44
  https://arxiv.org/abs/1607.06450
44
45
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
45
46
  """
46
- row_idx = tl.program_id(0)
47
+ row_idx = tl.program_id(0).to(tl.int64)
47
48
  col_offsets = tl.arange(0, BLOCK_SIZE)
48
49
  mask = col_offsets < n_cols
49
50
 
50
- Y_ptr += row_idx * Y_row_stride
51
- X_ptr += row_idx * X_row_stride
52
- Mean_ptr += row_idx * Mean_row_stride
53
- RSTD_ptr += row_idx * RSTD_row_stride
54
-
55
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
56
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
57
- B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
58
-
59
- mean = tl.sum(X_row, axis=0) / n_cols
60
- Xmm = tl.where(mask, X_row - mean, 0)
61
- var = tl.sum(Xmm * Xmm, axis=0) / n_cols
51
+ # Pre-load weights and bias in fp32 to avoid repeated conversions
52
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
53
+ B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0.0)
54
+ W_f32 = W_row.to(tl.float32)
55
+ B_f32 = B_row.to(tl.float32)
56
+
57
+ # Calculate pointers for this row
58
+ row_X_ptr = X_ptr + row_idx * X_row_stride
59
+ row_Y_ptr = Y_ptr + row_idx * Y_row_stride
60
+ row_Mean_ptr = Mean_ptr + row_idx * Mean_row_stride
61
+ row_RSTD_ptr = RSTD_ptr + row_idx * RSTD_row_stride
62
+
63
+ # Load input data and convert to fp32 for numerical stability
64
+ X_row = tl.load(row_X_ptr + col_offsets, mask=mask, other=0.0)
65
+ X_f32 = X_row.to(tl.float32)
66
+
67
+ # Compute statistics in fp32 for numerical stability
68
+ mean = tl.sum(X_f32, axis=0) / n_cols
69
+ X_centered = X_f32 - mean
70
+ # Apply mask to variance calculation to exclude contributions from masked elements
71
+ X_centered_masked = tl.where(mask, X_centered, 0.0)
72
+ var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
62
73
  rstd = rsqrt(var + eps)
63
74
 
64
- tl.store(Mean_ptr, mean)
65
- tl.store(RSTD_ptr, rstd)
75
+ # Store statistics (convert back to original dtype only once)
76
+ tl.store(row_Mean_ptr, mean.to(X_row.dtype))
77
+ tl.store(row_RSTD_ptr, rstd.to(X_row.dtype))
66
78
 
67
- Y_row = Xmm * rstd * W_row + B_row
79
+ # Fused normalization and affine transformation
80
+ # Y = (X - mean) * rstd * W + B = X_centered * rstd * W + B
81
+ Y_f32 = X_centered * rstd * W_f32 + B_f32
68
82
 
69
- tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
83
+ # Store output (single conversion back to original dtype)
84
+ tl.store(row_Y_ptr + col_offsets, Y_f32.to(X_row.dtype), mask=mask)
70
85
 
71
86
 
72
87
  @triton.jit
73
88
  def _layer_norm_backward_kernel(
74
89
  X_ptr, # pointer to input, shape (n_rows, n_cols)
90
+ stride_x, # stride of each row in input
75
91
  W_ptr, # pointer to weights, shape (n_cols,)
76
92
  Mean_ptr, # pointer to mean, shape (n_rows,)
93
+ stride_mean, # stride of each row in mean
77
94
  RSTD_ptr, # pointer to rstd, shape (n_rows,)
95
+ stride_rstd, # stride of each row in rstd
78
96
  DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
79
- DW_ptr, # pointer to weights grad, shape (n_cols,)
80
- DB_ptr, # pointer to bias grad, shape (n_cols,)
81
- DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
82
- stride_x, # stride of each row in input
83
97
  stride_dx, # stride of each row in input grad
98
+ DW_ptr, # pointer to weights grad, shape (n_cols,)
84
99
  stride_dw, # stride of each row in weights grad
100
+ DB_ptr, # pointer to bias grad, shape (n_cols,)
85
101
  stride_db, # stride of each row in bias grad
102
+ DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
86
103
  stride_dy, # stride of each row in output grad
87
104
  n_rows,
88
105
  n_cols,
89
106
  rows_per_program: tl.constexpr,
90
107
  BLOCK_SIZE: tl.constexpr,
91
- dtype: tl.constexpr,
92
108
  ):
93
109
  """
94
110
  References:
95
111
  https://arxiv.org/abs/1607.06450
96
112
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
97
- https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
98
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
99
113
  """
100
- row_block_id = tl.program_id(0)
114
+ row_block_id = tl.program_id(0).to(tl.int64)
101
115
  row_start = row_block_id * rows_per_program
102
116
  row_end = min((row_block_id + 1) * rows_per_program, n_rows)
103
117
  cols = tl.arange(0, BLOCK_SIZE)
104
118
  mask = cols < n_cols
105
119
 
106
- dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
120
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
121
  db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
108
122
 
109
- X_ptr += row_start * stride_x
110
- Mean_ptr += row_start
111
- RSTD_ptr += row_start
112
- DX_ptr += row_start * stride_dx
113
- DY_ptr += row_start * stride_dy
123
+ # Pre-load weights once (same optimization as forward pass)
124
+ w = tl.load(W_ptr + cols, mask=mask, other=0.0)
125
+ w_f32 = w.to(tl.float32)
126
+
127
+ # Calculate pointers for this specific row
128
+ row_X_ptr = X_ptr + row_start * stride_x
129
+ row_DX_ptr = DX_ptr + row_start * stride_dx
130
+ row_DY_ptr = DY_ptr + row_start * stride_dy
131
+ row_Mean_ptr = Mean_ptr + row_start
132
+ row_RSTD_ptr = RSTD_ptr + row_start
114
133
 
115
134
  for _ in range(row_start, row_end):
116
- x = tl.load(X_ptr + cols, mask=mask, other=0.0)
117
- w = tl.load(W_ptr + cols, mask=mask, other=0.0)
118
- dy = tl.load(DY_ptr + cols, mask=mask, other=0.0)
119
- mean = tl.load(Mean_ptr)
120
- rstd = tl.load(RSTD_ptr)
121
-
122
- x_hat = (x - mean) * rstd
123
- wdy = w * dy
135
+ # Load data for this row
136
+ x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
137
+ dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
138
+ mean = tl.load(row_Mean_ptr)
139
+ rstd = tl.load(row_RSTD_ptr)
140
+
141
+ # Convert to fp32 for numerical stability
142
+ x_f32 = x.to(tl.float32)
143
+ dy_f32 = dy.to(tl.float32)
144
+ mean_f32 = mean.to(tl.float32)
145
+ rstd_f32 = rstd.to(tl.float32)
146
+
147
+ # Compute backward pass for this row
148
+ x_hat = (x_f32 - mean_f32) * rstd_f32
149
+ wdy = w_f32 * dy_f32
124
150
  c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
125
151
  c2 = tl.sum(wdy, axis=0) / n_cols
126
- dx = (wdy - (x_hat * c1 + c2)) * rstd
127
- tl.store(DX_ptr + cols, dx.to(dtype), mask=mask)
152
+ dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
128
153
 
129
- dw_row += dy * x_hat
130
- db_row += dy
154
+ # Store input gradient
155
+ tl.store(row_DX_ptr + cols, dx, mask=mask)
131
156
 
132
- X_ptr += stride_x
133
- Mean_ptr += 1
134
- RSTD_ptr += 1
135
- DX_ptr += stride_dx
136
- DY_ptr += stride_dy
157
+ # Accumulate weight and bias gradients for this thread block's assigned rows
158
+ dw = dy_f32 * x_hat
159
+ db = dy_f32
160
+ dW_row += dw
161
+ db_row += db
137
162
 
138
- tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask)
139
- tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask)
163
+ row_X_ptr += stride_x
164
+ row_DX_ptr += stride_dx
165
+ row_DY_ptr += stride_dy
166
+ row_Mean_ptr += stride_mean
167
+ row_RSTD_ptr += stride_rstd
168
+
169
+ tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
170
+ tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
140
171
 
141
172
 
142
173
  def layer_norm_forward(X, W, B, eps):
174
+ """
175
+ Args:
176
+ X: Input tensor of shape (..., hidden_size)
177
+ W: Weight tensor of shape (hidden_size,)
178
+ B: Bias tensor of shape (hidden_size,)
179
+ eps: Small constant for numerical stability
180
+
181
+ Returns:
182
+ Tuple of (output, input, mean, rstd, block_size, num_warps)
183
+ """
143
184
  shape = X.shape
144
185
  dim = shape[-1]
145
186
  X = X.view(-1, dim)
146
187
  n_rows, n_cols = X.shape
188
+
189
+ # Calculate optimal block size and warp configuration
147
190
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
191
+
192
+ # Allocate output tensors
148
193
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
149
194
  Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
150
195
  RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
196
+
197
+ # Validate input dimensions
151
198
  if X.shape[1] != W.shape[0]:
152
199
  raise ValueError(
153
200
  f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
154
201
  f"must match weight size (W.shape[0]={W.shape[0]})"
155
202
  )
156
203
 
157
- _layer_norm_forward_kernel[(n_rows,)](
204
+ # XPU-specific optimization
205
+ kernel_args = {}
206
+ if X.device.type == "xpu":
207
+ kernel_args["grf_mode"] = "large"
208
+
209
+ # Launch kernel with one thread block per row for optimal performance
210
+ grid = (n_rows,)
211
+ _layer_norm_forward_kernel[grid](
158
212
  Y,
159
213
  Y.stride(0),
160
214
  X,
@@ -171,11 +225,25 @@ def layer_norm_forward(X, W, B, eps):
171
225
  eps,
172
226
  BLOCK_SIZE=BLOCK_SIZE,
173
227
  num_warps=num_warps,
228
+ **kernel_args,
174
229
  )
230
+
175
231
  return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
176
232
 
177
233
 
178
234
  def layer_norm_backward(dY, X, W, B, Mean, RSTD):
235
+ """
236
+ Args:
237
+ dY: Gradient of output
238
+ X: Input tensor
239
+ W: Weight tensor
240
+ B: Bias tensor
241
+ Mean: Pre-computed mean
242
+ RSTD: Pre-computed reciprocal standard deviation
243
+
244
+ Returns:
245
+ Tuple of (input_grad, weight_grad, bias_grad)
246
+ """
179
247
  shape = dY.shape
180
248
  dim = shape[-1]
181
249
  dY = dY.view(-1, dim)
@@ -185,54 +253,54 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
185
253
  if X.device.type == "cuda":
186
254
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
187
255
  elif X.device.type == "xpu":
188
- sm_count = torch.xpu.get_device_properties(X.device).gpu_subslice_count
256
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
189
257
 
190
- DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
191
- _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
192
- _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
258
+ # fp32 for numerical stability especially.
259
+ _DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
260
+ _DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
193
261
 
262
+ # Calculate optimal block size and warp configuration
194
263
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
195
264
  if n_cols > BLOCK_SIZE:
196
- raise RuntimeError(
197
- f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
198
- )
199
-
265
+ raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
200
266
  rows_per_program = math.ceil(n_rows / sm_count)
201
267
  grid = (sm_count,)
202
- triton_dtype = (
203
- tl.float32
204
- if X.dtype == torch.float32
205
- else tl.bfloat16
206
- if X.dtype == torch.bfloat16
207
- else tl.float16
208
- if X.dtype == torch.float16
209
- else tl.float32 # fallback to float32 for other types
210
- )
268
+
269
+ # Allocate gradient tensors
270
+ DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
271
+
272
+ kernel_args = {"num_warps": num_warps}
273
+ # XPU-specific optimization
274
+ if X.device.type == "xpu":
275
+ kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
276
+
277
+ # Launch kernel with one thread block per row for optimal performance
211
278
  _layer_norm_backward_kernel[grid](
212
279
  X,
280
+ X.stride(0),
213
281
  W,
214
282
  Mean,
283
+ Mean.stride(0),
215
284
  RSTD,
285
+ RSTD.stride(0),
216
286
  DX,
217
- _DW,
218
- _DB,
219
- dY,
220
- X.stride(0),
221
287
  DX.stride(0),
288
+ _DW,
222
289
  _DW.stride(0),
290
+ _DB,
223
291
  _DB.stride(0),
292
+ dY,
224
293
  dY.stride(0),
225
294
  n_rows,
226
295
  n_cols,
227
- rows_per_program,
296
+ rows_per_program=rows_per_program,
228
297
  BLOCK_SIZE=BLOCK_SIZE,
229
- dtype=triton_dtype,
298
+ **kernel_args,
230
299
  )
231
300
 
232
- DW = _DW.sum(dim=0).to(W.dtype)
233
- DB = _DB.sum(dim=0).to(W.dtype)
234
-
235
301
  DX = DX.view(*shape)
302
+ DW = _DW.sum(dim=0).to(W.dtype)
303
+ DB = _DB.sum(dim=0).to(B.dtype)
236
304
  return DX, DW, DB
237
305
 
238
306
 
@@ -0,0 +1,225 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ def _prepare_freqs(freqs_cis: torch.Tensor, seq_len: int, head_dim_half: int):
7
+ # Split or unpack complex frequencies into real and imag parts
8
+ if freqs_cis.is_complex():
9
+ freqs_real = freqs_cis.real
10
+ freqs_imag = freqs_cis.imag
11
+ else:
12
+ # Already split: last dim should be 2*head_dim_half
13
+ if freqs_cis.shape[-1] == 2 * head_dim_half:
14
+ freqs_real = freqs_cis[..., :head_dim_half]
15
+ freqs_imag = freqs_cis[..., head_dim_half:]
16
+ else:
17
+ raise ValueError(
18
+ f"Unexpected freqs_cis shape for non-complex input: {freqs_cis.shape}, expected last dim = {2 * head_dim_half}"
19
+ )
20
+
21
+ # Canonicalize to shape (seq_len, head_dim_half):
22
+ # 1) Ensure the last dimension is head_dim_half
23
+ if freqs_real.shape[-1] != head_dim_half:
24
+ raise ValueError(f"Unexpected last dim for freqs: {freqs_real.shape[-1]} (expected {head_dim_half})")
25
+ # 2) Flatten all leading dims to a single row dimension
26
+ freqs_real = freqs_real.reshape(-1, head_dim_half)
27
+ freqs_imag = freqs_imag.reshape(-1, head_dim_half)
28
+ # 3) If we have fewer rows than seq_len, allow broadcasting when single row
29
+ if freqs_real.shape[0] < seq_len:
30
+ if freqs_real.shape[0] == 1:
31
+ freqs_real = freqs_real.expand(seq_len, -1)
32
+ freqs_imag = freqs_imag.expand(seq_len, -1)
33
+ else:
34
+ raise ValueError(f"Insufficient rows in freqs: {freqs_real.shape[0]} < seq_len={seq_len}")
35
+ # 4) If we have more rows than seq_len (e.g., batch present), take the first seq_len rows
36
+ elif freqs_real.shape[0] > seq_len:
37
+ freqs_real = freqs_real[:seq_len]
38
+ freqs_imag = freqs_imag[:seq_len]
39
+
40
+ return freqs_real, freqs_imag
41
+
42
+
43
+ def _maybe_to_dtype(t: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
44
+ return t if t.dtype == dtype else t.to(dtype)
45
+
46
+
47
+ def _maybe_contiguous(t: torch.Tensor) -> torch.Tensor:
48
+ return t if t.is_contiguous() else t.contiguous()
49
+
50
+
51
+ def _cast_and_contiguous(q, k, freqs_real, freqs_imag):
52
+ # Choose compute dtype: use fp32 only when inputs are fp32; otherwise keep input dtype for performance
53
+ compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype
54
+
55
+ # Make sure q/k share the same dtype before casting to compute dtype
56
+ if k.dtype != q.dtype:
57
+ k = k.to(q.dtype)
58
+
59
+ q = _maybe_contiguous(_maybe_to_dtype(q, compute_dtype))
60
+ k = _maybe_contiguous(_maybe_to_dtype(k, compute_dtype))
61
+ freqs_real = _maybe_contiguous(_maybe_to_dtype(freqs_real, compute_dtype))
62
+ freqs_imag = _maybe_contiguous(_maybe_to_dtype(freqs_imag, compute_dtype))
63
+ return q, k, freqs_real, freqs_imag
64
+
65
+
66
+ @triton.jit
67
+ def _llama4_rope_kernel(
68
+ q_ptr,
69
+ k_ptr,
70
+ freqs_real_ptr,
71
+ freqs_imag_ptr,
72
+ q_row_stride,
73
+ k_row_stride,
74
+ q_head_stride,
75
+ k_head_stride,
76
+ freqs_row_stride,
77
+ seq_len,
78
+ batch_size,
79
+ imag_sign,
80
+ head_dim_half: tl.constexpr,
81
+ n_q_heads: tl.constexpr,
82
+ n_k_heads: tl.constexpr,
83
+ BLOCK_SIZE: tl.constexpr,
84
+ ):
85
+ """
86
+ H100-optimized RoPE kernel with improved parallelization across heads and dimensions.
87
+ Grid: (batch*seq, head)
88
+ """
89
+ # 2D grid
90
+ pid_bs = tl.program_id(0) # over batch*seq
91
+ pid_h = tl.program_id(1) # over heads
92
+
93
+ batch_idx = pid_bs // seq_len
94
+ seq_idx = pid_bs % seq_len
95
+
96
+ # Bounds check
97
+ if batch_idx >= batch_size or seq_idx >= seq_len:
98
+ return
99
+
100
+ # Base pointers for this (batch, seq) position
101
+ base_offset = batch_idx * seq_len + seq_idx
102
+ q_base = q_ptr + base_offset * q_row_stride
103
+ k_base = k_ptr + base_offset * k_row_stride
104
+
105
+ # Tiling over dim/2
106
+ for d_start in tl.static_range(0, head_dim_half, BLOCK_SIZE):
107
+ d_indices = d_start + tl.arange(0, BLOCK_SIZE)
108
+ mask_d = d_indices < head_dim_half
109
+
110
+ # Load frequencies once per tile (freqs layout: [seq_len, head_dim_half])
111
+ freq_idx = d_indices
112
+ freqs_real = tl.load(freqs_real_ptr + seq_idx * freqs_row_stride + freq_idx, mask=mask_d, other=0.0)
113
+ freqs_imag = tl.load(freqs_imag_ptr + seq_idx * freqs_row_stride + freq_idx, mask=mask_d, other=0.0)
114
+ freqs_imag = freqs_imag * imag_sign
115
+
116
+ # Process one query head per program in pid_h
117
+ if pid_h < n_q_heads:
118
+ q_head_ptr = q_base + pid_h * q_head_stride
119
+ q_real = tl.load(q_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
120
+ q_imag = tl.load(q_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)
121
+
122
+ # Complex multiply with FMAs: (a+ib)*(c+i d) = (a*c - b*d) + i(a*d + b*c)
123
+ new_q_real = tl.math.fma(q_real, freqs_real, -(q_imag * freqs_imag))
124
+ new_q_imag = tl.math.fma(q_real, freqs_imag, q_imag * freqs_real)
125
+
126
+ tl.store(q_head_ptr + d_indices * 2, new_q_real, mask=mask_d)
127
+ tl.store(q_head_ptr + d_indices * 2 + 1, new_q_imag, mask=mask_d)
128
+
129
+ # Process one key head per program in pid_h
130
+ if pid_h < n_k_heads:
131
+ k_head_ptr = k_base + pid_h * k_head_stride
132
+ k_real = tl.load(k_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
133
+ k_imag = tl.load(k_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)
134
+
135
+ new_k_real = tl.math.fma(k_real, freqs_real, -(k_imag * freqs_imag))
136
+ new_k_imag = tl.math.fma(k_real, freqs_imag, k_imag * freqs_real)
137
+
138
+ tl.store(k_head_ptr + d_indices * 2, new_k_real, mask=mask_d)
139
+ tl.store(k_head_ptr + d_indices * 2 + 1, new_k_imag, mask=mask_d)
140
+
141
+
142
+ def _select_kernel_meta(head_dim_half: int):
143
+ # Heuristic tuning for block size and num_warps
144
+ if head_dim_half >= 256:
145
+ return 128, 8
146
+ if head_dim_half >= 96:
147
+ return 128, 4
148
+ if head_dim_half >= 48:
149
+ return 64, 4
150
+ if head_dim_half >= 24:
151
+ return 32, 2
152
+ return 16, 2
153
+
154
+
155
+ def llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE: int = None, imag_sign: float = 1.0):
156
+ # Save original dtype for casting back
157
+ original_dtype = q.dtype
158
+
159
+ batch_size, seq_len, n_q_heads, head_dim = q.shape
160
+ _, _, n_k_heads, _ = k.shape
161
+ head_dim_half = head_dim // 2
162
+
163
+ # Prepare frequencies
164
+ freqs_real, freqs_imag = _prepare_freqs(freqs_cis, seq_len, head_dim_half)
165
+
166
+ # Cast to appropriate dtype and make contiguous only when needed
167
+ q, k, freqs_real, freqs_imag = _cast_and_contiguous(q, k, freqs_real, freqs_imag)
168
+
169
+ # H100-optimized meta-params
170
+ if BLOCK_SIZE is None:
171
+ BLOCK_SIZE, num_warps = _select_kernel_meta(head_dim_half)
172
+ else:
173
+ # Provide a default num_warps if caller pins BLOCK_SIZE
174
+ _, num_warps = _select_kernel_meta(head_dim_half)
175
+
176
+ # 2D grid: one program per (batch, seq, head)
177
+ n_heads_max = max(n_q_heads, n_k_heads)
178
+ grid = (batch_size * seq_len, n_heads_max)
179
+
180
+ # Launch kernel
181
+ _llama4_rope_kernel[grid](
182
+ q,
183
+ k,
184
+ freqs_real,
185
+ freqs_imag,
186
+ q.stride(1),
187
+ k.stride(1),
188
+ q.stride(2),
189
+ k.stride(2),
190
+ freqs_real.stride(0),
191
+ seq_len,
192
+ batch_size,
193
+ imag_sign,
194
+ head_dim_half,
195
+ n_q_heads,
196
+ n_k_heads,
197
+ BLOCK_SIZE,
198
+ num_warps=num_warps,
199
+ num_stages=2,
200
+ )
201
+
202
+ # Cast back to original dtype only if it differs from compute dtype
203
+ if q.dtype != original_dtype:
204
+ q = q.to(original_dtype)
205
+ if k.dtype != original_dtype:
206
+ k = k.to(original_dtype)
207
+
208
+ return q, k
209
+
210
+
211
+ class LigerLlama4RopeFunction(torch.autograd.Function):
212
+ @staticmethod
213
+ def forward(ctx, q, k, freqs_cis, BLOCK_SIZE: int = None):
214
+ q_out, k_out = llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE, imag_sign=1.0)
215
+ ctx.save_for_backward(freqs_cis.detach() if isinstance(freqs_cis, torch.Tensor) else freqs_cis)
216
+ ctx.BLOCK_SIZE = BLOCK_SIZE
217
+ return q_out, k_out
218
+
219
+ @staticmethod
220
+ def backward(ctx, dq, dk):
221
+ (freqs_cis,) = ctx.saved_tensors
222
+ BLOCK_SIZE = getattr(ctx, "BLOCK_SIZE", None)
223
+ # Use imag_sign=-1.0 for conjugate without materializing a new tensor
224
+ dq_out, dk_out = llama4_rope_forward(dq, dk, freqs_cis, BLOCK_SIZE, imag_sign=-1.0)
225
+ return dq_out, dk_out, None