liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (114) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +304 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +21 -4
  18. liger_kernel/ops/cross_entropy.py +235 -84
  19. liger_kernel/ops/dyt.py +157 -0
  20. liger_kernel/ops/experimental/embedding.py +1 -3
  21. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  22. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  23. liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
  24. liger_kernel/ops/fused_linear_jsd.py +17 -34
  25. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  26. liger_kernel/ops/geglu.py +7 -18
  27. liger_kernel/ops/group_norm.py +305 -0
  28. liger_kernel/ops/grpo_loss.py +310 -0
  29. liger_kernel/ops/jsd.py +46 -21
  30. liger_kernel/ops/kl_div.py +23 -19
  31. liger_kernel/ops/layer_norm.py +150 -86
  32. liger_kernel/ops/llama4_rope.py +225 -0
  33. liger_kernel/ops/multi_token_attention.py +207 -0
  34. liger_kernel/ops/poly_norm.py +386 -0
  35. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  36. liger_kernel/ops/rms_norm.py +314 -84
  37. liger_kernel/ops/rope.py +32 -34
  38. liger_kernel/ops/softmax.py +201 -0
  39. liger_kernel/ops/sparsemax.py +179 -0
  40. liger_kernel/ops/swiglu.py +5 -9
  41. liger_kernel/ops/tiled_mlp.py +136 -0
  42. liger_kernel/ops/tvd.py +207 -0
  43. liger_kernel/ops/utils.py +8 -4
  44. liger_kernel/transformers/__init__.py +199 -24
  45. liger_kernel/transformers/auto_model.py +6 -13
  46. liger_kernel/transformers/cross_entropy.py +33 -20
  47. liger_kernel/transformers/dyt.py +22 -0
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -3
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +291 -13
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -4
  57. liger_kernel/transformers/group_norm.py +50 -0
  58. liger_kernel/transformers/grpo_loss.py +98 -0
  59. liger_kernel/transformers/jsd.py +2 -7
  60. liger_kernel/transformers/kl_div.py +1 -3
  61. liger_kernel/transformers/layer_norm.py +3 -9
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/falcon_h1.py +122 -0
  64. liger_kernel/transformers/model/gemma.py +77 -77
  65. liger_kernel/transformers/model/gemma2.py +283 -0
  66. liger_kernel/transformers/model/gemma3.py +331 -0
  67. liger_kernel/transformers/model/glm4.py +141 -0
  68. liger_kernel/transformers/model/glm4v.py +163 -0
  69. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  70. liger_kernel/transformers/model/internvl.py +157 -0
  71. liger_kernel/transformers/model/llama.py +128 -79
  72. liger_kernel/transformers/model/llama4.py +121 -0
  73. liger_kernel/transformers/model/llava.py +344 -0
  74. liger_kernel/transformers/model/loss_utils.py +95 -0
  75. liger_kernel/transformers/model/mistral.py +68 -64
  76. liger_kernel/transformers/model/mixtral.py +75 -91
  77. liger_kernel/transformers/model/mllama.py +63 -68
  78. liger_kernel/transformers/model/olmo2.py +141 -0
  79. liger_kernel/transformers/model/output_classes.py +147 -0
  80. liger_kernel/transformers/model/paligemma.py +432 -0
  81. liger_kernel/transformers/model/phi3.py +59 -213
  82. liger_kernel/transformers/model/qwen2.py +75 -72
  83. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  84. liger_kernel/transformers/model/qwen2_vl.py +78 -98
  85. liger_kernel/transformers/model/qwen3.py +136 -0
  86. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  87. liger_kernel/transformers/model/qwen3_next.py +146 -0
  88. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  89. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  90. liger_kernel/transformers/model/smollm3.py +199 -0
  91. liger_kernel/transformers/model/smolvlm.py +158 -0
  92. liger_kernel/transformers/monkey_patch.py +2106 -289
  93. liger_kernel/transformers/multi_token_attention.py +64 -0
  94. liger_kernel/transformers/poly_norm.py +42 -0
  95. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  96. liger_kernel/transformers/rms_norm.py +57 -6
  97. liger_kernel/transformers/rope.py +45 -2
  98. liger_kernel/transformers/softmax.py +12 -0
  99. liger_kernel/transformers/sparsemax.py +16 -0
  100. liger_kernel/transformers/swiglu.py +23 -8
  101. liger_kernel/transformers/tiled_mlp.py +133 -0
  102. liger_kernel/transformers/trainer/__init__.py +4 -0
  103. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  104. liger_kernel/transformers/tvd.py +13 -0
  105. liger_kernel/triton/__init__.py +1 -3
  106. liger_kernel/triton/monkey_patch.py +1 -3
  107. liger_kernel/utils.py +71 -0
  108. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
  109. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  110. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
  111. liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
  112. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
