liger-kernel-nightly 0.5.10.dev20250611191801__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (107) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +54 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
  7. liger_kernel/chunked_loss/grpo_loss.py +46 -9
  8. liger_kernel/chunked_loss/jsd_loss.py +44 -13
  9. liger_kernel/ops/__init__.py +141 -0
  10. liger_kernel/ops/backends/README.md +151 -0
  11. liger_kernel/ops/backends/__init__.py +13 -0
  12. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  13. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  14. liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
  15. liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
  16. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  17. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  18. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  19. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  20. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  21. liger_kernel/ops/backends/registry.py +61 -0
  22. liger_kernel/ops/cross_entropy.py +130 -64
  23. liger_kernel/ops/dyt.py +5 -4
  24. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  25. liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
  26. liger_kernel/ops/geglu.py +6 -4
  27. liger_kernel/ops/group_norm.py +7 -7
  28. liger_kernel/ops/grpo_loss.py +3 -1
  29. liger_kernel/ops/kl_div.py +8 -11
  30. liger_kernel/ops/layer_norm.py +135 -80
  31. liger_kernel/ops/llama4_rope.py +225 -0
  32. liger_kernel/ops/poly_norm.py +390 -0
  33. liger_kernel/ops/rms_norm.py +148 -71
  34. liger_kernel/ops/rope.py +1 -1
  35. liger_kernel/ops/swiglu.py +1 -1
  36. liger_kernel/ops/tiled_mlp.py +136 -0
  37. liger_kernel/ops/utils.py +14 -0
  38. liger_kernel/transformers/__init__.py +65 -0
  39. liger_kernel/transformers/auto_model.py +21 -0
  40. liger_kernel/transformers/cross_entropy.py +9 -4
  41. liger_kernel/transformers/dyt.py +1 -1
  42. liger_kernel/transformers/experimental/__init__.py +5 -0
  43. liger_kernel/transformers/experimental/embedding.py +1 -1
  44. liger_kernel/transformers/functional.py +56 -24
  45. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  46. liger_kernel/transformers/fused_linear_cross_entropy.py +17 -5
  47. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  48. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  49. liger_kernel/transformers/geglu.py +1 -1
  50. liger_kernel/transformers/group_norm.py +1 -1
  51. liger_kernel/transformers/grpo_loss.py +57 -2
  52. liger_kernel/transformers/jsd.py +1 -1
  53. liger_kernel/transformers/kl_div.py +1 -1
  54. liger_kernel/transformers/layer_norm.py +1 -1
  55. liger_kernel/transformers/llama4_rope.py +93 -0
  56. liger_kernel/transformers/model/exaone4.py +136 -0
  57. liger_kernel/transformers/model/falcon_h1.py +122 -0
  58. liger_kernel/transformers/model/gemma.py +28 -8
  59. liger_kernel/transformers/model/gemma2.py +34 -11
  60. liger_kernel/transformers/model/gemma3.py +102 -112
  61. liger_kernel/transformers/model/glm4.py +18 -5
  62. liger_kernel/transformers/model/glm4v.py +163 -0
  63. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  64. liger_kernel/transformers/model/gpt_oss.py +211 -0
  65. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  66. liger_kernel/transformers/model/internvl.py +157 -0
  67. liger_kernel/transformers/model/llama.py +26 -7
  68. liger_kernel/transformers/model/llama4.py +121 -0
  69. liger_kernel/transformers/model/llava.py +18 -6
  70. liger_kernel/transformers/model/loss_utils.py +34 -3
  71. liger_kernel/transformers/model/mistral.py +17 -10
  72. liger_kernel/transformers/model/mixtral.py +24 -9
  73. liger_kernel/transformers/model/mllama.py +18 -7
  74. liger_kernel/transformers/model/olmo2.py +18 -5
  75. liger_kernel/transformers/model/olmo3.py +142 -0
  76. liger_kernel/transformers/model/output_classes.py +147 -0
  77. liger_kernel/transformers/model/paligemma.py +42 -5
  78. liger_kernel/transformers/model/phi3.py +24 -159
  79. liger_kernel/transformers/model/qwen2.py +26 -4
  80. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  81. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  82. liger_kernel/transformers/model/qwen3.py +22 -6
  83. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  84. liger_kernel/transformers/model/qwen3_next.py +146 -0
  85. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  86. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  87. liger_kernel/transformers/model/smollm3.py +199 -0
  88. liger_kernel/transformers/model/smolvlm.py +158 -0
  89. liger_kernel/transformers/monkey_patch.py +1423 -100
  90. liger_kernel/transformers/multi_token_attention.py +2 -2
  91. liger_kernel/transformers/poly_norm.py +42 -0
  92. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  93. liger_kernel/transformers/rms_norm.py +15 -5
  94. liger_kernel/transformers/rope.py +45 -1
  95. liger_kernel/transformers/softmax.py +1 -1
  96. liger_kernel/transformers/sparsemax.py +1 -1
  97. liger_kernel/transformers/swiglu.py +18 -1
  98. liger_kernel/transformers/tiled_mlp.py +125 -0
  99. liger_kernel/transformers/tvd.py +1 -1
  100. liger_kernel/utils.py +52 -0
  101. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +37 -25
  102. liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
  103. liger_kernel_nightly-0.5.10.dev20250611191801.dist-info/RECORD +0 -95
  104. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
  105. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
  106. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
  107. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
