liger-kernel-nightly 0.5.5.dev20250402185702__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (115) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +36 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +46 -15
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  15. liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
  16. liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
  17. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  18. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  19. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  20. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  21. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  22. liger_kernel/ops/backends/registry.py +61 -0
  23. liger_kernel/ops/cross_entropy.py +134 -65
  24. liger_kernel/ops/dyt.py +115 -180
  25. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  26. liger_kernel/ops/fused_linear_cross_entropy.py +117 -23
  27. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  28. liger_kernel/ops/geglu.py +6 -4
  29. liger_kernel/ops/group_norm.py +7 -7
  30. liger_kernel/ops/grpo_loss.py +312 -0
  31. liger_kernel/ops/jsd.py +2 -1
  32. liger_kernel/ops/kl_div.py +9 -5
  33. liger_kernel/ops/layer_norm.py +146 -78
  34. liger_kernel/ops/llama4_rope.py +225 -0
  35. liger_kernel/ops/multi_token_attention.py +207 -0
  36. liger_kernel/ops/poly_norm.py +390 -0
  37. liger_kernel/ops/rms_norm.py +398 -99
  38. liger_kernel/ops/rope.py +1 -1
  39. liger_kernel/ops/softmax.py +201 -0
  40. liger_kernel/ops/sparsemax.py +179 -0
  41. liger_kernel/ops/swiglu.py +1 -1
  42. liger_kernel/ops/tiled_mlp.py +136 -0
  43. liger_kernel/ops/utils.py +14 -0
  44. liger_kernel/transformers/__init__.py +208 -17
  45. liger_kernel/transformers/auto_model.py +21 -0
  46. liger_kernel/transformers/cross_entropy.py +9 -4
  47. liger_kernel/transformers/dyt.py +6 -4
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -1
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +122 -20
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -1
  57. liger_kernel/transformers/group_norm.py +1 -1
  58. liger_kernel/transformers/grpo_loss.py +153 -0
  59. liger_kernel/transformers/jsd.py +1 -1
  60. liger_kernel/transformers/kl_div.py +1 -1
  61. liger_kernel/transformers/layer_norm.py +1 -1
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/exaone4.py +136 -0
  64. liger_kernel/transformers/model/falcon_h1.py +122 -0
  65. liger_kernel/transformers/model/gemma.py +57 -27
  66. liger_kernel/transformers/model/gemma2.py +65 -28
  67. liger_kernel/transformers/model/gemma3.py +331 -0
  68. liger_kernel/transformers/model/glm4.py +141 -0
  69. liger_kernel/transformers/model/glm4v.py +163 -0
  70. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  71. liger_kernel/transformers/model/gpt_oss.py +211 -0
  72. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  73. liger_kernel/transformers/model/internvl.py +157 -0
  74. liger_kernel/transformers/model/llama.py +109 -27
  75. liger_kernel/transformers/model/llama4.py +121 -0
  76. liger_kernel/transformers/model/llava.py +111 -136
  77. liger_kernel/transformers/model/loss_utils.py +50 -12
  78. liger_kernel/transformers/model/mistral.py +51 -34
  79. liger_kernel/transformers/model/mixtral.py +50 -29
  80. liger_kernel/transformers/model/mllama.py +46 -24
  81. liger_kernel/transformers/model/olmo2.py +47 -22
  82. liger_kernel/transformers/model/olmo3.py +142 -0
  83. liger_kernel/transformers/model/output_classes.py +147 -0
  84. liger_kernel/transformers/model/paligemma.py +50 -14
  85. liger_kernel/transformers/model/phi3.py +47 -172
  86. liger_kernel/transformers/model/qwen2.py +55 -23
  87. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  88. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  89. liger_kernel/transformers/model/qwen3.py +136 -0
  90. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  91. liger_kernel/transformers/model/qwen3_next.py +146 -0
  92. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  93. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  94. liger_kernel/transformers/model/smollm3.py +199 -0
  95. liger_kernel/transformers/model/smolvlm.py +158 -0
  96. liger_kernel/transformers/monkey_patch.py +2018 -244
  97. liger_kernel/transformers/multi_token_attention.py +64 -0
  98. liger_kernel/transformers/poly_norm.py +42 -0
  99. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  100. liger_kernel/transformers/rms_norm.py +54 -6
  101. liger_kernel/transformers/rope.py +45 -1
  102. liger_kernel/transformers/softmax.py +12 -0
  103. liger_kernel/transformers/sparsemax.py +16 -0
  104. liger_kernel/transformers/swiglu.py +39 -1
  105. liger_kernel/transformers/tiled_mlp.py +125 -0
  106. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  107. liger_kernel/transformers/tvd.py +1 -1
  108. liger_kernel/utils.py +63 -0
  109. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +73 -39
  110. liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
  111. liger_kernel_nightly-0.5.5.dev20250402185702.dist-info/RECORD +0 -80
  112. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
  115. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
