liger-kernel 0.5.9__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +1 -1
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  6. liger_kernel/chunked_loss/jsd_loss.py +2 -2
  7. liger_kernel/ops/dyt.py +111 -179
  8. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  9. liger_kernel/ops/geglu.py +1 -1
  10. liger_kernel/ops/grpo_loss.py +310 -0
  11. liger_kernel/ops/multi_token_attention.py +207 -0
  12. liger_kernel/ops/rms_norm.py +265 -54
  13. liger_kernel/ops/softmax.py +201 -0
  14. liger_kernel/ops/sparsemax.py +179 -0
  15. liger_kernel/ops/swiglu.py +1 -1
  16. liger_kernel/transformers/__init__.py +8 -0
  17. liger_kernel/transformers/dyt.py +5 -3
  18. liger_kernel/transformers/fsdp.py +55 -0
  19. liger_kernel/transformers/functional.py +70 -0
  20. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  21. liger_kernel/transformers/grpo_loss.py +98 -0
  22. liger_kernel/transformers/model/gemma.py +25 -16
  23. liger_kernel/transformers/model/gemma2.py +27 -14
  24. liger_kernel/transformers/model/gemma3.py +62 -106
  25. liger_kernel/transformers/model/glm4.py +16 -13
  26. liger_kernel/transformers/model/llama.py +81 -18
  27. liger_kernel/transformers/model/llama4.py +108 -0
  28. liger_kernel/transformers/model/llava.py +95 -132
  29. liger_kernel/transformers/model/mistral.py +13 -14
  30. liger_kernel/transformers/model/mixtral.py +16 -15
  31. liger_kernel/transformers/model/mllama.py +16 -14
  32. liger_kernel/transformers/model/olmo2.py +16 -13
  33. liger_kernel/transformers/model/paligemma.py +8 -9
  34. liger_kernel/transformers/model/phi3.py +25 -16
  35. liger_kernel/transformers/model/qwen2.py +24 -15
  36. liger_kernel/transformers/model/qwen2_5_vl.py +41 -97
  37. liger_kernel/transformers/model/qwen2_vl.py +38 -106
  38. liger_kernel/transformers/model/qwen3.py +11 -9
  39. liger_kernel/transformers/model/qwen3_moe.py +132 -0
  40. liger_kernel/transformers/monkey_patch.py +424 -81
  41. liger_kernel/transformers/multi_token_attention.py +64 -0
  42. liger_kernel/transformers/rms_norm.py +40 -4
  43. liger_kernel/transformers/softmax.py +12 -0
  44. liger_kernel/transformers/sparsemax.py +16 -0
  45. liger_kernel/transformers/swiglu.py +21 -0
  46. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  47. liger_kernel/utils.py +11 -0
  48. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +41 -21
  49. liger_kernel-0.6.0.dist-info/RECORD +97 -0
  50. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
  51. liger_kernel/transformers/gema3_rms.py +0 -8
  52. liger_kernel-0.5.9.dist-info/RECORD +0 -84
  53. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
  54. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
  55. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@ import inspect
2
2
  import logging
3
3
 
4
4
  from functools import partial
5
+ from types import MethodType
5
6
  from typing import Callable
6
7
 
7
8
  import transformers
@@ -35,6 +36,13 @@ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
35
36
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
36
37
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
37
38
 
39
+ try:
40
+ import peft
41
+
42
+ PEFT_AVAILABLE = True
43
+ except ImportError:
44
+ PEFT_AVAILABLE = False
45
+
38
46
  transformer_version = version.parse(transformers.__version__)
39
47
 
40
48
  logger = logging.getLogger(__name__)
@@ -47,23 +55,72 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
47
55
  module.__dict__[method_name] = new_method.__get__(module, module.__class__)
48
56
 
49
57
 
50
- def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
51
- module.offset = offset
52
- module.casting_mode = casting_mode
53
- module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
54
- module.in_place = in_place
55
- _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
56
- _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
57
- module.__class__.__name__ = LigerRMSNorm.__name__
58
+ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None):
59
+ # Check if the module is a PEFT ModulesToSaveWrapper
60
+ # If it is, we need to patch the modules_to_save.default and original_modules
61
+ if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
62
+ module.modules_to_save.default.offset = offset
63
+ module.modules_to_save.default.casting_mode = casting_mode
64
+ module.modules_to_save.default.variance_epsilon = (
65
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
66
+ )
67
+ module.modules_to_save.default.in_place = in_place
68
+ module.modules_to_save.default.row_mode = row_mode
69
+ module.original_module.offset = offset
70
+ module.original_module.casting_mode = casting_mode
71
+ module.original_module.variance_epsilon = (
72
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
73
+ )
74
+ module.original_module.in_place = in_place
75
+ module.original_module.row_mode = row_mode
76
+ _bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
77
+ _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
78
+ _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
79
+ _bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
80
+ module.modules_to_save.default.__class__.__name__ = LigerRMSNorm.__name__
81
+ module.original_module.__class__.__name__ = LigerRMSNorm.__name__
82
+ else:
83
+ module.offset = offset
84
+ module.casting_mode = casting_mode
85
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
86
+ module.in_place = in_place
87
+ module.row_mode = row_mode
88
+ _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
89
+ _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
90
+ module.__class__.__name__ = LigerRMSNorm.__name__
58
91
 
