liger-kernel-nightly 0.5.10.dev20250602014906__py3-none-any.whl → 0.5.10.dev20250602134913__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/transformers/monkey_patch.py +9 -6
- liger_kernel/transformers/rms_norm.py +35 -0
- {liger_kernel_nightly-0.5.10.dev20250602014906.dist-info → liger_kernel_nightly-0.5.10.dev20250602134913.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.10.dev20250602014906.dist-info → liger_kernel_nightly-0.5.10.dev20250602134913.dist-info}/RECORD +8 -9
- liger_kernel/transformers/gema3_rms.py +0 -8
- {liger_kernel_nightly-0.5.10.dev20250602014906.dist-info → liger_kernel_nightly-0.5.10.dev20250602134913.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250602014906.dist-info → liger_kernel_nightly-0.5.10.dev20250602134913.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250602014906.dist-info → liger_kernel_nightly-0.5.10.dev20250602134913.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250602014906.dist-info → liger_kernel_nightly-0.5.10.dev20250602134913.dist-info}/top_level.txt +0 -0
@@ -627,8 +627,8 @@ def apply_liger_kernel_to_gemma(
|
|
627
627
|
from transformers.models.gemma import modeling_gemma
|
628
628
|
from transformers.models.gemma.modeling_gemma import GemmaModel
|
629
629
|
|
630
|
-
|
631
|
-
|
630
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma
|
631
|
+
|
632
632
|
_patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
|
633
633
|
|
634
634
|
if rope:
|
@@ -701,7 +701,8 @@ def apply_liger_kernel_to_gemma2(
|
|
701
701
|
from transformers.models.gemma2 import modeling_gemma2
|
702
702
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
703
703
|
|
704
|
-
|
704
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2
|
705
|
+
|
705
706
|
_patch_rms_norm_module_for_gemma2 = partial(
|
706
707
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
707
708
|
)
|
@@ -780,8 +781,8 @@ def apply_liger_kernel_to_gemma3_text(
|
|
780
781
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
|
781
782
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
|
782
783
|
|
783
|
-
from liger_kernel.transformers.gema3_rms import LigerRMSNormForGemma3
|
784
784
|
from liger_kernel.transformers.model.gemma3 import causal_forward
|
785
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma3
|
785
786
|
|
786
787
|
_patch_rms_norm_module_for_gemma3 = partial(
|
787
788
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
@@ -1451,11 +1452,12 @@ def apply_liger_kernel_to_olmo2(
|
|
1451
1452
|
from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
|
1452
1453
|
|
1453
1454
|
from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
|
1455
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
|
1454
1456
|
|
1455
1457
|
if rope:
|
1456
1458
|
modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
1457
1459
|
if rms_norm:
|
1458
|
-
modeling_olmo2.Olmo2RMSNorm =
|
1460
|
+
modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
|
1459
1461
|
if swiglu:
|
1460
1462
|
modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
|
1461
1463
|
if cross_entropy:
|
@@ -1514,11 +1516,12 @@ def apply_liger_kernel_to_glm4(
|
|
1514
1516
|
from transformers.models.glm4.modeling_glm4 import Glm4Model
|
1515
1517
|
|
1516
1518
|
from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
|
1519
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
1517
1520
|
|
1518
1521
|
if rope:
|
1519
1522
|
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
1520
1523
|
if rms_norm:
|
1521
|
-
modeling_glm4.Glm4RMSNorm =
|
1524
|
+
modeling_glm4.Glm4RMSNorm = LigerRMSNormForGlm4
|
1522
1525
|
if swiglu:
|
1523
1526
|
modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
|
1524
1527
|
if cross_entropy:
|
@@ -44,3 +44,38 @@ class LigerRMSNorm(nn.Module):
|
|
44
44
|
return (
|
45
45
|
f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
|
46
46
|
)
|
47
|
+
|
48
|
+
|
49
|
+
class LigerRMSNormForGemma(LigerRMSNorm):
|
50
|
+
def __init__(
|
51
|
+
self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=True, row_mode=None
|
52
|
+
):
|
53
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
54
|
+
|
55
|
+
|
56
|
+
class LigerRMSNormForGemma2(LigerRMSNorm):
|
57
|
+
def __init__(
|
58
|
+
self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None
|
59
|
+
):
|
60
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
61
|
+
|
62
|
+
|
63
|
+
class LigerRMSNormForGemma3(LigerRMSNorm):
|
64
|
+
"""Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
|
65
|
+
|
66
|
+
def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
|
67
|
+
super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)
|
68
|
+
|
69
|
+
|
70
|
+
class LigerRMSNormForOlmo2(LigerRMSNorm):
|
71
|
+
def __init__(
|
72
|
+
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
|
73
|
+
):
|
74
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
75
|
+
|
76
|
+
|
77
|
+
class LigerRMSNormForGlm4(LigerRMSNorm):
|
78
|
+
def __init__(
|
79
|
+
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
|
80
|
+
):
|
81
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
@@ -46,16 +46,15 @@ liger_kernel/transformers/functional.py,sha256=QmnAFpRgIbp9Rzlfp8QibwiEbf5BUcANx
|
|
46
46
|
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=O8Sg5BT81nTaY9fSGoOY9dOD9ekibwwiuXhdUHaxntQ,1742
|
47
47
|
liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
|
48
48
|
liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
|
49
|
-
liger_kernel/transformers/gema3_rms.py,sha256=LTmZOXe6WEnv6ZroW-kU1TE2B36-z5v8OLmKr3XEVFo,353
|
50
49
|
liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD2VuGJrMBz0,2173
|
51
50
|
liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-QxyaT8zhM,3897
|
52
51
|
liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
|
53
52
|
liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
|
54
53
|
liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
|
55
|
-
liger_kernel/transformers/monkey_patch.py,sha256=
|
54
|
+
liger_kernel/transformers/monkey_patch.py,sha256=drSPROAsiphnLPl2lFPBUvG_u4oIufDKEklqnDyy0vY,74584
|
56
55
|
liger_kernel/transformers/multi_token_attention.py,sha256=l9VDICK0dfmifUDW668hGscP8AHq2rYcM2oGUa3baRQ,1751
|
57
56
|
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
58
|
-
liger_kernel/transformers/rms_norm.py,sha256=
|
57
|
+
liger_kernel/transformers/rms_norm.py,sha256=eErIr1n-13oVrc1VJY07lqazYelw_vlu9Az__RmXPSE,2717
|
59
58
|
liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
|
60
59
|
liger_kernel/transformers/softmax.py,sha256=yadlAgE4V2JByMwrDDa2s5SUBp8Jgd57xwnVvAWoBaI,264
|
61
60
|
liger_kernel/transformers/sparsemax.py,sha256=0lQA0UEOs4mu8CMruZ3VLhImxQVXJWhPsAKUsYA7vj8,403
|
@@ -86,9 +85,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
86
85
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
87
86
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
88
87
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
89
|
-
liger_kernel_nightly-0.5.10.
|
90
|
-
liger_kernel_nightly-0.5.10.
|
91
|
-
liger_kernel_nightly-0.5.10.
|
92
|
-
liger_kernel_nightly-0.5.10.
|
93
|
-
liger_kernel_nightly-0.5.10.
|
94
|
-
liger_kernel_nightly-0.5.10.
|
88
|
+
liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
89
|
+
liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/METADATA,sha256=TewpbE_T3k_gTii2lgoBpvrzywyhd7f-xZl2kfbEYTc,24309
|
90
|
+
liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
91
|
+
liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
92
|
+
liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
93
|
+
liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
from .rms_norm import LigerRMSNorm
|
2
|
-
|
3
|
-
|
4
|
-
class LigerRMSNormForGemma3(LigerRMSNorm):
|
5
|
-
"""Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
|
6
|
-
|
7
|
-
def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
|
8
|
-
super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)
|
File without changes
|
File without changes
|
File without changes
|