@@ -21,8 +21,10 @@ from liger_kernel.ops.utils import calculate_settings
21
21
  from liger_kernel.ops.utils import compare_version
22
22
  from liger_kernel.ops.utils import ensure_contiguous
23
23
  from liger_kernel.ops.utils import torch_to_triton_dtype
24
+ from liger_kernel.utils import get_npu_multi_processor_count
25
+ from liger_kernel.utils import is_npu_available
24
26
 
25
- if compare_version("triton", operator.ge, "3.0.0"):
27
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
26
28
  try:
27
29
  # typical import path with dispatch available
28
30
  from triton.language.extra.libdevice import rsqrt
@@ -52,6 +54,7 @@ def _rms_norm_forward_kernel(
52
54
  eps,
53
55
  offset,
54
56
  casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
57
+ elementwise_affine: tl.constexpr,
55
58
  BLOCK_SIZE: tl.constexpr,
56
59
  ):
57
60
  """
@@ -63,17 +66,18 @@ def _rms_norm_forward_kernel(
63
66
  3. https://arxiv.org/pdf/1910.07467
64
67
  """
65
68
 
66
- row_idx = tl.program_id(0)
69
+ row_idx = tl.program_id(0).to(tl.int64)
67
70
  col_offsets = tl.arange(0, BLOCK_SIZE)
68
71
  mask = col_offsets < n_cols
69
72
 
70
- Y_ptr += row_idx * Y_row_stride
71
- X_ptr += row_idx * X_row_stride
72
- RSTD_ptr += row_idx * RSTD_row_stride
73
+ y_base = Y_ptr + row_idx * Y_row_stride
74
+ x_base = X_ptr + row_idx * X_row_stride
75
+ rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
73
76
 
74
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
77
+ X_row = tl.load(x_base + col_offsets, mask=mask, other=0)
75
78
  X_row_dtype = X_row.dtype
76
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
79
+ if elementwise_affine:
80
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
77
81
 
78
82
  # On Llama, only rstd is computed on fp32
79
83
  if casting_mode == _CASTING_MODE_LLAMA:
@@ -81,7 +85,8 @@ def _rms_norm_forward_kernel(
81
85
 
82
86
  # Gemma computes everything on fp32, and then casts back the output to the original dtype
83
87
  if casting_mode == _CASTING_MODE_GEMMA:
84
- W_row = W_row.to(tl.float32)
88
+ if elementwise_affine:
89
+ W_row = W_row.to(tl.float32)
85
90
  X_row = X_row.to(tl.float32)
86
91
 
87
92
  if casting_mode == _CASTING_MODE_NONE:
@@ -94,7 +99,7 @@ def _rms_norm_forward_kernel(
94
99
  # We can save time by caching rms with minimal memory overhead
95
100
  # because rms is much smaller compared to X_row, as rms is for each row.
96
101
  # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
97
- tl.store(RSTD_ptr, rstd)
102
+ tl.store(rstd_base, rstd)
98
103
 
99
104
  X_row = X_row * rstd
100
105
 
@@ -102,12 +107,15 @@ def _rms_norm_forward_kernel(
102
107
  if casting_mode == _CASTING_MODE_LLAMA:
103
108
  X_row = X_row.to(X_row_dtype)
104
109
 
105
- Y_row = X_row * (offset + W_row)
110
+ if elementwise_affine:
111
+ Y_row = X_row * (offset + W_row)
112
+ else:
113
+ Y_row = X_row
106
114
 
107
115
  if casting_mode == _CASTING_MODE_GEMMA:
108
116
  Y_row = Y_row.to(X_row_dtype)
109
117
 
110
- tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
118
+ tl.store(y_base + col_offsets, Y_row, mask=mask)
111
119
 
112
120
 
113
121
  @triton.jit
@@ -128,8 +136,9 @@ def _rms_norm_backward_kernel(
128
136
  n_rows,
129
137
  n_cols,
130
138
  offset,
131
- rows_per_program: tl.constexpr,
139
+ rows_per_program,
132
140
  casting_mode: tl.constexpr,
141
+ elementwise_affine: tl.constexpr,
133
142
  BLOCK_SIZE: tl.constexpr,
134
143
  ):
135
144
  """
@@ -137,61 +146,69 @@ def _rms_norm_backward_kernel(
137
146
  dw = sum(dy * (x / RMS)). summation over BxT dimension
138
147
  """
139
148
 
140
- row_block_id = tl.program_id(0)
149
+ row_block_id = tl.program_id(0).to(tl.int64)
141
150
  row_start = row_block_id * rows_per_program
142
151
  row_end = min((row_block_id + 1) * rows_per_program, n_rows)
143
152
  col_offsets = tl.arange(0, BLOCK_SIZE)
144
153
  mask = col_offsets < n_cols
145
154
 
146
- dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
155
+ if elementwise_affine:
156
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
147
157
 
148
- dY_ptr += row_start * dY_row_stride
149
- dX_ptr += row_start * dX_row_stride
158
+ if elementwise_affine:
159
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
160
+ W_row = W_row + offset
150
161
 
151
- X_ptr += row_start * X_row_stride
152
- RSTD_ptr += row_start
162
+ for row_idx in range(row_start, row_end):
163
+ dy_base = dY_ptr + row_idx * dY_row_stride
164
+ dx_base = dX_ptr + row_idx * dX_row_stride
153
165
 
154
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
155
- W_row = W_row + offset
166
+ x_base = X_ptr + row_idx * X_row_stride
167
+ rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
156
168
 
157
- for _ in range(row_start, row_end):
158
- dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
159
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
169
+ dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
170
+ X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
160
171
 
161
172
  # Get cached rms
162
- rstd_row = tl.load(RSTD_ptr)
173
+ rstd_row = tl.load(rstd_base)
163
174
 
164
175
  X_row = X_row.to(tl.float32)
165
176
 
166
177
  # Different bacward graphs for different casting modes
167
178
  if casting_mode == _CASTING_MODE_LLAMA:
168
- m = (dY_row * W_row).to(tl.float32)
179
+ if elementwise_affine:
180
+ m = (dY_row * W_row).to(tl.float32)
181
+ else:
182
+ m = dY_row.to(tl.float32)
169
183
 
170
184
  elif casting_mode == _CASTING_MODE_GEMMA:
171
185
  dY_row = dY_row.to(tl.float32)
172
- m = dY_row * W_row
186
+ if elementwise_affine:
187
+ m = dY_row * W_row
188
+ else:
189
+ m = dY_row
173
190
  else:
174
- m = dY_row * W_row
191
+ if elementwise_affine:
192
+ m = dY_row * W_row
193
+ else:
194
+ m = dY_row
175
195
 
176
196
  dX_row = rstd_row * m
177
197
 
178
198
  dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
179
199
 
180
- # calculate the gradient of W
181
- if casting_mode == _CASTING_MODE_LLAMA:
182
- dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
183
- else:
184
- # here X_row is already in fp32 (see previous if block)
185
- dW_row += dY_row * (X_row * rstd_row)
200
+ if elementwise_affine:
201
+ # calculate the gradient of W
202
+ if casting_mode == _CASTING_MODE_LLAMA:
203
+ dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
204
+ else:
205
+ # here X_row is already in fp32 (see previous if block)
206
+ dW_row += dY_row * (X_row * rstd_row)
186
207
 
187
- tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
208
+ tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
188
209
 
189
- dY_ptr += dY_row_stride
190
- dX_ptr += dX_row_stride
191
- X_ptr += X_row_stride
192
- RSTD_ptr += RSTD_row_stride
193
-
194
- tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
210
+ if elementwise_affine:
211
+ tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
195
212
 
196
213
 
197
214
  @triton.jit
@@ -209,6 +226,7 @@ def _block_rms_norm_forward_kernel(
209
226
  eps,
210
227
  offset,
211
228
  casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
229
+ elementwise_affine: tl.constexpr,
212
230
  BLOCK_SIZE: tl.constexpr,
213
231
  BLOCK_ROW: tl.constexpr,
214
232
  ):
@@ -232,7 +250,8 @@ def _block_rms_norm_forward_kernel(
232
250
  other=0,
233
251
  )
234
252
  X_row_dtype = X_row.dtype
235
- W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
253
+ if elementwise_affine:
254
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
236
255
 
237
256
  # On Llama, only rstd is computed on fp32
238
257
  if casting_mode == _CASTING_MODE_LLAMA:
@@ -240,7 +259,8 @@ def _block_rms_norm_forward_kernel(
240
259
 
241
260
  # Gemma computes everything on fp32, and then casts back the output to the original dtype
242
261
  if casting_mode == _CASTING_MODE_GEMMA:
243
- W_row = W_row.to(tl.float32)
262
+ if elementwise_affine:
263
+ W_row = W_row.to(tl.float32)
244
264
  X_row = X_row.to(tl.float32)
245
265
 
246
266
  if casting_mode == _CASTING_MODE_NONE:
@@ -261,7 +281,10 @@ def _block_rms_norm_forward_kernel(
261
281
  if casting_mode == _CASTING_MODE_LLAMA:
262
282
  X_row = X_row.to(X_row_dtype)
263
283
 
264
- Y_row = X_row * (offset + W_row)[None, :]
284
+ if elementwise_affine:
285
+ Y_row = X_row * (offset + W_row)[None, :]
286
+ else:
287
+ Y_row = X_row
265
288
 
266
289
  if casting_mode == _CASTING_MODE_GEMMA:
267
290
  Y_row = Y_row.to(X_row_dtype)
@@ -291,8 +314,8 @@ def _block_rms_norm_backward_kernel(
291
314
  n_rows,
292
315
  n_cols,
293
316
  offset,
294
- rows_per_program: tl.constexpr,
295
317
  casting_mode: tl.constexpr,
318
+ elementwise_affine: tl.constexpr,
296
319
  BLOCK_SIZE: tl.constexpr,
297
320
  BLOCK_ROW: tl.constexpr,
298
321
  ):
@@ -307,10 +330,11 @@ def _block_rms_norm_backward_kernel(
307
330
  col_offsets = tl.arange(0, BLOCK_SIZE)
308
331
  col_mask = col_offsets < n_cols
309
332
 
310
- dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
333
+ if elementwise_affine:
334
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
311
335
 
312
- W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
313
- W_row = W_row + offset
336
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
337
+ W_row = W_row + offset
314
338
 
315
339
  for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
316
340
  row_idx = start + tl.arange(0, BLOCK_ROW)
@@ -333,13 +357,22 @@ def _block_rms_norm_backward_kernel(
333
357
 
334
358
  # Different bacward graphs for different casting modes
335
359
  if casting_mode == _CASTING_MODE_LLAMA:
336
- m = (dY_row * W_row[None, :]).to(tl.float32)
360
+ if elementwise_affine:
361
+ m = (dY_row * W_row[None, :]).to(tl.float32)
362
+ else:
363
+ m = dY_row.to(tl.float32)
337
364
 
338
365
  elif casting_mode == _CASTING_MODE_GEMMA:
339
366
  dY_row = dY_row.to(tl.float32)
340
- m = dY_row * W_row[None, :]
367
+ if elementwise_affine:
368
+ m = dY_row * W_row[None, :]
369
+ else:
370
+ m = dY_row
341
371
  else:
342
- m = dY_row * W_row[None, :]
372
+ if elementwise_affine:
373
+ m = dY_row * W_row[None, :]
374
+ else:
375
+ m = dY_row
343
376
 
344
377
  dX_row = rstd_row[:, None] * m
345
378
 
@@ -347,12 +380,13 @@ def _block_rms_norm_backward_kernel(
347
380
  -(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
348
381
  )
349
382
 
350
- # calculate the gradient of W
351
- if casting_mode == _CASTING_MODE_LLAMA:
352
- dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]).to(X_dtype), 0)
353
- else:
354
- # here X_row is already in fp32 (see previous if block)
355
- dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
383
+ if elementwise_affine:
384
+ if casting_mode == _CASTING_MODE_LLAMA:
385
+ # TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
386
+ dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
387
+ else:
388
+ # here X_row is already in fp32 (see previous if block)
389
+ dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
356
390
 
357
391
  tl.store(
358
392
  dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
@@ -360,7 +394,8 @@ def _block_rms_norm_backward_kernel(
360
394
  mask=row_mask[:, None] & col_mask[None, :],
361
395
  )
362
396
 
363
- tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
397
+ if elementwise_affine:
398
+ tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
364
399
 
365
400
 
366
401
  _str_to_casting_mode = {
@@ -389,8 +424,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
389
424
  rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
390
425
  RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
391
426
 
392
- # Check constraints.
393
- assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
427
+ if W is not None:
428
+ # Check constraints.
429
+ assert X.shape[1] == W.shape[0], (
430
+ "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
431
+ )
432
+ elementwise_affine = True
433
+ else:
434
+ elementwise_affine = False
394
435
 
395
436
  # XPU-specific optimization
396
437
  kernel_args = {}
@@ -403,13 +444,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
403
444
  X,
404
445
  X.stride(0),
405
446
  W,
406
- W.stride(0),
447
+ W.stride(0) if elementwise_affine else 0,
407
448
  RSTD,
408
449
  RSTD.stride(0),
409
450
  n_cols,
410
451
  eps,
411
452
  offset,
412
453
  casting_mode,
454
+ elementwise_affine=elementwise_affine,
413
455
  BLOCK_SIZE=BLOCK_SIZE,
414
456
  num_warps=num_warps,
415
457
  **kernel_args, # XPU-specific optimization
@@ -423,7 +465,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
423
465
  X,
424
466
  X.stride(0),
425
467
  W,
426
- W.stride(0),
468
+ W.stride(0) if elementwise_affine else 0,
427
469
  RSTD,
428
470
  RSTD.stride(0),
429
471
  n_rows,
@@ -431,6 +473,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
431
473
  eps,
432
474
  offset,
433
475
  casting_mode,
476
+ elementwise_affine=elementwise_affine,
434
477
  BLOCK_SIZE=BLOCK_SIZE,
435
478
  num_warps=num_warps,
436
479
  **kernel_args, # XPU-specific optimization
@@ -449,9 +492,16 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
449
492
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
450
493
  elif X.device.type == "xpu":
451
494
  sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
495
+ elif X.device.type == "npu":
496
+ sm_count = get_npu_multi_processor_count()
452
497
 
453
- # fp32 for numerical stability especially.
454
- _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
498
+ if W is not None:
499
+ # fp32 for numerical stability especially.
500
+ _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
501
+ elementwise_affine = True
502
+ else:
503
+ _dW = None
504
+ elementwise_affine = False
455
505
 
456
506
  if n_cols > BLOCK_SIZE:
457
507
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
@@ -478,16 +528,17 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
478
528
  X.stride(0),
479
529
  torch_to_triton_dtype[X.dtype],
480
530
  W,
481
- W.stride(0),
531
+ W.stride(0) if elementwise_affine else 0,
482
532
  RSTD,
483
533
  RSTD.stride(0),
484
534
  _dW,
485
- _dW.stride(0),
535
+ _dW.stride(0) if elementwise_affine else 0,
486
536
  n_rows,
487
537
  n_cols,
488
538
  offset,
489
539
  rows_per_program,
490
540
  casting_mode,
541
+ elementwise_affine=elementwise_affine,
491
542
  BLOCK_SIZE=BLOCK_SIZE,
492
543
  num_warps=num_warps,
493
544
  **kernel_args, # XPU-specific optimization
@@ -504,22 +555,26 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
504
555
  X.stride(0),
505
556
  torch_to_triton_dtype[X.dtype],
506
557
  W,
507
- W.stride(0),
558
+ W.stride(0) if elementwise_affine else 0,
508
559
  RSTD,
509
560
  RSTD.stride(0),
510
561
  _dW,
511
- _dW.stride(0),
562
+ _dW.stride(0) if elementwise_affine else 0,
512
563
  n_rows,
513
564
  n_cols,
514
565
  offset,
515
- rows_per_program,
516
566
  casting_mode,
567
+ elementwise_affine=elementwise_affine,
517
568
  BLOCK_SIZE=BLOCK_SIZE,
518
569
  num_warps=num_warps,
519
570
  **kernel_args, # XPU-specific optimization
520
571
  )
521
572
  dX = dX.view(*shape)
522
- dW = _dW.sum(dim=0).to(W.dtype)
573
+
574
+ if elementwise_affine:
575
+ dW = _dW.sum(dim=0).to(W.dtype)
576
+ else:
577
+ dW = None
523
578
 
524
579
  return dX, dW
525
580
 
@@ -553,6 +608,13 @@ class LigerRMSNormFunction(torch.autograd.Function):
553
608
  X: (B, T, H) or (BxT, H)
554
609
  W: (H,)
555
610
  """
611
+ if isinstance(X, torch.distributed.tensor.DTensor):
612
+ # Input tensor is output of a tensor parallel module and
613
+ # needs to be gathered to a local tensor to compute
614
+ # RMSE layer norm on each TP worker.
615
+ # TODO: support CP.
616
+ X = X.full_tensor()
617
+
556
618
  Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
557
619
  ctx.offset = offset
558
620
  ctx.casting_mode = casting_mode
@@ -560,7 +622,11 @@ class LigerRMSNormFunction(torch.autograd.Function):
560
622
  ctx.row_mode = row_mode
561
623
  ctx.BLOCK_SIZE = BLOCK_SIZE
562
624
  ctx.num_warps = num_warps
563
- ctx.save_for_backward(X, W, RSTD)
625
+ ctx.elementwise_affine = W is not None
626
+ if W is not None:
627
+ ctx.save_for_backward(X, W, RSTD)
628
+ else:
629
+ ctx.save_for_backward(X, RSTD)
564
630
  return Y
565
631
 
566
632
  @staticmethod
@@ -569,7 +635,18 @@ class LigerRMSNormFunction(torch.autograd.Function):
569
635
  """
570
636
  Y: (B, T, H) or (BxT, H)
571
637
  """
572
- X, W, RSTD = ctx.saved_tensors
638
+ if ctx.elementwise_affine:
639
+ X, W, RSTD = ctx.saved_tensors
640
+ else:
641
+ X, RSTD = ctx.saved_tensors
642
+ W = None
643
+
644
+ if isinstance(dY, torch.distributed.tensor.DTensor):
645
+ # Gradients are output of a tensor parallel module and
646
+ # needs to be gathered to a local tensor for computing RMSE layer.
647
+ # TODO: support CP.
648
+ dY = dY.full_tensor()
649
+
573
650
  dX, dW = rms_norm_backward(
574
651
  dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
575
652
  )
liger_kernel/ops/rope.py CHANGED
@@ -32,7 +32,7 @@ def _triton_rope(
32
32
 
33
33
  # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
34
34
  # stride: (seq_len * head_dim, head_dim, 1)
35
- pid = tl.program_id(0)
35
+ pid = tl.program_id(0).to(tl.int64)
36
36
 
37
37
  # locate start address
38
38
  q_ptr = q_ptr + pid * q_row_stride
@@ -26,7 +26,7 @@ def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BL
26
26
  # sigmoid requires type float32
27
27
  a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
28
28
  b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
29
- c_row = silu(a_row) * b_row
29
+ c_row = silu(a_row).cast(b_row.dtype) * b_row
30
30
  tl.store(c_ptr + col_offsets, c_row, mask=mask)
31
31
 
32
32
 
@@ -0,0 +1,136 @@
1
+ import math
2
+
3
+ from typing import Callable
4
+ from typing import List
5
+ from typing import Optional
6
+
7
+ import torch
8
+
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+
11
+
12
+ class LigerTiledMLPFunction(torch.autograd.Function):
13
+ """
14
+ Based on DeepSpeed's TiledMLP:
15
+ https://github.com/deepspeedai/DeepSpeed/blob/v0.18.2/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L838
16
+
17
+ Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP
18
+ when using very long sequence lengths.
19
+
20
+ This module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration.
21
+ And if you're using activation checkpointing it then occurs thrice.
22
+
23
+ Args:
24
+ fn: the function to call on sharded inputs (e.g., mlp.forward)
25
+ mlp_module: the MLP nn.Module object
26
+ x: the input to MLP.forward (hidden_states)
27
+ shards: how many shards to use
28
+ compute_params: a list of weights engaged in the compute
29
+
30
+ Returns:
31
+ the computed hidden_states
32
+ """
33
+
34
+ @staticmethod
35
+ @ensure_contiguous
36
+ def forward(
37
+ ctx,
38
+ fn: Callable,
39
+ mlp_module: torch.nn.Module,
40
+ x: torch.Tensor,
41
+ shards: int,
42
+ compute_params: Optional[List[torch.nn.Parameter]] = None,
43
+ ) -> torch.Tensor:
44
+ ctx.fn = fn
45
+ ctx.mlp_module = mlp_module
46
+ ctx.shards = shards
47
+ ctx.save_for_backward(x)
48
+
49
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
50
+ x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
51
+ with torch.no_grad():
52
+ output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards]
53
+ output_unsharded = torch.cat(output_shards, dim=-2)
54
+
55
+ return output_unsharded
56
+
57
+ @staticmethod
58
+ @ensure_contiguous
59
+ def backward(ctx, *grads) -> tuple:
60
+ fn = ctx.fn
61
+ (x,) = ctx.saved_tensors
62
+ mlp_module = ctx.mlp_module
63
+ shards = ctx.shards
64
+
65
+ x_requires_grad = x.requires_grad
66
+ x = x.detach()
67
+ # detach() unsets x.requires_grad, so restore it
68
+ x.requires_grad_(x_requires_grad)
69
+
70
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
71
+ hidden_size = x.shape[-1]
72
+ x_shape_orig = x.shape
73
+
74
+ # flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1
75
+ x = x.view(-1, hidden_size)
76
+ incoming_grad = grads[0].view(-1, hidden_size)
77
+ x_grad = torch.zeros_like(x)
78
+
79
+ x_shards = list(torch.chunk(x, chunks=shards, dim=0))
80
+
81
+ for i, x_shard in enumerate(x_shards):
82
+ x_shard.requires_grad_(x_requires_grad)
83
+
84
+ # if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
85
+ shard_step = x_shards[i].shape[0]
86
+ shard_offset = i * x_shards[0].shape[0]
87
+
88
+ x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
89
+ incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
90
+
91
+ with torch.enable_grad():
92
+ output = fn(mlp_module, x_shard)
93
+ torch.autograd.backward(output, incoming_grad_shard)
94
+
95
+ # unflatten
96
+ x_grad = x_grad.view(x_shape_orig)
97
+
98
+ return (None, None, x_grad, None, None)
99
+
100
+
101
+ def apply_tiled_mlp(
102
+ fn: Callable,
103
+ mlp_module: torch.nn.Module,
104
+ x: torch.Tensor,
105
+ num_shards: Optional[int] = None,
106
+ compute_params: Optional[List[torch.nn.Parameter]] = None,
107
+ ) -> torch.Tensor:
108
+ """
109
+ Apply tiled MLP computation for memory efficiency.
110
+
111
+ Args:
112
+ fn: the function to call on sharded inputs (e.g., lambda module, x: module(x))
113
+ mlp_module: the MLP nn.Module object
114
+ x: the input tensor with shape [bs, seqlen, hidden_size] or [seqlen, hidden_size]
115
+ num_shards: number of shards to use. If None, automatically calculated as ceil(seqlen / hidden_size)
116
+ compute_params: list of parameters for DeepSpeed ZeRO optimization
117
+
118
+ Returns:
119
+ output tensor with the same shape as input
120
+ """
121
+ if num_shards is None:
122
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size]
123
+ hidden_size = x.shape[-1]
124
+ seqlen = x.shape[-2]
125
+ num_shards = math.ceil(seqlen / hidden_size)
126
+
127
+ # Ensure num_shards is at least 1
128
+ num_shards = max(1, num_shards)
129
+
130
+ return LigerTiledMLPFunction.apply(
131
+ fn,
132
+ mlp_module,
133
+ x,
134
+ num_shards,
135
+ compute_params,
136
+ )
liger_kernel/ops/utils.py CHANGED
@@ -78,6 +78,8 @@ def get_amp_custom_fwd_bwd() -> Callable:
78
78
  functools.partial(torch.amp.custom_fwd, device_type=device),
79
79
  functools.partial(torch.amp.custom_bwd, device_type=device),
80
80
  )
81
+ if hasattr(torch, "npu") and getattr(torch.npu, "amp", None) is not None:
82
+ return torch.npu.amp.custom_fwd, torch.npu.amp.custom_bwd
81
83
  return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
82
84
 
83
85
 
@@ -125,3 +127,15 @@ def element_mul_kernel(
125
127
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
126
128
  X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
127
129
  tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
130
+
131
+
132
+ def get_npu_core_count(default: int = 20) -> int:
133
+ """Return NPU vector core count.
134
+ Fallback to `default` if Triton runtime or NPU device is unavailable.
135
+ """
136
+ try:
137
+ utils = triton.runtime.driver.active.utils
138
+ props = utils.get_device_properties(0)
139
+ return int(props.get("num_vectorcore", default))
140
+ except Exception:
141
+ return default