@@ -17,12 +17,10 @@ import torch
17
17
  import triton
18
18
  import triton.language as tl
19
19
 
20
- from liger_kernel.ops.utils import (
21
- calculate_settings,
22
- compare_version,
23
- ensure_contiguous,
24
- torch_to_triton_dtype,
25
- )
20
+ from liger_kernel.ops.utils import calculate_settings
21
+ from liger_kernel.ops.utils import compare_version
22
+ from liger_kernel.ops.utils import ensure_contiguous
23
+ from liger_kernel.ops.utils import torch_to_triton_dtype
26
24
 
27
25
  if compare_version("triton", operator.ge, "3.0.0"):
28
26
  try:
@@ -35,9 +33,9 @@ else:
35
33
  from triton.language.math import rsqrt
36
34
 
37
35
 
38
- _CASTING_MODE_NONE = tl.constexpr(-1)
39
- _CASTING_MODE_LLAMA = tl.constexpr(0)
40
- _CASTING_MODE_GEMMA = tl.constexpr(1)
36
+ _CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1)
37
+ _CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0)
38
+ _CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1)
41
39
 
42
40
 
43
41
  @triton.jit
@@ -65,7 +63,7 @@ def _rms_norm_forward_kernel(
65
63
  3. https://arxiv.org/pdf/1910.07467
66
64
  """
67
65
 
68
- row_idx = tl.program_id(0)
66
+ row_idx = tl.program_id(0).to(tl.int64)
69
67
  col_offsets = tl.arange(0, BLOCK_SIZE)
70
68
  mask = col_offsets < n_cols
71
69
 
@@ -116,6 +114,8 @@ def _rms_norm_forward_kernel(
116
114
  def _rms_norm_backward_kernel(
117
115
  dY_ptr,
118
116
  dY_row_stride,
117
+ dX_ptr,
118
+ dX_row_stride,
119
119
  X_ptr,
120
120
  X_row_stride,
121
121
  X_dtype: tl.constexpr,
@@ -137,7 +137,7 @@ def _rms_norm_backward_kernel(
137
137
  dw = sum(dy * (x / RMS)). summation over BxT dimension
138
138
  """
139
139
 
140
- row_block_id = tl.program_id(0)
140
+ row_block_id = tl.program_id(0).to(tl.int64)
141
141
  row_start = row_block_id * rows_per_program
142
142
  row_end = min((row_block_id + 1) * rows_per_program, n_rows)
143
143
  col_offsets = tl.arange(0, BLOCK_SIZE)
@@ -146,6 +146,8 @@ def _rms_norm_backward_kernel(
146
146
  dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
147
147
 
148
148
  dY_ptr += row_start * dY_row_stride
149
+ dX_ptr += row_start * dX_row_stride
150
+
149
151
  X_ptr += row_start * X_row_stride
150
152
  RSTD_ptr += row_start
151
153
 
@@ -173,9 +175,7 @@ def _rms_norm_backward_kernel(
173
175
 
174
176
  dX_row = rstd_row * m
175
177
 
176
- dX_row += (rstd_row) * (
177
- -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
178
- )
178
+ dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
179
179
 
180
180
  # calculate the gradient of W
181
181
  if casting_mode == _CASTING_MODE_LLAMA:
@@ -184,15 +184,185 @@ def _rms_norm_backward_kernel(
184
184
  # here X_row is already in fp32 (see previous if block)
185
185
  dW_row += dY_row * (X_row * rstd_row)
186
186
 
187
- tl.store(dY_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
187
+ tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
188
188
 
189
189
  dY_ptr += dY_row_stride
190
+ dX_ptr += dX_row_stride
190
191
  X_ptr += X_row_stride
191
192
  RSTD_ptr += RSTD_row_stride
192
193
 
193
194
  tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
194
195
 
195
196
 
197
+ @triton.jit
198
+ def _block_rms_norm_forward_kernel(
199
+ Y_ptr,
200
+ Y_row_stride,
201
+ X_ptr,
202
+ X_row_stride,
203
+ W_ptr,
204
+ W_row_stride,
205
+ RSTD_ptr,
206
+ RSTD_row_stride,
207
+ n_rows,
208
+ n_cols,
209
+ eps,
210
+ offset,
211
+ casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
212
+ BLOCK_SIZE: tl.constexpr,
213
+ BLOCK_ROW: tl.constexpr,
214
+ ):
215
+ """
216
+ y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
217
+
218
+ Reference:
219
+ 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
220
+ 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
221
+ 3. https://arxiv.org/pdf/1910.07467
222
+ """
223
+
224
+ row_idx = tl.program_id(0) * BLOCK_ROW + tl.arange(0, BLOCK_ROW)
225
+ col_offsets = tl.arange(0, BLOCK_SIZE)
226
+ row_mask = row_idx < n_rows
227
+ col_mask = col_offsets < n_cols
228
+
229
+ X_row = tl.load(
230
+ X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
231
+ mask=row_mask[:, None] & col_mask[None, :],
232
+ other=0,
233
+ )
234
+ X_row_dtype = X_row.dtype
235
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
236
+
237
+ # On Llama, only rstd is computed on fp32
238
+ if casting_mode == _CASTING_MODE_LLAMA:
239
+ X_row = X_row.to(tl.float32)
240
+
241
+ # Gemma computes everything on fp32, and then casts back the output to the original dtype
242
+ if casting_mode == _CASTING_MODE_GEMMA:
243
+ W_row = W_row.to(tl.float32)
244
+ X_row = X_row.to(tl.float32)
245
+
246
+ if casting_mode == _CASTING_MODE_NONE:
247
+ eps = eps.to(X_row_dtype)
248
+ offset = offset.to(X_row_dtype)
249
+
250
+ mean_square = tl.sum(X_row * X_row, axis=1) / n_cols
251
+ rstd = rsqrt(mean_square + eps)
252
+
253
+ # We can save time by caching rms with minimal memory overhead
254
+ # because rms is much smaller compared to X_row, as rms is for each row.
255
+ # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
256
+ tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask)
257
+
258
+ X_row = X_row * rstd[:, None]
259
+
260
+ # On Llama, the multiplication with the weight is done on the original dtype
261
+ if casting_mode == _CASTING_MODE_LLAMA:
262
+ X_row = X_row.to(X_row_dtype)
263
+
264
+ Y_row = X_row * (offset + W_row)[None, :]
265
+
266
+ if casting_mode == _CASTING_MODE_GEMMA:
267
+ Y_row = Y_row.to(X_row_dtype)
268
+
269
+ tl.store(
270
+ Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :],
271
+ Y_row,
272
+ mask=row_mask[:, None] & col_mask[None, :],
273
+ )
274
+
275
+
276
+ @triton.jit
277
+ def _block_rms_norm_backward_kernel(
278
+ dY_ptr,
279
+ dY_row_stride,
280
+ dX_ptr,
281
+ dX_row_stride,
282
+ X_ptr,
283
+ X_row_stride,
284
+ X_dtype: tl.constexpr,
285
+ W_ptr,
286
+ W_row_stride,
287
+ RSTD_ptr,
288
+ RSTD_row_stride,
289
+ dW_ptr,
290
+ dW_row_stride,
291
+ n_rows,
292
+ n_cols,
293
+ offset,
294
+ rows_per_program: tl.constexpr,
295
+ casting_mode: tl.constexpr,
296
+ BLOCK_SIZE: tl.constexpr,
297
+ BLOCK_ROW: tl.constexpr,
298
+ ):
299
+ """
300
+ dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
301
+ dw = sum(dy * (x / RMS)). summation over BxT dimension
302
+ """
303
+
304
+ pid = tl.program_id(0).cast(tl.int64)
305
+ NUM_SMS = tl.num_programs(0)
306
+
307
+ col_offsets = tl.arange(0, BLOCK_SIZE)
308
+ col_mask = col_offsets < n_cols
309
+
310
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
311
+
312
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
313
+ W_row = W_row + offset
314
+
315
+ for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
316
+ row_idx = start + tl.arange(0, BLOCK_ROW)
317
+ row_mask = row_idx < n_rows
318
+ dY_row = tl.load(
319
+ dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :],
320
+ mask=row_mask[:, None] & col_mask[None, :],
321
+ other=0.0,
322
+ )
323
+ X_row = tl.load(
324
+ X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
325
+ mask=row_mask[:, None] & col_mask[None, :],
326
+ other=0.0,
327
+ )
328
+
329
+ # Get cached rms
330
+ rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, row_mask)
331
+
332
+ X_row = X_row.to(tl.float32)
333
+
334
+ # Different bacward graphs for different casting modes
335
+ if casting_mode == _CASTING_MODE_LLAMA:
336
+ m = (dY_row * W_row[None, :]).to(tl.float32)
337
+
338
+ elif casting_mode == _CASTING_MODE_GEMMA:
339
+ dY_row = dY_row.to(tl.float32)
340
+ m = dY_row * W_row[None, :]
341
+ else:
342
+ m = dY_row * W_row[None, :]
343
+
344
+ dX_row = rstd_row[:, None] * m
345
+
346
+ dX_row += (rstd_row[:, None]) * (
347
+ -(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
348
+ )
349
+
350
+ # calculate the gradient of W
351
+ if casting_mode == _CASTING_MODE_LLAMA:
352
+ dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]).to(X_dtype), 0)
353
+ else:
354
+ # here X_row is already in fp32 (see previous if block)
355
+ dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
356
+
357
+ tl.store(
358
+ dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
359
+ dX_row,
360
+ mask=row_mask[:, None] & col_mask[None, :],
361
+ )
362
+
363
+ tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
364
+
365
+
196
366
  _str_to_casting_mode = {
197
367
  "llama": _CASTING_MODE_LLAMA.value,
198
368
  "gemma": _CASTING_MODE_GEMMA.value,
@@ -200,16 +370,12 @@ _str_to_casting_mode = {
200
370
  }
201
371
 
202
372
 
203
- def rms_norm_forward(X, W, eps, offset, casting_mode):
373
+ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
204
374
  if not isinstance(casting_mode, int):
205
- assert (
206
- casting_mode in _str_to_casting_mode
207
- ), f"Invalid casting mode: {casting_mode}"
375
+ assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
208
376
  casting_mode = _str_to_casting_mode[casting_mode]
209
377
  else:
210
- assert (
211
- casting_mode in _str_to_casting_mode.values()
212
- ), f"Invalid casting mode: {casting_mode}"
378
+ assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}"
213
379
 
