liger-kernel 0.5.9__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +1 -1
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  6. liger_kernel/chunked_loss/jsd_loss.py +2 -2
  7. liger_kernel/ops/dyt.py +111 -179
  8. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  9. liger_kernel/ops/geglu.py +1 -1
  10. liger_kernel/ops/grpo_loss.py +310 -0
  11. liger_kernel/ops/multi_token_attention.py +207 -0
  12. liger_kernel/ops/rms_norm.py +265 -54
  13. liger_kernel/ops/softmax.py +201 -0
  14. liger_kernel/ops/sparsemax.py +179 -0
  15. liger_kernel/ops/swiglu.py +1 -1
  16. liger_kernel/transformers/__init__.py +8 -0
  17. liger_kernel/transformers/dyt.py +5 -3
  18. liger_kernel/transformers/fsdp.py +55 -0
  19. liger_kernel/transformers/functional.py +70 -0
  20. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  21. liger_kernel/transformers/grpo_loss.py +98 -0
  22. liger_kernel/transformers/model/gemma.py +25 -16
  23. liger_kernel/transformers/model/gemma2.py +27 -14
  24. liger_kernel/transformers/model/gemma3.py +62 -106
  25. liger_kernel/transformers/model/glm4.py +16 -13
  26. liger_kernel/transformers/model/llama.py +81 -18
  27. liger_kernel/transformers/model/llama4.py +108 -0
  28. liger_kernel/transformers/model/llava.py +95 -132
  29. liger_kernel/transformers/model/mistral.py +13 -14
  30. liger_kernel/transformers/model/mixtral.py +16 -15
  31. liger_kernel/transformers/model/mllama.py +16 -14
  32. liger_kernel/transformers/model/olmo2.py +16 -13
  33. liger_kernel/transformers/model/paligemma.py +8 -9
  34. liger_kernel/transformers/model/phi3.py +25 -16
  35. liger_kernel/transformers/model/qwen2.py +24 -15
  36. liger_kernel/transformers/model/qwen2_5_vl.py +41 -97
  37. liger_kernel/transformers/model/qwen2_vl.py +38 -106
  38. liger_kernel/transformers/model/qwen3.py +11 -9
  39. liger_kernel/transformers/model/qwen3_moe.py +132 -0
  40. liger_kernel/transformers/monkey_patch.py +424 -81
  41. liger_kernel/transformers/multi_token_attention.py +64 -0
  42. liger_kernel/transformers/rms_norm.py +40 -4
  43. liger_kernel/transformers/softmax.py +12 -0
  44. liger_kernel/transformers/sparsemax.py +16 -0
  45. liger_kernel/transformers/swiglu.py +21 -0
  46. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  47. liger_kernel/utils.py +11 -0
  48. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +41 -21
  49. liger_kernel-0.6.0.dist-info/RECORD +97 -0
  50. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
  51. liger_kernel/transformers/gema3_rms.py +0 -8
  52. liger_kernel-0.5.9.dist-info/RECORD +0 -84
  53. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
  54. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
  55. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
