liger-kernel-nightly 0.5.6.dev20250403190551__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +35 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +25 -9
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  15. liger_kernel/ops/backends/registry.py +61 -0
  16. liger_kernel/ops/cross_entropy.py +124 -64
  17. liger_kernel/ops/dyt.py +115 -180
  18. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  19. liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
  20. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  21. liger_kernel/ops/geglu.py +3 -2
  22. liger_kernel/ops/group_norm.py +2 -1
  23. liger_kernel/ops/grpo_loss.py +312 -0
  24. liger_kernel/ops/jsd.py +2 -1
  25. liger_kernel/ops/kl_div.py +13 -6
  26. liger_kernel/ops/layer_norm.py +146 -78
  27. liger_kernel/ops/llama4_rope.py +225 -0
  28. liger_kernel/ops/multi_token_attention.py +207 -0
  29. liger_kernel/ops/poly_norm.py +390 -0
  30. liger_kernel/ops/rms_norm.py +283 -56
  31. liger_kernel/ops/rope.py +1 -1
  32. liger_kernel/ops/softmax.py +201 -0
  33. liger_kernel/ops/sparsemax.py +179 -0
  34. liger_kernel/ops/swiglu.py +1 -1
  35. liger_kernel/ops/tiled_mlp.py +136 -0
  36. liger_kernel/ops/utils.py +2 -0
  37. liger_kernel/transformers/__init__.py +205 -19
  38. liger_kernel/transformers/cross_entropy.py +9 -4
  39. liger_kernel/transformers/dyt.py +6 -4
  40. liger_kernel/transformers/experimental/__init__.py +5 -0
  41. liger_kernel/transformers/experimental/embedding.py +1 -1
  42. liger_kernel/transformers/fsdp.py +55 -0
  43. liger_kernel/transformers/functional.py +122 -20
  44. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  45. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  46. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  47. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  48. liger_kernel/transformers/geglu.py +1 -1
  49. liger_kernel/transformers/group_norm.py +1 -1
  50. liger_kernel/transformers/grpo_loss.py +153 -0
  51. liger_kernel/transformers/jsd.py +1 -1
  52. liger_kernel/transformers/kl_div.py +1 -1
  53. liger_kernel/transformers/layer_norm.py +1 -1
  54. liger_kernel/transformers/llama4_rope.py +93 -0
  55. liger_kernel/transformers/model/falcon_h1.py +122 -0
  56. liger_kernel/transformers/model/gemma.py +50 -25
  57. liger_kernel/transformers/model/gemma2.py +55 -23
  58. liger_kernel/transformers/model/gemma3.py +117 -120
  59. liger_kernel/transformers/model/glm4.py +141 -0
  60. liger_kernel/transformers/model/glm4v.py +163 -0
  61. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  62. liger_kernel/transformers/model/gpt_oss.py +211 -0
  63. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  64. liger_kernel/transformers/model/internvl.py +157 -0
  65. liger_kernel/transformers/model/llama.py +102 -25
  66. liger_kernel/transformers/model/llama4.py +121 -0
  67. liger_kernel/transformers/model/llava.py +111 -136
  68. liger_kernel/transformers/model/loss_utils.py +50 -12
  69. liger_kernel/transformers/model/mistral.py +36 -23
  70. liger_kernel/transformers/model/mixtral.py +45 -25
  71. liger_kernel/transformers/model/mllama.py +39 -22
  72. liger_kernel/transformers/model/olmo2.py +40 -20
  73. liger_kernel/transformers/model/olmo3.py +142 -0
  74. liger_kernel/transformers/model/output_classes.py +147 -0
  75. liger_kernel/transformers/model/paligemma.py +50 -14
  76. liger_kernel/transformers/model/phi3.py +47 -177
  77. liger_kernel/transformers/model/qwen2.py +48 -21
  78. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  79. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  80. liger_kernel/transformers/model/qwen3.py +136 -0
  81. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  82. liger_kernel/transformers/model/qwen3_next.py +146 -0
  83. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  84. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  85. liger_kernel/transformers/model/smollm3.py +199 -0
  86. liger_kernel/transformers/model/smolvlm.py +158 -0
  87. liger_kernel/transformers/monkey_patch.py +1678 -160
  88. liger_kernel/transformers/multi_token_attention.py +64 -0
  89. liger_kernel/transformers/poly_norm.py +42 -0
  90. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  91. liger_kernel/transformers/rms_norm.py +48 -5
  92. liger_kernel/transformers/rope.py +45 -1
  93. liger_kernel/transformers/softmax.py +12 -0
  94. liger_kernel/transformers/sparsemax.py +16 -0
  95. liger_kernel/transformers/swiglu.py +39 -1
  96. liger_kernel/transformers/tiled_mlp.py +133 -0
  97. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  98. liger_kernel/transformers/tvd.py +1 -1
  99. liger_kernel/utils.py +36 -0
  100. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/METADATA +68 -38
  101. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  102. liger_kernel/transformers/gema3_rms.py +0 -8
  103. liger_kernel_nightly-0.5.6.dev20250403190551.dist-info/RECORD +0 -82
  104. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  105. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/NOTICE +0 -0
  106. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +0 -0
  107. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -21,8 +21,10 @@ from liger_kernel.ops.utils import calculate_settings