214
380
  shape = X.shape
215
381
  dim = shape[-1]
@@ -220,44 +386,70 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
220
386
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
221
387
  # RSTD is to cache rstd for each row
222
388
  # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
223
- rstd_dtype = (
224
- torch.float32
225
- if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value)
226
- else X.dtype
227
- )
389
+ rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
228
390
  RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
229
391
 
230
392
  # Check constraints.
231
- assert (
232
- X.shape[1] == W.shape[0]
233
- ), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
234
-
235
- _rms_norm_forward_kernel[(n_rows,)](
236
- Y,
237
- Y.stride(0),
238
- X,
239
- X.stride(0),
240
- W,
241
- W.stride(0),
242
- RSTD,
243
- RSTD.stride(0),
244
- n_cols,
245
- eps,
246
- offset,
247
- casting_mode,
248
- BLOCK_SIZE=BLOCK_SIZE,
249
- num_warps=num_warps,
250
- )
393
+ assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
394
+
395
+ # XPU-specific optimization
396
+ kernel_args = {}
397
+ if X.device.type == "xpu":
398
+ kernel_args["grf_mode"] = "large"
399
+ if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
400
+ _rms_norm_forward_kernel[(n_rows,)](
401
+ Y,
402
+ Y.stride(0),
403
+ X,
404
+ X.stride(0),
405
+ W,
406
+ W.stride(0),
407
+ RSTD,
408
+ RSTD.stride(0),
409
+ n_cols,
410
+ eps,
411
+ offset,
412
+ casting_mode,
413
+ BLOCK_SIZE=BLOCK_SIZE,
414
+ num_warps=num_warps,
415
+ **kernel_args, # XPU-specific optimization
416
+ )
417
+ else:
418
+ BLOCK_ROW = 16
419
+ kernel_args["BLOCK_ROW"] = BLOCK_ROW
420
+ _block_rms_norm_forward_kernel[(triton.cdiv(n_rows, BLOCK_ROW),)](
421
+ Y,
422
+ Y.stride(0),
423
+ X,
424
+ X.stride(0),
425
+ W,
426
+ W.stride(0),
427
+ RSTD,
428
+ RSTD.stride(0),
429
+ n_rows,
430
+ n_cols,
431
+ eps,
432
+ offset,
433
+ casting_mode,
434
+ BLOCK_SIZE=BLOCK_SIZE,
435
+ num_warps=num_warps,
436
+ **kernel_args, # XPU-specific optimization
437
+ )
251
438
  return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
