liger-kernel 0.5.10__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
  3. liger_kernel/chunked_loss/functional.py +2 -0
  4. liger_kernel/ops/dyt.py +0 -2
  5. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  6. liger_kernel/ops/geglu.py +1 -1
  7. liger_kernel/ops/multi_token_attention.py +207 -0
  8. liger_kernel/ops/rms_norm.py +265 -54
  9. liger_kernel/ops/softmax.py +201 -0
  10. liger_kernel/ops/sparsemax.py +62 -50
  11. liger_kernel/ops/swiglu.py +1 -1
  12. liger_kernel/transformers/__init__.py +3 -0
  13. liger_kernel/transformers/functional.py +62 -0
  14. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  15. liger_kernel/transformers/model/gemma.py +25 -8
  16. liger_kernel/transformers/model/gemma2.py +27 -8
  17. liger_kernel/transformers/model/gemma3.py +62 -98
  18. liger_kernel/transformers/model/glm4.py +16 -7
  19. liger_kernel/transformers/model/llama.py +25 -7
  20. liger_kernel/transformers/model/llama4.py +108 -0
  21. liger_kernel/transformers/model/llava.py +95 -124
  22. liger_kernel/transformers/model/mistral.py +13 -8
  23. liger_kernel/transformers/model/mixtral.py +16 -7
  24. liger_kernel/transformers/model/mllama.py +16 -7
  25. liger_kernel/transformers/model/olmo2.py +16 -7
  26. liger_kernel/transformers/model/paligemma.py +8 -1
  27. liger_kernel/transformers/model/phi3.py +25 -8
  28. liger_kernel/transformers/model/qwen2.py +24 -7
  29. liger_kernel/transformers/model/qwen2_5_vl.py +41 -91
  30. liger_kernel/transformers/model/qwen2_vl.py +38 -100
  31. liger_kernel/transformers/model/qwen3.py +11 -3
  32. liger_kernel/transformers/model/qwen3_moe.py +10 -6
  33. liger_kernel/transformers/monkey_patch.py +304 -70
  34. liger_kernel/transformers/multi_token_attention.py +64 -0
  35. liger_kernel/transformers/rms_norm.py +40 -4
  36. liger_kernel/transformers/softmax.py +12 -0
  37. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +8 -2
  38. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/RECORD +42 -35
  39. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
  40. liger_kernel/transformers/gema3_rms.py +0 -8
  41. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
  42. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
  43. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
liger_kernel/ops/geglu.py CHANGED
@@ -40,7 +40,7 @@ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE
40
40
  tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
41
41
  tanh_result = tanh(tanh_arg)
42
42
  geglu_a = 0.5 * a_row * (1 + tanh_result)
43
- c_row = geglu_a * b_row
43
+ c_row = geglu_a.cast(b_row.dtype) * b_row
44
44
  tl.store(c + col_offsets, c_row, mask=mask)
45
45
 
46
46
 