59
92
 
60
93
  def _patch_layer_norm_module(module, eps=1e-6):
61
- module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
62
- module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
63
-
64
- _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
65
- _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
66
- module.__class__.__name__ = LigerLayerNorm.__name__
94
+ # Check if the module is a PEFT ModulesToSaveWrapper
95
+ # If it is, we need to patch the modules_to_save.default and original_modules
96
+ if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
97
+ module.hidden_size = module.normalized_shape
98
+ _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
99
+ _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
100
+ module.modules_to_save.default.variance_epsilon = (
101
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
102
+ )
103
+ module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
104
+ module, "normalized_shape", None
105
+ )
106
+ module.original_module.variance_epsilon = (
107
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
108
+ )
109
+ module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
110
+ module, "normalized_shape", None
111
+ )
112
+ _bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
113
+ _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
114
+ _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
115
+ _bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
116
+ module.modules_to_save.default.__class__.__name__ = LigerLayerNorm.__name__
117
+ module.original_module.__class__.__name__ = LigerLayerNorm.__name__
118
+ else:
119
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
120
+ module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
121
+ _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
122
+ _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
123
+ module.__class__.__name__ = LigerLayerNorm.__name__
67
124
 
68
125
 
69
126
  def _patch_swiglu_module(module, liger_module):
