liger-kernel-nightly 0.5.10.dev20250528223524__py3-none-any.whl → 0.5.10.dev20250601024230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -193,6 +193,7 @@ def _rms_norm_backward_kernel(
193
193
 
194
194
  tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
195
195
 
196
+
196
197
  @triton.jit
197
198
  def _block_rms_norm_forward_kernel(
198
199
  Y_ptr,
@@ -225,8 +226,11 @@ def _block_rms_norm_forward_kernel(
225
226
  row_mask = row_idx < n_rows
226
227
  col_mask = col_offsets < n_cols
227
228
 
228
-
229
- X_row = tl.load(X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :], mask=row_mask[:, None] & col_mask[None, :] , other=0)
229
+ X_row = tl.load(
230
+ X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
231
+ mask=row_mask[:, None] & col_mask[None, :],
232
+ other=0,
233
+ )
230
234
  X_row_dtype = X_row.dtype
231
235
  W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
232
236
 
@@ -262,7 +266,12 @@ def _block_rms_norm_forward_kernel(
262
266
  if casting_mode == _CASTING_MODE_GEMMA:
263
267
  Y_row = Y_row.to(X_row_dtype)
264
268
 
265
- tl.store(Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :], Y_row, mask=row_mask[:, None] & col_mask[None, :])
269
+ tl.store(
270
+ Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :],
271
+ Y_row,
272
+ mask=row_mask[:, None] & col_mask[None, :],
273
+ )
274
+
266
275
 
267
276
  @triton.jit
268
277
  def _block_rms_norm_backward_kernel(
@@ -306,8 +315,16 @@ def _block_rms_norm_backward_kernel(
306
315
  for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
307
316
  row_idx = start + tl.arange(0, BLOCK_ROW)
308
317
  row_mask = row_idx < n_rows
309
- dY_row = tl.load(dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :], mask=row_mask[:, None] & col_mask[None, :], other=0.0)
310
- X_row = tl.load(X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :], mask=row_mask[:, None] & col_mask[None, :], other=0.0)
318
+ dY_row = tl.load(
319
+ dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :],
320
+ mask=row_mask[:, None] & col_mask[None, :],
321
+ other=0.0,
322
+ )
323
+ X_row = tl.load(
324
+ X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
325
+ mask=row_mask[:, None] & col_mask[None, :],
326
+ other=0.0,
327
+ )
311
328
 
312
329
  # Get cached rms
313
330
  rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, row_mask)
