liger-kernel 0.5.10__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
  3. liger_kernel/chunked_loss/functional.py +2 -0
  4. liger_kernel/ops/dyt.py +0 -2
  5. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  6. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  7. liger_kernel/ops/geglu.py +1 -1
  8. liger_kernel/ops/layer_norm.py +126 -89
  9. liger_kernel/ops/multi_token_attention.py +207 -0
  10. liger_kernel/ops/rms_norm.py +267 -56
  11. liger_kernel/ops/rope.py +1 -1
  12. liger_kernel/ops/softmax.py +201 -0
  13. liger_kernel/ops/sparsemax.py +62 -50
  14. liger_kernel/ops/swiglu.py +1 -1
  15. liger_kernel/transformers/__init__.py +8 -0
  16. liger_kernel/transformers/functional.py +67 -0
  17. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  18. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  19. liger_kernel/transformers/model/gemma.py +25 -8
  20. liger_kernel/transformers/model/gemma2.py +27 -8
  21. liger_kernel/transformers/model/gemma3.py +63 -99
  22. liger_kernel/transformers/model/glm4.py +16 -7
  23. liger_kernel/transformers/model/llama.py +25 -7
  24. liger_kernel/transformers/model/llama4.py +108 -0
  25. liger_kernel/transformers/model/llava.py +95 -124
  26. liger_kernel/transformers/model/mistral.py +13 -8
  27. liger_kernel/transformers/model/mixtral.py +16 -7
  28. liger_kernel/transformers/model/mllama.py +16 -7
  29. liger_kernel/transformers/model/olmo2.py +16 -7
  30. liger_kernel/transformers/model/paligemma.py +8 -1
  31. liger_kernel/transformers/model/phi3.py +25 -8
  32. liger_kernel/transformers/model/qwen2.py +24 -7
  33. liger_kernel/transformers/model/qwen2_5_vl.py +41 -91
  34. liger_kernel/transformers/model/qwen2_vl.py +38 -100
  35. liger_kernel/transformers/model/qwen3.py +11 -3
  36. liger_kernel/transformers/model/qwen3_moe.py +10 -6
  37. liger_kernel/transformers/model/smollm3.py +189 -0
  38. liger_kernel/transformers/monkey_patch.py +389 -82
  39. liger_kernel/transformers/multi_token_attention.py +64 -0
  40. liger_kernel/transformers/rms_norm.py +40 -4
  41. liger_kernel/transformers/softmax.py +12 -0
  42. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/METADATA +18 -14
  43. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/RECORD +47 -37
  44. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/WHEEL +1 -1
  45. liger_kernel/transformers/gema3_rms.py +0 -8
  46. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/licenses/LICENSE +0 -0
  47. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/licenses/NOTICE +0 -0
  48. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@ import inspect
2
2
  import logging
3
3
 
4
4
  from functools import partial
5
+ from types import MethodType
5
6
  from typing import Callable
6
7
 
7
8
  import transformers
@@ -28,6 +29,7 @@ from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
28
29
  from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
29
30
  from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
30
31
  from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
32
+ from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
31
33
  from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
32
34
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
33
35
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
@@ -54,7 +56,7 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
54
56
  module.__dict__[method_name] = new_method.__get__(module, module.__class__)
55
57
 
56
58
 
57
- def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
59
+ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None):
58
60
  # Check if the module is a PEFT ModulesToSaveWrapper
59
61
  # If it is, we need to patch the modules_to_save.default and original_modules
60
62
  if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
@@ -64,26 +66,29 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
64
66
  getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
65
67
  )
66
68
  module.modules_to_save.default.in_place = in_place
69
+ module.modules_to_save.default.row_mode = row_mode
67
70
  module.original_module.offset = offset
68
71
  module.original_module.casting_mode = casting_mode
69
72
  module.original_module.variance_epsilon = (
70
73
  getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
71
74
  )
72
75
  module.original_module.in_place = in_place
76
+ module.original_module.row_mode = row_mode
73
77
  _bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
74
78
  _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
75
79
  _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
76
80
  _bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
77
- module.modules_to_save.default.__class__.__name__ = LigerRMSNorm.__name__
78
- module.original_module.__class__.__name__ = LigerRMSNorm.__name__
81
+ _bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerRMSNorm.__name__)
82
+ _bind_method_to_module(module.original_module, "_get_name", lambda self: LigerRMSNorm.__name__)
79
83
  else:
80
84
  module.offset = offset
81
85
  module.casting_mode = casting_mode
82
86
  module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
83
87
  module.in_place = in_place
88
+ module.row_mode = row_mode
84
89
  _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
85
90
  _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
86
- module.__class__.__name__ = LigerRMSNorm.__name__
91
+ _bind_method_to_module(module, "_get_name", lambda self: LigerRMSNorm.__name__)
87
92
 
