liger-kernel 0.5.10__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/ops/dyt.py +0 -2
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +1 -1
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/rms_norm.py +265 -54
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +62 -50
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/transformers/__init__.py +3 -0
- liger_kernel/transformers/functional.py +62 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/model/gemma.py +25 -8
- liger_kernel/transformers/model/gemma2.py +27 -8
- liger_kernel/transformers/model/gemma3.py +62 -98
- liger_kernel/transformers/model/glm4.py +16 -7
- liger_kernel/transformers/model/llama.py +25 -7
- liger_kernel/transformers/model/llama4.py +108 -0
- liger_kernel/transformers/model/llava.py +95 -124
- liger_kernel/transformers/model/mistral.py +13 -8
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +16 -7
- liger_kernel/transformers/model/olmo2.py +16 -7
- liger_kernel/transformers/model/paligemma.py +8 -1
- liger_kernel/transformers/model/phi3.py +25 -8
- liger_kernel/transformers/model/qwen2.py +24 -7
- liger_kernel/transformers/model/qwen2_5_vl.py +41 -91
- liger_kernel/transformers/model/qwen2_vl.py +38 -100
- liger_kernel/transformers/model/qwen3.py +11 -3
- liger_kernel/transformers/model/qwen3_moe.py +10 -6
- liger_kernel/transformers/monkey_patch.py +304 -70
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/rms_norm.py +40 -4
- liger_kernel/transformers/softmax.py +12 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +8 -2
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/RECORD +42 -35
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
- liger_kernel/transformers/gema3_rms.py +0 -8
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,7 @@ import inspect
|
|
|
2
2
|
import logging
|
|
3
3
|
|
|
4
4
|
from functools import partial
|
|
5
|
+
from types import MethodType
|
|
5
6
|
from typing import Callable
|
|
6
7
|
|
|
7
8
|
import transformers
|
|
@@ -54,7 +55,7 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
|
|
54
55
|
module.__dict__[method_name] = new_method.__get__(module, module.__class__)
|
|
55
56
|
|
|
56
57
|
|
|
57
|
-
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
|
|
58
|
+
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None):
|
|
58
59
|
# Check if the module is a PEFT ModulesToSaveWrapper
|
|
59
60
|
# If it is, we need to patch the modules_to_save.default and original_modules
|
|
60
61
|
if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
|
|
@@ -64,12 +65,14 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
|
|
|
64
65
|
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
65
66
|
)
|
|
66
67
|
module.modules_to_save.default.in_place = in_place
|
|
68
|
+
module.modules_to_save.default.row_mode = row_mode
|
|
67
69
|
module.original_module.offset = offset
|
|
68
70
|
module.original_module.casting_mode = casting_mode
|
|
69
71
|
module.original_module.variance_epsilon = (
|
|
70
72
|
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
71
73
|
)
|
|
72
74
|
module.original_module.in_place = in_place
|
|
75
|
+
module.original_module.row_mode = row_mode
|
|
73
76
|
_bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
|
|
74
77
|
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
|
|
75
78
|
_bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
|
|
@@ -81,6 +84,7 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
|
|
|
81
84
|
module.casting_mode = casting_mode
|
|
82
85
|
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
83
86
|
module.in_place = in_place
|
|
87
|
+
module.row_mode = row_mode
|
|
84
88
|
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
|
85
89
|
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
86
90
|
module.__class__.__name__ = LigerRMSNorm.__name__
|
|
@@ -257,10 +261,16 @@ def apply_liger_kernel_to_llama(
|
|
|
257
261
|
|
|
258
262
|
if fused_linear_cross_entropy:
|
|
259
263
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
260
|
-
|
|
264
|
+
if model is not None:
|
|
265
|
+
model.forward = MethodType(llama_lce_forward, model)
|
|
266
|
+
else:
|
|
267
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
|
261
268
|
else: # if version < 4.46.1
|
|
262
269
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
263
|
-
|
|
270
|
+
if model is not None:
|
|
271
|
+
model.forward = MethodType(llama_lce_forward_deprecated, model)
|
|
272
|
+
else:
|
|
273
|
+
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
|
|
264
274
|
|
|
265
275
|
if model is not None:
|
|
266
276
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -314,13 +324,20 @@ def apply_liger_kernel_to_llava(
|
|
|
314
324
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
315
325
|
modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
316
326
|
if fused_linear_cross_entropy:
|
|
317
|
-
if transformer_version >= version.parse("4.
|
|
318
|
-
|
|
327
|
+
if transformer_version >= version.parse("4.52.0"):
|
|
328
|
+
if model is not None:
|
|
329
|
+
model.forward = MethodType(llava_lce_forward, model)
|
|
330
|
+
else:
|
|
331
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
|
|
332
|
+
elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
|
|
333
|
+
if model is not None:
|
|
334
|
+
model.forward = MethodType(llava_lce_forward_deprecated, model)
|
|
335
|
+
else:
|
|
336
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
|
|
319
337
|
else: # if version < 4.49.0
|
|
320
338
|
logger.warning(
|
|
321
|
-
"
|
|
339
|
+
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
|
|
322
340
|
)
|
|
323
|
-
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
|
|
324
341
|
|
|
325
342
|
if model is not None:
|
|
326
343
|
text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
|
|
@@ -359,6 +376,92 @@ def apply_liger_kernel_to_llava(
|
|
|
359
376
|
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
|
|
360
377
|
|
|
361
378
|
|
|
379
|
+
def apply_liger_kernel_to_llama4(
|
|
380
|
+
rope: bool = False,
|
|
381
|
+
cross_entropy: bool = False,
|
|
382
|
+
fused_linear_cross_entropy: bool = True,
|
|
383
|
+
rms_norm: bool = True,
|
|
384
|
+
swiglu: bool = True,
|
|
385
|
+
model: PreTrainedModel = None,
|
|
386
|
+
layer_norm: bool = True,
|
|
387
|
+
) -> None:
|
|
388
|
+
"""
|
|
389
|
+
Apply Liger kernels to replace original implementation in HuggingFace Llama4 models.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
393
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
394
|
+
fused_linear_cross_entropy (bool):
|
|
395
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
396
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
397
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
398
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
399
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
400
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
401
|
+
loaded. Default is None.
|
|
402
|
+
"""
|
|
403
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
404
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
from transformers.models.llama4 import modeling_llama4
|
|
408
|
+
from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
|
|
409
|
+
from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
|
|
410
|
+
from transformers.models.llama4.modeling_llama4 import Llama4TextModel
|
|
411
|
+
from transformers.models.llama4.modeling_llama4 import Llama4VisionModel
|
|
412
|
+
|
|
413
|
+
from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
|
|
414
|
+
|
|
415
|
+
if rope:
|
|
416
|
+
raise NotImplementedError("liger_rotary_pos_emb is not available for Llama4 models.")
|
|
417
|
+
if rms_norm:
|
|
418
|
+
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
|
|
419
|
+
if swiglu:
|
|
420
|
+
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
|
|
421
|
+
|
|
422
|
+
if cross_entropy:
|
|
423
|
+
modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
424
|
+
|
|
425
|
+
if fused_linear_cross_entropy:
|
|
426
|
+
modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
|
|
427
|
+
|
|
428
|
+
if model is not None:
|
|
429
|
+
# The model instance already exists, so we need to additionally patch the
|
|
430
|
+
# instance variables that reference already-instantiated modules
|
|
431
|
+
if isinstance(model, Llama4ForConditionalGeneration):
|
|
432
|
+
language_model: Llama4ForCausalLM = model.language_model
|
|
433
|
+
vision_model: Llama4VisionModel = model.vision_model
|
|
434
|
+
text_model: Llama4TextModel = language_model.model
|
|
435
|
+
elif isinstance(model, Llama4ForCausalLM):
|
|
436
|
+
text_model = model.model
|
|
437
|
+
vision_model = None
|
|
438
|
+
elif isinstance(model, Llama4TextModel):
|
|
439
|
+
text_model = model
|
|
440
|
+
vision_model = None
|
|
441
|
+
|
|
442
|
+
else:
|
|
443
|
+
raise ValueError(f"Unsupported Llama4 model type: {type(model)}")
|
|
444
|
+
|
|
445
|
+
if text_model:
|
|
446
|
+
if rms_norm:
|
|
447
|
+
_patch_rms_norm_module(text_model.norm)
|
|
448
|
+
for decoder_layer in text_model.layers:
|
|
449
|
+
if swiglu:
|
|
450
|
+
_patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
|
|
451
|
+
if rms_norm:
|
|
452
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
453
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
454
|
+
|
|
455
|
+
if vision_model:
|
|
456
|
+
_patch_layer_norm_module(vision_model.layernorm_pre)
|
|
457
|
+
_patch_layer_norm_module(vision_model.layernorm_post)
|
|
458
|
+
|
|
459
|
+
for layer in vision_model.model.layers:
|
|
460
|
+
if layer_norm:
|
|
461
|
+
_patch_layer_norm_module(layer.input_layernorm)
|
|
462
|
+
_patch_layer_norm_module(layer.post_attention_layernorm)
|
|
463
|
+
|
|
464
|
+
|
|
362
465
|
def apply_liger_kernel_to_mllama(
|
|
363
466
|
rope: bool = True,
|
|
364
467
|
cross_entropy: bool = False,
|
|
@@ -400,7 +503,7 @@ def apply_liger_kernel_to_mllama(
|
|
|
400
503
|
|
|
401
504
|
if rope:
|
|
402
505
|
modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
403
|
-
if layer_norm:
|
|
506
|
+
if layer_norm and model is None:
|
|
404
507
|
modeling_mllama.nn.LayerNorm = LigerLayerNorm
|
|
405
508
|
if rms_norm:
|
|
406
509
|
modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
|
|
@@ -416,10 +519,16 @@ def apply_liger_kernel_to_mllama(
|
|
|
416
519
|
modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
417
520
|
if fused_linear_cross_entropy:
|
|
418
521
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
419
|
-
|
|
522
|
+
if model is not None:
|
|
523
|
+
model.forward = MethodType(mllama_lce_forward, model)
|
|
524
|
+
else:
|
|
525
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
|
|
420
526
|
else: # if version < 4.46.1
|
|
421
527
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
422
|
-
|
|
528
|
+
if model is not None:
|
|
529
|
+
model.forward = MethodType(mllama_lce_forward_deprecated, model)
|
|
530
|
+
else:
|
|
531
|
+
modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
|
|
423
532
|
|
|
424
533
|
if model is not None:
|
|
425
534
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -428,13 +537,17 @@ def apply_liger_kernel_to_mllama(
|
|
|
428
537
|
if isinstance(model, MllamaForConditionalGeneration):
|
|
429
538
|
language_model: MllamaForCausalLM = model.language_model
|
|
430
539
|
vision_model: MllamaVisionModel = model.vision_model
|
|
431
|
-
|
|
540
|
+
if isinstance(language_model, MllamaForCausalLM):
|
|
541
|
+
text_model: MllamaTextModel = language_model.model
|
|
542
|
+
else:
|
|
543
|
+
text_model = language_model
|
|
432
544
|
elif isinstance(model, MllamaForCausalLM):
|
|
433
545
|
text_model = model.model
|
|
434
546
|
vision_model = None
|
|
435
547
|
elif isinstance(model, MllamaTextModel):
|
|
436
548
|
text_model = model
|
|
437
549
|
vision_model = None
|
|
550
|
+
|
|
438
551
|
else:
|
|
439
552
|
raise ValueError(f"Unsupported Mllama model type: {type(model)}")
|
|
440
553
|
|
|
@@ -501,7 +614,17 @@ def apply_liger_kernel_to_mistral(
|
|
|
501
614
|
if cross_entropy:
|
|
502
615
|
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
503
616
|
if fused_linear_cross_entropy:
|
|
504
|
-
|
|
617
|
+
if transformer_version >= version.parse("4.49.0"):
|
|
618
|
+
if model is not None:
|
|
619
|
+
model.forward = MethodType(mistral_lce_forward, model)
|
|
620
|
+
else:
|
|
621
|
+
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
|
|
622
|
+
else:
|
|
623
|
+
logger.warning(
|
|
624
|
+
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
|
|
625
|
+
)
|
|
626
|
+
logger.warning("LigerFusedLinearCrossEntropy patch is not applied.")
|
|
627
|
+
|
|
505
628
|
if swiglu:
|
|
506
629
|
modeling_mistral.MistralMLP = LigerSwiGLUMLP
|
|
507
630
|
|
|
@@ -569,10 +692,16 @@ def apply_liger_kernel_to_mixtral(
|
|
|
569
692
|
|
|
570
693
|
if fused_linear_cross_entropy:
|
|
571
694
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
572
|
-
|
|
695
|
+
if model is not None:
|
|
696
|
+
model.forward = MethodType(mixtral_lce_forward, model)
|
|
697
|
+
else:
|
|
698
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
|
|
573
699
|
else: # if version < 4.46.1
|
|
574
700
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
575
|
-
|
|
701
|
+
if model is not None:
|
|
702
|
+
model.forward = MethodType(mixtral_lce_forward_deprecated, model)
|
|
703
|
+
else:
|
|
704
|
+
modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
|
|
576
705
|
if swiglu:
|
|
577
706
|
modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
|
|
578
707
|
|
|
@@ -626,8 +755,8 @@ def apply_liger_kernel_to_gemma(
|
|
|
626
755
|
from transformers.models.gemma import modeling_gemma
|
|
627
756
|
from transformers.models.gemma.modeling_gemma import GemmaModel
|
|
628
757
|
|
|
629
|
-
|
|
630
|
-
|
|
758
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma
|
|
759
|
+
|
|
631
760
|
_patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
|
|
632
761
|
|
|
633
762
|
if rope:
|
|
@@ -646,10 +775,16 @@ def apply_liger_kernel_to_gemma(
|
|
|
646
775
|
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
|
647
776
|
if fused_linear_cross_entropy:
|
|
648
777
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
649
|
-
|
|
778
|
+
if model is not None:
|
|
779
|
+
model.forward = MethodType(gemma_lce_forward, model)
|
|
780
|
+
else:
|
|
781
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
|
|
650
782
|
else: # if version < 4.46.1
|
|
651
783
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
652
|
-
|
|
784
|
+
if model is not None:
|
|
785
|
+
model.forward = MethodType(gemma_lce_forward_deprecated, model)
|
|
786
|
+
else:
|
|
787
|
+
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
|
|
653
788
|
|
|
654
789
|
if model is not None:
|
|
655
790
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -700,7 +835,8 @@ def apply_liger_kernel_to_gemma2(
|
|
|
700
835
|
from transformers.models.gemma2 import modeling_gemma2
|
|
701
836
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
702
837
|
|
|
703
|
-
|
|
838
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2
|
|
839
|
+
|
|
704
840
|
_patch_rms_norm_module_for_gemma2 = partial(
|
|
705
841
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
706
842
|
)
|
|
@@ -720,10 +856,16 @@ def apply_liger_kernel_to_gemma2(
|
|
|
720
856
|
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
721
857
|
if fused_linear_cross_entropy:
|
|
722
858
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
723
|
-
|
|
859
|
+
if model is not None:
|
|
860
|
+
model.forward = MethodType(gemma2_lce_forward, model)
|
|
861
|
+
else:
|
|
862
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
|
|
724
863
|
else:
|
|
725
864
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
726
|
-
|
|
865
|
+
if model is not None:
|
|
866
|
+
model.forward = MethodType(gemma2_lce_forward_deprected, model)
|
|
867
|
+
else:
|
|
868
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
|
|
727
869
|
if geglu:
|
|
728
870
|
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
|
729
871
|
|
|
@@ -777,9 +919,10 @@ def apply_liger_kernel_to_gemma3_text(
|
|
|
777
919
|
from transformers.models.gemma3 import modeling_gemma3
|
|
778
920
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
|
|
779
921
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
|
|
922
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
|
|
780
923
|
|
|
781
|
-
from liger_kernel.transformers.gema3_rms import LigerRMSNormForGemma3
|
|
782
924
|
from liger_kernel.transformers.model.gemma3 import causal_forward
|
|
925
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma3
|
|
783
926
|
|
|
784
927
|
_patch_rms_norm_module_for_gemma3 = partial(
|
|
785
928
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
@@ -801,15 +944,18 @@ def apply_liger_kernel_to_gemma3_text(
|
|
|
801
944
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
802
945
|
|
|
803
946
|
if fused_linear_cross_entropy:
|
|
804
|
-
|
|
947
|
+
if model is not None:
|
|
948
|
+
model.forward = MethodType(causal_forward, model)
|
|
949
|
+
else:
|
|
950
|
+
modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
|
|
805
951
|
|
|
806
952
|
if model is not None:
|
|
807
953
|
# The model instance already exists, so we need to additionally patch the
|
|
808
954
|
# instance variables that reference already-instantiated modules
|
|
809
955
|
|
|
810
|
-
if isinstance(model, Gemma3ForCausalLM):
|
|
956
|
+
if isinstance(model, Gemma3ForCausalLM) or isinstance(model, Gemma3TextModel):
|
|
811
957
|
# get the base model from the model instance
|
|
812
|
-
base_model = model.model
|
|
958
|
+
base_model = model.model if isinstance(model, Gemma3ForCausalLM) else model
|
|
813
959
|
|
|
814
960
|
if rms_norm:
|
|
815
961
|
_patch_rms_norm_module_for_gemma3(base_model.norm)
|
|
@@ -871,7 +1017,7 @@ def apply_liger_kernel_to_gemma3(
|
|
|
871
1017
|
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
|
|
872
1018
|
)
|
|
873
1019
|
|
|
874
|
-
if layer_norm:
|
|
1020
|
+
if layer_norm and model is None:
|
|
875
1021
|
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
876
1022
|
|
|
877
1023
|
apply_liger_kernel_to_gemma3_text(
|
|
@@ -882,7 +1028,10 @@ def apply_liger_kernel_to_gemma3(
|
|
|
882
1028
|
modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
883
1029
|
|
|
884
1030
|
if fused_linear_cross_entropy:
|
|
885
|
-
|
|
1031
|
+
if model is not None:
|
|
1032
|
+
model.forward = MethodType(multimodal_forward, model)
|
|
1033
|
+
else:
|
|
1034
|
+
modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
|
|
886
1035
|
|
|
887
1036
|
if model is not None:
|
|
888
1037
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -950,7 +1099,9 @@ def apply_liger_kernel_to_paligemma(
|
|
|
950
1099
|
# PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
|
|
951
1100
|
|
|
952
1101
|
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
|
1102
|
+
from transformers.models.gemma.modeling_gemma import GemmaModel
|
|
953
1103
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
|
|
1104
|
+
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
954
1105
|
from transformers.models.paligemma import modeling_paligemma
|
|
955
1106
|
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
|
956
1107
|
from transformers.models.siglip import modeling_siglip
|
|
@@ -961,7 +1112,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
961
1112
|
from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
|
|
962
1113
|
|
|
963
1114
|
# The vision_tower is a SiglipVisionModel
|
|
964
|
-
if layer_norm:
|
|
1115
|
+
if layer_norm and model is None:
|
|
965
1116
|
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
966
1117
|
|
|
967
1118
|
# SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
|
|
@@ -979,10 +1130,16 @@ def apply_liger_kernel_to_paligemma(
|
|
|
979
1130
|
modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
980
1131
|
if fused_linear_cross_entropy:
|
|
981
1132
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
982
|
-
|
|
1133
|
+
if model is not None:
|
|
1134
|
+
model.forward = MethodType(lce_forward, model)
|
|
1135
|
+
else:
|
|
1136
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
|
|
983
1137
|
else: # if version < 4.46.1
|
|
984
1138
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
985
|
-
|
|
1139
|
+
if model is not None:
|
|
1140
|
+
model.forward = MethodType(lce_forward_deprecated, model)
|
|
1141
|
+
else:
|
|
1142
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
|
|
986
1143
|
|
|
987
1144
|
if model is not None:
|
|
988
1145
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -1003,7 +1160,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
1003
1160
|
|
|
1004
1161
|
language_model = model.language_model
|
|
1005
1162
|
|
|
1006
|
-
if isinstance(language_model, GemmaForCausalLM):
|
|
1163
|
+
if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
|
|
1007
1164
|
apply_liger_kernel_to_gemma(
|
|
1008
1165
|
rope=rope,
|
|
1009
1166
|
cross_entropy=False,
|
|
@@ -1013,7 +1170,7 @@ def apply_liger_kernel_to_paligemma(
|
|
|
1013
1170
|
model=language_model,
|
|
1014
1171
|
)
|
|
1015
1172
|
|
|
1016
|
-
elif isinstance(language_model, Gemma2ForCausalLM):
|
|
1173
|
+
elif isinstance(language_model, (Gemma2ForCausalLM, Gemma2Model)):
|
|
1017
1174
|
apply_liger_kernel_to_gemma2(
|
|
1018
1175
|
rope=rope,
|
|
1019
1176
|
cross_entropy=False,
|
|
@@ -1074,10 +1231,16 @@ def apply_liger_kernel_to_qwen2(
|
|
|
1074
1231
|
|
|
1075
1232
|
if fused_linear_cross_entropy:
|
|
1076
1233
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
1077
|
-
|
|
1234
|
+
if model is not None:
|
|
1235
|
+
model.forward = MethodType(qwen2_lce_forward, model)
|
|
1236
|
+
else:
|
|
1237
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
|
1078
1238
|
else: # if version < 4.46.1
|
|
1079
1239
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1080
|
-
|
|
1240
|
+
if model is not None:
|
|
1241
|
+
model.forward = MethodType(qwen2_lce_forward_deprecated, model)
|
|
1242
|
+
else:
|
|
1243
|
+
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
|
|
1081
1244
|
|
|
1082
1245
|
if swiglu:
|
|
1083
1246
|
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
|
@@ -1133,7 +1296,10 @@ def apply_liger_kernel_to_qwen3(
|
|
|
1133
1296
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
1134
1297
|
|
|
1135
1298
|
if fused_linear_cross_entropy:
|
|
1136
|
-
|
|
1299
|
+
if model is not None:
|
|
1300
|
+
model.forward = MethodType(qwen3_lce_forward, model)
|
|
1301
|
+
else:
|
|
1302
|
+
modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
|
|
1137
1303
|
|
|
1138
1304
|
if swiglu:
|
|
1139
1305
|
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
|
|
@@ -1188,7 +1354,10 @@ def apply_liger_kernel_to_qwen3_moe(
|
|
|
1188
1354
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
1189
1355
|
|
|
1190
1356
|
if fused_linear_cross_entropy:
|
|
1191
|
-
|
|
1357
|
+
if model is not None:
|
|
1358
|
+
model.forward = MethodType(qwen3_lce_forward, model)
|
|
1359
|
+
else:
|
|
1360
|
+
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
|
|
1192
1361
|
|
|
1193
1362
|
if swiglu:
|
|
1194
1363
|
modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
|
|
@@ -1204,7 +1373,8 @@ def apply_liger_kernel_to_qwen3_moe(
|
|
|
1204
1373
|
_patch_rms_norm_module(base_model.norm)
|
|
1205
1374
|
for decoder_layer in base_model.layers:
|
|
1206
1375
|
if swiglu:
|
|
1207
|
-
|
|
1376
|
+
for mlp_expert in decoder_layer.mlp.experts:
|
|
1377
|
+
_patch_swiglu_module(mlp_expert, LigerQwen3MoeSwiGLUMLP)
|
|
1208
1378
|
if rms_norm:
|
|
1209
1379
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1210
1380
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -1221,7 +1391,7 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
1221
1391
|
) -> None:
|
|
1222
1392
|
"""
|
|
1223
1393
|
Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
|
|
1224
|
-
NOTE: Qwen2-VL is not
|
|
1394
|
+
NOTE: Qwen2-VL is not supported in transformers<4.52.4
|
|
1225
1395
|
|
|
1226
1396
|
Args:
|
|
1227
1397
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
@@ -1235,12 +1405,19 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
1235
1405
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1236
1406
|
loaded. Default is None.
|
|
1237
1407
|
"""
|
|
1408
|
+
if transformer_version < version.parse("4.52.4"):
|
|
1409
|
+
logger.warning("Qwen2-VL support is only compatible with transformers >= 4.52.4")
|
|
1410
|
+
return
|
|
1411
|
+
|
|
1238
1412
|
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1239
1413
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1240
1414
|
)
|
|
1241
1415
|
|
|
1242
1416
|
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
|
1417
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
|
|
1418
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
|
|
1243
1419
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
|
|
1420
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel
|
|
1244
1421
|
|
|
1245
1422
|
from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
|
|
1246
1423
|
|
|
@@ -1249,12 +1426,15 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
1249
1426
|
if rms_norm:
|
|
1250
1427
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
|
|
1251
1428
|
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
|
|
1252
|
-
if layer_norm:
|
|
1429
|
+
if layer_norm and model is None:
|
|
1253
1430
|
modeling_qwen2_vl.LayerNorm = LigerLayerNorm
|
|
1254
1431
|
if cross_entropy:
|
|
1255
1432
|
modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1256
1433
|
if fused_linear_cross_entropy:
|
|
1257
|
-
|
|
1434
|
+
if model is not None:
|
|
1435
|
+
model.forward = MethodType(qwen2_vl_lce_forward, model)
|
|
1436
|
+
else:
|
|
1437
|
+
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
|
|
1258
1438
|
if swiglu:
|
|
1259
1439
|
modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
|
|
1260
1440
|
|
|
@@ -1262,24 +1442,38 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
1262
1442
|
# The model instance already exists, so we need to additionally patch the
|
|
1263
1443
|
# instance variables that reference already-instantiated modules
|
|
1264
1444
|
|
|
1265
|
-
|
|
1266
|
-
|
|
1445
|
+
if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)):
|
|
1446
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
|
1447
|
+
# Not sure if it is subject to changes in the future.
|
|
1448
|
+
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
|
|
1449
|
+
text_model: Qwen2VLTextModel = model.language_model
|
|
1450
|
+
vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
|
|
1451
|
+
elif isinstance(model, Qwen2VLTextModel):
|
|
1452
|
+
text_model: Qwen2VLTextModel = model
|
|
1453
|
+
vision_model = None
|
|
1454
|
+
else:
|
|
1455
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
1456
|
+
raise TypeError(
|
|
1457
|
+
f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
|
|
1458
|
+
)
|
|
1267
1459
|
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
for vision_block in
|
|
1460
|
+
# Patch Qwen2VisionTransformerPretrainedModel
|
|
1461
|
+
if vision_model is not None:
|
|
1462
|
+
for vision_block in vision_model.blocks:
|
|
1271
1463
|
if layer_norm:
|
|
1272
1464
|
_patch_layer_norm_module(vision_block.norm1)
|
|
1273
1465
|
_patch_layer_norm_module(vision_block.norm2)
|
|
1274
1466
|
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
for decoder_layer in base_model.layers:
|
|
1278
|
-
if swiglu:
|
|
1279
|
-
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1467
|
+
# Patch Qwen2VisionTextModel
|
|
1468
|
+
if text_model is not None:
|
|
1280
1469
|
if rms_norm:
|
|
1281
|
-
_patch_rms_norm_module(
|
|
1282
|
-
|
|
1470
|
+
_patch_rms_norm_module(text_model.norm)
|
|
1471
|
+
for decoder_layer in text_model.layers:
|
|
1472
|
+
if swiglu:
|
|
1473
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1474
|
+
if rms_norm:
|
|
1475
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1476
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1283
1477
|
|
|
1284
1478
|
|
|
1285
1479
|
def apply_liger_kernel_to_qwen2_5_vl(
|
|
@@ -1305,12 +1499,19 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1305
1499
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
1306
1500
|
loaded. Default is None.
|
|
1307
1501
|
"""
|
|
1502
|
+
if transformer_version < version.parse("4.52.4"):
|
|
1503
|
+
logger.warning("Qwen2.5-VL support is only compatible with transformers >= 4.52.4")
|
|
1504
|
+
return
|
|
1505
|
+
|
|
1308
1506
|
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1309
1507
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1310
1508
|
)
|
|
1311
1509
|
|
|
1312
1510
|
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
|
|
1511
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
|
|
1512
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
|
1313
1513
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
|
|
1514
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel
|
|
1314
1515
|
|
|
1315
1516
|
from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward
|
|
1316
1517
|
|
|
@@ -1321,7 +1522,10 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1321
1522
|
if cross_entropy:
|
|
1322
1523
|
modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1323
1524
|
if fused_linear_cross_entropy:
|
|
1324
|
-
|
|
1525
|
+
if model is not None:
|
|
1526
|
+
model.forward = MethodType(qwen2_5_vl_lce_forward, model)
|
|
1527
|
+
else:
|
|
1528
|
+
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
|
|
1325
1529
|
if swiglu:
|
|
1326
1530
|
modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
|
|
1327
1531
|
|
|
@@ -1329,24 +1533,37 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
1329
1533
|
# The model instance already exists, so we need to additionally patch the
|
|
1330
1534
|
# instance variables that reference already-instantiated modules
|
|
1331
1535
|
|
|
1332
|
-
|
|
1333
|
-
|
|
1536
|
+
if isinstance(model, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel)):
|
|
1537
|
+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
|
|
1538
|
+
# Not sure if it is subject to changes in the future.
|
|
1539
|
+
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
|
|
1540
|
+
text_model: Qwen2_5_VLTextModel = model.language_model
|
|
1541
|
+
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
|
|
1542
|
+
elif isinstance(model, Qwen2_5_VLTextModel):
|
|
1543
|
+
text_model: Qwen2_5_VLTextModel = model
|
|
1544
|
+
vision_model = None
|
|
1545
|
+
else:
|
|
1546
|
+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
|
|
1547
|
+
raise TypeError(
|
|
1548
|
+
f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
|
|
1549
|
+
)
|
|
1334
1550
|
|
|
1335
|
-
if
|
|
1551
|
+
if vision_model is not None:
|
|
1336
1552
|
# Patch Qwen2_5_VisionTransformerPretrainedModel
|
|
1337
1553
|
for vision_block in model.visual.blocks:
|
|
1338
1554
|
if rms_norm:
|
|
1339
1555
|
_patch_rms_norm_module(vision_block.norm1)
|
|
1340
1556
|
_patch_rms_norm_module(vision_block.norm2)
|
|
1341
1557
|
|
|
1342
|
-
if
|
|
1343
|
-
_patch_rms_norm_module(base_model.norm)
|
|
1344
|
-
for decoder_layer in base_model.layers:
|
|
1345
|
-
if swiglu:
|
|
1346
|
-
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1558
|
+
if text_model is not None:
|
|
1347
1559
|
if rms_norm:
|
|
1348
|
-
_patch_rms_norm_module(
|
|
1349
|
-
|
|
1560
|
+
_patch_rms_norm_module(text_model.norm)
|
|
1561
|
+
for decoder_layer in text_model.layers:
|
|
1562
|
+
if swiglu:
|
|
1563
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1564
|
+
if rms_norm:
|
|
1565
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1566
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1350
1567
|
|
|
1351
1568
|
|
|
1352
1569
|
def apply_liger_kernel_to_phi3(
|
|
@@ -1395,10 +1612,16 @@ def apply_liger_kernel_to_phi3(
|
|
|
1395
1612
|
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
1396
1613
|
if fused_linear_cross_entropy:
|
|
1397
1614
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
1398
|
-
|
|
1615
|
+
if model is not None:
|
|
1616
|
+
model.forward = MethodType(phi3_lce_forward, model)
|
|
1617
|
+
else:
|
|
1618
|
+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
|
1399
1619
|
else: # if version < 4.46.1
|
|
1400
1620
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
1401
|
-
|
|
1621
|
+
if model is not None:
|
|
1622
|
+
model.forward = MethodType(phi3_lce_forward_deprecated, model)
|
|
1623
|
+
else:
|
|
1624
|
+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
|
|
1402
1625
|
|
|
1403
1626
|
if model is not None:
|
|
1404
1627
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -1449,11 +1672,12 @@ def apply_liger_kernel_to_olmo2(
|
|
|
1449
1672
|
from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
|
|
1450
1673
|
|
|
1451
1674
|
from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
|
|
1675
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
|
|
1452
1676
|
|
|
1453
1677
|
if rope:
|
|
1454
1678
|
modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1455
1679
|
if rms_norm:
|
|
1456
|
-
modeling_olmo2.Olmo2RMSNorm =
|
|
1680
|
+
modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
|
|
1457
1681
|
if swiglu:
|
|
1458
1682
|
modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
|
|
1459
1683
|
if cross_entropy:
|
|
@@ -1461,7 +1685,10 @@ def apply_liger_kernel_to_olmo2(
|
|
|
1461
1685
|
|
|
1462
1686
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
1463
1687
|
if fused_linear_cross_entropy:
|
|
1464
|
-
|
|
1688
|
+
if model is not None:
|
|
1689
|
+
model.forward = MethodType(olmo2_lce_forward, model)
|
|
1690
|
+
else:
|
|
1691
|
+
modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
|
|
1465
1692
|
|
|
1466
1693
|
if model is not None:
|
|
1467
1694
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -1512,11 +1739,12 @@ def apply_liger_kernel_to_glm4(
|
|
|
1512
1739
|
from transformers.models.glm4.modeling_glm4 import Glm4Model
|
|
1513
1740
|
|
|
1514
1741
|
from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
|
|
1742
|
+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
|
|
1515
1743
|
|
|
1516
1744
|
if rope:
|
|
1517
1745
|
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
|
|
1518
1746
|
if rms_norm:
|
|
1519
|
-
modeling_glm4.Glm4RMSNorm =
|
|
1747
|
+
modeling_glm4.Glm4RMSNorm = LigerRMSNormForGlm4
|
|
1520
1748
|
if swiglu:
|
|
1521
1749
|
modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
|
|
1522
1750
|
if cross_entropy:
|
|
@@ -1524,7 +1752,10 @@ def apply_liger_kernel_to_glm4(
|
|
|
1524
1752
|
|
|
1525
1753
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
1526
1754
|
if fused_linear_cross_entropy:
|
|
1527
|
-
|
|
1755
|
+
if model is not None:
|
|
1756
|
+
model.forward = MethodType(glm4_lce_forward, model)
|
|
1757
|
+
else:
|
|
1758
|
+
modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
|
|
1528
1759
|
|
|
1529
1760
|
if model is not None:
|
|
1530
1761
|
# The model instance already exists, so we need to additionally patch the
|
|
@@ -1554,6 +1785,8 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
1554
1785
|
"gemma3": apply_liger_kernel_to_gemma3,
|
|
1555
1786
|
"glm4": apply_liger_kernel_to_glm4,
|
|
1556
1787
|
"llama": apply_liger_kernel_to_llama,
|
|
1788
|
+
"llama4_text": apply_liger_kernel_to_llama4,
|
|
1789
|
+
"llama4": apply_liger_kernel_to_llama4,
|
|
1557
1790
|
"llava": apply_liger_kernel_to_llava,
|
|
1558
1791
|
"granite": apply_liger_kernel_to_granite,
|
|
1559
1792
|
"mllama": apply_liger_kernel_to_mllama,
|
|
@@ -1565,7 +1798,9 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
1565
1798
|
"qwen3": apply_liger_kernel_to_qwen3,
|
|
1566
1799
|
"qwen3_moe": apply_liger_kernel_to_qwen3_moe,
|
|
1567
1800
|
"qwen2_vl": apply_liger_kernel_to_qwen2_vl,
|
|
1801
|
+
"qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
|
|
1568
1802
|
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
|
|
1803
|
+
"qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
|
|
1569
1804
|
"phi3": apply_liger_kernel_to_phi3,
|
|
1570
1805
|
"paligemma": apply_liger_kernel_to_paligemma,
|
|
1571
1806
|
}
|
|
@@ -1625,7 +1860,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
|
|
1625
1860
|
return
|
|
1626
1861
|
|
|
1627
1862
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
|
1628
|
-
|
|
1629
1863
|
apply_fn_signature = inspect.signature(apply_fn)
|
|
1630
1864
|
|
|
1631
1865
|
# Filter out the keyword arguments that are not supported by the apply function
|