@@ -21,8 +21,10 @@ from liger_kernel.ops.utils import calculate_settings
21
21
  from liger_kernel.ops.utils import compare_version
22
22
  from liger_kernel.ops.utils import ensure_contiguous
23
23
  from liger_kernel.ops.utils import torch_to_triton_dtype
24
+ from liger_kernel.utils import get_npu_multi_processor_count
25
+ from liger_kernel.utils import is_npu_available
24
26
 
25
- if compare_version("triton", operator.ge, "3.0.0"):
27
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
26
28
  try:
27
29
  # typical import path with dispatch available
28
30
  from triton.language.extra.libdevice import rsqrt
@@ -52,6 +54,7 @@ def _rms_norm_forward_kernel(
52
54
  eps,
53
55
  offset,
54
56
  casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
57
+ elementwise_affine: tl.constexpr,
55
58
  BLOCK_SIZE: tl.constexpr,
56
59
  ):
57
60
  """
@@ -63,17 +66,18 @@ def _rms_norm_forward_kernel(
63
66
  3. https://arxiv.org/pdf/1910.07467
64
67
  """
65
68
 
66
- row_idx = tl.program_id(0)
69
+ row_idx = tl.program_id(0).to(tl.int64)
67
70
  col_offsets = tl.arange(0, BLOCK_SIZE)
68
71
  mask = col_offsets < n_cols
69
72
 
70
- Y_ptr += row_idx * Y_row_stride
71
- X_ptr += row_idx * X_row_stride
72
- RSTD_ptr += row_idx * RSTD_row_stride
73
+ y_base = Y_ptr + row_idx * Y_row_stride
74
+ x_base = X_ptr + row_idx * X_row_stride
75
+ rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
73
76
 
74
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
77
+ X_row = tl.load(x_base + col_offsets, mask=mask, other=0)
75
78
  X_row_dtype = X_row.dtype
76
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
79
+ if elementwise_affine:
80
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
77
81
 
78
82
  # On Llama, only rstd is computed on fp32
79
83
  if casting_mode == _CASTING_MODE_LLAMA:
@@ -81,7 +85,8 @@ def _rms_norm_forward_kernel(
81
85
 
82
86
  # Gemma computes everything on fp32, and then casts back the output to the original dtype
83
87
  if casting_mode == _CASTING_MODE_GEMMA:
84
- W_row = W_row.to(tl.float32)
88
+ if elementwise_affine:
89
+ W_row = W_row.to(tl.float32)
85
90
  X_row = X_row.to(tl.float32)
86
91
 
87
92
  if casting_mode == _CASTING_MODE_NONE:
@@ -94,7 +99,7 @@ def _rms_norm_forward_kernel(
94
99
  # We can save time by caching rms with minimal memory overhead
95
100
  # because rms is much smaller compared to X_row, as rms is for each row.
96
101
  # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
97
- tl.store(RSTD_ptr, rstd)
102
+ tl.store(rstd_base, rstd)
98
103
 
99
104
  X_row = X_row * rstd
100
105
 
@@ -102,12 +107,15 @@ def _rms_norm_forward_kernel(
102
107
  if casting_mode == _CASTING_MODE_LLAMA:
103
108
  X_row = X_row.to(X_row_dtype)
104
109
 
105
- Y_row = X_row * (offset + W_row)
110
+ if elementwise_affine:
111
+ Y_row = X_row * (offset + W_row)
112
+ else:
113
+ Y_row = X_row
106
114
 
107
115
  if casting_mode == _CASTING_MODE_GEMMA:
108
116
  Y_row = Y_row.to(X_row_dtype)
109
117
 
110
- tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
118
+ tl.store(y_base + col_offsets, Y_row, mask=mask)
111
119
 
112
120
 
113
121
  @triton.jit
@@ -128,8 +136,9 @@ def _rms_norm_backward_kernel(
128
136
  n_rows,
129
137
  n_cols,
130
138
  offset,
131
- rows_per_program: tl.constexpr,
139
+ rows_per_program,
132
140
  casting_mode: tl.constexpr,
141
+ elementwise_affine: tl.constexpr,
133
142
  BLOCK_SIZE: tl.constexpr,
134
143
  ):
135
144
  """
@@ -137,61 +146,256 @@ def _rms_norm_backward_kernel(
137
146
  dw = sum(dy * (x / RMS)). summation over BxT dimension
138
147
  """
139
148
 
140
- row_block_id = tl.program_id(0)
149
+ row_block_id = tl.program_id(0).to(tl.int64)
141
150
  row_start = row_block_id * rows_per_program
142
151
  row_end = min((row_block_id + 1) * rows_per_program, n_rows)
143
152
  col_offsets = tl.arange(0, BLOCK_SIZE)
144
153
  mask = col_offsets < n_cols
145
154
 
146
- dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
155
+ if elementwise_affine:
156
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
147
157
 
148
- dY_ptr += row_start * dY_row_stride
149
- dX_ptr += row_start * dX_row_stride
158
+ if elementwise_affine:
159
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
160
+ W_row = W_row + offset
150
161
 
151
- X_ptr += row_start * X_row_stride
152
- RSTD_ptr += row_start
162
+ for row_idx in range(row_start, row_end):
163
+ dy_base = dY_ptr + row_idx * dY_row_stride
164
+ dx_base = dX_ptr + row_idx * dX_row_stride
153
165
 
154
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
155
- W_row = W_row + offset
166
+ x_base = X_ptr + row_idx * X_row_stride
167
+ rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
156
168
 
157
- for _ in range(row_start, row_end):
158
- dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
159
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
169
+ dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
170
+ X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
160
171
 
161
172
  # Get cached rms
162
- rstd_row = tl.load(RSTD_ptr)
173
+ rstd_row = tl.load(rstd_base)
163
174
 
164
175
  X_row = X_row.to(tl.float32)
165
176
 
166
177
  # Different bacward graphs for different casting modes
167
178
  if casting_mode == _CASTING_MODE_LLAMA:
168
- m = (dY_row * W_row).to(tl.float32)
179
+ if elementwise_affine:
180
+ m = (dY_row * W_row).to(tl.float32)
181
+ else:
182
+ m = dY_row.to(tl.float32)
169
183
 
170
184
  elif casting_mode == _CASTING_MODE_GEMMA:
171
185
  dY_row = dY_row.to(tl.float32)
172
- m = dY_row * W_row
186
+ if elementwise_affine:
187
+ m = dY_row * W_row
188
+ else:
189
+ m = dY_row
173
190
  else:
174
- m = dY_row * W_row
191
+ if elementwise_affine:
192
+ m = dY_row * W_row
193
+ else:
194
+ m = dY_row
175
195
 
176
196
  dX_row = rstd_row * m
177
197
 
178
198
  dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
179
199
 
180
- # calculate the gradient of W
200
+ if elementwise_affine:
201
+ # calculate the gradient of W
202
+ if casting_mode == _CASTING_MODE_LLAMA:
203
+ dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
204
+ else:
205
+ # here X_row is already in fp32 (see previous if block)
206
+ dW_row += dY_row * (X_row * rstd_row)
207
+
208
+ tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
209
+
210
+ if elementwise_affine:
211
+ tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
212
+
213
+
214
+ @triton.jit
215
+ def _block_rms_norm_forward_kernel(
216
+ Y_ptr,
217
+ Y_row_stride,
218
+ X_ptr,
219
+ X_row_stride,
220
+ W_ptr,
221
+ W_row_stride,
222
+ RSTD_ptr,
223
+ RSTD_row_stride,
224
+ n_rows,
225
+ n_cols,
226
+ eps,
227
+ offset,
228
+ casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
229
+ elementwise_affine: tl.constexpr,
230
+ BLOCK_SIZE: tl.constexpr,
231
+ BLOCK_ROW: tl.constexpr,
232
+ ):
233
+ """
234
+ y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
235
+
236
+ Reference:
237
+ 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
238
+ 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
239
+ 3. https://arxiv.org/pdf/1910.07467
240
+ """
241
+
242
+ row_idx = tl.program_id(0) * BLOCK_ROW + tl.arange(0, BLOCK_ROW)
243
+ col_offsets = tl.arange(0, BLOCK_SIZE)
244
+ row_mask = row_idx < n_rows
245
+ col_mask = col_offsets < n_cols
246
+
247
+ X_row = tl.load(
248
+ X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
249
+ mask=row_mask[:, None] & col_mask[None, :],
250
+ other=0,
251
+ )
252
+ X_row_dtype = X_row.dtype
253
+ if elementwise_affine:
254
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
255
+
256
+ # On Llama, only rstd is computed on fp32
257
+ if casting_mode == _CASTING_MODE_LLAMA:
258
+ X_row = X_row.to(tl.float32)
259
+
260
+ # Gemma computes everything on fp32, and then casts back the output to the original dtype
261
+ if casting_mode == _CASTING_MODE_GEMMA:
262
+ if elementwise_affine:
263
+ W_row = W_row.to(tl.float32)
264
+ X_row = X_row.to(tl.float32)
265
+
266
+ if casting_mode == _CASTING_MODE_NONE:
267
+ eps = eps.to(X_row_dtype)
268
+ offset = offset.to(X_row_dtype)
269
+
270
+ mean_square = tl.sum(X_row * X_row, axis=1) / n_cols
271
+ rstd = rsqrt(mean_square + eps)
272
+
273
+ # We can save time by caching rms with minimal memory overhead
274
+ # because rms is much smaller compared to X_row, as rms is for each row.
275
+ # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
276
+ tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask)
277
+
278
+ X_row = X_row * rstd[:, None]
279
+
280
+ # On Llama, the multiplication with the weight is done on the original dtype
281
+ if casting_mode == _CASTING_MODE_LLAMA:
282
+ X_row = X_row.to(X_row_dtype)
283
+
284
+ if elementwise_affine:
285
+ Y_row = X_row * (offset + W_row)[None, :]
286
+ else:
287
+ Y_row = X_row
288
+
289
+ if casting_mode == _CASTING_MODE_GEMMA:
290
+ Y_row = Y_row.to(X_row_dtype)
291
+
292
+ tl.store(
293
+ Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :],
294
+ Y_row,
295
+ mask=row_mask[:, None] & col_mask[None, :],
296
+ )
297
+
298
+
299
+ @triton.jit
300
+ def _block_rms_norm_backward_kernel(
301
+ dY_ptr,
302
+ dY_row_stride,
303
+ dX_ptr,
304
+ dX_row_stride,
305
+ X_ptr,
306
+ X_row_stride,
307
+ X_dtype: tl.constexpr,
308
+ W_ptr,
309
+ W_row_stride,
310
+ RSTD_ptr,
311
+ RSTD_row_stride,
312
+ dW_ptr,
313
+ dW_row_stride,
314
+ n_rows,
315
+ n_cols,
316
+ offset,
317
+ casting_mode: tl.constexpr,
318
+ elementwise_affine: tl.constexpr,
319
+ BLOCK_SIZE: tl.constexpr,
320
+ BLOCK_ROW: tl.constexpr,
321
+ ):
322
+ """
323
+ dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
324
+ dw = sum(dy * (x / RMS)). summation over BxT dimension
325
+ """
326
+
327
+ pid = tl.program_id(0).cast(tl.int64)
328
+ NUM_SMS = tl.num_programs(0)
329
+
330
+ col_offsets = tl.arange(0, BLOCK_SIZE)
331
+ col_mask = col_offsets < n_cols
332
+
333
+ if elementwise_affine:
334
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
335
+
336
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
337
+ W_row = W_row + offset
338
+
339
+ for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
340
+ row_idx = start + tl.arange(0, BLOCK_ROW)
341
+ row_mask = row_idx < n_rows
342
+ dY_row = tl.load(
343
+ dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :],
344
+ mask=row_mask[:, None] & col_mask[None, :],
345
+ other=0.0,
346
+ )
347
+ X_row = tl.load(
348
+ X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
349
+ mask=row_mask[:, None] & col_mask[None, :],
350
+ other=0.0,
351
+ )
352
+
353
+ # Get cached rms
354
+ rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, row_mask)
355
+
356
+ X_row = X_row.to(tl.float32)
357
+
358
+ # Different bacward graphs for different casting modes
181
359
  if casting_mode == _CASTING_MODE_LLAMA:
182
- dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
360
+ if elementwise_affine:
361
+ m = (dY_row * W_row[None, :]).to(tl.float32)
362
+ else:
363
+ m = dY_row.to(tl.float32)
364
+
365
+ elif casting_mode == _CASTING_MODE_GEMMA:
366
+ dY_row = dY_row.to(tl.float32)
367
+ if elementwise_affine:
368
+ m = dY_row * W_row[None, :]
369
+ else:
370
+ m = dY_row
183
371
  else:
184
- # here X_row is already in fp32 (see previous if block)
185
- dW_row += dY_row * (X_row * rstd_row)
372
+ if elementwise_affine:
373
+ m = dY_row * W_row[None, :]
374
+ else:
375
+ m = dY_row
186
376
 
187
- tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
377
+ dX_row = rstd_row[:, None] * m
188
378
 
189
- dY_ptr += dY_row_stride
190
- dX_ptr += dX_row_stride
191
- X_ptr += X_row_stride
192
- RSTD_ptr += RSTD_row_stride
379
+ dX_row += (rstd_row[:, None]) * (
380
+ -(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
381
+ )
382
+
383
+ if elementwise_affine:
384
+ if casting_mode == _CASTING_MODE_LLAMA:
385
+ # TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
386
+ dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
387
+ else:
388
+ # here X_row is already in fp32 (see previous if block)
389
+ dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
390
+
391
+ tl.store(
392
+ dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
393
+ dX_row,
394
+ mask=row_mask[:, None] & col_mask[None, :],
395
+ )
193
396
 
194
- tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
397
+ if elementwise_affine:
398
+ tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
195
399
 
196
400
 
197
401
  _str_to_casting_mode = {
@@ -201,7 +405,7 @@ _str_to_casting_mode = {
201
405
  }
202
406
 
203
407
 
204
- def rms_norm_forward(X, W, eps, offset, casting_mode):
408
+ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
205
409
  if not isinstance(casting_mode, int):
206
410
  assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
207
411
  casting_mode = _str_to_casting_mode[casting_mode]
@@ -220,29 +424,64 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
220
424
  rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
221
425
  RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
222
426
 
223
- # Check constraints.
224
- assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
225
-
226
- _rms_norm_forward_kernel[(n_rows,)](
227
- Y,
228
- Y.stride(0),
229
- X,
230
- X.stride(0),
231
- W,
232
- W.stride(0),
233
- RSTD,
234
- RSTD.stride(0),
235
- n_cols,
236
- eps,
237
- offset,
238
- casting_mode,
239
- BLOCK_SIZE=BLOCK_SIZE,
240
- num_warps=num_warps,
241
- )
427
+ if W is not None:
428
+ # Check constraints.
429
+ assert X.shape[1] == W.shape[0], (
430
+ "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
431
+ )
432
+ elementwise_affine = True
433
+ else:
434
+ elementwise_affine = False
435
+
436
+ # XPU-specific optimization
437
+ kernel_args = {}
438
+ if X.device.type == "xpu":
439
+ kernel_args["grf_mode"] = "large"
440
+ if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
441
+ _rms_norm_forward_kernel[(n_rows,)](
442
+ Y,
443
+ Y.stride(0),
444
+ X,
445
+ X.stride(0),
446
+ W,
447
+ W.stride(0) if elementwise_affine else 0,
448
+ RSTD,
449
+ RSTD.stride(0),
450
+ n_cols,
451
+ eps,
452
+ offset,
453
+ casting_mode,
454
+ elementwise_affine=elementwise_affine,
455
+ BLOCK_SIZE=BLOCK_SIZE,
456
+ num_warps=num_warps,
457
+ **kernel_args, # XPU-specific optimization
458
+ )
459
+ else:
460
+ BLOCK_ROW = 16
461
+ kernel_args["BLOCK_ROW"] = BLOCK_ROW
462
+ _block_rms_norm_forward_kernel[(triton.cdiv(n_rows, BLOCK_ROW),)](
463
+ Y,
464
+ Y.stride(0),
465
+ X,
466
+ X.stride(0),
467
+ W,
468
+ W.stride(0) if elementwise_affine else 0,
469
+ RSTD,
470
+ RSTD.stride(0),
471
+ n_rows,
472
+ n_cols,
473
+ eps,
474
+ offset,
475
+ casting_mode,
476
+ elementwise_affine=elementwise_affine,
477
+ BLOCK_SIZE=BLOCK_SIZE,
478
+ num_warps=num_warps,
479
+ **kernel_args, # XPU-specific optimization
480
+ )
242
481
  return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
243
482
 
244
483
 
245
- def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
484
+ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode):
246
485
  shape = dY.shape
247
486
  dim = shape[-1]
248
487
  dY = dY.view(-1, dim)
@@ -252,10 +491,17 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
252
491
  if X.device.type == "cuda":
253
492
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
254
493
  elif X.device.type == "xpu":
255
- sm_count = torch.xpu.get_device_properties(X.device).gpu_subslice_count
256
-
257
- # fp32 for numerical stability especially.
258
- _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
494
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
495
+ elif X.device.type == "npu":
496
+ sm_count = get_npu_multi_processor_count()
497
+
498
+ if W is not None:
499
+ # fp32 for numerical stability especially.
500
+ _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
501
+ elementwise_affine = True
502
+ else:
503
+ _dW = None
504
+ elementwise_affine = False
259
505
 
260
506
  if n_cols > BLOCK_SIZE:
261
507
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
@@ -267,30 +513,68 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
267
513
  else:
268
514
  dX = torch.zeros_like(dY)
269
515
 
270
- _rms_norm_backward_kernel[grid](
271
- dY,
272
- dY.stride(0),
273
- dX,
274
- dX.stride(0),
275
- X,
276
- X.stride(0),
277
- torch_to_triton_dtype[X.dtype],
278
- W,
279
- W.stride(0),
280
- RSTD,
281
- RSTD.stride(0),
282
- _dW,
283
- _dW.stride(0),
284
- n_rows,
285
- n_cols,
286
- offset,
287
- rows_per_program,
288
- casting_mode,
289
- BLOCK_SIZE=BLOCK_SIZE,
290
- num_warps=num_warps,
291
- )
516
+ # XPU-specific optimization
517
+ kernel_args = {}
518
+ if X.device.type == "xpu":
519
+ kernel_args["grf_mode"] = "large"
520
+
521
+ if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
522
+ _rms_norm_backward_kernel[grid](
523
+ dY,
524
+ dY.stride(0),
525
+ dX,
526
+ dX.stride(0),
527
+ X,
528
+ X.stride(0),
529
+ torch_to_triton_dtype[X.dtype],
530
+ W,
531
+ W.stride(0) if elementwise_affine else 0,
532
+ RSTD,
533
+ RSTD.stride(0),
534
+ _dW,
535
+ _dW.stride(0) if elementwise_affine else 0,
536
+ n_rows,
537
+ n_cols,
538
+ offset,
539
+ rows_per_program,
540
+ casting_mode,
541
+ elementwise_affine=elementwise_affine,
542
+ BLOCK_SIZE=BLOCK_SIZE,
543
+ num_warps=num_warps,
544
+ **kernel_args, # XPU-specific optimization
545
+ )
546
+ else:
547
+ BLOCK_ROW = 16
548
+ kernel_args["BLOCK_ROW"] = BLOCK_ROW
549
+ _block_rms_norm_backward_kernel[grid](
550
+ dY,
551
+ dY.stride(0),
552
+ dX,
553
+ dX.stride(0),
554
+ X,
555
+ X.stride(0),
556
+ torch_to_triton_dtype[X.dtype],
557
+ W,
558
+ W.stride(0) if elementwise_affine else 0,
559
+ RSTD,
560
+ RSTD.stride(0),
561
+ _dW,
562
+ _dW.stride(0) if elementwise_affine else 0,
563
+ n_rows,
564
+ n_cols,
565
+ offset,
566
+ casting_mode,
567
+ elementwise_affine=elementwise_affine,
568
+ BLOCK_SIZE=BLOCK_SIZE,
569
+ num_warps=num_warps,
570
+ **kernel_args, # XPU-specific optimization
571
+ )
292
572
  dX = dX.view(*shape)
293
- dW = _dW.sum(dim=0).to(W.dtype)
573
+
574
+ if elementwise_affine:
575
+ dW = _dW.sum(dim=0).to(W.dtype)
576
+ else:
577
+ dW = None
294
578
 
295
579
  return dX, dW
296
580
 
@@ -319,18 +603,30 @@ class LigerRMSNormFunction(torch.autograd.Function):
319
603
 
320
604
  @staticmethod
321
605
  @ensure_contiguous
322
- def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True):
606
+ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None):
323
607
  """
324
608
  X: (B, T, H) or (BxT, H)
325
609
  W: (H,)
326
610
  """