21
21
  from liger_kernel.ops.utils import compare_version
22
22
  from liger_kernel.ops.utils import ensure_contiguous
23
23
  from liger_kernel.ops.utils import torch_to_triton_dtype
24
+ from liger_kernel.utils import get_npu_multi_processor_count
25
+ from liger_kernel.utils import is_npu_available
24
26
 
25
- if compare_version("triton", operator.ge, "3.0.0"):
27
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
26
28
  try:
27
29
  # typical import path with dispatch available
28
30
  from triton.language.extra.libdevice import rsqrt
@@ -63,7 +65,7 @@ def _rms_norm_forward_kernel(
63
65
  3. https://arxiv.org/pdf/1910.07467
64
66
  """
65
67
 
66
- row_idx = tl.program_id(0)
68
+ row_idx = tl.program_id(0).to(tl.int64)
67
69
  col_offsets = tl.arange(0, BLOCK_SIZE)
68
70
  mask = col_offsets < n_cols
69
71
 
@@ -137,7 +139,7 @@ def _rms_norm_backward_kernel(
137
139
  dw = sum(dy * (x / RMS)). summation over BxT dimension
138
140
  """
139
141
 
140
- row_block_id = tl.program_id(0)
142
+ row_block_id = tl.program_id(0).to(tl.int64)
141
143
  row_start = row_block_id * rows_per_program
142
144
  row_end = min((row_block_id + 1) * rows_per_program, n_rows)
143
145
  col_offsets = tl.arange(0, BLOCK_SIZE)
@@ -194,6 +196,176 @@ def _rms_norm_backward_kernel(
194
196
  tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
195
197
 
196
198
 
199
+ @triton.jit
200
+ def _block_rms_norm_forward_kernel(
201
+ Y_ptr,
202
+ Y_row_stride,
203
+ X_ptr,
204
+ X_row_stride,
205
+ W_ptr,
206
+ W_row_stride,
207
+ RSTD_ptr,
208
+ RSTD_row_stride,
209
+ n_rows,
210
+ n_cols,
211
+ eps,
212
+ offset,
213
+ casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
214
+ BLOCK_SIZE: tl.constexpr,
215
+ BLOCK_ROW: tl.constexpr,
216
+ ):
217
+ """
218
+ y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
219
+
220
+ Reference:
221
+ 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
222
+ 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
223
+ 3. https://arxiv.org/pdf/1910.07467
224
+ """
225
+
226
+ row_idx = tl.program_id(0) * BLOCK_ROW + tl.arange(0, BLOCK_ROW)
227
+ col_offsets = tl.arange(0, BLOCK_SIZE)
228
+ row_mask = row_idx < n_rows
229
+ col_mask = col_offsets < n_cols
230
+
231
+ X_row = tl.load(
232
+ X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
233
+ mask=row_mask[:, None] & col_mask[None, :],
234
+ other=0,
235
+ )
236
+ X_row_dtype = X_row.dtype
237
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
238
+
239
+ # On Llama, only rstd is computed on fp32
240
+ if casting_mode == _CASTING_MODE_LLAMA:
241
+ X_row = X_row.to(tl.float32)
242
+
243
+ # Gemma computes everything on fp32, and then casts back the output to the original dtype
244
+ if casting_mode == _CASTING_MODE_GEMMA:
245
+ W_row = W_row.to(tl.float32)
246
+ X_row = X_row.to(tl.float32)
247
+
248
+ if casting_mode == _CASTING_MODE_NONE:
249
+ eps = eps.to(X_row_dtype)
250
+ offset = offset.to(X_row_dtype)
251
+
252
+ mean_square = tl.sum(X_row * X_row, axis=1) / n_cols
253
+ rstd = rsqrt(mean_square + eps)
254
+
255
+ # We can save time by caching rms with minimal memory overhead
256
+ # because rms is much smaller compared to X_row, as rms is for each row.
257
+ # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
258
+ tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask)
259
+
260
+ X_row = X_row * rstd[:, None]
261
+
262
+ # On Llama, the multiplication with the weight is done on the original dtype
263
+ if casting_mode == _CASTING_MODE_LLAMA:
264
+ X_row = X_row.to(X_row_dtype)
265
+
266
+ Y_row = X_row * (offset + W_row)[None, :]
267
+
268
+ if casting_mode == _CASTING_MODE_GEMMA:
269
+ Y_row = Y_row.to(X_row_dtype)
270
+
271
+ tl.store(
272
+ Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :],
273
+ Y_row,
274
+ mask=row_mask[:, None] & col_mask[None, :],
275
+ )
276
+
277
+
278
+ @triton.jit
279
+ def _block_rms_norm_backward_kernel(
280
+ dY_ptr,
281
+ dY_row_stride,
282
+ dX_ptr,
283
+ dX_row_stride,
284
+ X_ptr,
285
+ X_row_stride,
286
+ X_dtype: tl.constexpr,
287
+ W_ptr,
288
+ W_row_stride,
289
+ RSTD_ptr,
290
+ RSTD_row_stride,
291
+ dW_ptr,
292
+ dW_row_stride,
293
+ n_rows,
294
+ n_cols,
295
+ offset,
296
+ rows_per_program: tl.constexpr,
297
+ casting_mode: tl.constexpr,
298
+ BLOCK_SIZE: tl.constexpr,
299
+ BLOCK_ROW: tl.constexpr,
300
+ ):
301
+ """
302
+ dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
303
+ dw = sum(dy * (x / RMS)). summation over BxT dimension
304
+ """
305
+
306
+ pid = tl.program_id(0).cast(tl.int64)
307
+ NUM_SMS = tl.num_programs(0)
308
+
309
+ col_offsets = tl.arange(0, BLOCK_SIZE)
310
+ col_mask = col_offsets < n_cols
311
+
312
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
313
+
314
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
315
+ W_row = W_row + offset
316
+
317
+ for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
318
+ row_idx = start + tl.arange(0, BLOCK_ROW)
319
+ row_mask = row_idx < n_rows
320
+ dY_row = tl.load(
321
+ dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :],
322
+ mask=row_mask[:, None] & col_mask[None, :],
323
+ other=0.0,
324
+ )
325
+ X_row = tl.load(
326
+ X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
327
+ mask=row_mask[:, None] & col_mask[None, :],
328
+ other=0.0,
329
+ )
330
+
331
+ # Get cached rms
332
+ rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, row_mask)
333
+
334
+ X_row = X_row.to(tl.float32)
335
+
336
+ # Different bacward graphs for different casting modes
337
+ if casting_mode == _CASTING_MODE_LLAMA:
338
+ m = (dY_row * W_row[None, :]).to(tl.float32)
339
+
340
+ elif casting_mode == _CASTING_MODE_GEMMA:
341
+ dY_row = dY_row.to(tl.float32)
342
+ m = dY_row * W_row[None, :]
343
+ else:
344
+ m = dY_row * W_row[None, :]
345
+
346
+ dX_row = rstd_row[:, None] * m
347
+
348
+ dX_row += (rstd_row[:, None]) * (
349
+ -(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
350
+ )
351
+
352
+ # calculate the gradient of W
353
+ if casting_mode == _CASTING_MODE_LLAMA:
354
+ # TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
355
+ dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
356
+ else:
357
+ # here X_row is already in fp32 (see previous if block)
358
+ dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
359
+
360
+ tl.store(
361
+ dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
362
+ dX_row,
363
+ mask=row_mask[:, None] & col_mask[None, :],
364
+ )
365
+
366
+ tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
367
+
368
+
197
369
  _str_to_casting_mode = {
198
370
  "llama": _CASTING_MODE_LLAMA.value,
199
371
  "gemma": _CASTING_MODE_GEMMA.value,
@@ -201,7 +373,7 @@ _str_to_casting_mode = {
201
373
  }
202
374
 
203
375
 
204
- def rms_norm_forward(X, W, eps, offset, casting_mode):
376
+ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
205
377
  if not isinstance(casting_mode, int):
206
378
  assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
207
379
  casting_mode = _str_to_casting_mode[casting_mode]
@@ -223,26 +395,53 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
223
395
  # Check constraints.
224
396
  assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
225
397
 
226
- _rms_norm_forward_kernel[(n_rows,)](
227
- Y,
228
- Y.stride(0),
229
- X,
230
- X.stride(0),
231
- W,
232
- W.stride(0),
233
- RSTD,
234
- RSTD.stride(0),
235
- n_cols,
236
- eps,
237
- offset,
238
- casting_mode,
239
- BLOCK_SIZE=BLOCK_SIZE,
240
- num_warps=num_warps,
241
- )
398
+ # XPU-specific optimization
399
+ kernel_args = {}
400
+ if X.device.type == "xpu":
401
+ kernel_args["grf_mode"] = "large"
402
+ if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
403
+ _rms_norm_forward_kernel[(n_rows,)](
404
+ Y,
405
+ Y.stride(0),
406
+ X,
407
+ X.stride(0),
408
+ W,
409
+ W.stride(0),
410
+ RSTD,
411
+ RSTD.stride(0),
412
+ n_cols,
413
+ eps,
414
+ offset,
415
+ casting_mode,
416
+ BLOCK_SIZE=BLOCK_SIZE,
417
+ num_warps=num_warps,
418
+ **kernel_args, # XPU-specific optimization
419
+ )
420
+ else:
421
+ BLOCK_ROW = 16
422
+ kernel_args["BLOCK_ROW"] = BLOCK_ROW
423
+ _block_rms_norm_forward_kernel[(triton.cdiv(n_rows, BLOCK_ROW),)](
424
+ Y,
425
+ Y.stride(0),
426
+ X,
427
+ X.stride(0),
428
+ W,
429
+ W.stride(0),
430
+ RSTD,
431
+ RSTD.stride(0),
432
+ n_rows,
433
+ n_cols,
434
+ eps,
435
+ offset,
436
+ casting_mode,
437
+ BLOCK_SIZE=BLOCK_SIZE,
438
+ num_warps=num_warps,
439
+ **kernel_args, # XPU-specific optimization
440
+ )
242
441
  return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
243
442
 
244
443
 
245
- def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
444
+ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode):
246
445
  shape = dY.shape
247
446
  dim = shape[-1]
248
447
  dY = dY.view(-1, dim)
@@ -252,7 +451,9 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
252
451
  if X.device.type == "cuda":
253
452
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
254
453
  elif X.device.type == "xpu":
255
- sm_count = torch.xpu.get_device_properties(X.device).gpu_subslice_count
454
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
455
+ elif X.device.type == "npu":
456
+ sm_count = get_npu_multi_processor_count()
256
457
 
257
458
  # fp32 for numerical stability especially.
258
459
  _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
@@ -267,28 +468,61 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
267
468
  else:
268
469
  dX = torch.zeros_like(dY)
269
470
 
270
- _rms_norm_backward_kernel[grid](
271
- dY,
272
- dY.stride(0),
273
- dX,
274
- dX.stride(0),
275
- X,
276
- X.stride(0),
277
- torch_to_triton_dtype[X.dtype],
278
- W,
279
- W.stride(0),
280
- RSTD,
281
- RSTD.stride(0),
282
- _dW,
283
- _dW.stride(0),
284
- n_rows,
285
- n_cols,
286
- offset,
287
- rows_per_program,
288
- casting_mode,
289
- BLOCK_SIZE=BLOCK_SIZE,
290
- num_warps=num_warps,
291
- )
471
+ # XPU-specific optimization
472
+ kernel_args = {}
473
+ if X.device.type == "xpu":
474
+ kernel_args["grf_mode"] = "large"
475
+
476
+ if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
477
+ _rms_norm_backward_kernel[grid](
478
+ dY,
479
+ dY.stride(0),
480
+ dX,
481
+ dX.stride(0),
482
+ X,
483
+ X.stride(0),
484
+ torch_to_triton_dtype[X.dtype],
485
+ W,
486
+ W.stride(0),
487
+ RSTD,
488
+ RSTD.stride(0),
489
+ _dW,
490
+ _dW.stride(0),
491
+ n_rows,
492
+ n_cols,
493
+ offset,
494
+ rows_per_program,
495
+ casting_mode,
496
+ BLOCK_SIZE=BLOCK_SIZE,
497
+ num_warps=num_warps,
498
+ **kernel_args, # XPU-specific optimization
499
+ )
500
+ else:
501
+ BLOCK_ROW = 16
502
+ kernel_args["BLOCK_ROW"] = BLOCK_ROW
503
+ _block_rms_norm_backward_kernel[grid](
504
+ dY,
505
+ dY.stride(0),
506
+ dX,
507
+ dX.stride(0),
508
+ X,
509
+ X.stride(0),
510
+ torch_to_triton_dtype[X.dtype],
511
+ W,
512
+ W.stride(0),
513
+ RSTD,
514
+ RSTD.stride(0),
515
+ _dW,
516
+ _dW.stride(0),
517
+ n_rows,
518
+ n_cols,
519
+ offset,
520
+ rows_per_program,
521
+ casting_mode,
522
+ BLOCK_SIZE=BLOCK_SIZE,
523
+ num_warps=num_warps,
524
+ **kernel_args, # XPU-specific optimization
525
+ )
292
526
  dX = dX.view(*shape)
293
527
  dW = _dW.sum(dim=0).to(W.dtype)
294
528
 
@@ -319,15 +553,16 @@ class LigerRMSNormFunction(torch.autograd.Function):
319
553
 
320
554
  @staticmethod
321
555
  @ensure_contiguous
322
- def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True):
556
+ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None):
323
557
  """
324
558
  X: (B, T, H) or (BxT, H)
325
559
  W: (H,)
326
560
  """
