liger-kernel-nightly 0.5.5.dev20250402185702__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (115) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +36 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +46 -15
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  15. liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
  16. liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
  17. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  18. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  19. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  20. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  21. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  22. liger_kernel/ops/backends/registry.py +61 -0
  23. liger_kernel/ops/cross_entropy.py +134 -65
  24. liger_kernel/ops/dyt.py +115 -180
  25. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  26. liger_kernel/ops/fused_linear_cross_entropy.py +117 -23
  27. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  28. liger_kernel/ops/geglu.py +6 -4
  29. liger_kernel/ops/group_norm.py +7 -7
  30. liger_kernel/ops/grpo_loss.py +312 -0
  31. liger_kernel/ops/jsd.py +2 -1
  32. liger_kernel/ops/kl_div.py +9 -5
  33. liger_kernel/ops/layer_norm.py +146 -78
  34. liger_kernel/ops/llama4_rope.py +225 -0
  35. liger_kernel/ops/multi_token_attention.py +207 -0
  36. liger_kernel/ops/poly_norm.py +390 -0
  37. liger_kernel/ops/rms_norm.py +398 -99
  38. liger_kernel/ops/rope.py +1 -1
  39. liger_kernel/ops/softmax.py +201 -0
  40. liger_kernel/ops/sparsemax.py +179 -0
  41. liger_kernel/ops/swiglu.py +1 -1
  42. liger_kernel/ops/tiled_mlp.py +136 -0
  43. liger_kernel/ops/utils.py +14 -0
  44. liger_kernel/transformers/__init__.py +208 -17
  45. liger_kernel/transformers/auto_model.py +21 -0
  46. liger_kernel/transformers/cross_entropy.py +9 -4
  47. liger_kernel/transformers/dyt.py +6 -4
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -1
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +122 -20
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -1
  57. liger_kernel/transformers/group_norm.py +1 -1
  58. liger_kernel/transformers/grpo_loss.py +153 -0
  59. liger_kernel/transformers/jsd.py +1 -1
  60. liger_kernel/transformers/kl_div.py +1 -1
  61. liger_kernel/transformers/layer_norm.py +1 -1
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/exaone4.py +136 -0
  64. liger_kernel/transformers/model/falcon_h1.py +122 -0
  65. liger_kernel/transformers/model/gemma.py +57 -27
  66. liger_kernel/transformers/model/gemma2.py +65 -28
  67. liger_kernel/transformers/model/gemma3.py +331 -0
  68. liger_kernel/transformers/model/glm4.py +141 -0
  69. liger_kernel/transformers/model/glm4v.py +163 -0
  70. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  71. liger_kernel/transformers/model/gpt_oss.py +211 -0
  72. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  73. liger_kernel/transformers/model/internvl.py +157 -0
  74. liger_kernel/transformers/model/llama.py +109 -27
  75. liger_kernel/transformers/model/llama4.py +121 -0
  76. liger_kernel/transformers/model/llava.py +111 -136
  77. liger_kernel/transformers/model/loss_utils.py +50 -12
  78. liger_kernel/transformers/model/mistral.py +51 -34
  79. liger_kernel/transformers/model/mixtral.py +50 -29
  80. liger_kernel/transformers/model/mllama.py +46 -24
  81. liger_kernel/transformers/model/olmo2.py +47 -22
  82. liger_kernel/transformers/model/olmo3.py +142 -0
  83. liger_kernel/transformers/model/output_classes.py +147 -0
  84. liger_kernel/transformers/model/paligemma.py +50 -14
  85. liger_kernel/transformers/model/phi3.py +47 -172
  86. liger_kernel/transformers/model/qwen2.py +55 -23
  87. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  88. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  89. liger_kernel/transformers/model/qwen3.py +136 -0
  90. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  91. liger_kernel/transformers/model/qwen3_next.py +146 -0
  92. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  93. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  94. liger_kernel/transformers/model/smollm3.py +199 -0
  95. liger_kernel/transformers/model/smolvlm.py +158 -0
  96. liger_kernel/transformers/monkey_patch.py +2018 -244
  97. liger_kernel/transformers/multi_token_attention.py +64 -0
  98. liger_kernel/transformers/poly_norm.py +42 -0
  99. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  100. liger_kernel/transformers/rms_norm.py +54 -6
  101. liger_kernel/transformers/rope.py +45 -1
  102. liger_kernel/transformers/softmax.py +12 -0
  103. liger_kernel/transformers/sparsemax.py +16 -0
  104. liger_kernel/transformers/swiglu.py +39 -1
  105. liger_kernel/transformers/tiled_mlp.py +125 -0
  106. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  107. liger_kernel/transformers/tvd.py +1 -1
  108. liger_kernel/utils.py +63 -0
  109. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +73 -39
  110. liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
  111. liger_kernel_nightly-0.5.5.dev20250402185702.dist-info/RECORD +0 -80
  112. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
  115. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,9 @@ import inspect
2
2
  import logging
3
3
 
4
4
  from functools import partial
5
+ from types import MethodType
5
6
  from typing import Callable
7
+ from typing import Optional
6
8
 
7
9
  import transformers
8
10
 
@@ -13,10 +15,12 @@ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
13
15
  from liger_kernel.transformers.functional import liger_cross_entropy
14
16
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
15
17
  from liger_kernel.transformers.layer_norm import LigerLayerNorm
18
+ from liger_kernel.transformers.model.falcon_h1 import lce_forward as falcon_h1_lce_forward
16
19
  from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
17
20
  from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
18
21
  from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
19
22
  from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
23
+ from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward
20
24
  from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
21
25
  from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
22
26
  from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
@@ -25,16 +29,24 @@ from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_f
25
29
  from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
26
30
  from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
27
31
  from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
28
- from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
29
32
  from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
30
33
  from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
34
+ from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
31
35
  from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
32
36
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
33
37
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
38
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb_vision
34
39
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
35
40
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
36
41
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
37
42
 
43
+ try:
44
+ import peft
45
+
46
+ PEFT_AVAILABLE = True
47
+ except ImportError:
48
+ PEFT_AVAILABLE = False
49
+
38
50
  transformer_version = version.parse(transformers.__version__)
39
51
 
40
52
  logger = logging.getLogger(__name__)
@@ -47,33 +59,82 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
47
59
  module.__dict__[method_name] = new_method.__get__(module, module.__class__)
48
60
 
49
61
 
50
- def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
51
- module.offset = offset
52
- module.casting_mode = casting_mode
53
- module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
54
- module.in_place = in_place
55
- _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
56
- _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
57
- module.__class__.__name__ = LigerRMSNorm.__name__
62
+ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None):
63
+ # Check if the module is a PEFT ModulesToSaveWrapper
64
+ # If it is, we need to patch the modules_to_save.default and original_modules
65
+ if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
66
+ module.modules_to_save.default.offset = offset
67
+ module.modules_to_save.default.casting_mode = casting_mode
68
+ module.modules_to_save.default.variance_epsilon = (
69
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
70
+ )
71
+ module.modules_to_save.default.in_place = in_place
72
+ module.modules_to_save.default.row_mode = row_mode
73
+ module.original_module.offset = offset
74
+ module.original_module.casting_mode = casting_mode
75
+ module.original_module.variance_epsilon = (
76
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
77
+ )
78
+ module.original_module.in_place = in_place
79
+ module.original_module.row_mode = row_mode
80
+ _bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
81
+ _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
82
+ _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
83
+ _bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
84
+ _bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerRMSNorm.__name__)
85
+ _bind_method_to_module(module.original_module, "_get_name", lambda self: LigerRMSNorm.__name__)
86
+ else:
87
+ module.offset = offset
88
+ module.casting_mode = casting_mode
89
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
90
+ module.in_place = in_place
91
+ module.row_mode = row_mode
92
+ _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
93
+ _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
94
+ _bind_method_to_module(module, "_get_name", lambda self: LigerRMSNorm.__name__)
58
95
 
59
96
 
60
97
  def _patch_layer_norm_module(module, eps=1e-6):
61
- module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
62
- module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
63
-
64
- _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
65
- _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
66
- module.__class__.__name__ = LigerLayerNorm.__name__
98
+ # Check if the module is a PEFT ModulesToSaveWrapper
99
+ # If it is, we need to patch the modules_to_save.default and original_modules
100
+ if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
101
+ module.hidden_size = module.normalized_shape
102
+ _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
103
+ _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
104
+ module.modules_to_save.default.variance_epsilon = (
105
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
106
+ )
107
+ module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
108
+ module, "normalized_shape", None
109
+ )
110
+ module.original_module.variance_epsilon = (
111
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
112
+ )
113
+ module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
114
+ module, "normalized_shape", None
115
+ )
116
+ _bind_method_to_module(module.modules_to_save.default, "forward", LigerLayerNorm.forward)
117
+ _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerLayerNorm.extra_repr)
118
+ _bind_method_to_module(module.original_module, "forward", LigerLayerNorm.forward)
119
+ _bind_method_to_module(module.original_module, "extra_repr", LigerLayerNorm.extra_repr)
120
+ _bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerLayerNorm.__name__)
121
+ _bind_method_to_module(module.original_module, "_get_name", lambda self: LigerLayerNorm.__name__)
122
+ else:
123
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
124
+ module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
125
+ _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
126
+ _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
127
+ _bind_method_to_module(module, "_get_name", lambda self: LigerLayerNorm.__name__)
67
128
 
68
129
 
69
130
  def _patch_swiglu_module(module, liger_module):
70
131
  _bind_method_to_module(module, "forward", liger_module.forward)
71
- module.__class__.__name__ = liger_module.__name__
132
+ _bind_method_to_module(module, "_get_name", lambda self: liger_module.__name__)
72
133
 
73
134
 
74
135
  def _patch_geglu_module(module):
75
136
  _bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
76
- module.__class__.__name__ = LigerGEGLUMLP.__name__
137
+ _bind_method_to_module(module, "_get_name", lambda self: LigerGEGLUMLP.__name__)
77
138
 
78
139
 
