liger-kernel-nightly 0.6.2.dev20251011154427__py3-none-any.whl → 0.6.4.dev20260107111351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (97) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +39 -11
  6. liger_kernel/ops/__init__.py +141 -0
  7. liger_kernel/ops/backends/README.md +151 -0
  8. liger_kernel/ops/backends/__init__.py +13 -0
  9. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  10. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  11. liger_kernel/ops/backends/_ascend/ops/__init__.py +43 -0
  12. liger_kernel/ops/backends/_ascend/ops/geglu.py +244 -0
  13. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  14. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  15. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  16. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  17. liger_kernel/ops/backends/registry.py +61 -0
  18. liger_kernel/ops/cross_entropy.py +75 -12
  19. liger_kernel/ops/dyt.py +5 -2
  20. liger_kernel/ops/fused_add_rms_norm.py +5 -1
  21. liger_kernel/ops/fused_linear_cross_entropy.py +45 -14
  22. liger_kernel/ops/geglu.py +5 -3
  23. liger_kernel/ops/group_norm.py +2 -1
  24. liger_kernel/ops/grpo_loss.py +3 -1
  25. liger_kernel/ops/layer_norm.py +86 -66
  26. liger_kernel/ops/poly_norm.py +390 -0
  27. liger_kernel/ops/rms_norm.py +131 -49
  28. liger_kernel/ops/tiled_mlp.py +136 -0
  29. liger_kernel/ops/utils.py +14 -0
  30. liger_kernel/transformers/__init__.py +30 -0
  31. liger_kernel/transformers/auto_model.py +21 -0
  32. liger_kernel/transformers/cross_entropy.py +9 -4
  33. liger_kernel/transformers/dyt.py +1 -1
  34. liger_kernel/transformers/experimental/embedding.py +1 -1
  35. liger_kernel/transformers/functional.py +48 -25
  36. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  37. liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
  38. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  39. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  40. liger_kernel/transformers/geglu.py +1 -1
  41. liger_kernel/transformers/group_norm.py +1 -1
  42. liger_kernel/transformers/grpo_loss.py +57 -2
  43. liger_kernel/transformers/jsd.py +1 -1
  44. liger_kernel/transformers/kl_div.py +1 -1
  45. liger_kernel/transformers/layer_norm.py +1 -1
  46. liger_kernel/transformers/llama4_rope.py +1 -1
  47. liger_kernel/transformers/model/falcon_h1.py +19 -5
  48. liger_kernel/transformers/model/gemma.py +17 -6
  49. liger_kernel/transformers/model/gemma2.py +14 -5
  50. liger_kernel/transformers/model/gemma3.py +26 -12
  51. liger_kernel/transformers/model/glm4.py +16 -4
  52. liger_kernel/transformers/model/glm4v.py +16 -4
  53. liger_kernel/transformers/model/glm4v_moe.py +23 -4
  54. liger_kernel/transformers/model/gpt_oss.py +211 -0
  55. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  56. liger_kernel/transformers/model/internvl.py +12 -5
  57. liger_kernel/transformers/model/llama.py +14 -5
  58. liger_kernel/transformers/model/llama4.py +16 -4
  59. liger_kernel/transformers/model/llava.py +12 -4
  60. liger_kernel/transformers/model/loss_utils.py +31 -3
  61. liger_kernel/transformers/model/mistral.py +15 -6
  62. liger_kernel/transformers/model/mixtral.py +16 -7
  63. liger_kernel/transformers/model/mllama.py +12 -4
  64. liger_kernel/transformers/model/olmo2.py +16 -4
  65. liger_kernel/transformers/model/olmo3.py +142 -0
  66. liger_kernel/transformers/model/output_classes.py +147 -0
  67. liger_kernel/transformers/model/paligemma.py +23 -5
  68. liger_kernel/transformers/model/phi3.py +14 -7
  69. liger_kernel/transformers/model/qwen2.py +16 -3
  70. liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
  71. liger_kernel/transformers/model/qwen2_vl.py +16 -4
  72. liger_kernel/transformers/model/qwen3.py +20 -5
  73. liger_kernel/transformers/model/qwen3_moe.py +19 -5
  74. liger_kernel/transformers/model/qwen3_next.py +146 -0
  75. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  76. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  77. liger_kernel/transformers/model/smollm3.py +15 -6
  78. liger_kernel/transformers/model/smolvlm.py +158 -0
  79. liger_kernel/transformers/monkey_patch.py +702 -48
  80. liger_kernel/transformers/multi_token_attention.py +1 -1
  81. liger_kernel/transformers/poly_norm.py +42 -0
  82. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  83. liger_kernel/transformers/rms_norm.py +15 -3
  84. liger_kernel/transformers/rope.py +45 -1
  85. liger_kernel/transformers/softmax.py +1 -1
  86. liger_kernel/transformers/sparsemax.py +1 -1
  87. liger_kernel/transformers/swiglu.py +18 -1
  88. liger_kernel/transformers/tiled_mlp.py +133 -0
  89. liger_kernel/transformers/tvd.py +1 -1
  90. liger_kernel/utils.py +52 -0
  91. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/METADATA +12 -3
  92. liger_kernel_nightly-0.6.4.dev20260107111351.dist-info/RECORD +130 -0
  93. liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/RECORD +0 -107
  94. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/LICENSE +0 -0
  95. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/NOTICE +0 -0
  96. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/WHEEL +0 -0
  97. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/top_level.txt +0 -0