88
93
 
89
94
  def _patch_layer_norm_module(module, eps=1e-6):
@@ -105,28 +110,28 @@ def _patch_layer_norm_module(module, eps=1e-6):
105
110
  module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
106
111
  module, "normalized_shape", None
107
112
  )
108
- _bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
109
- _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
110
- _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
111
- _bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
112
- module.modules_to_save.default.__class__.__name__ = LigerLayerNorm.__name__
113
- module.original_module.__class__.__name__ = LigerLayerNorm.__name__
113
+ _bind_method_to_module(module.modules_to_save.default, "forward", LigerLayerNorm.forward)
114
+ _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerLayerNorm.extra_repr)
115
+ _bind_method_to_module(module.original_module, "forward", LigerLayerNorm.forward)
116
+ _bind_method_to_module(module.original_module, "extra_repr", LigerLayerNorm.extra_repr)
117
+ _bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerLayerNorm.__name__)
118
+ _bind_method_to_module(module.original_module, "_get_name", lambda self: LigerLayerNorm.__name__)
114
119
  else:
115
120
  module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
116
121
  module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
117
122
  _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
118
123
  _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
119
- module.__class__.__name__ = LigerLayerNorm.__name__
124
+ _bind_method_to_module(module, "_get_name", lambda self: LigerLayerNorm.__name__)
120
125
 
121
126
 
122
127
  def _patch_swiglu_module(module, liger_module):
123
128
  _bind_method_to_module(module, "forward", liger_module.forward)
124
- module.__class__.__name__ = liger_module.__name__
129
+ _bind_method_to_module(module, "_get_name", lambda self: liger_module.__name__)
125
130
 
126
131
 
127
132
  def _patch_geglu_module(module):
128
133
  _bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
129
- module.__class__.__name__ = LigerGEGLUMLP.__name__
134
+ _bind_method_to_module(module, "_get_name", lambda self: LigerGEGLUMLP.__name__)
130
135
 
131
136
 