@@ -0,0 +1,207 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ from torch.nn.modules.utils import _pair
7
+
8
+ from liger_kernel.ops.softmax import _softmax_forward
9
+ from liger_kernel.ops.sparsemax import _sparsemax_backward
10
+ from liger_kernel.ops.sparsemax import _sparsemax_forward
11
+ from liger_kernel.ops.utils import calculate_settings
12
+ from liger_kernel.ops.utils import ensure_contiguous
13
+
14
+
15
+ @triton.jit
16
+ def _mask_fwd_kernel(
17
+ scores_ptr,
18
+ out_ptr,
19
+ stride_b,
20
+ stride_m,
21
+ stride_n,
22
+ L,
23
+ mask_val: tl.constexpr,
24
+ BLOCK: tl.constexpr,
25
+ num_warps: tl.constexpr,
26
+ ):
27
+ row_block = tl.program_id(0)
28
+ col_block = tl.program_id(1)
29
+ batch_id = tl.program_id(2)
30
+
31
+ row_idx = row_block * BLOCK + tl.arange(0, BLOCK)
32
+ col_idx = col_block * BLOCK + tl.arange(0, BLOCK)
33
+ in_bounds = (row_idx[:, None] < L) & (col_idx[None, :] < L)
34
+
35
+ base = scores_ptr + batch_id * stride_b
36
+ offs = row_idx[:, None] * stride_m + col_idx[None, :] * stride_n
37
+ future = col_idx[None, :] > row_idx[:, None]
38
+ mask_load = in_bounds & ~future
39
+ out = tl.load(base + offs, mask=mask_load, other=mask_val, cache_modifier=".ca")
40
+ tl.store(out_ptr + batch_id * stride_b + offs, out, mask=in_bounds, cache_modifier=".cs")
41
+
42
+
43
+ @triton.jit
44
+ def _mask_bwd_kernel(
45
+ grad_in_ptr, out_ptr, stride_b, stride_m, stride_n, L, BLOCK: tl.constexpr, num_warps: tl.constexpr
46
+ ):
47
+ row_block = tl.program_id(0)
48
+ col_block = tl.program_id(1)
49
+ batch_id = tl.program_id(2)
50
+
51
+ row_idx = row_block * BLOCK + tl.arange(0, BLOCK)
52
+ col_idx = col_block * BLOCK + tl.arange(0, BLOCK)
53
+ in_bounds = (row_idx[:, None] < L) & (col_idx[None, :] < L)
54
+
55
+ base = grad_in_ptr + batch_id * stride_b
56
+ offs = row_idx[:, None] * stride_m + col_idx[None, :] * stride_n
57
+ grad_vals = tl.load(base + offs, mask=in_bounds, other=0.0, cache_modifier=".ca")
58
+
59
+ future = col_idx[None, :] > row_idx[:, None]
60
+ zero = tl.zeros(grad_vals.shape, dtype=grad_vals.dtype)
61
+ out = tl.where(future, zero, grad_vals)
62
+
63
+ tl.store(out_ptr + batch_id * stride_b + offs, out, mask=in_bounds, cache_modifier=".wb")
64
+
65
+
66
+ def _mask_inf_forward(scores: torch.Tensor) -> torch.Tensor:
67
+ *batch, L, _ = scores.shape
68
+ N = int(torch.prod(torch.tensor(batch))) if batch else 1
69
+ scores_f = scores.view(N, L, L)
70
+ out = torch.empty_like(scores_f)
71
+
72
+ sb, sm, sn = scores_f.stride(0), scores_f.stride(1), scores_f.stride(2)
73
+ BLOCK_SIZE, num_warps = calculate_settings(L)
74
+ grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
75
+ _mask_fwd_kernel[grid](scores_f, out, sb, sm, sn, L, mask_val=-1e9, BLOCK=BLOCK_SIZE, num_warps=num_warps)
76
+ return out.view(*batch, L, L)
77
+
78
+
79
+ def _mask_inf_backward(grad: torch.Tensor) -> torch.Tensor:
80
+ *batch, L, _ = grad.shape
81
+ N = int(torch.prod(torch.tensor(batch))) if batch else 1
82
+ grad_f = grad.view(N, L, L)
83
+ out = torch.empty_like(grad_f)
84
+
85
+ sb, sm, sn = grad_f.stride(0), grad_f.stride(1), grad_f.stride(2)
86
+ BLOCK_SIZE, num_warps = calculate_settings(L)
87
+ grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
88
+ _mask_bwd_kernel[grid](grad_f, out, sb, sm, sn, L, BLOCK=BLOCK_SIZE, num_warps=num_warps)
89
+ return out.view(*batch, L, L)
90
+
91
+
92
+ def _mask_zero_forward(scores: torch.Tensor) -> torch.Tensor:
93
+ *batch, L, _ = scores.shape
94
+ N = int(torch.prod(torch.tensor(batch))) if batch else 1
95
+ scores_f = scores.view(N, L, L)
96
+ out = torch.empty_like(scores_f)
97
+
98
+ sb, sm, sn = scores_f.stride(0), scores_f.stride(1), scores_f.stride(2)
99
+ BLOCK_SIZE, num_warps = calculate_settings(L)
100
+ grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
101
+ _mask_fwd_kernel[grid](scores_f, out, sb, sm, sn, L, mask_val=0.0, BLOCK=BLOCK_SIZE, num_warps=num_warps)
102
+ return out.view(*batch, L, L)
103
+
104
+
105
+ def _mask_zero_backward(grad: torch.Tensor) -> torch.Tensor:
106
+ *batch, L, _ = grad.shape
107
+ N = int(torch.prod(torch.tensor(batch))) if batch else 1
108
+ grad_f = grad.view(N, L, L)
109
+ out = torch.empty_like(grad_f)
110
+
111
+ sb, sm, sn = grad_f.stride(0), grad_f.stride(1), grad_f.stride(2)
112
+ BLOCK_SIZE, num_warps = calculate_settings(L)
113
+ grid = (triton.cdiv(L, BLOCK_SIZE), triton.cdiv(L, BLOCK_SIZE), N)
114
+ _mask_bwd_kernel[grid](grad_f, out, sb, sm, sn, L, BLOCK=BLOCK_SIZE, num_warps=num_warps)
115
+ return out.view(*batch, L, L)
116
+
117
+
118
+ class LigerMultiTokenAttentionFunction(torch.autograd.Function):
119
+ @staticmethod
120
+ @ensure_contiguous
121
+ def forward(ctx, scores, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, sparse=False):
122
+ scores_inf = _mask_inf_forward(scores)
123
+
124
+ out_flat_sparse = None
125
+ activation_output = None
126
+
127
+ ctx.sparse = sparse
128
+
129
+ if sparse:
130
+ if scores_inf.dtype != torch.float32:
131
+ raise RuntimeError("Liger sparse multi-token attention currently only supports fp32 input scores")
132
+ probs_sparse, out_flat_sparse = _sparsemax_forward(scores_inf, dim=-1)
133
+ activation_output = probs_sparse
134
+ ctx.save_for_backward(scores_inf, activation_output, out_flat_sparse, weight, bias)
135
+ ctx.out_flat_sparse_saved = True
136
+ else:
137
+ probs_softmax, _, _, _ = _softmax_forward(scores_inf)
138
+ activation_output = probs_softmax
139
+ ctx.save_for_backward(scores_inf, activation_output, weight, bias)
140
+ ctx.out_flat_sparse_saved = False
141
+
142
+ out_conv = F.conv2d(
143
+ activation_output,
144
+ weight,
145
+ bias,
146
+ stride=stride,
147
+ padding=padding,
148
+ dilation=dilation,
149
+ groups=groups,
150
+ )
151
+
152
+ out = _mask_zero_forward(out_conv)
153
+
154
+ ctx.stride = _pair(stride)
155
+ ctx.padding = _pair(padding)
156
+ ctx.dilation = _pair(dilation)
157
+ ctx.groups = groups
158
+ ctx.dim = -1
159
+
160
+ return out
161
+
162
+ @staticmethod
163
+ @ensure_contiguous
164
+ def backward(ctx, grad_out):
165
+ if ctx.out_flat_sparse_saved:
166
+ scores_inf, activation_output, out_flat_sparse, weight, bias = ctx.saved_tensors
167
+ else:
168
+ scores_inf, activation_output, weight, bias = ctx.saved_tensors
169
+ out_flat_sparse = None
170
+
171
+ use_sparsemax = ctx.sparse
172
+ dim = ctx.dim
173
+ stride, padding, dilation, groups = (ctx.stride, ctx.padding, ctx.dilation, ctx.groups)
174
+
175
+ grad_conv = _mask_zero_backward(grad_out)
176
+
177
+ grad_probs = F.conv_transpose2d(
178
+ grad_conv, weight, None, stride=stride, padding=padding, dilation=dilation, groups=groups
179
+ )
180
+
181
+ grad_weight = torch.nn.grad.conv2d_weight(
182
+ input=activation_output,
183
+ weight_size=weight.shape,
184
+ grad_output=grad_conv,
185
+ stride=stride,
186
+ padding=padding,
187
+ dilation=dilation,
188
+ groups=groups,
189
+ )
190
+ grad_bias = None
191
+ if bias is not None:
192
+ grad_bias = grad_conv.sum(dim=(0, 2, 3))
193
+
194
+ grad_scores_inf = None
195
+ if use_sparsemax:
196
+ if not ctx.out_flat_sparse_saved or out_flat_sparse is None:
197
+ raise RuntimeError("Internal error: Sparse flag is set but sparse tensor was not saved.")
198
+ grad_scores_inf = _sparsemax_backward(grad_probs, out_flat_sparse, dim=dim)
199
+ else:
200
+ grad_probs_cont = grad_probs
201
+ probs_cont = activation_output
202
+ dot = (grad_probs_cont * probs_cont).sum(dim=-1, keepdim=True)
203
+ grad_scores_inf = probs_cont * (grad_probs_cont - dot)
204
+
205
+ grad_scores = _mask_inf_backward(grad_scores_inf)
206
+
207
+ return (grad_scores, grad_weight, grad_bias, None, None, None, None, None)
@@ -194,6 +194,175 @@ def _rms_norm_backward_kernel(
194
194
  tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
195
195
 
196
196
 
197
+ @triton.jit
198
+ def _block_rms_norm_forward_kernel(
199
+ Y_ptr,
200
+ Y_row_stride,
201
+ X_ptr,
202
+ X_row_stride,
203
+ W_ptr,
204
+ W_row_stride,
205
+ RSTD_ptr,
206
+ RSTD_row_stride,
207
+ n_rows,
208
+ n_cols,
209
+ eps,
210
+ offset,
211
+ casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
212
+ BLOCK_SIZE: tl.constexpr,
213
+ BLOCK_ROW: tl.constexpr,
214
+ ):
215
+ """
216
+ y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
217
+
218
+ Reference:
219
+ 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
220
+ 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
221
+ 3. https://arxiv.org/pdf/1910.07467
222
+ """
223
+
224
+ row_idx = tl.program_id(0) * BLOCK_ROW + tl.arange(0, BLOCK_ROW)
225
+ col_offsets = tl.arange(0, BLOCK_SIZE)
226
+ row_mask = row_idx < n_rows
227
+ col_mask = col_offsets < n_cols
228
+
229
+ X_row = tl.load(
230
+ X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
231
+ mask=row_mask[:, None] & col_mask[None, :],
232
+ other=0,
233
+ )
234
+ X_row_dtype = X_row.dtype
235
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
236
+
237
+ # On Llama, only rstd is computed on fp32
238
+ if casting_mode == _CASTING_MODE_LLAMA:
239
+ X_row = X_row.to(tl.float32)
240
+
241
+ # Gemma computes everything on fp32, and then casts back the output to the original dtype
242
+ if casting_mode == _CASTING_MODE_GEMMA:
243
+ W_row = W_row.to(tl.float32)
244
+ X_row = X_row.to(tl.float32)
245
+
246
+ if casting_mode == _CASTING_MODE_NONE:
247
+ eps = eps.to(X_row_dtype)
248
+ offset = offset.to(X_row_dtype)
249
+
250
+ mean_square = tl.sum(X_row * X_row, axis=1) / n_cols
251
+ rstd = rsqrt(mean_square + eps)
252
+
253
+ # We can save time by caching rms with minimal memory overhead
254
+ # because rms is much smaller compared to X_row, as rms is for each row.
255
+ # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
256
+ tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask)
257
+
258
+ X_row = X_row * rstd[:, None]
259
+
260
+ # On Llama, the multiplication with the weight is done on the original dtype
261
+ if casting_mode == _CASTING_MODE_LLAMA:
262
+ X_row = X_row.to(X_row_dtype)
263
+
264
+ Y_row = X_row * (offset + W_row)[None, :]
265
+
266
+ if casting_mode == _CASTING_MODE_GEMMA:
267
+ Y_row = Y_row.to(X_row_dtype)
268
+
269
+ tl.store(
270
+ Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :],
271
+ Y_row,
272
+ mask=row_mask[:, None] & col_mask[None, :],
273
+ )
274
+
275
+
276
+ @triton.jit
277
+ def _block_rms_norm_backward_kernel(
278
+ dY_ptr,
279
+ dY_row_stride,
280
+ dX_ptr,
281
+ dX_row_stride,
282
+ X_ptr,
283
+ X_row_stride,
284
+ X_dtype: tl.constexpr,
285
+ W_ptr,
286
+ W_row_stride,
287
+ RSTD_ptr,
288
+ RSTD_row_stride,
289
+ dW_ptr,
290
+ dW_row_stride,
291
+ n_rows,
292
+ n_cols,
293
+ offset,
294
+ rows_per_program: tl.constexpr,
295
+ casting_mode: tl.constexpr,
296
+ BLOCK_SIZE: tl.constexpr,
297
+ BLOCK_ROW: tl.constexpr,
298
+ ):
299
+ """
300
+ dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
301
+ dw = sum(dy * (x / RMS)). summation over BxT dimension
302
+ """
303
+
304
+ pid = tl.program_id(0).cast(tl.int64)
305
+ NUM_SMS = tl.num_programs(0)
306
+
307
+ col_offsets = tl.arange(0, BLOCK_SIZE)
308
+ col_mask = col_offsets < n_cols
309
+
310
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
311
+
312
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
313
+ W_row = W_row + offset
314
+
315
+ for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
316
+ row_idx = start + tl.arange(0, BLOCK_ROW)
317
+ row_mask = row_idx < n_rows
318
+ dY_row = tl.load(
319
+ dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :],
320
+ mask=row_mask[:, None] & col_mask[None, :],
321
+ other=0.0,
322
+ )
323
+ X_row = tl.load(
324
+ X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
325
+ mask=row_mask[:, None] & col_mask[None, :],
326
+ other=0.0,
327
+ )
328
+
329
+ # Get cached rms
330
+ rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, row_mask)
331
+
332
+ X_row = X_row.to(tl.float32)
333
+
334
+ # Different bacward graphs for different casting modes
335
+ if casting_mode == _CASTING_MODE_LLAMA:
336
+ m = (dY_row * W_row[None, :]).to(tl.float32)
337
+
338
+ elif casting_mode == _CASTING_MODE_GEMMA:
339
+ dY_row = dY_row.to(tl.float32)
340
+ m = dY_row * W_row[None, :]
341
+ else:
342
+ m = dY_row * W_row[None, :]
343
+
344
+ dX_row = rstd_row[:, None] * m
345
+
346
+ dX_row += (rstd_row[:, None]) * (
347
+ -(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
348
+ )
349
+
350
+ # calculate the gradient of W
351
+ if casting_mode == _CASTING_MODE_LLAMA:
352
+ dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]).to(X_dtype), 0)
353
+ else:
354
+ # here X_row is already in fp32 (see previous if block)
355
+ dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
356
+
357
+ tl.store(
358
+ dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
359
+ dX_row,
360
+ mask=row_mask[:, None] & col_mask[None, :],
361
+ )
362
+
363
+ tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
364
+
365
+
197
366
  _str_to_casting_mode = {
198
367
  "llama": _CASTING_MODE_LLAMA.value,
199
368
  "gemma": _CASTING_MODE_GEMMA.value,
@@ -201,7 +370,7 @@ _str_to_casting_mode = {
201
370
  }
202
371
 
203
372
 
204
- def rms_norm_forward(X, W, eps, offset, casting_mode):
373
+ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
205
374
  if not isinstance(casting_mode, int):
206
375
  assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
207
376
  casting_mode = _str_to_casting_mode[casting_mode]
@@ -227,27 +396,49 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
227
396
  kernel_args = {}
228
397
  if X.device.type == "xpu":
229
398
  kernel_args["grf_mode"] = "large"
230
- _rms_norm_forward_kernel[(n_rows,)](
231
- Y,
232
- Y.stride(0),
233
- X,
234
- X.stride(0),
235
- W,
236
- W.stride(0),
237
- RSTD,
238
- RSTD.stride(0),
239
- n_cols,
240
- eps,
241
- offset,
242
- casting_mode,
243
- BLOCK_SIZE=BLOCK_SIZE,
244
- num_warps=num_warps,
245
- **kernel_args, # XPU-specific optimization
246
- )
399
+ if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
400
+ _rms_norm_forward_kernel[(n_rows,)](
401
+ Y,
402
+ Y.stride(0),
403
+ X,
404
+ X.stride(0),
405
+ W,
406
+ W.stride(0),
407
+ RSTD,
408
+ RSTD.stride(0),
409
+ n_cols,
410
+ eps,
411
+ offset,
412
+ casting_mode,
413
+ BLOCK_SIZE=BLOCK_SIZE,
414
+ num_warps=num_warps,
415
+ **kernel_args, # XPU-specific optimization
416
+ )
417
+ else:
418
+ BLOCK_ROW = 16
419
+ kernel_args["BLOCK_ROW"] = BLOCK_ROW
420
+ _block_rms_norm_forward_kernel[(triton.cdiv(n_rows, BLOCK_ROW),)](
421
+ Y,
422
+ Y.stride(0),
423
+ X,
424
+ X.stride(0),
425
+ W,
426
+ W.stride(0),
427
+ RSTD,
428
+ RSTD.stride(0),
429
+ n_rows,
430
+ n_cols,
431
+ eps,
432
+ offset,
433
+ casting_mode,
434
+ BLOCK_SIZE=BLOCK_SIZE,
435
+ num_warps=num_warps,
436
+ **kernel_args, # XPU-specific optimization
437
+ )
247
438
  return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
248
439
 
249
440
 
250
- def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
441
+ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode):
251
442
  shape = dY.shape