@@ -21,8 +21,10 @@ from liger_kernel.ops.utils import calculate_settings
21
21
  from liger_kernel.ops.utils import compare_version
22
22
  from liger_kernel.ops.utils import ensure_contiguous
23
23
  from liger_kernel.ops.utils import torch_to_triton_dtype
24
+ from liger_kernel.utils import get_npu_multi_processor_count
25
+ from liger_kernel.utils import is_npu_available
24
26
 
25
- if compare_version("triton", operator.ge, "3.0.0"):
27
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
26
28
  try:
27
29
  # typical import path with dispatch available
28
30
  from triton.language.extra.libdevice import rsqrt
@@ -52,6 +54,7 @@ def _rms_norm_forward_kernel(
52
54
  eps,
53
55
  offset,
54
56
  casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
57
+ elementwise_affine: tl.constexpr,
55
58
  BLOCK_SIZE: tl.constexpr,
56
59
  ):
57
60
  """
@@ -73,7 +76,8 @@ def _rms_norm_forward_kernel(
73
76
 
74
77
  X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
75
78
  X_row_dtype = X_row.dtype
76
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
79
+ if elementwise_affine:
80
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
77
81
 
78
82
  # On Llama, only rstd is computed on fp32
79
83
  if casting_mode == _CASTING_MODE_LLAMA:
@@ -81,7 +85,8 @@ def _rms_norm_forward_kernel(
81
85
 
82
86
  # Gemma computes everything on fp32, and then casts back the output to the original dtype
83
87
  if casting_mode == _CASTING_MODE_GEMMA:
84
- W_row = W_row.to(tl.float32)
88
+ if elementwise_affine:
89
+ W_row = W_row.to(tl.float32)
85
90
  X_row = X_row.to(tl.float32)
86
91
 
87
92
  if casting_mode == _CASTING_MODE_NONE:
@@ -102,7 +107,10 @@ def _rms_norm_forward_kernel(
102
107
  if casting_mode == _CASTING_MODE_LLAMA:
103
108
  X_row = X_row.to(X_row_dtype)
104
109
 
105
- Y_row = X_row * (offset + W_row)
110
+ if elementwise_affine:
111
+ Y_row = X_row * (offset + W_row)
112
+ else:
113
+ Y_row = X_row
106
114
 
107
115
  if casting_mode == _CASTING_MODE_GEMMA:
108
116
  Y_row = Y_row.to(X_row_dtype)
@@ -128,8 +136,9 @@ def _rms_norm_backward_kernel(
128
136
  n_rows,
129
137
  n_cols,
130
138
  offset,
131
- rows_per_program: tl.constexpr,
139
+ rows_per_program,
132
140
  casting_mode: tl.constexpr,
141
+ elementwise_affine: tl.constexpr,
133
142
  BLOCK_SIZE: tl.constexpr,
134
143
  ):
135
144
  """