@@ -194,6 +194,175 @@ def _rms_norm_backward_kernel(
194
194
  tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
195
195
 
196
196
 
197
+ @triton.jit
198
+ def _block_rms_norm_forward_kernel(
199
+ Y_ptr,
200
+ Y_row_stride,
201
+ X_ptr,
202
+ X_row_stride,
203
+ W_ptr,
204
+ W_row_stride,
205
+ RSTD_ptr,
206
+ RSTD_row_stride,
207
+ n_rows,
208
+ n_cols,
209
+ eps,
210
+ offset,
211
+ casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
212
+ BLOCK_SIZE: tl.constexpr,
213
+ BLOCK_ROW: tl.constexpr,
214
+ ):
215
+ """
216
+ y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
217
+
218
+ Reference:
219
+ 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
220
+ 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
221
+ 3. https://arxiv.org/pdf/1910.07467
222
+ """
223
+
224
+ row_idx = tl.program_id(0) * BLOCK_ROW + tl.arange(0, BLOCK_ROW)
225
+ col_offsets = tl.arange(0, BLOCK_SIZE)
226
+ row_mask = row_idx < n_rows
227
+ col_mask = col_offsets < n_cols
228
+
229
+ X_row = tl.load(
230
+ X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
231
+ mask=row_mask[:, None] & col_mask[None, :],
232
+ other=0,
233
+ )
234
+ X_row_dtype = X_row.dtype
235
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
236
+
237
+ # On Llama, only rstd is computed on fp32
238
+ if casting_mode == _CASTING_MODE_LLAMA:
239
+ X_row = X_row.to(tl.float32)
240
+
241
+ # Gemma computes everything on fp32, and then casts back the output to the original dtype
242
+ if casting_mode == _CASTING_MODE_GEMMA:
243
+ W_row = W_row.to(tl.float32)
244
+ X_row = X_row.to(tl.float32)
245
+
246
+ if casting_mode == _CASTING_MODE_NONE:
247
+ eps = eps.to(X_row_dtype)
248
+ offset = offset.to(X_row_dtype)
249
+
250
+ mean_square = tl.sum(X_row * X_row, axis=1) / n_cols
251
+ rstd = rsqrt(mean_square + eps)
252
+
253
+ # We can save time by caching rms with minimal memory overhead
254
+ # because rms is much smaller compared to X_row, as rms is for each row.
255
+ # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
256
+ tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask)
257
+
258
+ X_row = X_row * rstd[:, None]
259
+
260
+ # On Llama, the multiplication with the weight is done on the original dtype
261
+ if casting_mode == _CASTING_MODE_LLAMA:
262
+ X_row = X_row.to(X_row_dtype)
263
+
264
+ Y_row = X_row * (offset + W_row)[None, :]
265
+
266
+ if casting_mode == _CASTING_MODE_GEMMA:
267
+ Y_row = Y_row.to(X_row_dtype)
268
+
269
+ tl.store(
270
+ Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :],
271
+ Y_row,
272
+ mask=row_mask[:, None] & col_mask[None, :],
273
+ )
274
+
275
+
276
+ @triton.jit
277
+ def _block_rms_norm_backward_kernel(
278
+ dY_ptr,
279
+ dY_row_stride,
280
+ dX_ptr,
281
+ dX_row_stride,
282
+ X_ptr,
283
+ X_row_stride,
284
+ X_dtype: tl.constexpr,
285
+ W_ptr,
286
+ W_row_stride,
287
+ RSTD_ptr,
288
+ RSTD_row_stride,
289
+ dW_ptr,
290
+ dW_row_stride,
291
+ n_rows,
292
+ n_cols,
293
+ offset,
294
+ rows_per_program: tl.constexpr,
295
+ casting_mode: tl.constexpr,
296
+ BLOCK_SIZE: tl.constexpr,
297
+ BLOCK_ROW: tl.constexpr,
298
+ ):
299
+ """
300
+ dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
301
+ dw = sum(dy * (x / RMS)). summation over BxT dimension
302
+ """
303
+
304
+ pid = tl.program_id(0).cast(tl.int64)
305
+ NUM_SMS = tl.num_programs(0)
306
+
307
+ col_offsets = tl.arange(0, BLOCK_SIZE)
308
+ col_mask = col_offsets < n_cols
309
+
310
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
311
+
312
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
313
+ W_row = W_row + offset
314
+
315
+ for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
316
+ row_idx = start + tl.arange(0, BLOCK_ROW)
317
+ row_mask = row_idx < n_rows
318
+ dY_row = tl.load(
319
+ dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :],
320
+ mask=row_mask[:, None] & col_mask[None, :],
321
+ other=0.0,
322
+ )
323
+ X_row = tl.load(
324
+ X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
325
+ mask=row_mask[:, None] & col_mask[None, :],
326
+ other=0.0,
327
+ )
328
+
329
+ # Get cached rms
330
+ rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, row_mask)
331
+
332
+ X_row = X_row.to(tl.float32)
333
+
334
+ # Different bacward graphs for different casting modes
335
+ if casting_mode == _CASTING_MODE_LLAMA:
336
+ m = (dY_row * W_row[None, :]).to(tl.float32)
337
+
338
+ elif casting_mode == _CASTING_MODE_GEMMA:
339
+ dY_row = dY_row.to(tl.float32)
340
+ m = dY_row * W_row[None, :]
341
+ else:
342
+ m = dY_row * W_row[None, :]
343
+
344
+ dX_row = rstd_row[:, None] * m
345
+
346
+ dX_row += (rstd_row[:, None]) * (
347
+ -(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
348
+ )
349
+
350
+ # calculate the gradient of W
351
+ if casting_mode == _CASTING_MODE_LLAMA:
352
+ dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]).to(X_dtype), 0)
353
+ else:
354
+ # here X_row is already in fp32 (see previous if block)
355
+ dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
356
+
357
+ tl.store(
358
+ dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
359
+ dX_row,
360
+ mask=row_mask[:, None] & col_mask[None, :],
361
+ )
362
+
363
+ tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
364
+
365
+
197
366
  _str_to_casting_mode = {
198
367
  "llama": _CASTING_MODE_LLAMA.value,
199
368
  "gemma": _CASTING_MODE_GEMMA.value,
@@ -201,7 +370,7 @@ _str_to_casting_mode = {
201
370
  }
202
371
 
203
372
 
204
- def rms_norm_forward(X, W, eps, offset, casting_mode):
373
+ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
205
374
  if not isinstance(casting_mode, int):
206
375
  assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
207
376
  casting_mode = _str_to_casting_mode[casting_mode]
@@ -227,27 +396,49 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
227
396
  kernel_args = {}
228
397
  if X.device.type == "xpu":
229
398
  kernel_args["grf_mode"] = "large"
230
- _rms_norm_forward_kernel[(n_rows,)](
231
- Y,
232
- Y.stride(0),
233
- X,
234
- X.stride(0),
235
- W,
236
- W.stride(0),
237
- RSTD,
238
- RSTD.stride(0),
239
- n_cols,
240
- eps,
241
- offset,
242
- casting_mode,
243
- BLOCK_SIZE=BLOCK_SIZE,
244
- num_warps=num_warps,
245
- **kernel_args, # XPU-specific optimization
246
- )
399
+ if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
400
+ _rms_norm_forward_kernel[(n_rows,)](
401
+ Y,
402
+ Y.stride(0),
403
+ X,
404
+ X.stride(0),
405
+ W,
406
+ W.stride(0),
407
+ RSTD,
408
+ RSTD.stride(0),
409
+ n_cols,
410
+ eps,
411
+ offset,
412
+ casting_mode,
413
+ BLOCK_SIZE=BLOCK_SIZE,
414
+ num_warps=num_warps,
415
+ **kernel_args, # XPU-specific optimization
416
+ )
417
+ else:
418
+ BLOCK_ROW = 16
419
+ kernel_args["BLOCK_ROW"] = BLOCK_ROW
420
+ _block_rms_norm_forward_kernel[(triton.cdiv(n_rows, BLOCK_ROW),)](
421
+ Y,
422
+ Y.stride(0),
423
+ X,
424
+ X.stride(0),
425
+ W,
426
+ W.stride(0),
427
+ RSTD,
428
+ RSTD.stride(0),
429
+ n_rows,
430
+ n_cols,
431
+ eps,
432
+ offset,
433
+ casting_mode,
434
+ BLOCK_SIZE=BLOCK_SIZE,
435
+ num_warps=num_warps,
436
+ **kernel_args, # XPU-specific optimization
437
+ )
247
438
  return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
248
439
 
249
440
 
250
- def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
441
+ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode):
251
442
  shape = dY.shape