252
439
 
253
440
 
254
- def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps):
441
+ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode):
255
442
  shape = dY.shape
256
443
  dim = shape[-1]
257
444
  dY = dY.view(-1, dim)
258
445
  n_rows, n_cols = dY.shape
259
446
 
260
- sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
447
+ sm_count = 1
448
+ if X.device.type == "cuda":
449
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
450
+ elif X.device.type == "xpu":
451
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
452
+
261
453
  # fp32 for numerical stability especially.
262
454
  _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
263
455
 
@@ -265,29 +457,70 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
265
457
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
266
458
  rows_per_program = math.ceil(n_rows / sm_count)
267
459
  grid = (sm_count,)
268
- # Here we use dY to store the value of dX to save memory
269
- _rms_norm_backward_kernel[grid](
270
- dY,
271
- dY.stride(0),
272
- X,
273
- X.stride(0),
274
- torch_to_triton_dtype[X.dtype],
275
- W,
276
- W.stride(0),
277
- RSTD,
278
- RSTD.stride(0),
279
- _dW,
280
- _dW.stride(0),
281
- n_rows,
282
- n_cols,
283
- offset,
284
- rows_per_program,
285
- casting_mode,
286
- BLOCK_SIZE=BLOCK_SIZE,
287
- num_warps=num_warps,
288
- )
289
- dX = dY.view(*shape)
460
+
461
+ if in_place is True:
462
+ dX = dY
463
+ else:
464
+ dX = torch.zeros_like(dY)
465
+
466
+ # XPU-specific optimization
467
+ kernel_args = {}
468
+ if X.device.type == "xpu":
469
+ kernel_args["grf_mode"] = "large"
470
+
471
+ if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
472
+ _rms_norm_backward_kernel[grid](
473
+ dY,
474
+ dY.stride(0),
475
+ dX,
476
+ dX.stride(0),
477
+ X,
478
+ X.stride(0),
479
+ torch_to_triton_dtype[X.dtype],
480
+ W,
481
+ W.stride(0),
482
+ RSTD,
483
+ RSTD.stride(0),
484
+ _dW,
485
+ _dW.stride(0),
486
+ n_rows,
487
+ n_cols,
488
+ offset,
489
+ rows_per_program,
490
+ casting_mode,
491
+ BLOCK_SIZE=BLOCK_SIZE,
492
+ num_warps=num_warps,
493
+ **kernel_args, # XPU-specific optimization
494
+ )
495
+ else:
496
+ BLOCK_ROW = 16
497
+ kernel_args["BLOCK_ROW"] = BLOCK_ROW
498
+ _block_rms_norm_backward_kernel[grid](
499
+ dY,
500
+ dY.stride(0),
501
+ dX,
502
+ dX.stride(0),
503
+ X,
504
+ X.stride(0),
505
+ torch_to_triton_dtype[X.dtype],
506
+ W,
507
+ W.stride(0),
508
+ RSTD,
509
+ RSTD.stride(0),
510
+ _dW,
511
+ _dW.stride(0),
512
+ n_rows,
513
+ n_cols,
514
+ offset,
515
+ rows_per_program,
516
+ casting_mode,
517
+ BLOCK_SIZE=BLOCK_SIZE,
518
+ num_warps=num_warps,
519
+ **kernel_args, # XPU-specific optimization
520
+ )
521
+ dX = dX.view(*shape)
290
522
  dW = _dW.sum(dim=0).to(W.dtype)
