liger-kernel 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. liger_kernel/chunked_loss/functional.py +2 -0
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +17 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +346 -0
  4. liger_kernel/chunked_loss/grpo_loss.py +134 -60
  5. liger_kernel/chunked_loss/jsd_loss.py +12 -7
  6. liger_kernel/ops/cross_entropy.py +3 -2
  7. liger_kernel/ops/dyt.py +225 -0
  8. liger_kernel/ops/fused_linear_jsd.py +2 -1
  9. liger_kernel/ops/jsd.py +32 -12
  10. liger_kernel/ops/kl_div.py +15 -8
  11. liger_kernel/ops/layer_norm.py +14 -1
  12. liger_kernel/ops/rms_norm.py +12 -1
  13. liger_kernel/transformers/__init__.py +133 -15
  14. liger_kernel/transformers/dyt.py +20 -0
  15. liger_kernel/transformers/functional.py +5 -0
  16. liger_kernel/transformers/gema3_rms.py +8 -0
  17. liger_kernel/transformers/model/gemma.py +17 -20
  18. liger_kernel/transformers/model/gemma2.py +17 -21
  19. liger_kernel/transformers/model/gemma3.py +335 -0
  20. liger_kernel/transformers/model/llama.py +17 -19
  21. liger_kernel/transformers/model/llava.py +369 -0
  22. liger_kernel/transformers/model/loss_utils.py +64 -0
  23. liger_kernel/transformers/model/mistral.py +28 -25
  24. liger_kernel/transformers/model/mixtral.py +20 -26
  25. liger_kernel/transformers/model/mllama.py +17 -19
  26. liger_kernel/transformers/model/olmo2.py +17 -20
  27. liger_kernel/transformers/model/paligemma.py +397 -0
  28. liger_kernel/transformers/model/phi3.py +17 -19
  29. liger_kernel/transformers/model/qwen2.py +17 -19
  30. liger_kernel/transformers/model/qwen2_5_vl.py +9 -10
  31. liger_kernel/transformers/model/qwen2_vl.py +9 -10
  32. liger_kernel/transformers/monkey_patch.py +392 -13
  33. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/METADATA +11 -6
  34. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/RECORD +38 -31
  35. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/WHEEL +1 -1
  36. liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
  37. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/LICENSE +0 -0
  38. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/NOTICE +0 -0
  39. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,8 @@ from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_for
19
19
  from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
20
20
  from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
21
21
  from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
22
+ from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
23
+ from liger_kernel.transformers.model.llava import lce_forward_deprecated as llava_lce_forward_deprecated
22
24
  from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
23
25
  from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
24
26
  from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
@@ -52,13 +54,26 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
52
54
  module.in_place = in_place
53
55
  _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
54
56
  _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
57
+ module.__class__.__name__ = LigerRMSNorm.__name__
55
58
 
56
59
 
57
60
  def _patch_layer_norm_module(module, eps=1e-6):
58
61
  module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
59
- module.hidden_size = module.normalized_shape
62
+ module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
63
+
60
64
  _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
61
65
  _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
66
+ module.__class__.__name__ = LigerLayerNorm.__name__
67
+
68
+
69
+ def _patch_swiglu_module(module, liger_module):
70
+ _bind_method_to_module(module, "forward", liger_module.forward)
71
+ module.__class__.__name__ = liger_module.__name__
72
+
73
+
74
+ def _patch_geglu_module(module):
75
+ _bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
76
+ module.__class__.__name__ = LigerGEGLUMLP.__name__
62
77
 
63
78
 
