liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +307 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +63 -0
  18. liger_kernel/ops/__init__.py +141 -0
  19. liger_kernel/ops/backends/README.md +151 -0
  20. liger_kernel/ops/backends/__init__.py +13 -0
  21. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  22. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  23. liger_kernel/ops/backends/registry.py +61 -0
  24. liger_kernel/ops/cross_entropy.py +383 -114
  25. liger_kernel/ops/dyt.py +160 -0
  26. liger_kernel/ops/experimental/embedding.py +141 -0
  27. liger_kernel/ops/experimental/mm_int8int2.py +349 -0
  28. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  29. liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
  30. liger_kernel/ops/fused_linear_jsd.py +228 -0
  31. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  32. liger_kernel/ops/geglu.py +66 -64
  33. liger_kernel/ops/group_norm.py +306 -0
  34. liger_kernel/ops/grpo_loss.py +312 -0
  35. liger_kernel/ops/jsd.py +201 -0
  36. liger_kernel/ops/kl_div.py +262 -0
  37. liger_kernel/ops/layer_norm.py +320 -0
  38. liger_kernel/ops/llama4_rope.py +225 -0
  39. liger_kernel/ops/multi_token_attention.py +207 -0
  40. liger_kernel/ops/poly_norm.py +390 -0
  41. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  42. liger_kernel/ops/rms_norm.py +484 -88
  43. liger_kernel/ops/rope.py +122 -117
  44. liger_kernel/ops/softmax.py +201 -0
  45. liger_kernel/ops/sparsemax.py +179 -0
  46. liger_kernel/ops/swiglu.py +68 -65
  47. liger_kernel/ops/tiled_mlp.py +136 -0
  48. liger_kernel/ops/tvd.py +207 -0
  49. liger_kernel/ops/utils.py +82 -3
  50. liger_kernel/transformers/__init__.py +218 -6
  51. liger_kernel/transformers/auto_model.py +38 -0
  52. liger_kernel/transformers/cross_entropy.py +52 -7
  53. liger_kernel/transformers/dyt.py +22 -0
  54. liger_kernel/transformers/experimental/__init__.py +5 -0
  55. liger_kernel/transformers/experimental/embedding.py +26 -0
  56. liger_kernel/transformers/fsdp.py +55 -0
  57. liger_kernel/transformers/functional.py +301 -0
  58. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  59. liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
  60. liger_kernel/transformers/fused_linear_jsd.py +95 -0
  61. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  62. liger_kernel/transformers/geglu.py +6 -7
  63. liger_kernel/transformers/group_norm.py +50 -0
  64. liger_kernel/transformers/grpo_loss.py +153 -0
  65. liger_kernel/transformers/jsd.py +70 -0
  66. liger_kernel/transformers/kl_div.py +12 -0
  67. liger_kernel/transformers/layer_norm.py +24 -0
  68. liger_kernel/transformers/llama4_rope.py +93 -0
  69. liger_kernel/transformers/model/falcon_h1.py +122 -0
  70. liger_kernel/transformers/model/gemma.py +261 -0
  71. liger_kernel/transformers/model/gemma2.py +283 -0
  72. liger_kernel/transformers/model/gemma3.py +332 -0
  73. liger_kernel/transformers/model/glm4.py +141 -0
  74. liger_kernel/transformers/model/glm4v.py +163 -0
  75. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  76. liger_kernel/transformers/model/gpt_oss.py +211 -0
  77. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  78. liger_kernel/transformers/model/internvl.py +157 -0
  79. liger_kernel/transformers/model/llama.py +221 -41
  80. liger_kernel/transformers/model/llama4.py +121 -0
  81. liger_kernel/transformers/model/llava.py +344 -0
  82. liger_kernel/transformers/model/loss_utils.py +95 -0
  83. liger_kernel/transformers/model/mistral.py +145 -0
  84. liger_kernel/transformers/model/mixtral.py +293 -0
  85. liger_kernel/transformers/model/mllama.py +269 -0
  86. liger_kernel/transformers/model/olmo2.py +141 -0
  87. liger_kernel/transformers/model/olmo3.py +142 -0
  88. liger_kernel/transformers/model/output_classes.py +147 -0
  89. liger_kernel/transformers/model/paligemma.py +433 -0
  90. liger_kernel/transformers/model/phi3.py +120 -0
  91. liger_kernel/transformers/model/qwen2.py +259 -0
  92. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  93. liger_kernel/transformers/model/qwen2_vl.py +159 -0
  94. liger_kernel/transformers/model/qwen3.py +136 -0
  95. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  96. liger_kernel/transformers/model/qwen3_next.py +146 -0
  97. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  98. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  99. liger_kernel/transformers/model/smollm3.py +199 -0
  100. liger_kernel/transformers/model/smolvlm.py +158 -0
  101. liger_kernel/transformers/monkey_patch.py +2816 -21
  102. liger_kernel/transformers/multi_token_attention.py +64 -0
  103. liger_kernel/transformers/poly_norm.py +42 -0
  104. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  105. liger_kernel/transformers/rms_norm.py +75 -5
  106. liger_kernel/transformers/rope.py +47 -3
  107. liger_kernel/transformers/softmax.py +12 -0
  108. liger_kernel/transformers/sparsemax.py +16 -0
  109. liger_kernel/transformers/swiglu.py +62 -6
  110. liger_kernel/transformers/tiled_mlp.py +133 -0
  111. liger_kernel/transformers/trainer/__init__.py +4 -0
  112. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  113. liger_kernel/transformers/trainer_integration.py +2 -45
  114. liger_kernel/transformers/tvd.py +13 -0
  115. liger_kernel/triton/__init__.py +1 -3
  116. liger_kernel/triton/monkey_patch.py +1 -5
  117. liger_kernel/utils.py +96 -0
  118. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
  119. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
  120. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  121. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
  122. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
  123. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
  124. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
  125. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  126. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,219 @@
1
+ import inspect
2
+ import logging
3
+
4
+ from functools import partial
5
+ from types import MethodType
6
+ from typing import Callable
7
+ from typing import Optional
8
+
9
+ import transformers
10
+
11
+ from packaging import version
12
+ from transformers import PreTrainedModel
13
+
1
14
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
15
+ from liger_kernel.transformers.functional import liger_cross_entropy
2
16
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
3
- from liger_kernel.transformers.model.llama import lce_forward
17
+ from liger_kernel.transformers.layer_norm import LigerLayerNorm
18
+ from liger_kernel.transformers.model.falcon_h1 import lce_forward as falcon_h1_lce_forward
19
+ from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
20
+ from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
21
+ from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
22
+ from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
23
+ from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward
24
+ from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
25
+ from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
26
+ from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
27
+ from liger_kernel.transformers.model.llava import lce_forward_deprecated as llava_lce_forward_deprecated
28
+ from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
29
+ from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
30
+ from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
31
+ from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
32
+ from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
33
+ from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
34
+ from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
35
+ from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
4
36
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
5
37
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
6
- from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP, LigerSwiGLUMLP
38
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb_vision
39
+ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
40
+ from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
41
+ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
42
+
43
+ try:
44
+ import peft
45
+
46
+ PEFT_AVAILABLE = True
47
+ except ImportError:
48
+ PEFT_AVAILABLE = False
49
+
50
+ transformer_version = version.parse(transformers.__version__)
51
+
52
+ logger = logging.getLogger(__name__)
53
+ SUPPORTED_TRANSFORMER_VERSION = "4.46.1"
54
+ TRANSFORMER_DEPRECATION_WARNING = "Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/34191"
55
+
56
+
57
+ def _bind_method_to_module(module, method_name: str, new_method: Callable):
58
+ # Binds a new method to a module instance so that self is passed as the first argument
59
+ module.__dict__[method_name] = new_method.__get__(module, module.__class__)
60
+
61
+
62
+ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None):
63
+ # Check if the module is a PEFT ModulesToSaveWrapper
64
+ # If it is, we need to patch the modules_to_save.default and original_modules
65
+ if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
66
+ module.modules_to_save.default.offset = offset
67
+ module.modules_to_save.default.casting_mode = casting_mode
68
+ module.modules_to_save.default.variance_epsilon = (
69
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
70
+ )
71
+ module.modules_to_save.default.in_place = in_place
72
+ module.modules_to_save.default.row_mode = row_mode
73
+ module.original_module.offset = offset
74
+ module.original_module.casting_mode = casting_mode
75
+ module.original_module.variance_epsilon = (
76
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
77
+ )
78
+ module.original_module.in_place = in_place
79
+ module.original_module.row_mode = row_mode
80
+ _bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
81
+ _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
82
+ _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
83
+ _bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
84
+ _bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerRMSNorm.__name__)
85
+ _bind_method_to_module(module.original_module, "_get_name", lambda self: LigerRMSNorm.__name__)
86
+ else:
87
+ module.offset = offset
88
+ module.casting_mode = casting_mode
89
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
90
+ module.in_place = in_place
91
+ module.row_mode = row_mode
92
+ _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
93
+ _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
94
+ _bind_method_to_module(module, "_get_name", lambda self: LigerRMSNorm.__name__)
95
+
96
+
97
+ def _patch_layer_norm_module(module, eps=1e-6):
98
+ # Check if the module is a PEFT ModulesToSaveWrapper
99
+ # If it is, we need to patch the modules_to_save.default and original_modules
100
+ if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
101
+ module.hidden_size = module.normalized_shape
102
+ _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
103
+ _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
104
+ module.modules_to_save.default.variance_epsilon = (
105
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
106
+ )
107
+ module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
108
+ module, "normalized_shape", None
109
+ )
110
+ module.original_module.variance_epsilon = (
111
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
112
+ )
113
+ module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
114
+ module, "normalized_shape", None
115
+ )
116
+ _bind_method_to_module(module.modules_to_save.default, "forward", LigerLayerNorm.forward)
117
+ _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerLayerNorm.extra_repr)
118
+ _bind_method_to_module(module.original_module, "forward", LigerLayerNorm.forward)
119
+ _bind_method_to_module(module.original_module, "extra_repr", LigerLayerNorm.extra_repr)
120
+ _bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerLayerNorm.__name__)
121
+ _bind_method_to_module(module.original_module, "_get_name", lambda self: LigerLayerNorm.__name__)
122
+ else:
123
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
124
+ module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
125
+ _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
126
+ _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
127
+ _bind_method_to_module(module, "_get_name", lambda self: LigerLayerNorm.__name__)
128
+
129
+
130
+ def _patch_swiglu_module(module, liger_module):
131
+ _bind_method_to_module(module, "forward", liger_module.forward)
132
+ _bind_method_to_module(module, "_get_name", lambda self: liger_module.__name__)
133
+
134
+
135
+ def _patch_geglu_module(module):
136
+ _bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
137
+ _bind_method_to_module(module, "_get_name", lambda self: LigerGEGLUMLP.__name__)
138
+
139
+
140
+ def apply_liger_kernel_to_granite(
141
+ rope: bool = True,
142
+ cross_entropy: bool = True,
143
+ fused_linear_cross_entropy: bool = False,
144
+ rms_norm: bool = True,
145
+ swiglu: bool = True,
146
+ model: PreTrainedModel = None,
147
+ ) -> None:
148
+ """
149
+ Apply Liger kernels to replace original implementation in HuggingFace Granite 3 models
150
+
151
+ Args:
152
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
153
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
154
+ fused_linear_cross_entropy (bool):
155
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
156
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
157
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
158
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
159
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
160
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
161
+ loaded. Default is None.
162
+
163
+
164
+
165
+ Debugging notes:
166
+ If LigerSwiGLUMLP is OK for Llama, it should be fine for Granite, but it's not.
167
+ """
168
+
169
+ assert not (cross_entropy and fused_linear_cross_entropy), (
170
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
171
+ )
172
+
173
+ from transformers.models.granite import modeling_granite
174
+ from transformers.models.granite.modeling_granite import GraniteModel
175
+
176
+ if swiglu:
177
+ modeling_granite.GraniteMLP = LigerSwiGLUMLP
178
+
179
+ if rms_norm:
180
+ modeling_granite.GraniteRMSNorm = LigerRMSNorm
181
+
182
+ if rope:
183
+ modeling_granite.apply_rotary_pos_emb = liger_rotary_pos_emb
184
+
185
+ if cross_entropy:
186
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
187
+ from transformers.loss.loss_utils import nn
188
+
189
+ nn.functional.cross_entropy = liger_cross_entropy
190
+ else:
191
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
192
+ modeling_granite.CrossEntropyLoss = LigerCrossEntropyLoss
193
+
194
+ if fused_linear_cross_entropy:
195
+ raise NotImplementedError("LigerFusedLinearCrossEntropy is not available for Granite models.")
196
+ # NOTE: Granite model `GraniteForCausalLM.forward` scales logits each
197
+ # call, so we can't sidestep logit materialization. A bit more work
198
+ # would be needed to add a scaling term to the `LigerFusedLinearCrossEntropyFunction`
199
+ # for the logit output.
200
+
201
+ if model is not None:
202
+ # The model instance already exists, so we need to additionally patch the
203
+ # instance variables that reference already-instantiated modules (e.g. GraniteRMSNorm or GraniteMLP)
204
+
205
+ # get the base model from the model instance
206
+ base_model: GraniteModel = getattr(model, model.base_model_prefix, model)
207
+
208
+ if rms_norm:
209
+ _patch_rms_norm_module(base_model.norm)
210
+
211
+ for decoder_layer in base_model.layers:
212
+ if swiglu:
213
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
214
+ if rms_norm:
215
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
216
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
7
217
 
8
218
 