79
140
  def apply_liger_kernel_to_granite(
@@ -204,10 +265,16 @@ def apply_liger_kernel_to_llama(
204
265
 
205
266
  if fused_linear_cross_entropy:
206
267
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
207
- modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
268
+ if model is not None:
269
+ model.forward = MethodType(llama_lce_forward, model)
270
+ else:
271
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
208
272
  else: # if version < 4.46.1
209
273
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
210
- modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
274
+ if model is not None:
275
+ model.forward = MethodType(llama_lce_forward_deprecated, model)
276
+ else:
277
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
211
278
 
212
279
  if model is not None:
213
280
  # The model instance already exists, so we need to additionally patch the
@@ -227,6 +294,77 @@ def apply_liger_kernel_to_llama(
227
294
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
228
295
 
229
296
 
297
+ def apply_liger_kernel_to_smollm3(
298
+ rope: bool = True,
299
+ cross_entropy: bool = False,
300
+ fused_linear_cross_entropy: bool = True,
301
+ rms_norm: bool = True,
302
+ swiglu: bool = True,
303
+ model: PreTrainedModel = None,
304
+ ) -> None:
305
+ """
306
+ Apply Liger kernels to replace original implementation in HuggingFace SmolLM3 model
307
+
308
+ Args:
309
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
310
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
311
+ fused_linear_cross_entropy (bool):
312
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
313
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
314
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
315
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
316
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
317
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
318
+ loaded. Default is None.
319
+ """
320
+
321
+ assert not (cross_entropy and fused_linear_cross_entropy), (
322
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
323
+ )
324
+
325
+ from transformers.models.smollm3 import modeling_smollm3
326
+ from transformers.models.smollm3.modeling_smollm3 import SmolLM3Model
327
+
328
+ if rope:
329
+ modeling_smollm3.apply_rotary_pos_emb = liger_rotary_pos_emb
330
+ if rms_norm:
331
+ modeling_smollm3.SmolLM3RMSNorm = LigerRMSNorm
332
+ if swiglu:
333
+ modeling_smollm3.SmolLM3MLP = LigerSwiGLUMLP
334
+
335
+ if cross_entropy:
336
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
337
+ from transformers.loss.loss_utils import nn
338
+
339
+ nn.functional.cross_entropy = liger_cross_entropy
340
+ else:
341
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
342
+ modeling_smollm3.CrossEntropyLoss = LigerCrossEntropyLoss
343
+
344
+ if fused_linear_cross_entropy:
345
+ if model is not None:
346
+ model.forward = MethodType(smollm3_lce_forward, model)
347
+ else:
348
+ modeling_smollm3.SmolLM3ForCausalLM.forward = smollm3_lce_forward
349
+
350
+ if model is not None:
351
+ # The model instance already exists, so we need to additionally patch the
352
+ # instance variables that reference already-instantiated modules (e.g. SmolLM3RMSNorm or SmolLM3MLP)
353
+
354
+ # get the base model from the model instance
355
+ base_model: SmolLM3Model = getattr(model, model.base_model_prefix, model)
356
+
357
+ if rms_norm:
358
+ _patch_rms_norm_module(base_model.norm)
359
+
360
+ for decoder_layer in base_model.layers:
361
+ if swiglu:
362
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
363
+ if rms_norm:
364
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
365
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
366
+
367
+
230
368
  def apply_liger_kernel_to_llava(
231
369
  cross_entropy: bool = False,
232
370
  fused_linear_cross_entropy: bool = True,
@@ -261,13 +399,20 @@ def apply_liger_kernel_to_llava(
261
399
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
262
400
  modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
263
401
  if fused_linear_cross_entropy:
264
- if transformer_version >= version.parse("4.49.0"):
265
- modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
402
+ if transformer_version >= version.parse("4.52.0"):
403
+ if model is not None:
404
+ model.forward = MethodType(llava_lce_forward, model)
405
+ else:
406
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
407
+ elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
408
+ if model is not None:
409
+ model.forward = MethodType(llava_lce_forward_deprecated, model)
410
+ else:
411
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
266
412
  else: # if version < 4.49.0
267
413
  logger.warning(
268
- "Support for transformers versions < 4.49.0 will soon be discontinued due to issues with incorrect legacy processing. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/35526"
414
+ "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
269
415
  )
270
- modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
271
416
 
272
417
  if model is not None:
273
418
  text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
@@ -285,7 +430,7 @@ def apply_liger_kernel_to_llava(
285
430
  f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
286
431
  f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
287
432
  )
288
- text_kwargs["model"] = model.language_model
433
+ text_kwargs["model"] = model.model.language_model
289
434
  text_liger_fn(**text_kwargs)
290
435
  elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
291
436
  logger.warning(f"{text_model_name} is not supported by Liger kernel.")
@@ -300,12 +445,103 @@ def apply_liger_kernel_to_llava(
300
445
  f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
301
446
  f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
302
447
  )
303
- vision_kwargs["model"] = model.vision_tower
448
+ vision_kwargs["model"] = model.model.vision_tower
304
449
  vision_liger_fn(**vision_kwargs)
305
450
  elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
306
451
  logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
307
452
 
308
453
 
454
+ def apply_liger_kernel_to_llama4(
455
+ rope: bool = True,
456
+ cross_entropy: bool = False,
457
+ fused_linear_cross_entropy: bool = True,
458
+ rms_norm: bool = True,
459
+ swiglu: bool = True,
460
+ model: PreTrainedModel = None,
461
+ layer_norm: bool = True,
462
+ ) -> None:
463
+ """
464
+ Apply Liger kernels to replace original implementation in HuggingFace Llama4 models.
465
+
466
+ Args:
467
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
468
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
469
+ fused_linear_cross_entropy (bool):
470
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
471
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
472
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
473
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
474
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
475
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
476
+ loaded. Default is None.
477
+ """
478
+ assert not (cross_entropy and fused_linear_cross_entropy), (
479
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
480
+ )
481
+
482
+ from transformers.models.llama4 import modeling_llama4
483
+ from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
484
+ from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
485
+ from transformers.models.llama4.modeling_llama4 import Llama4TextModel
486
+ from transformers.models.llama4.modeling_llama4 import Llama4VisionModel
487
+
488
+ from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
489
+
490
+ if rope:
491
+ from liger_kernel.transformers.llama4_rope import apply_liger_llama4_rope_full
492
+
493
+ apply_liger_llama4_rope_full(modeling_llama4)
494
+ if rms_norm:
495
+ modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
496
+ if swiglu:
497
+ modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
498
+
499
+ if cross_entropy:
500
+ modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
501
+
502
+ if fused_linear_cross_entropy:
503
+ modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
504
+
505
+ if model is not None:
506
+ # The model instance already exists, so we need to additionally patch the
507
+ # instance variables that reference already-instantiated modules
508
+ if isinstance(model, Llama4ForConditionalGeneration):
509
+ language_model: Llama4ForCausalLM = model.language_model
510
+ vision_model: Llama4VisionModel = model.vision_model
511
+ text_model: Llama4TextModel = language_model.model
512
+ elif isinstance(model, Llama4ForCausalLM):
513
+ text_model = model.model
514
+ vision_model = None
515
+ elif isinstance(model, Llama4TextModel):
516
+ text_model = model
517
+ vision_model = None
518
+
519
+ else:
520
+ raise ValueError(f"Unsupported Llama4 model type: {type(model)}")
521
+
522
+ if text_model:
523
+ if rms_norm:
524
+ _patch_rms_norm_module(text_model.norm)
525
+ for decoder_layer in text_model.layers:
526
+ if swiglu:
527
+ if decoder_layer.is_moe_layer:
528
+ _patch_swiglu_module(decoder_layer.feed_forward.shared_expert, LigerSwiGLUMLP)
529
+ else:
530
+ _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
531
+ if rms_norm:
532
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
533
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
534
+
535
+ if vision_model:
536
+ _patch_layer_norm_module(vision_model.layernorm_pre)
537
+ _patch_layer_norm_module(vision_model.layernorm_post)
538
+
539
+ for layer in vision_model.model.layers:
540
+ if layer_norm:
541
+ _patch_layer_norm_module(layer.input_layernorm)
542
+ _patch_layer_norm_module(layer.post_attention_layernorm)
543
+
544
+
309
545
  def apply_liger_kernel_to_mllama(
310
546
  rope: bool = True,
311
547
  cross_entropy: bool = False,
@@ -347,7 +583,7 @@ def apply_liger_kernel_to_mllama(
347
583
 
348
584
  if rope:
349
585
  modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
350
- if layer_norm:
586
+ if layer_norm and model is None:
351
587
  modeling_mllama.nn.LayerNorm = LigerLayerNorm
352
588
  if rms_norm:
353
589
  modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
@@ -363,25 +599,35 @@ def apply_liger_kernel_to_mllama(
363
599
  modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
364
600
  if fused_linear_cross_entropy:
365
601
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
366
- modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
602
+ if model is not None:
603
+ model.forward = MethodType(mllama_lce_forward, model)
604
+ else:
605
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
367
606
  else: # if version < 4.46.1
368
607
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
369
- modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
608
+ if model is not None:
609
+ model.forward = MethodType(mllama_lce_forward_deprecated, model)
610
+ else:
611
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
370
612
 
371
613
  if model is not None:
372
614
  # The model instance already exists, so we need to additionally patch the
373
615
  # instance variables that reference already-instantiated modules
374
616
 
375
617
  if isinstance(model, MllamaForConditionalGeneration):
376
- language_model: MllamaForCausalLM = model.language_model
377
- vision_model: MllamaVisionModel = model.vision_model
378
- text_model: MllamaTextModel = language_model.model
618
+ language_model: MllamaForCausalLM = model.model.language_model
619
+ vision_model: MllamaVisionModel = model.model.vision_model
620
+ if isinstance(language_model, MllamaForCausalLM):
621
+ text_model: MllamaTextModel = language_model.model
622
+ else:
623
+ text_model = language_model
379
624
  elif isinstance(model, MllamaForCausalLM):
380
625
  text_model = model.model
381
626
  vision_model = None
382
627
  elif isinstance(model, MllamaTextModel):
383
628
  text_model = model
384
629
  vision_model = None
630
+
385
631
  else:
386
632
  raise ValueError(f"Unsupported Mllama model type: {type(model)}")
387
633
 
@@ -448,7 +694,17 @@ def apply_liger_kernel_to_mistral(
448
694
  if cross_entropy:
449
695
  modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
450
696
  if fused_linear_cross_entropy:
451
- modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
697
+ if transformer_version >= version.parse("4.49.0"):
698
+ if model is not None:
699
+ model.forward = MethodType(mistral_lce_forward, model)
700
+ else:
701
+ modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
702
+ else:
703
+ logger.warning(
704
+ "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
705
+ )
706
+ logger.warning("LigerFusedLinearCrossEntropy patch is not applied.")
707
+
452
708
  if swiglu:
453
709
  modeling_mistral.MistralMLP = LigerSwiGLUMLP
454
710
 
@@ -516,10 +772,16 @@ def apply_liger_kernel_to_mixtral(
516
772
 
517
773
  if fused_linear_cross_entropy:
518
774
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
519
- modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
775
+ if model is not None:
776
+ model.forward = MethodType(mixtral_lce_forward, model)
777
+ else:
778
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
520
779
  else: # if version < 4.46.1
521
780
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
522
- modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
781
+ if model is not None:
782
+ model.forward = MethodType(mixtral_lce_forward_deprecated, model)
783
+ else:
784
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
523
785
  if swiglu:
524
786
  modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
525
787
 
@@ -573,8 +835,8 @@ def apply_liger_kernel_to_gemma(
573
835
  from transformers.models.gemma import modeling_gemma
574
836
  from transformers.models.gemma.modeling_gemma import GemmaModel
575
837
 
576
- # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
577
- LigerRMSNormForGemma = partial(LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma")
838
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma
839
+
578
840
  _patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
579
841
 
580
842
  if rope:
@@ -593,10 +855,16 @@ def apply_liger_kernel_to_gemma(
593
855
  modeling_gemma.GemmaMLP = LigerGEGLUMLP
594
856
  if fused_linear_cross_entropy:
595
857
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
596
- modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
858
+ if model is not None:
859
+ model.forward = MethodType(gemma_lce_forward, model)
860
+ else:
861
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
597
862
  else: # if version < 4.46.1
598
863
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
599
- modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
864
+ if model is not None:
865
+ model.forward = MethodType(gemma_lce_forward_deprecated, model)
866
+ else:
867
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
600
868
 
601
869
  if model is not None:
602
870
  # The model instance already exists, so we need to additionally patch the
@@ -647,7 +915,8 @@ def apply_liger_kernel_to_gemma2(
647
915
  from transformers.models.gemma2 import modeling_gemma2
648
916
  from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
649
917
 
650
- LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False)
918
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2
919
+
651
920
  _patch_rms_norm_module_for_gemma2 = partial(
652
921
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
653
922
  )
@@ -667,10 +936,16 @@ def apply_liger_kernel_to_gemma2(
667
936
  modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
668
937
  if fused_linear_cross_entropy:
669
938
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
670
- modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
939
+ if model is not None:
940
+ model.forward = MethodType(gemma2_lce_forward, model)
941
+ else:
942
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
671
943
  else:
672
944
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
673
- modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
945
+ if model is not None:
946
+ model.forward = MethodType(gemma2_lce_forward_deprected, model)
947
+ else:
948
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
674
949
  if geglu:
675
950
  modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
676
951
 
@@ -694,17 +969,16 @@ def apply_liger_kernel_to_gemma2(
694
969
  _patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
695
970
 
696
971
 
697
- def apply_liger_kernel_to_paligemma(
972
+ def apply_liger_kernel_to_gemma3_text(
698
973
  rope: bool = True,
699
974
  cross_entropy: bool = False,
700
975
  fused_linear_cross_entropy: bool = True,
701
- layer_norm: bool = True,
702
976
  rms_norm: bool = True,
703
977
  geglu: bool = True,
704
978
  model: PreTrainedModel = None,
705
979
  ) -> None:
706
980
  """
707
- Apply Liger kernels to replace original implementation in HuggingFace PaliGemma
981
+ Apply Liger kernels to replace original implementation in HuggingFace Gemma3
708
982
 
709
983
  Args:
710
984
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
@@ -713,7 +987,6 @@ def apply_liger_kernel_to_paligemma(
713
987
  Whether to apply Liger's fused linear cross entropy loss. Default is True.
714
988
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
715
989
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
716
- layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
717
990
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
718
991
  geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
719
992
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
@@ -723,97 +996,77 @@ def apply_liger_kernel_to_paligemma(
723
996
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
724
997
  )
725
998
 
726
- # PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
999
+ from transformers.models.gemma3 import modeling_gemma3
1000
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
1001
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
1002
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
727
1003
 
728
- from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
729
- from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
730
- from transformers.models.paligemma import modeling_paligemma
731
- from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
732
- from transformers.models.siglip import modeling_siglip
733
- from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
734
- from transformers.models.siglip.modeling_siglip import SiglipVisionModel
1004
+ from liger_kernel.transformers.model.gemma3 import causal_forward
1005
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma3
735
1006
 
736
- from liger_kernel.transformers.model.paligemma import lce_forward
737
- from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
1007
+ _patch_rms_norm_module_for_gemma3 = partial(
1008
+ _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
1009
+ )
738
1010
 
739
- # The vision_tower is a SiglipVisionModel
740
- if layer_norm:
741
- modeling_siglip.nn.LayerNorm = LigerLayerNorm
1011
+ if rope:
1012
+ modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
742
1013
 
743
- # SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
744
- # The multi_modal_projector is Linear, nothing to do
1014
+ if rms_norm:
1015
+ modeling_gemma3.Gemma3RMSNorm = LigerRMSNormForGemma3
1016
+
1017
+ if geglu:
1018
+ modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
745
1019
 
746
- # The language_model is GemmaForCausalLM or Gemma2ForCausalLM
747
- apply_liger_kernel_to_gemma(
748
- rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
749
- )
750
- apply_liger_kernel_to_gemma2(
751
- rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
752
- )
753
1020
  # Handle loss function
754
1021
  if cross_entropy:
755
- modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
1022
+ from transformers.loss.loss_utils import nn
1023
+
1024
+ nn.functional.cross_entropy = liger_cross_entropy
1025
+
756
1026
  if fused_linear_cross_entropy:
757
- if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
758
- modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
759
- else: # if version < 4.46.1
760
- logger.warning(TRANSFORMER_DEPRECATION_WARNING)
761
- modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
1027
+ if model is not None:
1028
+ model.forward = MethodType(causal_forward, model)
1029
+ else:
1030
+ modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
762
1031
 
763
1032
  if model is not None:
764
1033
  # The model instance already exists, so we need to additionally patch the
765
1034
  # instance variables that reference already-instantiated modules
766
1035
 
767
- if not isinstance(model, PaliGemmaForConditionalGeneration):
768
- raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
769
-
770
- vision_tower: SiglipVisionModel = model.vision_tower
771
-
772
- _patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
773
-
774
- for layer in vision_tower.vision_model.encoder.layers:
775
- layer: SiglipEncoderLayer
776
- if layer_norm:
777
- _patch_layer_norm_module(layer.layer_norm1)
778
- _patch_layer_norm_module(layer.layer_norm2)
1036
+ if isinstance(model, Gemma3ForCausalLM) or isinstance(model, Gemma3TextModel):
1037
+ # get the base model from the model instance
1038
+ base_model = model.model if isinstance(model, Gemma3ForCausalLM) else model
779
1039
 
780
- language_model = model.language_model
1040
+ if rms_norm:
1041
+ _patch_rms_norm_module_for_gemma3(base_model.norm)
781
1042
 
782
- if isinstance(language_model, GemmaForCausalLM):
783
- apply_liger_kernel_to_gemma(
784
- rope=rope,
785
- cross_entropy=False,
786
- fused_linear_cross_entropy=False,
787
- rms_norm=rms_norm,
788
- geglu=geglu,
789
- model=language_model,
790
- )
1043
+ for decoder_layer in base_model.layers:
1044
+ decoder_layer: Gemma3DecoderLayer
1045
+ if geglu:
1046
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
1047
+ if rms_norm:
1048
+ _patch_rms_norm_module_for_gemma3(decoder_layer.input_layernorm)
1049
+ _patch_rms_norm_module_for_gemma3(decoder_layer.post_attention_layernorm)
1050
+ _patch_rms_norm_module_for_gemma3(decoder_layer.pre_feedforward_layernorm)
1051
+ _patch_rms_norm_module_for_gemma3(decoder_layer.post_feedforward_layernorm)
1052
+ _patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.q_norm)
1053
+ _patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.k_norm)
791
1054
 
792
- elif isinstance(language_model, Gemma2ForCausalLM):
793
- apply_liger_kernel_to_gemma2(
794
- rope=rope,
795
- cross_entropy=False,
796
- fused_linear_cross_entropy=False,
797
- rms_norm=rms_norm,
798
- geglu=geglu,
799
- model=language_model,
800
- )
801
1055
  else:
802
- raise TypeError(
803
- "The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM."
804
- )
1056
+ raise TypeError("The model must be Gemma3ForCausalLM.")
805
1057
 
806
1058
 
807
- def apply_liger_kernel_to_qwen2(
1059
+ def apply_liger_kernel_to_gemma3(
808
1060
  rope: bool = True,
809
1061
  cross_entropy: bool = False,
810
1062
  fused_linear_cross_entropy: bool = True,
1063
+ layer_norm: bool = True,
811
1064
  rms_norm: bool = True,
812
- swiglu: bool = True,
1065
+ geglu: bool = True,
813
1066
  model: PreTrainedModel = None,
814
1067
  ) -> None:
815
1068
  """
816
- Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models
1069
+ Apply Liger kernels to replace original implementation in HuggingFace Gemma3
817
1070
 
818
1071
  Args:
819
1072
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
@@ -822,8 +1075,9 @@ def apply_liger_kernel_to_qwen2(
822
1075
  Whether to apply Liger's fused linear cross entropy loss. Default is True.
823
1076
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
824
1077
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1078
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
825
1079
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
826
- swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
1080
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
827
1081
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
828
1082
  loaded. Default is None.
829
1083
  """
@@ -831,64 +1085,1378 @@ def apply_liger_kernel_to_qwen2(
831
1085
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
832
1086
  )
833
1087
 
834
- from transformers.models.qwen2 import modeling_qwen2
835
- from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
1088
+ from transformers.models.gemma3 import modeling_gemma3
1089
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3ForConditionalGeneration
1090
+ from transformers.models.siglip import modeling_siglip
1091
+ from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
1092
+ from transformers.models.siglip.modeling_siglip import SiglipVisionModel
836
1093
 
837
- if rope:
838
- modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
839
- if rms_norm:
840
- modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
1094
+ from liger_kernel.transformers.model.gemma3 import multimodal_forward
841
1095
 
842
- if cross_entropy:
843
- if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
844
- from transformers.loss.loss_utils import nn
1096
+ _patch_rms_norm_module_for_gemma3 = partial(
1097
+ _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
1098
+ )
845
1099
 
846
- nn.functional.cross_entropy = liger_cross_entropy
847
- else:
848
- logger.warning(TRANSFORMER_DEPRECATION_WARNING)
849
- modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
1100
+ if layer_norm and model is None:
1101
+ modeling_siglip.nn.LayerNorm = LigerLayerNorm
850
1102
 
851
- if fused_linear_cross_entropy:
852
- if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
853
- modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
854
- else: # if version < 4.46.1
855
- logger.warning(TRANSFORMER_DEPRECATION_WARNING)
856
- modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
1103
+ apply_liger_kernel_to_gemma3_text(
1104
+ rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
1105
+ )
857
1106
 
858
- if swiglu:
859
- modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
1107
+ if cross_entropy:
1108
+ modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
1109
+
1110
+ if fused_linear_cross_entropy:
1111
+ if model is not None:
1112
+ model.forward = MethodType(multimodal_forward, model)
1113
+ else:
1114
+ modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
860
1115
 
861
1116
  if model is not None:
862
1117
  # The model instance already exists, so we need to additionally patch the
863
1118
  # instance variables that reference already-instantiated modules
864
1119
 
865
- # get the base model from the model instance
866
- base_model: Qwen2Model = getattr(model, model.base_model_prefix, model)
1120
+ if isinstance(model, Gemma3ForConditionalGeneration):
1121
+ if isinstance(model.model.vision_tower, SiglipVisionModel):
1122
+ vision_tower = model.model.vision_tower
867
1123
 
868
- if rms_norm:
869
- _patch_rms_norm_module(base_model.norm)
1124
+ _patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
1125
+
1126
+ for layer in vision_tower.vision_model.encoder.layers:
1127
+ layer: SiglipEncoderLayer
1128
+ if layer_norm:
1129
+ _patch_layer_norm_module(layer.layer_norm1)
1130
+ _patch_layer_norm_module(layer.layer_norm2)
1131
+ else:
1132
+ raise TypeError("The vision tower must be SiglipVisionModel")
870
1133
 
871
- for decoder_layer in base_model.layers:
872
- if swiglu:
873
- _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
874
1134
  if rms_norm:
875
- _patch_rms_norm_module(decoder_layer.input_layernorm)
876
- _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
877
- print("Applied Liger kernels to Qwen2")
1135
+ _patch_rms_norm_module_for_gemma3(model.model.multi_modal_projector.mm_soft_emb_norm)
878
1136
 
1137
+ apply_liger_kernel_to_gemma3_text(
1138
+ rope=rope,
1139
+ cross_entropy=False,
1140
+ fused_linear_cross_entropy=False,
1141
+ rms_norm=rms_norm,
1142
+ geglu=geglu,
1143
+ model=model.model.language_model,
1144
+ )
1145
+
1146
+ else:
1147
+ raise TypeError("The model must be Gemma3ForConditionalGeneration.")
1148
+
1149
+
1150
+ def apply_liger_kernel_to_paligemma(
1151
+ rope: bool = True,
1152
+ cross_entropy: bool = False,
1153
+ fused_linear_cross_entropy: bool = True,
1154
+ layer_norm: bool = True,
1155
+ rms_norm: bool = True,
1156
+ geglu: bool = True,
1157
+ model: PreTrainedModel = None,
1158
+ ) -> None:
1159
+ """
1160
+ Apply Liger kernels to replace original implementation in HuggingFace PaliGemma
1161
+
1162
+ Args:
1163
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1164
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1165
+ fused_linear_cross_entropy (bool):
1166
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1167
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1168
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1169
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
1170
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1171
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
1172
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1173
+ loaded. Default is None.
1174
+ """
1175
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1176
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1177
+ )
1178
+
1179
+ # PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
1180
+
1181
+ from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
1182
+ from transformers.models.gemma.modeling_gemma import GemmaModel
1183
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
1184
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
1185
+ from transformers.models.paligemma import modeling_paligemma
1186
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
1187
+ from transformers.models.siglip import modeling_siglip
1188
+ from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
1189
+ from transformers.models.siglip.modeling_siglip import SiglipVisionModel
1190
+
1191
+ from liger_kernel.transformers.model.paligemma import lce_forward
1192
+ from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
1193
+
1194
+ # The vision_tower is a SiglipVisionModel
1195
+ if layer_norm and model is None:
1196
+ modeling_siglip.nn.LayerNorm = LigerLayerNorm
1197
+
1198
+ # SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
1199
+ # The multi_modal_projector is Linear, nothing to do
1200
+
1201
+ # The language_model is GemmaForCausalLM or Gemma2ForCausalLM
1202
+ apply_liger_kernel_to_gemma(
1203
+ rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
1204
+ )
1205
+ apply_liger_kernel_to_gemma2(
1206
+ rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
1207
+ )
1208
+ # Handle loss function
1209
+ if cross_entropy:
1210
+ modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
1211
+ if fused_linear_cross_entropy:
1212
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1213
+ if model is not None:
1214
+ model.forward = MethodType(lce_forward, model)
1215
+ else:
1216
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
1217
+ else: # if version < 4.46.1
1218
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1219
+ if model is not None:
1220
+ model.forward = MethodType(lce_forward_deprecated, model)
1221
+ else:
1222
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
1223
+
1224
+ if model is not None:
1225
+ # The model instance already exists, so we need to additionally patch the
1226
+ # instance variables that reference already-instantiated modules
1227
+
1228
+ if not isinstance(model, PaliGemmaForConditionalGeneration):
1229
+ raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
1230
+
1231
+ vision_tower: SiglipVisionModel = model.model.vision_tower
1232
+
1233
+ _patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
1234
+
1235
+ for layer in vision_tower.vision_model.encoder.layers:
1236
+ layer: SiglipEncoderLayer
1237
+ if layer_norm:
1238
+ _patch_layer_norm_module(layer.layer_norm1)
1239
+ _patch_layer_norm_module(layer.layer_norm2)
1240
+
1241
+ language_model = model.model.language_model
1242
+
1243
+ if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
1244
+ apply_liger_kernel_to_gemma(
1245
+ rope=rope,
1246
+ cross_entropy=False,
1247
+ fused_linear_cross_entropy=False,
1248
+ rms_norm=rms_norm,
1249
+ geglu=geglu,
1250
+ model=language_model,
1251
+ )
1252
+
1253
+ elif isinstance(language_model, (Gemma2ForCausalLM, Gemma2Model)):
1254
+ apply_liger_kernel_to_gemma2(
1255
+ rope=rope,
1256
+ cross_entropy=False,
1257
+ fused_linear_cross_entropy=False,
1258
+ rms_norm=rms_norm,
1259
+ geglu=geglu,
1260
+ model=language_model,
1261
+ )
1262
+ else:
1263
+ raise TypeError(
1264
+ "The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM."
1265
+ )
1266
+
1267
+
1268
+ def apply_liger_kernel_to_qwen2(
1269
+ rope: bool = True,
1270
+ cross_entropy: bool = False,
1271
+ fused_linear_cross_entropy: bool = True,
1272
+ rms_norm: bool = True,
1273
+ swiglu: bool = True,
1274
+ model: PreTrainedModel = None,
1275
+ ) -> None:
1276
+ """
1277
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models
1278
+
1279
+ Args:
1280
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1281
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1282
+ fused_linear_cross_entropy (bool):
1283
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1284
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1285
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1286
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1287
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
1288
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1289
+ loaded. Default is None.
1290
+ """
1291
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1292
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1293
+ )
1294
+
1295
+ from transformers.models.qwen2 import modeling_qwen2
1296
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
1297
+
1298
+ if rope:
1299
+ modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
1300
+ if rms_norm:
1301
+ modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
1302
+
1303
+ if cross_entropy:
1304
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1305
+ from transformers.loss.loss_utils import nn
1306
+
1307
+ nn.functional.cross_entropy = liger_cross_entropy
1308
+ else:
1309
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1310
+ modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
1311
+
1312
+ if fused_linear_cross_entropy:
1313
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1314
+ if model is not None:
1315
+ model.forward = MethodType(qwen2_lce_forward, model)
1316
+ else:
1317
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
1318
+ else: # if version < 4.46.1
1319
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1320
+ if model is not None:
1321
+ model.forward = MethodType(qwen2_lce_forward_deprecated, model)
1322
+ else:
1323
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
1324
+
1325
+ if swiglu:
1326
+ modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
1327
+
1328
+ if model is not None:
1329
+ # The model instance already exists, so we need to additionally patch the
1330
+ # instance variables that reference already-instantiated modules
1331
+
1332
+ # get the base model from the model instance
1333
+ base_model: Qwen2Model = getattr(model, model.base_model_prefix, model)
1334
+
1335
+ if rms_norm:
1336
+ _patch_rms_norm_module(base_model.norm)
1337
+
1338
+ for decoder_layer in base_model.layers:
1339
+ if swiglu:
1340
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1341
+ if rms_norm:
1342
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1343
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1344
+
1345
+
1346
+ def apply_liger_kernel_to_qwen3(
1347
+ rope: bool = True,
1348
+ cross_entropy: bool = False,
1349
+ fused_linear_cross_entropy: bool = True,
1350
+ rms_norm: bool = True,
1351
+ swiglu: bool = True,
1352
+ model: PreTrainedModel = None,
1353
+ ) -> None:
1354
+ """
1355
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
1356
+ """
1357
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1358
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1359
+ )
1360
+
1361
+ from transformers.models.qwen3 import modeling_qwen3
1362
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3Model
1363
+
1364
+ from liger_kernel.transformers.model.qwen3 import lce_forward as qwen3_lce_forward
1365
+
1366
+ if rope:
1367
+ modeling_qwen3.apply_rotary_pos_emb = liger_rotary_pos_emb
1368
+
1369
+ if rms_norm:
1370
+ modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm
1371
+
1372
+ if cross_entropy:
1373
+ from transformers.loss.loss_utils import nn
1374
+
1375
+ nn.functional.cross_entropy = liger_cross_entropy
1376
+
1377
+ if fused_linear_cross_entropy:
1378
+ if model is not None:
1379
+ model.forward = MethodType(qwen3_lce_forward, model)
1380
+ else:
1381
+ modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
1382
+
1383
+ if swiglu:
1384
+ modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
1385
+
1386
+ if model is not None:
1387
+ # The model instance already exists, so we need to additionally patch the
1388
+ # instance variables that reference already-instantiated modules
1389
+
1390
+ # get the base model from the model instance
1391
+ base_model: Qwen3Model = getattr(model, model.base_model_prefix, model)
1392
+
1393
+ if rms_norm:
1394
+ _patch_rms_norm_module(base_model.norm)
1395
+ for decoder_layer in base_model.layers:
1396
+ if swiglu:
1397
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1398
+ if rms_norm:
1399
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1400
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1401
+
1402
+
1403
+ def apply_liger_kernel_to_qwen3_moe(
1404
+ rope: bool = True,
1405
+ cross_entropy: bool = False,
1406
+ fused_linear_cross_entropy: bool = True,
1407
+ rms_norm: bool = True,
1408
+ swiglu: bool = True,
1409
+ model: PreTrainedModel = None,
1410
+ ) -> None:
1411
+ """
1412
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
1413
+ """
1414
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1415
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1416
+ )
1417
+
1418
+ from transformers.models.qwen3_moe import modeling_qwen3_moe
1419
+ from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeModel
1420
+
1421
+ from liger_kernel.transformers.model.qwen3_moe import lce_forward as qwen3_lce_forward
1422
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
1423
+
1424
+ if rope:
1425
+ modeling_qwen3_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
1426
+
1427
+ if rms_norm:
1428
+ modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
1429
+
1430
+ if cross_entropy:
1431
+ from transformers.loss.loss_utils import nn
1432
+
1433
+ nn.functional.cross_entropy = liger_cross_entropy
1434
+
1435
+ if fused_linear_cross_entropy:
1436
+ if model is not None:
1437
+ model.forward = MethodType(qwen3_lce_forward, model)
1438
+ else:
1439
+ modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
1440
+
1441
+ if swiglu:
1442
+ modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
1443
+
1444
+ if model is not None:
1445
+ # The model instance already exists, so we need to additionally patch the
1446
+ # instance variables that reference already-instantiated modules
1447
+
1448
+ # get the base model from the model instance
1449
+ base_model: Qwen3MoeModel = getattr(model, model.base_model_prefix, model)
1450
+
1451
+ if rms_norm:
1452
+ _patch_rms_norm_module(base_model.norm)
1453
+ for decoder_layer in base_model.layers:
1454
+ if swiglu:
1455
+ for mlp_expert in decoder_layer.mlp.experts:
1456
+ _patch_swiglu_module(mlp_expert, LigerQwen3MoeSwiGLUMLP)
1457
+ if rms_norm:
1458
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1459
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1460
+
1461
+
1462
+ def apply_liger_kernel_to_gpt_oss(
1463
+ rope: bool = True,
1464
+ cross_entropy: bool = False,
1465
+ fused_linear_cross_entropy: bool = True,
1466
+ rms_norm: bool = True,
1467
+ swiglu: bool = False, # Set to False by default since GPT-OSS has custom expert implementation
1468
+ model: PreTrainedModel = None,
1469
+ ) -> None:
1470
+ """
1471
+ Apply Liger kernels to replace original implementation in HuggingFace GPT-OSS models.
1472
+ NOTE: GPT-OSS is supported in transformers >= 4.55.0
1473
+ NOTE: SwiGLU patching is disabled by default for GPT-OSS as it uses a custom expert
1474
+ implementation with clamping and MXFP4 quantization.
1475
+
1476
+ Args:
1477
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1478
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1479
+ fused_linear_cross_entropy (bool):
1480
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1481
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1482
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1483
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1484
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1485
+ Note: GPT-OSS uses a custom expert implementation, so SwiGLU patching is disabled by default.
1486
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1487
+ loaded. Default is None.
1488
+ """
1489
+ if version.parse(transformers.__version__) < version.parse("4.55.0"):
1490
+ logger.warning("GPT-OSS support requires transformers >= 4.55.0")
1491
+ return
1492
+
1493
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1494
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1495
+ )
1496
+
1497
+ from transformers.models.gpt_oss import modeling_gpt_oss
1498
+ from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel
1499
+
1500
+ if rope:
1501
+ modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb
1502
+
1503
+ if rms_norm:
1504
+ modeling_gpt_oss.GptOssRMSNorm = LigerRMSNorm
1505
+
1506
+ if cross_entropy:
1507
+ from transformers.loss.loss_utils import nn
1508
+
1509
+ nn.functional.cross_entropy = liger_cross_entropy
1510
+
1511
+ if fused_linear_cross_entropy:
1512
+ if model is not None:
1513
+ model.forward = MethodType(gpt_oss_lce_forward, model)
1514
+ else:
1515
+ modeling_gpt_oss.GptOssForCausalLM.forward = gpt_oss_lce_forward
1516
+
1517
+ # Note: SwiGLU patching is not implemented for GPT-OSS due to custom expert implementation
1518
+ # with clamping (swiglu_limit=7.0) and MXFP4 quantization
1519
+
1520
+ if model is not None:
1521
+ # The model instance already exists, so we need to additionally patch the
1522
+ # instance variables that reference already-instantiated modules
1523
+
1524
+ # get the base model from the model instance
1525
+ base_model: GptOssModel = getattr(model, model.base_model_prefix, model)
1526
+
1527
+ if rms_norm:
1528
+ _patch_rms_norm_module(base_model.norm)
1529
+ for decoder_layer in base_model.layers:
1530
+ if rms_norm:
1531
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1532
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1533
+
1534
+
1535
+ def apply_liger_kernel_to_qwen2_vl(
1536
+ rope: bool = True,
1537
+ cross_entropy: bool = False,
1538
+ fused_linear_cross_entropy: bool = True,
1539
+ rms_norm: bool = True,
1540
+ layer_norm: bool = True,
1541
+ swiglu: bool = True,
1542
+ model: PreTrainedModel = None,
1543
+ ) -> None:
1544
+ """
1545
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
1546
+ NOTE: Qwen2-VL is not supported in transformers<4.52.4
1547
+
1548
+ Args:
1549
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1550
+ fused_linear_cross_entropy (bool):
1551
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1552
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1553
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1554
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1555
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
1556
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
1557
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1558
+ loaded. Default is None.
1559
+ """
1560
+ if transformer_version < version.parse("4.52.4"):
1561
+ logger.warning("Qwen2-VL support is only compatible with transformers >= 4.52.4")
1562
+ return
1563
+
1564
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1565
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1566
+ )
1567
+
1568
+ from transformers.models.qwen2_vl import modeling_qwen2_vl
1569
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
1570
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
1571
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
1572
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel
1573
+
1574
+ from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
1575
+
1576
+ if rope:
1577
+ modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
1578
+ if rms_norm:
1579
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
1580
+ modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
1581
+ if layer_norm and model is None:
1582
+ modeling_qwen2_vl.LayerNorm = LigerLayerNorm
1583
+ if cross_entropy:
1584
+ modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1585
+ if fused_linear_cross_entropy:
1586
+ if model is not None:
1587
+ model.forward = MethodType(qwen2_vl_lce_forward, model)
1588
+ else:
1589
+ modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
1590
+ if swiglu:
1591
+ modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
1592
+
1593
+ if model is not None:
1594
+ # The model instance already exists, so we need to additionally patch the
1595
+ # instance variables that reference already-instantiated modules
1596
+ if isinstance(model, Qwen2VLForConditionalGeneration):
1597
+ text_model: Qwen2VLTextModel = model.model.language_model
1598
+ vision_model: Qwen2VisionTransformerPretrainedModel = model.model.visual
1599
+ elif isinstance(model, Qwen2VLModel):
1600
+ text_model: Qwen2VLTextModel = model.language_model
1601
+ vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
1602
+ elif isinstance(model, Qwen2VLTextModel):
1603
+ text_model: Qwen2VLTextModel = model
1604
+ vision_model = None
1605
+ else:
1606
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
1607
+ raise TypeError(
1608
+ f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
1609
+ )
1610
+
1611
+ # Patch Qwen2VisionTransformerPretrainedModel
1612
+ if vision_model is not None:
1613
+ for vision_block in vision_model.blocks:
1614
+ if layer_norm:
1615
+ _patch_layer_norm_module(vision_block.norm1)
1616
+ _patch_layer_norm_module(vision_block.norm2)
1617
+
1618
+ # Patch Qwen2VisionTextModel
1619
+ if text_model is not None:
1620
+ if rms_norm:
1621
+ _patch_rms_norm_module(text_model.norm)
1622
+ for decoder_layer in text_model.layers:
1623
+ if swiglu:
1624
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1625
+ if rms_norm:
1626
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1627
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1628
+
1629
+
1630
+ def apply_liger_kernel_to_qwen2_5_vl(
1631
+ rope: bool = True,
1632
+ cross_entropy: bool = False,
1633
+ fused_linear_cross_entropy: bool = True,
1634
+ rms_norm: bool = True,
1635
+ swiglu: bool = True,
1636
+ model: PreTrainedModel = None,
1637
+ ) -> None:
1638
+ """
1639
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen2.5-VL models.
1640
+ NOTE: Qwen2.5-VL is not available in transformers<4.48.2
1641
+
1642
+ Args:
1643
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1644
+ fused_linear_cross_entropy (bool):
1645
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1646
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1647
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1648
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1649
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
1650
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1651
+ loaded. Default is None.
1652
+ """
1653
+ if transformer_version < version.parse("4.52.4"):
1654
+ logger.warning("Qwen2.5-VL support is only compatible with transformers >= 4.52.4")
1655
+ return
1656
+
1657
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1658
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1659
+ )
1660
+
1661
+ from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
1662
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
1663
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
1664
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
1665
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel
1666
+
1667
+ from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward
1668
+
1669
+ if rope:
1670
+ modeling_qwen2_5_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
1671
+ if rms_norm:
1672
+ modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm
1673
+ if cross_entropy:
1674
+ modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1675
+ if fused_linear_cross_entropy:
1676
+ if model is not None:
1677
+ model.forward = MethodType(qwen2_5_vl_lce_forward, model)
1678
+ else:
1679
+ modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
1680
+ if swiglu:
1681
+ modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
1682
+
1683
+ if model is not None:
1684
+ # The model instance already exists, so we need to additionally patch the
1685
+ # instance variables that reference already-instantiated modules
1686
+ if isinstance(model, Qwen2_5_VLForConditionalGeneration):
1687
+ text_model: Qwen2_5_VLTextModel = model.model.language_model
1688
+ vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.model.visual
1689
+ elif isinstance(model, Qwen2_5_VLModel):
1690
+ text_model: Qwen2_5_VLTextModel = model.language_model
1691
+ vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
1692
+ elif isinstance(model, Qwen2_5_VLTextModel):
1693
+ text_model: Qwen2_5_VLTextModel = model
1694
+ vision_model = None
1695
+ else:
1696
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
1697
+ raise TypeError(
1698
+ f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
1699
+ )
1700
+
1701
+ if vision_model is not None:
1702
+ # Patch Qwen2_5_VisionTransformerPretrainedModel
1703
+ for vision_block in vision_model.blocks:
1704
+ if rms_norm:
1705
+ _patch_rms_norm_module(vision_block.norm1)
1706
+ _patch_rms_norm_module(vision_block.norm2)
1707
+
1708
+ if text_model is not None:
1709
+ if rms_norm:
1710
+ _patch_rms_norm_module(text_model.norm)
1711
+ for decoder_layer in text_model.layers:
1712
+ if swiglu:
1713
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1714
+ if rms_norm:
1715
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1716
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1717
+
1718
+
1719
+ def apply_liger_kernel_to_qwen3_vl(
1720
+ rope: bool = True,
1721
+ cross_entropy: bool = False,
1722
+ fused_linear_cross_entropy: bool = True,
1723
+ rms_norm: bool = True,
1724
+ swiglu: bool = False,
1725
+ model: PreTrainedModel = None,
1726
+ ) -> None:
1727
+ """
1728
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
1729
+
1730
+ Args:
1731
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1732
+ fused_linear_cross_entropy (bool):
1733
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1734
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1735
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1736
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1737
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1738
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1739
+ loaded. Default is None.
1740
+ """
1741
+
1742
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1743
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1744
+ )
1745
+
1746
+ from transformers.models.qwen3_vl import modeling_qwen3_vl
1747
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
1748
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
1749
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
1750
+
1751
+ from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
1752
+
1753
+ if rope:
1754
+ modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb
1755
+ modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
1756
+
1757
+ if rms_norm:
1758
+ modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
1759
+
1760
+ if cross_entropy:
1761
+ from transformers.loss.loss_utils import nn
1762
+
1763
+ nn.functional.cross_entropy = liger_cross_entropy
1764
+
1765
+ if fused_linear_cross_entropy:
1766
+ if model is not None:
1767
+ model.forward = MethodType(qwen3_vl_lce_forward, model)
1768
+ else:
1769
+ modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
1770
+
1771
+ if model is not None and rms_norm:
1772
+ if isinstance(model, Qwen3VLForConditionalGeneration):
1773
+ text_model: Qwen3VLTextModel = model.model.language_model
1774
+ elif isinstance(model, Qwen3VLModel):
1775
+ text_model: Qwen3VLTextModel = model.language_model
1776
+ elif isinstance(model, Qwen3VLTextModel):
1777
+ text_model = model
1778
+ else:
1779
+ raise TypeError(
1780
+ f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
1781
+ )
1782
+
1783
+ _patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1784
+
1785
+ if text_model is not None:
1786
+ _patch_qwen3_vl_rms_norm(text_model.norm)
1787
+ for decoder_layer in text_model.layers:
1788
+ _patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
1789
+ _patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
1790
+ self_attn = getattr(decoder_layer, "self_attn", None)
1791
+ if self_attn is not None:
1792
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1793
+ _patch_qwen3_vl_rms_norm(self_attn.q_norm)
1794
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1795
+ _patch_qwen3_vl_rms_norm(self_attn.k_norm)
1796
+
1797
+
1798
+ def apply_liger_kernel_to_qwen3_vl_moe(
1799
+ rope: bool = True,
1800
+ cross_entropy: bool = False,
1801
+ fused_linear_cross_entropy: bool = True,
1802
+ rms_norm: bool = True,
1803
+ swiglu: bool = False,
1804
+ model: PreTrainedModel = None,
1805
+ ) -> None:
1806
+ """
1807
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
1808
+
1809
+ Args:
1810
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1811
+ fused_linear_cross_entropy (bool):
1812
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
1813
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1814
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1815
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1816
+ loaded. Default is None.
1817
+ """
1818
+
1819
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1820
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1821
+ )
1822
+
1823
+ from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
1824
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
1825
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
1826
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
1827
+
1828
+ from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
1829
+
1830
+ if rope:
1831
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
1832
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
1833
+
1834
+ if rms_norm:
1835
+ modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
1836
+
1837
+ if cross_entropy:
1838
+ from transformers.loss.loss_utils import nn
1839
+
1840
+ nn.functional.cross_entropy = liger_cross_entropy
1841
+
1842
+ if fused_linear_cross_entropy:
1843
+ if model is not None:
1844
+ model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
1845
+ else:
1846
+ modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
1847
+
1848
+ if model is not None and rms_norm:
1849
+ if isinstance(model, Qwen3VLMoeForConditionalGeneration):
1850
+ text_model: Qwen3VLMoeTextModel = model.model.language_model
1851
+ elif isinstance(model, Qwen3VLMoeModel):
1852
+ text_model: Qwen3VLMoeTextModel = model.language_model
1853
+ elif isinstance(model, Qwen3VLMoeTextModel):
1854
+ text_model = model
1855
+ else:
1856
+ raise TypeError(
1857
+ f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
1858
+ )
1859
+
1860
+ _patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1861
+
1862
+ if text_model is not None:
1863
+ _patch_qwen3_vl_moe_rms_norm(text_model.norm)
1864
+ for decoder_layer in text_model.layers:
1865
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
1866
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
1867
+ self_attn = getattr(decoder_layer, "self_attn", None)
1868
+ if self_attn is not None:
1869
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1870
+ _patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
1871
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1872
+ _patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
1873
+
1874
+
1875
+ def apply_liger_kernel_to_phi3(
1876
+ rope: bool = True,
1877
+ cross_entropy: bool = False,
1878
+ fused_linear_cross_entropy: bool = True,
1879
+ rms_norm: bool = True,
1880
+ swiglu: bool = True,
1881
+ model: PreTrainedModel = None,
1882
+ ) -> None:
1883
+ """
1884
+ Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
1885
+
1886
+ Args:
1887
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1888
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1889
+ fused_linear_cross_entropy (bool):
1890
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1891
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1892
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1893
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1894
+ swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
1895
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1896
+ loaded. Default is None.
1897
+ """
1898
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1899
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1900
+ )
1901
+
1902
+ from transformers.models.phi3 import modeling_phi3
1903
+ from transformers.models.phi3.modeling_phi3 import Phi3Model
1904
+
1905
+ if rope:
1906
+ modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
1907
+ if rms_norm:
1908
+ modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
1909
+ if swiglu:
1910
+ modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
1911
+ if cross_entropy:
1912
+ from transformers.loss.loss_utils import nn
1913
+
1914
+ nn.functional.cross_entropy = liger_cross_entropy
1915
+ if fused_linear_cross_entropy:
1916
+ if model is not None:
1917
+ model.forward = MethodType(phi3_lce_forward, model)
1918
+ else:
1919
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1920
+
1921
+ if model is not None:
1922
+ # The model instance already exists, so we need to additionally patch the
1923
+ # instance variables that reference already-instantiated modules
1924
+
1925
+ # get the base model from the model instance
1926
+ base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
1927
+
1928
+ if rms_norm:
1929
+ _patch_rms_norm_module(base_model.norm)
1930
+
1931
+ for decoder_layer in base_model.layers:
1932
+ if swiglu:
1933
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
1934
+ if rms_norm:
1935
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1936
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1937
+
1938
+
1939
+ def apply_liger_kernel_to_olmo2(
1940
+ rope: bool = True,
1941
+ cross_entropy: bool = False,
1942
+ fused_linear_cross_entropy: bool = True,
1943
+ rms_norm: bool = True,
1944
+ swiglu: bool = True,
1945
+ model: PreTrainedModel = None,
1946
+ ) -> None:
1947
+ """
1948
+ Apply Liger kernels to replace original implementation in HuggingFace OLMO2 models.
1949
+
1950
+ Args:
1951
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1952
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1953
+ fused_linear_cross_entropy (bool):
1954
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1955
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1956
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1957
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1958
+ swiglu (bool): Whether to apply Liger's SwiGLU Olmo2MLP. Default is True.
1959
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1960
+ loaded. Default is None.
1961
+ """
1962
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1963
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1964
+ )
1965
+
1966
+ from transformers.models.olmo2 import modeling_olmo2
1967
+ from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
1968
+
1969
+ from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
1970
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
1971
+
1972
+ if rope:
1973
+ modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
1974
+ if rms_norm:
1975
+ modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
1976
+ if swiglu:
1977
+ modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
1978
+ if cross_entropy:
1979
+ from transformers.loss.loss_utils import nn
1980
+
1981
+ nn.functional.cross_entropy = liger_cross_entropy
1982
+ if fused_linear_cross_entropy:
1983
+ if model is not None:
1984
+ model.forward = MethodType(olmo2_lce_forward, model)
1985
+ else:
1986
+ modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
1987
+
1988
+ if model is not None:
1989
+ # The model instance already exists, so we need to additionally patch the
1990
+ # instance variables that reference already-instantiated modules
1991
+
1992
+ # get the base model from the model instance
1993
+ base_model: Olmo2Model = getattr(model, model.base_model_prefix, model)
1994
+
1995
+ if rms_norm:
1996
+ _patch_rms_norm_module(base_model.norm)
1997
+
1998
+ for decoder_layer in base_model.layers:
1999
+ if swiglu:
2000
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2001
+ if rms_norm:
2002
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
2003
+ _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
2004
+
2005
+
2006
+ def apply_liger_kernel_to_olmo3(
2007
+ rope: bool = True,
2008
+ cross_entropy: bool = False,
2009
+ fused_linear_cross_entropy: bool = True,
2010
+ rms_norm: bool = True,
2011
+ swiglu: bool = True,
2012
+ model: PreTrainedModel = None,
2013
+ ) -> None:
2014
+ """
2015
+ Apply Liger kernels to replace original implementation in HuggingFace Olmo3 models.
2016
+
2017
+ Args:
2018
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2019
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2020
+ fused_linear_cross_entropy (bool):
2021
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2022
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2023
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2024
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2025
+ swiglu (bool): Whether to apply Liger's SwiGLU to Olmo3MLP. Default is True.
2026
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2027
+ loaded. Default is None.
2028
+ """
2029
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2030
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2031
+ )
2032
+
2033
+ from transformers.models.olmo3 import modeling_olmo3
2034
+ from transformers.models.olmo3.modeling_olmo3 import Olmo3Model
2035
+
2036
+ from liger_kernel.transformers.model.olmo3 import lce_forward as olmo3_lce_forward
2037
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
2038
+
2039
+ # Olmo3 arch is very similar to Olmo2, so we can reuse all these components in the same way.
2040
+ if rope:
2041
+ modeling_olmo3.apply_rotary_pos_emb = liger_rotary_pos_emb
2042
+ if rms_norm:
2043
+ modeling_olmo3.Olmo3RMSNorm = LigerRMSNormForOlmo2 # same as olmo2
2044
+ if swiglu:
2045
+ modeling_olmo3.Olmo3MLP = LigerSwiGLUMLP
2046
+ if cross_entropy:
2047
+ from transformers.loss.loss_utils import nn
2048
+
2049
+ nn.functional.cross_entropy = liger_cross_entropy
2050
+ if fused_linear_cross_entropy:
2051
+ if model is not None:
2052
+ model.forward = MethodType(olmo3_lce_forward, model)
2053
+ else:
2054
+ modeling_olmo3.Olmo3ForCausalLM.forward = olmo3_lce_forward
2055
+
2056
+ if model is not None:
2057
+ # The model instance already exists, so we need to additionally patch the
2058
+ # instance variables that reference already-instantiated modules
2059
+
2060
+ # get the base model from the model instance
2061
+ base_model: Olmo3Model = getattr(model, model.base_model_prefix, model)
2062
+
2063
+ if rms_norm:
2064
+ _patch_rms_norm_module(base_model.norm)
2065
+
2066
+ for decoder_layer in base_model.layers:
2067
+ if swiglu:
2068
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2069
+ if rms_norm:
2070
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
2071
+ _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
2072
+
2073
+
2074
+ def apply_liger_kernel_to_glm4(
2075
+ rope: bool = False,
2076
+ cross_entropy: bool = False,
2077
+ fused_linear_cross_entropy: bool = True,
2078
+ rms_norm: bool = True,
2079
+ swiglu: bool = True,
2080
+ model: PreTrainedModel = None,
2081
+ ) -> None:
2082
+ """
2083
+ Apply Liger kernels to replace original implementation in HuggingFace GLM-4 models.
2084
+
2085
+ Args:
2086
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2087
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2088
+ fused_linear_cross_entropy (bool):
2089
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2090
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2091
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2092
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2093
+ swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
2094
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2095
+ loaded. Default is None.
2096
+ """
2097
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2098
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2099
+ )
2100
+
2101
+ from transformers.models.glm4 import modeling_glm4
2102
+ from transformers.models.glm4.modeling_glm4 import Glm4Model
2103
+
2104
+ from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
2105
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
2106
+
2107
+ if rope:
2108
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
2109
+ if rms_norm:
2110
+ modeling_glm4.Glm4RMSNorm = LigerRMSNormForGlm4
2111
+ if swiglu:
2112
+ modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
2113
+ if cross_entropy:
2114
+ from transformers.loss.loss_utils import nn
2115
+
2116
+ nn.functional.cross_entropy = liger_cross_entropy
2117
+ if fused_linear_cross_entropy:
2118
+ if model is not None:
2119
+ model.forward = MethodType(glm4_lce_forward, model)
2120
+ else:
2121
+ modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
2122
+
2123
+ if model is not None:
2124
+ # The model instance already exists, so we need to additionally patch the
2125
+ # instance variables that reference already-instantiated modules
2126
+
2127
+ # get the base model from the model instance
2128
+ base_model: Glm4Model = getattr(model, model.base_model_prefix, model)
2129
+
2130
+ if rms_norm:
2131
+ _patch_rms_norm_module(base_model.norm, in_place=False)
2132
+
2133
+ for decoder_layer in base_model.layers:
2134
+ if swiglu:
2135
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
2136
+ if rms_norm:
2137
+ _patch_rms_norm_module(decoder_layer.input_layernorm, in_place=False)
2138
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
2139
+ _patch_rms_norm_module(decoder_layer.post_self_attn_layernorm, in_place=False)
2140
+ _patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
2141
+
2142
+
2143
+ def apply_liger_kernel_to_glm4v(
2144
+ rope: bool = False,
2145
+ cross_entropy: bool = False,
2146
+ fused_linear_cross_entropy: bool = True,
2147
+ rms_norm: bool = True,
2148
+ swiglu: bool = True,
2149
+ model: PreTrainedModel = None,
2150
+ ) -> None:
2151
+ """
2152
+ Apply Liger kernels to replace original implementation in HuggingFace GLM-4v models.
2153
+
2154
+ Args:
2155
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2156
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2157
+ fused_linear_cross_entropy (bool):
2158
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2159
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2160
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2161
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2162
+ swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
2163
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2164
+ loaded. Default is None.
2165
+ """
2166
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2167
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2168
+ )
2169
+
2170
+ from transformers.models.glm4v import modeling_glm4v
2171
+ from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration
2172
+ from transformers.models.glm4v.modeling_glm4v import Glm4vModel
2173
+ from transformers.models.glm4v.modeling_glm4v import Glm4vTextModel
2174
+ from transformers.models.glm4v.modeling_glm4v import Glm4vVisionModel
2175
+
2176
+ from liger_kernel.transformers.model.glm4v import lce_forward as glm4v_lce_forward
2177
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
2178
+
2179
+ if rope:
2180
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
2181
+ if rms_norm:
2182
+ modeling_glm4v.Glm4vRMSNorm = LigerRMSNormForGlm4
2183
+ if cross_entropy:
2184
+ from transformers.loss.loss_utils import nn
2185
+
2186
+ nn.functional.cross_entropy = liger_cross_entropy
2187
+ if fused_linear_cross_entropy:
2188
+ if model is not None:
2189
+ model.forward = MethodType(glm4v_lce_forward, model)
2190
+ else:
2191
+ modeling_glm4v.Glm4vForConditionalGeneration.forward = glm4v_lce_forward
2192
+
2193
+ if model is not None:
2194
+ # The model instance already exists, so we need to additionally patch the
2195
+ # instance variables that reference already-instantiated modules
2196
+ if isinstance(model, Glm4vForConditionalGeneration):
2197
+ text_model: Glm4vTextModel = model.model.language_model
2198
+ vision_model: Glm4vVisionModel = model.model.visual
2199
+ elif isinstance(model, Glm4vModel):
2200
+ text_model: Glm4vTextModel = model.language_model
2201
+ vision_model: Glm4vVisionModel = model.visual
2202
+ elif isinstance(model, Glm4vTextModel):
2203
+ text_model: Glm4vTextModel = model
2204
+ vision_model = None
2205
+ else:
2206
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
2207
+ raise TypeError(
2208
+ f"Unsupported glm4.1v model type. `model` must be `Glm4VLForConditionalGeneration`, `Glm4vVisionModel` or `Glm4vTextModel`. Got: {type(model)}"
2209
+ )
2210
+
2211
+ if vision_model is not None:
2212
+ for vision_block in vision_model.blocks:
2213
+ if rms_norm:
2214
+ _patch_rms_norm_module(vision_block.norm1)
2215
+ _patch_rms_norm_module(vision_block.norm2)
2216
+ if swiglu:
2217
+ _patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
2218
+
2219
+ if text_model is not None:
2220
+ if rms_norm:
2221
+ _patch_rms_norm_module(text_model.norm)
2222
+ for decoder_layer in text_model.layers:
2223
+ if swiglu:
2224
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
2225
+ if rms_norm:
2226
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2227
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2228
+ _patch_rms_norm_module(decoder_layer.post_self_attn_layernorm)
2229
+ _patch_rms_norm_module(decoder_layer.post_mlp_layernorm)
2230
+
2231
+
2232
+ def apply_liger_kernel_to_glm4v_moe(
2233
+ rope: bool = False,
2234
+ cross_entropy: bool = False,
2235
+ fused_linear_cross_entropy: bool = True,
2236
+ rms_norm: bool = True,
2237
+ swiglu: bool = True,
2238
+ model: PreTrainedModel = None,
2239
+ ) -> None:
2240
+ """
2241
+ Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
2242
+
2243
+ Args:
2244
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2245
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2246
+ fused_linear_cross_entropy (bool):
2247
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2248
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2249
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2250
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2251
+ swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
2252
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2253
+ loaded. Default is None.
2254
+ """
2255
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2256
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2257
+ )
2258
+
2259
+ from transformers.models.glm4v_moe import modeling_glm4v_moe
2260
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeForConditionalGeneration
2261
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeModel
2262
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextModel
2263
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeVisionModel
2264
+
2265
+ from liger_kernel.transformers.model.glm4v_moe import lce_forward as glm4v_moe_lce_forward
2266
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
2267
+
2268
+ if rope:
2269
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
2270
+ if rms_norm:
2271
+ modeling_glm4v_moe.Glm4vMoeRMSNorm = LigerRMSNormForGlm4
2272
+ modeling_glm4v_moe.Glm4vMoeTextRMSNorm = LigerRMSNormForGlm4
2273
+ if cross_entropy:
2274
+ from transformers.loss.loss_utils import nn
2275
+
2276
+ nn.functional.cross_entropy = liger_cross_entropy
2277
+ if fused_linear_cross_entropy:
2278
+ if model is not None:
2279
+ model.forward = MethodType(glm4v_moe_lce_forward, model)
2280
+ else:
2281
+ modeling_glm4v_moe.Glm4vMoeForConditionalGeneration.forward = glm4v_moe_lce_forward
2282
+
2283
+ if model is not None:
2284
+ # The model instance already exists, so we need to additionally patch the
2285
+ # instance variables that reference already-instantiated modules
2286
+ if isinstance(model, Glm4vMoeForConditionalGeneration):
2287
+ text_model: Glm4vMoeTextModel = model.model.language_model
2288
+ vision_model: Glm4vMoeVisionModel = model.model.visual
2289
+ Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
2290
+ elif isinstance(model, Glm4vMoeModel):
2291
+ text_model: Glm4vMoeTextModel = model.language_model
2292
+ vision_model: Glm4vMoeVisionModel = model.visual
2293
+ Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
2294
+ elif isinstance(model, Glm4vMoeTextModel):
2295
+ text_model: Glm4vMoeTextModel = model
2296
+ vision_model = None
2297
+ else:
2298
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
2299
+ raise TypeError(
2300
+ f"Unsupported glm4v_moe model type. `model` must be `Glm4vMoeForConditionalGeneration`, `Glm4vMoeVisionModel` or `Glm4vMoeTextModel`. Got: {type(model)}"
2301
+ )
2302
+
2303
+ if vision_model is not None:
2304
+ _patch_rms_norm_module(vision_model.post_conv_layernorm)
2305
+ _patch_rms_norm_module(vision_model.post_layernorm)
2306
+ for vision_block in vision_model.blocks:
2307
+ if rms_norm:
2308
+ _patch_rms_norm_module(vision_block.norm1)
2309
+ _patch_rms_norm_module(vision_block.norm2)
2310
+ if swiglu:
2311
+ _patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
2312
+
2313
+ if text_model is not None:
2314
+ if rms_norm:
2315
+ _patch_rms_norm_module(text_model.norm)
2316
+ for decoder_layer in text_model.layers:
2317
+ if swiglu:
2318
+ decoder_layer.mlp = _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2319
+ if rms_norm:
2320
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2321
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2322
+ if isinstance(Glm4vMoeTextMoE, type) and isinstance(decoder_layer.mlp, Glm4vMoeTextMoE):
2323
+ experts = getattr(decoder_layer.mlp, "experts", None)
2324
+ if experts is not None:
2325
+ for expert in experts:
2326
+ _patch_swiglu_module(expert, LigerSwiGLUMLP)
2327
+ if decoder_layer.mlp.shared_experts is not None:
2328
+ _patch_swiglu_module(decoder_layer.mlp.shared_experts, LigerSwiGLUMLP)
2329
+ for decoder_layer in text_model.layers:
2330
+ if rms_norm:
2331
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2332
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2333
+
2334
+
2335
+ def apply_liger_kernel_to_internvl(
2336
+ cross_entropy: bool = False,
2337
+ fused_linear_cross_entropy: bool = True,
2338
+ rms_norm: bool = True,
2339
+ layer_norm: bool = True,
2340
+ model: Optional[PreTrainedModel] = None,
2341
+ **kwargs,
2342
+ ) -> None:
2343
+ """
2344
+ Apply Liger kernels to replace original implementation in HuggingFace InternVL models.
2345
+ Due to the characteristics of InternVL, the model must be passed to apply Liger-Kernel's patch to other models connected to InternVL.
2346
+ However, if an LM not supported by Liger-Kernel is connected to InternVL, unexpected side effects may occur.
2347
+ NOTE: InternVL is not available in transformers<4.52.1
2348
+
2349
+ Args:
2350
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2351
+ fused_linear_cross_entropy (bool):
2352
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2353
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2354
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2355
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2356
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
2357
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2358
+ loaded. Default is None.
2359
+ """
2360
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2361
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2362
+ )
2363
+ import torch.nn as torch_nn
2364
+
2365
+ from transformers.models.internvl import modeling_internvl
2366
+ from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration
2367
+ from transformers.models.internvl.modeling_internvl import InternVLModel
2368
+ from transformers.models.internvl.modeling_internvl import InternVLVisionLayer
2369
+ from transformers.models.internvl.modeling_internvl import InternVLVisionModel
2370
+ from transformers.models.internvl.modeling_internvl import InternVLVisionRMSNorm
2371
+
2372
+ from liger_kernel.transformers.layer_norm import LigerLayerNorm
2373
+ from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
2374
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm
2375
+
2376
+ if layer_norm and model is None:
2377
+ modeling_internvl.nn.LayerNorm = LigerLayerNorm
2378
+
2379
+ if cross_entropy:
2380
+ logger.info("Apply liger cross entropy")
2381
+
2382
+ from transformers.loss.loss_utils import nn
879
2383
 
880
- def apply_liger_kernel_to_qwen2_vl(
881
- rope: bool = True,
2384
+ nn.functional.cross_entropy = liger_cross_entropy
2385
+ if fused_linear_cross_entropy:
2386
+ modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
2387
+ if rms_norm:
2388
+ modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
2389
+
2390
+ if model is not None:
2391
+ # The model instance already exists, so we need to additionally patch the
2392
+ # instance variables that reference already-instantiated modules
2393
+ if isinstance(model, InternVLForConditionalGeneration):
2394
+ text_model = model.model.language_model
2395
+ vision_model: InternVLVisionModel = model.model.vision_tower
2396
+ elif isinstance(model, InternVLModel):
2397
+ text_model = model.language_model
2398
+ vision_model: InternVLVisionModel = model.vision_tower
2399
+ else:
2400
+ raise TypeError(
2401
+ f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}"
2402
+ )
2403
+
2404
+ text_model_name = model.config.text_config.model_type
2405
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2406
+
2407
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2408
+ if text_liger_fn:
2409
+ accept_params = inspect.signature(text_liger_fn).parameters
2410
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2411
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2412
+
2413
+ if remain_params:
2414
+ logger.warning(
2415
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2416
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2417
+ )
2418
+ text_kwargs["model"] = text_model
2419
+ text_liger_fn(**text_kwargs)
2420
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2421
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2422
+
2423
+ # Patch vision model RMSNorm layers
2424
+ if rms_norm:
2425
+ for encoder_layer in vision_model.encoder.layer:
2426
+ encoder_layer: InternVLVisionLayer
2427
+ if isinstance(encoder_layer.attention.q_norm, InternVLVisionRMSNorm):
2428
+ _patch_rms_norm_module(encoder_layer.attention.q_norm)
2429
+ if isinstance(encoder_layer.attention.k_norm, InternVLVisionRMSNorm):
2430
+ _patch_rms_norm_module(encoder_layer.attention.k_norm)
2431
+
2432
+ # Patch vision model LayerNorm layers
2433
+ if layer_norm:
2434
+ # Patch layernorm
2435
+ if isinstance(vision_model.layernorm, torch_nn.LayerNorm):
2436
+ _patch_layer_norm_module(vision_model.layernorm)
2437
+
2438
+ # Patch encoder layers
2439
+ for encoder_layer in vision_model.encoder.layer:
2440
+ encoder_layer: InternVLVisionLayer
2441
+ if isinstance(encoder_layer.layernorm_before, torch_nn.LayerNorm):
2442
+ _patch_layer_norm_module(encoder_layer.layernorm_before)
2443
+ if isinstance(encoder_layer.layernorm_after, torch_nn.LayerNorm):
2444
+ _patch_layer_norm_module(encoder_layer.layernorm_after)
2445
+
2446
+
2447
+ def apply_liger_kernel_to_smolvlm(
882
2448
  cross_entropy: bool = False,
883
2449
  fused_linear_cross_entropy: bool = True,
884
2450
  rms_norm: bool = True,
885
2451
  layer_norm: bool = True,
886
- swiglu: bool = True,
887
- model: PreTrainedModel = None,
2452
+ model: Optional[PreTrainedModel] = None,
2453
+ **kwargs,
888
2454
  ) -> None:
889
2455
  """
890
- Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
891
- NOTE: Qwen2-VL is not available in transformers<4.45.0
2456
+ Apply Liger kernels to replace original implementation in HuggingFace SmolVLM models.
2457
+ Due to the characteristics of SmolVLM, the model must be passed to apply Liger-Kernel's patch to other models connected to SmolVLM.
2458
+ However, if an LM not supported by Liger-Kernel is connected to SmolVLM, unexpected side effects may occur.
2459
+ NOTE: SmolVLM is not available in transformers<4.50.0
892
2460
 
893
2461
  Args:
894
2462
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
@@ -898,7 +2466,6 @@ def apply_liger_kernel_to_qwen2_vl(
898
2466
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
899
2467
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
900
2468
  layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
901
- swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
902
2469
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
903
2470
  loaded. Default is None.
904
2471
  """
@@ -906,51 +2473,148 @@ def apply_liger_kernel_to_qwen2_vl(
906
2473
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
907
2474
  )
908
2475
 
909
- from transformers.models.qwen2_vl import modeling_qwen2_vl
910
- from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
2476
+ from transformers.models.smolvlm import modeling_smolvlm
2477
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMEncoderLayer
2478
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration
2479
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMModel
2480
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer
911
2481
 
912
- from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
2482
+ from liger_kernel.transformers.model.smolvlm import lce_forward as smolvlm_lce_forward
2483
+
2484
+ # Patch LayerNorm for vision model if model is not provided (pre-initialization)
2485
+ if layer_norm and model is None:
2486
+ modeling_smolvlm.nn.LayerNorm = LigerLayerNorm
2487
+
2488
+ if cross_entropy:
2489
+ logger.info("Apply liger cross entropy")
2490
+
2491
+ from transformers.loss.loss_utils import nn
2492
+
2493
+ nn.functional.cross_entropy = liger_cross_entropy
2494
+ if fused_linear_cross_entropy:
2495
+ if model is not None:
2496
+ model.forward = MethodType(smolvlm_lce_forward, model)
2497
+ else:
2498
+ modeling_smolvlm.SmolVLMForConditionalGeneration.forward = smolvlm_lce_forward
2499
+ if rms_norm:
2500
+ modeling_smolvlm.SmolVLMRMSNorm = LigerRMSNorm
2501
+
2502
+ if model is not None:
2503
+ # The model instance already exists, so we need to additionally patch the
2504
+ # instance variables that reference already-instantiated modules
2505
+ if isinstance(model, SmolVLMForConditionalGeneration):
2506
+ text_model = model.model.text_model
2507
+ vision_model: SmolVLMVisionTransformer = model.model.vision_model
2508
+ elif isinstance(model, SmolVLMModel):
2509
+ text_model = model.text_model
2510
+ vision_model: SmolVLMVisionTransformer = model.vision_model
2511
+ else:
2512
+ raise TypeError(
2513
+ f"Unsupported smolvlm model type. `model` must be `SmolVLMForConditionalGeneration`, `SmolVLMModel`. Got: {type(model)}"
2514
+ )
2515
+
2516
+ text_model_name = model.config.text_config.model_type
2517
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2518
+
2519
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2520
+ if text_liger_fn:
2521
+ accept_params = inspect.signature(text_liger_fn).parameters
2522
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2523
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2524
+
2525
+ if remain_params:
2526
+ logger.warning(
2527
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2528
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2529
+ )
2530
+ text_kwargs["model"] = text_model
2531
+ text_liger_fn(**text_kwargs)
2532
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2533
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2534
+
2535
+ # Patch vision model LayerNorm layers
2536
+ if layer_norm:
2537
+ # Patch post_layernorm
2538
+ _patch_layer_norm_module(vision_model.post_layernorm)
2539
+
2540
+ # Patch encoder layers
2541
+ for encoder_layer in vision_model.encoder.layers:
2542
+ encoder_layer: SmolVLMEncoderLayer
2543
+ _patch_layer_norm_module(encoder_layer.layer_norm1)
2544
+ _patch_layer_norm_module(encoder_layer.layer_norm2)
2545
+
2546
+
2547
+ def apply_liger_kernel_to_falcon_h1(
2548
+ rope: bool = True,
2549
+ cross_entropy: bool = False,
2550
+ fused_linear_cross_entropy: bool = True,
2551
+ rms_norm: bool = True,
2552
+ swiglu: bool = False,
2553
+ model: PreTrainedModel = None,
2554
+ ) -> None:
2555
+ """
2556
+ Apply Liger kernels to replace original implementation in HuggingFace Falcon-H1 models
2557
+ Args:
2558
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2559
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
2560
+ fused_linear_cross_entropy (bool):
2561
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
2562
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2563
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2564
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
2565
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
2566
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2567
+ loaded. Default is None.
2568
+ """
2569
+
2570
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2571
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2572
+ )
2573
+
2574
+ from transformers.models.falcon_h1 import modeling_falcon_h1
2575
+ from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1Model
913
2576
 
914
2577
  if rope:
915
- modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
2578
+ logger.info("Apply liger rotary pos emb.")
2579
+ modeling_falcon_h1.apply_rotary_pos_emb = liger_rotary_pos_emb
916
2580
  if rms_norm:
917
- # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
918
- modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
919
- if layer_norm:
920
- modeling_qwen2_vl.LayerNorm = LigerLayerNorm
2581
+ logger.info("Apply liger RMSNorm")
2582
+ modeling_falcon_h1.FalconH1RMSNorm = LigerRMSNorm
2583
+ if swiglu:
2584
+ logger.warning("LigerSwiGLUMLP is not available for Falcon-H1 models. There will be no effect.")
2585
+
921
2586
  if cross_entropy:
922
- modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
2587
+ logger.info("Apply liger cross entropy")
2588
+ from transformers.loss.loss_utils import nn
2589
+
2590
+ nn.functional.cross_entropy = liger_cross_entropy
2591
+
923
2592
  if fused_linear_cross_entropy:
924
- modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
925
- if swiglu:
926
- modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
2593
+ if model is not None:
2594
+ model.forward = MethodType(falcon_h1_lce_forward, model)
2595
+ else:
2596
+ modeling_falcon_h1.FalconH1ForCausalLM.forward = falcon_h1_lce_forward
927
2597
 
928
2598
  if model is not None:
929
2599
  # The model instance already exists, so we need to additionally patch the
930
- # instance variables that reference already-instantiated modules
2600
+ # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
931
2601
 
932
2602
  # get the base model from the model instance
933
- base_model: Qwen2VLModel = getattr(model, model.base_model_prefix, model)
934
-
935
- if hasattr(model, "visual"):
936
- # Patch Qwen2VisionTransformerPretrainedModel
937
- for vision_block in model.visual.blocks:
938
- if layer_norm:
939
- _patch_layer_norm_module(vision_block.norm1)
940
- _patch_layer_norm_module(vision_block.norm2)
2603
+ base_model: FalconH1Model = getattr(model, model.base_model_prefix, model)
941
2604
 
942
2605
  if rms_norm:
943
- _patch_rms_norm_module(base_model.norm)
2606
+ _patch_rms_norm_module(base_model.final_layernorm)
2607
+
944
2608
  for decoder_layer in base_model.layers:
945
2609
  if swiglu:
946
2610
  _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
947
2611
  if rms_norm:
948
2612
  _patch_rms_norm_module(decoder_layer.input_layernorm)
949
- _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2613
+ _patch_rms_norm_module(decoder_layer.pre_ff_layernorm)
950
2614
 
951
2615
 
952
- def apply_liger_kernel_to_qwen2_5_vl(
953
- rope: bool = True,
2616
+ def apply_liger_kernel_to_qwen3_next(
2617
+ rope: bool = False,
954
2618
  cross_entropy: bool = False,
955
2619
  fused_linear_cross_entropy: bool = True,
956
2620
  rms_norm: bool = True,
@@ -958,17 +2622,17 @@ def apply_liger_kernel_to_qwen2_5_vl(
958
2622
  model: PreTrainedModel = None,
959
2623
  ) -> None:
960
2624
  """
961
- Apply Liger kernels to replace original implementation in HuggingFace Qwen2.5-VL models.
962
- NOTE: Qwen2.5-VL is not available in transformers<4.48.2
2625
+ Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
963
2626
 
964
2627
  Args:
2628
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
965
2629
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
966
2630
  fused_linear_cross_entropy (bool):
967
2631
  Whether to apply Liger's fused linear cross entropy loss. Default is True.
968
2632
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
969
2633
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
970
2634
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
971
- swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
2635
+ swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
972
2636
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
973
2637
  loaded. Default is None.
974
2638
  """
@@ -976,47 +2640,129 @@ def apply_liger_kernel_to_qwen2_5_vl(
976
2640
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
977
2641
  )
978
2642
 
979
- from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
980
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
2643
+ from transformers.models.qwen3_next import modeling_qwen3_next
2644
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM
2645
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP
2646
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextModel
2647
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
981
2648
 
982
- from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward
2649
+ from liger_kernel.transformers.model.qwen3_next import lce_forward as qwen3_next_lce_forward
2650
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
2651
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
983
2652
 
984
2653
  if rope:
985
- modeling_qwen2_5_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
2654
+ # It might enocunter nan issue
2655
+ # modeling_qwen3_next.apply_rotary_pos_emb = liger_rotary_pos_emb
2656
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3Next models.")
986
2657
  if rms_norm:
987
- modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm
2658
+ modeling_qwen3_next.Qwen3NextRMSNorm = LigerRMSNormForQwen3Next
988
2659
  if cross_entropy:
989
- modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
2660
+ from transformers.loss.loss_utils import nn
2661
+
2662
+ nn.functional.cross_entropy = liger_cross_entropy
990
2663
  if fused_linear_cross_entropy:
991
- modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
2664
+ if model is not None:
2665
+ if isinstance(model, Qwen3NextForCausalLM):
2666
+ model.forward = MethodType(qwen3_next_lce_forward, model)
2667
+ else:
2668
+ raise TypeError(
2669
+ f" fused_linear_cross_entropy is only applicable on Qwen3NextForCausalLM. Got: {type(model)}"
2670
+ )
2671
+ else:
2672
+ modeling_qwen3_next.Qwen3NextForCausalLM.forward = qwen3_next_lce_forward
992
2673
  if swiglu:
993
- modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
2674
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2675
+ modeling_qwen3_next.Qwen3NextMLP = LigerQwen3MoeSwiGLUMLP
994
2676
 
995
2677
  if model is not None:
996
2678
  # The model instance already exists, so we need to additionally patch the
997
2679
  # instance variables that reference already-instantiated modules
2680
+ if isinstance(model, (Qwen3NextForCausalLM, Qwen3NextModel)):
2681
+ base_model: Qwen3NextForCausalLM = getattr(model, model.base_model_prefix, model)
2682
+ else:
2683
+ raise TypeError(
2684
+ f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}"
2685
+ )
998
2686
 
999
- # get the base model from the model instance
1000
- base_model: Qwen2_5_VLModel = getattr(model, model.base_model_prefix, model)
2687
+ if rms_norm:
2688
+ _patch_rms_norm_module(base_model.norm)
1001
2689
 
1002
- if hasattr(model, "visual"):
1003
- # Patch Qwen2_5_VisionTransformerPretrainedModel
1004
- for vision_block in model.visual.blocks:
1005
- if rms_norm:
1006
- _patch_rms_norm_module(vision_block.norm1)
1007
- _patch_rms_norm_module(vision_block.norm2)
2690
+ for decoder_layer in base_model.layers:
2691
+ if rms_norm:
2692
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2693
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2694
+
2695
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2696
+ if swiglu:
2697
+ if isinstance(decoder_layer.mlp, Qwen3NextMLP):
2698
+ _patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
2699
+ if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock):
2700
+ _patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP)
2701
+ experts = getattr(decoder_layer.mlp, "experts", None)
2702
+ if experts is not None:
2703
+ for expert in experts:
2704
+ _patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
2705
+
2706
+
2707
+ def apply_liger_kernel_to_hunyuan_v1_dense(
2708
+ rope: bool = True,
2709
+ cross_entropy: bool = False,
2710
+ fused_linear_cross_entropy: bool = True,
2711
+ rms_norm: bool = True,
2712
+ swiglu: bool = True,
2713
+ model: PreTrainedModel = None,
2714
+ ) -> None:
2715
+ """
2716
+ Apply Liger kernels to replace original implementation in HuggingFace Hunyuan v1 dense models.
2717
+ """
2718
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2719
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2720
+ )
2721
+
2722
+ from transformers.models.hunyuan_v1_dense import modeling_hunyuan_v1_dense
2723
+ from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1Model
2724
+
2725
+ from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_lce_forward
2726
+ from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
2727
+
2728
+ if rope:
2729
+ modeling_hunyuan_v1_dense.apply_rotary_pos_emb = liger_rotary_pos_emb
2730
+
2731
+ if rms_norm:
2732
+ modeling_hunyuan_v1_dense.HunYuanDenseV1RMSNorm = LigerRMSNorm
2733
+
2734
+ if cross_entropy:
2735
+ from transformers.loss.loss_utils import nn
2736
+
2737
+ nn.functional.cross_entropy = liger_cross_entropy
2738
+
2739
+ if fused_linear_cross_entropy:
2740
+ if model is not None:
2741
+ model.forward = MethodType(hunyuan_v1_lce_forward, model)
2742
+ else:
2743
+ modeling_hunyuan_v1_dense.HunYuanDenseV1ForCausalLM.forward = hunyuan_v1_lce_forward
2744
+
2745
+ if swiglu:
2746
+ modeling_hunyuan_v1_dense.HunYuanDenseV1MLP = LigerHunyuanV1SwiGLUMLP
2747
+
2748
+ if model is not None:
2749
+ # The model instance already exists, so we need to additionally patch the
2750
+ # instance variables that reference already-instantiated modules
2751
+
2752
+ # get the base model from the model instance
2753
+ base_model: HunYuanDenseV1Model = getattr(model, model.base_model_prefix, model)
1008
2754
 
1009
2755
  if rms_norm:
1010
2756
  _patch_rms_norm_module(base_model.norm)
1011
2757
  for decoder_layer in base_model.layers:
1012
2758
  if swiglu:
1013
- _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2759
+ _patch_swiglu_module(decoder_layer.mlp, LigerHunyuanV1SwiGLUMLP)
1014
2760
  if rms_norm:
1015
2761
  _patch_rms_norm_module(decoder_layer.input_layernorm)
1016
2762
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1017
2763
 
1018
2764
 
1019
- def apply_liger_kernel_to_phi3(
2765
+ def apply_liger_kernel_to_hunyuan_v1_moe(
1020
2766
  rope: bool = True,
1021
2767
  cross_entropy: bool = False,
1022
2768
  fused_linear_cross_entropy: bool = True,
@@ -1025,67 +2771,57 @@ def apply_liger_kernel_to_phi3(
1025
2771
  model: PreTrainedModel = None,
1026
2772
  ) -> None:
1027
2773
  """
1028
- Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
1029
-
1030
- Args:
1031
- rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1032
- cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1033
- fused_linear_cross_entropy (bool):
1034
- Whether to apply Liger's fused linear cross entropy loss. Default is True.
1035
- `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1036
- If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1037
- rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1038
- swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
1039
- model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1040
- loaded. Default is None.
2774
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
1041
2775
  """
1042
2776
  assert not (cross_entropy and fused_linear_cross_entropy), (
1043
2777
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
1044
2778
  )
1045
2779
 
1046
- from transformers.models.phi3 import modeling_phi3
1047
- from transformers.models.phi3.modeling_phi3 import Phi3Model
2780
+ from transformers.models.hunyuan_v1_moe import modeling_hunyuan_v1_moe
2781
+ from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1Model
2782
+
2783
+ from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_moe_lce_forward
2784
+ from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
1048
2785
 
1049
2786
  if rope:
1050
- modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
2787
+ modeling_hunyuan_v1_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
2788
+
1051
2789
  if rms_norm:
1052
- modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
1053
- if swiglu:
1054
- modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
2790
+ modeling_hunyuan_v1_moe.HunYuanMoEV1RMSNorm = LigerRMSNorm
2791
+
1055
2792
  if cross_entropy:
1056
- if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1057
- from transformers.loss.loss_utils import nn
2793
+ from transformers.loss.loss_utils import nn
2794
+
2795
+ nn.functional.cross_entropy = liger_cross_entropy
1058
2796
 
1059
- nn.functional.cross_entropy = liger_cross_entropy
1060
- else:
1061
- logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1062
- modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
1063
2797
  if fused_linear_cross_entropy:
1064
- if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1065
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1066
- else: # if version < 4.46.1
1067
- logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1068
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
2798
+ if model is not None:
2799
+ model.forward = MethodType(hunyuan_v1_moe_lce_forward, model)
2800
+ else:
2801
+ modeling_hunyuan_v1_moe.HunYuanMoEV1ForCausalLM.forward = hunyuan_v1_moe_lce_forward
2802
+
2803
+ if swiglu:
2804
+ modeling_hunyuan_v1_moe.HunYuanMoEV1MLP = LigerHunyuanV1SwiGLUMLP
1069
2805
 
1070
2806
  if model is not None:
1071
2807
  # The model instance already exists, so we need to additionally patch the
1072
2808
  # instance variables that reference already-instantiated modules
1073
2809
 
1074
2810
  # get the base model from the model instance
1075
- base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
2811
+ base_model: HunYuanMoEV1Model = getattr(model, model.base_model_prefix, model)
1076
2812
 
1077
2813
  if rms_norm:
1078
2814
  _patch_rms_norm_module(base_model.norm)
1079
-
1080
2815
  for decoder_layer in base_model.layers:
1081
2816
  if swiglu:
1082
- _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
2817
+ for mlp_expert in decoder_layer.mlp.experts:
2818
+ _patch_swiglu_module(mlp_expert, LigerHunyuanV1SwiGLUMLP)
1083
2819
  if rms_norm:
1084
2820
  _patch_rms_norm_module(decoder_layer.input_layernorm)
1085
2821
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1086
2822
 
1087
2823
 
1088
- def apply_liger_kernel_to_olmo2(
2824
+ def apply_liger_kernel_to_exaone4(
1089
2825
  rope: bool = True,
1090
2826
  cross_entropy: bool = False,
1091
2827
  fused_linear_cross_entropy: bool = True,
@@ -1094,7 +2830,7 @@ def apply_liger_kernel_to_olmo2(
1094
2830
  model: PreTrainedModel = None,
1095
2831
  ) -> None:
1096
2832
  """
1097
- Apply Liger kernels to replace original implementation in HuggingFace OLMO2 models.
2833
+ Apply Liger kernels to replace original implementation in HuggingFace EXAONE4 models.
1098
2834
 
1099
2835
  Args:
1100
2836
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
@@ -1104,7 +2840,7 @@ def apply_liger_kernel_to_olmo2(
1104
2840
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1105
2841
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1106
2842
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1107
- swiglu (bool): Whether to apply Liger's SwiGLU Olmo2MLP. Default is True.
2843
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
1108
2844
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1109
2845
  loaded. Default is None.
1110
2846
  """
@@ -1112,47 +2848,70 @@ def apply_liger_kernel_to_olmo2(
1112
2848
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
1113
2849
  )
1114
2850
 
1115
- from transformers.models.olmo2 import modeling_olmo2
1116
- from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
2851
+ from transformers.models.exaone4 import modeling_exaone4
2852
+ from transformers.models.exaone4.modeling_exaone4 import Exaone4Model
1117
2853
 
1118
- from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
2854
+ from liger_kernel.transformers.model.exaone4 import lce_forward as exaone4_lce_forward
1119
2855
 
1120
2856
  if rope:
1121
- modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
2857
+ modeling_exaone4.apply_rotary_pos_emb = liger_rotary_pos_emb
2858
+
1122
2859
  if rms_norm:
1123
- modeling_olmo2.Olmo2RMSNorm = partial(LigerRMSNorm, in_place=False)
1124
- if swiglu:
1125
- modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
2860
+ # EXAONE4 requires in_place=False to avoid gradient issues
2861
+ class Exaone4LigerRMSNorm(LigerRMSNorm):
2862
+ def __init__(self, hidden_size, eps=1e-6, **kwargs):
2863
+ super().__init__(hidden_size, eps, **kwargs)
2864
+ self.in_place = False
2865
+
2866
+ modeling_exaone4.Exaone4RMSNorm = Exaone4LigerRMSNorm
2867
+
1126
2868
  if cross_entropy:
1127
2869
  from transformers.loss.loss_utils import nn
1128
2870
 
1129
2871
  nn.functional.cross_entropy = liger_cross_entropy
2872
+
1130
2873
  if fused_linear_cross_entropy:
1131
- modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
2874
+ if model is not None:
2875
+ model.forward = MethodType(exaone4_lce_forward, model)
2876
+ else:
2877
+ modeling_exaone4.Exaone4ForCausalLM.forward = exaone4_lce_forward
2878
+
2879
+ if swiglu:
2880
+ modeling_exaone4.Exaone4MLP = LigerSwiGLUMLP
1132
2881
 
1133
2882
  if model is not None:
1134
2883
  # The model instance already exists, so we need to additionally patch the
1135
2884
  # instance variables that reference already-instantiated modules
1136
2885
 
1137
2886
  # get the base model from the model instance
1138
- base_model: Olmo2Model = getattr(model, model.base_model_prefix, model)
2887
+ base_model: Exaone4Model = getattr(model, model.base_model_prefix, model)
1139
2888
 
1140
2889
  if rms_norm:
1141
- _patch_rms_norm_module(base_model.norm)
1142
-
2890
+ _patch_rms_norm_module(base_model.norm, in_place=False)
1143
2891
  for decoder_layer in base_model.layers:
1144
2892
  if swiglu:
1145
- _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2893
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
1146
2894
  if rms_norm:
1147
2895
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
1148
2896
  _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
2897
+ _patch_rms_norm_module(decoder_layer.self_attn.q_norm, in_place=False)
2898
+ _patch_rms_norm_module(decoder_layer.self_attn.k_norm, in_place=False)
1149
2899
 
1150
2900
 
1151
2901
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
1152
2902
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
1153
2903
  "gemma": apply_liger_kernel_to_gemma,
1154
2904
  "gemma2": apply_liger_kernel_to_gemma2,
2905
+ "gemma3_text": apply_liger_kernel_to_gemma3_text,
2906
+ "gemma3": apply_liger_kernel_to_gemma3,
2907
+ "glm4": apply_liger_kernel_to_glm4,
2908
+ "glm4v": apply_liger_kernel_to_glm4v,
2909
+ "glm4v_moe": apply_liger_kernel_to_glm4v_moe,
2910
+ "gpt_oss": apply_liger_kernel_to_gpt_oss,
2911
+ "internvl": apply_liger_kernel_to_internvl,
1155
2912
  "llama": apply_liger_kernel_to_llama,
2913
+ "llama4_text": apply_liger_kernel_to_llama4,
2914
+ "llama4": apply_liger_kernel_to_llama4,
1156
2915
  "llava": apply_liger_kernel_to_llava,
1157
2916
  "granite": apply_liger_kernel_to_granite,
1158
2917
  "mllama": apply_liger_kernel_to_mllama,
@@ -1160,11 +2919,27 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1160
2919
  "mistral": apply_liger_kernel_to_mistral,
1161
2920
  "mixtral": apply_liger_kernel_to_mixtral,
1162
2921
  "olmo2": apply_liger_kernel_to_olmo2,
2922
+ "olmo3": apply_liger_kernel_to_olmo3,
1163
2923
  "qwen2": apply_liger_kernel_to_qwen2,
2924
+ "qwen3": apply_liger_kernel_to_qwen3,
2925
+ "qwen3_moe": apply_liger_kernel_to_qwen3_moe,
1164
2926
  "qwen2_vl": apply_liger_kernel_to_qwen2_vl,
2927
+ "qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
1165
2928
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
2929
+ "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
2930
+ "qwen3_next": apply_liger_kernel_to_qwen3_next,
2931
+ "qwen3_vl": apply_liger_kernel_to_qwen3_vl,
2932
+ "qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
2933
+ "qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
2934
+ "qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
2935
+ "smollm3": apply_liger_kernel_to_smollm3,
1166
2936
  "phi3": apply_liger_kernel_to_phi3,
1167
2937
  "paligemma": apply_liger_kernel_to_paligemma,
2938
+ "falcon_h1": apply_liger_kernel_to_falcon_h1,
2939
+ "smolvlm": apply_liger_kernel_to_smolvlm,
2940
+ "hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
2941
+ "hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
2942
+ "exaone4": apply_liger_kernel_to_exaone4,
1168
2943
  }
1169
2944
 
1170
2945
 
@@ -1222,7 +2997,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
1222
2997
  return
1223
2998
 
1224
2999
  apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
1225
-
1226
3000
  apply_fn_signature = inspect.signature(apply_fn)
1227
3001
 
1228
3002
  # Filter out the keyword arguments that are not supported by the apply function