@@ -326,7 +343,9 @@ def _block_rms_norm_backward_kernel(
326
343
 
327
344
  dX_row = rstd_row[:, None] * m
328
345
 
329
- dX_row += (rstd_row[:, None]) * (-(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row)
346
+ dX_row += (rstd_row[:, None]) * (
347
+ -(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
348
+ )
330
349
 
331
350
  # calculate the gradient of W
332
351
  if casting_mode == _CASTING_MODE_LLAMA:
@@ -335,8 +354,11 @@ def _block_rms_norm_backward_kernel(
335
354
  # here X_row is already in fp32 (see previous if block)
336
355
  dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
337
356
 
338
- tl.store(dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :], dX_row, mask=row_mask[:, None] & col_mask[None, :])
339
-
357
+ tl.store(
358
+ dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
359
+ dX_row,
360
+ mask=row_mask[:, None] & col_mask[None, :],
361
+ )
340
362
 
341
363
  tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
342
364
 
@@ -549,15 +571,6 @@ class LigerRMSNormFunction(torch.autograd.Function):
549
571
  """
550
572
  X, W, RSTD = ctx.saved_tensors
551
573
  dX, dW = rms_norm_backward(
552
- dY,
553
- X,
554
- W,
555
- RSTD,
556
- ctx.offset,
557
- ctx.casting_mode,
558
- ctx.BLOCK_SIZE,
559
- ctx.num_warps,
560
- ctx.in_place,
561
- ctx.row_mode
574
+ dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
562
575
  )
563
576
  return dX, dW, None, None, None, None, None
@@ -776,7 +776,8 @@ def apply_liger_kernel_to_gemma3_text(
776
776
 
777
777
  from transformers.models.gemma3 import modeling_gemma3
778
778
  from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
779
- from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM, Gemma3TextModel
779
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
780
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
780
781
 
781
782
  from liger_kernel.transformers.gema3_rms import LigerRMSNormForGemma3
782
783
  from liger_kernel.transformers.model.gemma3 import causal_forward
@@ -37,7 +37,7 @@ class LigerRMSNorm(nn.Module):
37
37
  self.offset,
38
38
  self.casting_mode,
39
39
  self.in_place,
40
- self.row_mode
40
+ self.row_mode,
41
41
  )
42
42
 
43
43
  def extra_repr(self):
@@ -4,7 +4,7 @@ import torch.nn as nn
4
4
  from liger_kernel.ops.softmax import LigerSoftmaxFunction
5
5
 
6
6
 
7
- class LigerKernelSoftmax(nn.Module):
7
+ class LigerSoftmax(nn.Module):
8
8
  def __init__(self):
9
9
  super().__init__()
10
10
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.10.dev20250528223524
3
+ Version: 0.5.10.dev20250601024230
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -325,6 +325,8 @@ loss.backward()
325
325
  | GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
326
326
  | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
327
327
  | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
328
+ | Multi Token Attention | `liger_kernel.transformers.LigerMultiTokenAttention` |
329
+ | Softmax | `liger_kernel.transformers.LigerSoftmax` |
328
330
  | Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
329
331
 
330
332
 
@@ -28,7 +28,7 @@ liger_kernel/ops/kl_div.py,sha256=ZjGdDLKWksHT9dZ0xF_TDgAkj5cuMTwwT5tr9E-_24o,87
28
28
  liger_kernel/ops/layer_norm.py,sha256=vWCyOm-F2GMAilB-ozJcFeUQQLCJoTE_uiXq-_0uYuI,8356
29
29
  liger_kernel/ops/multi_token_attention.py,sha256=Oz_RXDp-OSS_R_HuGmaETHdAJ7Toda_70OfE7TXMUlY,7645
30
30
  liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
31
- liger_kernel/ops/rms_norm.py,sha256=IDj_V3hwo6tm3FijVbRh6ebUj2A3591MNkMer_gncdM,18749
31
+ liger_kernel/ops/rms_norm.py,sha256=-rcgHwWCxlA-Syec2XhdW4jfOeCDt2r7qwjslgXFYDU,18865
32
32
  liger_kernel/ops/rope.py,sha256=ofmBOkUpZZO-Q8Z5B_LOFYYLD-YT-8WnJ4vGOrDYouI,8943
33
33
  liger_kernel/ops/softmax.py,sha256=tgORx6MK1IDDtZKqGarj0IPIVjqAIEUXXYPiinhRdtI,5864
34
34
  liger_kernel/ops/sparsemax.py,sha256=AeWe1xgkHJFEKWTj2vu_0hj7LztGvjqXAps-QTpCY0U,5087
@@ -52,12 +52,12 @@ liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-
52
52
  liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
53
53
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
54
54
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
55
- liger_kernel/transformers/monkey_patch.py,sha256=a0CXSC8BwZg3vok-ns0udZLUOBkegGQgPDod3H8ilP4,74610
55
+ liger_kernel/transformers/monkey_patch.py,sha256=A91QWjMG7d7302lx-Djjxd_VgwBhYwxAYa1davBFCjU,74668
56
56
  liger_kernel/transformers/multi_token_attention.py,sha256=l9VDICK0dfmifUDW668hGscP8AHq2rYcM2oGUa3baRQ,1751
57
57
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
58
- liger_kernel/transformers/rms_norm.py,sha256=srMS4jdkMCjY4Yqj9jjsy_IkY8KlHdTPLOx4069ZACA,1277
58
+ liger_kernel/transformers/rms_norm.py,sha256=QimExM27kYoAnaZqxb_8mBaUcd72-X01DviJ1dQd55I,1278
59
59
  liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
60
- liger_kernel/transformers/softmax.py,sha256=u7bFo35-cjaAm9of6-DLzmkaNFELOM-9AgyrcvUPifw,270
60
+ liger_kernel/transformers/softmax.py,sha256=yadlAgE4V2JByMwrDDa2s5SUBp8Jgd57xwnVvAWoBaI,264
61
61
  liger_kernel/transformers/sparsemax.py,sha256=0lQA0UEOs4mu8CMruZ3VLhImxQVXJWhPsAKUsYA7vj8,403
62
62
  liger_kernel/transformers/swiglu.py,sha256=LZ8YeLIdv2k46JleZMjzubGk98smt6t780kSgcVLsQk,3454
63
63
  liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
@@ -86,9 +86,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
86
86
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
87
87
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
88
88
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
89
- liger_kernel_nightly-0.5.10.dev20250528223524.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
90
- liger_kernel_nightly-0.5.10.dev20250528223524.dist-info/METADATA,sha256=XqzBAk8PxwjhEYwf_3Xw0sssbGSM3IWW9z3NWlsZ7ZU,24113
91
- liger_kernel_nightly-0.5.10.dev20250528223524.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
92
- liger_kernel_nightly-0.5.10.dev20250528223524.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
93
- liger_kernel_nightly-0.5.10.dev20250528223524.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
94
- liger_kernel_nightly-0.5.10.dev20250528223524.dist-info/RECORD,,
89
+ liger_kernel_nightly-0.5.10.dev20250601024230.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
90
+ liger_kernel_nightly-0.5.10.dev20250601024230.dist-info/METADATA,sha256=p4YDg6nRS2Zh3pCFi_dj1Yl7DtEi5U3bciMTtrcY-1U,24309
91
+ liger_kernel_nightly-0.5.10.dev20250601024230.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
92
+ liger_kernel_nightly-0.5.10.dev20250601024230.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
93
+ liger_kernel_nightly-0.5.10.dev20250601024230.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
94
+ liger_kernel_nightly-0.5.10.dev20250601024230.dist-info/RECORD,,