liger-kernel 0.5.5__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. liger_kernel/chunked_loss/functional.py +2 -0
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +17 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +331 -0
  4. liger_kernel/chunked_loss/grpo_loss.py +103 -61
  5. liger_kernel/chunked_loss/jsd_loss.py +12 -7
  6. liger_kernel/ops/cross_entropy.py +3 -2
  7. liger_kernel/ops/dyt.py +225 -0
  8. liger_kernel/ops/fused_linear_jsd.py +2 -1
  9. liger_kernel/ops/jsd.py +30 -11
  10. liger_kernel/ops/kl_div.py +2 -2
  11. liger_kernel/transformers/__init__.py +3 -0
  12. liger_kernel/transformers/dyt.py +20 -0
  13. liger_kernel/transformers/functional.py +5 -0
  14. liger_kernel/transformers/model/gemma.py +8 -16
  15. liger_kernel/transformers/model/gemma2.py +7 -16
  16. liger_kernel/transformers/model/llama.py +8 -15
  17. liger_kernel/transformers/model/llava.py +369 -0
  18. liger_kernel/transformers/model/loss_utils.py +57 -0
  19. liger_kernel/transformers/model/mistral.py +9 -10
  20. liger_kernel/transformers/model/mixtral.py +8 -15
  21. liger_kernel/transformers/model/mllama.py +8 -15
  22. liger_kernel/transformers/model/olmo2.py +8 -16
  23. liger_kernel/transformers/model/paligemma.py +397 -0
  24. liger_kernel/transformers/model/phi3.py +8 -15
  25. liger_kernel/transformers/model/qwen2.py +8 -15
  26. liger_kernel/transformers/model/qwen2_5_vl.py +9 -10
  27. liger_kernel/transformers/model/qwen2_vl.py +9 -10
  28. liger_kernel/transformers/monkey_patch.py +219 -13
  29. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/METADATA +9 -6
  30. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/RECORD +34 -29
  31. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/WHEEL +1 -1
  32. liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
  33. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info/licenses}/LICENSE +0 -0
  34. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info/licenses}/NOTICE +0 -0
  35. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,8 @@ from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_for
19
19
  from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
20
20
  from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
21
21
  from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
22
+ from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
23
+ from liger_kernel.transformers.model.llava import lce_forward_deprecated as llava_lce_forward_deprecated
22
24
  from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
23
25
  from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
24
26
  from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
@@ -52,13 +54,26 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
52
54
  module.in_place = in_place
53
55
  _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
54
56
  _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
57
+ module.__class__.__name__ = LigerRMSNorm.__name__
55
58
 
56
59
 
57
60
  def _patch_layer_norm_module(module, eps=1e-6):
58
61
  module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
59
- module.hidden_size = module.normalized_shape
62
+ module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
63
+
60
64
  _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
61
65
  _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
66
+ module.__class__.__name__ = LigerLayerNorm.__name__
67
+
68
+
69
+ def _patch_swiglu_module(module, liger_module):
70
+ _bind_method_to_module(module, "forward", liger_module.forward)
71
+ module.__class__.__name__ = liger_module.__name__
72
+
73
+
74
+ def _patch_geglu_module(module):
75
+ _bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
76
+ module.__class__.__name__ = LigerGEGLUMLP.__name__
62
77
 
63
78
 