523
+
291
524
  return dX, dW
292
525
 
293
526
 
@@ -307,20 +540,24 @@ class LigerRMSNormFunction(torch.autograd.Function):
307
540
  - 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
308
541
  - 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
309
542
  - 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
543
+
544
+ `in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs.
545
+ For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place.
546
+ Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False`
310
547
  """
311
548
 
312
549
  @staticmethod
313
550
  @ensure_contiguous
314
- def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama"):
551
+ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None):
315
552
  """
316
553
  X: (B, T, H) or (BxT, H)
317
554
  W: (H,)
318
555
  """
319
- Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(
320
- X, W, eps, offset, casting_mode
321
- )
556
+ Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
322
557
  ctx.offset = offset
323
558
  ctx.casting_mode = casting_mode
559
+ ctx.in_place = in_place
560
+ ctx.row_mode = row_mode
324
561
  ctx.BLOCK_SIZE = BLOCK_SIZE
325
562
  ctx.num_warps = num_warps
326
563
  ctx.save_for_backward(X, W, RSTD)
@@ -334,13 +571,6 @@ class LigerRMSNormFunction(torch.autograd.Function):
334
571
  """
335
572
  X, W, RSTD = ctx.saved_tensors
336
573
  dX, dW = rms_norm_backward(
337
- dY,
338
- X,
339
- W,
340
- RSTD,
341
- ctx.offset,
342
- ctx.casting_mode,
343
- ctx.BLOCK_SIZE,
344
- ctx.num_warps,
574
+ dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
345
575
  )