@@ -143,7 +152,8 @@ def _rms_norm_backward_kernel(
143
152
  col_offsets = tl.arange(0, BLOCK_SIZE)
144
153
  mask = col_offsets < n_cols
145
154
 
146
- dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
155
+ if elementwise_affine:
156
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
147
157
 
148
158
  dY_ptr += row_start * dY_row_stride
149
159
  dX_ptr += row_start * dX_row_stride
@@ -151,8 +161,9 @@ def _rms_norm_backward_kernel(
151
161
  X_ptr += row_start * X_row_stride
152
162
  RSTD_ptr += row_start
153
163
 
154
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
155
- W_row = W_row + offset
164
+ if elementwise_affine:
165
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
166
+ W_row = W_row + offset
156
167
 
157
168
  for _ in range(row_start, row_end):
158
169
  dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
@@ -165,24 +176,34 @@ def _rms_norm_backward_kernel(
165
176
 
166
177
  # Different bacward graphs for different casting modes
167
178
  if casting_mode == _CASTING_MODE_LLAMA:
168
- m = (dY_row * W_row).to(tl.float32)
179
+ if elementwise_affine:
180
+ m = (dY_row * W_row).to(tl.float32)
181
+ else:
182
+ m = dY_row.to(tl.float32)
169
183
 
170
184
  elif casting_mode == _CASTING_MODE_GEMMA:
171
185
  dY_row = dY_row.to(tl.float32)
172
- m = dY_row * W_row
186
+ if elementwise_affine:
187
+ m = dY_row * W_row
188
+ else:
189
+ m = dY_row
173
190
  else:
174
- m = dY_row * W_row
191
+ if elementwise_affine:
192
+ m = dY_row * W_row
193
+ else:
194
+ m = dY_row
175
195
 
176
196
  dX_row = rstd_row * m
177
197
 
178
198
  dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
179
199
 
180
- # calculate the gradient of W
181
- if casting_mode == _CASTING_MODE_LLAMA:
182
- dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
183
- else:
184
- # here X_row is already in fp32 (see previous if block)
185
- dW_row += dY_row * (X_row * rstd_row)
200
+ if elementwise_affine:
201
+ # calculate the gradient of W
202
+ if casting_mode == _CASTING_MODE_LLAMA:
203
+ dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
204
+ else:
205
+ # here X_row is already in fp32 (see previous if block)
206
+ dW_row += dY_row * (X_row * rstd_row)
186
207
 
187
208
  tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
188
209
 
@@ -191,7 +212,8 @@ def _rms_norm_backward_kernel(
191
212
  X_ptr += X_row_stride
192
213
  RSTD_ptr += RSTD_row_stride
193
214
 
194
- tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
215
+ if elementwise_affine:
216
+ tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
195
217
 
196
218
 
197
219
  @triton.jit
@@ -209,6 +231,7 @@ def _block_rms_norm_forward_kernel(
209
231
  eps,
210
232
  offset,
211
233
  casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
234
+ elementwise_affine: tl.constexpr,
212
235
  BLOCK_SIZE: tl.constexpr,
213
236
  BLOCK_ROW: tl.constexpr,
214
237
  ):
@@ -232,7 +255,8 @@ def _block_rms_norm_forward_kernel(
232
255
  other=0,
233
256
  )
234
257
  X_row_dtype = X_row.dtype
235
- W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
258
+ if elementwise_affine:
259
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
236
260
 
237
261
  # On Llama, only rstd is computed on fp32
238
262
  if casting_mode == _CASTING_MODE_LLAMA:
@@ -240,7 +264,8 @@ def _block_rms_norm_forward_kernel(
240
264
 
241
265
  # Gemma computes everything on fp32, and then casts back the output to the original dtype
242
266
  if casting_mode == _CASTING_MODE_GEMMA:
243
- W_row = W_row.to(tl.float32)
267
+ if elementwise_affine:
268
+ W_row = W_row.to(tl.float32)
244
269
  X_row = X_row.to(tl.float32)
245
270
 
246
271
  if casting_mode == _CASTING_MODE_NONE:
@@ -261,7 +286,10 @@ def _block_rms_norm_forward_kernel(
261
286
  if casting_mode == _CASTING_MODE_LLAMA:
262
287
  X_row = X_row.to(X_row_dtype)
263
288
 
264
- Y_row = X_row * (offset + W_row)[None, :]
289
+ if elementwise_affine:
290
+ Y_row = X_row * (offset + W_row)[None, :]
291
+ else:
292
+ Y_row = X_row
265
293
 
266
294
  if casting_mode == _CASTING_MODE_GEMMA:
267
295
  Y_row = Y_row.to(X_row_dtype)
@@ -291,8 +319,8 @@ def _block_rms_norm_backward_kernel(
291
319
  n_rows,
292
320
  n_cols,
293
321
  offset,
294
- rows_per_program: tl.constexpr,
295
322
  casting_mode: tl.constexpr,
323
+ elementwise_affine: tl.constexpr,
296
324
  BLOCK_SIZE: tl.constexpr,
297
325
  BLOCK_ROW: tl.constexpr,
298
326
  ):
@@ -307,10 +335,11 @@ def _block_rms_norm_backward_kernel(
307
335
  col_offsets = tl.arange(0, BLOCK_SIZE)
308
336
  col_mask = col_offsets < n_cols
309
337
 
310
- dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
338
+ if elementwise_affine:
339
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
311
340
 
312
- W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
313
- W_row = W_row + offset
341
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
342
+ W_row = W_row + offset
314
343
 
315
344
  for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
316
345
  row_idx = start + tl.arange(0, BLOCK_ROW)
@@ -333,13 +362,22 @@ def _block_rms_norm_backward_kernel(
333
362
 
334
363
  # Different bacward graphs for different casting modes
335
364
  if casting_mode == _CASTING_MODE_LLAMA:
336
- m = (dY_row * W_row[None, :]).to(tl.float32)
365
+ if elementwise_affine:
366
+ m = (dY_row * W_row[None, :]).to(tl.float32)
367
+ else:
368
+ m = dY_row.to(tl.float32)
337
369
 
338
370
  elif casting_mode == _CASTING_MODE_GEMMA:
339
371
  dY_row = dY_row.to(tl.float32)
340
- m = dY_row * W_row[None, :]
372
+ if elementwise_affine:
373
+ m = dY_row * W_row[None, :]
374
+ else:
375
+ m = dY_row
341
376
  else:
342
- m = dY_row * W_row[None, :]
377
+ if elementwise_affine:
378
+ m = dY_row * W_row[None, :]
379
+ else:
380
+ m = dY_row
343
381
 
344
382
  dX_row = rstd_row[:, None] * m
345
383
 
@@ -347,12 +385,13 @@ def _block_rms_norm_backward_kernel(
347
385
  -(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
348
386
  )
349
387
 
350
- # calculate the gradient of W
351
- if casting_mode == _CASTING_MODE_LLAMA:
352
- dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]).to(X_dtype), 0)
353
- else:
354
- # here X_row is already in fp32 (see previous if block)
355
- dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
388
+ if elementwise_affine:
389
+ if casting_mode == _CASTING_MODE_LLAMA:
390
+ # TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
391
+ dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
392
+ else:
393
+ # here X_row is already in fp32 (see previous if block)
394
+ dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
356
395
 
357
396
  tl.store(
358
397
  dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
@@ -360,7 +399,8 @@ def _block_rms_norm_backward_kernel(
360
399
  mask=row_mask[:, None] & col_mask[None, :],
361
400
  )
362
401
 
363
- tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
402
+ if elementwise_affine:
403
+ tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
364
404
 
365
405
 
366
406
  _str_to_casting_mode = {
@@ -389,8 +429,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
389
429
  rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
390
430
  RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
391
431
 
392
- # Check constraints.
393
- assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
432
+ if W is not None:
433
+ # Check constraints.
434
+ assert X.shape[1] == W.shape[0], (
435
+ "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
436
+ )
437
+ elementwise_affine = True
438
+ else:
439
+ elementwise_affine = False
394
440
 
395
441
  # XPU-specific optimization
396
442
  kernel_args = {}
@@ -403,13 +449,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
403
449
  X,
404
450
  X.stride(0),
405
451
  W,
406
- W.stride(0),
452
+ W.stride(0) if elementwise_affine else 0,
407
453
  RSTD,
408
454
  RSTD.stride(0),
409
455
  n_cols,
410
456
  eps,
411
457
  offset,
412
458
  casting_mode,
459
+ elementwise_affine=elementwise_affine,
413
460
  BLOCK_SIZE=BLOCK_SIZE,
414
461
  num_warps=num_warps,
415
462
  **kernel_args, # XPU-specific optimization
@@ -423,7 +470,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
423
470
  X,
424
471
  X.stride(0),
425
472
  W,
426
- W.stride(0),
473
+ W.stride(0) if elementwise_affine else 0,
427
474
  RSTD,
428
475
  RSTD.stride(0),
429
476
  n_rows,
@@ -431,6 +478,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
431
478
  eps,
432
479
  offset,
433
480
  casting_mode,
481
+ elementwise_affine=elementwise_affine,
434
482
  BLOCK_SIZE=BLOCK_SIZE,
435
483
  num_warps=num_warps,
436
484
  **kernel_args, # XPU-specific optimization
@@ -449,9 +497,16 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
449
497
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
450
498
  elif X.device.type == "xpu":
451
499
  sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
500
+ elif X.device.type == "npu":
501
+ sm_count = get_npu_multi_processor_count()
452
502
 
453
- # fp32 for numerical stability especially.
454
- _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
503
+ if W is not None:
504
+ # fp32 for numerical stability especially.
505
+ _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
506
+ elementwise_affine = True
507
+ else:
508
+ _dW = None
509
+ elementwise_affine = False
455
510
 
456
511
  if n_cols > BLOCK_SIZE:
457
512
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
@@ -478,16 +533,17 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
478
533
  X.stride(0),
479
534
  torch_to_triton_dtype[X.dtype],
480
535
  W,
481
- W.stride(0),
536
+ W.stride(0) if elementwise_affine else 0,
482
537
  RSTD,
483
538
  RSTD.stride(0),
484
539
  _dW,
485
- _dW.stride(0),
540
+ _dW.stride(0) if elementwise_affine else 0,
486
541
  n_rows,
487
542
  n_cols,
488
543
  offset,
489
544
  rows_per_program,
490
545
  casting_mode,
546
+ elementwise_affine=elementwise_affine,
491
547
  BLOCK_SIZE=BLOCK_SIZE,
492
548
  num_warps=num_warps,
493
549
  **kernel_args, # XPU-specific optimization
@@ -504,22 +560,26 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
504
560
  X.stride(0),
505
561
  torch_to_triton_dtype[X.dtype],
506
562
  W,
507
- W.stride(0),
563
+ W.stride(0) if elementwise_affine else 0,
508
564
  RSTD,
509
565
  RSTD.stride(0),
510
566
  _dW,
511
- _dW.stride(0),
567
+ _dW.stride(0) if elementwise_affine else 0,
512
568
  n_rows,
513
569
  n_cols,
514
570
  offset,
515
- rows_per_program,
516
571
  casting_mode,
572
+ elementwise_affine=elementwise_affine,
517
573
  BLOCK_SIZE=BLOCK_SIZE,
518
574
  num_warps=num_warps,
519
575
  **kernel_args, # XPU-specific optimization
520
576
  )
521
577
  dX = dX.view(*shape)
522
- dW = _dW.sum(dim=0).to(W.dtype)
578
+
579
+ if elementwise_affine:
580
+ dW = _dW.sum(dim=0).to(W.dtype)
581
+ else:
582
+ dW = None
523
583
 
524
584
  return dX, dW
525
585
 
@@ -553,6 +613,13 @@ class LigerRMSNormFunction(torch.autograd.Function):
553
613
  X: (B, T, H) or (BxT, H)
554
614
  W: (H,)
555
615
  """
616
+ if isinstance(X, torch.distributed.tensor.DTensor):
617
+ # Input tensor is output of a tensor parallel module and
618
+ # needs to be gathered to a local tensor to compute
619
+ # RMSE layer norm on each TP worker.
620
+ # TODO: support CP.
621
+ X = X.full_tensor()
622
+
556
623
  Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
557
624
  ctx.offset = offset
558
625
  ctx.casting_mode = casting_mode
@@ -560,7 +627,11 @@ class LigerRMSNormFunction(torch.autograd.Function):
560
627
  ctx.row_mode = row_mode
561
628
  ctx.BLOCK_SIZE = BLOCK_SIZE
562
629
  ctx.num_warps = num_warps
563
- ctx.save_for_backward(X, W, RSTD)
630
+ ctx.elementwise_affine = W is not None
631
+ if W is not None:
632
+ ctx.save_for_backward(X, W, RSTD)
633
+ else:
634
+ ctx.save_for_backward(X, RSTD)
564
635
  return Y
565
636
 
566
637
  @staticmethod
@@ -569,7 +640,18 @@ class LigerRMSNormFunction(torch.autograd.Function):
569
640
  """
570
641
  Y: (B, T, H) or (BxT, H)
571
642
  """
572
- X, W, RSTD = ctx.saved_tensors
643
+ if ctx.elementwise_affine:
644
+ X, W, RSTD = ctx.saved_tensors
645
+ else:
646
+ X, RSTD = ctx.saved_tensors
647
+ W = None
648
+
649
+ if isinstance(dY, torch.distributed.tensor.DTensor):
650
+ # Gradients are output of a tensor parallel module and
651
+ # needs to be gathered to a local tensor for computing RMSE layer.
652
+ # TODO: support CP.
653
+ dY = dY.full_tensor()
654
+
573
655
  dX, dW = rms_norm_backward(
574
656
  dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
575
657
  )
@@ -0,0 +1,136 @@
1
+ import math
2
+
3
+ from typing import Callable
4
+ from typing import List
5
+ from typing import Optional
6
+
7
+ import torch
8
+
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+
11
+
12
+ class LigerTiledMLPFunction(torch.autograd.Function):
13
+ """
14
+ Based on DeepSpeed's TiledMLP:
15
+ https://github.com/deepspeedai/DeepSpeed/blob/v0.18.2/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L838
16
+
17
+ Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP
18
+ when using very long sequence lengths.
19
+
20
+ This module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration.
21
+ And if you're using activation checkpointing it then occurs thrice.
22
+
23
+ Args:
24
+ fn: the function to call on sharded inputs (e.g., mlp.forward)
25
+ mlp_module: the MLP nn.Module object
26
+ x: the input to MLP.forward (hidden_states)
27
+ shards: how many shards to use
28
+ compute_params: a list of weights engaged in the compute
29
+
30
+ Returns:
31
+ the computed hidden_states
32
+ """
33
+
34
+ @staticmethod
35
+ @ensure_contiguous
36
+ def forward(
37
+ ctx,
38
+ fn: Callable,
39
+ mlp_module: torch.nn.Module,
40
+ x: torch.Tensor,
41
+ shards: int,
42
+ compute_params: Optional[List[torch.nn.Parameter]] = None,
43
+ ) -> torch.Tensor:
44
+ ctx.fn = fn
45
+ ctx.mlp_module = mlp_module
46
+ ctx.shards = shards
47
+ ctx.save_for_backward(x)
48
+
49
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
50
+ x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
51
+ with torch.no_grad():
52
+ output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards]
53
+ output_unsharded = torch.cat(output_shards, dim=-2)
54
+
55
+ return output_unsharded
56
+
57
+ @staticmethod
58
+ @ensure_contiguous
59
+ def backward(ctx, *grads) -> tuple:
60
+ fn = ctx.fn
61
+ (x,) = ctx.saved_tensors
62
+ mlp_module = ctx.mlp_module
63
+ shards = ctx.shards
64
+
65
+ x_requires_grad = x.requires_grad
66
+ x = x.detach()
67
+ # detach() unsets x.requires_grad, so restore it
68
+ x.requires_grad_(x_requires_grad)
69
+
70
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
71
+ hidden_size = x.shape[-1]
72
+ x_shape_orig = x.shape
73
+
74
+ # flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1
75
+ x = x.view(-1, hidden_size)
76
+ incoming_grad = grads[0].view(-1, hidden_size)
77
+ x_grad = torch.zeros_like(x)
78
+
79
+ x_shards = list(torch.chunk(x, chunks=shards, dim=0))
80
+
81
+ for i, x_shard in enumerate(x_shards):
82
+ x_shard.requires_grad_(x_requires_grad)
83
+
84
+ # if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
85
+ shard_step = x_shards[i].shape[0]
86
+ shard_offset = i * x_shards[0].shape[0]
87
+
88
+ x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
89
+ incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
90
+
91
+ with torch.enable_grad():
92
+ output = fn(mlp_module, x_shard)
93
+ torch.autograd.backward(output, incoming_grad_shard)
94
+
95
+ # unflatten
96
+ x_grad = x_grad.view(x_shape_orig)
97
+
98
+ return (None, None, x_grad, None, None)
99
+
100
+
101
+ def apply_tiled_mlp(
102
+ fn: Callable,
103
+ mlp_module: torch.nn.Module,
104
+ x: torch.Tensor,
105
+ num_shards: Optional[int] = None,
106
+ compute_params: Optional[List[torch.nn.Parameter]] = None,
107
+ ) -> torch.Tensor:
108
+ """
109
+ Apply tiled MLP computation for memory efficiency.
110
+
111
+ Args:
112
+ fn: the function to call on sharded inputs (e.g., lambda module, x: module(x))
113
+ mlp_module: the MLP nn.Module object
114
+ x: the input tensor with shape [bs, seqlen, hidden_size] or [seqlen, hidden_size]
115
+ num_shards: number of shards to use. If None, automatically calculated as ceil(seqlen / hidden_size)
116
+ compute_params: list of parameters for DeepSpeed ZeRO optimization
117
+
118
+ Returns:
119
+ output tensor with the same shape as input
120
+ """
121
+ if num_shards is None:
122
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size]
123
+ hidden_size = x.shape[-1]
124
+ seqlen = x.shape[-2]
125
+ num_shards = math.ceil(seqlen / hidden_size)
126
+
127
+ # Ensure num_shards is at least 1
128
+ num_shards = max(1, num_shards)
129
+
130
+ return LigerTiledMLPFunction.apply(
131
+ fn,
132
+ mlp_module,
133
+ x,
134
+ num_shards,
135
+ compute_params,
136
+ )
liger_kernel/ops/utils.py CHANGED
@@ -78,6 +78,8 @@ def get_amp_custom_fwd_bwd() -> Callable:
78
78
  functools.partial(torch.amp.custom_fwd, device_type=device),
79
79
  functools.partial(torch.amp.custom_bwd, device_type=device),
80
80
  )
81
+ if hasattr(torch, "npu") and getattr(torch.npu, "amp", None) is not None:
82
+ return torch.npu.amp.custom_fwd, torch.npu.amp.custom_bwd
81
83
  return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
82
84
 
83
85
 
@@ -125,3 +127,15 @@ def element_mul_kernel(
125
127
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
126
128
  X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
127
129
  tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
130
+
131
+
132
+ def get_npu_core_count(default: int = 20) -> int:
133
+ """Return NPU vector core count.
134
+ Fallback to `default` if Triton runtime or NPU device is unavailable.
135
+ """
136
+ try:
137
+ utils = triton.runtime.driver.active.utils
138
+ props = utils.get_device_properties(0)
139
+ return int(props.get("num_vectorcore", default))
140
+ except Exception:
141
+ return default