64
79
  def apply_liger_kernel_to_granite(
@@ -134,7 +149,7 @@ def apply_liger_kernel_to_granite(
134
149
 
135
150
  for decoder_layer in base_model.layers:
136
151
  if swiglu:
137
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
152
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
138
153
  if rms_norm:
139
154
  _patch_rms_norm_module(decoder_layer.input_layernorm)
140
155
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -206,12 +221,91 @@ def apply_liger_kernel_to_llama(
206
221
 
207
222
  for decoder_layer in base_model.layers:
208
223
  if swiglu:
209
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
224
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
210
225
  if rms_norm:
211
226
  _patch_rms_norm_module(decoder_layer.input_layernorm)
212
227
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
213
228
 
214
229
 
230
+ def apply_liger_kernel_to_llava(
231
+ cross_entropy: bool = False,
232
+ fused_linear_cross_entropy: bool = True,
233
+ model: PreTrainedModel = None,
234
+ **kwargs,
235
+ ) -> None:
236
+ """
237
+ Apply Liger kernels to replace original implementation in HuggingFace Llava models.
238
+ Due to the characteristics of LlaVa, the model must be passed to apply Liger-Kernel's patch to other models connected to LLaVa.
239
+ However, if an LM not supported by Liger-Kernel is connected to LLaVa, unexpected side effects may occur.
240
+ NOTE: Llava is not available in transformers<4.36.0
241
+
242
+ Args:
243
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
244
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
245
+ fused_linear_cross_entropy (bool):
246
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
247
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
248
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
249
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
250
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
251
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
252
+ loaded. Default is None.
253
+ """
254
+ assert not (cross_entropy and fused_linear_cross_entropy), (
255
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
256
+ )
257
+
258
+ from transformers.models.llava import modeling_llava
259
+
260
+ if cross_entropy:
261
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
262
+ modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
263
+ if fused_linear_cross_entropy:
264
+ if transformer_version >= version.parse("4.49.0"):
265
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
266
+ else: # if version < 4.49.0
267
+ logger.warning(
268
+ "Support for transformers versions < 4.49.0 will soon be discontinued due to issues with incorrect legacy processing. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/35526"
269
+ )
270
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
271
+
272
+ if model is not None:
273
+ text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
274
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
275
+ vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)
276
+
277
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs}
278
+ if text_liger_fn:
279
+ accept_params = inspect.signature(text_liger_fn).parameters
280
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
281
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
282
+
283
+ if remain_params:
284
+ logger.warning(
285
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
286
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
287
+ )
288
+ text_kwargs["model"] = model.language_model
289
+ text_liger_fn(**text_kwargs)
290
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
291
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
292
+
293
+ if vision_liger_fn:
294
+ accept_params = inspect.signature(vision_liger_fn).parameters
295
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
296
+ vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
297
+
298
+ if remain_params:
299
+ logger.warning(
300
+ f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
301
+ f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
302
+ )
303
+ vision_kwargs["model"] = model.vision_tower
304
+ vision_liger_fn(**vision_kwargs)
305
+ elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
306
+ logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
307
+
308
+
215
309
  def apply_liger_kernel_to_mllama(
216
310
  rope: bool = True,
217
311
  cross_entropy: bool = False,
@@ -296,7 +390,7 @@ def apply_liger_kernel_to_mllama(
296
390
  _patch_rms_norm_module(text_model.norm)
297
391
  for decoder_layer in text_model.layers:
298
392
  if swiglu:
299
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
393
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
300
394
  if rms_norm:
301
395
  _patch_rms_norm_module(decoder_layer.input_layernorm)
302
396
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -370,7 +464,7 @@ def apply_liger_kernel_to_mistral(
370
464
 
371
465
  for decoder_layer in base_model.layers:
372
466
  if swiglu:
373
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
467
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
374
468
  if rms_norm:
375
469
  _patch_rms_norm_module(decoder_layer.input_layernorm)
376
470
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -442,7 +536,7 @@ def apply_liger_kernel_to_mixtral(
442
536
  for decoder_layer in base_model.layers:
443
537
  if swiglu:
444
538
  for expert in decoder_layer.block_sparse_moe.experts:
445
- _bind_method_to_module(expert, "forward", LigerBlockSparseTop2MLP.forward)
539
+ _patch_swiglu_module(expert, LigerBlockSparseTop2MLP)
446
540
  if rms_norm:
447
541
  _patch_rms_norm_module(decoder_layer.input_layernorm)
448
542
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -516,7 +610,7 @@ def apply_liger_kernel_to_gemma(
516
610
 
517
611
  for decoder_layer in base_model.layers:
518
612
  if geglu:
519
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
613
+ _patch_geglu_module(decoder_layer.mlp)
520
614
  if rms_norm:
521
615
  _patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
522
616
  _patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
@@ -592,7 +686,7 @@ def apply_liger_kernel_to_gemma2(
592
686
 
593
687
  for decoder_layer in base_model.layers:
594
688
  if geglu:
595
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
689
+ _patch_geglu_module(decoder_layer.mlp)
596
690
  if rms_norm:
597
691
  _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
598
692
  _patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
@@ -600,6 +694,116 @@ def apply_liger_kernel_to_gemma2(
600
694
  _patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
601
695
 
602
696
 
697
+ def apply_liger_kernel_to_paligemma(
698
+ rope: bool = True,
699
+ cross_entropy: bool = False,
700
+ fused_linear_cross_entropy: bool = True,
701
+ layer_norm: bool = True,
702
+ rms_norm: bool = True,
703
+ geglu: bool = True,
704
+ model: PreTrainedModel = None,
705
+ ) -> None:
706
+ """
707
+ Apply Liger kernels to replace original implementation in HuggingFace PaliGemma
708
+
709
+ Args:
710
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
711
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
712
+ fused_linear_cross_entropy (bool):
713
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
714
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
715
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
716
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
717
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
718
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
719
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
720
+ loaded. Default is None.
721
+ """
722
+ assert not (cross_entropy and fused_linear_cross_entropy), (
723
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
724
+ )
725
+
726
+ # PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
727
+
728
+ from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
729
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
730
+ from transformers.models.paligemma import modeling_paligemma
731
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
732
+ from transformers.models.siglip import modeling_siglip
733
+ from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
734
+ from transformers.models.siglip.modeling_siglip import SiglipVisionModel
735
+
736
+ from liger_kernel.transformers.model.paligemma import lce_forward
737
+ from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
738
+
739
+ # The vision_tower is a SiglipVisionModel
740
+ if layer_norm:
741
+ modeling_siglip.nn.LayerNorm = LigerLayerNorm
742
+
743
+ # SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
744
+ # The multi_modal_projector is Linear, nothing to do
745
+
746
+ # The language_model is GemmaForCausalLM or Gemma2ForCausalLM
747
+ apply_liger_kernel_to_gemma(
748
+ rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
749
+ )
750
+ apply_liger_kernel_to_gemma2(
751
+ rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
752
+ )
753
+ # Handle loss function
754
+ if cross_entropy:
755
+ modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
756
+ if fused_linear_cross_entropy:
757
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
758
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
759
+ else: # if version < 4.46.1
760
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
761
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
762
+
763
+ if model is not None:
764
+ # The model instance already exists, so we need to additionally patch the
765
+ # instance variables that reference already-instantiated modules
766
+
767
+ if not isinstance(model, PaliGemmaForConditionalGeneration):
768
+ raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
769
+
770
+ vision_tower: SiglipVisionModel = model.vision_tower
771
+
772
+ _patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
773
+
774
+ for layer in vision_tower.vision_model.encoder.layers:
775
+ layer: SiglipEncoderLayer
776
+ if layer_norm:
777
+ _patch_layer_norm_module(layer.layer_norm1)
778
+ _patch_layer_norm_module(layer.layer_norm2)
779
+
780
+ language_model = model.language_model
781
+
782
+ if isinstance(language_model, GemmaForCausalLM):
783
+ apply_liger_kernel_to_gemma(
784
+ rope=rope,
785
+ cross_entropy=False,
786
+ fused_linear_cross_entropy=False,
787
+ rms_norm=rms_norm,
788
+ geglu=geglu,
789
+ model=language_model,
790
+ )
791
+
792
+ elif isinstance(language_model, Gemma2ForCausalLM):
793
+ apply_liger_kernel_to_gemma2(
794
+ rope=rope,
795
+ cross_entropy=False,
796
+ fused_linear_cross_entropy=False,
797
+ rms_norm=rms_norm,
798
+ geglu=geglu,
799
+ model=language_model,
800
+ )
801
+ else:
802
+ raise TypeError(
803
+ "The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM."
804
+ )
805
+
806
+
603
807
  def apply_liger_kernel_to_qwen2(
604
808
  rope: bool = True,
605
809
  cross_entropy: bool = False,
@@ -666,7 +870,7 @@ def apply_liger_kernel_to_qwen2(
666
870
 
667
871
  for decoder_layer in base_model.layers:
668
872
  if swiglu:
669
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
873
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
670
874
  if rms_norm:
671
875
  _patch_rms_norm_module(decoder_layer.input_layernorm)
672
876
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -739,7 +943,7 @@ def apply_liger_kernel_to_qwen2_vl(
739
943
  _patch_rms_norm_module(base_model.norm)
740
944
  for decoder_layer in base_model.layers:
741
945
  if swiglu:
742
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
946
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
743
947
  if rms_norm:
744
948
  _patch_rms_norm_module(decoder_layer.input_layernorm)
745
949
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -806,7 +1010,7 @@ def apply_liger_kernel_to_qwen2_5_vl(
806
1010
  _patch_rms_norm_module(base_model.norm)
807
1011
  for decoder_layer in base_model.layers:
808
1012
  if swiglu:
809
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
1013
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
810
1014
  if rms_norm:
811
1015
  _patch_rms_norm_module(decoder_layer.input_layernorm)
812
1016
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -875,7 +1079,7 @@ def apply_liger_kernel_to_phi3(
875
1079
 
876
1080
  for decoder_layer in base_model.layers:
877
1081
  if swiglu:
878
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward)
1082
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
879
1083
  if rms_norm:
880
1084
  _patch_rms_norm_module(decoder_layer.input_layernorm)
881
1085
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -938,7 +1142,7 @@ def apply_liger_kernel_to_olmo2(
938
1142
 
939
1143
  for decoder_layer in base_model.layers:
940
1144
  if swiglu:
941
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
1145
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
942
1146
  if rms_norm:
943
1147
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
944
1148
  _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
@@ -949,6 +1153,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
949
1153
  "gemma": apply_liger_kernel_to_gemma,
950
1154
  "gemma2": apply_liger_kernel_to_gemma2,
951
1155
  "llama": apply_liger_kernel_to_llama,
1156
+ "llava": apply_liger_kernel_to_llava,
952
1157
  "granite": apply_liger_kernel_to_granite,
953
1158
  "mllama": apply_liger_kernel_to_mllama,
954
1159
  "mllama_text_model": apply_liger_kernel_to_mllama,
@@ -959,6 +1164,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
959
1164
  "qwen2_vl": apply_liger_kernel_to_qwen2_vl,
960
1165
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
961
1166
  "phi3": apply_liger_kernel_to_phi3,
1167
+ "paligemma": apply_liger_kernel_to_paligemma,
962
1168
  }
963
1169
 
964
1170
 
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: liger_kernel
3
- Version: 0.5.5
3
+ Version: 0.5.6
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -45,6 +45,7 @@ Requires-Dist: datasets>=2.19.2; extra == "dev"
45
45
  Requires-Dist: seaborn; extra == "dev"
46
46
  Requires-Dist: mkdocs; extra == "dev"
47
47
  Requires-Dist: mkdocs-material; extra == "dev"
48
+ Dynamic: license-file
48
49
  Dynamic: provides-extra
49
50
  Dynamic: requires-dist
50
51
 
@@ -115,6 +116,7 @@ Dynamic: requires-dist
115
116
  <details>
116
117
  <summary>Latest News 🔥</summary>
117
118
 
119
+ - [2025/03/06] We release a joint blog post on TorchTune × Liger - [Peak Performance, Minimized Memory: Optimizing torchtune’s performance with torch.compile & Liger Kernel](https://pytorch.org/blog/peak-performance-minimized-memory/)
118
120
  - [2024/12/11] We release [v0.5.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.5.0): 80% more memory efficient post training losses (DPO, ORPO, CPO, etc)!
119
121
  - [2024/12/5] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training)
120
122
  - [2024/11/6] We release [v0.4.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.4.0): Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision!
@@ -177,7 +179,7 @@ y = orpo_loss(lm_head.weight, x, target)
177
179
  - **Exact:** Computation is exact—no approximations! Both forward and backward passes are implemented with rigorous unit tests and undergo convergence testing against training runs without Liger Kernel to ensure accuracy.
178
180
  - **Lightweight:** Liger Kernel has minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches!
179
181
  - **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP, DeepSpeed, DDP, etc.).
180
- - **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift)
182
+ - **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift), [oumi](https://github.com/oumi-ai/oumi/tree/main)
181
183
 
182
184
  ## Installation
183
185
 
@@ -312,6 +314,7 @@ loss.backward()
312
314
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
313
315
  | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
314
316
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
317
+ | Paligemma, Paligemma2, & Paligemma2 Mix | `liger_kernel.transformers.apply_liger_kernel_to_paligemma` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
315
318
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
316
319
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
317
320
  | Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -386,8 +389,8 @@ loss.backward()
386
389
  ## Contact
387
390
 
388
391
  - For issues, create a Github ticket in this repository
389
- - For open discussion, join [our discord channel](https://discord.gg/gpumode)
390
- - For formal collaboration, send an email to yannchen@linkedin.com
392
+ - For open discussion, join [our discord channel on GPUMode](https://discord.com/channels/1189498204333543425/1275130785933951039)
393
+ - For formal collaboration, send an email to yannchen@linkedin.com and hning@linkedin.com
391
394
 
392
395
  ## Cite this work
393
396
 
@@ -406,7 +409,7 @@ Biblatex entry:
406
409
  ```
407
410
 
408
411
  ## Star History
409
- [![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://star-history.com/#linkedin/Liger-Kernel&Date)
412
+ [![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://www.star-history.com/#linkedin/Liger-Kernel&Date)
410
413
 
411
414
  <p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
412
415
  <a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
@@ -5,24 +5,25 @@ liger_kernel/chunked_loss/README.md,sha256=0FmkFC3hKBqyoDT5uTlIYmrvRkF-EOCR1y-EB
5
5
  liger_kernel/chunked_loss/__init__.py,sha256=ATu-xX5Fc49Cr6yBOGBRNTo593ZrU5ZCsIuvoIbJWw4,603
6
6
  liger_kernel/chunked_loss/cpo_loss.py,sha256=Gzz1eU4kgcbdubFVRy55e8A1Cr-r45UgNicXwZIjmBU,5454
7
7
  liger_kernel/chunked_loss/dpo_loss.py,sha256=xZwGqS04si9zXyob95SAdalC-hajZg8fWINqiqffN8k,5855
8
- liger_kernel/chunked_loss/functional.py,sha256=THWWpCnRVhTVfnPnyvQjdBvo1JDtxhwLmtZE_yiBBqM,817
9
- liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=FJh7k3sry-fqnBApLSngf7h-lHQEiXtOY_tiRDVanPM,11022
8
+ liger_kernel/chunked_loss/functional.py,sha256=9G3nKm-Bi7uoZRFkL8wwGMl6juDl4bSzDvTa5GHZPzg,955
9
+ liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=ooR-qnZCyWJN935oHCSWLaKKKyaYERyhNczRGi1VOiw,11935
10
+ liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=-E4AuWY-y2bMo_kAmEQBgQ92UJh3L5IiCRGVcfMJOCE,12731
10
11
  liger_kernel/chunked_loss/fused_linear_preference.py,sha256=ojB42jYPu0c4ki96Ft-hy7Sf6fh_WikG-aWNrlZzSio,18362
11
- liger_kernel/chunked_loss/fused_linear_rlhf.py,sha256=wGujqwLz91mOE9MmdenhBIKvbmswhwtINMCpcP7D74c,9050
12
12
  liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=RiuK3UtRwH9T6jZ36sA8Urj-TVuOLOO2syLg_JOQapY,13437
13
- liger_kernel/chunked_loss/grpo_loss.py,sha256=axED3628yKODu1v7PMAvSd08WZqwNQvJOTUYMgcihdQ,6665
14
- liger_kernel/chunked_loss/jsd_loss.py,sha256=j2_1AYLu0FW2VQJIEr1J1qHsWd5VUo6C3aedglHVH4Y,6771
13
+ liger_kernel/chunked_loss/grpo_loss.py,sha256=6Mb4ZT6MfnOr4Xo681rMR0LKkhzJhInvQp8wp2YVMK0,8913
14
+ liger_kernel/chunked_loss/jsd_loss.py,sha256=u2ahkuHsbhpNaKcpBCz5gCMDk9ou-P04DHji592dIBo,7067
15
15
  liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsmSbQyqwQY,7529
16
16
  liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
17
17
  liger_kernel/chunked_loss/simpo_loss.py,sha256=fy2w8KbhMrBv7b1jdIeH3bBFxY52bPQPZb3KwBvmurM,5385
18
18
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- liger_kernel/ops/cross_entropy.py,sha256=yKKhN63I7r9NxJye4wTLBvvKAyrXQt6jf4nBo3lJyVg,18860
19
+ liger_kernel/ops/cross_entropy.py,sha256=T5oSsqOS1y-Iea5o9v_BSU-_mIEXqWAT1oX_m59NcA4,18941
20
+ liger_kernel/ops/dyt.py,sha256=YD1-buHz9VmIX838VKzLc-lm5CeUQ4LAskGDWBUMQHA,6187
20
21
  liger_kernel/ops/fused_linear_cross_entropy.py,sha256=1Y3Uk_TCSjqKgoG2eot1ptnWXJXXQESqGvOmqAW1gsM,10912
21
- liger_kernel/ops/fused_linear_jsd.py,sha256=Seshez2qaM6HiTQ8_HEqSwhaeVruNT1SvIM4ZrAPBEU,9602
22
+ liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
22
23
  liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
23
24
  liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
24
- liger_kernel/ops/jsd.py,sha256=0jNeRxpcNI5ckxCdoCNyO5GEedLIuzx3lz6KAiksc4o,6109
25
- liger_kernel/ops/kl_div.py,sha256=MnfuYqqQESON1X2Swy064x1urKtMFdgeSWd60VttBXI,8420
25
+ liger_kernel/ops/jsd.py,sha256=rkloGA7nDfVaa5nKY6-EYBw0E1p_MSsl4fr2xZGTp04,6961
26
+ liger_kernel/ops/kl_div.py,sha256=NkG7D6_DnPBzr-ohhYiQbRBnq_fbGmpn5UU7y0UBKQo,8420
26
27
  liger_kernel/ops/layer_norm.py,sha256=6roQjioyg-9O2qLPV8nL4U0-5UH80tdzOMTWwjvDnn8,7961
27
28
  liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
28
29
  liger_kernel/ops/rms_norm.py,sha256=PWLJcdIKU5e-8BuYFHd9Cqlq6wmr6fUXKi9zQD4LetU,11727
@@ -32,10 +33,11 @@ liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
32
33
  liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
33
34
  liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
34
35
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
35
- liger_kernel/transformers/__init__.py,sha256=4bwMPQhGHxmZ-WTFAMD9m-s0PYyfcvIRxhq_h3b0Wz0,2363
36
+ liger_kernel/transformers/__init__.py,sha256=t70gqygxH63iz-B0MOdZx4AEgA8MfqU1G7N6dvIneCY,2618
36
37
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
37
38
  liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
38
- liger_kernel/transformers/functional.py,sha256=ShLD3eb--XKNtllznCrOYTbo4f-1KVwzi0KLMICdrn4,4942
39
+ liger_kernel/transformers/dyt.py,sha256=QMqqc14pkE0WhpRZvapfnNAun-6C0C_tHExL2ZJuCUA,648
40
+ liger_kernel/transformers/functional.py,sha256=4h9Pdx_iINBqfv2Zod_c27qOpYXDDwbdVgatQ9_XBmI,5089
39
41
  liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=09Rt7FZzLH42VOcIbQ4dlQd0o3Rlb4vk6fqiOQ7WTD8,1778
40
42
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
41
43
  liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
@@ -43,7 +45,7 @@ liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD
43
45
  liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
44
46
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
45
47
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
46
- liger_kernel/transformers/monkey_patch.py,sha256=9ud9tv1LI9WIa9UDu0abGIiusIIkayO1fjAUMWgwwT0,47096
48
+ liger_kernel/transformers/monkey_patch.py,sha256=95afvIrZA9xSWLNIJspBLbz8lxv2Y5gfZke7MyqoOX8,56965
47
49
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
48
50
  liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUSyH2uFW2A,1191
49
51
  liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
@@ -52,24 +54,27 @@ liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx
52
54
  liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_JerQ,450
53
55
  liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
54
56
  liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
55
- liger_kernel/transformers/model/gemma.py,sha256=ky89b3aWPaeTGRMC-745KgixtQIRXzNAiCORAMLn9yo,9654
56
- liger_kernel/transformers/model/gemma2.py,sha256=27NcoZjEqP9Lqb4Wf0EKqTbr2HTGiHPhrVyPCRGPz6s,10767
57
- liger_kernel/transformers/model/llama.py,sha256=3LJFXKFDKvEakaWPc_NicSFst4Y_hdSMrdl1UDK1EcA,10330
58
- liger_kernel/transformers/model/mistral.py,sha256=MVRksI5_j_8WJu8znOHKCdSI5jSu-S7cdFYzt9m_vIQ,5180
59
- liger_kernel/transformers/model/mixtral.py,sha256=jpZJkpl625Q-JHWarj2MqT5mRaSsiCtg0c9vVyvOdCY,11430
60
- liger_kernel/transformers/model/mllama.py,sha256=qWexBdskuN3gPJvPUwt4J0nU675tGD6W7wxgRZ9Bifg,11145
61
- liger_kernel/transformers/model/olmo2.py,sha256=yyksS6E4fuWd8asEW8rEDBKqZpFmP4ITCM_bjIDZaoY,5124
62
- liger_kernel/transformers/model/phi3.py,sha256=biRa8fph9qdnQmkD9I21t5XIjpIt1i6UKU4uk8Up8pU,10292
63
- liger_kernel/transformers/model/qwen2.py,sha256=14UuPjxB-tjqWn85Tn4fqBFvVhVsth5iPEt8kJSMiew,9581
64
- liger_kernel/transformers/model/qwen2_5_vl.py,sha256=l71WBfX0ptrisoURIRwXJH7MQ2vGKOvcRYMNsrydwlQ,9455
65
- liger_kernel/transformers/model/qwen2_vl.py,sha256=yMLqsfSYcvhClUpTUjGoADiOxfLB2B8240VdrPP0c8s,9851
57
+ liger_kernel/transformers/model/gemma.py,sha256=7cBTljzh-8_ACBhYl6NUfj5_ux92YRlmnAU5gfDAQAI,9312
58
+ liger_kernel/transformers/model/gemma2.py,sha256=X0FOIhvFlTrmWI7Ws06wUkutgHW3lWtLOnnHp1NgZ3A,10403
59
+ liger_kernel/transformers/model/llama.py,sha256=d9rBaK8e8RSMCFHdgom9ZHuXOlnh6U_o-GkAFGRNGOY,9989
60
+ liger_kernel/transformers/model/llava.py,sha256=b0pEagjUbu2-eS9xegjyfl1DwIXLwZcNpff55ibaMbA,17601
61
+ liger_kernel/transformers/model/loss_utils.py,sha256=Z-fUrf-cUDUjUIH7Tl9OL2hT8nmtx7ES3kg8syuWKy4,1476
62
+ liger_kernel/transformers/model/mistral.py,sha256=o7tyl1sPWPfZwwrBLRlryHlSI8I55viuJoMI5Bh5Nww,5014
63
+ liger_kernel/transformers/model/mixtral.py,sha256=T0ITv2-PkR8VErVOVUizoS4EzjmARyR7GFh0tXDB_i4,11089
64
+ liger_kernel/transformers/model/mllama.py,sha256=RCKtwnGOMFYIbtt1zUQ15Cyv4eNpHkTWcgkmG2EEs2I,10804
65
+ liger_kernel/transformers/model/olmo2.py,sha256=5M8kczp4D-jvbjcV7cKATIJGF34xd-Rs-PPdKZWSIlY,4685
66
+ liger_kernel/transformers/model/paligemma.py,sha256=GNReT6tVZt3ON6aaa9ovg8mnu1hYocSx9OhgC7b-_28,19191
67
+ liger_kernel/transformers/model/phi3.py,sha256=NmU2DuU1Huwha6K7YSsJCnvQfUovTTGlsfBZhbx0UoI,9951
68
+ liger_kernel/transformers/model/qwen2.py,sha256=t7NotBHoebsPqNSxwaf9DXTg8jxgB5BdunSGqYOE0hQ,9240
69
+ liger_kernel/transformers/model/qwen2_5_vl.py,sha256=70BnHZjx6eQWTwi3zc5SMwxTeOOA4Tbdkfy6IYRcTaM,9289
70
+ liger_kernel/transformers/model/qwen2_vl.py,sha256=zo4O9fShNHYqSLrzLGqQYWSMtJI6UHaSY7zvMCYWyD8,9685
66
71
  liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
67
72
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
68
73
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
69
74
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
70
- liger_kernel-0.5.5.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
71
- liger_kernel-0.5.5.dist-info/METADATA,sha256=PRpIrVa7cvCW-D7zMA6qpsQ1iJogiK6POWpYUbYHYr4,22411
72
- liger_kernel-0.5.5.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
73
- liger_kernel-0.5.5.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
74
- liger_kernel-0.5.5.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
75
- liger_kernel-0.5.5.dist-info/RECORD,,
75
+ liger_kernel-0.5.6.dist-info/licenses/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
76
+ liger_kernel-0.5.6.dist-info/licenses/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
77
+ liger_kernel-0.5.6.dist-info/METADATA,sha256=yam1-5oz74ok_T_rVfn3RLvCDXPxDfXZpChC1PVTFoY,23002
78
+ liger_kernel-0.5.6.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
79
+ liger_kernel-0.5.6.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
80
+ liger_kernel-0.5.6.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (76.0.0)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5