327
- Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
561
+ Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
328
562
  ctx.offset = offset
329
563
  ctx.casting_mode = casting_mode
330
564
  ctx.in_place = in_place
565
+ ctx.row_mode = row_mode
331
566
  ctx.BLOCK_SIZE = BLOCK_SIZE
332
567
  ctx.num_warps = num_warps
333
568
  ctx.save_for_backward(X, W, RSTD)
@@ -341,14 +576,6 @@ class LigerRMSNormFunction(torch.autograd.Function):
341
576
  """
342
577
  X, W, RSTD = ctx.saved_tensors
343
578
  dX, dW = rms_norm_backward(
344
- dY,
345
- X,
346
- W,
347
- RSTD,
348
- ctx.offset,
349
- ctx.casting_mode,
350
- ctx.BLOCK_SIZE,
351
- ctx.num_warps,
352
- ctx.in_place,
579
+ dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
353
580
  )
354
- return dX, dW, None, None, None, None
581
+ return dX, dW, None, None, None, None, None
liger_kernel/ops/rope.py CHANGED
@@ -32,7 +32,7 @@ def _triton_rope(
32
32
 
33
33
  # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
34
34
  # stride: (seq_len * head_dim, head_dim, 1)
35
- pid = tl.program_id(0)
35
+ pid = tl.program_id(0).to(tl.int64)
36
36
 
37
37
  # locate start address
38
38
  q_ptr = q_ptr + pid * q_row_stride
@@ -0,0 +1,201 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import ensure_contiguous
9
+
10
+
11
+ @triton.jit
12
+ def _softmax_single_block_forward_kernel(
13
+ Y_ptr,
14
+ Y_row_stride,
15
+ X_ptr,
16
+ X_row_stride,
17
+ n_cols,
18
+ BLOCK_SIZE: tl.constexpr,
19
+ ):
20
+ row_id = tl.program_id(0)
21
+ offs = tl.arange(0, BLOCK_SIZE)
22
+ mask = offs < n_cols
23
+
24
+ x = tl.load(X_ptr + row_id * X_row_stride + offs, mask=mask, other=-float("inf"), cache_modifier=".ca")
25
+ m = tl.max(x, axis=0)
26
+ e = tl.exp(x - m)
27
+ d = tl.sum(e, axis=0)
28
+ y = e / d
29
+ tl.store(Y_ptr + row_id * Y_row_stride + offs, y, mask=mask, cache_modifier=".cs")
30
+
31
+
32
+ @triton.jit
33
+ def _softmax_multi_block_forward_kernel(
34
+ Y_ptr,
35
+ Y_row_stride,
36
+ X_ptr,
37
+ X_row_stride,
38
+ n_cols,
39
+ BLOCK_SIZE: tl.constexpr,
40
+ ):
41
+ row_id = tl.program_id(0)
42
+ offs = tl.arange(0, BLOCK_SIZE)
43
+
44
+ m = tl.float32(-float("inf"))
45
+ d = tl.float32(0.0)
46
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
47
+ idx = start + offs
48
+ mask = idx < n_cols
49
+ xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
50
+ blk_max = tl.max(xblk, axis=0)
51
+ new_m = tl.max(m, blk_max)
52
+ d = d * tl.exp(m - new_m) + tl.sum(tl.exp(xblk - new_m), axis=0)
53
+ m = new_m
54
+
55
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
56
+ idx = start + offs
57
+ mask = idx < n_cols
58
+ xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
59
+ yblk = tl.exp(xblk - m) / d
60
+ tl.store(Y_ptr + row_id * Y_row_stride + idx, yblk, mask=mask, cache_modifier=".cs")
61
+
62
+
63
+ @triton.jit
64
+ def _softmax_single_block_backward_kernel(
65
+ dy_ptr,
66
+ dy_stride,
67
+ y_ptr,
68
+ y_stride,
69
+ dx_ptr,
70
+ dx_stride,
71
+ n_cols,
72
+ BLOCK_SIZE: tl.constexpr,
73
+ ):
74
+ row_id = tl.program_id(0)
75
+ offs = tl.arange(0, BLOCK_SIZE)
76
+ mask = offs < n_cols
77
+
78
+ dy = tl.load(dy_ptr + row_id * dy_stride + offs, mask=mask, other=0.0)
79
+ y = tl.load(y_ptr + row_id * y_stride + offs, mask=mask, other=0.0, cache_modifier=".ca")
80
+ dot = tl.sum(dy * y, axis=0)
81
+ dx = y * (dy - dot)
82
+ tl.store(dx_ptr + row_id * dx_stride + offs, dx, mask=mask, cache_modifier=".wb")
83
+
84
+
85
+ @triton.jit
86
+ def _softmax_multi_block_backward_kernel(
87
+ dy_ptr,
88
+ dy_stride,
89
+ y_ptr,
90
+ y_stride,
91
+ dx_ptr,
92
+ dx_stride,
93
+ n_cols,
94
+ BLOCK_SIZE: tl.constexpr,
95
+ ):
96
+ row_id = tl.program_id(0)
97
+ offs = tl.arange(0, BLOCK_SIZE)
98
+ acc = tl.float32(0.0)
99
+
100
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
101
+ idx = start + offs
102
+ mask = idx < n_cols
103
+ dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
104
+ y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
105
+ acc += tl.sum(dy_blk * y_blk, axis=0)
106
+
107
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
108
+ idx = start + offs
109
+ mask = idx < n_cols
110
+ dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
111
+ y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
112
+ dx_blk = y_blk * (dy_blk - acc)
113
+ tl.store(dx_ptr + row_id * dx_stride + idx, dx_blk, mask=mask, cache_modifier=".wb")
114
+
115
+
116
+ def _softmax_forward(x: torch.Tensor) -> Tuple[torch.Tensor, int, int, bool]:
117
+ *batch, n_cols = x.shape
118
+ x2d = x.contiguous().view(-1, n_cols)
119
+ n_rows = x2d.shape[0]
120
+
121
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
122
+ y2d = torch.empty_like(x2d)
123
+
124
+ if n_cols <= BLOCK_SIZE:
125
+ _softmax_single_block_forward_kernel[(n_rows,)](
126
+ y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
127
+ )
128
+ multi_block_launch = False
129
+ else:
130
+ _softmax_multi_block_forward_kernel[(n_rows,)](
131
+ y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
132
+ )
133
+ multi_block_launch = True
134
+
135
+ return y2d.view(*batch, n_cols), BLOCK_SIZE, num_warps, multi_block_launch
136
+
137
+
138
+ def _softmax_backward(
139
+ dy: torch.Tensor,
140
+ y: torch.Tensor,
141
+ BLOCK_SIZE: int,
142
+ num_warps: int,
143
+ multi_block_launch: bool,
144
+ ) -> torch.Tensor:
145
+ *batch, n_cols = dy.shape
146
+ dy2d = dy.contiguous().view(-1, n_cols)
147
+ y2d = y.contiguous().view(-1, n_cols)
148
+ n_rows = dy2d.shape[0]
149
+ dx2d = torch.empty_like(dy2d)
150
+
151
+ if not multi_block_launch and n_cols <= BLOCK_SIZE:
152
+ _softmax_single_block_backward_kernel[(n_rows,)](
153
+ dy2d,
154
+ dy2d.stride(0),
155
+ y2d,
156
+ y2d.stride(0),
157
+ dx2d,
158
+ dx2d.stride(0),
159
+ n_cols,
160
+ BLOCK_SIZE=BLOCK_SIZE,
161
+ num_warps=num_warps,
162
+ )
163
+ else:
164
+ _softmax_multi_block_backward_kernel[(n_rows,)](
165
+ dy2d,
166
+ dy2d.stride(0),
167
+ y2d,
168
+ y2d.stride(0),
169
+ dx2d,
170
+ dx2d.stride(0),
171
+ n_cols,
172
+ BLOCK_SIZE=BLOCK_SIZE,
173
+ num_warps=num_warps,
174
+ )
175
+
176
+ return dx2d.view(*batch, n_cols)
177
+
178
+
179
+ class LigerSoftmaxFunction(torch.autograd.Function):
180
+ @staticmethod
181
+ @ensure_contiguous
182
+ def forward(ctx, input_: torch.Tensor):
183
+ y, BLOCK_SIZE, num_warps, multi_block_launch = _softmax_forward(input_)
184
+ ctx.save_for_backward(y)
185
+ ctx.BLOCK_SIZE = BLOCK_SIZE
186
+ ctx.num_warps = num_warps
187
+ ctx.multi_block_launch = multi_block_launch
188
+ return y
189
+
190
+ @staticmethod
191
+ @ensure_contiguous
192
+ def backward(ctx, grad_output):
193
+ (y,) = ctx.saved_tensors
194
+ dx = _softmax_backward(
195
+ grad_output,
196
+ y,
197
+ ctx.BLOCK_SIZE,
198
+ ctx.num_warps,
199
+ ctx.multi_block_launch,
200
+ )
201
+ return dx