252
443
  dim = shape[-1]
253
444
  dY = dY.view(-1, dim)
@@ -277,29 +468,56 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
277
468
  if X.device.type == "xpu":
278
469
  kernel_args["grf_mode"] = "large"
279
470
 
280
- _rms_norm_backward_kernel[grid](
281
- dY,
282
- dY.stride(0),
283
- dX,
284
- dX.stride(0),
285
- X,
286
- X.stride(0),
287
- torch_to_triton_dtype[X.dtype],
288
- W,
289
- W.stride(0),
290
- RSTD,
291
- RSTD.stride(0),
292
- _dW,
293
- _dW.stride(0),
294
- n_rows,
295
- n_cols,
296
- offset,
297
- rows_per_program,
298
- casting_mode,
299
- BLOCK_SIZE=BLOCK_SIZE,
300
- num_warps=num_warps,
301
- **kernel_args, # XPU-specific optimization
302
- )
471
+ if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
472
+ _rms_norm_backward_kernel[grid](
473
+ dY,
474
+ dY.stride(0),
475
+ dX,
476
+ dX.stride(0),
477
+ X,
478
+ X.stride(0),
479
+ torch_to_triton_dtype[X.dtype],
480
+ W,
481
+ W.stride(0),
482
+ RSTD,
483
+ RSTD.stride(0),
484
+ _dW,
485
+ _dW.stride(0),
486
+ n_rows,
487
+ n_cols,
488
+ offset,
489
+ rows_per_program,
490
+ casting_mode,
491
+ BLOCK_SIZE=BLOCK_SIZE,
492
+ num_warps=num_warps,
493
+ **kernel_args, # XPU-specific optimization
494
+ )
495
+ else:
496
+ BLOCK_ROW = 16
497
+ kernel_args["BLOCK_ROW"] = BLOCK_ROW
498
+ _block_rms_norm_backward_kernel[grid](
499
+ dY,
500
+ dY.stride(0),
501
+ dX,
502
+ dX.stride(0),
503
+ X,
504
+ X.stride(0),
505
+ torch_to_triton_dtype[X.dtype],
506
+ W,
507
+ W.stride(0),
508
+ RSTD,
509
+ RSTD.stride(0),
510
+ _dW,
511
+ _dW.stride(0),
512
+ n_rows,
513
+ n_cols,
514
+ offset,
515
+ rows_per_program,
516
+ casting_mode,
517
+ BLOCK_SIZE=BLOCK_SIZE,
518
+ num_warps=num_warps,
519
+ **kernel_args, # XPU-specific optimization
520
+ )
303
521
  dX = dX.view(*shape)