@@ -204,10 +261,16 @@ def apply_liger_kernel_to_llama(
204
261
 
205
262
  if fused_linear_cross_entropy:
206
263
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
207
- modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
264
+ if model is not None:
265
+ model.forward = MethodType(llama_lce_forward, model)
266
+ else:
267
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
208
268
  else: # if version < 4.46.1
209
269
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
210
- modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
270
+ if model is not None:
271
+ model.forward = MethodType(llama_lce_forward_deprecated, model)
272
+ else:
273
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
211
274
 
212
275
  if model is not None:
213
276
  # The model instance already exists, so we need to additionally patch the
@@ -261,13 +324,20 @@ def apply_liger_kernel_to_llava(
261
324
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
262
325
  modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
263
326
  if fused_linear_cross_entropy:
264
- if transformer_version >= version.parse("4.49.0"):
265
- modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
327
+ if transformer_version >= version.parse("4.52.0"):
328
+ if model is not None:
329
+ model.forward = MethodType(llava_lce_forward, model)
330
+ else:
331
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
332
+ elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
333
+ if model is not None:
334
+ model.forward = MethodType(llava_lce_forward_deprecated, model)
335
+ else:
336
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
266
337
  else: # if version < 4.49.0
267
338
  logger.warning(
268
- "Support for transformers versions < 4.49.0 will soon be discontinued due to issues with incorrect legacy processing. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/35526"
339
+ "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
269
340
  )
270
- modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
271
341
 
272
342
  if model is not None:
273
343
  text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
@@ -306,6 +376,92 @@ def apply_liger_kernel_to_llava(
306
376
  logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
307
377
 
308
378
 
379
+ def apply_liger_kernel_to_llama4(
380
+ rope: bool = False,
381
+ cross_entropy: bool = False,
382
+ fused_linear_cross_entropy: bool = True,
383
+ rms_norm: bool = True,
384
+ swiglu: bool = True,
385
+ model: PreTrainedModel = None,
386
+ layer_norm: bool = True,
387
+ ) -> None:
388
+ """
389
+ Apply Liger kernels to replace original implementation in HuggingFace Llama4 models.
390
+
391
+ Args:
392
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
393
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
394
+ fused_linear_cross_entropy (bool):
395
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
396
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
397
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
398
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
399
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
400
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
401
+ loaded. Default is None.
402
+ """
403
+ assert not (cross_entropy and fused_linear_cross_entropy), (
404
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
405
+ )
406
+
407
+ from transformers.models.llama4 import modeling_llama4
408
+ from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
409
+ from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
410
+ from transformers.models.llama4.modeling_llama4 import Llama4TextModel
411
+ from transformers.models.llama4.modeling_llama4 import Llama4VisionModel
412
+
413
+ from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
414
+
415
+ if rope:
416
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Llama4 models.")
417
+ if rms_norm:
418
+ modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
419
+ if swiglu:
420
+ modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
421
+
422
+ if cross_entropy:
423
+ modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
424
+
425
+ if fused_linear_cross_entropy:
426
+ modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
427
+
428
+ if model is not None:
429
+ # The model instance already exists, so we need to additionally patch the
430
+ # instance variables that reference already-instantiated modules
431
+ if isinstance(model, Llama4ForConditionalGeneration):
432
+ language_model: Llama4ForCausalLM = model.language_model
433
+ vision_model: Llama4VisionModel = model.vision_model
434
+ text_model: Llama4TextModel = language_model.model
435
+ elif isinstance(model, Llama4ForCausalLM):
436
+ text_model = model.model
437
+ vision_model = None
438
+ elif isinstance(model, Llama4TextModel):
439
+ text_model = model
440
+ vision_model = None
441
+
442
+ else:
443
+ raise ValueError(f"Unsupported Llama4 model type: {type(model)}")
444
+
445
+ if text_model:
446
+ if rms_norm:
447
+ _patch_rms_norm_module(text_model.norm)
448
+ for decoder_layer in text_model.layers:
449
+ if swiglu:
450
+ _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
451
+ if rms_norm:
452
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
453
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
454
+
455
+ if vision_model:
456
+ _patch_layer_norm_module(vision_model.layernorm_pre)
457
+ _patch_layer_norm_module(vision_model.layernorm_post)
458
+
459
+ for layer in vision_model.model.layers:
460
+ if layer_norm:
461
+ _patch_layer_norm_module(layer.input_layernorm)
462
+ _patch_layer_norm_module(layer.post_attention_layernorm)
463
+
464
+
309
465
  def apply_liger_kernel_to_mllama(
310
466
  rope: bool = True,
311
467
  cross_entropy: bool = False,
@@ -347,7 +503,7 @@ def apply_liger_kernel_to_mllama(
347
503
 
348
504
  if rope:
349
505
  modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
350
- if layer_norm:
506
+ if layer_norm and model is None:
351
507
  modeling_mllama.nn.LayerNorm = LigerLayerNorm
352
508
  if rms_norm:
353
509
  modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
@@ -363,10 +519,16 @@ def apply_liger_kernel_to_mllama(
363
519
  modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
364
520
  if fused_linear_cross_entropy:
365
521
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
366
- modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
522
+ if model is not None:
523
+ model.forward = MethodType(mllama_lce_forward, model)
524
+ else:
525
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
367
526
  else: # if version < 4.46.1
368
527
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
369
- modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
528
+ if model is not None:
529
+ model.forward = MethodType(mllama_lce_forward_deprecated, model)
530
+ else:
531
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
370
532
 
371
533
  if model is not None:
372
534
  # The model instance already exists, so we need to additionally patch the
@@ -375,13 +537,17 @@ def apply_liger_kernel_to_mllama(
375
537
  if isinstance(model, MllamaForConditionalGeneration):
376
538
  language_model: MllamaForCausalLM = model.language_model
377
539
  vision_model: MllamaVisionModel = model.vision_model
378
- text_model: MllamaTextModel = language_model.model
540
+ if isinstance(language_model, MllamaForCausalLM):
541
+ text_model: MllamaTextModel = language_model.model
542
+ else:
543
+ text_model = language_model
379
544
  elif isinstance(model, MllamaForCausalLM):
380
545
  text_model = model.model
381
546
  vision_model = None
382
547
  elif isinstance(model, MllamaTextModel):
383
548
  text_model = model
384
549
  vision_model = None
550
+
385
551
  else:
386
552
  raise ValueError(f"Unsupported Mllama model type: {type(model)}")
387
553
 
@@ -448,7 +614,17 @@ def apply_liger_kernel_to_mistral(
448
614
  if cross_entropy:
449
615
  modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
450
616
  if fused_linear_cross_entropy:
451
- modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
617
+ if transformer_version >= version.parse("4.49.0"):
618
+ if model is not None:
619
+ model.forward = MethodType(mistral_lce_forward, model)
620
+ else:
621
+ modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
622
+ else:
623
+ logger.warning(
624
+ "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
625
+ )
626
+ logger.warning("LigerFusedLinearCrossEntropy patch is not applied.")
627
+
452
628
  if swiglu:
453
629
  modeling_mistral.MistralMLP = LigerSwiGLUMLP
454
630
 
@@ -516,10 +692,16 @@ def apply_liger_kernel_to_mixtral(
516
692
 
517
693
  if fused_linear_cross_entropy:
518
694
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
519
- modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
695
+ if model is not None:
696
+ model.forward = MethodType(mixtral_lce_forward, model)
697
+ else:
698
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
520
699
  else: # if version < 4.46.1
521
700
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
522
- modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
701
+ if model is not None:
702
+ model.forward = MethodType(mixtral_lce_forward_deprecated, model)
703
+ else:
704
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
523
705
  if swiglu:
524
706
  modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
525
707
 
@@ -573,8 +755,8 @@ def apply_liger_kernel_to_gemma(
573
755
  from transformers.models.gemma import modeling_gemma
574
756
  from transformers.models.gemma.modeling_gemma import GemmaModel
575
757
 
576
- # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
577
- LigerRMSNormForGemma = partial(LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma")
758
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma
759
+
578
760
  _patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
579
761
 
580
762
  if rope:
@@ -593,10 +775,16 @@ def apply_liger_kernel_to_gemma(
593
775
  modeling_gemma.GemmaMLP = LigerGEGLUMLP
594
776
  if fused_linear_cross_entropy:
595
777
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
596
- modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
778
+ if model is not None:
779
+ model.forward = MethodType(gemma_lce_forward, model)
780
+ else:
781
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
597
782
  else: # if version < 4.46.1
598
783
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
599
- modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
784
+ if model is not None:
785
+ model.forward = MethodType(gemma_lce_forward_deprecated, model)
786
+ else:
787
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
600
788
 
601
789
  if model is not None:
602
790
  # The model instance already exists, so we need to additionally patch the
@@ -647,7 +835,8 @@ def apply_liger_kernel_to_gemma2(
647
835
  from transformers.models.gemma2 import modeling_gemma2
648
836
  from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
649
837
 
650
- LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False)
838
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2
839
+
651
840
  _patch_rms_norm_module_for_gemma2 = partial(
652
841
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
653
842
  )
@@ -667,10 +856,16 @@ def apply_liger_kernel_to_gemma2(
667
856
  modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
668
857
  if fused_linear_cross_entropy:
669
858
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
670
- modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
859
+ if model is not None:
860
+ model.forward = MethodType(gemma2_lce_forward, model)
861
+ else:
862
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
671
863
  else:
672
864
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
673
- modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
865
+ if model is not None:
866
+ model.forward = MethodType(gemma2_lce_forward_deprected, model)
867
+ else:
868
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
674
869
  if geglu:
675
870
  modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
676
871
 
@@ -724,9 +919,10 @@ def apply_liger_kernel_to_gemma3_text(
724
919
  from transformers.models.gemma3 import modeling_gemma3
725
920
  from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
726
921
  from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
922
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
727
923
 
728
- from liger_kernel.transformers.gema3_rms import LigerRMSNormForGemma3
729
924
  from liger_kernel.transformers.model.gemma3 import causal_forward
925
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma3
730
926
 
731
927
  _patch_rms_norm_module_for_gemma3 = partial(
732
928
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
@@ -748,15 +944,18 @@ def apply_liger_kernel_to_gemma3_text(
748
944
  nn.functional.cross_entropy = liger_cross_entropy
749
945
 
750
946
  if fused_linear_cross_entropy:
751
- modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
947
+ if model is not None:
948
+ model.forward = MethodType(causal_forward, model)
949
+ else:
950
+ modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
752
951
 
753
952
  if model is not None:
754
953
  # The model instance already exists, so we need to additionally patch the
755
954
  # instance variables that reference already-instantiated modules
756
955
 
757
- if isinstance(model, Gemma3ForCausalLM):
956
+ if isinstance(model, Gemma3ForCausalLM) or isinstance(model, Gemma3TextModel):
758
957
  # get the base model from the model instance
759
- base_model = model.model
958
+ base_model = model.model if isinstance(model, Gemma3ForCausalLM) else model
760
959
 
761
960
  if rms_norm:
762
961
  _patch_rms_norm_module_for_gemma3(base_model.norm)
@@ -818,7 +1017,7 @@ def apply_liger_kernel_to_gemma3(
818
1017
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
819
1018
  )
820
1019
 
821
- if layer_norm:
1020
+ if layer_norm and model is None:
822
1021
  modeling_siglip.nn.LayerNorm = LigerLayerNorm
823
1022
 
824
1023
  apply_liger_kernel_to_gemma3_text(
@@ -829,7 +1028,10 @@ def apply_liger_kernel_to_gemma3(
829
1028
  modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
830
1029
 
831
1030
  if fused_linear_cross_entropy:
832
- modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
1031
+ if model is not None:
1032
+ model.forward = MethodType(multimodal_forward, model)
1033
+ else:
1034
+ modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
833
1035
 
834
1036
  if model is not None:
835
1037
  # The model instance already exists, so we need to additionally patch the
@@ -897,7 +1099,9 @@ def apply_liger_kernel_to_paligemma(
897
1099
  # PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
898
1100
 
899
1101
  from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
1102
+ from transformers.models.gemma.modeling_gemma import GemmaModel
900
1103
  from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
1104
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
901
1105
  from transformers.models.paligemma import modeling_paligemma
902
1106
  from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
903
1107
  from transformers.models.siglip import modeling_siglip
@@ -908,7 +1112,7 @@ def apply_liger_kernel_to_paligemma(
908
1112
  from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
909
1113
 
910
1114
  # The vision_tower is a SiglipVisionModel
911
- if layer_norm:
1115
+ if layer_norm and model is None:
912
1116
  modeling_siglip.nn.LayerNorm = LigerLayerNorm
913
1117
 
914
1118
  # SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
@@ -926,10 +1130,16 @@ def apply_liger_kernel_to_paligemma(
926
1130
  modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
927
1131
  if fused_linear_cross_entropy:
928
1132
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
929
- modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
1133
+ if model is not None:
1134
+ model.forward = MethodType(lce_forward, model)
1135
+ else:
1136
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
930
1137
  else: # if version < 4.46.1
931
1138
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
932
- modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
1139
+ if model is not None:
1140
+ model.forward = MethodType(lce_forward_deprecated, model)
1141
+ else:
1142
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
933
1143
 
934
1144
  if model is not None:
935
1145
  # The model instance already exists, so we need to additionally patch the
@@ -950,7 +1160,7 @@ def apply_liger_kernel_to_paligemma(
950
1160
 
951
1161
  language_model = model.language_model
952
1162
 
953
- if isinstance(language_model, GemmaForCausalLM):
1163
+ if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
954
1164
  apply_liger_kernel_to_gemma(
955
1165
  rope=rope,
956
1166
  cross_entropy=False,
@@ -960,7 +1170,7 @@ def apply_liger_kernel_to_paligemma(
960
1170
  model=language_model,
961
1171
  )
962
1172
 
963
- elif isinstance(language_model, Gemma2ForCausalLM):
1173
+ elif isinstance(language_model, (Gemma2ForCausalLM, Gemma2Model)):
964
1174
  apply_liger_kernel_to_gemma2(
965
1175
  rope=rope,
966
1176
  cross_entropy=False,
@@ -1021,10 +1231,16 @@ def apply_liger_kernel_to_qwen2(
1021
1231
 
1022
1232
  if fused_linear_cross_entropy:
1023
1233
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1024
- modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
1234
+ if model is not None:
1235
+ model.forward = MethodType(qwen2_lce_forward, model)
1236
+ else:
1237
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
1025
1238
  else: # if version < 4.46.1
1026
1239
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1027
- modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
1240
+ if model is not None:
1241
+ model.forward = MethodType(qwen2_lce_forward_deprecated, model)
1242
+ else:
1243
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
1028
1244
 
1029
1245
  if swiglu:
1030
1246
  modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
@@ -1080,7 +1296,10 @@ def apply_liger_kernel_to_qwen3(
1080
1296
  nn.functional.cross_entropy = liger_cross_entropy
1081
1297
 
1082
1298
  if fused_linear_cross_entropy:
1083
- modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
1299
+ if model is not None:
1300
+ model.forward = MethodType(qwen3_lce_forward, model)
1301
+ else:
1302
+ modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
1084
1303
 
1085
1304
  if swiglu:
1086
1305
  modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
@@ -1102,6 +1321,65 @@ def apply_liger_kernel_to_qwen3(
1102
1321
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1103
1322
 
1104
1323
 
1324
+ def apply_liger_kernel_to_qwen3_moe(
1325
+ rope: bool = True,
1326
+ cross_entropy: bool = False,
1327
+ fused_linear_cross_entropy: bool = True,
1328
+ rms_norm: bool = True,
1329
+ swiglu: bool = True,
1330
+ model: PreTrainedModel = None,
1331
+ ) -> None:
1332
+ """
1333
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
1334
+ """
1335
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1336
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1337
+ )
1338
+
1339
+ from transformers.models.qwen3_moe import modeling_qwen3_moe
1340
+ from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeModel
1341
+
1342
+ from liger_kernel.transformers.model.qwen3_moe import lce_forward as qwen3_lce_forward
1343
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
1344
+
1345
+ if rope:
1346
+ modeling_qwen3_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
1347
+
1348
+ if rms_norm:
1349
+ modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
1350
+
1351
+ if cross_entropy:
1352
+ from transformers.loss.loss_utils import nn
1353
+
1354
+ nn.functional.cross_entropy = liger_cross_entropy
1355
+
1356
+ if fused_linear_cross_entropy:
1357
+ if model is not None:
1358
+ model.forward = MethodType(qwen3_lce_forward, model)
1359
+ else:
1360
+ modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
1361
+
1362
+ if swiglu:
1363
+ modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
1364
+
1365
+ if model is not None:
1366
+ # The model instance already exists, so we need to additionally patch the
1367
+ # instance variables that reference already-instantiated modules
1368
+
1369
+ # get the base model from the model instance
1370
+ base_model: Qwen3MoeModel = getattr(model, model.base_model_prefix, model)
1371
+
1372
+ if rms_norm:
1373
+ _patch_rms_norm_module(base_model.norm)
1374
+ for decoder_layer in base_model.layers:
1375
+ if swiglu:
1376
+ for mlp_expert in decoder_layer.mlp.experts:
1377
+ _patch_swiglu_module(mlp_expert, LigerQwen3MoeSwiGLUMLP)
1378
+ if rms_norm:
1379
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1380
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1381
+
1382
+
1105
1383
  def apply_liger_kernel_to_qwen2_vl(
1106
1384
  rope: bool = True,
1107
1385
  cross_entropy: bool = False,
@@ -1113,7 +1391,7 @@ def apply_liger_kernel_to_qwen2_vl(
1113
1391
  ) -> None:
1114
1392
  """
1115
1393
  Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
1116
- NOTE: Qwen2-VL is not available in transformers<4.45.0
1394
+ NOTE: Qwen2-VL is not supported in transformers<4.52.4
1117
1395
 
1118
1396
  Args:
1119
1397
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
@@ -1127,12 +1405,19 @@ def apply_liger_kernel_to_qwen2_vl(
1127
1405
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1128
1406
  loaded. Default is None.
1129
1407
  """
1408
+ if transformer_version < version.parse("4.52.4"):
1409
+ logger.warning("Qwen2-VL support is only compatible with transformers >= 4.52.4")
1410
+ return
1411
+
1130
1412
  assert not (cross_entropy and fused_linear_cross_entropy), (
1131
1413
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
1132
1414
  )
1133
1415
 
1134
1416
  from transformers.models.qwen2_vl import modeling_qwen2_vl
1417
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
1418
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
1135
1419
  from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
1420
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel
1136
1421
 
1137
1422
  from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
1138
1423
 
@@ -1141,12 +1426,15 @@ def apply_liger_kernel_to_qwen2_vl(
1141
1426
  if rms_norm:
1142
1427
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
1143
1428
  modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
1144
- if layer_norm:
1429
+ if layer_norm and model is None:
1145
1430
  modeling_qwen2_vl.LayerNorm = LigerLayerNorm
1146
1431
  if cross_entropy:
1147
1432
  modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1148
1433
  if fused_linear_cross_entropy:
1149
- modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
1434
+ if model is not None:
1435
+ model.forward = MethodType(qwen2_vl_lce_forward, model)
1436
+ else:
1437
+ modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
1150
1438
  if swiglu:
1151
1439
  modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
1152
1440
 
@@ -1154,24 +1442,38 @@ def apply_liger_kernel_to_qwen2_vl(
1154
1442
  # The model instance already exists, so we need to additionally patch the
1155
1443
  # instance variables that reference already-instantiated modules
1156
1444
 
1157
- # get the base model from the model instance
1158
- base_model: Qwen2VLModel = getattr(model, model.base_model_prefix, model)
1445
+ if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)):
1446
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
1447
+ # Not sure if it is subject to changes in the future.
1448
+ # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
1449
+ text_model: Qwen2VLTextModel = model.language_model
1450
+ vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
1451
+ elif isinstance(model, Qwen2VLTextModel):
1452
+ text_model: Qwen2VLTextModel = model
1453
+ vision_model = None
1454
+ else:
1455
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
1456
+ raise TypeError(
1457
+ f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
1458
+ )
1159
1459
 
1160
- if hasattr(model, "visual"):
1161
- # Patch Qwen2VisionTransformerPretrainedModel
1162
- for vision_block in model.visual.blocks:
1460
+ # Patch Qwen2VisionTransformerPretrainedModel
1461
+ if vision_model is not None:
1462
+ for vision_block in vision_model.blocks:
1163
1463
  if layer_norm:
1164
1464
  _patch_layer_norm_module(vision_block.norm1)
1165
1465
  _patch_layer_norm_module(vision_block.norm2)
1166
1466
 
1167
- if rms_norm:
1168
- _patch_rms_norm_module(base_model.norm)
1169
- for decoder_layer in base_model.layers:
1170
- if swiglu:
1171
- _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1467
+ # Patch Qwen2VisionTextModel
1468
+ if text_model is not None:
1172
1469
  if rms_norm:
1173
- _patch_rms_norm_module(decoder_layer.input_layernorm)
1174
- _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1470
+ _patch_rms_norm_module(text_model.norm)
1471
+ for decoder_layer in text_model.layers:
1472
+ if swiglu:
1473
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1474
+ if rms_norm:
1475
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1476
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1175
1477
 
1176
1478
 
1177
1479
  def apply_liger_kernel_to_qwen2_5_vl(
@@ -1197,12 +1499,19 @@ def apply_liger_kernel_to_qwen2_5_vl(
1197
1499
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1198
1500
  loaded. Default is None.
1199
1501
  """
1502
+ if transformer_version < version.parse("4.52.4"):
1503
+ logger.warning("Qwen2.5-VL support is only compatible with transformers >= 4.52.4")
1504
+ return
1505
+
1200
1506
  assert not (cross_entropy and fused_linear_cross_entropy), (
1201
1507
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
1202
1508
  )
1203
1509
 
1204
1510
  from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
1511
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
1512
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
1205
1513
  from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
1514
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel
1206
1515
 
1207
1516
  from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward
1208
1517
 
@@ -1213,7 +1522,10 @@ def apply_liger_kernel_to_qwen2_5_vl(
1213
1522
  if cross_entropy:
1214
1523
  modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1215
1524
  if fused_linear_cross_entropy:
1216
- modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
1525
+ if model is not None:
1526
+ model.forward = MethodType(qwen2_5_vl_lce_forward, model)
1527
+ else:
1528
+ modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
1217
1529
  if swiglu:
1218
1530
  modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
1219
1531
 
@@ -1221,24 +1533,37 @@ def apply_liger_kernel_to_qwen2_5_vl(
1221
1533
  # The model instance already exists, so we need to additionally patch the
1222
1534
  # instance variables that reference already-instantiated modules
1223
1535
 
1224
- # get the base model from the model instance
1225
- base_model: Qwen2_5_VLModel = getattr(model, model.base_model_prefix, model)
1536
+ if isinstance(model, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel)):
1537
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
1538
+ # Not sure if it is subject to changes in the future.
1539
+ # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
1540
+ text_model: Qwen2_5_VLTextModel = model.language_model
1541
+ vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
1542
+ elif isinstance(model, Qwen2_5_VLTextModel):
1543
+ text_model: Qwen2_5_VLTextModel = model
1544
+ vision_model = None
1545
+ else:
1546
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
1547
+ raise TypeError(
1548
+ f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
1549
+ )
1226
1550
 
1227
- if hasattr(model, "visual"):
1551
+ if vision_model is not None:
1228
1552
  # Patch Qwen2_5_VisionTransformerPretrainedModel
1229
1553
  for vision_block in model.visual.blocks:
1230
1554
  if rms_norm:
1231
1555
  _patch_rms_norm_module(vision_block.norm1)
1232
1556
  _patch_rms_norm_module(vision_block.norm2)
1233
1557
 
1234
- if rms_norm:
1235
- _patch_rms_norm_module(base_model.norm)
1236
- for decoder_layer in base_model.layers:
1237
- if swiglu:
1238
- _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1558
+ if text_model is not None:
1239
1559
  if rms_norm:
1240
- _patch_rms_norm_module(decoder_layer.input_layernorm)
1241
- _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1560
+ _patch_rms_norm_module(text_model.norm)
1561
+ for decoder_layer in text_model.layers:
1562
+ if swiglu:
1563
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1564
+ if rms_norm:
1565
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1566
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1242
1567
 
1243
1568
 
1244
1569
  def apply_liger_kernel_to_phi3(
@@ -1287,10 +1612,16 @@ def apply_liger_kernel_to_phi3(
1287
1612
  modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
1288
1613
  if fused_linear_cross_entropy:
1289
1614
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1290
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1615
+ if model is not None:
1616
+ model.forward = MethodType(phi3_lce_forward, model)
1617
+ else:
1618
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1291
1619
  else: # if version < 4.46.1
1292
1620
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1293
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
1621
+ if model is not None:
1622
+ model.forward = MethodType(phi3_lce_forward_deprecated, model)
1623
+ else:
1624
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
1294
1625
 
1295
1626
  if model is not None:
1296
1627
  # The model instance already exists, so we need to additionally patch the
@@ -1341,11 +1672,12 @@ def apply_liger_kernel_to_olmo2(
1341
1672
  from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
1342
1673
 
1343
1674
  from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
1675
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
1344
1676
 
1345
1677
  if rope:
1346
1678
  modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
1347
1679
  if rms_norm:
1348
- modeling_olmo2.Olmo2RMSNorm = partial(LigerRMSNorm, in_place=False)
1680
+ modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
1349
1681
  if swiglu:
1350
1682
  modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
1351
1683
  if cross_entropy:
@@ -1353,7 +1685,10 @@ def apply_liger_kernel_to_olmo2(
1353
1685
 
1354
1686
  nn.functional.cross_entropy = liger_cross_entropy
1355
1687
  if fused_linear_cross_entropy:
1356
- modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
1688
+ if model is not None:
1689
+ model.forward = MethodType(olmo2_lce_forward, model)
1690
+ else:
1691
+ modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
1357
1692
 
1358
1693
  if model is not None:
1359
1694
  # The model instance already exists, so we need to additionally patch the
@@ -1404,11 +1739,12 @@ def apply_liger_kernel_to_glm4(
1404
1739
  from transformers.models.glm4.modeling_glm4 import Glm4Model
1405
1740
 
1406
1741
  from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
1742
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
1407
1743
 
1408
1744
  if rope:
1409
1745
  raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
1410
1746
  if rms_norm:
1411
- modeling_glm4.Glm4RMSNorm = partial(LigerRMSNorm, in_place=False)
1747
+ modeling_glm4.Glm4RMSNorm = LigerRMSNormForGlm4
1412
1748
  if swiglu:
1413
1749
  modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
1414
1750
  if cross_entropy:
@@ -1416,7 +1752,10 @@ def apply_liger_kernel_to_glm4(
1416
1752
 
1417
1753
  nn.functional.cross_entropy = liger_cross_entropy
1418
1754
  if fused_linear_cross_entropy:
1419
- modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
1755
+ if model is not None:
1756
+ model.forward = MethodType(glm4_lce_forward, model)
1757
+ else:
1758
+ modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
1420
1759
 
1421
1760
  if model is not None:
1422
1761
  # The model instance already exists, so we need to additionally patch the
@@ -1446,6 +1785,8 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1446
1785
  "gemma3": apply_liger_kernel_to_gemma3,
1447
1786
  "glm4": apply_liger_kernel_to_glm4,
1448
1787
  "llama": apply_liger_kernel_to_llama,
1788
+ "llama4_text": apply_liger_kernel_to_llama4,
1789
+ "llama4": apply_liger_kernel_to_llama4,
1449
1790
  "llava": apply_liger_kernel_to_llava,
1450
1791
  "granite": apply_liger_kernel_to_granite,
1451
1792
  "mllama": apply_liger_kernel_to_mllama,
@@ -1455,8 +1796,11 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1455
1796
  "olmo2": apply_liger_kernel_to_olmo2,
1456
1797
  "qwen2": apply_liger_kernel_to_qwen2,
1457
1798
  "qwen3": apply_liger_kernel_to_qwen3,
1799
+ "qwen3_moe": apply_liger_kernel_to_qwen3_moe,
1458
1800
  "qwen2_vl": apply_liger_kernel_to_qwen2_vl,
1801
+ "qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
1459
1802
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
1803
+ "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
1460
1804
  "phi3": apply_liger_kernel_to_phi3,
1461
1805
  "paligemma": apply_liger_kernel_to_paligemma,
1462
1806
  }
@@ -1516,7 +1860,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
1516
1860
  return
1517
1861
 
1518
1862
  apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
1519
-
1520
1863
  apply_fn_signature = inspect.signature(apply_fn)
1521
1864
 
1522
1865
  # Filter out the keyword arguments that are not supported by the apply function