252
443
  dim = shape[-1]
253
444
  dY = dY.view(-1, dim)
@@ -277,29 +468,56 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
277
468
  if X.device.type == "xpu":
278
469
  kernel_args["grf_mode"] = "large"
279
470
 
280
- _rms_norm_backward_kernel[grid](
281
- dY,
282
- dY.stride(0),
283
- dX,
284
- dX.stride(0),
285
- X,
286
- X.stride(0),
287
- torch_to_triton_dtype[X.dtype],
288
- W,
289
- W.stride(0),
290
- RSTD,
291
- RSTD.stride(0),
292
- _dW,
293
- _dW.stride(0),
294
- n_rows,
295
- n_cols,
296
- offset,
297
- rows_per_program,
298
- casting_mode,
299
- BLOCK_SIZE=BLOCK_SIZE,
300
- num_warps=num_warps,
301
- **kernel_args, # XPU-specific optimization
302
- )
471
+ if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
472
+ _rms_norm_backward_kernel[grid](
473
+ dY,
474
+ dY.stride(0),
475
+ dX,
476
+ dX.stride(0),
477
+ X,
478
+ X.stride(0),
479
+ torch_to_triton_dtype[X.dtype],
480
+ W,
481
+ W.stride(0),
482
+ RSTD,
483
+ RSTD.stride(0),
484
+ _dW,
485
+ _dW.stride(0),
486
+ n_rows,
487
+ n_cols,
488
+ offset,
489
+ rows_per_program,
490
+ casting_mode,
491
+ BLOCK_SIZE=BLOCK_SIZE,
492
+ num_warps=num_warps,
493
+ **kernel_args, # XPU-specific optimization
494
+ )
495
+ else:
496
+ BLOCK_ROW = 16
497
+ kernel_args["BLOCK_ROW"] = BLOCK_ROW
498
+ _block_rms_norm_backward_kernel[grid](
499
+ dY,
500
+ dY.stride(0),
501
+ dX,
502
+ dX.stride(0),
503
+ X,
504
+ X.stride(0),
505
+ torch_to_triton_dtype[X.dtype],
506
+ W,
507
+ W.stride(0),
508
+ RSTD,
509
+ RSTD.stride(0),
510
+ _dW,
511
+ _dW.stride(0),
512
+ n_rows,
513
+ n_cols,
514
+ offset,
515
+ rows_per_program,
516
+ casting_mode,
517
+ BLOCK_SIZE=BLOCK_SIZE,
518
+ num_warps=num_warps,
519
+ **kernel_args, # XPU-specific optimization
520
+ )
303
521
  dX = dX.view(*shape)
