liger-kernel-nightly 0.5.10.dev20250602014906__py3-none-any.whl → 0.5.10.dev20250602134913__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -627,8 +627,8 @@ def apply_liger_kernel_to_gemma(
627
627
  from transformers.models.gemma import modeling_gemma
628
628
  from transformers.models.gemma.modeling_gemma import GemmaModel
629
629
 
630
- # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
631
- LigerRMSNormForGemma = partial(LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma")
630
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma
631
+
632
632
  _patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
633
633
 
634
634
  if rope:
@@ -701,7 +701,8 @@ def apply_liger_kernel_to_gemma2(
701
701
  from transformers.models.gemma2 import modeling_gemma2
702
702
  from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
703
703
 
704
- LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False)
704
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2
705
+
705
706
  _patch_rms_norm_module_for_gemma2 = partial(
706
707
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
707
708
  )
@@ -780,8 +781,8 @@ def apply_liger_kernel_to_gemma3_text(
780
781
  from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
781
782
  from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
782
783
 
783
- from liger_kernel.transformers.gema3_rms import LigerRMSNormForGemma3
784
784
  from liger_kernel.transformers.model.gemma3 import causal_forward
785
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma3
785
786
 
786
787
  _patch_rms_norm_module_for_gemma3 = partial(
787
788
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
@@ -1451,11 +1452,12 @@ def apply_liger_kernel_to_olmo2(
1451
1452
  from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
1452
1453
 
1453
1454
  from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
1455
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
1454
1456
 
1455
1457
  if rope:
1456
1458
  modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
1457
1459
  if rms_norm:
1458
- modeling_olmo2.Olmo2RMSNorm = partial(LigerRMSNorm, in_place=False)
1460
+ modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
1459
1461
  if swiglu:
1460
1462
  modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
1461
1463
  if cross_entropy:
@@ -1514,11 +1516,12 @@ def apply_liger_kernel_to_glm4(
1514
1516
  from transformers.models.glm4.modeling_glm4 import Glm4Model
1515
1517
 
1516
1518
  from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
1519
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
1517
1520
 
1518
1521
  if rope:
1519
1522
  raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
1520
1523
  if rms_norm:
1521
- modeling_glm4.Glm4RMSNorm = partial(LigerRMSNorm, in_place=False)
1524
+ modeling_glm4.Glm4RMSNorm = LigerRMSNormForGlm4
1522
1525
  if swiglu:
1523
1526
  modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
1524
1527
  if cross_entropy:
@@ -44,3 +44,38 @@ class LigerRMSNorm(nn.Module):
44
44
  return (
45
45
  f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
46
46
  )
47
+
48
+
49
+ class LigerRMSNormForGemma(LigerRMSNorm):
50
+ def __init__(
51
+ self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=True, row_mode=None
52
+ ):
53
+ super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
54
+
55
+
56
+ class LigerRMSNormForGemma2(LigerRMSNorm):
57
+ def __init__(
58
+ self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None
59
+ ):
60
+ super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
61
+
62
+
63
+ class LigerRMSNormForGemma3(LigerRMSNorm):
64
+ """Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
65
+
66
+ def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
67
+ super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)
68
+
69
+
70
+ class LigerRMSNormForOlmo2(LigerRMSNorm):
71
+ def __init__(
72
+ self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
73
+ ):
74
+ super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
75
+
76
+
77
+ class LigerRMSNormForGlm4(LigerRMSNorm):
78
+ def __init__(
79
+ self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
80
+ ):
81
+ super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.10.dev20250602014906
3
+ Version: 0.5.10.dev20250602134913
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -46,16 +46,15 @@ liger_kernel/transformers/functional.py,sha256=QmnAFpRgIbp9Rzlfp8QibwiEbf5BUcANx
46
46
  liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=O8Sg5BT81nTaY9fSGoOY9dOD9ekibwwiuXhdUHaxntQ,1742
47
47
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
48
48
  liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
49
- liger_kernel/transformers/gema3_rms.py,sha256=LTmZOXe6WEnv6ZroW-kU1TE2B36-z5v8OLmKr3XEVFo,353
50
49
  liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD2VuGJrMBz0,2173
51
50
  liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-QxyaT8zhM,3897
52
51
  liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
53
52
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
54
53
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
55
- liger_kernel/transformers/monkey_patch.py,sha256=8FC4hfIfIF2I8-t-kz6ObXXVlvzIJ4_WT2o3z61Ibcg,74663
54
+ liger_kernel/transformers/monkey_patch.py,sha256=drSPROAsiphnLPl2lFPBUvG_u4oIufDKEklqnDyy0vY,74584
56
55
  liger_kernel/transformers/multi_token_attention.py,sha256=l9VDICK0dfmifUDW668hGscP8AHq2rYcM2oGUa3baRQ,1751
57
56
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
58
- liger_kernel/transformers/rms_norm.py,sha256=QimExM27kYoAnaZqxb_8mBaUcd72-X01DviJ1dQd55I,1278
57
+ liger_kernel/transformers/rms_norm.py,sha256=eErIr1n-13oVrc1VJY07lqazYelw_vlu9Az__RmXPSE,2717
59
58
  liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
60
59
  liger_kernel/transformers/softmax.py,sha256=yadlAgE4V2JByMwrDDa2s5SUBp8Jgd57xwnVvAWoBaI,264
61
60
  liger_kernel/transformers/sparsemax.py,sha256=0lQA0UEOs4mu8CMruZ3VLhImxQVXJWhPsAKUsYA7vj8,403
@@ -86,9 +85,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
86
85
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
87
86
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
88
87
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
89
- liger_kernel_nightly-0.5.10.dev20250602014906.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
90
- liger_kernel_nightly-0.5.10.dev20250602014906.dist-info/METADATA,sha256=bKd1gwkECC62Zf6H8YYs0SviZ5vn9wZ-f72dSF94wg8,24309
91
- liger_kernel_nightly-0.5.10.dev20250602014906.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
92
- liger_kernel_nightly-0.5.10.dev20250602014906.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
93
- liger_kernel_nightly-0.5.10.dev20250602014906.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
94
- liger_kernel_nightly-0.5.10.dev20250602014906.dist-info/RECORD,,
88
+ liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
89
+ liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/METADATA,sha256=TewpbE_T3k_gTii2lgoBpvrzywyhd7f-xZl2kfbEYTc,24309
90
+ liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
91
+ liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
92
+ liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
93
+ liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- from .rms_norm import LigerRMSNorm
2
-
3
-
4
- class LigerRMSNormForGemma3(LigerRMSNorm):
5
- """Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
6
-
7
- def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
8
- super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)