liger-kernel 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +17 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +346 -0
- liger_kernel/chunked_loss/grpo_loss.py +134 -60
- liger_kernel/chunked_loss/jsd_loss.py +12 -7
- liger_kernel/ops/cross_entropy.py +3 -2
- liger_kernel/ops/dyt.py +225 -0
- liger_kernel/ops/fused_linear_jsd.py +2 -1
- liger_kernel/ops/jsd.py +32 -12
- liger_kernel/ops/kl_div.py +15 -8
- liger_kernel/ops/layer_norm.py +14 -1
- liger_kernel/ops/rms_norm.py +12 -1
- liger_kernel/transformers/__init__.py +133 -15
- liger_kernel/transformers/dyt.py +20 -0
- liger_kernel/transformers/functional.py +5 -0
- liger_kernel/transformers/gema3_rms.py +8 -0
- liger_kernel/transformers/model/gemma.py +17 -20
- liger_kernel/transformers/model/gemma2.py +17 -21
- liger_kernel/transformers/model/gemma3.py +335 -0
- liger_kernel/transformers/model/llama.py +17 -19
- liger_kernel/transformers/model/llava.py +369 -0
- liger_kernel/transformers/model/loss_utils.py +64 -0
- liger_kernel/transformers/model/mistral.py +28 -25
- liger_kernel/transformers/model/mixtral.py +20 -26
- liger_kernel/transformers/model/mllama.py +17 -19
- liger_kernel/transformers/model/olmo2.py +17 -20
- liger_kernel/transformers/model/paligemma.py +397 -0
- liger_kernel/transformers/model/phi3.py +17 -19
- liger_kernel/transformers/model/qwen2.py +17 -19
- liger_kernel/transformers/model/qwen2_5_vl.py +9 -10
- liger_kernel/transformers/model/qwen2_vl.py +9 -10
- liger_kernel/transformers/monkey_patch.py +392 -13
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/METADATA +11 -6
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/RECORD +38 -31
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/WHEEL +1 -1
- liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/LICENSE +0 -0
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/NOTICE +0 -0
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/top_level.txt +0 -0
|
@@ -19,6 +19,8 @@ from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_for
|
|
|
19
19
|
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
|
|
20
20
|
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
|
21
21
|
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
|
|
22
|
+
from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
|
|
23
|
+
from liger_kernel.transformers.model.llava import lce_forward_deprecated as llava_lce_forward_deprecated
|
|
22
24
|
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
|
|
23
25
|
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
|
|
24
26
|
from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
|
|
@@ -52,13 +54,26 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
|
|
|
52
54
|
module.in_place = in_place
|
|
53
55
|
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
|
54
56
|
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
57
|
+
module.__class__.__name__ = LigerRMSNorm.__name__
|
|
55
58
|
|
|
56
59
|
|
|
57
60
|
def _patch_layer_norm_module(module, eps=1e-6):
|
|
58
61
|
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
59
|
-
module.hidden_size = module
|
|
62
|
+
module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
|
|
63
|
+
|
|
60
64
|
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
61
65
|
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
66
|
+
module.__class__.__name__ = LigerLayerNorm.__name__
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _patch_swiglu_module(module, liger_module):
|
|
70
|
+
_bind_method_to_module(module, "forward", liger_module.forward)
|
|
71
|
+
module.__class__.__name__ = liger_module.__name__
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _patch_geglu_module(module):
|
|
75
|
+
_bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
|
|
76
|
+
module.__class__.__name__ = LigerGEGLUMLP.__name__
|
|
62
77
|
|
|
63
78
|
|
|
64
79
|
def apply_liger_kernel_to_granite(
|
|
@@ -134,7 +149,7 @@ def apply_liger_kernel_to_granite(
|
|
|
134
149
|
|
|
135
150
|
for decoder_layer in base_model.layers:
|
|
136
151
|
if swiglu:
|
|
137
|
-
|
|
152
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
138
153
|
if rms_norm:
|
|
139
154
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
140
155
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -206,12 +221,91 @@ def apply_liger_kernel_to_llama(
|
|
|
206
221
|
|
|
207
222
|
for decoder_layer in base_model.layers:
|
|
208
223
|
if swiglu:
|
|
209
|
-
|
|
224
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
210
225
|
if rms_norm:
|
|
211
226
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
212
227
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
213
228
|
|
|
214
229
|
|
|
230
|
+
def apply_liger_kernel_to_llava(
|
|
231
|
+
cross_entropy: bool = False,
|
|
232
|
+
fused_linear_cross_entropy: bool = True,
|
|
233
|
+
model: PreTrainedModel = None,
|
|
234
|
+
**kwargs,
|
|
235
|
+
) -> None:
|
|
236
|
+
"""
|
|
237
|
+
Apply Liger kernels to replace original implementation in HuggingFace Llava models.
|
|
238
|
+
Due to the characteristics of LlaVa, the model must be passed to apply Liger-Kernel's patch to other models connected to LLaVa.
|
|
239
|
+
However, if an LM not supported by Liger-Kernel is connected to LLaVa, unexpected side effects may occur.
|
|
240
|
+
NOTE: Llava is not available in transformers<4.36.0
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
244
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
245
|
+
fused_linear_cross_entropy (bool):
|
|
246
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
247
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
248
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
249
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
250
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
251
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
252
|
+
loaded. Default is None.
|
|
253
|
+
"""
|
|
254
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
255
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
from transformers.models.llava import modeling_llava
|
|
259
|
+
|
|
260
|
+
if cross_entropy:
|
|
261
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
262
|
+
modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
263
|
+
if fused_linear_cross_entropy:
|
|
264
|
+
if transformer_version >= version.parse("4.49.0"):
|
|
265
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
|
|
266
|
+
else: # if version < 4.49.0
|
|
267
|
+
logger.warning(
|
|
268
|
+
"Support for transformers versions < 4.49.0 will soon be discontinued due to issues with incorrect legacy processing. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/35526"
|
|
269
|
+
)
|
|
270
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
|
|
271
|
+
|
|
272
|
+
if model is not None:
|
|
273
|
+
text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
|
|
274
|
+
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
|
|
275
|
+
vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)
|
|
276
|
+
|
|
277
|
+
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs}
|
|
278
|
+
if text_liger_fn:
|
|
279
|
+
accept_params = inspect.signature(text_liger_fn).parameters
|
|
280
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
281
|
+
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
282
|
+
|
|
283
|
+
if remain_params:
|
|
284
|
+
logger.warning(
|
|
285
|
+
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
286
|
+
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
287
|
+
)
|
|
288
|
+
text_kwargs["model"] = model.language_model
|
|
289
|
+
text_liger_fn(**text_kwargs)
|
|
290
|
+
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
291
|
+
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
292
|
+
|
|
293
|
+
if vision_liger_fn:
|
|
294
|
+
accept_params = inspect.signature(vision_liger_fn).parameters
|
|
295
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
296
|
+
vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
297
|
+
|
|
298
|
+
if remain_params:
|
|
299
|
+
logger.warning(
|
|
300
|
+
f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
|
|
301
|
+
f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
|
|
302
|
+
)
|
|
303
|
+
vision_kwargs["model"] = model.vision_tower
|
|
304
|
+
vision_liger_fn(**vision_kwargs)
|
|
305
|
+
elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
306
|
+
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
|
|
307
|
+
|
|
308
|
+
|
|
215
309
|
def apply_liger_kernel_to_mllama(
|
|
216
310
|
rope: bool = True,
|
|
217
311
|
cross_entropy: bool = False,
|
|
@@ -296,7 +390,7 @@ def apply_liger_kernel_to_mllama(
|
|
|
296
390
|
_patch_rms_norm_module(text_model.norm)
|
|
297
391
|
for decoder_layer in text_model.layers:
|
|
298
392
|
if swiglu:
|
|
299
|
-
|
|
393
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
300
394
|
if rms_norm:
|
|
301
395
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
302
396
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -370,7 +464,7 @@ def apply_liger_kernel_to_mistral(
|
|
|
370
464
|
|
|
371
465
|
for decoder_layer in base_model.layers:
|
|
372
466
|
if swiglu:
|
|
373
|
-
|
|
467
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
374
468
|
if rms_norm:
|
|
375
469
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
376
470
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -442,7 +536,7 @@ def apply_liger_kernel_to_mixtral(
|
|
|
442
536
|
for decoder_layer in base_model.layers:
|
|
443
537
|
if swiglu:
|
|
444
538
|
for expert in decoder_layer.block_sparse_moe.experts:
|
|
445
|
-
|
|
539
|
+
_patch_swiglu_module(expert, LigerBlockSparseTop2MLP)
|
|
446
540
|
if rms_norm:
|
|
447
541
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
448
542
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -516,7 +610,7 @@ def apply_liger_kernel_to_gemma(
|
|
|
516
610
|
|
|
517
611
|
for decoder_layer in base_model.layers:
|
|
518
612
|
if geglu:
|
|
519
|
-
|
|
613
|
+
_patch_geglu_module(decoder_layer.mlp)
|
|
520
614
|
if rms_norm:
|
|
521
615
|
_patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
|
|
522
616
|
_patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
|
|
@@ -592,7 +686,7 @@ def apply_liger_kernel_to_gemma2(
|
|
|
592
686
|
|
|
593
687
|
for decoder_layer in base_model.layers:
|
|
594
688
|
if geglu:
|
|
595
|
-
|
|
689
|
+
_patch_geglu_module(decoder_layer.mlp)
|
|
596
690
|
if rms_norm:
|
|
597
691
|
_patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
|
|
598
692
|
_patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
|
|
@@ -600,6 +694,287 @@ def apply_liger_kernel_to_gemma2(
|
|
|
600
694
|
_patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
|
|
601
695
|
|
|
602
696
|
|
|
697
|
+
def apply_liger_kernel_to_gemma3_text(
|
|
698
|
+
rope: bool = True,
|
|
699
|
+
cross_entropy: bool = False,
|
|
700
|
+
fused_linear_cross_entropy: bool = True,
|
|
701
|
+
rms_norm: bool = True,
|
|
702
|
+
geglu: bool = True,
|
|
703
|
+
model: PreTrainedModel = None,
|
|
704
|
+
) -> None:
|
|
705
|
+
"""
|
|
706
|
+
Apply Liger kernels to replace original implementation in HuggingFace Gemma3
|
|
707
|
+
|
|
708
|
+
Args:
|
|
709
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
710
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
711
|
+
fused_linear_cross_entropy (bool):
|
|
712
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
713
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
714
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
715
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
716
|
+
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
717
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
718
|
+
loaded. Default is None.
|
|
719
|
+
"""
|
|
720
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
721
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
722
|
+
)
|
|
723
|
+
|
|
724
|
+
from transformers.models.gemma3 import modeling_gemma3
|
|
725
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
|
|
726
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
|
|
727
|
+
|
|
728
|
+
from liger_kernel.transformers.gema3_rms import LigerRMSNormForGemma3
|
|
729
|
+
from liger_kernel.transformers.model.gemma3 import causal_forward
|
|
730
|
+
|
|
731
|
+
_patch_rms_norm_module_for_gemma3 = partial(
|
|
732
|
+
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
if rope:
|
|
736
|
+
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
737
|
+
|
|
738
|
+
if rms_norm:
|
|
739
|
+
modeling_gemma3.Gemma3RMSNorm = LigerRMSNormForGemma3
|
|
740
|
+
|
|
741
|
+
if geglu:
|
|
742
|
+
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
|
|
743
|
+
|
|
744
|
+
# Handle loss function
|
|
745
|
+
if cross_entropy:
|
|
746
|
+
from transformers.loss.loss_utils import nn
|
|
747
|
+
|
|
748
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
749
|
+
|
|
750
|
+
if fused_linear_cross_entropy:
|
|
751
|
+
modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
|
|
752
|
+
|
|
753
|
+
if model is not None:
|
|
754
|
+
# The model instance already exists, so we need to additionally patch the
|
|
755
|
+
# instance variables that reference already-instantiated modules
|
|
756
|
+
|
|
757
|
+
if isinstance(model, Gemma3ForCausalLM):
|
|
758
|
+
# get the base model from the model instance
|
|
759
|
+
base_model = model.model
|
|
760
|
+
|
|
761
|
+
if rms_norm:
|
|
762
|
+
_patch_rms_norm_module_for_gemma3(base_model.norm)
|
|
763
|
+
|
|
764
|
+
for decoder_layer in base_model.layers:
|
|
765
|
+
decoder_layer: Gemma3DecoderLayer
|
|
766
|
+
if geglu:
|
|
767
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
|
|
768
|
+
if rms_norm:
|
|
769
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.input_layernorm)
|
|
770
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.post_attention_layernorm)
|
|
771
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.pre_feedforward_layernorm)
|
|
772
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.post_feedforward_layernorm)
|
|
773
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.q_norm)
|
|
774
|
+
_patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.k_norm)
|
|
775
|
+
|
|
776
|
+
else:
|
|
777
|
+
raise TypeError("The model must be Gemma3ForCausalLM.")
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
def apply_liger_kernel_to_gemma3(
|
|
781
|
+
rope: bool = True,
|
|
782
|
+
cross_entropy: bool = False,
|
|
783
|
+
fused_linear_cross_entropy: bool = True,
|
|
784
|
+
layer_norm: bool = True,
|
|
785
|
+
rms_norm: bool = True,
|
|
786
|
+
geglu: bool = True,
|
|
787
|
+
model: PreTrainedModel = None,
|
|
788
|
+
) -> None:
|
|
789
|
+
"""
|
|
790
|
+
Apply Liger kernels to replace original implementation in HuggingFace Gemma3
|
|
791
|
+
|
|
792
|
+
Args:
|
|
793
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
794
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
795
|
+
fused_linear_cross_entropy (bool):
|
|
796
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
797
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
798
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
799
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
800
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
801
|
+
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
802
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
803
|
+
loaded. Default is None.
|
|
804
|
+
"""
|
|
805
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
806
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
807
|
+
)
|
|
808
|
+
|
|
809
|
+
from transformers.models.gemma3 import modeling_gemma3
|
|
810
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForConditionalGeneration
|
|
811
|
+
from transformers.models.siglip import modeling_siglip
|
|
812
|
+
from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
|
|
813
|
+
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
|
|
814
|
+
|
|
815
|
+
from liger_kernel.transformers.model.gemma3 import multimodal_forward
|
|
816
|
+
|
|
817
|
+
_patch_rms_norm_module_for_gemma3 = partial(
|
|
818
|
+
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
819
|
+
)
|
|
820
|
+
|
|
821
|
+
if layer_norm:
|
|
822
|
+
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
823
|
+
|
|
824
|
+
apply_liger_kernel_to_gemma3_text(
|
|
825
|
+
rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
if cross_entropy:
|
|
829
|
+
modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
830
|
+
|
|
831
|
+
if fused_linear_cross_entropy:
|
|
832
|
+
modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
|
|
833
|
+
|
|
834
|
+
if model is not None:
|
|
835
|
+
# The model instance already exists, so we need to additionally patch the
|
|
836
|
+
# instance variables that reference already-instantiated modules
|
|
837
|
+
|
|
838
|
+
if isinstance(model, Gemma3ForConditionalGeneration):
|
|
839
|
+
if isinstance(model.vision_tower, SiglipVisionModel):
|
|
840
|
+
vision_tower = model.vision_tower
|
|
841
|
+
|
|
842
|
+
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
|
|
843
|
+
|
|
844
|
+
for layer in vision_tower.vision_model.encoder.layers:
|
|
845
|
+
layer: SiglipEncoderLayer
|
|
846
|
+
if layer_norm:
|
|
847
|
+
_patch_layer_norm_module(layer.layer_norm1)
|
|
848
|
+
_patch_layer_norm_module(layer.layer_norm2)
|
|
849
|
+
else:
|
|
850
|
+
raise TypeError("The vision tower must be SiglipVisionModel")
|
|
851
|
+
|
|
852
|
+
if rms_norm:
|
|
853
|
+
_patch_rms_norm_module_for_gemma3(model.multi_modal_projector.mm_soft_emb_norm)
|
|
854
|
+
|
|
855
|
+
apply_liger_kernel_to_gemma3_text(
|
|
856
|
+
rope=rope,
|
|
857
|
+
cross_entropy=False,
|
|
858
|
+
fused_linear_cross_entropy=False,
|
|
859
|
+
rms_norm=rms_norm,
|
|
860
|
+
geglu=geglu,
|
|
861
|
+
model=model.language_model,
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
else:
|
|
865
|
+
raise TypeError("The model must be Gemma3ForConditionalGeneration.")
|
|
866
|
+
|
|
867
|
+
|
|
868
|
+
def apply_liger_kernel_to_paligemma(
|
|
869
|
+
rope: bool = True,
|
|
870
|
+
cross_entropy: bool = False,
|
|
871
|
+
fused_linear_cross_entropy: bool = True,
|
|
872
|
+
layer_norm: bool = True,
|
|
873
|
+
rms_norm: bool = True,
|
|
874
|
+
geglu: bool = True,
|
|
875
|
+
model: PreTrainedModel = None,
|
|
876
|
+
) -> None:
|
|
877
|
+
"""
|
|
878
|
+
Apply Liger kernels to replace original implementation in HuggingFace PaliGemma
|
|
879
|
+
|
|
880
|
+
Args:
|
|
881
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
882
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
883
|
+
fused_linear_cross_entropy (bool):
|
|
884
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
885
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
886
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
887
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
888
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
889
|
+
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
890
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
891
|
+
loaded. Default is None.
|
|
892
|
+
"""
|
|
893
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
894
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
895
|
+
)
|
|
896
|
+
|
|
897
|
+
# PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
|
|
898
|
+
|
|
899
|
+
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
|
900
|
+
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
|
|
901
|
+
from transformers.models.paligemma import modeling_paligemma
|
|
902
|
+
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
|
903
|
+
from transformers.models.siglip import modeling_siglip
|
|
904
|
+
from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
|
|
905
|
+
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
|
|
906
|
+
|
|
907
|
+
from liger_kernel.transformers.model.paligemma import lce_forward
|
|
908
|
+
from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
|
|
909
|
+
|
|
910
|
+
# The vision_tower is a SiglipVisionModel
|
|
911
|
+
if layer_norm:
|
|
912
|
+
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
913
|
+
|
|
914
|
+
# SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
|
|
915
|
+
# The multi_modal_projector is Linear, nothing to do
|
|
916
|
+
|
|
917
|
+
# The language_model is GemmaForCausalLM or Gemma2ForCausalLM
|
|
918
|
+
apply_liger_kernel_to_gemma(
|
|
919
|
+
rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
|
|
920
|
+
)
|
|
921
|
+
apply_liger_kernel_to_gemma2(
|
|
922
|
+
rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
|
|
923
|
+
)
|
|
924
|
+
# Handle loss function
|
|
925
|
+
if cross_entropy:
|
|
926
|
+
modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
927
|
+
if fused_linear_cross_entropy:
|
|
928
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
929
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
|
|
930
|
+
else: # if version < 4.46.1
|
|
931
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
932
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
|
|
933
|
+
|
|
934
|
+
if model is not None:
|
|
935
|
+
# The model instance already exists, so we need to additionally patch the
|
|
936
|
+
# instance variables that reference already-instantiated modules
|
|
937
|
+
|
|
938
|
+
if not isinstance(model, PaliGemmaForConditionalGeneration):
|
|
939
|
+
raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
|
|
940
|
+
|
|
941
|
+
vision_tower: SiglipVisionModel = model.vision_tower
|
|
942
|
+
|
|
943
|
+
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
|
|
944
|
+
|
|
945
|
+
for layer in vision_tower.vision_model.encoder.layers:
|
|
946
|
+
layer: SiglipEncoderLayer
|
|
947
|
+
if layer_norm:
|
|
948
|
+
_patch_layer_norm_module(layer.layer_norm1)
|
|
949
|
+
_patch_layer_norm_module(layer.layer_norm2)
|
|
950
|
+
|
|
951
|
+
language_model = model.language_model
|
|
952
|
+
|
|
953
|
+
if isinstance(language_model, GemmaForCausalLM):
|
|
954
|
+
apply_liger_kernel_to_gemma(
|
|
955
|
+
rope=rope,
|
|
956
|
+
cross_entropy=False,
|
|
957
|
+
fused_linear_cross_entropy=False,
|
|
958
|
+
rms_norm=rms_norm,
|
|
959
|
+
geglu=geglu,
|
|
960
|
+
model=language_model,
|
|
961
|
+
)
|
|
962
|
+
|
|
963
|
+
elif isinstance(language_model, Gemma2ForCausalLM):
|
|
964
|
+
apply_liger_kernel_to_gemma2(
|
|
965
|
+
rope=rope,
|
|
966
|
+
cross_entropy=False,
|
|
967
|
+
fused_linear_cross_entropy=False,
|
|
968
|
+
rms_norm=rms_norm,
|
|
969
|
+
geglu=geglu,
|
|
970
|
+
model=language_model,
|
|
971
|
+
)
|
|
972
|
+
else:
|
|
973
|
+
raise TypeError(
|
|
974
|
+
"The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM."
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
|
|
603
978
|
def apply_liger_kernel_to_qwen2(
|
|
604
979
|
rope: bool = True,
|
|
605
980
|
cross_entropy: bool = False,
|
|
@@ -666,7 +1041,7 @@ def apply_liger_kernel_to_qwen2(
|
|
|
666
1041
|
|
|
667
1042
|
for decoder_layer in base_model.layers:
|
|
668
1043
|
if swiglu:
|
|
669
|
-
|
|
1044
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
670
1045
|
if rms_norm:
|
|
671
1046
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
672
1047
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -739,7 +1114,7 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
739
1114
|
_patch_rms_norm_module(base_model.norm)
|
|
740
1115
|
for decoder_layer in base_model.layers:
|
|
741
1116
|
if swiglu:
|
|
742
|
-
|
|
1117
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
743
1118
|
if rms_norm:
|
|
744
1119
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
745
1120
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -806,7 +1181,7 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
806
1181
|
_patch_rms_norm_module(base_model.norm)
|
|
807
1182
|
for decoder_layer in base_model.layers:
|
|
808
1183
|
if swiglu:
|
|
809
|
-
|
|
1184
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
810
1185
|
if rms_norm:
|
|
811
1186
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
812
1187
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -875,7 +1250,7 @@ def apply_liger_kernel_to_phi3(
|
|
|
875
1250
|
|
|
876
1251
|
for decoder_layer in base_model.layers:
|
|
877
1252
|
if swiglu:
|
|
878
|
-
|
|
1253
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
|
|
879
1254
|
if rms_norm:
|
|
880
1255
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
881
1256
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -938,7 +1313,7 @@ def apply_liger_kernel_to_olmo2(
|
|
|
938
1313
|
|
|
939
1314
|
for decoder_layer in base_model.layers:
|
|
940
1315
|
if swiglu:
|
|
941
|
-
|
|
1316
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
942
1317
|
if rms_norm:
|
|
943
1318
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
944
1319
|
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
@@ -948,7 +1323,10 @@ def apply_liger_kernel_to_olmo2(
|
|
|
948
1323
|
MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
949
1324
|
"gemma": apply_liger_kernel_to_gemma,
|
|
950
1325
|
"gemma2": apply_liger_kernel_to_gemma2,
|
|
1326
|
+
"gemma3_text": apply_liger_kernel_to_gemma3_text,
|
|
1327
|
+
"gemma3": apply_liger_kernel_to_gemma3,
|
|
951
1328
|
"llama": apply_liger_kernel_to_llama,
|
|
1329
|
+
"llava": apply_liger_kernel_to_llava,
|
|
952
1330
|
"granite": apply_liger_kernel_to_granite,
|
|
953
1331
|
"mllama": apply_liger_kernel_to_mllama,
|
|
954
1332
|
"mllama_text_model": apply_liger_kernel_to_mllama,
|
|
@@ -959,6 +1337,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
959
1337
|
"qwen2_vl": apply_liger_kernel_to_qwen2_vl,
|
|
960
1338
|
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
|
|
961
1339
|
"phi3": apply_liger_kernel_to_phi3,
|
|
1340
|
+
"paligemma": apply_liger_kernel_to_paligemma,
|
|
962
1341
|
}
|
|
963
1342
|
|
|
964
1343
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: liger_kernel
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.7
|
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
|
@@ -45,6 +45,7 @@ Requires-Dist: datasets>=2.19.2; extra == "dev"
|
|
|
45
45
|
Requires-Dist: seaborn; extra == "dev"
|
|
46
46
|
Requires-Dist: mkdocs; extra == "dev"
|
|
47
47
|
Requires-Dist: mkdocs-material; extra == "dev"
|
|
48
|
+
Dynamic: license-file
|
|
48
49
|
Dynamic: provides-extra
|
|
49
50
|
Dynamic: requires-dist
|
|
50
51
|
|
|
@@ -115,6 +116,7 @@ Dynamic: requires-dist
|
|
|
115
116
|
<details>
|
|
116
117
|
<summary>Latest News 🔥</summary>
|
|
117
118
|
|
|
119
|
+
- [2025/03/06] We release a joint blog post on TorchTune × Liger - [Peak Performance, Minimized Memory: Optimizing torchtune’s performance with torch.compile & Liger Kernel](https://pytorch.org/blog/peak-performance-minimized-memory/)
|
|
118
120
|
- [2024/12/11] We release [v0.5.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.5.0): 80% more memory efficient post training losses (DPO, ORPO, CPO, etc)!
|
|
119
121
|
- [2024/12/5] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training)
|
|
120
122
|
- [2024/11/6] We release [v0.4.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.4.0): Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision!
|
|
@@ -177,7 +179,7 @@ y = orpo_loss(lm_head.weight, x, target)
|
|
|
177
179
|
- **Exact:** Computation is exact—no approximations! Both forward and backward passes are implemented with rigorous unit tests and undergo convergence testing against training runs without Liger Kernel to ensure accuracy.
|
|
178
180
|
- **Lightweight:** Liger Kernel has minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches!
|
|
179
181
|
- **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP, DeepSpeed, DDP, etc.).
|
|
180
|
-
- **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift)
|
|
182
|
+
- **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift), [oumi](https://github.com/oumi-ai/oumi/tree/main)
|
|
181
183
|
|
|
182
184
|
## Installation
|
|
183
185
|
|
|
@@ -312,6 +314,9 @@ loss.backward()
|
|
|
312
314
|
| Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
313
315
|
| Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
314
316
|
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
317
|
+
| Gemma3 (Text) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3_text` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
318
|
+
| Gemma3 (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
319
|
+
| Paligemma, Paligemma2, & Paligemma2 Mix | `liger_kernel.transformers.apply_liger_kernel_to_paligemma` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
315
320
|
| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
316
321
|
| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
317
322
|
| Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
|
|
@@ -386,8 +391,8 @@ loss.backward()
|
|
|
386
391
|
## Contact
|
|
387
392
|
|
|
388
393
|
- For issues, create a Github ticket in this repository
|
|
389
|
-
- For open discussion, join [our discord channel](https://discord.
|
|
390
|
-
- For formal collaboration, send an email to yannchen@linkedin.com
|
|
394
|
+
- For open discussion, join [our discord channel on GPUMode](https://discord.com/channels/1189498204333543425/1275130785933951039)
|
|
395
|
+
- For formal collaboration, send an email to yannchen@linkedin.com and hning@linkedin.com
|
|
391
396
|
|
|
392
397
|
## Cite this work
|
|
393
398
|
|
|
@@ -406,7 +411,7 @@ Biblatex entry:
|
|
|
406
411
|
```
|
|
407
412
|
|
|
408
413
|
## Star History
|
|
409
|
-
[](https://star-history.com/#linkedin/Liger-Kernel&Date)
|
|
414
|
+
[](https://www.star-history.com/#linkedin/Liger-Kernel&Date)
|
|
410
415
|
|
|
411
416
|
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
|
|
412
417
|
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
|