64
79
  def apply_liger_kernel_to_granite(
@@ -134,7 +149,7 @@ def apply_liger_kernel_to_granite(
134
149
 
135
150
  for decoder_layer in base_model.layers:
136
151
  if swiglu:
137
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
152
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
138
153
  if rms_norm:
139
154
  _patch_rms_norm_module(decoder_layer.input_layernorm)
140
155
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -206,12 +221,91 @@ def apply_liger_kernel_to_llama(
206
221
 
207
222
  for decoder_layer in base_model.layers:
208
223
  if swiglu:
209
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
224
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
210
225
  if rms_norm:
211
226
  _patch_rms_norm_module(decoder_layer.input_layernorm)
212
227
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
213
228
 
214
229
 
230
+ def apply_liger_kernel_to_llava(
231
+ cross_entropy: bool = False,
232
+ fused_linear_cross_entropy: bool = True,
233
+ model: PreTrainedModel = None,
234
+ **kwargs,
235
+ ) -> None:
236
+ """
237
+ Apply Liger kernels to replace original implementation in HuggingFace Llava models.
238
+ Due to the characteristics of LlaVa, the model must be passed to apply Liger-Kernel's patch to other models connected to LLaVa.
239
+ However, if an LM not supported by Liger-Kernel is connected to LLaVa, unexpected side effects may occur.
240
+ NOTE: Llava is not available in transformers<4.36.0
241
+
242
+ Args:
243
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
244
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
245
+ fused_linear_cross_entropy (bool):
246
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
247
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
248
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
249
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
250
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
251
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
252
+ loaded. Default is None.
253
+ """
254
+ assert not (cross_entropy and fused_linear_cross_entropy), (
255
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
256
+ )
257
+
258
+ from transformers.models.llava import modeling_llava
259
+
260
+ if cross_entropy:
261
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
262
+ modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
263
+ if fused_linear_cross_entropy:
264
+ if transformer_version >= version.parse("4.49.0"):
265
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
266
+ else: # if version < 4.49.0
267
+ logger.warning(
268
+ "Support for transformers versions < 4.49.0 will soon be discontinued due to issues with incorrect legacy processing. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/35526"
269
+ )
270
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
271
+
272
+ if model is not None:
273
+ text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
274
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
275
+ vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)
276
+
277
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs}
278
+ if text_liger_fn:
279
+ accept_params = inspect.signature(text_liger_fn).parameters
280
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
281
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
282
+
283
+ if remain_params:
284
+ logger.warning(
285
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
286
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
287
+ )
288
+ text_kwargs["model"] = model.language_model
289
+ text_liger_fn(**text_kwargs)
290
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
291
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
292
+
293
+ if vision_liger_fn:
294
+ accept_params = inspect.signature(vision_liger_fn).parameters
295
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
296
+ vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
297
+
298
+ if remain_params:
299
+ logger.warning(
300
+ f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
301
+ f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
302
+ )
303
+ vision_kwargs["model"] = model.vision_tower
304
+ vision_liger_fn(**vision_kwargs)
305
+ elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
306
+ logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
307
+
308
+
215
309
  def apply_liger_kernel_to_mllama(
216
310
  rope: bool = True,
217
311
  cross_entropy: bool = False,
@@ -296,7 +390,7 @@ def apply_liger_kernel_to_mllama(
296
390
  _patch_rms_norm_module(text_model.norm)
297
391
  for decoder_layer in text_model.layers:
298
392
  if swiglu:
299
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
393
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
300
394
  if rms_norm:
301
395
  _patch_rms_norm_module(decoder_layer.input_layernorm)
302
396
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -370,7 +464,7 @@ def apply_liger_kernel_to_mistral(
370
464
 
371
465
  for decoder_layer in base_model.layers:
372
466
  if swiglu:
373
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
467
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
374
468
  if rms_norm:
375
469
  _patch_rms_norm_module(decoder_layer.input_layernorm)
376
470
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -442,7 +536,7 @@ def apply_liger_kernel_to_mixtral(
442
536
  for decoder_layer in base_model.layers:
443
537
  if swiglu:
444
538
  for expert in decoder_layer.block_sparse_moe.experts:
445
- _bind_method_to_module(expert, "forward", LigerBlockSparseTop2MLP.forward)
539
+ _patch_swiglu_module(expert, LigerBlockSparseTop2MLP)
446
540
  if rms_norm:
447
541
  _patch_rms_norm_module(decoder_layer.input_layernorm)
448
542
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -516,7 +610,7 @@ def apply_liger_kernel_to_gemma(
516
610
 
517
611
  for decoder_layer in base_model.layers:
518
612
  if geglu:
519
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
613
+ _patch_geglu_module(decoder_layer.mlp)
520
614
  if rms_norm:
521
615
  _patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
522
616
  _patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
@@ -592,7 +686,7 @@ def apply_liger_kernel_to_gemma2(
592
686
 
593
687
  for decoder_layer in base_model.layers:
594
688
  if geglu:
595
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
689
+ _patch_geglu_module(decoder_layer.mlp)
596
690
  if rms_norm:
597
691
  _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
598
692
  _patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
@@ -600,6 +694,287 @@ def apply_liger_kernel_to_gemma2(
600
694
  _patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
601
695
 
602
696
 
697
+ def apply_liger_kernel_to_gemma3_text(
698
+ rope: bool = True,
699
+ cross_entropy: bool = False,
700
+ fused_linear_cross_entropy: bool = True,
701
+ rms_norm: bool = True,
702
+ geglu: bool = True,
703
+ model: PreTrainedModel = None,
704
+ ) -> None:
705
+ """
706
+ Apply Liger kernels to replace original implementation in HuggingFace Gemma3
707
+
708
+ Args:
709
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
710
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
711
+ fused_linear_cross_entropy (bool):
712
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
713
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
714
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
715
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
716
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
717
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
718
+ loaded. Default is None.
719
+ """
720
+ assert not (cross_entropy and fused_linear_cross_entropy), (
721
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
722
+ )
723
+
724
+ from transformers.models.gemma3 import modeling_gemma3
725
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
726
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
727
+
728
+ from liger_kernel.transformers.gema3_rms import LigerRMSNormForGemma3
729
+ from liger_kernel.transformers.model.gemma3 import causal_forward
730
+
731
+ _patch_rms_norm_module_for_gemma3 = partial(
732
+ _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
733
+ )
734
+
735
+ if rope:
736
+ modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
737
+
738
+ if rms_norm:
739
+ modeling_gemma3.Gemma3RMSNorm = LigerRMSNormForGemma3
740
+
741
+ if geglu:
742
+ modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
743
+
744
+ # Handle loss function
745
+ if cross_entropy:
746
+ from transformers.loss.loss_utils import nn
747
+
748
+ nn.functional.cross_entropy = liger_cross_entropy
749
+
750
+ if fused_linear_cross_entropy:
751
+ modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
752
+
753
+ if model is not None:
754
+ # The model instance already exists, so we need to additionally patch the
755
+ # instance variables that reference already-instantiated modules
756
+
757
+ if isinstance(model, Gemma3ForCausalLM):
758
+ # get the base model from the model instance
759
+ base_model = model.model
760
+
761
+ if rms_norm:
762
+ _patch_rms_norm_module_for_gemma3(base_model.norm)
763
+
764
+ for decoder_layer in base_model.layers:
765
+ decoder_layer: Gemma3DecoderLayer
766
+ if geglu:
767
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
768
+ if rms_norm:
769
+ _patch_rms_norm_module_for_gemma3(decoder_layer.input_layernorm)
770
+ _patch_rms_norm_module_for_gemma3(decoder_layer.post_attention_layernorm)
771
+ _patch_rms_norm_module_for_gemma3(decoder_layer.pre_feedforward_layernorm)
772
+ _patch_rms_norm_module_for_gemma3(decoder_layer.post_feedforward_layernorm)
773
+ _patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.q_norm)
774
+ _patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.k_norm)
775
+
776
+ else:
777
+ raise TypeError("The model must be Gemma3ForCausalLM.")
778
+
779
+
780
+ def apply_liger_kernel_to_gemma3(
781
+ rope: bool = True,
782
+ cross_entropy: bool = False,
783
+ fused_linear_cross_entropy: bool = True,
784
+ layer_norm: bool = True,
785
+ rms_norm: bool = True,
786
+ geglu: bool = True,
787
+ model: PreTrainedModel = None,
788
+ ) -> None:
789
+ """
790
+ Apply Liger kernels to replace original implementation in HuggingFace Gemma3
791
+
792
+ Args:
793
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
794
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
795
+ fused_linear_cross_entropy (bool):
796
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
797
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
798
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
799
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
800
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
801
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
802
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
803
+ loaded. Default is None.
804
+ """
805
+ assert not (cross_entropy and fused_linear_cross_entropy), (
806
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
807
+ )
808
+
809
+ from transformers.models.gemma3 import modeling_gemma3
810
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3ForConditionalGeneration
811
+ from transformers.models.siglip import modeling_siglip
812
+ from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
813
+ from transformers.models.siglip.modeling_siglip import SiglipVisionModel
814
+
815
+ from liger_kernel.transformers.model.gemma3 import multimodal_forward
816
+
817
+ _patch_rms_norm_module_for_gemma3 = partial(
818
+ _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
819
+ )
820
+
821
+ if layer_norm:
822
+ modeling_siglip.nn.LayerNorm = LigerLayerNorm
823
+
824
+ apply_liger_kernel_to_gemma3_text(
825
+ rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
826
+ )
827
+
828
+ if cross_entropy:
829
+ modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
830
+
831
+ if fused_linear_cross_entropy:
832
+ modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
833
+
834
+ if model is not None:
835
+ # The model instance already exists, so we need to additionally patch the
836
+ # instance variables that reference already-instantiated modules
837
+
838
+ if isinstance(model, Gemma3ForConditionalGeneration):
839
+ if isinstance(model.vision_tower, SiglipVisionModel):
840
+ vision_tower = model.vision_tower
841
+
842
+ _patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
843
+
844
+ for layer in vision_tower.vision_model.encoder.layers:
845
+ layer: SiglipEncoderLayer
846
+ if layer_norm:
847
+ _patch_layer_norm_module(layer.layer_norm1)
848
+ _patch_layer_norm_module(layer.layer_norm2)
849
+ else:
850
+ raise TypeError("The vision tower must be SiglipVisionModel")
851
+
852
+ if rms_norm:
853
+ _patch_rms_norm_module_for_gemma3(model.multi_modal_projector.mm_soft_emb_norm)
854
+
855
+ apply_liger_kernel_to_gemma3_text(
856
+ rope=rope,
857
+ cross_entropy=False,
858
+ fused_linear_cross_entropy=False,
859
+ rms_norm=rms_norm,
860
+ geglu=geglu,
861
+ model=model.language_model,
862
+ )
863
+
864
+ else:
865
+ raise TypeError("The model must be Gemma3ForConditionalGeneration.")
866
+
867
+
868
+ def apply_liger_kernel_to_paligemma(
869
+ rope: bool = True,
870
+ cross_entropy: bool = False,
871
+ fused_linear_cross_entropy: bool = True,
872
+ layer_norm: bool = True,
873
+ rms_norm: bool = True,
874
+ geglu: bool = True,
875
+ model: PreTrainedModel = None,
876
+ ) -> None:
877
+ """
878
+ Apply Liger kernels to replace original implementation in HuggingFace PaliGemma
879
+
880
+ Args:
881
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
882
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
883
+ fused_linear_cross_entropy (bool):
884
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
885
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
886
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
887
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
888
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
889
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
890
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
891
+ loaded. Default is None.
892
+ """
893
+ assert not (cross_entropy and fused_linear_cross_entropy), (
894
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
895
+ )
896
+
897
+ # PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
898
+
899
+ from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
900
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
901
+ from transformers.models.paligemma import modeling_paligemma
902
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
903
+ from transformers.models.siglip import modeling_siglip
904
+ from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
905
+ from transformers.models.siglip.modeling_siglip import SiglipVisionModel
906
+
907
+ from liger_kernel.transformers.model.paligemma import lce_forward
908
+ from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
909
+
910
+ # The vision_tower is a SiglipVisionModel
911
+ if layer_norm:
912
+ modeling_siglip.nn.LayerNorm = LigerLayerNorm
913
+
914
+ # SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
915
+ # The multi_modal_projector is Linear, nothing to do
916
+
917
+ # The language_model is GemmaForCausalLM or Gemma2ForCausalLM
918
+ apply_liger_kernel_to_gemma(
919
+ rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
920
+ )
921
+ apply_liger_kernel_to_gemma2(
922
+ rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
923
+ )
924
+ # Handle loss function
925
+ if cross_entropy:
926
+ modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
927
+ if fused_linear_cross_entropy:
928
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
929
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
930
+ else: # if version < 4.46.1
931
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
932
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
933
+
934
+ if model is not None:
935
+ # The model instance already exists, so we need to additionally patch the
936
+ # instance variables that reference already-instantiated modules
937
+
938
+ if not isinstance(model, PaliGemmaForConditionalGeneration):
939
+ raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
940
+
941
+ vision_tower: SiglipVisionModel = model.vision_tower
942
+
943
+ _patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
944
+
945
+ for layer in vision_tower.vision_model.encoder.layers:
946
+ layer: SiglipEncoderLayer
947
+ if layer_norm:
948
+ _patch_layer_norm_module(layer.layer_norm1)
949
+ _patch_layer_norm_module(layer.layer_norm2)
950
+
951
+ language_model = model.language_model
952
+
953
+ if isinstance(language_model, GemmaForCausalLM):
954
+ apply_liger_kernel_to_gemma(
955
+ rope=rope,
956
+ cross_entropy=False,
957
+ fused_linear_cross_entropy=False,
958
+ rms_norm=rms_norm,
959
+ geglu=geglu,
960
+ model=language_model,
961
+ )
962
+
963
+ elif isinstance(language_model, Gemma2ForCausalLM):
964
+ apply_liger_kernel_to_gemma2(
965
+ rope=rope,
966
+ cross_entropy=False,
967
+ fused_linear_cross_entropy=False,
968
+ rms_norm=rms_norm,
969
+ geglu=geglu,
970
+ model=language_model,
971
+ )
972
+ else:
973
+ raise TypeError(
974
+ "The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM."
975
+ )
976
+
977
+
603
978
  def apply_liger_kernel_to_qwen2(
604
979
  rope: bool = True,
605
980
  cross_entropy: bool = False,
@@ -666,7 +1041,7 @@ def apply_liger_kernel_to_qwen2(
666
1041
 
667
1042
  for decoder_layer in base_model.layers:
668
1043
  if swiglu:
669
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
1044
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
670
1045
  if rms_norm:
671
1046
  _patch_rms_norm_module(decoder_layer.input_layernorm)
672
1047
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -739,7 +1114,7 @@ def apply_liger_kernel_to_qwen2_vl(
739
1114
  _patch_rms_norm_module(base_model.norm)
740
1115
  for decoder_layer in base_model.layers:
741
1116
  if swiglu:
742
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
1117
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
743
1118
  if rms_norm:
744
1119
  _patch_rms_norm_module(decoder_layer.input_layernorm)
745
1120
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -806,7 +1181,7 @@ def apply_liger_kernel_to_qwen2_5_vl(
806
1181
  _patch_rms_norm_module(base_model.norm)
807
1182
  for decoder_layer in base_model.layers:
808
1183
  if swiglu:
809
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
1184
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
810
1185
  if rms_norm:
811
1186
  _patch_rms_norm_module(decoder_layer.input_layernorm)
812
1187
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -875,7 +1250,7 @@ def apply_liger_kernel_to_phi3(
875
1250
 
876
1251
  for decoder_layer in base_model.layers:
877
1252
  if swiglu:
878
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward)
1253
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
879
1254
  if rms_norm:
880
1255
  _patch_rms_norm_module(decoder_layer.input_layernorm)
881
1256
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -938,7 +1313,7 @@ def apply_liger_kernel_to_olmo2(
938
1313
 
939
1314
  for decoder_layer in base_model.layers:
940
1315
  if swiglu:
941
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
1316
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
942
1317
  if rms_norm:
943
1318
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
944
1319
  _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
@@ -948,7 +1323,10 @@ def apply_liger_kernel_to_olmo2(
948
1323
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
949
1324
  "gemma": apply_liger_kernel_to_gemma,
950
1325
  "gemma2": apply_liger_kernel_to_gemma2,
1326
+ "gemma3_text": apply_liger_kernel_to_gemma3_text,
1327
+ "gemma3": apply_liger_kernel_to_gemma3,
951
1328
  "llama": apply_liger_kernel_to_llama,
1329
+ "llava": apply_liger_kernel_to_llava,
952
1330
  "granite": apply_liger_kernel_to_granite,
953
1331
  "mllama": apply_liger_kernel_to_mllama,
954
1332
  "mllama_text_model": apply_liger_kernel_to_mllama,
@@ -959,6 +1337,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
959
1337
  "qwen2_vl": apply_liger_kernel_to_qwen2_vl,
960
1338
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
961
1339
  "phi3": apply_liger_kernel_to_phi3,
1340
+ "paligemma": apply_liger_kernel_to_paligemma,
962
1341
  }
963
1342
 
964
1343
 
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: liger_kernel
3
- Version: 0.5.5
3
+ Version: 0.5.7
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -45,6 +45,7 @@ Requires-Dist: datasets>=2.19.2; extra == "dev"
45
45
  Requires-Dist: seaborn; extra == "dev"
46
46
  Requires-Dist: mkdocs; extra == "dev"
47
47
  Requires-Dist: mkdocs-material; extra == "dev"
48
+ Dynamic: license-file
48
49
  Dynamic: provides-extra
49
50
  Dynamic: requires-dist
50
51
 
@@ -115,6 +116,7 @@ Dynamic: requires-dist
115
116
  <details>
116
117
  <summary>Latest News 🔥</summary>
117
118
 
119
+ - [2025/03/06] We release a joint blog post on TorchTune × Liger - [Peak Performance, Minimized Memory: Optimizing torchtune’s performance with torch.compile & Liger Kernel](https://pytorch.org/blog/peak-performance-minimized-memory/)
118
120
  - [2024/12/11] We release [v0.5.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.5.0): 80% more memory efficient post training losses (DPO, ORPO, CPO, etc)!
119
121
  - [2024/12/5] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training)
120
122
  - [2024/11/6] We release [v0.4.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.4.0): Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision!
@@ -177,7 +179,7 @@ y = orpo_loss(lm_head.weight, x, target)
177
179
  - **Exact:** Computation is exact—no approximations! Both forward and backward passes are implemented with rigorous unit tests and undergo convergence testing against training runs without Liger Kernel to ensure accuracy.
178
180
  - **Lightweight:** Liger Kernel has minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches!
179
181
  - **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP, DeepSpeed, DDP, etc.).
180
- - **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift)
182
+ - **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift), [oumi](https://github.com/oumi-ai/oumi/tree/main)
181
183
 
182
184
  ## Installation
183
185
 
@@ -312,6 +314,9 @@ loss.backward()
312
314
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
313
315
  | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
314
316
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
317
+ | Gemma3 (Text) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3_text` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
318
+ | Gemma3 (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
319
+ | Paligemma, Paligemma2, & Paligemma2 Mix | `liger_kernel.transformers.apply_liger_kernel_to_paligemma` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
315
320
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
316
321
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
317
322
  | Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -386,8 +391,8 @@ loss.backward()
386
391
  ## Contact
387
392
 
388
393
  - For issues, create a Github ticket in this repository
389
- - For open discussion, join [our discord channel](https://discord.gg/gpumode)
390
- - For formal collaboration, send an email to yannchen@linkedin.com
394
+ - For open discussion, join [our discord channel on GPUMode](https://discord.com/channels/1189498204333543425/1275130785933951039)
395
+ - For formal collaboration, send an email to yannchen@linkedin.com and hning@linkedin.com
391
396
 
392
397
  ## Cite this work
393
398
 
@@ -406,7 +411,7 @@ Biblatex entry:
406
411
  ```
407
412
 
408
413
  ## Star History
409
- [![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://star-history.com/#linkedin/Liger-Kernel&Date)
414
+ [![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://www.star-history.com/#linkedin/Liger-Kernel&Date)
410
415
 
411
416
  <p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
412
417
  <a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">