liger-kernel-nightly 0.5.10.dev20250601024230__py3-none-any.whl → 0.5.10.dev20250602134913__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/transformers/monkey_patch.py +11 -7
- liger_kernel/transformers/rms_norm.py +35 -0
- {liger_kernel_nightly-0.5.10.dev20250601024230.dist-info → liger_kernel_nightly-0.5.10.dev20250602134913.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.10.dev20250601024230.dist-info → liger_kernel_nightly-0.5.10.dev20250602134913.dist-info}/RECORD +8 -9
- liger_kernel/transformers/gema3_rms.py +0 -8
- {liger_kernel_nightly-0.5.10.dev20250601024230.dist-info → liger_kernel_nightly-0.5.10.dev20250602134913.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250601024230.dist-info → liger_kernel_nightly-0.5.10.dev20250602134913.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250601024230.dist-info → liger_kernel_nightly-0.5.10.dev20250602134913.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250601024230.dist-info → liger_kernel_nightly-0.5.10.dev20250602134913.dist-info}/top_level.txt +0 -0
@@ -428,13 +428,14 @@ def apply_liger_kernel_to_mllama(
|
|
428
428
|
if isinstance(model, MllamaForConditionalGeneration):
|
429
429
|
language_model: MllamaForCausalLM = model.language_model
|
430
430
|
vision_model: MllamaVisionModel = model.vision_model
|
431
|
-
text_model: MllamaTextModel = language_model
|
431
|
+
text_model: MllamaTextModel = language_model
|
432
432
|
elif isinstance(model, MllamaForCausalLM):
|
433
433
|
text_model = model.model
|
434
434
|
vision_model = None
|
435
435
|
elif isinstance(model, MllamaTextModel):
|
436
436
|
text_model = model
|
437
437
|
vision_model = None
|
438
|
+
|
438
439
|
else:
|
439
440
|
raise ValueError(f"Unsupported Mllama model type: {type(model)}")
|
440
441
|
|
@@ -626,8 +627,8 @@ def apply_liger_kernel_to_gemma(
|
|
626
627
|
from transformers.models.gemma import modeling_gemma
|
627
628
|
from transformers.models.gemma.modeling_gemma import GemmaModel
|
628
629
|
|
629
|
-
|
630
|
-
|
630
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma
|
631
|
+
|
631
632
|
_patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
|
632
633
|
|
633
634
|
if rope:
|
@@ -700,7 +701,8 @@ def apply_liger_kernel_to_gemma2(
|
|
700
701
|
from transformers.models.gemma2 import modeling_gemma2
|
701
702
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
702
703
|
|
703
|
-
|
704
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2
|
705
|
+
|
704
706
|
_patch_rms_norm_module_for_gemma2 = partial(
|
705
707
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
706
708
|
)
|
@@ -779,8 +781,8 @@ def apply_liger_kernel_to_gemma3_text(
|
|
779
781
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
|
780
782
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
|
781
783
|
|
782
|
-
from liger_kernel.transformers.gema3_rms import LigerRMSNormForGemma3
|
783
784
|
from liger_kernel.transformers.model.gemma3 import causal_forward
|
785
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma3
|
784
786
|
|
785
787
|
_patch_rms_norm_module_for_gemma3 = partial(
|
786
788
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
@@ -1450,11 +1452,12 @@ def apply_liger_kernel_to_olmo2(
|
|
1450
1452
|
from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
|
1451
1453
|
|
1452
1454
|
from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
|
1455
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
|
1453
1456
|
|
1454
1457
|
if rope:
|
1455
1458
|
modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
1456
1459
|
if rms_norm:
|
1457
|
-
modeling_olmo2.Olmo2RMSNorm =
|
1460
|
+
modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
|
1458
1461
|
if swiglu:
|
1459
1462
|
modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
|
1460
1463
|
if cross_entropy:
|
@@ -1513,11 +1516,12 @@ def apply_liger_kernel_to_glm4(
|
|
1513
1516
|
from transformers.models.glm4.modeling_glm4 import Glm4Model
|
1514
1517
|
|
1515
1518
|
from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
|
1519
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
1516
1520
|
|
1517
1521
|
if rope:
|
1518
1522
|
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
1519
1523
|
if rms_norm:
|
1520
|
-
modeling_glm4.Glm4RMSNorm =
|
1524
|
+
modeling_glm4.Glm4RMSNorm = LigerRMSNormForGlm4
|
1521
1525
|
if swiglu:
|
1522
1526
|
modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
|
1523
1527
|
if cross_entropy:
|
@@ -44,3 +44,38 @@ class LigerRMSNorm(nn.Module):
|
|
44
44
|
return (
|
45
45
|
f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
|
46
46
|
)
|
47
|
+
|
48
|
+
|
49
|
+
class LigerRMSNormForGemma(LigerRMSNorm):
|
50
|
+
def __init__(
|
51
|
+
self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=True, row_mode=None
|
52
|
+
):
|
53
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
54
|
+
|
55
|
+
|
56
|
+
class LigerRMSNormForGemma2(LigerRMSNorm):
|
57
|
+
def __init__(
|
58
|
+
self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None
|
59
|
+
):
|
60
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
61
|
+
|
62
|
+
|
63
|
+
class LigerRMSNormForGemma3(LigerRMSNorm):
|
64
|
+
"""Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
|
65
|
+
|
66
|
+
def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
|
67
|
+
super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)
|
68
|
+
|
69
|
+
|
70
|
+
class LigerRMSNormForOlmo2(LigerRMSNorm):
|
71
|
+
def __init__(
|
72
|
+
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
|
73
|
+
):
|
74
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
75
|
+
|
76
|
+
|
77
|
+
class LigerRMSNormForGlm4(LigerRMSNorm):
|
78
|
+
def __init__(
|
79
|
+
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
|
80
|
+
):
|
81
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
@@ -46,16 +46,15 @@ liger_kernel/transformers/functional.py,sha256=QmnAFpRgIbp9Rzlfp8QibwiEbf5BUcANx
|
|
46
46
|
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=O8Sg5BT81nTaY9fSGoOY9dOD9ekibwwiuXhdUHaxntQ,1742
|
47
47
|
liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
|
48
48
|
liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
|
49
|
-
liger_kernel/transformers/gema3_rms.py,sha256=LTmZOXe6WEnv6ZroW-kU1TE2B36-z5v8OLmKr3XEVFo,353
|
50
49
|
liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD2VuGJrMBz0,2173
|
51
50
|
liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-QxyaT8zhM,3897
|
52
51
|
liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
|
53
52
|
liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
|
54
53
|
liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
|
55
|
-
liger_kernel/transformers/monkey_patch.py,sha256=
|
54
|
+
liger_kernel/transformers/monkey_patch.py,sha256=drSPROAsiphnLPl2lFPBUvG_u4oIufDKEklqnDyy0vY,74584
|
56
55
|
liger_kernel/transformers/multi_token_attention.py,sha256=l9VDICK0dfmifUDW668hGscP8AHq2rYcM2oGUa3baRQ,1751
|
57
56
|
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
58
|
-
liger_kernel/transformers/rms_norm.py,sha256=
|
57
|
+
liger_kernel/transformers/rms_norm.py,sha256=eErIr1n-13oVrc1VJY07lqazYelw_vlu9Az__RmXPSE,2717
|
59
58
|
liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
|
60
59
|
liger_kernel/transformers/softmax.py,sha256=yadlAgE4V2JByMwrDDa2s5SUBp8Jgd57xwnVvAWoBaI,264
|
61
60
|
liger_kernel/transformers/sparsemax.py,sha256=0lQA0UEOs4mu8CMruZ3VLhImxQVXJWhPsAKUsYA7vj8,403
|
@@ -86,9 +85,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
86
85
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
87
86
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
88
87
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
89
|
-
liger_kernel_nightly-0.5.10.
|
90
|
-
liger_kernel_nightly-0.5.10.
|
91
|
-
liger_kernel_nightly-0.5.10.
|
92
|
-
liger_kernel_nightly-0.5.10.
|
93
|
-
liger_kernel_nightly-0.5.10.
|
94
|
-
liger_kernel_nightly-0.5.10.
|
88
|
+
liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
89
|
+
liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/METADATA,sha256=TewpbE_T3k_gTii2lgoBpvrzywyhd7f-xZl2kfbEYTc,24309
|
90
|
+
liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
91
|
+
liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
92
|
+
liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
93
|
+
liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
from .rms_norm import LigerRMSNorm
|
2
|
-
|
3
|
-
|
4
|
-
class LigerRMSNormForGemma3(LigerRMSNorm):
|
5
|
-
"""Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
|
6
|
-
|
7
|
-
def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
|
8
|
-
super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)
|
File without changes
|
File without changes
|
File without changes
|