304
522
  dW = _dW.sum(dim=0).to(W.dtype)
305
523
 
@@ -330,15 +548,16 @@ class LigerRMSNormFunction(torch.autograd.Function):
330
548
 
331
549
  @staticmethod
332
550
  @ensure_contiguous
333
- def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True):
551
+ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None):
334
552
  """
335
553
  X: (B, T, H) or (BxT, H)
336
554
  W: (H,)
337
555
  """
338
- Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
556
+ Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
339
557
  ctx.offset = offset
340
558
  ctx.casting_mode = casting_mode
341
559
  ctx.in_place = in_place
560
+ ctx.row_mode = row_mode
342
561
  ctx.BLOCK_SIZE = BLOCK_SIZE
343
562
  ctx.num_warps = num_warps
344
563
  ctx.save_for_backward(X, W, RSTD)
@@ -352,14 +571,6 @@ class LigerRMSNormFunction(torch.autograd.Function):
352
571
  """
353
572
  X, W, RSTD = ctx.saved_tensors
354
573
  dX, dW = rms_norm_backward(
355
- dY,
356
- X,
357
- W,
358
- RSTD,
359
- ctx.offset,
360
- ctx.casting_mode,
361
- ctx.BLOCK_SIZE,
362
- ctx.num_warps,
363
- ctx.in_place,
574
+ dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
364
575
  )
365
- return dX, dW, None, None, None, None
576
+ return dX, dW, None, None, None, None, None
@@ -0,0 +1,201 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import ensure_contiguous
9
+
10
+
11
+ @triton.jit
12
+ def _softmax_single_block_forward_kernel(
13
+ Y_ptr,
14
+ Y_row_stride,
15
+ X_ptr,
16
+ X_row_stride,
17
+ n_cols,
18
+ BLOCK_SIZE: tl.constexpr,
19
+ ):
20
+ row_id = tl.program_id(0)
21
+ offs = tl.arange(0, BLOCK_SIZE)
22
+ mask = offs < n_cols
23
+
24
+ x = tl.load(X_ptr + row_id * X_row_stride + offs, mask=mask, other=-float("inf"), cache_modifier=".ca")
25
+ m = tl.max(x, axis=0)
26
+ e = tl.exp(x - m)
27
+ d = tl.sum(e, axis=0)
28
+ y = e / d
29
+ tl.store(Y_ptr + row_id * Y_row_stride + offs, y, mask=mask, cache_modifier=".cs")
30
+
31
+
32
+ @triton.jit
33
+ def _softmax_multi_block_forward_kernel(
34
+ Y_ptr,
35
+ Y_row_stride,
36
+ X_ptr,
37
+ X_row_stride,
38
+ n_cols,
39
+ BLOCK_SIZE: tl.constexpr,
40
+ ):
41
+ row_id = tl.program_id(0)
42
+ offs = tl.arange(0, BLOCK_SIZE)
43
+
44
+ m = tl.float32(-float("inf"))
45
+ d = tl.float32(0.0)
46
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
47
+ idx = start + offs
48
+ mask = idx < n_cols
49
+ xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
50
+ blk_max = tl.max(xblk, axis=0)
51
+ new_m = tl.max(m, blk_max)
52
+ d = d * tl.exp(m - new_m) + tl.sum(tl.exp(xblk - new_m), axis=0)
53
+ m = new_m
54
+
55
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
56
+ idx = start + offs
57
+ mask = idx < n_cols
58
+ xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
59
+ yblk = tl.exp(xblk - m) / d
60
+ tl.store(Y_ptr + row_id * Y_row_stride + idx, yblk, mask=mask, cache_modifier=".cs")
61
+
62
+
63
+ @triton.jit
64
+ def _softmax_single_block_backward_kernel(
65
+ dy_ptr,
66
+ dy_stride,
67
+ y_ptr,
68
+ y_stride,
69
+ dx_ptr,
70
+ dx_stride,
71
+ n_cols,
72
+ BLOCK_SIZE: tl.constexpr,
73
+ ):
74
+ row_id = tl.program_id(0)
75
+ offs = tl.arange(0, BLOCK_SIZE)
76
+ mask = offs < n_cols
77
+
78
+ dy = tl.load(dy_ptr + row_id * dy_stride + offs, mask=mask, other=0.0)
79
+ y = tl.load(y_ptr + row_id * y_stride + offs, mask=mask, other=0.0, cache_modifier=".ca")
80
+ dot = tl.sum(dy * y, axis=0)
81
+ dx = y * (dy - dot)
82
+ tl.store(dx_ptr + row_id * dx_stride + offs, dx, mask=mask, cache_modifier=".wb")
83
+
84
+
85
+ @triton.jit
86
+ def _softmax_multi_block_backward_kernel(
87
+ dy_ptr,
88
+ dy_stride,
89
+ y_ptr,
90
+ y_stride,
91
+ dx_ptr,
92
+ dx_stride,
93
+ n_cols,
94
+ BLOCK_SIZE: tl.constexpr,
95
+ ):
96
+ row_id = tl.program_id(0)
97
+ offs = tl.arange(0, BLOCK_SIZE)
98
+ acc = tl.float32(0.0)
99
+
100
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
101
+ idx = start + offs
102
+ mask = idx < n_cols
103
+ dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
104
+ y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
105
+ acc += tl.sum(dy_blk * y_blk, axis=0)
106
+
107
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
108
+ idx = start + offs
109
+ mask = idx < n_cols
110
+ dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
111
+ y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
112
+ dx_blk = y_blk * (dy_blk - acc)
113
+ tl.store(dx_ptr + row_id * dx_stride + idx, dx_blk, mask=mask, cache_modifier=".wb")
114
+
115
+
116
+ def _softmax_forward(x: torch.Tensor) -> Tuple[torch.Tensor, int, int, bool]:
117
+ *batch, n_cols = x.shape
118
+ x2d = x.contiguous().view(-1, n_cols)
119
+ n_rows = x2d.shape[0]
120
+
121
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
122
+ y2d = torch.empty_like(x2d)
123
+
124
+ if n_cols <= BLOCK_SIZE:
125
+ _softmax_single_block_forward_kernel[(n_rows,)](
126
+ y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
127
+ )
128
+ multi_block_launch = False
129
+ else:
130
+ _softmax_multi_block_forward_kernel[(n_rows,)](
131
+ y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
132
+ )
133
+ multi_block_launch = True
134
+
135
+ return y2d.view(*batch, n_cols), BLOCK_SIZE, num_warps, multi_block_launch
136
+
137
+
138
+ def _softmax_backward(
139
+ dy: torch.Tensor,
140
+ y: torch.Tensor,
141
+ BLOCK_SIZE: int,
142
+ num_warps: int,
143
+ multi_block_launch: bool,
144
+ ) -> torch.Tensor:
145
+ *batch, n_cols = dy.shape
146
+ dy2d = dy.contiguous().view(-1, n_cols)
147
+ y2d = y.contiguous().view(-1, n_cols)
148
+ n_rows = dy2d.shape[0]
149
+ dx2d = torch.empty_like(dy2d)
150
+
151
+ if not multi_block_launch and n_cols <= BLOCK_SIZE:
152
+ _softmax_single_block_backward_kernel[(n_rows,)](
153
+ dy2d,
154
+ dy2d.stride(0),
155
+ y2d,
156
+ y2d.stride(0),
157
+ dx2d,
158
+ dx2d.stride(0),
159
+ n_cols,
160
+ BLOCK_SIZE=BLOCK_SIZE,
161
+ num_warps=num_warps,
162
+ )
163
+ else:
164
+ _softmax_multi_block_backward_kernel[(n_rows,)](
165
+ dy2d,
166
+ dy2d.stride(0),
167
+ y2d,
168
+ y2d.stride(0),
169
+ dx2d,
170
+ dx2d.stride(0),
171
+ n_cols,
172
+ BLOCK_SIZE=BLOCK_SIZE,
173
+ num_warps=num_warps,
174
+ )
175
+
176
+ return dx2d.view(*batch, n_cols)
177
+
178
+
179
+ class LigerSoftmaxFunction(torch.autograd.Function):
180
+ @staticmethod
181
+ @ensure_contiguous
182
+ def forward(ctx, input_: torch.Tensor):
183
+ y, BLOCK_SIZE, num_warps, multi_block_launch = _softmax_forward(input_)
184
+ ctx.save_for_backward(y)
185
+ ctx.BLOCK_SIZE = BLOCK_SIZE
186
+ ctx.num_warps = num_warps
187
+ ctx.multi_block_launch = multi_block_launch
188
+ return y
189
+
190
+ @staticmethod
191
+ @ensure_contiguous
192
+ def backward(ctx, grad_output):
193
+ (y,) = ctx.saved_tensors
194
+ dx = _softmax_backward(
195
+ grad_output,
196
+ y,
197
+ ctx.BLOCK_SIZE,
198
+ ctx.num_warps,
199
+ ctx.multi_block_launch,
200
+ )
201
+ return dx