132
137
  def apply_liger_kernel_to_granite(
@@ -257,10 +262,16 @@ def apply_liger_kernel_to_llama(
257
262
 
258
263
  if fused_linear_cross_entropy:
259
264
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
260
- modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
265
+ if model is not None:
266
+ model.forward = MethodType(llama_lce_forward, model)
267
+ else:
268
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
261
269
  else: # if version < 4.46.1
262
270
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
263
- modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
271
+ if model is not None:
272
+ model.forward = MethodType(llama_lce_forward_deprecated, model)
273
+ else:
274
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
264
275
 
265
276
  if model is not None:
266
277
  # The model instance already exists, so we need to additionally patch the
@@ -280,6 +291,77 @@ def apply_liger_kernel_to_llama(
280
291
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
281
292
 
282
293
 
294
+ def apply_liger_kernel_to_smollm3(
295
+ rope: bool = True,
296
+ cross_entropy: bool = False,
297
+ fused_linear_cross_entropy: bool = True,
298
+ rms_norm: bool = True,
299
+ swiglu: bool = True,
300
+ model: PreTrainedModel = None,
301
+ ) -> None:
302
+ """
303
+ Apply Liger kernels to replace original implementation in HuggingFace SmolLM3 model
304
+
305
+ Args:
306
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
307
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
308
+ fused_linear_cross_entropy (bool):
309
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
310
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
311
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
312
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
313
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
314
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
315
+ loaded. Default is None.
316
+ """
317
+
318
+ assert not (cross_entropy and fused_linear_cross_entropy), (
319
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
320
+ )
321
+
322
+ from transformers.models.smollm3 import modeling_smollm3
323
+ from transformers.models.smollm3.modeling_smollm3 import SmolLM3Model
324
+
325
+ if rope:
326
+ modeling_smollm3.apply_rotary_pos_emb = liger_rotary_pos_emb
327
+ if rms_norm:
328
+ modeling_smollm3.SmolLM3RMSNorm = LigerRMSNorm
329
+ if swiglu:
330
+ modeling_smollm3.SmolLM3MLP = LigerSwiGLUMLP
331
+
332
+ if cross_entropy:
333
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
334
+ from transformers.loss.loss_utils import nn
335
+
336
+ nn.functional.cross_entropy = liger_cross_entropy
337
+ else:
338
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
339
+ modeling_smollm3.CrossEntropyLoss = LigerCrossEntropyLoss
340
+
341
+ if fused_linear_cross_entropy:
342
+ if model is not None:
343
+ model.forward = MethodType(smollm3_lce_forward, model)
344
+ else:
345
+ modeling_smollm3.SmolLM3ForCausalLM.forward = smollm3_lce_forward
346
+
347
+ if model is not None:
348
+ # The model instance already exists, so we need to additionally patch the
349
+ # instance variables that reference already-instantiated modules (e.g. SmolLM3RMSNorm or SmolLM3MLP)
350
+
351
+ # get the base model from the model instance
352
+ base_model: SmolLM3Model = getattr(model, model.base_model_prefix, model)
353
+
354
+ if rms_norm:
355
+ _patch_rms_norm_module(base_model.norm)
356
+
357
+ for decoder_layer in base_model.layers:
358
+ if swiglu:
359
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
360
+ if rms_norm:
361
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
362
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
363
+
364
+
283
365
  def apply_liger_kernel_to_llava(
284
366
  cross_entropy: bool = False,
285
367
  fused_linear_cross_entropy: bool = True,
@@ -314,13 +396,20 @@ def apply_liger_kernel_to_llava(
314
396
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
315
397
  modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
316
398
  if fused_linear_cross_entropy:
317
- if transformer_version >= version.parse("4.49.0"):
318
- modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
399
+ if transformer_version >= version.parse("4.52.0"):
400
+ if model is not None:
401
+ model.forward = MethodType(llava_lce_forward, model)
402
+ else:
403
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
404
+ elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
405
+ if model is not None:
406
+ model.forward = MethodType(llava_lce_forward_deprecated, model)
407
+ else:
408
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
319
409
  else: # if version < 4.49.0
320
410
  logger.warning(
321
- "Support for transformers versions < 4.49.0 will soon be discontinued due to issues with incorrect legacy processing. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/35526"
411
+ "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
322
412
  )
323
- modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
324
413
 
325
414
  if model is not None:
326
415
  text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
@@ -359,6 +448,92 @@ def apply_liger_kernel_to_llava(
359
448
  logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
360
449
 
361
450
 
451
+ def apply_liger_kernel_to_llama4(
452
+ rope: bool = False,
453
+ cross_entropy: bool = False,
454
+ fused_linear_cross_entropy: bool = True,
455
+ rms_norm: bool = True,
456
+ swiglu: bool = True,
457
+ model: PreTrainedModel = None,
458
+ layer_norm: bool = True,
459
+ ) -> None:
460
+ """
461
+ Apply Liger kernels to replace original implementation in HuggingFace Llama4 models.
462
+
463
+ Args:
464
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
465
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
466
+ fused_linear_cross_entropy (bool):
467
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
468
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
469
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
470
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
471
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
472
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
473
+ loaded. Default is None.
474
+ """
475
+ assert not (cross_entropy and fused_linear_cross_entropy), (
476
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
477
+ )
478
+
479
+ from transformers.models.llama4 import modeling_llama4
480
+ from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
481
+ from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
482
+ from transformers.models.llama4.modeling_llama4 import Llama4TextModel
483
+ from transformers.models.llama4.modeling_llama4 import Llama4VisionModel
484
+
485
+ from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
486
+
487
+ if rope:
488
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Llama4 models.")
489
+ if rms_norm:
490
+ modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
491
+ if swiglu:
492
+ modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
493
+
494
+ if cross_entropy:
495
+ modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
496
+
497
+ if fused_linear_cross_entropy:
498
+ modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
499
+
500
+ if model is not None:
501
+ # The model instance already exists, so we need to additionally patch the
502
+ # instance variables that reference already-instantiated modules
503
+ if isinstance(model, Llama4ForConditionalGeneration):
504
+ language_model: Llama4ForCausalLM = model.language_model
505
+ vision_model: Llama4VisionModel = model.vision_model
506
+ text_model: Llama4TextModel = language_model.model
507
+ elif isinstance(model, Llama4ForCausalLM):
508
+ text_model = model.model
509
+ vision_model = None
510
+ elif isinstance(model, Llama4TextModel):
511
+ text_model = model
512
+ vision_model = None
513
+
514
+ else:
515
+ raise ValueError(f"Unsupported Llama4 model type: {type(model)}")
516
+
517
+ if text_model:
518
+ if rms_norm:
519
+ _patch_rms_norm_module(text_model.norm)
520
+ for decoder_layer in text_model.layers:
521
+ if swiglu:
522
+ _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
523
+ if rms_norm:
524
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
525
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
526
+
527
+ if vision_model:
528
+ _patch_layer_norm_module(vision_model.layernorm_pre)
529
+ _patch_layer_norm_module(vision_model.layernorm_post)
530
+
531
+ for layer in vision_model.model.layers:
532
+ if layer_norm:
533
+ _patch_layer_norm_module(layer.input_layernorm)
534
+ _patch_layer_norm_module(layer.post_attention_layernorm)
535
+
536
+
362
537
  def apply_liger_kernel_to_mllama(
363
538
  rope: bool = True,
364
539
  cross_entropy: bool = False,
@@ -400,7 +575,7 @@ def apply_liger_kernel_to_mllama(
400
575
 
401
576
  if rope:
402
577
  modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
403
- if layer_norm:
578
+ if layer_norm and model is None:
404
579
  modeling_mllama.nn.LayerNorm = LigerLayerNorm
405
580
  if rms_norm:
406
581
  modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
@@ -416,10 +591,16 @@ def apply_liger_kernel_to_mllama(
416
591
  modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
417
592
  if fused_linear_cross_entropy:
418
593
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
419
- modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
594
+ if model is not None:
595
+ model.forward = MethodType(mllama_lce_forward, model)
596
+ else:
597
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
420
598
  else: # if version < 4.46.1
421
599
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
422
- modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
600
+ if model is not None:
601
+ model.forward = MethodType(mllama_lce_forward_deprecated, model)
602
+ else:
603
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
423
604
 
424
605
  if model is not None:
425
606
  # The model instance already exists, so we need to additionally patch the
@@ -428,13 +609,17 @@ def apply_liger_kernel_to_mllama(
428
609
  if isinstance(model, MllamaForConditionalGeneration):
429
610
  language_model: MllamaForCausalLM = model.language_model
430
611
  vision_model: MllamaVisionModel = model.vision_model
431
- text_model: MllamaTextModel = language_model.model
612
+ if isinstance(language_model, MllamaForCausalLM):
613
+ text_model: MllamaTextModel = language_model.model
614
+ else:
615
+ text_model = language_model
432
616
  elif isinstance(model, MllamaForCausalLM):
433
617
  text_model = model.model
434
618
  vision_model = None
435
619
  elif isinstance(model, MllamaTextModel):
436
620
  text_model = model
437
621
  vision_model = None
622
+
438
623
  else:
439
624
  raise ValueError(f"Unsupported Mllama model type: {type(model)}")
440
625
 
@@ -501,7 +686,17 @@ def apply_liger_kernel_to_mistral(
501
686
  if cross_entropy:
502
687
  modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
503
688
  if fused_linear_cross_entropy:
504
- modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
689
+ if transformer_version >= version.parse("4.49.0"):
690
+ if model is not None:
691
+ model.forward = MethodType(mistral_lce_forward, model)
692
+ else:
693
+ modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
694
+ else:
695
+ logger.warning(
696
+ "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
697
+ )
698
+ logger.warning("LigerFusedLinearCrossEntropy patch is not applied.")
699
+
505
700
  if swiglu:
506
701
  modeling_mistral.MistralMLP = LigerSwiGLUMLP
507
702
 
@@ -569,10 +764,16 @@ def apply_liger_kernel_to_mixtral(
569
764
 
570
765
  if fused_linear_cross_entropy:
571
766
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
572
- modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
767
+ if model is not None:
768
+ model.forward = MethodType(mixtral_lce_forward, model)
769
+ else:
770
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
573
771
  else: # if version < 4.46.1
574
772
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
575
- modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
773
+ if model is not None:
774
+ model.forward = MethodType(mixtral_lce_forward_deprecated, model)
775
+ else:
776
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
576
777
  if swiglu:
577
778
  modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
578
779
 
@@ -626,8 +827,8 @@ def apply_liger_kernel_to_gemma(
626
827
  from transformers.models.gemma import modeling_gemma
627
828
  from transformers.models.gemma.modeling_gemma import GemmaModel
628
829
 
629
- # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
630
- LigerRMSNormForGemma = partial(LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma")
830
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma
831
+
631
832
  _patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
632
833
 
633
834
  if rope:
@@ -646,10 +847,16 @@ def apply_liger_kernel_to_gemma(
646
847
  modeling_gemma.GemmaMLP = LigerGEGLUMLP
647
848
  if fused_linear_cross_entropy:
648
849
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
649
- modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
850
+ if model is not None:
851
+ model.forward = MethodType(gemma_lce_forward, model)
852
+ else:
853
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
650
854
  else: # if version < 4.46.1
651
855
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
652
- modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
856
+ if model is not None:
857
+ model.forward = MethodType(gemma_lce_forward_deprecated, model)
858
+ else:
859
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
653
860
 
654
861
  if model is not None:
655
862
  # The model instance already exists, so we need to additionally patch the
@@ -700,7 +907,8 @@ def apply_liger_kernel_to_gemma2(
700
907
  from transformers.models.gemma2 import modeling_gemma2
701
908
  from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
702
909
 
703
- LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False)
910
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2
911
+
704
912
  _patch_rms_norm_module_for_gemma2 = partial(
705
913
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
706
914
  )
@@ -720,10 +928,16 @@ def apply_liger_kernel_to_gemma2(
720
928
  modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
721
929
  if fused_linear_cross_entropy:
722
930
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
723
- modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
931
+ if model is not None:
932
+ model.forward = MethodType(gemma2_lce_forward, model)
933
+ else:
934
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
724
935
  else:
725
936
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
726
- modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
937
+ if model is not None:
938
+ model.forward = MethodType(gemma2_lce_forward_deprected, model)
939
+ else:
940
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
727
941
  if geglu:
728
942
  modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
729
943
 
@@ -777,9 +991,10 @@ def apply_liger_kernel_to_gemma3_text(
777
991
  from transformers.models.gemma3 import modeling_gemma3
778
992
  from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
779
993
  from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
994
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
780
995
 
781
- from liger_kernel.transformers.gema3_rms import LigerRMSNormForGemma3
782
996
  from liger_kernel.transformers.model.gemma3 import causal_forward
997
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma3
783
998
 
784
999
  _patch_rms_norm_module_for_gemma3 = partial(
785
1000
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
@@ -801,15 +1016,18 @@ def apply_liger_kernel_to_gemma3_text(
801
1016
  nn.functional.cross_entropy = liger_cross_entropy
802
1017
 
803
1018
  if fused_linear_cross_entropy:
804
- modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
1019
+ if model is not None:
1020
+ model.forward = MethodType(causal_forward, model)
1021
+ else:
1022
+ modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
805
1023
 
806
1024
  if model is not None:
807
1025
  # The model instance already exists, so we need to additionally patch the
808
1026
  # instance variables that reference already-instantiated modules
809
1027
 
810
- if isinstance(model, Gemma3ForCausalLM):
1028
+ if isinstance(model, Gemma3ForCausalLM) or isinstance(model, Gemma3TextModel):
811
1029
  # get the base model from the model instance
812
- base_model = model.model
1030
+ base_model = model.model if isinstance(model, Gemma3ForCausalLM) else model
813
1031
 
814
1032
  if rms_norm:
815
1033
  _patch_rms_norm_module_for_gemma3(base_model.norm)
@@ -871,7 +1089,7 @@ def apply_liger_kernel_to_gemma3(
871
1089
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
872
1090
  )
873
1091
 
874
- if layer_norm:
1092
+ if layer_norm and model is None:
875
1093
  modeling_siglip.nn.LayerNorm = LigerLayerNorm
876
1094
 
877
1095
  apply_liger_kernel_to_gemma3_text(
@@ -882,7 +1100,10 @@ def apply_liger_kernel_to_gemma3(
882
1100
  modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
883
1101
 
884
1102
  if fused_linear_cross_entropy:
885
- modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
1103
+ if model is not None:
1104
+ model.forward = MethodType(multimodal_forward, model)
1105
+ else:
1106
+ modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
886
1107
 
887
1108
  if model is not None:
888
1109
  # The model instance already exists, so we need to additionally patch the
@@ -950,7 +1171,9 @@ def apply_liger_kernel_to_paligemma(
950
1171
  # PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
951
1172
 
952
1173
  from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
1174
+ from transformers.models.gemma.modeling_gemma import GemmaModel
953
1175
  from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
1176
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
954
1177
  from transformers.models.paligemma import modeling_paligemma
955
1178
  from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
956
1179
  from transformers.models.siglip import modeling_siglip
@@ -961,7 +1184,7 @@ def apply_liger_kernel_to_paligemma(
961
1184
  from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
962
1185
 
963
1186
  # The vision_tower is a SiglipVisionModel
964
- if layer_norm:
1187
+ if layer_norm and model is None:
965
1188
  modeling_siglip.nn.LayerNorm = LigerLayerNorm
966
1189
 
967
1190
  # SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
@@ -979,10 +1202,16 @@ def apply_liger_kernel_to_paligemma(
979
1202
  modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
980
1203
  if fused_linear_cross_entropy:
981
1204
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
982
- modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
1205
+ if model is not None:
1206
+ model.forward = MethodType(lce_forward, model)
1207
+ else:
1208
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
983
1209
  else: # if version < 4.46.1
984
1210
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
985
- modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
1211
+ if model is not None:
1212
+ model.forward = MethodType(lce_forward_deprecated, model)
1213
+ else:
1214
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
986
1215
 
987
1216
  if model is not None:
988
1217
  # The model instance already exists, so we need to additionally patch the
@@ -1003,7 +1232,7 @@ def apply_liger_kernel_to_paligemma(
1003
1232
 
1004
1233
  language_model = model.language_model
1005
1234
 
1006
- if isinstance(language_model, GemmaForCausalLM):
1235
+ if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
1007
1236
  apply_liger_kernel_to_gemma(
1008
1237
  rope=rope,
1009
1238
  cross_entropy=False,
@@ -1013,7 +1242,7 @@ def apply_liger_kernel_to_paligemma(
1013
1242
  model=language_model,
1014
1243
  )
1015
1244
 
1016
- elif isinstance(language_model, Gemma2ForCausalLM):
1245
+ elif isinstance(language_model, (Gemma2ForCausalLM, Gemma2Model)):
1017
1246
  apply_liger_kernel_to_gemma2(
1018
1247
  rope=rope,
1019
1248
  cross_entropy=False,
@@ -1074,10 +1303,16 @@ def apply_liger_kernel_to_qwen2(
1074
1303
 
1075
1304
  if fused_linear_cross_entropy:
1076
1305
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1077
- modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
1306
+ if model is not None:
1307
+ model.forward = MethodType(qwen2_lce_forward, model)
1308
+ else:
1309
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
1078
1310
  else: # if version < 4.46.1
1079
1311
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1080
- modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
1312
+ if model is not None:
1313
+ model.forward = MethodType(qwen2_lce_forward_deprecated, model)
1314
+ else:
1315
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
1081
1316
 
1082
1317
  if swiglu:
1083
1318
  modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
@@ -1133,7 +1368,10 @@ def apply_liger_kernel_to_qwen3(
1133
1368
  nn.functional.cross_entropy = liger_cross_entropy
1134
1369
 
1135
1370
  if fused_linear_cross_entropy:
1136
- modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
1371
+ if model is not None:
1372
+ model.forward = MethodType(qwen3_lce_forward, model)
1373
+ else:
1374
+ modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
1137
1375
 
1138
1376
  if swiglu:
1139
1377
  modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
@@ -1188,7 +1426,10 @@ def apply_liger_kernel_to_qwen3_moe(
1188
1426
  nn.functional.cross_entropy = liger_cross_entropy
1189
1427
 
1190
1428
  if fused_linear_cross_entropy:
1191
- modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
1429
+ if model is not None:
1430
+ model.forward = MethodType(qwen3_lce_forward, model)
1431
+ else:
1432
+ modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
1192
1433
 
1193
1434
  if swiglu:
1194
1435
  modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
@@ -1204,7 +1445,8 @@ def apply_liger_kernel_to_qwen3_moe(
1204
1445
  _patch_rms_norm_module(base_model.norm)
1205
1446
  for decoder_layer in base_model.layers:
1206
1447
  if swiglu:
1207
- _patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
1448
+ for mlp_expert in decoder_layer.mlp.experts:
1449
+ _patch_swiglu_module(mlp_expert, LigerQwen3MoeSwiGLUMLP)
1208
1450
  if rms_norm:
1209
1451
  _patch_rms_norm_module(decoder_layer.input_layernorm)
1210
1452
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -1221,7 +1463,7 @@ def apply_liger_kernel_to_qwen2_vl(
1221
1463
  ) -> None:
1222
1464
  """
1223
1465
  Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
1224
- NOTE: Qwen2-VL is not available in transformers<4.45.0
1466
+ NOTE: Qwen2-VL is not supported in transformers<4.52.4
1225
1467
 
1226
1468
  Args:
1227
1469
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
@@ -1235,12 +1477,19 @@ def apply_liger_kernel_to_qwen2_vl(
1235
1477
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1236
1478
  loaded. Default is None.
1237
1479
  """
1480
+ if transformer_version < version.parse("4.52.4"):
1481
+ logger.warning("Qwen2-VL support is only compatible with transformers >= 4.52.4")
1482
+ return
1483
+
1238
1484
  assert not (cross_entropy and fused_linear_cross_entropy), (
1239
1485
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
1240
1486
  )
1241
1487
 
1242
1488
  from transformers.models.qwen2_vl import modeling_qwen2_vl
1489
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
1490
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
1243
1491
  from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
1492
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel
1244
1493
 
1245
1494
  from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
1246
1495
 
@@ -1249,12 +1498,15 @@ def apply_liger_kernel_to_qwen2_vl(
1249
1498
  if rms_norm:
1250
1499
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
1251
1500
  modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
1252
- if layer_norm:
1501
+ if layer_norm and model is None:
1253
1502
  modeling_qwen2_vl.LayerNorm = LigerLayerNorm
1254
1503
  if cross_entropy:
1255
1504
  modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1256
1505
  if fused_linear_cross_entropy:
1257
- modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
1506
+ if model is not None:
1507
+ model.forward = MethodType(qwen2_vl_lce_forward, model)
1508
+ else:
1509
+ modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
1258
1510
  if swiglu:
1259
1511
  modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
1260
1512
 
@@ -1262,24 +1514,38 @@ def apply_liger_kernel_to_qwen2_vl(
1262
1514
  # The model instance already exists, so we need to additionally patch the
1263
1515
  # instance variables that reference already-instantiated modules
1264
1516
 
1265
- # get the base model from the model instance
1266
- base_model: Qwen2VLModel = getattr(model, model.base_model_prefix, model)
1517
+ if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)):
1518
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
1519
+ # Not sure if it is subject to changes in the future.
1520
+ # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
1521
+ text_model: Qwen2VLTextModel = model.language_model
1522
+ vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
1523
+ elif isinstance(model, Qwen2VLTextModel):
1524
+ text_model: Qwen2VLTextModel = model
1525
+ vision_model = None
1526
+ else:
1527
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
1528
+ raise TypeError(
1529
+ f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
1530
+ )
1267
1531
 
1268
- if hasattr(model, "visual"):
1269
- # Patch Qwen2VisionTransformerPretrainedModel
1270
- for vision_block in model.visual.blocks:
1532
+ # Patch Qwen2VisionTransformerPretrainedModel
1533
+ if vision_model is not None:
1534
+ for vision_block in vision_model.blocks:
1271
1535
  if layer_norm:
1272
1536
  _patch_layer_norm_module(vision_block.norm1)
1273
1537
  _patch_layer_norm_module(vision_block.norm2)
1274
1538
 
1275
- if rms_norm:
1276
- _patch_rms_norm_module(base_model.norm)
1277
- for decoder_layer in base_model.layers:
1278
- if swiglu:
1279
- _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1539
+ # Patch Qwen2VisionTextModel
1540
+ if text_model is not None:
1280
1541
  if rms_norm:
1281
- _patch_rms_norm_module(decoder_layer.input_layernorm)
1282
- _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1542
+ _patch_rms_norm_module(text_model.norm)
1543
+ for decoder_layer in text_model.layers:
1544
+ if swiglu:
1545
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1546
+ if rms_norm:
1547
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1548
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1283
1549
 
1284
1550
 
1285
1551
  def apply_liger_kernel_to_qwen2_5_vl(
@@ -1305,12 +1571,19 @@ def apply_liger_kernel_to_qwen2_5_vl(
1305
1571
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1306
1572
  loaded. Default is None.
1307
1573
  """
1574
+ if transformer_version < version.parse("4.52.4"):
1575
+ logger.warning("Qwen2.5-VL support is only compatible with transformers >= 4.52.4")
1576
+ return
1577
+
1308
1578
  assert not (cross_entropy and fused_linear_cross_entropy), (
1309
1579
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
1310
1580
  )
1311
1581
 
1312
1582
  from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
1583
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
1584
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
1313
1585
  from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
1586
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel
1314
1587
 
1315
1588
  from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward
1316
1589
 
@@ -1321,7 +1594,10 @@ def apply_liger_kernel_to_qwen2_5_vl(
1321
1594
  if cross_entropy:
1322
1595
  modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1323
1596
  if fused_linear_cross_entropy:
1324
- modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
1597
+ if model is not None:
1598
+ model.forward = MethodType(qwen2_5_vl_lce_forward, model)
1599
+ else:
1600
+ modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
1325
1601
  if swiglu:
1326
1602
  modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
1327
1603
 
@@ -1329,24 +1605,37 @@ def apply_liger_kernel_to_qwen2_5_vl(
1329
1605
  # The model instance already exists, so we need to additionally patch the
1330
1606
  # instance variables that reference already-instantiated modules
1331
1607
 
1332
- # get the base model from the model instance
1333
- base_model: Qwen2_5_VLModel = getattr(model, model.base_model_prefix, model)
1608
+ if isinstance(model, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel)):
1609
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
1610
+ # Not sure if it is subject to changes in the future.
1611
+ # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
1612
+ text_model: Qwen2_5_VLTextModel = model.language_model
1613
+ vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
1614
+ elif isinstance(model, Qwen2_5_VLTextModel):
1615
+ text_model: Qwen2_5_VLTextModel = model
1616
+ vision_model = None
1617
+ else:
1618
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
1619
+ raise TypeError(
1620
+ f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
1621
+ )
1334
1622
 
1335
- if hasattr(model, "visual"):
1623
+ if vision_model is not None:
1336
1624
  # Patch Qwen2_5_VisionTransformerPretrainedModel
1337
1625
  for vision_block in model.visual.blocks:
1338
1626
  if rms_norm:
1339
1627
  _patch_rms_norm_module(vision_block.norm1)
1340
1628
  _patch_rms_norm_module(vision_block.norm2)
1341
1629
 
1342
- if rms_norm:
1343
- _patch_rms_norm_module(base_model.norm)
1344
- for decoder_layer in base_model.layers:
1345
- if swiglu:
1346
- _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1630
+ if text_model is not None:
1347
1631
  if rms_norm:
1348
- _patch_rms_norm_module(decoder_layer.input_layernorm)
1349
- _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1632
+ _patch_rms_norm_module(text_model.norm)
1633
+ for decoder_layer in text_model.layers:
1634
+ if swiglu:
1635
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1636
+ if rms_norm:
1637
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1638
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1350
1639
 
1351
1640
 
1352
1641
  def apply_liger_kernel_to_phi3(
@@ -1395,10 +1684,16 @@ def apply_liger_kernel_to_phi3(
1395
1684
  modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
1396
1685
  if fused_linear_cross_entropy:
1397
1686
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1398
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1687
+ if model is not None:
1688
+ model.forward = MethodType(phi3_lce_forward, model)
1689
+ else:
1690
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1399
1691
  else: # if version < 4.46.1
1400
1692
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1401
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
1693
+ if model is not None:
1694
+ model.forward = MethodType(phi3_lce_forward_deprecated, model)
1695
+ else:
1696
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
1402
1697
 
1403
1698
  if model is not None:
1404
1699
  # The model instance already exists, so we need to additionally patch the
@@ -1449,11 +1744,12 @@ def apply_liger_kernel_to_olmo2(
1449
1744
  from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
1450
1745
 
1451
1746
  from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
1747
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
1452
1748
 
1453
1749
  if rope:
1454
1750
  modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
1455
1751
  if rms_norm:
1456
- modeling_olmo2.Olmo2RMSNorm = partial(LigerRMSNorm, in_place=False)
1752
+ modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
1457
1753
  if swiglu:
1458
1754
  modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
1459
1755
  if cross_entropy:
@@ -1461,7 +1757,10 @@ def apply_liger_kernel_to_olmo2(
1461
1757
 
1462
1758
  nn.functional.cross_entropy = liger_cross_entropy
1463
1759
  if fused_linear_cross_entropy:
1464
- modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
1760
+ if model is not None:
1761
+ model.forward = MethodType(olmo2_lce_forward, model)
1762
+ else:
1763
+ modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
1465
1764
 
1466
1765
  if model is not None:
1467
1766
  # The model instance already exists, so we need to additionally patch the
@@ -1512,11 +1811,12 @@ def apply_liger_kernel_to_glm4(
1512
1811
  from transformers.models.glm4.modeling_glm4 import Glm4Model
1513
1812
 
1514
1813
  from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
1814
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
1515
1815
 
1516
1816
  if rope:
1517
1817
  raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
1518
1818
  if rms_norm:
1519
- modeling_glm4.Glm4RMSNorm = partial(LigerRMSNorm, in_place=False)
1819
+ modeling_glm4.Glm4RMSNorm = LigerRMSNormForGlm4
1520
1820
  if swiglu:
1521
1821
  modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
1522
1822
  if cross_entropy:
@@ -1524,7 +1824,10 @@ def apply_liger_kernel_to_glm4(
1524
1824
 
1525
1825
  nn.functional.cross_entropy = liger_cross_entropy
1526
1826
  if fused_linear_cross_entropy:
1527
- modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
1827
+ if model is not None:
1828
+ model.forward = MethodType(glm4_lce_forward, model)
1829
+ else:
1830
+ modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
1528
1831
 
1529
1832
  if model is not None:
1530
1833
  # The model instance already exists, so we need to additionally patch the
@@ -1554,6 +1857,8 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1554
1857
  "gemma3": apply_liger_kernel_to_gemma3,
1555
1858
  "glm4": apply_liger_kernel_to_glm4,
1556
1859
  "llama": apply_liger_kernel_to_llama,
1860
+ "llama4_text": apply_liger_kernel_to_llama4,
1861
+ "llama4": apply_liger_kernel_to_llama4,
1557
1862
  "llava": apply_liger_kernel_to_llava,
1558
1863
  "granite": apply_liger_kernel_to_granite,
1559
1864
  "mllama": apply_liger_kernel_to_mllama,
@@ -1565,7 +1870,10 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1565
1870
  "qwen3": apply_liger_kernel_to_qwen3,
1566
1871
  "qwen3_moe": apply_liger_kernel_to_qwen3_moe,
1567
1872
  "qwen2_vl": apply_liger_kernel_to_qwen2_vl,
1873
+ "qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
1568
1874
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
1875
+ "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
1876
+ "smollm3": apply_liger_kernel_to_smollm3,
1569
1877
  "phi3": apply_liger_kernel_to_phi3,
1570
1878
  "paligemma": apply_liger_kernel_to_paligemma,
1571
1879
  }
@@ -1625,7 +1933,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
1625
1933
  return
1626
1934
 
1627
1935
  apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
1628
-
1629
1936
  apply_fn_signature = inspect.signature(apply_fn)
1630
1937
 
1631
1938
  # Filter out the keyword arguments that are not supported by the apply function