9
219
  def apply_liger_kernel_to_llama(
@@ -12,6 +222,7 @@ def apply_liger_kernel_to_llama(
12
222
  fused_linear_cross_entropy: bool = True,
13
223
  rms_norm: bool = True,
14
224
  swiglu: bool = True,
225
+ model: PreTrainedModel = None,
15
226
  ) -> None:
16
227
  """
17
228
  Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
@@ -20,18 +231,21 @@ def apply_liger_kernel_to_llama(
20
231
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
21
232
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
22
233
  fused_linear_cross_entropy (bool):
23
- Whether to apply Liger's fused lienar cross entropy loss. Default is True.
234
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
24
235
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
25
236
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
26
237
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
27
238
  swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
239
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
240
+ loaded. Default is None.
28
241
  """
29
242
 
30
- assert not (
31
- cross_entropy and fused_linear_cross_entropy
32
- ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
243
+ assert not (cross_entropy and fused_linear_cross_entropy), (
244
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
245
+ )
33
246
 
34
247
  from transformers.models.llama import modeling_llama
248
+ from transformers.models.llama.modeling_llama import LlamaModel
35
249
 
36
250
  if rope:
37
251
  modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
@@ -39,29 +253,439 @@ def apply_liger_kernel_to_llama(
39
253
  modeling_llama.LlamaRMSNorm = LigerRMSNorm
40
254
  if swiglu:
41
255
  modeling_llama.LlamaMLP = LigerSwiGLUMLP
256
+
257
+ if cross_entropy:
258
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
259
+ from transformers.loss.loss_utils import nn
260
+
261
+ nn.functional.cross_entropy = liger_cross_entropy
262
+ else:
263
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
264
+ modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
265
+
266
+ if fused_linear_cross_entropy:
267
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
268
+ if model is not None:
269
+ model.forward = MethodType(llama_lce_forward, model)
270
+ else:
271
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
272
+ else: # if version < 4.46.1
273
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
274
+ if model is not None:
275
+ model.forward = MethodType(llama_lce_forward_deprecated, model)
276
+ else:
277
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
278
+
279
+ if model is not None:
280
+ # The model instance already exists, so we need to additionally patch the
281
+ # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
282
+
283
+ # get the base model from the model instance
284
+ base_model: LlamaModel = getattr(model, model.base_model_prefix, model)
285
+
286
+ if rms_norm:
287
+ _patch_rms_norm_module(base_model.norm)
288
+
289
+ for decoder_layer in base_model.layers:
290
+ if swiglu:
291
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
292
+ if rms_norm:
293
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
294
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
295
+
296
+
297
+ def apply_liger_kernel_to_smollm3(
298
+ rope: bool = True,
299
+ cross_entropy: bool = False,
300
+ fused_linear_cross_entropy: bool = True,
301
+ rms_norm: bool = True,
302
+ swiglu: bool = True,
303
+ model: PreTrainedModel = None,
304
+ ) -> None:
305
+ """
306
+ Apply Liger kernels to replace original implementation in HuggingFace SmolLM3 model
307
+
308
+ Args:
309
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
310
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
311
+ fused_linear_cross_entropy (bool):
312
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
313
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
314
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
315
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
316
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
317
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
318
+ loaded. Default is None.
319
+ """
320
+
321
+ assert not (cross_entropy and fused_linear_cross_entropy), (
322
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
323
+ )
324
+
325
+ from transformers.models.smollm3 import modeling_smollm3
326
+ from transformers.models.smollm3.modeling_smollm3 import SmolLM3Model
327
+
328
+ if rope:
329
+ modeling_smollm3.apply_rotary_pos_emb = liger_rotary_pos_emb
330
+ if rms_norm:
331
+ modeling_smollm3.SmolLM3RMSNorm = LigerRMSNorm
332
+ if swiglu:
333
+ modeling_smollm3.SmolLM3MLP = LigerSwiGLUMLP
334
+
335
+ if cross_entropy:
336
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
337
+ from transformers.loss.loss_utils import nn
338
+
339
+ nn.functional.cross_entropy = liger_cross_entropy
340
+ else:
341
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
342
+ modeling_smollm3.CrossEntropyLoss = LigerCrossEntropyLoss
343
+
344
+ if fused_linear_cross_entropy:
345
+ if model is not None:
346
+ model.forward = MethodType(smollm3_lce_forward, model)
347
+ else:
348
+ modeling_smollm3.SmolLM3ForCausalLM.forward = smollm3_lce_forward
349
+
350
+ if model is not None:
351
+ # The model instance already exists, so we need to additionally patch the
352
+ # instance variables that reference already-instantiated modules (e.g. SmolLM3RMSNorm or SmolLM3MLP)
353
+
354
+ # get the base model from the model instance
355
+ base_model: SmolLM3Model = getattr(model, model.base_model_prefix, model)
356
+
357
+ if rms_norm:
358
+ _patch_rms_norm_module(base_model.norm)
359
+
360
+ for decoder_layer in base_model.layers:
361
+ if swiglu:
362
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
363
+ if rms_norm:
364
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
365
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
366
+
367
+
368
+ def apply_liger_kernel_to_llava(
369
+ cross_entropy: bool = False,
370
+ fused_linear_cross_entropy: bool = True,
371
+ model: PreTrainedModel = None,
372
+ **kwargs,
373
+ ) -> None:
374
+ """
375
+ Apply Liger kernels to replace original implementation in HuggingFace Llava models.
376
+ Due to the characteristics of LlaVa, the model must be passed to apply Liger-Kernel's patch to other models connected to LLaVa.
377
+ However, if an LM not supported by Liger-Kernel is connected to LLaVa, unexpected side effects may occur.
378
+ NOTE: Llava is not available in transformers<4.36.0
379
+
380
+ Args:
381
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
382
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
383
+ fused_linear_cross_entropy (bool):
384
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
385
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
386
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
387
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
388
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
389
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
390
+ loaded. Default is None.
391
+ """
392
+ assert not (cross_entropy and fused_linear_cross_entropy), (
393
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
394
+ )
395
+
396
+ from transformers.models.llava import modeling_llava
397
+
398
+ if cross_entropy:
399
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
400
+ modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
401
+ if fused_linear_cross_entropy:
402
+ if transformer_version >= version.parse("4.52.0"):
403
+ if model is not None:
404
+ model.forward = MethodType(llava_lce_forward, model)
405
+ else:
406
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
407
+ elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
408
+ if model is not None:
409
+ model.forward = MethodType(llava_lce_forward_deprecated, model)
410
+ else:
411
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
412
+ else: # if version < 4.49.0
413
+ logger.warning(
414
+ "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
415
+ )
416
+
417
+ if model is not None:
418
+ text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
419
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
420
+ vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)
421
+
422
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs}
423
+ if text_liger_fn:
424
+ accept_params = inspect.signature(text_liger_fn).parameters
425
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
426
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
427
+
428
+ if remain_params:
429
+ logger.warning(
430
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
431
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
432
+ )
433
+ text_kwargs["model"] = model.language_model
434
+ text_liger_fn(**text_kwargs)
435
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
436
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
437
+
438
+ if vision_liger_fn:
439
+ accept_params = inspect.signature(vision_liger_fn).parameters
440
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
441
+ vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
442
+
443
+ if remain_params:
444
+ logger.warning(
445
+ f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
446
+ f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
447
+ )
448
+ vision_kwargs["model"] = model.vision_tower
449
+ vision_liger_fn(**vision_kwargs)
450
+ elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
451
+ logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
452
+
453
+
454
+ def apply_liger_kernel_to_llama4(
455
+ rope: bool = True,
456
+ cross_entropy: bool = False,
457
+ fused_linear_cross_entropy: bool = True,
458
+ rms_norm: bool = True,
459
+ swiglu: bool = True,
460
+ model: PreTrainedModel = None,
461
+ layer_norm: bool = True,
462
+ ) -> None:
463
+ """
464
+ Apply Liger kernels to replace original implementation in HuggingFace Llama4 models.
465
+
466
+ Args:
467
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
468
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
469
+ fused_linear_cross_entropy (bool):
470
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
471
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
472
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
473
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
474
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
475
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
476
+ loaded. Default is None.
477
+ """
478
+ assert not (cross_entropy and fused_linear_cross_entropy), (
479
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
480
+ )
481
+
482
+ from transformers.models.llama4 import modeling_llama4
483
+ from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
484
+ from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
485
+ from transformers.models.llama4.modeling_llama4 import Llama4TextModel
486
+ from transformers.models.llama4.modeling_llama4 import Llama4VisionModel
487
+
488
+ from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
489
+
490
+ if rope:
491
+ from liger_kernel.transformers.llama4_rope import apply_liger_llama4_rope_full
492
+
493
+ apply_liger_llama4_rope_full(modeling_llama4)
494
+ if rms_norm:
495
+ modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
496
+ if swiglu:
497
+ modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
498
+
499
+ if cross_entropy:
500
+ modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
501
+
502
+ if fused_linear_cross_entropy:
503
+ modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
504
+
505
+ if model is not None:
506
+ # The model instance already exists, so we need to additionally patch the
507
+ # instance variables that reference already-instantiated modules
508
+ if isinstance(model, Llama4ForConditionalGeneration):
509
+ language_model: Llama4ForCausalLM = model.language_model
510
+ vision_model: Llama4VisionModel = model.vision_model
511
+ text_model: Llama4TextModel = language_model.model
512
+ elif isinstance(model, Llama4ForCausalLM):
513
+ text_model = model.model
514
+ vision_model = None
515
+ elif isinstance(model, Llama4TextModel):
516
+ text_model = model
517
+ vision_model = None
518
+
519
+ else:
520
+ raise ValueError(f"Unsupported Llama4 model type: {type(model)}")
521
+
522
+ if text_model:
523
+ if rms_norm:
524
+ _patch_rms_norm_module(text_model.norm)
525
+ for decoder_layer in text_model.layers:
526
+ if swiglu:
527
+ if decoder_layer.is_moe_layer:
528
+ _patch_swiglu_module(decoder_layer.feed_forward.shared_expert, LigerSwiGLUMLP)
529
+ else:
530
+ _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
531
+ if rms_norm:
532
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
533
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
534
+
535
+ if vision_model:
536
+ _patch_layer_norm_module(vision_model.layernorm_pre)
537
+ _patch_layer_norm_module(vision_model.layernorm_post)
538
+
539
+ for layer in vision_model.model.layers:
540
+ if layer_norm:
541
+ _patch_layer_norm_module(layer.input_layernorm)
542
+ _patch_layer_norm_module(layer.post_attention_layernorm)
543
+
544
+
545
+ def apply_liger_kernel_to_mllama(
546
+ rope: bool = True,
547
+ cross_entropy: bool = False,
548
+ fused_linear_cross_entropy: bool = True,
549
+ layer_norm: bool = True,
550
+ rms_norm: bool = True,
551
+ swiglu: bool = True,
552
+ model: PreTrainedModel = None,
553
+ ) -> None:
554
+ """
555
+ Apply Liger kernels to replace original implementation in HuggingFace MLlama models.
556
+ NOTE: MLlama is not available in transformers<4.45.0
557
+
558
+ Args:
559
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
560
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
561
+ fused_linear_cross_entropy (bool):
562
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
563
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
564
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
565
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
566
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
567
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
568
+ loaded. Default is None.
569
+ """
570
+
571
+ assert not (cross_entropy and fused_linear_cross_entropy), (
572
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
573
+ )
574
+
575
+ from transformers.models.mllama import modeling_mllama
576
+ from transformers.models.mllama.modeling_mllama import MllamaForCausalLM
577
+ from transformers.models.mllama.modeling_mllama import MllamaForConditionalGeneration
578
+ from transformers.models.mllama.modeling_mllama import MllamaTextModel
579
+ from transformers.models.mllama.modeling_mllama import MllamaVisionModel
580
+
581
+ from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward
582
+ from liger_kernel.transformers.model.mllama import lce_forward_deprecated as mllama_lce_forward_deprecated
583
+
584
+ if rope:
585
+ modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
586
+ if layer_norm and model is None:
587
+ modeling_mllama.nn.LayerNorm = LigerLayerNorm
588
+ if rms_norm:
589
+ modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
590
+ if swiglu:
591
+ modeling_mllama.MllamaTextMLP = LigerSwiGLUMLP
42
592
  if cross_entropy:
43
- modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
593
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
594
+ from transformers.loss.loss_utils import nn
595
+
596
+ nn.functional.cross_entropy = liger_cross_entropy
597
+ else:
598
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
599
+ modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
44
600
  if fused_linear_cross_entropy:
45
- modeling_llama.LlamaForCausalLM.forward = lce_forward
601
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
602
+ if model is not None:
603
+ model.forward = MethodType(mllama_lce_forward, model)
604
+ else:
605
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
606
+ else: # if version < 4.46.1
607
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
608
+ if model is not None:
609
+ model.forward = MethodType(mllama_lce_forward_deprecated, model)
610
+ else:
611
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
612
+
613
+ if model is not None:
614
+ # The model instance already exists, so we need to additionally patch the
615
+ # instance variables that reference already-instantiated modules
616
+
617
+ if isinstance(model, MllamaForConditionalGeneration):
618
+ language_model: MllamaForCausalLM = model.language_model
619
+ vision_model: MllamaVisionModel = model.vision_model
620
+ if isinstance(language_model, MllamaForCausalLM):
621
+ text_model: MllamaTextModel = language_model.model
622
+ else:
623
+ text_model = language_model
624
+ elif isinstance(model, MllamaForCausalLM):
625
+ text_model = model.model
626
+ vision_model = None
627
+ elif isinstance(model, MllamaTextModel):
628
+ text_model = model
629
+ vision_model = None
630
+
631
+ else:
632
+ raise ValueError(f"Unsupported Mllama model type: {type(model)}")
633
+
634
+ if text_model:
635
+ if rms_norm:
636
+ _patch_rms_norm_module(text_model.norm)
637
+ for decoder_layer in text_model.layers:
638
+ if swiglu:
639
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
640
+ if rms_norm:
641
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
642
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
643
+
644
+ if vision_model:
645
+ _patch_layer_norm_module(vision_model.layernorm_pre)
646
+ _patch_layer_norm_module(vision_model.layernorm_post)
647
+
648
+ for layer in vision_model.transformer.layers:
649
+ if layer_norm:
650
+ _patch_layer_norm_module(layer.input_layernorm)
651
+ _patch_layer_norm_module(layer.post_attention_layernorm)
652
+
653
+ for layer in vision_model.global_transformer.layers:
654
+ if layer_norm:
655
+ _patch_layer_norm_module(layer.input_layernorm)
656
+ _patch_layer_norm_module(layer.post_attention_layernorm)
46
657
 
47
658
 
48
659
  def apply_liger_kernel_to_mistral(
49
660
  rope: bool = True,
50
- cross_entropy: bool = True,
661
+ cross_entropy: bool = False,
662
+ fused_linear_cross_entropy: bool = True,
51
663
  rms_norm: bool = True,
52
664
  swiglu: bool = True,
665
+ model: PreTrainedModel = None,
53
666
  ) -> None:
54
667
  """
55
668
  Apply Liger kernels to replace original implementation in HuggingFace Mistral models
56
669
 
57
670
  Args:
58
- rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
671
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
59
672
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
673
+ fused_linear_cross_entropy (bool):
674
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
675
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
676
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
677
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
60
678
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
61
679
  swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
680
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
681
+ loaded. Default is None.
62
682
  """
683
+ assert not (cross_entropy and fused_linear_cross_entropy), (
684
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
685
+ )
63
686
 
64
687
  from transformers.models.mistral import modeling_mistral
688
+ from transformers.models.mistral.modeling_mistral import MistralModel
65
689
 
66
690
  if rope:
67
691
  modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
@@ -69,62 +693,2233 @@ def apply_liger_kernel_to_mistral(
69
693
  modeling_mistral.MistralRMSNorm = LigerRMSNorm
70
694
  if cross_entropy:
71
695
  modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
696
+ if fused_linear_cross_entropy:
697
+ if transformer_version >= version.parse("4.49.0"):
698
+ if model is not None:
699
+ model.forward = MethodType(mistral_lce_forward, model)
700
+ else:
701
+ modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
702
+ else:
703
+ logger.warning(
704
+ "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
705
+ )
706
+ logger.warning("LigerFusedLinearCrossEntropy patch is not applied.")
707
+
72
708
  if swiglu:
73
709
  modeling_mistral.MistralMLP = LigerSwiGLUMLP
74
710
 
711
+ if model is not None:
712
+ # The model instance already exists, so we need to additionally patch the
713
+ # instance variables that reference already-instantiated modules
714
+
715
+ # get the base model from the model instance
716
+ base_model: MistralModel = getattr(model, model.base_model_prefix, model)
717
+
718
+ if rms_norm:
719
+ _patch_rms_norm_module(base_model.norm)
720
+
721
+ for decoder_layer in base_model.layers:
722
+ if swiglu:
723
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
724
+ if rms_norm:
725
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
726
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
727
+
75
728
 
76
729
  def apply_liger_kernel_to_mixtral(
77
730
  rope: bool = True,
78
- cross_entropy: bool = True,
731
+ cross_entropy: bool = False,
732
+ fused_linear_cross_entropy: bool = True,
79
733
  rms_norm: bool = True,
80
734
  swiglu: bool = True,
735
+ model: PreTrainedModel = None,
81
736
  ) -> None:
82
737
  """
83
738
  Apply Liger kernels to replace original implementation in HuggingFace Mixtral models
84
739
 
85
740
  Args:
86
741
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
87
- cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
742
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
743
+ fused_linear_cross_entropy (bool):
744
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
745
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
746
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
88
747
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
89
748
  swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
749
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
750
+ loaded. Default is None.
90
751
  """
91
752
 
753
+ assert not (cross_entropy and fused_linear_cross_entropy), (
754
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
755
+ )
756
+
92
757
  from transformers.models.mixtral import modeling_mixtral
758
+ from transformers.models.mixtral.modeling_mixtral import MixtralModel
93
759
 
94
760
  if rope:
95
761
  modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
96
762
  if rms_norm:
97
- modeling_mixtral.MistralRMSNorm = LigerRMSNorm
763
+ modeling_mixtral.MixtralRMSNorm = LigerRMSNorm
98
764
  if cross_entropy:
99
- modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
765
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
766
+ from transformers.loss.loss_utils import nn
767
+
768
+ nn.functional.cross_entropy = liger_cross_entropy
769
+ else:
770
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
771
+ modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
772
+
773
+ if fused_linear_cross_entropy:
774
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
775
+ if model is not None:
776
+ model.forward = MethodType(mixtral_lce_forward, model)
777
+ else:
778
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
779
+ else: # if version < 4.46.1
780
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
781
+ if model is not None:
782
+ model.forward = MethodType(mixtral_lce_forward_deprecated, model)
783
+ else:
784
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
100
785
  if swiglu:
101
786
  modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
102
787
 
788
+ if model is not None:
789
+ # The model instance already exists, so we need to additionally patch the
790
+ # instance variables that reference already-instantiated modules
791
+
792
+ # get the base model from the model instance
793
+ base_model: MixtralModel = getattr(model, model.base_model_prefix, model)
794
+
795
+ if rms_norm:
796
+ _patch_rms_norm_module(base_model.norm)
797
+
798
+ for decoder_layer in base_model.layers:
799
+ if swiglu:
800
+ for expert in decoder_layer.block_sparse_moe.experts:
801
+ _patch_swiglu_module(expert, LigerBlockSparseTop2MLP)
802
+ if rms_norm:
803
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
804
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
805
+
103
806
 
104
807
  def apply_liger_kernel_to_gemma(
105
808
  rope: bool = True,
106
- cross_entropy: bool = True,
809
+ cross_entropy: bool = False,
810
+ fused_linear_cross_entropy: bool = True,
107
811
  rms_norm: bool = True,
108
812
  geglu: bool = True,
813
+ model: PreTrainedModel = None,
109
814
  ) -> None:
110
815
  """
111
- Apply Liger kernels to replace original implementation in HuggingFace Gemma2 models
112
- to make GPU go burrr.
816
+ Apply Liger kernels to replace original implementation in HuggingFace Gemma
817
+ (Gemma 1 and 1.1 supported, for Gemma2 please use `apply_liger_kernel_to_gemma2` ) to make GPU go burrr.
113
818
 
114
819
  Args:
115
820
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
116
- cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
821
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
822
+ fused_linear_cross_entropy (bool):
823
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
824
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
825
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
117
826
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
118
827
  geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
828
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
829
+ loaded. Default is None.
119
830
  """
120
- # TODO(yundai424): add convergence test for gemma
831
+ assert not (cross_entropy and fused_linear_cross_entropy), (
832
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
833
+ )
834
+
121
835
  from transformers.models.gemma import modeling_gemma
836
+ from transformers.models.gemma.modeling_gemma import GemmaModel
837
+
838
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma
839
+
840
+ _patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
122
841
 
123
842
  if rope:
124
843
  modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
125
844
  if rms_norm:
126
- modeling_gemma.GemmaRMSNorm = LigerRMSNorm
845
+ modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma
127
846
  if cross_entropy:
128
- modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
847
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
848
+ from transformers.loss.loss_utils import nn
849
+
850
+ nn.functional.cross_entropy = liger_cross_entropy
851
+ else:
852
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
853
+ modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
129
854
  if geglu:
130
855
  modeling_gemma.GemmaMLP = LigerGEGLUMLP
856
+ if fused_linear_cross_entropy:
857
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
858
+ if model is not None:
859
+ model.forward = MethodType(gemma_lce_forward, model)
860
+ else:
861
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
862
+ else: # if version < 4.46.1
863
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
864
+ if model is not None:
865
+ model.forward = MethodType(gemma_lce_forward_deprecated, model)
866
+ else:
867
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
868
+
869
+ if model is not None:
870
+ # The model instance already exists, so we need to additionally patch the
871
+ # instance variables that reference already-instantiated modules
872
+
873
+ # get the base model from the model instance
874
+ base_model: GemmaModel = getattr(model, model.base_model_prefix, model)
875
+
876
+ if rms_norm:
877
+ _patch_rms_norm_module_for_gemma(base_model.norm)
878
+
879
+ for decoder_layer in base_model.layers:
880
+ if geglu:
881
+ _patch_geglu_module(decoder_layer.mlp)
882
+ if rms_norm:
883
+ _patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
884
+ _patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
885
+
886
+
887
+ def apply_liger_kernel_to_gemma2(
888
+ rope: bool = True,
889
+ cross_entropy: bool = False,
890
+ fused_linear_cross_entropy: bool = True,
891
+ rms_norm: bool = True,
892
+ geglu: bool = True,
893
+ model: PreTrainedModel = None,
894
+ ) -> None:
895
+ """
896
+ Apply Liger kernels to replace original implementation in HuggingFace Gemma2
897
+ (for Gemma1 please use `apply_liger_kernel_to_gemma`) to make GPU go burrr.
898
+
899
+ Args:
900
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
901
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
902
+ fused_linear_cross_entropy (bool):
903
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
904
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
905
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
906
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
907
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
908
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
909
+ loaded. Default is None.
910
+ """
911
+ assert not (cross_entropy and fused_linear_cross_entropy), (
912
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
913
+ )
914
+
915
+ from transformers.models.gemma2 import modeling_gemma2
916
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
917
+
918
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2
919
+
920
+ _patch_rms_norm_module_for_gemma2 = partial(
921
+ _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
922
+ )
923
+
924
+ if rope:
925
+ modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
926
+ if rms_norm:
927
+ # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
928
+ modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2
929
+ if cross_entropy:
930
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
931
+ from transformers.loss.loss_utils import nn
932
+
933
+ nn.functional.cross_entropy = liger_cross_entropy
934
+ else:
935
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
936
+ modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
937
+ if fused_linear_cross_entropy:
938
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
939
+ if model is not None:
940
+ model.forward = MethodType(gemma2_lce_forward, model)
941
+ else:
942
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
943
+ else:
944
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
945
+ if model is not None:
946
+ model.forward = MethodType(gemma2_lce_forward_deprected, model)
947
+ else:
948
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
949
+ if geglu:
950
+ modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
951
+
952
+ if model is not None:
953
+ # The model instance already exists, so we need to additionally patch the
954
+ # instance variables that reference already-instantiated modules
955
+
956
+ # get the base model from the model instance
957
+ base_model: Gemma2Model = getattr(model, model.base_model_prefix, model)
958
+
959
+ if rms_norm:
960
+ _patch_rms_norm_module_for_gemma2(base_model.norm)
961
+
962
+ for decoder_layer in base_model.layers:
963
+ if geglu:
964
+ _patch_geglu_module(decoder_layer.mlp)
965
+ if rms_norm:
966
+ _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
967
+ _patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
968
+ _patch_rms_norm_module_for_gemma2(decoder_layer.pre_feedforward_layernorm)
969
+ _patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
970
+
971
+
972
+ def apply_liger_kernel_to_gemma3_text(
973
+ rope: bool = True,
974
+ cross_entropy: bool = False,
975
+ fused_linear_cross_entropy: bool = True,
976
+ rms_norm: bool = True,
977
+ geglu: bool = True,
978
+ model: PreTrainedModel = None,
979
+ ) -> None:
980
+ """
981
+ Apply Liger kernels to replace original implementation in HuggingFace Gemma3
982
+
983
+ Args:
984
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
985
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
986
+ fused_linear_cross_entropy (bool):
987
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
988
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
989
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
990
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
991
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
992
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
993
+ loaded. Default is None.
994
+ """
995
+ assert not (cross_entropy and fused_linear_cross_entropy), (
996
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
997
+ )
998
+
999
+ from transformers.models.gemma3 import modeling_gemma3
1000
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
1001
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
1002
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
1003
+
1004
+ from liger_kernel.transformers.model.gemma3 import causal_forward
1005
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma3
1006
+
1007
+ _patch_rms_norm_module_for_gemma3 = partial(
1008
+ _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
1009
+ )
1010
+
1011
+ if rope:
1012
+ modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
1013
+
1014
+ if rms_norm:
1015
+ modeling_gemma3.Gemma3RMSNorm = LigerRMSNormForGemma3
1016
+
1017
+ if geglu:
1018
+ modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
1019
+
1020
+ # Handle loss function
1021
+ if cross_entropy:
1022
+ from transformers.loss.loss_utils import nn
1023
+
1024
+ nn.functional.cross_entropy = liger_cross_entropy
1025
+
1026
+ if fused_linear_cross_entropy:
1027
+ if model is not None:
1028
+ model.forward = MethodType(causal_forward, model)
1029
+ else:
1030
+ modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
1031
+
1032
+ if model is not None:
1033
+ # The model instance already exists, so we need to additionally patch the
1034
+ # instance variables that reference already-instantiated modules
1035
+
1036
+ if isinstance(model, Gemma3ForCausalLM) or isinstance(model, Gemma3TextModel):
1037
+ # get the base model from the model instance
1038
+ base_model = model.model if isinstance(model, Gemma3ForCausalLM) else model
1039
+
1040
+ if rms_norm:
1041
+ _patch_rms_norm_module_for_gemma3(base_model.norm)
1042
+
1043
+ for decoder_layer in base_model.layers:
1044
+ decoder_layer: Gemma3DecoderLayer
1045
+ if geglu:
1046
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
1047
+ if rms_norm:
1048
+ _patch_rms_norm_module_for_gemma3(decoder_layer.input_layernorm)
1049
+ _patch_rms_norm_module_for_gemma3(decoder_layer.post_attention_layernorm)
1050
+ _patch_rms_norm_module_for_gemma3(decoder_layer.pre_feedforward_layernorm)
1051
+ _patch_rms_norm_module_for_gemma3(decoder_layer.post_feedforward_layernorm)
1052
+ _patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.q_norm)
1053
+ _patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.k_norm)
1054
+
1055
+ else:
1056
+ raise TypeError("The model must be Gemma3ForCausalLM.")
1057
+
1058
+
1059
+ def apply_liger_kernel_to_gemma3(
1060
+ rope: bool = True,
1061
+ cross_entropy: bool = False,
1062
+ fused_linear_cross_entropy: bool = True,
1063
+ layer_norm: bool = True,
1064
+ rms_norm: bool = True,
1065
+ geglu: bool = True,
1066
+ model: PreTrainedModel = None,
1067
+ ) -> None:
1068
+ """
1069
+ Apply Liger kernels to replace original implementation in HuggingFace Gemma3
1070
+
1071
+ Args:
1072
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1073
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1074
+ fused_linear_cross_entropy (bool):
1075
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1076
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1077
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1078
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
1079
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1080
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
1081
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1082
+ loaded. Default is None.
1083
+ """
1084
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1085
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1086
+ )
1087
+
1088
+ from transformers.models.gemma3 import modeling_gemma3
1089
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3ForConditionalGeneration
1090
+ from transformers.models.siglip import modeling_siglip
1091
+ from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
1092
+ from transformers.models.siglip.modeling_siglip import SiglipVisionModel
1093
+
1094
+ from liger_kernel.transformers.model.gemma3 import multimodal_forward
1095
+
1096
+ _patch_rms_norm_module_for_gemma3 = partial(
1097
+ _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
1098
+ )
1099
+
1100
+ if layer_norm and model is None:
1101
+ modeling_siglip.nn.LayerNorm = LigerLayerNorm
1102
+
1103
+ apply_liger_kernel_to_gemma3_text(
1104
+ rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
1105
+ )
1106
+
1107
+ if cross_entropy:
1108
+ modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
1109
+
1110
+ if fused_linear_cross_entropy:
1111
+ if model is not None:
1112
+ model.forward = MethodType(multimodal_forward, model)
1113
+ else:
1114
+ modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
1115
+
1116
+ if model is not None:
1117
+ # The model instance already exists, so we need to additionally patch the
1118
+ # instance variables that reference already-instantiated modules
1119
+
1120
+ if isinstance(model, Gemma3ForConditionalGeneration):
1121
+ if isinstance(model.vision_tower, SiglipVisionModel):
1122
+ vision_tower = model.vision_tower
1123
+
1124
+ _patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
1125
+
1126
+ for layer in vision_tower.vision_model.encoder.layers:
1127
+ layer: SiglipEncoderLayer
1128
+ if layer_norm:
1129
+ _patch_layer_norm_module(layer.layer_norm1)
1130
+ _patch_layer_norm_module(layer.layer_norm2)
1131
+ else:
1132
+ raise TypeError("The vision tower must be SiglipVisionModel")
1133
+
1134
+ if rms_norm:
1135
+ _patch_rms_norm_module_for_gemma3(model.multi_modal_projector.mm_soft_emb_norm)
1136
+
1137
+ apply_liger_kernel_to_gemma3_text(
1138
+ rope=rope,
1139
+ cross_entropy=False,
1140
+ fused_linear_cross_entropy=False,
1141
+ rms_norm=rms_norm,
1142
+ geglu=geglu,
1143
+ model=model.language_model,
1144
+ )
1145
+
1146
+ else:
1147
+ raise TypeError("The model must be Gemma3ForConditionalGeneration.")
1148
+
1149
+
1150
+ def apply_liger_kernel_to_paligemma(
1151
+ rope: bool = True,
1152
+ cross_entropy: bool = False,
1153
+ fused_linear_cross_entropy: bool = True,
1154
+ layer_norm: bool = True,
1155
+ rms_norm: bool = True,
1156
+ geglu: bool = True,
1157
+ model: PreTrainedModel = None,
1158
+ ) -> None:
1159
+ """
1160
+ Apply Liger kernels to replace original implementation in HuggingFace PaliGemma
1161
+
1162
+ Args:
1163
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1164
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1165
+ fused_linear_cross_entropy (bool):
1166
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1167
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1168
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1169
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
1170
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1171
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
1172
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1173
+ loaded. Default is None.
1174
+ """
1175
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1176
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1177
+ )
1178
+
1179
+ # PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
1180
+
1181
+ from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
1182
+ from transformers.models.gemma.modeling_gemma import GemmaModel
1183
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
1184
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
1185
+ from transformers.models.paligemma import modeling_paligemma
1186
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
1187
+ from transformers.models.siglip import modeling_siglip
1188
+ from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
1189
+ from transformers.models.siglip.modeling_siglip import SiglipVisionModel
1190
+
1191
+ from liger_kernel.transformers.model.paligemma import lce_forward
1192
+ from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
1193
+
1194
+ # The vision_tower is a SiglipVisionModel
1195
+ if layer_norm and model is None:
1196
+ modeling_siglip.nn.LayerNorm = LigerLayerNorm
1197
+
1198
+ # SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
1199
+ # The multi_modal_projector is Linear, nothing to do
1200
+
1201
+ # The language_model is GemmaForCausalLM or Gemma2ForCausalLM
1202
+ apply_liger_kernel_to_gemma(
1203
+ rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
1204
+ )
1205
+ apply_liger_kernel_to_gemma2(
1206
+ rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
1207
+ )
1208
+ # Handle loss function
1209
+ if cross_entropy:
1210
+ modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
1211
+ if fused_linear_cross_entropy:
1212
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1213
+ if model is not None:
1214
+ model.forward = MethodType(lce_forward, model)
1215
+ else:
1216
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
1217
+ else: # if version < 4.46.1
1218
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1219
+ if model is not None:
1220
+ model.forward = MethodType(lce_forward_deprecated, model)
1221
+ else:
1222
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
1223
+
1224
+ if model is not None:
1225
+ # The model instance already exists, so we need to additionally patch the
1226
+ # instance variables that reference already-instantiated modules
1227
+
1228
+ if not isinstance(model, PaliGemmaForConditionalGeneration):
1229
+ raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
1230
+
1231
+ vision_tower: SiglipVisionModel = model.vision_tower
1232
+
1233
+ _patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
1234
+
1235
+ for layer in vision_tower.vision_model.encoder.layers:
1236
+ layer: SiglipEncoderLayer
1237
+ if layer_norm:
1238
+ _patch_layer_norm_module(layer.layer_norm1)
1239
+ _patch_layer_norm_module(layer.layer_norm2)
1240
+
1241
+ language_model = model.language_model
1242
+
1243
+ if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
1244
+ apply_liger_kernel_to_gemma(
1245
+ rope=rope,
1246
+ cross_entropy=False,
1247
+ fused_linear_cross_entropy=False,
1248
+ rms_norm=rms_norm,
1249
+ geglu=geglu,
1250
+ model=language_model,
1251
+ )
1252
+
1253
+ elif isinstance(language_model, (Gemma2ForCausalLM, Gemma2Model)):
1254
+ apply_liger_kernel_to_gemma2(
1255
+ rope=rope,
1256
+ cross_entropy=False,
1257
+ fused_linear_cross_entropy=False,
1258
+ rms_norm=rms_norm,
1259
+ geglu=geglu,
1260
+ model=language_model,
1261
+ )
1262
+ else:
1263
+ raise TypeError(
1264
+ "The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM."
1265
+ )
1266
+
1267
+
1268
+ def apply_liger_kernel_to_qwen2(
1269
+ rope: bool = True,
1270
+ cross_entropy: bool = False,
1271
+ fused_linear_cross_entropy: bool = True,
1272
+ rms_norm: bool = True,
1273
+ swiglu: bool = True,
1274
+ model: PreTrainedModel = None,
1275
+ ) -> None:
1276
+ """
1277
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models
1278
+
1279
+ Args:
1280
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1281
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1282
+ fused_linear_cross_entropy (bool):
1283
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1284
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1285
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1286
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1287
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
1288
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1289
+ loaded. Default is None.
1290
+ """
1291
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1292
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1293
+ )
1294
+
1295
+ from transformers.models.qwen2 import modeling_qwen2
1296
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
1297
+
1298
+ if rope:
1299
+ modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
1300
+ if rms_norm:
1301
+ modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
1302
+
1303
+ if cross_entropy:
1304
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1305
+ from transformers.loss.loss_utils import nn
1306
+
1307
+ nn.functional.cross_entropy = liger_cross_entropy
1308
+ else:
1309
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1310
+ modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
1311
+
1312
+ if fused_linear_cross_entropy:
1313
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1314
+ if model is not None:
1315
+ model.forward = MethodType(qwen2_lce_forward, model)
1316
+ else:
1317
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
1318
+ else: # if version < 4.46.1
1319
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1320
+ if model is not None:
1321
+ model.forward = MethodType(qwen2_lce_forward_deprecated, model)
1322
+ else:
1323
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
1324
+
1325
+ if swiglu:
1326
+ modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
1327
+
1328
+ if model is not None:
1329
+ # The model instance already exists, so we need to additionally patch the
1330
+ # instance variables that reference already-instantiated modules
1331
+
1332
+ # get the base model from the model instance
1333
+ base_model: Qwen2Model = getattr(model, model.base_model_prefix, model)
1334
+
1335
+ if rms_norm:
1336
+ _patch_rms_norm_module(base_model.norm)
1337
+
1338
+ for decoder_layer in base_model.layers:
1339
+ if swiglu:
1340
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1341
+ if rms_norm:
1342
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1343
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1344
+
1345
+
1346
+ def apply_liger_kernel_to_qwen3(
1347
+ rope: bool = True,
1348
+ cross_entropy: bool = False,
1349
+ fused_linear_cross_entropy: bool = True,
1350
+ rms_norm: bool = True,
1351
+ swiglu: bool = True,
1352
+ model: PreTrainedModel = None,
1353
+ ) -> None:
1354
+ """
1355
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
1356
+ """
1357
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1358
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1359
+ )
1360
+
1361
+ from transformers.models.qwen3 import modeling_qwen3
1362
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3Model
1363
+
1364
+ from liger_kernel.transformers.model.qwen3 import lce_forward as qwen3_lce_forward
1365
+
1366
+ if rope:
1367
+ modeling_qwen3.apply_rotary_pos_emb = liger_rotary_pos_emb
1368
+
1369
+ if rms_norm:
1370
+ modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm
1371
+
1372
+ if cross_entropy:
1373
+ from transformers.loss.loss_utils import nn
1374
+
1375
+ nn.functional.cross_entropy = liger_cross_entropy
1376
+
1377
+ if fused_linear_cross_entropy:
1378
+ if model is not None:
1379
+ model.forward = MethodType(qwen3_lce_forward, model)
1380
+ else:
1381
+ modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
1382
+
1383
+ if swiglu:
1384
+ modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
1385
+
1386
+ if model is not None:
1387
+ # The model instance already exists, so we need to additionally patch the
1388
+ # instance variables that reference already-instantiated modules
1389
+
1390
+ # get the base model from the model instance
1391
+ base_model: Qwen3Model = getattr(model, model.base_model_prefix, model)
1392
+
1393
+ if rms_norm:
1394
+ _patch_rms_norm_module(base_model.norm)
1395
+ for decoder_layer in base_model.layers:
1396
+ if swiglu:
1397
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1398
+ if rms_norm:
1399
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1400
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1401
+
1402
+
1403
+ def apply_liger_kernel_to_qwen3_moe(
1404
+ rope: bool = True,
1405
+ cross_entropy: bool = False,
1406
+ fused_linear_cross_entropy: bool = True,
1407
+ rms_norm: bool = True,
1408
+ swiglu: bool = True,
1409
+ model: PreTrainedModel = None,
1410
+ ) -> None:
1411
+ """
1412
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
1413
+ """
1414
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1415
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1416
+ )
1417
+
1418
+ from transformers.models.qwen3_moe import modeling_qwen3_moe
1419
+ from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeModel
1420
+
1421
+ from liger_kernel.transformers.model.qwen3_moe import lce_forward as qwen3_lce_forward
1422
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
1423
+
1424
+ if rope:
1425
+ modeling_qwen3_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
1426
+
1427
+ if rms_norm:
1428
+ modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
1429
+
1430
+ if cross_entropy:
1431
+ from transformers.loss.loss_utils import nn
1432
+
1433
+ nn.functional.cross_entropy = liger_cross_entropy
1434
+
1435
+ if fused_linear_cross_entropy:
1436
+ if model is not None:
1437
+ model.forward = MethodType(qwen3_lce_forward, model)
1438
+ else:
1439
+ modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
1440
+
1441
+ if swiglu:
1442
+ modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
1443
+
1444
+ if model is not None:
1445
+ # The model instance already exists, so we need to additionally patch the
1446
+ # instance variables that reference already-instantiated modules
1447
+
1448
+ # get the base model from the model instance
1449
+ base_model: Qwen3MoeModel = getattr(model, model.base_model_prefix, model)
1450
+
1451
+ if rms_norm:
1452
+ _patch_rms_norm_module(base_model.norm)
1453
+ for decoder_layer in base_model.layers:
1454
+ if swiglu:
1455
+ for mlp_expert in decoder_layer.mlp.experts:
1456
+ _patch_swiglu_module(mlp_expert, LigerQwen3MoeSwiGLUMLP)
1457
+ if rms_norm:
1458
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1459
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1460
+
1461
+
1462
+ def apply_liger_kernel_to_gpt_oss(
1463
+ rope: bool = True,
1464
+ cross_entropy: bool = False,
1465
+ fused_linear_cross_entropy: bool = True,
1466
+ rms_norm: bool = True,
1467
+ swiglu: bool = False, # Set to False by default since GPT-OSS has custom expert implementation
1468
+ model: PreTrainedModel = None,
1469
+ ) -> None:
1470
+ """
1471
+ Apply Liger kernels to replace original implementation in HuggingFace GPT-OSS models.
1472
+ NOTE: GPT-OSS is supported in transformers >= 4.55.0
1473
+ NOTE: SwiGLU patching is disabled by default for GPT-OSS as it uses a custom expert
1474
+ implementation with clamping and MXFP4 quantization.
1475
+
1476
+ Args:
1477
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1478
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1479
+ fused_linear_cross_entropy (bool):
1480
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1481
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1482
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1483
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1484
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1485
+ Note: GPT-OSS uses a custom expert implementation, so SwiGLU patching is disabled by default.
1486
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1487
+ loaded. Default is None.
1488
+ """
1489
+ if version.parse(transformers.__version__) < version.parse("4.55.0"):
1490
+ logger.warning("GPT-OSS support requires transformers >= 4.55.0")
1491
+ return
1492
+
1493
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1494
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1495
+ )
1496
+
1497
+ from transformers.models.gpt_oss import modeling_gpt_oss
1498
+ from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel
1499
+
1500
+ if rope:
1501
+ modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb
1502
+
1503
+ if rms_norm:
1504
+ modeling_gpt_oss.GptOssRMSNorm = LigerRMSNorm
1505
+
1506
+ if cross_entropy:
1507
+ from transformers.loss.loss_utils import nn
1508
+
1509
+ nn.functional.cross_entropy = liger_cross_entropy
1510
+
1511
+ if fused_linear_cross_entropy:
1512
+ if model is not None:
1513
+ model.forward = MethodType(gpt_oss_lce_forward, model)
1514
+ else:
1515
+ modeling_gpt_oss.GptOssForCausalLM.forward = gpt_oss_lce_forward
1516
+
1517
+ # Note: SwiGLU patching is not implemented for GPT-OSS due to custom expert implementation
1518
+ # with clamping (swiglu_limit=7.0) and MXFP4 quantization
1519
+
1520
+ if model is not None:
1521
+ # The model instance already exists, so we need to additionally patch the
1522
+ # instance variables that reference already-instantiated modules
1523
+
1524
+ # get the base model from the model instance
1525
+ base_model: GptOssModel = getattr(model, model.base_model_prefix, model)
1526
+
1527
+ if rms_norm:
1528
+ _patch_rms_norm_module(base_model.norm)
1529
+ for decoder_layer in base_model.layers:
1530
+ if rms_norm:
1531
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1532
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1533
+
1534
+
1535
+ def apply_liger_kernel_to_qwen2_vl(
1536
+ rope: bool = True,
1537
+ cross_entropy: bool = False,
1538
+ fused_linear_cross_entropy: bool = True,
1539
+ rms_norm: bool = True,
1540
+ layer_norm: bool = True,
1541
+ swiglu: bool = True,
1542
+ model: PreTrainedModel = None,
1543
+ ) -> None:
1544
+ """
1545
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
1546
+ NOTE: Qwen2-VL is not supported in transformers<4.52.4
1547
+
1548
+ Args:
1549
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1550
+ fused_linear_cross_entropy (bool):
1551
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1552
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1553
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1554
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1555
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
1556
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
1557
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1558
+ loaded. Default is None.
1559
+ """
1560
+ if transformer_version < version.parse("4.52.4"):
1561
+ logger.warning("Qwen2-VL support is only compatible with transformers >= 4.52.4")
1562
+ return
1563
+
1564
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1565
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1566
+ )
1567
+
1568
+ from transformers.models.qwen2_vl import modeling_qwen2_vl
1569
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
1570
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
1571
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
1572
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel
1573
+
1574
+ from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
1575
+
1576
+ if rope:
1577
+ modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
1578
+ if rms_norm:
1579
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
1580
+ modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
1581
+ if layer_norm and model is None:
1582
+ modeling_qwen2_vl.LayerNorm = LigerLayerNorm
1583
+ if cross_entropy:
1584
+ modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1585
+ if fused_linear_cross_entropy:
1586
+ if model is not None:
1587
+ model.forward = MethodType(qwen2_vl_lce_forward, model)
1588
+ else:
1589
+ modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
1590
+ if swiglu:
1591
+ modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
1592
+
1593
+ if model is not None:
1594
+ # The model instance already exists, so we need to additionally patch the
1595
+ # instance variables that reference already-instantiated modules
1596
+
1597
+ if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)):
1598
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
1599
+ # Not sure if it is subject to changes in the future.
1600
+ # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
1601
+ text_model: Qwen2VLTextModel = model.language_model
1602
+ vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
1603
+ elif isinstance(model, Qwen2VLTextModel):
1604
+ text_model: Qwen2VLTextModel = model
1605
+ vision_model = None
1606
+ else:
1607
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
1608
+ raise TypeError(
1609
+ f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
1610
+ )
1611
+
1612
+ # Patch Qwen2VisionTransformerPretrainedModel
1613
+ if vision_model is not None:
1614
+ for vision_block in vision_model.blocks:
1615
+ if layer_norm:
1616
+ _patch_layer_norm_module(vision_block.norm1)
1617
+ _patch_layer_norm_module(vision_block.norm2)
1618
+
1619
+ # Patch Qwen2VisionTextModel
1620
+ if text_model is not None:
1621
+ if rms_norm:
1622
+ _patch_rms_norm_module(text_model.norm)
1623
+ for decoder_layer in text_model.layers:
1624
+ if swiglu:
1625
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1626
+ if rms_norm:
1627
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1628
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1629
+
1630
+
1631
+ def apply_liger_kernel_to_qwen2_5_vl(
1632
+ rope: bool = True,
1633
+ cross_entropy: bool = False,
1634
+ fused_linear_cross_entropy: bool = True,
1635
+ rms_norm: bool = True,
1636
+ swiglu: bool = True,
1637
+ model: PreTrainedModel = None,
1638
+ ) -> None:
1639
+ """
1640
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen2.5-VL models.
1641
+ NOTE: Qwen2.5-VL is not available in transformers<4.48.2
1642
+
1643
+ Args:
1644
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1645
+ fused_linear_cross_entropy (bool):
1646
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1647
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1648
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1649
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1650
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
1651
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1652
+ loaded. Default is None.
1653
+ """
1654
+ if transformer_version < version.parse("4.52.4"):
1655
+ logger.warning("Qwen2.5-VL support is only compatible with transformers >= 4.52.4")
1656
+ return
1657
+
1658
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1659
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1660
+ )
1661
+
1662
+ from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
1663
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
1664
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
1665
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
1666
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel
1667
+
1668
+ from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward
1669
+
1670
+ if rope:
1671
+ modeling_qwen2_5_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
1672
+ if rms_norm:
1673
+ modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm
1674
+ if cross_entropy:
1675
+ modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1676
+ if fused_linear_cross_entropy:
1677
+ if model is not None:
1678
+ model.forward = MethodType(qwen2_5_vl_lce_forward, model)
1679
+ else:
1680
+ modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
1681
+ if swiglu:
1682
+ modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
1683
+
1684
+ if model is not None:
1685
+ # The model instance already exists, so we need to additionally patch the
1686
+ # instance variables that reference already-instantiated modules
1687
+
1688
+ if isinstance(model, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel)):
1689
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
1690
+ # Not sure if it is subject to changes in the future.
1691
+ # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
1692
+ text_model: Qwen2_5_VLTextModel = model.language_model
1693
+ vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
1694
+ elif isinstance(model, Qwen2_5_VLTextModel):
1695
+ text_model: Qwen2_5_VLTextModel = model
1696
+ vision_model = None
1697
+ else:
1698
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
1699
+ raise TypeError(
1700
+ f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
1701
+ )
1702
+
1703
+ if vision_model is not None:
1704
+ # Patch Qwen2_5_VisionTransformerPretrainedModel
1705
+ for vision_block in model.visual.blocks:
1706
+ if rms_norm:
1707
+ _patch_rms_norm_module(vision_block.norm1)
1708
+ _patch_rms_norm_module(vision_block.norm2)
1709
+
1710
+ if text_model is not None:
1711
+ if rms_norm:
1712
+ _patch_rms_norm_module(text_model.norm)
1713
+ for decoder_layer in text_model.layers:
1714
+ if swiglu:
1715
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1716
+ if rms_norm:
1717
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1718
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1719
+
1720
+
1721
+ def apply_liger_kernel_to_qwen3_vl(
1722
+ rope: bool = True,
1723
+ cross_entropy: bool = False,
1724
+ fused_linear_cross_entropy: bool = True,
1725
+ rms_norm: bool = True,
1726
+ swiglu: bool = False,
1727
+ model: PreTrainedModel = None,
1728
+ ) -> None:
1729
+ """
1730
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
1731
+
1732
+ Args:
1733
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1734
+ fused_linear_cross_entropy (bool):
1735
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1736
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1737
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1738
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1739
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1740
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1741
+ loaded. Default is None.
1742
+ """
1743
+
1744
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1745
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1746
+ )
1747
+
1748
+ from transformers.models.qwen3_vl import modeling_qwen3_vl
1749
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
1750
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
1751
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
1752
+
1753
+ from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
1754
+
1755
+ if rope:
1756
+ modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb
1757
+ modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
1758
+
1759
+ if rms_norm:
1760
+ modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
1761
+
1762
+ if cross_entropy:
1763
+ from transformers.loss.loss_utils import nn
1764
+
1765
+ nn.functional.cross_entropy = liger_cross_entropy
1766
+
1767
+ if fused_linear_cross_entropy:
1768
+ if model is not None:
1769
+ model.forward = MethodType(qwen3_vl_lce_forward, model)
1770
+ else:
1771
+ modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
1772
+
1773
+ if model is not None and rms_norm:
1774
+ if isinstance(model, (Qwen3VLForConditionalGeneration, Qwen3VLModel)):
1775
+ text_model: Qwen3VLTextModel = model.language_model
1776
+ elif isinstance(model, Qwen3VLTextModel):
1777
+ text_model = model
1778
+ else:
1779
+ raise TypeError(
1780
+ f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
1781
+ )
1782
+
1783
+ _patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1784
+
1785
+ if text_model is not None:
1786
+ _patch_qwen3_vl_rms_norm(text_model.norm)
1787
+ for decoder_layer in text_model.layers:
1788
+ _patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
1789
+ _patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
1790
+ self_attn = getattr(decoder_layer, "self_attn", None)
1791
+ if self_attn is not None:
1792
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1793
+ _patch_qwen3_vl_rms_norm(self_attn.q_norm)
1794
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1795
+ _patch_qwen3_vl_rms_norm(self_attn.k_norm)
1796
+
1797
+
1798
+ def apply_liger_kernel_to_qwen3_vl_moe(
1799
+ rope: bool = True,
1800
+ cross_entropy: bool = False,
1801
+ fused_linear_cross_entropy: bool = True,
1802
+ rms_norm: bool = True,
1803
+ swiglu: bool = False,
1804
+ model: PreTrainedModel = None,
1805
+ ) -> None:
1806
+ """
1807
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
1808
+
1809
+ Args:
1810
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1811
+ fused_linear_cross_entropy (bool):
1812
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
1813
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1814
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1815
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1816
+ loaded. Default is None.
1817
+ """
1818
+
1819
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1820
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1821
+ )
1822
+
1823
+ from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
1824
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
1825
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
1826
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
1827
+
1828
+ from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
1829
+
1830
+ if rope:
1831
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
1832
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
1833
+
1834
+ if rms_norm:
1835
+ modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
1836
+
1837
+ if cross_entropy:
1838
+ from transformers.loss.loss_utils import nn
1839
+
1840
+ nn.functional.cross_entropy = liger_cross_entropy
1841
+
1842
+ if fused_linear_cross_entropy:
1843
+ if model is not None:
1844
+ model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
1845
+ else:
1846
+ modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
1847
+
1848
+ if model is not None and rms_norm:
1849
+ if isinstance(model, (Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel)):
1850
+ text_model: Qwen3VLMoeTextModel = model.language_model
1851
+ elif isinstance(model, Qwen3VLMoeTextModel):
1852
+ text_model = model
1853
+ else:
1854
+ raise TypeError(
1855
+ f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
1856
+ )
1857
+
1858
+ _patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1859
+
1860
+ if text_model is not None:
1861
+ _patch_qwen3_vl_moe_rms_norm(text_model.norm)
1862
+ for decoder_layer in text_model.layers:
1863
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
1864
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
1865
+ self_attn = getattr(decoder_layer, "self_attn", None)
1866
+ if self_attn is not None:
1867
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1868
+ _patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
1869
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1870
+ _patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
1871
+
1872
+
1873
+ def apply_liger_kernel_to_phi3(
1874
+ rope: bool = True,
1875
+ cross_entropy: bool = False,
1876
+ fused_linear_cross_entropy: bool = True,
1877
+ rms_norm: bool = True,
1878
+ swiglu: bool = True,
1879
+ model: PreTrainedModel = None,
1880
+ ) -> None:
1881
+ """
1882
+ Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
1883
+
1884
+ Args:
1885
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1886
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1887
+ fused_linear_cross_entropy (bool):
1888
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1889
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1890
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1891
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1892
+ swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
1893
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1894
+ loaded. Default is None.
1895
+ """
1896
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1897
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1898
+ )
1899
+
1900
+ from transformers.models.phi3 import modeling_phi3
1901
+ from transformers.models.phi3.modeling_phi3 import Phi3Model
1902
+
1903
+ if rope:
1904
+ modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
1905
+ if rms_norm:
1906
+ modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
1907
+ if swiglu:
1908
+ modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
1909
+ if cross_entropy:
1910
+ from transformers.loss.loss_utils import nn
1911
+
1912
+ nn.functional.cross_entropy = liger_cross_entropy
1913
+ if fused_linear_cross_entropy:
1914
+ if model is not None:
1915
+ model.forward = MethodType(phi3_lce_forward, model)
1916
+ else:
1917
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1918
+
1919
+ if model is not None:
1920
+ # The model instance already exists, so we need to additionally patch the
1921
+ # instance variables that reference already-instantiated modules
1922
+
1923
+ # get the base model from the model instance
1924
+ base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
1925
+
1926
+ if rms_norm:
1927
+ _patch_rms_norm_module(base_model.norm)
1928
+
1929
+ for decoder_layer in base_model.layers:
1930
+ if swiglu:
1931
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
1932
+ if rms_norm:
1933
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1934
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1935
+
1936
+
1937
+ def apply_liger_kernel_to_olmo2(
1938
+ rope: bool = True,
1939
+ cross_entropy: bool = False,
1940
+ fused_linear_cross_entropy: bool = True,
1941
+ rms_norm: bool = True,
1942
+ swiglu: bool = True,
1943
+ model: PreTrainedModel = None,
1944
+ ) -> None:
1945
+ """
1946
+ Apply Liger kernels to replace original implementation in HuggingFace OLMO2 models.
1947
+
1948
+ Args:
1949
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1950
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1951
+ fused_linear_cross_entropy (bool):
1952
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1953
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1954
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1955
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1956
+ swiglu (bool): Whether to apply Liger's SwiGLU Olmo2MLP. Default is True.
1957
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1958
+ loaded. Default is None.
1959
+ """
1960
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1961
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1962
+ )
1963
+
1964
+ from transformers.models.olmo2 import modeling_olmo2
1965
+ from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
1966
+
1967
+ from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
1968
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
1969
+
1970
+ if rope:
1971
+ modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
1972
+ if rms_norm:
1973
+ modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
1974
+ if swiglu:
1975
+ modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
1976
+ if cross_entropy:
1977
+ from transformers.loss.loss_utils import nn
1978
+
1979
+ nn.functional.cross_entropy = liger_cross_entropy
1980
+ if fused_linear_cross_entropy:
1981
+ if model is not None:
1982
+ model.forward = MethodType(olmo2_lce_forward, model)
1983
+ else:
1984
+ modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
1985
+
1986
+ if model is not None:
1987
+ # The model instance already exists, so we need to additionally patch the
1988
+ # instance variables that reference already-instantiated modules
1989
+
1990
+ # get the base model from the model instance
1991
+ base_model: Olmo2Model = getattr(model, model.base_model_prefix, model)
1992
+
1993
+ if rms_norm:
1994
+ _patch_rms_norm_module(base_model.norm)
1995
+
1996
+ for decoder_layer in base_model.layers:
1997
+ if swiglu:
1998
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1999
+ if rms_norm:
2000
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
2001
+ _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
2002
+
2003
+
2004
+ def apply_liger_kernel_to_olmo3(
2005
+ rope: bool = True,
2006
+ cross_entropy: bool = False,
2007
+ fused_linear_cross_entropy: bool = True,
2008
+ rms_norm: bool = True,
2009
+ swiglu: bool = True,
2010
+ model: PreTrainedModel = None,
2011
+ ) -> None:
2012
+ """
2013
+ Apply Liger kernels to replace original implementation in HuggingFace Olmo3 models.
2014
+
2015
+ Args:
2016
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2017
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2018
+ fused_linear_cross_entropy (bool):
2019
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2020
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2021
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2022
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2023
+ swiglu (bool): Whether to apply Liger's SwiGLU to Olmo3MLP. Default is True.
2024
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2025
+ loaded. Default is None.
2026
+ """
2027
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2028
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2029
+ )
2030
+
2031
+ from transformers.models.olmo3 import modeling_olmo3
2032
+ from transformers.models.olmo3.modeling_olmo3 import Olmo3Model
2033
+
2034
+ from liger_kernel.transformers.model.olmo3 import lce_forward as olmo3_lce_forward
2035
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
2036
+
2037
+ # Olmo3 arch is very similar to Olmo2, so we can reuse all these components in the same way.
2038
+ if rope:
2039
+ modeling_olmo3.apply_rotary_pos_emb = liger_rotary_pos_emb
2040
+ if rms_norm:
2041
+ modeling_olmo3.Olmo3RMSNorm = LigerRMSNormForOlmo2 # same as olmo2
2042
+ if swiglu:
2043
+ modeling_olmo3.Olmo3MLP = LigerSwiGLUMLP
2044
+ if cross_entropy:
2045
+ from transformers.loss.loss_utils import nn
2046
+
2047
+ nn.functional.cross_entropy = liger_cross_entropy
2048
+ if fused_linear_cross_entropy:
2049
+ if model is not None:
2050
+ model.forward = MethodType(olmo3_lce_forward, model)
2051
+ else:
2052
+ modeling_olmo3.Olmo3ForCausalLM.forward = olmo3_lce_forward
2053
+
2054
+ if model is not None:
2055
+ # The model instance already exists, so we need to additionally patch the
2056
+ # instance variables that reference already-instantiated modules
2057
+
2058
+ # get the base model from the model instance
2059
+ base_model: Olmo3Model = getattr(model, model.base_model_prefix, model)
2060
+
2061
+ if rms_norm:
2062
+ _patch_rms_norm_module(base_model.norm)
2063
+
2064
+ for decoder_layer in base_model.layers:
2065
+ if swiglu:
2066
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2067
+ if rms_norm:
2068
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
2069
+ _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
2070
+
2071
+
2072
+ def apply_liger_kernel_to_glm4(
2073
+ rope: bool = False,
2074
+ cross_entropy: bool = False,
2075
+ fused_linear_cross_entropy: bool = True,
2076
+ rms_norm: bool = True,
2077
+ swiglu: bool = True,
2078
+ model: PreTrainedModel = None,
2079
+ ) -> None:
2080
+ """
2081
+ Apply Liger kernels to replace original implementation in HuggingFace GLM-4 models.
2082
+
2083
+ Args:
2084
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2085
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2086
+ fused_linear_cross_entropy (bool):
2087
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2088
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2089
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2090
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2091
+ swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
2092
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2093
+ loaded. Default is None.
2094
+ """
2095
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2096
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2097
+ )
2098
+
2099
+ from transformers.models.glm4 import modeling_glm4
2100
+ from transformers.models.glm4.modeling_glm4 import Glm4Model
2101
+
2102
+ from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
2103
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
2104
+
2105
+ if rope:
2106
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
2107
+ if rms_norm:
2108
+ modeling_glm4.Glm4RMSNorm = LigerRMSNormForGlm4
2109
+ if swiglu:
2110
+ modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
2111
+ if cross_entropy:
2112
+ from transformers.loss.loss_utils import nn
2113
+
2114
+ nn.functional.cross_entropy = liger_cross_entropy
2115
+ if fused_linear_cross_entropy:
2116
+ if model is not None:
2117
+ model.forward = MethodType(glm4_lce_forward, model)
2118
+ else:
2119
+ modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
2120
+
2121
+ if model is not None:
2122
+ # The model instance already exists, so we need to additionally patch the
2123
+ # instance variables that reference already-instantiated modules
2124
+
2125
+ # get the base model from the model instance
2126
+ base_model: Glm4Model = getattr(model, model.base_model_prefix, model)
2127
+
2128
+ if rms_norm:
2129
+ _patch_rms_norm_module(base_model.norm, in_place=False)
2130
+
2131
+ for decoder_layer in base_model.layers:
2132
+ if swiglu:
2133
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
2134
+ if rms_norm:
2135
+ _patch_rms_norm_module(decoder_layer.input_layernorm, in_place=False)
2136
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
2137
+ _patch_rms_norm_module(decoder_layer.post_self_attn_layernorm, in_place=False)
2138
+ _patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
2139
+
2140
+
2141
+ def apply_liger_kernel_to_glm4v(
2142
+ rope: bool = False,
2143
+ cross_entropy: bool = False,
2144
+ fused_linear_cross_entropy: bool = True,
2145
+ rms_norm: bool = True,
2146
+ swiglu: bool = True,
2147
+ model: PreTrainedModel = None,
2148
+ ) -> None:
2149
+ """
2150
+ Apply Liger kernels to replace original implementation in HuggingFace GLM-4v models.
2151
+
2152
+ Args:
2153
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2154
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2155
+ fused_linear_cross_entropy (bool):
2156
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2157
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2158
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2159
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2160
+ swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
2161
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2162
+ loaded. Default is None.
2163
+ """
2164
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2165
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2166
+ )
2167
+
2168
+ from transformers.models.glm4v import modeling_glm4v
2169
+ from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration
2170
+ from transformers.models.glm4v.modeling_glm4v import Glm4vModel
2171
+ from transformers.models.glm4v.modeling_glm4v import Glm4vTextModel
2172
+ from transformers.models.glm4v.modeling_glm4v import Glm4vVisionModel
2173
+
2174
+ from liger_kernel.transformers.model.glm4v import lce_forward as glm4v_lce_forward
2175
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
2176
+
2177
+ if rope:
2178
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
2179
+ if rms_norm:
2180
+ modeling_glm4v.Glm4vRMSNorm = LigerRMSNormForGlm4
2181
+ if cross_entropy:
2182
+ from transformers.loss.loss_utils import nn
2183
+
2184
+ nn.functional.cross_entropy = liger_cross_entropy
2185
+ if fused_linear_cross_entropy:
2186
+ if model is not None:
2187
+ model.forward = MethodType(glm4v_lce_forward, model)
2188
+ else:
2189
+ modeling_glm4v.Glm4vForConditionalGeneration.forward = glm4v_lce_forward
2190
+
2191
+ if model is not None:
2192
+ # The model instance already exists, so we need to additionally patch the
2193
+ # instance variables that reference already-instantiated modules
2194
+ if isinstance(model, (Glm4vForConditionalGeneration, Glm4vModel)):
2195
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
2196
+ # Not sure if it is subject to changes in the future.
2197
+ # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py#L1305
2198
+ text_model: Glm4vTextModel = model.language_model
2199
+ vision_model: Glm4vVisionModel = model.visual
2200
+ elif isinstance(model, Glm4vTextModel):
2201
+ text_model: Glm4vTextModel = model
2202
+ vision_model = None
2203
+ else:
2204
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
2205
+ raise TypeError(
2206
+ f"Unsupported glm4.1v model type. `model` must be `Glm4VLForConditionalGeneration`, `Glm4vVisionModel` or `Glm4vTextModel`. Got: {type(model)}"
2207
+ )
2208
+
2209
+ if vision_model is not None:
2210
+ for vision_block in vision_model.blocks:
2211
+ if rms_norm:
2212
+ _patch_rms_norm_module(vision_block.norm1)
2213
+ _patch_rms_norm_module(vision_block.norm2)
2214
+ if swiglu:
2215
+ _patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
2216
+
2217
+ if text_model is not None:
2218
+ if rms_norm:
2219
+ _patch_rms_norm_module(text_model.norm)
2220
+ for decoder_layer in text_model.layers:
2221
+ if swiglu:
2222
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
2223
+ if rms_norm:
2224
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2225
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2226
+ _patch_rms_norm_module(decoder_layer.post_self_attn_layernorm)
2227
+ _patch_rms_norm_module(decoder_layer.post_mlp_layernorm)
2228
+
2229
+
2230
+ def apply_liger_kernel_to_glm4v_moe(
2231
+ rope: bool = False,
2232
+ cross_entropy: bool = False,
2233
+ fused_linear_cross_entropy: bool = True,
2234
+ rms_norm: bool = True,
2235
+ swiglu: bool = True,
2236
+ model: PreTrainedModel = None,
2237
+ ) -> None:
2238
+ """
2239
+ Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
2240
+
2241
+ Args:
2242
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2243
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2244
+ fused_linear_cross_entropy (bool):
2245
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2246
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2247
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2248
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2249
+ swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
2250
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2251
+ loaded. Default is None.
2252
+ """
2253
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2254
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2255
+ )
2256
+
2257
+ from transformers.models.glm4v_moe import modeling_glm4v_moe
2258
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeForConditionalGeneration
2259
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeModel
2260
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextModel
2261
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeVisionModel
2262
+
2263
+ from liger_kernel.transformers.model.glm4v_moe import lce_forward as glm4v_moe_lce_forward
2264
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
2265
+
2266
+ if rope:
2267
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
2268
+ if rms_norm:
2269
+ modeling_glm4v_moe.Glm4vMoeRMSNorm = LigerRMSNormForGlm4
2270
+ modeling_glm4v_moe.Glm4vMoeTextRMSNorm = LigerRMSNormForGlm4
2271
+ if cross_entropy:
2272
+ from transformers.loss.loss_utils import nn
2273
+
2274
+ nn.functional.cross_entropy = liger_cross_entropy
2275
+ if fused_linear_cross_entropy:
2276
+ if model is not None:
2277
+ model.forward = MethodType(glm4v_moe_lce_forward, model)
2278
+ else:
2279
+ modeling_glm4v_moe.Glm4vMoeForConditionalGeneration.forward = glm4v_moe_lce_forward
2280
+
2281
+ if model is not None:
2282
+ # The model instance already exists, so we need to additionally patch the
2283
+ # instance variables that reference already-instantiated modules
2284
+ if isinstance(model, (Glm4vMoeForConditionalGeneration, Glm4vMoeModel)):
2285
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
2286
+ # Not sure if it is subject to changes in the future.
2287
+ # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py#L337
2288
+ text_model: Glm4vMoeTextModel = model.language_model
2289
+ vision_model: Glm4vMoeVisionModel = model.visual
2290
+ Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
2291
+ elif isinstance(model, Glm4vMoeTextModel):
2292
+ text_model: Glm4vMoeTextModel = model
2293
+ vision_model = None
2294
+ else:
2295
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
2296
+ raise TypeError(
2297
+ f"Unsupported glm4v_moe model type. `model` must be `Glm4vMoeForConditionalGeneration`, `Glm4vMoeVisionModel` or `Glm4vMoeTextModel`. Got: {type(model)}"
2298
+ )
2299
+
2300
+ if vision_model is not None:
2301
+ _patch_rms_norm_module(vision_model.post_conv_layernorm)
2302
+ _patch_rms_norm_module(vision_model.post_layernorm)
2303
+ for vision_block in vision_model.blocks:
2304
+ if rms_norm:
2305
+ _patch_rms_norm_module(vision_block.norm1)
2306
+ _patch_rms_norm_module(vision_block.norm2)
2307
+ if swiglu:
2308
+ _patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
2309
+
2310
+ if text_model is not None:
2311
+ if rms_norm:
2312
+ _patch_rms_norm_module(text_model.norm)
2313
+ for decoder_layer in text_model.layers:
2314
+ if swiglu:
2315
+ decoder_layer.mlp = _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2316
+ if rms_norm:
2317
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2318
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2319
+ if isinstance(Glm4vMoeTextMoE, type) and isinstance(decoder_layer.mlp, Glm4vMoeTextMoE):
2320
+ experts = getattr(decoder_layer.mlp, "experts", None)
2321
+ if experts is not None:
2322
+ for expert in experts:
2323
+ _patch_swiglu_module(expert, LigerSwiGLUMLP)
2324
+ if decoder_layer.mlp.shared_experts is not None:
2325
+ _patch_swiglu_module(decoder_layer.mlp.shared_experts, LigerSwiGLUMLP)
2326
+ for decoder_layer in text_model.layers:
2327
+ if rms_norm:
2328
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2329
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2330
+
2331
+
2332
+ def apply_liger_kernel_to_internvl(
2333
+ cross_entropy: bool = False,
2334
+ fused_linear_cross_entropy: bool = True,
2335
+ rms_norm: bool = True,
2336
+ layer_norm: bool = True,
2337
+ model: Optional[PreTrainedModel] = None,
2338
+ **kwargs,
2339
+ ) -> None:
2340
+ """
2341
+ Apply Liger kernels to replace original implementation in HuggingFace InternVL models.
2342
+ Due to the characteristics of InternVL, the model must be passed to apply Liger-Kernel's patch to other models connected to InternVL.
2343
+ However, if an LM not supported by Liger-Kernel is connected to InternVL, unexpected side effects may occur.
2344
+ NOTE: InternVL is not available in transformers<4.52.1
2345
+
2346
+ Args:
2347
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2348
+ fused_linear_cross_entropy (bool):
2349
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2350
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2351
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2352
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2353
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
2354
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2355
+ loaded. Default is None.
2356
+ """
2357
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2358
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2359
+ )
2360
+ import torch.nn as torch_nn
2361
+
2362
+ from transformers.models.internvl import modeling_internvl
2363
+ from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration
2364
+ from transformers.models.internvl.modeling_internvl import InternVLModel
2365
+ from transformers.models.internvl.modeling_internvl import InternVLVisionLayer
2366
+ from transformers.models.internvl.modeling_internvl import InternVLVisionModel
2367
+ from transformers.models.internvl.modeling_internvl import InternVLVisionRMSNorm
2368
+
2369
+ from liger_kernel.transformers.layer_norm import LigerLayerNorm
2370
+ from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
2371
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm
2372
+
2373
+ if layer_norm and model is None:
2374
+ modeling_internvl.nn.LayerNorm = LigerLayerNorm
2375
+
2376
+ if cross_entropy:
2377
+ logger.info("Apply liger cross entropy")
2378
+
2379
+ from transformers.loss.loss_utils import nn
2380
+
2381
+ nn.functional.cross_entropy = liger_cross_entropy
2382
+ if fused_linear_cross_entropy:
2383
+ modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
2384
+ if rms_norm:
2385
+ modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
2386
+
2387
+ if model is not None:
2388
+ # The model instance already exists, so we need to additionally patch the
2389
+ # instance variables that reference already-instantiated modules
2390
+ if isinstance(model, (InternVLForConditionalGeneration, InternVLModel)):
2391
+ # NOTE: language_model and visual properties can be accessed throught conditional class.
2392
+ text_model = model.language_model
2393
+ vision_model: InternVLVisionModel = model.vision_tower
2394
+ else:
2395
+ raise TypeError(
2396
+ f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}"
2397
+ )
2398
+
2399
+ text_model_name = model.config.text_config.model_type
2400
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2401
+
2402
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2403
+ if text_liger_fn:
2404
+ accept_params = inspect.signature(text_liger_fn).parameters
2405
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2406
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2407
+
2408
+ if remain_params:
2409
+ logger.warning(
2410
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2411
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2412
+ )
2413
+ text_kwargs["model"] = text_model
2414
+ text_liger_fn(**text_kwargs)
2415
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2416
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2417
+
2418
+ # Patch vision model RMSNorm layers
2419
+ if rms_norm:
2420
+ for encoder_layer in vision_model.encoder.layer:
2421
+ encoder_layer: InternVLVisionLayer
2422
+ if isinstance(encoder_layer.attention.q_norm, InternVLVisionRMSNorm):
2423
+ _patch_rms_norm_module(encoder_layer.attention.q_norm)
2424
+ if isinstance(encoder_layer.attention.k_norm, InternVLVisionRMSNorm):
2425
+ _patch_rms_norm_module(encoder_layer.attention.k_norm)
2426
+
2427
+ # Patch vision model LayerNorm layers
2428
+ if layer_norm:
2429
+ # Patch layernorm
2430
+ if isinstance(vision_model.layernorm, torch_nn.LayerNorm):
2431
+ _patch_layer_norm_module(vision_model.layernorm)
2432
+
2433
+ # Patch encoder layers
2434
+ for encoder_layer in vision_model.encoder.layer:
2435
+ encoder_layer: InternVLVisionLayer
2436
+ if isinstance(encoder_layer.layernorm_before, torch_nn.LayerNorm):
2437
+ _patch_layer_norm_module(encoder_layer.layernorm_before)
2438
+ if isinstance(encoder_layer.layernorm_after, torch_nn.LayerNorm):
2439
+ _patch_layer_norm_module(encoder_layer.layernorm_after)
2440
+
2441
+
2442
+ def apply_liger_kernel_to_smolvlm(
2443
+ cross_entropy: bool = False,
2444
+ fused_linear_cross_entropy: bool = True,
2445
+ rms_norm: bool = True,
2446
+ layer_norm: bool = True,
2447
+ model: Optional[PreTrainedModel] = None,
2448
+ **kwargs,
2449
+ ) -> None:
2450
+ """
2451
+ Apply Liger kernels to replace original implementation in HuggingFace SmolVLM models.
2452
+ Due to the characteristics of SmolVLM, the model must be passed to apply Liger-Kernel's patch to other models connected to SmolVLM.
2453
+ However, if an LM not supported by Liger-Kernel is connected to SmolVLM, unexpected side effects may occur.
2454
+ NOTE: SmolVLM is not available in transformers<4.50.0
2455
+
2456
+ Args:
2457
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2458
+ fused_linear_cross_entropy (bool):
2459
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2460
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2461
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2462
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2463
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
2464
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2465
+ loaded. Default is None.
2466
+ """
2467
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2468
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2469
+ )
2470
+
2471
+ from transformers.models.smolvlm import modeling_smolvlm
2472
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMEncoderLayer
2473
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration
2474
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMModel
2475
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer
2476
+
2477
+ from liger_kernel.transformers.model.smolvlm import lce_forward as smolvlm_lce_forward
2478
+
2479
+ # Patch LayerNorm for vision model if model is not provided (pre-initialization)
2480
+ if layer_norm and model is None:
2481
+ modeling_smolvlm.nn.LayerNorm = LigerLayerNorm
2482
+
2483
+ if cross_entropy:
2484
+ logger.info("Apply liger cross entropy")
2485
+
2486
+ from transformers.loss.loss_utils import nn
2487
+
2488
+ nn.functional.cross_entropy = liger_cross_entropy
2489
+ if fused_linear_cross_entropy:
2490
+ if model is not None:
2491
+ model.forward = MethodType(smolvlm_lce_forward, model)
2492
+ else:
2493
+ modeling_smolvlm.SmolVLMForConditionalGeneration.forward = smolvlm_lce_forward
2494
+ if rms_norm:
2495
+ modeling_smolvlm.SmolVLMRMSNorm = LigerRMSNorm
2496
+
2497
+ if model is not None:
2498
+ # The model instance already exists, so we need to additionally patch the
2499
+ # instance variables that reference already-instantiated modules
2500
+ if isinstance(model, SmolVLMForConditionalGeneration):
2501
+ text_model = model.model.text_model
2502
+ vision_model: SmolVLMVisionTransformer = model.model.vision_model
2503
+ elif isinstance(model, SmolVLMModel):
2504
+ text_model = model.text_model
2505
+ vision_model: SmolVLMVisionTransformer = model.vision_model
2506
+ else:
2507
+ raise TypeError(
2508
+ f"Unsupported smolvlm model type. `model` must be `SmolVLMForConditionalGeneration`, `SmolVLMModel`. Got: {type(model)}"
2509
+ )
2510
+
2511
+ text_model_name = model.config.text_config.model_type
2512
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2513
+
2514
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2515
+ if text_liger_fn:
2516
+ accept_params = inspect.signature(text_liger_fn).parameters
2517
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2518
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2519
+
2520
+ if remain_params:
2521
+ logger.warning(
2522
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2523
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2524
+ )
2525
+ text_kwargs["model"] = text_model
2526
+ text_liger_fn(**text_kwargs)
2527
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2528
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2529
+
2530
+ # Patch vision model LayerNorm layers
2531
+ if layer_norm:
2532
+ # Patch post_layernorm
2533
+ _patch_layer_norm_module(vision_model.post_layernorm)
2534
+
2535
+ # Patch encoder layers
2536
+ for encoder_layer in vision_model.encoder.layers:
2537
+ encoder_layer: SmolVLMEncoderLayer
2538
+ _patch_layer_norm_module(encoder_layer.layer_norm1)
2539
+ _patch_layer_norm_module(encoder_layer.layer_norm2)
2540
+
2541
+
2542
+ def apply_liger_kernel_to_falcon_h1(
2543
+ rope: bool = True,
2544
+ cross_entropy: bool = False,
2545
+ fused_linear_cross_entropy: bool = True,
2546
+ rms_norm: bool = True,
2547
+ swiglu: bool = False,
2548
+ model: PreTrainedModel = None,
2549
+ ) -> None:
2550
+ """
2551
+ Apply Liger kernels to replace original implementation in HuggingFace Falcon-H1 models
2552
+ Args:
2553
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2554
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
2555
+ fused_linear_cross_entropy (bool):
2556
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
2557
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2558
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2559
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
2560
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
2561
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2562
+ loaded. Default is None.
2563
+ """
2564
+
2565
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2566
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2567
+ )
2568
+
2569
+ from transformers.models.falcon_h1 import modeling_falcon_h1
2570
+ from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1Model
2571
+
2572
+ if rope:
2573
+ logger.info("Apply liger rotary pos emb.")
2574
+ modeling_falcon_h1.apply_rotary_pos_emb = liger_rotary_pos_emb
2575
+ if rms_norm:
2576
+ logger.info("Apply liger RMSNorm")
2577
+ modeling_falcon_h1.FalconH1RMSNorm = LigerRMSNorm
2578
+ if swiglu:
2579
+ logger.warning("LigerSwiGLUMLP is not available for Falcon-H1 models. There will be no effect.")
2580
+
2581
+ if cross_entropy:
2582
+ logger.info("Apply liger cross entropy")
2583
+ from transformers.loss.loss_utils import nn
2584
+
2585
+ nn.functional.cross_entropy = liger_cross_entropy
2586
+
2587
+ if fused_linear_cross_entropy:
2588
+ if model is not None:
2589
+ model.forward = MethodType(falcon_h1_lce_forward, model)
2590
+ else:
2591
+ modeling_falcon_h1.FalconH1ForCausalLM.forward = falcon_h1_lce_forward
2592
+
2593
+ if model is not None:
2594
+ # The model instance already exists, so we need to additionally patch the
2595
+ # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
2596
+
2597
+ # get the base model from the model instance
2598
+ base_model: FalconH1Model = getattr(model, model.base_model_prefix, model)
2599
+
2600
+ if rms_norm:
2601
+ _patch_rms_norm_module(base_model.final_layernorm)
2602
+
2603
+ for decoder_layer in base_model.layers:
2604
+ if swiglu:
2605
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2606
+ if rms_norm:
2607
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2608
+ _patch_rms_norm_module(decoder_layer.pre_ff_layernorm)
2609
+
2610
+
2611
+ def apply_liger_kernel_to_qwen3_next(
2612
+ rope: bool = False,
2613
+ cross_entropy: bool = False,
2614
+ fused_linear_cross_entropy: bool = True,
2615
+ rms_norm: bool = True,
2616
+ swiglu: bool = True,
2617
+ model: PreTrainedModel = None,
2618
+ ) -> None:
2619
+ """
2620
+ Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
2621
+
2622
+ Args:
2623
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2624
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2625
+ fused_linear_cross_entropy (bool):
2626
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2627
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2628
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2629
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2630
+ swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
2631
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2632
+ loaded. Default is None.
2633
+ """
2634
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2635
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2636
+ )
2637
+
2638
+ from transformers.models.qwen3_next import modeling_qwen3_next
2639
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM
2640
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP
2641
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextModel
2642
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
2643
+
2644
+ from liger_kernel.transformers.model.qwen3_next import lce_forward as qwen3_next_lce_forward
2645
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
2646
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
2647
+
2648
+ if rope:
2649
+ # It might enocunter nan issue
2650
+ # modeling_qwen3_next.apply_rotary_pos_emb = liger_rotary_pos_emb
2651
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3Next models.")
2652
+ if rms_norm:
2653
+ modeling_qwen3_next.Qwen3NextRMSNorm = LigerRMSNormForQwen3Next
2654
+ if cross_entropy:
2655
+ from transformers.loss.loss_utils import nn
2656
+
2657
+ nn.functional.cross_entropy = liger_cross_entropy
2658
+ if fused_linear_cross_entropy:
2659
+ if model is not None:
2660
+ if isinstance(model, Qwen3NextForCausalLM):
2661
+ model.forward = MethodType(qwen3_next_lce_forward, model)
2662
+ else:
2663
+ raise TypeError(
2664
+ f" fused_linear_cross_entropy is only applicable on Qwen3NextForCausalLM. Got: {type(model)}"
2665
+ )
2666
+ else:
2667
+ modeling_qwen3_next.Qwen3NextForCausalLM.forward = qwen3_next_lce_forward
2668
+ if swiglu:
2669
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2670
+ modeling_qwen3_next.Qwen3NextMLP = LigerQwen3MoeSwiGLUMLP
2671
+
2672
+ if model is not None:
2673
+ # The model instance already exists, so we need to additionally patch the
2674
+ # instance variables that reference already-instantiated modules
2675
+ if isinstance(model, (Qwen3NextForCausalLM, Qwen3NextModel)):
2676
+ base_model: Qwen3NextForCausalLM = getattr(model, model.base_model_prefix, model)
2677
+ else:
2678
+ raise TypeError(
2679
+ f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}"
2680
+ )
2681
+
2682
+ if rms_norm:
2683
+ _patch_rms_norm_module(base_model.norm)
2684
+
2685
+ for decoder_layer in base_model.layers:
2686
+ if rms_norm:
2687
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2688
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2689
+
2690
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2691
+ if swiglu:
2692
+ if isinstance(decoder_layer.mlp, Qwen3NextMLP):
2693
+ _patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
2694
+ if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock):
2695
+ _patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP)
2696
+ experts = getattr(decoder_layer.mlp, "experts", None)
2697
+ if experts is not None:
2698
+ for expert in experts:
2699
+ _patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
2700
+
2701
+
2702
+ def apply_liger_kernel_to_hunyuan_v1_dense(
2703
+ rope: bool = True,
2704
+ cross_entropy: bool = False,
2705
+ fused_linear_cross_entropy: bool = True,
2706
+ rms_norm: bool = True,
2707
+ swiglu: bool = True,
2708
+ model: PreTrainedModel = None,
2709
+ ) -> None:
2710
+ """
2711
+ Apply Liger kernels to replace original implementation in HuggingFace Hunyuan v1 dense models.
2712
+ """
2713
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2714
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2715
+ )
2716
+
2717
+ from transformers.models.hunyuan_v1_dense import modeling_hunyuan_v1_dense
2718
+ from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1Model
2719
+
2720
+ from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_lce_forward
2721
+ from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
2722
+
2723
+ if rope:
2724
+ modeling_hunyuan_v1_dense.apply_rotary_pos_emb = liger_rotary_pos_emb
2725
+
2726
+ if rms_norm:
2727
+ modeling_hunyuan_v1_dense.HunYuanDenseV1RMSNorm = LigerRMSNorm
2728
+
2729
+ if cross_entropy:
2730
+ from transformers.loss.loss_utils import nn
2731
+
2732
+ nn.functional.cross_entropy = liger_cross_entropy
2733
+
2734
+ if fused_linear_cross_entropy:
2735
+ if model is not None:
2736
+ model.forward = MethodType(hunyuan_v1_lce_forward, model)
2737
+ else:
2738
+ modeling_hunyuan_v1_dense.HunYuanDenseV1ForCausalLM.forward = hunyuan_v1_lce_forward
2739
+
2740
+ if swiglu:
2741
+ modeling_hunyuan_v1_dense.HunYuanDenseV1MLP = LigerHunyuanV1SwiGLUMLP
2742
+
2743
+ if model is not None:
2744
+ # The model instance already exists, so we need to additionally patch the
2745
+ # instance variables that reference already-instantiated modules
2746
+
2747
+ # get the base model from the model instance
2748
+ base_model: HunYuanDenseV1Model = getattr(model, model.base_model_prefix, model)
2749
+
2750
+ if rms_norm:
2751
+ _patch_rms_norm_module(base_model.norm)
2752
+ for decoder_layer in base_model.layers:
2753
+ if swiglu:
2754
+ _patch_swiglu_module(decoder_layer.mlp, LigerHunyuanV1SwiGLUMLP)
2755
+ if rms_norm:
2756
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2757
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2758
+
2759
+
2760
+ def apply_liger_kernel_to_hunyuan_v1_moe(
2761
+ rope: bool = True,
2762
+ cross_entropy: bool = False,
2763
+ fused_linear_cross_entropy: bool = True,
2764
+ rms_norm: bool = True,
2765
+ swiglu: bool = True,
2766
+ model: PreTrainedModel = None,
2767
+ ) -> None:
2768
+ """
2769
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
2770
+ """
2771
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2772
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2773
+ )
2774
+
2775
+ from transformers.models.hunyuan_v1_moe import modeling_hunyuan_v1_moe
2776
+ from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1Model
2777
+
2778
+ from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_moe_lce_forward
2779
+ from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
2780
+
2781
+ if rope:
2782
+ modeling_hunyuan_v1_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
2783
+
2784
+ if rms_norm:
2785
+ modeling_hunyuan_v1_moe.HunYuanMoEV1RMSNorm = LigerRMSNorm
2786
+
2787
+ if cross_entropy:
2788
+ from transformers.loss.loss_utils import nn
2789
+
2790
+ nn.functional.cross_entropy = liger_cross_entropy
2791
+
2792
+ if fused_linear_cross_entropy:
2793
+ if model is not None:
2794
+ model.forward = MethodType(hunyuan_v1_moe_lce_forward, model)
2795
+ else:
2796
+ modeling_hunyuan_v1_moe.HunYuanMoEV1ForCausalLM.forward = hunyuan_v1_moe_lce_forward
2797
+
2798
+ if swiglu:
2799
+ modeling_hunyuan_v1_moe.HunYuanMoEV1MLP = LigerHunyuanV1SwiGLUMLP
2800
+
2801
+ if model is not None:
2802
+ # The model instance already exists, so we need to additionally patch the
2803
+ # instance variables that reference already-instantiated modules
2804
+
2805
+ # get the base model from the model instance
2806
+ base_model: HunYuanMoEV1Model = getattr(model, model.base_model_prefix, model)
2807
+
2808
+ if rms_norm:
2809
+ _patch_rms_norm_module(base_model.norm)
2810
+ for decoder_layer in base_model.layers:
2811
+ if swiglu:
2812
+ for mlp_expert in decoder_layer.mlp.experts:
2813
+ _patch_swiglu_module(mlp_expert, LigerHunyuanV1SwiGLUMLP)
2814
+ if rms_norm:
2815
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2816
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2817
+
2818
+
2819
+ # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
2820
+ MODEL_TYPE_TO_APPLY_LIGER_FN = {
2821
+ "gemma": apply_liger_kernel_to_gemma,
2822
+ "gemma2": apply_liger_kernel_to_gemma2,
2823
+ "gemma3_text": apply_liger_kernel_to_gemma3_text,
2824
+ "gemma3": apply_liger_kernel_to_gemma3,
2825
+ "glm4": apply_liger_kernel_to_glm4,
2826
+ "glm4v": apply_liger_kernel_to_glm4v,
2827
+ "glm4v_moe": apply_liger_kernel_to_glm4v_moe,
2828
+ "gpt_oss": apply_liger_kernel_to_gpt_oss,
2829
+ "internvl": apply_liger_kernel_to_internvl,
2830
+ "llama": apply_liger_kernel_to_llama,
2831
+ "llama4_text": apply_liger_kernel_to_llama4,
2832
+ "llama4": apply_liger_kernel_to_llama4,
2833
+ "llava": apply_liger_kernel_to_llava,
2834
+ "granite": apply_liger_kernel_to_granite,
2835
+ "mllama": apply_liger_kernel_to_mllama,
2836
+ "mllama_text_model": apply_liger_kernel_to_mllama,
2837
+ "mistral": apply_liger_kernel_to_mistral,
2838
+ "mixtral": apply_liger_kernel_to_mixtral,
2839
+ "olmo2": apply_liger_kernel_to_olmo2,
2840
+ "olmo3": apply_liger_kernel_to_olmo3,
2841
+ "qwen2": apply_liger_kernel_to_qwen2,
2842
+ "qwen3": apply_liger_kernel_to_qwen3,
2843
+ "qwen3_moe": apply_liger_kernel_to_qwen3_moe,
2844
+ "qwen2_vl": apply_liger_kernel_to_qwen2_vl,
2845
+ "qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
2846
+ "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
2847
+ "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
2848
+ "qwen3_next": apply_liger_kernel_to_qwen3_next,
2849
+ "qwen3_vl": apply_liger_kernel_to_qwen3_vl,
2850
+ "qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
2851
+ "qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
2852
+ "qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
2853
+ "smollm3": apply_liger_kernel_to_smollm3,
2854
+ "phi3": apply_liger_kernel_to_phi3,
2855
+ "paligemma": apply_liger_kernel_to_paligemma,
2856
+ "falcon_h1": apply_liger_kernel_to_falcon_h1,
2857
+ "smolvlm": apply_liger_kernel_to_smolvlm,
2858
+ "hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
2859
+ "hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
2860
+ }
2861
+
2862
+
2863
+ def _apply_liger_kernel(model_type: str, **kwargs) -> None:
2864
+ """
2865
+ Applies Liger kernels based on the specified model type. The custom
2866
+ kernels for the specified model type will be applied with the provided
2867
+ keyword arguments, otherwise the default configuration will be used.
2868
+
2869
+ ** Note: Calling _apply_liger_kernel() after model initialization
2870
+ will not be able to fully patch models. This must be called before model initialization.
2871
+ If the model has already been instantiated
2872
+
2873
+ Args:
2874
+ - model_type: the model types as defined in transformers/models/auto/modeling_auto.py
2875
+ and specified in the model's config.json
2876
+ - kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
2877
+ """
2878
+ if not model_type:
2879
+ logger.info("Model type was not provided. No Liger kernels will be applied.")
2880
+ return
2881
+
2882
+ if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
2883
+ logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
2884
+ return
2885
+
2886
+ apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
2887
+ apply_fn_signature = inspect.signature(apply_fn)
2888
+
2889
+ # Filter out the keyword arguments that are not supported by the apply function
2890
+ applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
2891
+
2892
+ logger.info(f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}")
2893
+
2894
+ # Assume this is invoked pre-model initialization, so we only need to patch transformers code
2895
+ apply_fn(**applicable_kwargs)
2896
+
2897
+
2898
+ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
2899
+ """
2900
+ Applies Liger kernels to the provided model instance.
2901
+
2902
+ Args:
2903
+ - model: the model instance to apply Liger kernels to
2904
+ - kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
2905
+ """
2906
+ model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
2907
+
2908
+ if not model_type:
2909
+ logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
2910
+ return
2911
+
2912
+ if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
2913
+ logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
2914
+ return
2915
+
2916
+ apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
2917
+ apply_fn_signature = inspect.signature(apply_fn)
2918
+
2919
+ # Filter out the keyword arguments that are not supported by the apply function
2920
+ applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
2921
+ logger.info(
2922
+ f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
2923
+ )
2924
+
2925
+ apply_fn(model=model, **applicable_kwargs)