304
522
  dW = _dW.sum(dim=0).to(W.dtype)
305
523
 
@@ -330,15 +548,16 @@ class LigerRMSNormFunction(torch.autograd.Function):
330
548
 
331
549
  @staticmethod
332
550
  @ensure_contiguous
333
- def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True):
551
+ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None):
334
552
  """
335
553
  X: (B, T, H) or (BxT, H)
336
554
  W: (H,)
337
555
  """
338
- Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
556
+ Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
339
557
  ctx.offset = offset
340
558
  ctx.casting_mode = casting_mode
341
559
  ctx.in_place = in_place
560
+ ctx.row_mode = row_mode
342
561
  ctx.BLOCK_SIZE = BLOCK_SIZE
343
562
  ctx.num_warps = num_warps
344
563
  ctx.save_for_backward(X, W, RSTD)
@@ -352,14 +571,6 @@ class LigerRMSNormFunction(torch.autograd.Function):
352
571
  """
353
572
  X, W, RSTD = ctx.saved_tensors
354
573
  dX, dW = rms_norm_backward(
355
- dY,
356
- X,
357
- W,
358
- RSTD,
359
- ctx.offset,
360
- ctx.casting_mode,
361
- ctx.BLOCK_SIZE,
362
- ctx.num_warps,
363
- ctx.in_place,
574
+ dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
364
575
  )
365
- return dX, dW, None, None, None, None
576
+ return dX, dW, None, None, None, None, None