346
- return dX, dW, None, None, None
576
+ return dX, dW, None, None, None, None, None
liger_kernel/ops/rope.py CHANGED
@@ -15,6 +15,7 @@ def _triton_rope(
15
15
  sin_row_stride,
16
16
  sl,
17
17
  bs: tl.constexpr,
18
+ cos_bs: tl.constexpr,
18
19
  n_qh: tl.constexpr,
19
20
  n_kh: tl.constexpr,
20
21
  hd: tl.constexpr,
@@ -29,9 +30,9 @@ def _triton_rope(
29
30
  # k size: (bsz, seq_len, num_kv_heads, head_dim)
30
31
  # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
31
32
 
32
- # cos size: (1, seq_len, head_dim)
33
+ # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
33
34
  # stride: (seq_len * head_dim, head_dim, 1)
34
- pid = tl.program_id(0)
35
+ pid = tl.program_id(0).to(tl.int64)
35
36
 
36
37
  # locate start address
37
38
  q_ptr = q_ptr + pid * q_row_stride
@@ -48,9 +49,19 @@ def _triton_rope(
48
49
  # and pid % sl to get the sequence index.
49
50
  # 2. We only need the left half of cos and sin matrix because the right half is just
50
51
  # a clone of the left half.
51
- cos_row_idx = pid % (sl)
52
- cos = cos + cos_row_idx * cos_row_stride
53
- sin = sin + cos_row_idx * sin_row_stride
52
+ batch_idx = pid // sl
53
+ cos_row_idx = pid % sl
54
+ cos = cos + tl.where(
55
+ cos_bs == 1,
56
+ cos_row_idx * cos_row_stride,
57
+ batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
58
+ )
59
+ sin = sin + tl.where(
60
+ cos_bs == 1,
61
+ cos_row_idx * sin_row_stride,
62
+ batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
63
+ )
64
+
54
65
  cos_offsets = tl.arange(0, pad_hd // 2)
55
66
  cos_mask = cos_offsets < hd // 2
56
67
  cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
@@ -61,36 +72,20 @@ def _triton_rope(
61
72
  # program instance (i.e. for the current token) separately
62
73
  # ####################################################################
63
74
  # left half of the head
64
- first_half_q_offsets = (
65
- tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
66
- )
67
- first_half_k_offsets = (
68
- tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
69
- )
70
- first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
71
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
72
- )
73
- first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
74
- tl.arange(0, pad_hd // 2)[None, :] < hd // 2
75
- )
76
- q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
77
- sin_row.dtype
78
- )
79
- k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
80
- sin_row.dtype
81
- )
75
+ first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
76
+ first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
77
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
78
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
79
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
80
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
82
81
 
83
82
  # right half of the head
84
83
  second_half_q_offsets = first_half_q_offsets + (hd // 2)
85
84
  second_half_k_offsets = first_half_k_offsets + (hd // 2)
86
85
  second_q_mask = first_q_mask
87
86
  second_k_mask = first_k_mask
88
- q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
89
- sin_row.dtype
90
- )
91
- k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
92
- sin_row.dtype
93
- )
87
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
88
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
94
89
 
95
90
  if not BACKWARD_PASS:
96
91
  # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
@@ -118,7 +113,6 @@ def _triton_rope(
118
113
 
119
114
 
120
115
  def rope_forward(q, k, cos, sin):
121
-
122
116
  # transpose it back to the physical shape because Triton looks at the physical storage
123
117
  # note: q and k are incontiguous before the transformation and will become contiguous after transpose
124
118
  q = q.transpose(1, 2)
@@ -138,6 +132,7 @@ def rope_forward(q, k, cos, sin):
138
132
  k = k.contiguous()
139
133
  cos = cos.contiguous()
140
134
  sin = sin.contiguous()
135
+ cos_batch_size = cos.shape[0]
141
136
 
142
137
  _triton_rope[(n_row,)](
143
138
  q,
@@ -150,6 +145,7 @@ def rope_forward(q, k, cos, sin):
150
145
  sin.stride(-2),
151
146
  seq_len,
152
147
  batch_size,
148
+ cos_batch_size,
153
149
  n_q_head,
154
150
  n_kv_head,
155
151
  head_dim,
@@ -167,6 +163,7 @@ def rope_backward(dq, dk, cos, sin):
167
163
  dk = dk.transpose(1, 2)
168
164
 
169
165
  batch_size, seq_len, n_q_head, head_dim = dq.shape
166
+ cos_batch_size = cos.shape[0]
170
167
  n_kv_head = dk.shape[2]
171
168
  pad_hd = triton.next_power_of_2(head_dim)
172
169
  pad_n_q_head = triton.next_power_of_2(n_q_head)
@@ -191,6 +188,7 @@ def rope_backward(dq, dk, cos, sin):
191
188
  sin.stride(-2),
192
189
  seq_len,
193
190
  batch_size,
191
+ cos_batch_size,
194
192
  n_q_head,
195
193
  n_kv_head,
196
194
  head_dim,
@@ -221,8 +219,8 @@ class LigerRopeFunction(torch.autograd.Function):
221
219
  """
222
220
  q size: (bsz, n_q_head, seq_len, head_dim)
223
221
  k size: (bsz, n_kv_head, seq_len, head_dim)
224
- cos size: (1, seq_len, head_dim)
225
- sin size: (1, seq_len, head_dim)
222
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
223
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
226
224
  """
227
225
  q, k, cos, sin = rope_forward(q, k, cos, sin)
228
226
  ctx.save_for_backward(cos, sin)
@@ -232,8 +230,8 @@ class LigerRopeFunction(torch.autograd.Function):
232
230
  """
233
231
  dq size: (bsz, n_q_head, seq_len, head_dim)
234
232
  dk size: (bsz, n_kv_head, seq_len, head_dim)
235
- cos size: (1, seq_len, head_dim)
236
- sin size: (1, seq_len, head_dim)
233
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
234
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
237
235
  """
238
236
 
239
237
  cos, sin = ctx.saved_tensors