liger-kernel-nightly 0.5.10.dev20250601024230__py3-none-any.whl → 0.5.10.dev20250602134913__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -428,13 +428,14 @@ def apply_liger_kernel_to_mllama(
428
428
  if isinstance(model, MllamaForConditionalGeneration):
429
429
  language_model: MllamaForCausalLM = model.language_model
430
430
  vision_model: MllamaVisionModel = model.vision_model
431
- text_model: MllamaTextModel = language_model.model
431
+ text_model: MllamaTextModel = language_model
432
432
  elif isinstance(model, MllamaForCausalLM):
433
433
  text_model = model.model
434
434
  vision_model = None
435
435
  elif isinstance(model, MllamaTextModel):
436
436
  text_model = model
437
437
  vision_model = None
438
+
438
439
  else:
439
440
  raise ValueError(f"Unsupported Mllama model type: {type(model)}")
440
441
 
@@ -626,8 +627,8 @@ def apply_liger_kernel_to_gemma(
626
627
  from transformers.models.gemma import modeling_gemma
627
628
  from transformers.models.gemma.modeling_gemma import GemmaModel
628
629
 
629
- # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
630
- LigerRMSNormForGemma = partial(LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma")
630
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma
631
+
631
632
  _patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
632
633
 
633
634
  if rope:
@@ -700,7 +701,8 @@ def apply_liger_kernel_to_gemma2(
700
701
  from transformers.models.gemma2 import modeling_gemma2
701
702
  from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
702
703
 
703
- LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False)
704
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2
705
+
704
706
  _patch_rms_norm_module_for_gemma2 = partial(
705
707
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
706
708
  )
@@ -779,8 +781,8 @@ def apply_liger_kernel_to_gemma3_text(
779
781
  from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
780
782
  from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
781
783
 
782
- from liger_kernel.transformers.gema3_rms import LigerRMSNormForGemma3
783
784
  from liger_kernel.transformers.model.gemma3 import causal_forward
785
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma3
784
786
 
785
787
  _patch_rms_norm_module_for_gemma3 = partial(
786
788
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
@@ -1450,11 +1452,12 @@ def apply_liger_kernel_to_olmo2(
1450
1452
  from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
1451
1453
 
1452
1454
  from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
1455
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
1453
1456
 
1454
1457
  if rope:
1455
1458
  modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
1456
1459
  if rms_norm:
1457
- modeling_olmo2.Olmo2RMSNorm = partial(LigerRMSNorm, in_place=False)
1460
+ modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
1458
1461
  if swiglu:
1459
1462
  modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
1460
1463
  if cross_entropy:
@@ -1513,11 +1516,12 @@ def apply_liger_kernel_to_glm4(
1513
1516
  from transformers.models.glm4.modeling_glm4 import Glm4Model
1514
1517
 
1515
1518
  from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
1519
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
1516
1520
 
1517
1521
  if rope:
1518
1522
  raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
1519
1523
  if rms_norm:
1520
- modeling_glm4.Glm4RMSNorm = partial(LigerRMSNorm, in_place=False)
1524
+ modeling_glm4.Glm4RMSNorm = LigerRMSNormForGlm4
1521
1525
  if swiglu:
1522
1526
  modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
1523
1527
  if cross_entropy:
@@ -44,3 +44,38 @@ class LigerRMSNorm(nn.Module):
44
44
  return (
45
45
  f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
46
46
  )
47
+
48
+
49
+ class LigerRMSNormForGemma(LigerRMSNorm):
50
+ def __init__(
51
+ self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=True, row_mode=None
52
+ ):
53
+ super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
54
+
55
+
56
+ class LigerRMSNormForGemma2(LigerRMSNorm):
57
+ def __init__(
58
+ self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None
59
+ ):
60
+ super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
61
+
62
+
63
+ class LigerRMSNormForGemma3(LigerRMSNorm):
64
+ """Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
65
+
66
+ def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
67
+ super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)
68
+
69
+
70
+ class LigerRMSNormForOlmo2(LigerRMSNorm):
71
+ def __init__(
72
+ self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
73
+ ):
74
+ super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
75
+
76
+
77
+ class LigerRMSNormForGlm4(LigerRMSNorm):
78
+ def __init__(
79
+ self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
80
+ ):
81
+ super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.10.dev20250601024230
3
+ Version: 0.5.10.dev20250602134913
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -46,16 +46,15 @@ liger_kernel/transformers/functional.py,sha256=QmnAFpRgIbp9Rzlfp8QibwiEbf5BUcANx
46
46
  liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=O8Sg5BT81nTaY9fSGoOY9dOD9ekibwwiuXhdUHaxntQ,1742
47
47
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
48
48
  liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
49
- liger_kernel/transformers/gema3_rms.py,sha256=LTmZOXe6WEnv6ZroW-kU1TE2B36-z5v8OLmKr3XEVFo,353
50
49
  liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD2VuGJrMBz0,2173
51
50
  liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-QxyaT8zhM,3897
52
51
  liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
53
52
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
54
53
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
55
- liger_kernel/transformers/monkey_patch.py,sha256=A91QWjMG7d7302lx-Djjxd_VgwBhYwxAYa1davBFCjU,74668
54
+ liger_kernel/transformers/monkey_patch.py,sha256=drSPROAsiphnLPl2lFPBUvG_u4oIufDKEklqnDyy0vY,74584
56
55
  liger_kernel/transformers/multi_token_attention.py,sha256=l9VDICK0dfmifUDW668hGscP8AHq2rYcM2oGUa3baRQ,1751
57
56
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
58
- liger_kernel/transformers/rms_norm.py,sha256=QimExM27kYoAnaZqxb_8mBaUcd72-X01DviJ1dQd55I,1278
57
+ liger_kernel/transformers/rms_norm.py,sha256=eErIr1n-13oVrc1VJY07lqazYelw_vlu9Az__RmXPSE,2717
59
58
  liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
60
59
  liger_kernel/transformers/softmax.py,sha256=yadlAgE4V2JByMwrDDa2s5SUBp8Jgd57xwnVvAWoBaI,264
61
60
  liger_kernel/transformers/sparsemax.py,sha256=0lQA0UEOs4mu8CMruZ3VLhImxQVXJWhPsAKUsYA7vj8,403
@@ -86,9 +85,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
86
85
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
87
86
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
88
87
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
89
- liger_kernel_nightly-0.5.10.dev20250601024230.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
90
- liger_kernel_nightly-0.5.10.dev20250601024230.dist-info/METADATA,sha256=p4YDg6nRS2Zh3pCFi_dj1Yl7DtEi5U3bciMTtrcY-1U,24309
91
- liger_kernel_nightly-0.5.10.dev20250601024230.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
92
- liger_kernel_nightly-0.5.10.dev20250601024230.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
93
- liger_kernel_nightly-0.5.10.dev20250601024230.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
94
- liger_kernel_nightly-0.5.10.dev20250601024230.dist-info/RECORD,,
88
+ liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
89
+ liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/METADATA,sha256=TewpbE_T3k_gTii2lgoBpvrzywyhd7f-xZl2kfbEYTc,24309
90
+ liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
91
+ liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
92
+ liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
93
+ liger_kernel_nightly-0.5.10.dev20250602134913.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- from .rms_norm import LigerRMSNorm
2
-
3
-
4
- class LigerRMSNormForGemma3(LigerRMSNorm):
5
- """Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
6
-
7
- def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
8
- super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)