327
- Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
611
+ if isinstance(X, torch.distributed.tensor.DTensor):
612
+ # Input tensor is output of a tensor parallel module and
613
+ # needs to be gathered to a local tensor to compute
614
+ # RMSE layer norm on each TP worker.
615
+ # TODO: support CP.
616
+ X = X.full_tensor()
617
+
618
+ Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
328
619
  ctx.offset = offset
329
620
  ctx.casting_mode = casting_mode
330
621
  ctx.in_place = in_place
622
+ ctx.row_mode = row_mode
331
623
  ctx.BLOCK_SIZE = BLOCK_SIZE
332
624
  ctx.num_warps = num_warps
333
- ctx.save_for_backward(X, W, RSTD)
625
+ ctx.elementwise_affine = W is not None
626
+ if W is not None:
627
+ ctx.save_for_backward(X, W, RSTD)
628
+ else:
629
+ ctx.save_for_backward(X, RSTD)
334
630
  return Y
335
631
 
336
632
  @staticmethod
@@ -339,16 +635,19 @@ class LigerRMSNormFunction(torch.autograd.Function):
339
635
  """
340
636
  Y: (B, T, H) or (BxT, H)
341
637
  """
342
- X, W, RSTD = ctx.saved_tensors
638
+ if ctx.elementwise_affine:
639
+ X, W, RSTD = ctx.saved_tensors
640
+ else:
641
+ X, RSTD = ctx.saved_tensors
642
+ W = None
643
+
644
+ if isinstance(dY, torch.distributed.tensor.DTensor):
645
+ # Gradients are output of a tensor parallel module and
646
+ # needs to be gathered to a local tensor for computing RMSE layer.
647
+ # TODO: support CP.
648
+ dY = dY.full_tensor()
649
+
343
650
  dX, dW = rms_norm_backward(
344
- dY,
345
- X,
346
- W,
347
- RSTD,
348
- ctx.offset,
349
- ctx.casting_mode,
350
- ctx.BLOCK_SIZE,
351
- ctx.num_warps,
352
- ctx.in_place,
651
+ dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
353
652
  )
354
- return dX, dW, None, None, None, None
653
+ return dX, dW, None, None, None, None, None
liger_kernel/ops/rope.py CHANGED
@@ -32,7 +32,7 @@ def _triton_rope(
32
32
 
33
33
  # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
34
34
  # stride: (seq_len * head_dim, head_dim, 1)
35
- pid = tl.program_id(0)
35
+ pid = tl.program_id(0).to(tl.int64)
36
36
 
37
37
  # locate start address
38
38
  q_ptr = q_ptr + pid * q_row_stride