liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (114) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +304 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +21 -4
  18. liger_kernel/ops/cross_entropy.py +235 -84
  19. liger_kernel/ops/dyt.py +157 -0
  20. liger_kernel/ops/experimental/embedding.py +1 -3
  21. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  22. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  23. liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
  24. liger_kernel/ops/fused_linear_jsd.py +17 -34
  25. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  26. liger_kernel/ops/geglu.py +7 -18
  27. liger_kernel/ops/group_norm.py +305 -0
  28. liger_kernel/ops/grpo_loss.py +310 -0
  29. liger_kernel/ops/jsd.py +46 -21
  30. liger_kernel/ops/kl_div.py +23 -19
  31. liger_kernel/ops/layer_norm.py +150 -86
  32. liger_kernel/ops/llama4_rope.py +225 -0
  33. liger_kernel/ops/multi_token_attention.py +207 -0
  34. liger_kernel/ops/poly_norm.py +386 -0
  35. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  36. liger_kernel/ops/rms_norm.py +314 -84
  37. liger_kernel/ops/rope.py +32 -34
  38. liger_kernel/ops/softmax.py +201 -0
  39. liger_kernel/ops/sparsemax.py +179 -0
  40. liger_kernel/ops/swiglu.py +5 -9
  41. liger_kernel/ops/tiled_mlp.py +136 -0
  42. liger_kernel/ops/tvd.py +207 -0
  43. liger_kernel/ops/utils.py +8 -4
  44. liger_kernel/transformers/__init__.py +199 -24
  45. liger_kernel/transformers/auto_model.py +6 -13
  46. liger_kernel/transformers/cross_entropy.py +33 -20
  47. liger_kernel/transformers/dyt.py +22 -0
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -3
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +291 -13
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -4
  57. liger_kernel/transformers/group_norm.py +50 -0
  58. liger_kernel/transformers/grpo_loss.py +98 -0
  59. liger_kernel/transformers/jsd.py +2 -7
  60. liger_kernel/transformers/kl_div.py +1 -3
  61. liger_kernel/transformers/layer_norm.py +3 -9
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/falcon_h1.py +122 -0
  64. liger_kernel/transformers/model/gemma.py +77 -77
  65. liger_kernel/transformers/model/gemma2.py +283 -0
  66. liger_kernel/transformers/model/gemma3.py +331 -0
  67. liger_kernel/transformers/model/glm4.py +141 -0
  68. liger_kernel/transformers/model/glm4v.py +163 -0
  69. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  70. liger_kernel/transformers/model/internvl.py +157 -0
  71. liger_kernel/transformers/model/llama.py +128 -79
  72. liger_kernel/transformers/model/llama4.py +121 -0
  73. liger_kernel/transformers/model/llava.py +344 -0
  74. liger_kernel/transformers/model/loss_utils.py +95 -0
  75. liger_kernel/transformers/model/mistral.py +68 -64
  76. liger_kernel/transformers/model/mixtral.py +75 -91
  77. liger_kernel/transformers/model/mllama.py +63 -68
  78. liger_kernel/transformers/model/olmo2.py +141 -0
  79. liger_kernel/transformers/model/output_classes.py +147 -0
  80. liger_kernel/transformers/model/paligemma.py +432 -0
  81. liger_kernel/transformers/model/phi3.py +59 -213
  82. liger_kernel/transformers/model/qwen2.py +75 -72
  83. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  84. liger_kernel/transformers/model/qwen2_vl.py +78 -98
  85. liger_kernel/transformers/model/qwen3.py +136 -0
  86. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  87. liger_kernel/transformers/model/qwen3_next.py +146 -0
  88. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  89. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  90. liger_kernel/transformers/model/smollm3.py +199 -0
  91. liger_kernel/transformers/model/smolvlm.py +158 -0
  92. liger_kernel/transformers/monkey_patch.py +2106 -289
  93. liger_kernel/transformers/multi_token_attention.py +64 -0
  94. liger_kernel/transformers/poly_norm.py +42 -0
  95. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  96. liger_kernel/transformers/rms_norm.py +57 -6
  97. liger_kernel/transformers/rope.py +45 -2
  98. liger_kernel/transformers/softmax.py +12 -0
  99. liger_kernel/transformers/sparsemax.py +16 -0
  100. liger_kernel/transformers/swiglu.py +23 -8
  101. liger_kernel/transformers/tiled_mlp.py +133 -0
  102. liger_kernel/transformers/trainer/__init__.py +4 -0
  103. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  104. liger_kernel/transformers/tvd.py +13 -0
  105. liger_kernel/triton/__init__.py +1 -3
  106. liger_kernel/triton/monkey_patch.py +1 -3
  107. liger_kernel/utils.py +71 -0
  108. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
  109. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  110. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
  111. liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
  112. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
@@ -1,43 +1,51 @@
1
1
  import inspect
2
2
  import logging
3
+
3
4
  from functools import partial
5
+ from types import MethodType
4
6
  from typing import Callable
7
+ from typing import Optional
5
8
 
6
9
  import transformers
10
+
7
11
  from packaging import version
8
12
  from transformers import PreTrainedModel
9
13
 
10
14
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
15
+ from liger_kernel.transformers.functional import liger_cross_entropy
11
16
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
12
17
  from liger_kernel.transformers.layer_norm import LigerLayerNorm
18
+ from liger_kernel.transformers.model.falcon_h1 import lce_forward as falcon_h1_lce_forward
13
19
  from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
14
- from liger_kernel.transformers.model.gemma import (
15
- lce_forward_deprecated as gemma_lce_forward_deprecated,
16
- )
20
+ from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
21
+ from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
22
+ from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
17
23
  from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
18
- from liger_kernel.transformers.model.llama import (
19
- lce_forward_deprecated as llama_lce_forward_deprecated,
20
- )
24
+ from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
25
+ from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
26
+ from liger_kernel.transformers.model.llava import lce_forward_deprecated as llava_lce_forward_deprecated
21
27
  from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
22
28
  from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
23
- from liger_kernel.transformers.model.mixtral import (
24
- lce_forward_deprecated as mixtral_lce_forward_deprecated,
25
- )
29
+ from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
26
30
  from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
27
- from liger_kernel.transformers.model.phi3 import (
28
- lce_forward_deprecated as phi3_lce_forward_deprecated,
29
- )
30
31
  from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
31
- from liger_kernel.transformers.model.qwen2 import (
32
- lce_forward_deprecated as qwen2_lce_forward_deprecated,
33
- )
32
+ from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
33
+ from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
34
+ from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
34
35
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
35
36
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
36
- from liger_kernel.transformers.swiglu import (
37
- LigerBlockSparseTop2MLP,
38
- LigerPhi3SwiGLUMLP,
39
- LigerSwiGLUMLP,
40
- )
37
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb_with_cast
38
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb_with_cast_and_leading_batch
39
+ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
40
+ from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
41
+ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
42
+
43
+ try:
44
+ import peft
45
+
46
+ PEFT_AVAILABLE = True
47
+ except ImportError:
48
+ PEFT_AVAILABLE = False
41
49
 
42
50
  transformer_version = version.parse(transformers.__version__)
43
51
 
@@ -51,23 +59,161 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
51
59
  module.__dict__[method_name] = new_method.__get__(module, module.__class__)
52
60
 
53
61
 
54
- def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama"):
55
- module.offset = offset
56
- module.casting_mode = casting_mode
57
- module.variance_epsilon = (
58
- getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
59
- )
60
- _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
61
- _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
62
+ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None):
63
+ # Check if the module is a PEFT ModulesToSaveWrapper
64
+ # If it is, we need to patch the modules_to_save.default and original_modules
65
+ if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
66
+ module.modules_to_save.default.offset = offset
67
+ module.modules_to_save.default.casting_mode = casting_mode
68
+ module.modules_to_save.default.variance_epsilon = (
69
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
70
+ )
71
+ module.modules_to_save.default.in_place = in_place
72
+ module.modules_to_save.default.row_mode = row_mode
73
+ module.original_module.offset = offset
74
+ module.original_module.casting_mode = casting_mode
75
+ module.original_module.variance_epsilon = (
76
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
77
+ )
78
+ module.original_module.in_place = in_place
79
+ module.original_module.row_mode = row_mode
80
+ _bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
81
+ _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
82
+ _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
83
+ _bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
84
+ _bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerRMSNorm.__name__)
85
+ _bind_method_to_module(module.original_module, "_get_name", lambda self: LigerRMSNorm.__name__)
86
+ else:
87
+ module.offset = offset
88
+ module.casting_mode = casting_mode
89
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
90
+ module.in_place = in_place
91
+ module.row_mode = row_mode
92
+ _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
93
+ _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
94
+ _bind_method_to_module(module, "_get_name", lambda self: LigerRMSNorm.__name__)
62
95
 
63
96
 
64
97
  def _patch_layer_norm_module(module, eps=1e-6):
65
- module.variance_epsilon = (
66
- getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
98
+ # Check if the module is a PEFT ModulesToSaveWrapper
99
+ # If it is, we need to patch the modules_to_save.default and original_modules
100
+ if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
101
+ module.hidden_size = module.normalized_shape
102
+ _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
103
+ _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
104
+ module.modules_to_save.default.variance_epsilon = (
105
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
106
+ )
107
+ module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
108
+ module, "normalized_shape", None
109
+ )
110
+ module.original_module.variance_epsilon = (
111
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
112
+ )
113
+ module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
114
+ module, "normalized_shape", None
115
+ )
116
+ _bind_method_to_module(module.modules_to_save.default, "forward", LigerLayerNorm.forward)
117
+ _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerLayerNorm.extra_repr)
118
+ _bind_method_to_module(module.original_module, "forward", LigerLayerNorm.forward)
119
+ _bind_method_to_module(module.original_module, "extra_repr", LigerLayerNorm.extra_repr)
120
+ _bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerLayerNorm.__name__)
121
+ _bind_method_to_module(module.original_module, "_get_name", lambda self: LigerLayerNorm.__name__)
122
+ else:
123
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
124
+ module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
125
+ _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
126
+ _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
127
+ _bind_method_to_module(module, "_get_name", lambda self: LigerLayerNorm.__name__)
128
+
129
+
130
+ def _patch_swiglu_module(module, liger_module):
131
+ _bind_method_to_module(module, "forward", liger_module.forward)
132
+ _bind_method_to_module(module, "_get_name", lambda self: liger_module.__name__)
133
+
134
+
135
+ def _patch_geglu_module(module):
136
+ _bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
137
+ _bind_method_to_module(module, "_get_name", lambda self: LigerGEGLUMLP.__name__)
138
+
139
+
140
+ def apply_liger_kernel_to_granite(
141
+ rope: bool = True,
142
+ cross_entropy: bool = True,
143
+ fused_linear_cross_entropy: bool = False,
144
+ rms_norm: bool = True,
145
+ swiglu: bool = True,
146
+ model: PreTrainedModel = None,
147
+ ) -> None:
148
+ """
149
+ Apply Liger kernels to replace original implementation in HuggingFace Granite 3 models
150
+
151
+ Args:
152
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
153
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
154
+ fused_linear_cross_entropy (bool):
155
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
156
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
157
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
158
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
159
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
160
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
161
+ loaded. Default is None.
162
+
163
+
164
+
165
+ Debugging notes:
166
+ If LigerSwiGLUMLP is OK for Llama, it should be fine for Granite, but it's not.
167
+ """
168
+
169
+ assert not (cross_entropy and fused_linear_cross_entropy), (
170
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
67
171
  )
68
- module.hidden_size = module.normalized_shape
69
- _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
70
- _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
172
+
173
+ from transformers.models.granite import modeling_granite
174
+ from transformers.models.granite.modeling_granite import GraniteModel
175
+
176
+ if swiglu:
177
+ modeling_granite.GraniteMLP = LigerSwiGLUMLP
178
+
179
+ if rms_norm:
180
+ modeling_granite.GraniteRMSNorm = LigerRMSNorm
181
+
182
+ if rope:
183
+ modeling_granite.apply_rotary_pos_emb = liger_rotary_pos_emb
184
+
185
+ if cross_entropy:
186
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
187
+ from transformers.loss.loss_utils import nn
188
+
189
+ nn.functional.cross_entropy = liger_cross_entropy
190
+ else:
191
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
192
+ modeling_granite.CrossEntropyLoss = LigerCrossEntropyLoss
193
+
194
+ if fused_linear_cross_entropy:
195
+ raise NotImplementedError("LigerFusedLinearCrossEntropy is not available for Granite models.")
196
+ # NOTE: Granite model `GraniteForCausalLM.forward` scales logits each
197
+ # call, so we can't sidestep logit materialization. A bit more work
198
+ # would be needed to add a scaling term to the `LigerFusedLinearCrossEntropyFunction`
199
+ # for the logit output.
200
+
201
+ if model is not None:
202
+ # The model instance already exists, so we need to additionally patch the
203
+ # instance variables that reference already-instantiated modules (e.g. GraniteRMSNorm or GraniteMLP)
204
+
205
+ # get the base model from the model instance
206
+ base_model: GraniteModel = getattr(model, model.base_model_prefix, model)
207
+
208
+ if rms_norm:
209
+ _patch_rms_norm_module(base_model.norm)
210
+
211
+ for decoder_layer in base_model.layers:
212
+ if swiglu:
213
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
214
+ if rms_norm:
215
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
216
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
71
217
 
72
218
 
73
219
  def apply_liger_kernel_to_llama(
@@ -94,11 +240,12 @@ def apply_liger_kernel_to_llama(
94
240
  loaded. Default is None.
95
241
  """
96
242
 
97
- assert not (
98
- cross_entropy and fused_linear_cross_entropy
99
- ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
243
+ assert not (cross_entropy and fused_linear_cross_entropy), (
244
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
245
+ )
100
246
 
101
247
  from transformers.models.llama import modeling_llama
248
+ from transformers.models.llama.modeling_llama import LlamaModel
102
249
 
103
250
  if rope:
104
251
  modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
@@ -106,42 +253,295 @@ def apply_liger_kernel_to_llama(
106
253
  modeling_llama.LlamaRMSNorm = LigerRMSNorm
107
254
  if swiglu:
108
255
  modeling_llama.LlamaMLP = LigerSwiGLUMLP
256
+
109
257
  if cross_entropy:
110
- modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
258
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
259
+ from transformers.loss.loss_utils import nn
260
+
261
+ nn.functional.cross_entropy = liger_cross_entropy
262
+ else:
263
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
264
+ modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
265
+
111
266
  if fused_linear_cross_entropy:
112
267
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
113
- modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
268
+ if model is not None:
269
+ model.forward = MethodType(llama_lce_forward, model)
270
+ else:
271
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
114
272
  else: # if version < 4.46.1
115
273
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
116
- modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
274
+ if model is not None:
275
+ model.forward = MethodType(llama_lce_forward_deprecated, model)
276
+ else:
277
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
117
278
 
118
279
  if model is not None:
119
280
  # The model instance already exists, so we need to additionally patch the
120
281
  # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
121
282
 
122
- if hasattr(model, "model"):
123
- # The case for LlamaForCausalLM or LlamaForSequenceClassification, for example
124
- base_model = model.model
125
- elif hasattr(model, "transformer"):
126
- # LlamaForQuestionAnswering uses "transformer" instead of "model"
127
- base_model = model.transformer
283
+ # get the base model from the model instance
284
+ base_model: LlamaModel = getattr(model, model.base_model_prefix, model)
285
+
286
+ if rms_norm:
287
+ _patch_rms_norm_module(base_model.norm)
288
+
289
+ for decoder_layer in base_model.layers:
290
+ if swiglu:
291
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
292
+ if rms_norm:
293
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
294
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
295
+
296
+
297
+ def apply_liger_kernel_to_smollm3(
298
+ rope: bool = True,
299
+ cross_entropy: bool = False,
300
+ fused_linear_cross_entropy: bool = True,
301
+ rms_norm: bool = True,
302
+ swiglu: bool = True,
303
+ model: PreTrainedModel = None,
304
+ ) -> None:
305
+ """
306
+ Apply Liger kernels to replace original implementation in HuggingFace SmolLM3 model
307
+
308
+ Args:
309
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
310
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
311
+ fused_linear_cross_entropy (bool):
312
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
313
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
314
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
315
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
316
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
317
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
318
+ loaded. Default is None.
319
+ """
320
+
321
+ assert not (cross_entropy and fused_linear_cross_entropy), (
322
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
323
+ )
324
+
325
+ from transformers.models.smollm3 import modeling_smollm3
326
+ from transformers.models.smollm3.modeling_smollm3 import SmolLM3Model
327
+
328
+ if rope:
329
+ modeling_smollm3.apply_rotary_pos_emb = liger_rotary_pos_emb
330
+ if rms_norm:
331
+ modeling_smollm3.SmolLM3RMSNorm = LigerRMSNorm
332
+ if swiglu:
333
+ modeling_smollm3.SmolLM3MLP = LigerSwiGLUMLP
334
+
335
+ if cross_entropy:
336
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
337
+ from transformers.loss.loss_utils import nn
338
+
339
+ nn.functional.cross_entropy = liger_cross_entropy
340
+ else:
341
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
342
+ modeling_smollm3.CrossEntropyLoss = LigerCrossEntropyLoss
343
+
344
+ if fused_linear_cross_entropy:
345
+ if model is not None:
346
+ model.forward = MethodType(smollm3_lce_forward, model)
128
347
  else:
129
- # Direct LlamaModel
130
- base_model = model
348
+ modeling_smollm3.SmolLM3ForCausalLM.forward = smollm3_lce_forward
349
+
350
+ if model is not None:
351
+ # The model instance already exists, so we need to additionally patch the
352
+ # instance variables that reference already-instantiated modules (e.g. SmolLM3RMSNorm or SmolLM3MLP)
353
+
354
+ # get the base model from the model instance
355
+ base_model: SmolLM3Model = getattr(model, model.base_model_prefix, model)
131
356
 
132
357
  if rms_norm:
133
358
  _patch_rms_norm_module(base_model.norm)
134
359
 
135
360
  for decoder_layer in base_model.layers:
136
361
  if swiglu:
137
- _bind_method_to_module(
138
- decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
139
- )
362
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
140
363
  if rms_norm:
141
364
  _patch_rms_norm_module(decoder_layer.input_layernorm)
142
365
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
143
366
 
144
367
 
368
+ def apply_liger_kernel_to_llava(
369
+ cross_entropy: bool = False,
370
+ fused_linear_cross_entropy: bool = True,
371
+ model: PreTrainedModel = None,
372
+ **kwargs,
373
+ ) -> None:
374
+ """
375
+ Apply Liger kernels to replace original implementation in HuggingFace Llava models.
376
+ Due to the characteristics of LlaVa, the model must be passed to apply Liger-Kernel's patch to other models connected to LLaVa.
377
+ However, if an LM not supported by Liger-Kernel is connected to LLaVa, unexpected side effects may occur.
378
+ NOTE: Llava is not available in transformers<4.36.0
379
+
380
+ Args:
381
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
382
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
383
+ fused_linear_cross_entropy (bool):
384
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
385
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
386
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
387
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
388
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
389
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
390
+ loaded. Default is None.
391
+ """
392
+ assert not (cross_entropy and fused_linear_cross_entropy), (
393
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
394
+ )
395
+
396
+ from transformers.models.llava import modeling_llava
397
+
398
+ if cross_entropy:
399
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
400
+ modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
401
+ if fused_linear_cross_entropy:
402
+ if transformer_version >= version.parse("4.52.0"):
403
+ if model is not None:
404
+ model.forward = MethodType(llava_lce_forward, model)
405
+ else:
406
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
407
+ elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
408
+ if model is not None:
409
+ model.forward = MethodType(llava_lce_forward_deprecated, model)
410
+ else:
411
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
412
+ else: # if version < 4.49.0
413
+ logger.warning(
414
+ "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
415
+ )
416
+
417
+ if model is not None:
418
+ text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
419
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
420
+ vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)
421
+
422
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs}
423
+ if text_liger_fn:
424
+ accept_params = inspect.signature(text_liger_fn).parameters
425
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
426
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
427
+
428
+ if remain_params:
429
+ logger.warning(
430
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
431
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
432
+ )
433
+ text_kwargs["model"] = model.language_model
434
+ text_liger_fn(**text_kwargs)
435
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
436
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
437
+
438
+ if vision_liger_fn:
439
+ accept_params = inspect.signature(vision_liger_fn).parameters
440
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
441
+ vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
442
+
443
+ if remain_params:
444
+ logger.warning(
445
+ f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
446
+ f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
447
+ )
448
+ vision_kwargs["model"] = model.vision_tower
449
+ vision_liger_fn(**vision_kwargs)
450
+ elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
451
+ logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
452
+
453
+
454
+ def apply_liger_kernel_to_llama4(
455
+ rope: bool = True,
456
+ cross_entropy: bool = False,
457
+ fused_linear_cross_entropy: bool = True,
458
+ rms_norm: bool = True,
459
+ swiglu: bool = True,
460
+ model: PreTrainedModel = None,
461
+ layer_norm: bool = True,
462
+ ) -> None:
463
+ """
464
+ Apply Liger kernels to replace original implementation in HuggingFace Llama4 models.
465
+
466
+ Args:
467
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
468
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
469
+ fused_linear_cross_entropy (bool):
470
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
471
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
472
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
473
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
474
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
475
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
476
+ loaded. Default is None.
477
+ """
478
+ assert not (cross_entropy and fused_linear_cross_entropy), (
479
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
480
+ )
481
+
482
+ from transformers.models.llama4 import modeling_llama4
483
+ from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
484
+ from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
485
+ from transformers.models.llama4.modeling_llama4 import Llama4TextModel
486
+ from transformers.models.llama4.modeling_llama4 import Llama4VisionModel
487
+
488
+ from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
489
+
490
+ if rope:
491
+ from liger_kernel.transformers.llama4_rope import apply_liger_llama4_rope_full
492
+
493
+ apply_liger_llama4_rope_full(modeling_llama4)
494
+ if rms_norm:
495
+ modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
496
+ if swiglu:
497
+ modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
498
+
499
+ if cross_entropy:
500
+ modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
501
+
502
+ if fused_linear_cross_entropy:
503
+ modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
504
+
505
+ if model is not None:
506
+ # The model instance already exists, so we need to additionally patch the
507
+ # instance variables that reference already-instantiated modules
508
+ if isinstance(model, Llama4ForConditionalGeneration):
509
+ language_model: Llama4ForCausalLM = model.language_model
510
+ vision_model: Llama4VisionModel = model.vision_model
511
+ text_model: Llama4TextModel = language_model.model
512
+ elif isinstance(model, Llama4ForCausalLM):
513
+ text_model = model.model
514
+ vision_model = None
515
+ elif isinstance(model, Llama4TextModel):
516
+ text_model = model
517
+ vision_model = None
518
+
519
+ else:
520
+ raise ValueError(f"Unsupported Llama4 model type: {type(model)}")
521
+
522
+ if text_model:
523
+ if rms_norm:
524
+ _patch_rms_norm_module(text_model.norm)
525
+ for decoder_layer in text_model.layers:
526
+ if swiglu:
527
+ if decoder_layer.is_moe_layer:
528
+ _patch_swiglu_module(decoder_layer.feed_forward.shared_expert, LigerSwiGLUMLP)
529
+ else:
530
+ _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
531
+ if rms_norm:
532
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
533
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
534
+
535
+ if vision_model:
536
+ _patch_layer_norm_module(vision_model.layernorm_pre)
537
+ _patch_layer_norm_module(vision_model.layernorm_post)
538
+
539
+ for layer in vision_model.model.layers:
540
+ if layer_norm:
541
+ _patch_layer_norm_module(layer.input_layernorm)
542
+ _patch_layer_norm_module(layer.post_attention_layernorm)
543
+
544
+
145
545
  def apply_liger_kernel_to_mllama(
146
546
  rope: bool = True,
147
547
  cross_entropy: bool = False,
@@ -168,39 +568,47 @@ def apply_liger_kernel_to_mllama(
168
568
  loaded. Default is None.
169
569
  """
170
570
 
171
- assert not (
172
- cross_entropy and fused_linear_cross_entropy
173
- ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
571
+ assert not (cross_entropy and fused_linear_cross_entropy), (
572
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
573
+ )
174
574
 
175
575
  from transformers.models.mllama import modeling_mllama
176
- from transformers.models.mllama.modeling_mllama import (
177
- MllamaForCausalLM,
178
- MllamaForConditionalGeneration,
179
- MllamaTextModel,
180
- MllamaVisionModel,
181
- )
576
+ from transformers.models.mllama.modeling_mllama import MllamaForCausalLM
577
+ from transformers.models.mllama.modeling_mllama import MllamaForConditionalGeneration
578
+ from transformers.models.mllama.modeling_mllama import MllamaTextModel
579
+ from transformers.models.mllama.modeling_mllama import MllamaVisionModel
182
580
 
183
581
  from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward
184
- from liger_kernel.transformers.model.mllama import (
185
- lce_forward_deprecated as mllama_lce_forward_deprecated,
186
- )
582
+ from liger_kernel.transformers.model.mllama import lce_forward_deprecated as mllama_lce_forward_deprecated
187
583
 
188
584
  if rope:
189
585
  modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
190
- if layer_norm:
586
+ if layer_norm and model is None:
191
587
  modeling_mllama.nn.LayerNorm = LigerLayerNorm
192
588
  if rms_norm:
193
589
  modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
194
590
  if swiglu:
195
591
  modeling_mllama.MllamaTextMLP = LigerSwiGLUMLP
196
592
  if cross_entropy:
197
- modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
593
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
594
+ from transformers.loss.loss_utils import nn
595
+
596
+ nn.functional.cross_entropy = liger_cross_entropy
597
+ else:
598
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
599
+ modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
198
600
  if fused_linear_cross_entropy:
199
601
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
200
- modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
602
+ if model is not None:
603
+ model.forward = MethodType(mllama_lce_forward, model)
604
+ else:
605
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
201
606
  else: # if version < 4.46.1
202
607
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
203
- modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
608
+ if model is not None:
609
+ model.forward = MethodType(mllama_lce_forward_deprecated, model)
610
+ else:
611
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
204
612
 
205
613
  if model is not None:
206
614
  # The model instance already exists, so we need to additionally patch the
@@ -209,13 +617,17 @@ def apply_liger_kernel_to_mllama(
209
617
  if isinstance(model, MllamaForConditionalGeneration):
210
618
  language_model: MllamaForCausalLM = model.language_model
211
619
  vision_model: MllamaVisionModel = model.vision_model
212
- text_model: MllamaTextModel = language_model.model
620
+ if isinstance(language_model, MllamaForCausalLM):
621
+ text_model: MllamaTextModel = language_model.model
622
+ else:
623
+ text_model = language_model
213
624
  elif isinstance(model, MllamaForCausalLM):
214
625
  text_model = model.model
215
626
  vision_model = None
216
627
  elif isinstance(model, MllamaTextModel):
217
628
  text_model = model
218
629
  vision_model = None
630
+
219
631
  else:
220
632
  raise ValueError(f"Unsupported Mllama model type: {type(model)}")
221
633
 
@@ -224,9 +636,7 @@ def apply_liger_kernel_to_mllama(
224
636
  _patch_rms_norm_module(text_model.norm)
225
637
  for decoder_layer in text_model.layers:
226
638
  if swiglu:
227
- _bind_method_to_module(
228
- decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
229
- )
639
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
230
640
  if rms_norm:
231
641
  _patch_rms_norm_module(decoder_layer.input_layernorm)
232
642
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -258,7 +668,7 @@ def apply_liger_kernel_to_mistral(
258
668
  Apply Liger kernels to replace original implementation in HuggingFace Mistral models
259
669
 
260
670
  Args:
261
- rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
671
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
262
672
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
263
673
  fused_linear_cross_entropy (bool):
264
674
  Whether to apply Liger's fused linear cross entropy loss. Default is True.
@@ -270,11 +680,12 @@ def apply_liger_kernel_to_mistral(
270
680
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
271
681
  loaded. Default is None.
272
682
  """
273
- assert not (
274
- cross_entropy and fused_linear_cross_entropy
275
- ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
683
+ assert not (cross_entropy and fused_linear_cross_entropy), (
684
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
685
+ )
276
686
 
277
687
  from transformers.models.mistral import modeling_mistral
688
+ from transformers.models.mistral.modeling_mistral import MistralModel
278
689
 
279
690
  if rope:
280
691
  modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
@@ -283,7 +694,17 @@ def apply_liger_kernel_to_mistral(
283
694
  if cross_entropy:
284
695
  modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
285
696
  if fused_linear_cross_entropy:
286
- modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
697
+ if transformer_version >= version.parse("4.49.0"):
698
+ if model is not None:
699
+ model.forward = MethodType(mistral_lce_forward, model)
700
+ else:
701
+ modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
702
+ else:
703
+ logger.warning(
704
+ "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
705
+ )
706
+ logger.warning("LigerFusedLinearCrossEntropy patch is not applied.")
707
+
287
708
  if swiglu:
288
709
  modeling_mistral.MistralMLP = LigerSwiGLUMLP
289
710
 
@@ -291,21 +712,15 @@ def apply_liger_kernel_to_mistral(
291
712
  # The model instance already exists, so we need to additionally patch the
292
713
  # instance variables that reference already-instantiated modules
293
714
 
294
- if hasattr(model, "model"):
295
- # The case for MistralForCausalLM, MistralForTokenClassification for example
296
- base_model = model.model
297
- else:
298
- # Direct MistralModel
299
- base_model = model
715
+ # get the base model from the model instance
716
+ base_model: MistralModel = getattr(model, model.base_model_prefix, model)
300
717
 
301
718
  if rms_norm:
302
719
  _patch_rms_norm_module(base_model.norm)
303
720
 
304
721
  for decoder_layer in base_model.layers:
305
722
  if swiglu:
306
- _bind_method_to_module(
307
- decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
308
- )
723
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
309
724
  if rms_norm:
310
725
  _patch_rms_norm_module(decoder_layer.input_layernorm)
311
726
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -335,24 +750,38 @@ def apply_liger_kernel_to_mixtral(
335
750
  loaded. Default is None.
336
751
  """
337
752
 
338
- assert not (
339
- cross_entropy and fused_linear_cross_entropy
340
- ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
753
+ assert not (cross_entropy and fused_linear_cross_entropy), (
754
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
755
+ )
341
756
 
342
757
  from transformers.models.mixtral import modeling_mixtral
758
+ from transformers.models.mixtral.modeling_mixtral import MixtralModel
343
759
 
344
760
  if rope:
345
761
  modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
346
762
  if rms_norm:
347
763
  modeling_mixtral.MixtralRMSNorm = LigerRMSNorm
348
764
  if cross_entropy:
349
- modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
765
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
766
+ from transformers.loss.loss_utils import nn
767
+
768
+ nn.functional.cross_entropy = liger_cross_entropy
769
+ else:
770
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
771
+ modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
772
+
350
773
  if fused_linear_cross_entropy:
351
774
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
352
- modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
775
+ if model is not None:
776
+ model.forward = MethodType(mixtral_lce_forward, model)
777
+ else:
778
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
353
779
  else: # if version < 4.46.1
354
780
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
355
- modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
781
+ if model is not None:
782
+ model.forward = MethodType(mixtral_lce_forward_deprecated, model)
783
+ else:
784
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
356
785
  if swiglu:
357
786
  modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
358
787
 
@@ -360,12 +789,8 @@ def apply_liger_kernel_to_mixtral(
360
789
  # The model instance already exists, so we need to additionally patch the
361
790
  # instance variables that reference already-instantiated modules
362
791
 
363
- if hasattr(model, "model"):
364
- # The case for MixtralForCausalLM, MixtralForTokenClassification for example
365
- base_model = model.model
366
- else:
367
- # Direct MixtralModel
368
- base_model = model
792
+ # get the base model from the model instance
793
+ base_model: MixtralModel = getattr(model, model.base_model_prefix, model)
369
794
 
370
795
  if rms_norm:
371
796
  _patch_rms_norm_module(base_model.norm)
@@ -373,9 +798,7 @@ def apply_liger_kernel_to_mixtral(
373
798
  for decoder_layer in base_model.layers:
374
799
  if swiglu:
375
800
  for expert in decoder_layer.block_sparse_moe.experts:
376
- _bind_method_to_module(
377
- expert, "forward", LigerBlockSparseTop2MLP.forward
378
- )
801
+ _patch_swiglu_module(expert, LigerBlockSparseTop2MLP)
379
802
  if rms_norm:
380
803
  _patch_rms_norm_module(decoder_layer.input_layernorm)
381
804
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -405,54 +828,57 @@ def apply_liger_kernel_to_gemma(
405
828
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
406
829
  loaded. Default is None.
407
830
  """
408
- assert not (
409
- cross_entropy and fused_linear_cross_entropy
410
- ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
831
+ assert not (cross_entropy and fused_linear_cross_entropy), (
832
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
833
+ )
411
834
 
412
835
  from transformers.models.gemma import modeling_gemma
836
+ from transformers.models.gemma.modeling_gemma import GemmaModel
413
837
 
414
- # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
415
- LigerRMSNormForGemma = partial(
416
- LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
417
- )
418
- _patch_rms_norm_module_for_gemma = partial(
419
- _patch_rms_norm_module, casting_mode="gemma", offset=1.0
420
- )
838
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma
839
+
840
+ _patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
421
841
 
422
842
  if rope:
423
843
  modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
424
844
  if rms_norm:
425
845
  modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma
426
846
  if cross_entropy:
427
- modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
847
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
848
+ from transformers.loss.loss_utils import nn
849
+
850
+ nn.functional.cross_entropy = liger_cross_entropy
851
+ else:
852
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
853
+ modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
428
854
  if geglu:
429
855
  modeling_gemma.GemmaMLP = LigerGEGLUMLP
430
856
  if fused_linear_cross_entropy:
431
857
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
432
- modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
858
+ if model is not None:
859
+ model.forward = MethodType(gemma_lce_forward, model)
860
+ else:
861
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
433
862
  else: # if version < 4.46.1
434
863
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
435
- modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
864
+ if model is not None:
865
+ model.forward = MethodType(gemma_lce_forward_deprecated, model)
866
+ else:
867
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
436
868
 
437
869
  if model is not None:
438
870
  # The model instance already exists, so we need to additionally patch the
439
871
  # instance variables that reference already-instantiated modules
440
872
 
441
- if hasattr(model, "model"):
442
- # The case for GemmaForCausalLM, GemmaForTokenClassification for example
443
- base_model = model.model
444
- else:
445
- # Direct GemmaModel
446
- base_model = model
873
+ # get the base model from the model instance
874
+ base_model: GemmaModel = getattr(model, model.base_model_prefix, model)
447
875
 
448
876
  if rms_norm:
449
877
  _patch_rms_norm_module_for_gemma(base_model.norm)
450
878
 
451
879
  for decoder_layer in base_model.layers:
452
880
  if geglu:
453
- _bind_method_to_module(
454
- decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
455
- )
881
+ _patch_geglu_module(decoder_layer.mlp)
456
882
  if rms_norm:
457
883
  _patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
458
884
  _patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
@@ -460,7 +886,8 @@ def apply_liger_kernel_to_gemma(
460
886
 
461
887
  def apply_liger_kernel_to_gemma2(
462
888
  rope: bool = True,
463
- cross_entropy: bool = True,
889
+ cross_entropy: bool = False,
890
+ fused_linear_cross_entropy: bool = True,
464
891
  rms_norm: bool = True,
465
892
  geglu: bool = True,
466
893
  model: PreTrainedModel = None,
@@ -471,65 +898,1107 @@ def apply_liger_kernel_to_gemma2(
471
898
 
472
899
  Args:
473
900
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
474
- cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
475
- rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
476
- geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
901
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
902
+ fused_linear_cross_entropy (bool):
903
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
904
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
905
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
906
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
907
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
908
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
909
+ loaded. Default is None.
910
+ """
911
+ assert not (cross_entropy and fused_linear_cross_entropy), (
912
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
913
+ )
914
+
915
+ from transformers.models.gemma2 import modeling_gemma2
916
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
917
+
918
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2
919
+
920
+ _patch_rms_norm_module_for_gemma2 = partial(
921
+ _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
922
+ )
923
+
924
+ if rope:
925
+ modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
926
+ if rms_norm:
927
+ # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
928
+ modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2
929
+ if cross_entropy:
930
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
931
+ from transformers.loss.loss_utils import nn
932
+
933
+ nn.functional.cross_entropy = liger_cross_entropy
934
+ else:
935
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
936
+ modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
937
+ if fused_linear_cross_entropy:
938
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
939
+ if model is not None:
940
+ model.forward = MethodType(gemma2_lce_forward, model)
941
+ else:
942
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
943
+ else:
944
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
945
+ if model is not None:
946
+ model.forward = MethodType(gemma2_lce_forward_deprected, model)
947
+ else:
948
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
949
+ if geglu:
950
+ modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
951
+
952
+ if model is not None:
953
+ # The model instance already exists, so we need to additionally patch the
954
+ # instance variables that reference already-instantiated modules
955
+
956
+ # get the base model from the model instance
957
+ base_model: Gemma2Model = getattr(model, model.base_model_prefix, model)
958
+
959
+ if rms_norm:
960
+ _patch_rms_norm_module_for_gemma2(base_model.norm)
961
+
962
+ for decoder_layer in base_model.layers:
963
+ if geglu:
964
+ _patch_geglu_module(decoder_layer.mlp)
965
+ if rms_norm:
966
+ _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
967
+ _patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
968
+ _patch_rms_norm_module_for_gemma2(decoder_layer.pre_feedforward_layernorm)
969
+ _patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
970
+
971
+
972
+ def apply_liger_kernel_to_gemma3_text(
973
+ rope: bool = True,
974
+ cross_entropy: bool = False,
975
+ fused_linear_cross_entropy: bool = True,
976
+ rms_norm: bool = True,
977
+ geglu: bool = True,
978
+ model: PreTrainedModel = None,
979
+ ) -> None:
980
+ """
981
+ Apply Liger kernels to replace original implementation in HuggingFace Gemma3
982
+
983
+ Args:
984
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
985
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
986
+ fused_linear_cross_entropy (bool):
987
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
988
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
989
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
990
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
991
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
992
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
993
+ loaded. Default is None.
994
+ """
995
+ assert not (cross_entropy and fused_linear_cross_entropy), (
996
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
997
+ )
998
+
999
+ from transformers.models.gemma3 import modeling_gemma3
1000
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
1001
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
1002
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
1003
+
1004
+ from liger_kernel.transformers.model.gemma3 import causal_forward
1005
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma3
1006
+
1007
+ _patch_rms_norm_module_for_gemma3 = partial(
1008
+ _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
1009
+ )
1010
+
1011
+ if rope:
1012
+ modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
1013
+
1014
+ if rms_norm:
1015
+ modeling_gemma3.Gemma3RMSNorm = LigerRMSNormForGemma3
1016
+
1017
+ if geglu:
1018
+ modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
1019
+
1020
+ # Handle loss function
1021
+ if cross_entropy:
1022
+ from transformers.loss.loss_utils import nn
1023
+
1024
+ nn.functional.cross_entropy = liger_cross_entropy
1025
+
1026
+ if fused_linear_cross_entropy:
1027
+ if model is not None:
1028
+ model.forward = MethodType(causal_forward, model)
1029
+ else:
1030
+ modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
1031
+
1032
+ if model is not None:
1033
+ # The model instance already exists, so we need to additionally patch the
1034
+ # instance variables that reference already-instantiated modules
1035
+
1036
+ if isinstance(model, Gemma3ForCausalLM) or isinstance(model, Gemma3TextModel):
1037
+ # get the base model from the model instance
1038
+ base_model = model.model if isinstance(model, Gemma3ForCausalLM) else model
1039
+
1040
+ if rms_norm:
1041
+ _patch_rms_norm_module_for_gemma3(base_model.norm)
1042
+
1043
+ for decoder_layer in base_model.layers:
1044
+ decoder_layer: Gemma3DecoderLayer
1045
+ if geglu:
1046
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
1047
+ if rms_norm:
1048
+ _patch_rms_norm_module_for_gemma3(decoder_layer.input_layernorm)
1049
+ _patch_rms_norm_module_for_gemma3(decoder_layer.post_attention_layernorm)
1050
+ _patch_rms_norm_module_for_gemma3(decoder_layer.pre_feedforward_layernorm)
1051
+ _patch_rms_norm_module_for_gemma3(decoder_layer.post_feedforward_layernorm)
1052
+ _patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.q_norm)
1053
+ _patch_rms_norm_module_for_gemma3(decoder_layer.self_attn.k_norm)
1054
+
1055
+ else:
1056
+ raise TypeError("The model must be Gemma3ForCausalLM.")
1057
+
1058
+
1059
+ def apply_liger_kernel_to_gemma3(
1060
+ rope: bool = True,
1061
+ cross_entropy: bool = False,
1062
+ fused_linear_cross_entropy: bool = True,
1063
+ layer_norm: bool = True,
1064
+ rms_norm: bool = True,
1065
+ geglu: bool = True,
1066
+ model: PreTrainedModel = None,
1067
+ ) -> None:
1068
+ """
1069
+ Apply Liger kernels to replace original implementation in HuggingFace Gemma3
1070
+
1071
+ Args:
1072
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1073
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1074
+ fused_linear_cross_entropy (bool):
1075
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1076
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1077
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1078
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
1079
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1080
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
1081
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1082
+ loaded. Default is None.
1083
+ """
1084
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1085
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1086
+ )
1087
+
1088
+ from transformers.models.gemma3 import modeling_gemma3
1089
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3ForConditionalGeneration
1090
+ from transformers.models.siglip import modeling_siglip
1091
+ from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
1092
+ from transformers.models.siglip.modeling_siglip import SiglipVisionModel
1093
+
1094
+ from liger_kernel.transformers.model.gemma3 import multimodal_forward
1095
+
1096
+ _patch_rms_norm_module_for_gemma3 = partial(
1097
+ _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
1098
+ )
1099
+
1100
+ if layer_norm and model is None:
1101
+ modeling_siglip.nn.LayerNorm = LigerLayerNorm
1102
+
1103
+ apply_liger_kernel_to_gemma3_text(
1104
+ rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
1105
+ )
1106
+
1107
+ if cross_entropy:
1108
+ modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
1109
+
1110
+ if fused_linear_cross_entropy:
1111
+ if model is not None:
1112
+ model.forward = MethodType(multimodal_forward, model)
1113
+ else:
1114
+ modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
1115
+
1116
+ if model is not None:
1117
+ # The model instance already exists, so we need to additionally patch the
1118
+ # instance variables that reference already-instantiated modules
1119
+
1120
+ if isinstance(model, Gemma3ForConditionalGeneration):
1121
+ if isinstance(model.vision_tower, SiglipVisionModel):
1122
+ vision_tower = model.vision_tower
1123
+
1124
+ _patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
1125
+
1126
+ for layer in vision_tower.vision_model.encoder.layers:
1127
+ layer: SiglipEncoderLayer
1128
+ if layer_norm:
1129
+ _patch_layer_norm_module(layer.layer_norm1)
1130
+ _patch_layer_norm_module(layer.layer_norm2)
1131
+ else:
1132
+ raise TypeError("The vision tower must be SiglipVisionModel")
1133
+
1134
+ if rms_norm:
1135
+ _patch_rms_norm_module_for_gemma3(model.multi_modal_projector.mm_soft_emb_norm)
1136
+
1137
+ apply_liger_kernel_to_gemma3_text(
1138
+ rope=rope,
1139
+ cross_entropy=False,
1140
+ fused_linear_cross_entropy=False,
1141
+ rms_norm=rms_norm,
1142
+ geglu=geglu,
1143
+ model=model.language_model,
1144
+ )
1145
+
1146
+ else:
1147
+ raise TypeError("The model must be Gemma3ForConditionalGeneration.")
1148
+
1149
+
1150
+ def apply_liger_kernel_to_paligemma(
1151
+ rope: bool = True,
1152
+ cross_entropy: bool = False,
1153
+ fused_linear_cross_entropy: bool = True,
1154
+ layer_norm: bool = True,
1155
+ rms_norm: bool = True,
1156
+ geglu: bool = True,
1157
+ model: PreTrainedModel = None,
1158
+ ) -> None:
1159
+ """
1160
+ Apply Liger kernels to replace original implementation in HuggingFace PaliGemma
1161
+
1162
+ Args:
1163
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1164
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1165
+ fused_linear_cross_entropy (bool):
1166
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1167
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1168
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1169
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
1170
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1171
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
1172
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1173
+ loaded. Default is None.
1174
+ """
1175
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1176
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1177
+ )
1178
+
1179
+ # PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
1180
+
1181
+ from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
1182
+ from transformers.models.gemma.modeling_gemma import GemmaModel
1183
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
1184
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
1185
+ from transformers.models.paligemma import modeling_paligemma
1186
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
1187
+ from transformers.models.siglip import modeling_siglip
1188
+ from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
1189
+ from transformers.models.siglip.modeling_siglip import SiglipVisionModel
1190
+
1191
+ from liger_kernel.transformers.model.paligemma import lce_forward
1192
+ from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
1193
+
1194
+ # The vision_tower is a SiglipVisionModel
1195
+ if layer_norm and model is None:
1196
+ modeling_siglip.nn.LayerNorm = LigerLayerNorm
1197
+
1198
+ # SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
1199
+ # The multi_modal_projector is Linear, nothing to do
1200
+
1201
+ # The language_model is GemmaForCausalLM or Gemma2ForCausalLM
1202
+ apply_liger_kernel_to_gemma(
1203
+ rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
1204
+ )
1205
+ apply_liger_kernel_to_gemma2(
1206
+ rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
1207
+ )
1208
+ # Handle loss function
1209
+ if cross_entropy:
1210
+ modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
1211
+ if fused_linear_cross_entropy:
1212
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1213
+ if model is not None:
1214
+ model.forward = MethodType(lce_forward, model)
1215
+ else:
1216
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
1217
+ else: # if version < 4.46.1
1218
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1219
+ if model is not None:
1220
+ model.forward = MethodType(lce_forward_deprecated, model)
1221
+ else:
1222
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
1223
+
1224
+ if model is not None:
1225
+ # The model instance already exists, so we need to additionally patch the
1226
+ # instance variables that reference already-instantiated modules
1227
+
1228
+ if not isinstance(model, PaliGemmaForConditionalGeneration):
1229
+ raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
1230
+
1231
+ vision_tower: SiglipVisionModel = model.vision_tower
1232
+
1233
+ _patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
1234
+
1235
+ for layer in vision_tower.vision_model.encoder.layers:
1236
+ layer: SiglipEncoderLayer
1237
+ if layer_norm:
1238
+ _patch_layer_norm_module(layer.layer_norm1)
1239
+ _patch_layer_norm_module(layer.layer_norm2)
1240
+
1241
+ language_model = model.language_model
1242
+
1243
+ if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
1244
+ apply_liger_kernel_to_gemma(
1245
+ rope=rope,
1246
+ cross_entropy=False,
1247
+ fused_linear_cross_entropy=False,
1248
+ rms_norm=rms_norm,
1249
+ geglu=geglu,
1250
+ model=language_model,
1251
+ )
1252
+
1253
+ elif isinstance(language_model, (Gemma2ForCausalLM, Gemma2Model)):
1254
+ apply_liger_kernel_to_gemma2(
1255
+ rope=rope,
1256
+ cross_entropy=False,
1257
+ fused_linear_cross_entropy=False,
1258
+ rms_norm=rms_norm,
1259
+ geglu=geglu,
1260
+ model=language_model,
1261
+ )
1262
+ else:
1263
+ raise TypeError(
1264
+ "The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM."
1265
+ )
1266
+
1267
+
1268
+ def apply_liger_kernel_to_qwen2(
1269
+ rope: bool = True,
1270
+ cross_entropy: bool = False,
1271
+ fused_linear_cross_entropy: bool = True,
1272
+ rms_norm: bool = True,
1273
+ swiglu: bool = True,
1274
+ model: PreTrainedModel = None,
1275
+ ) -> None:
1276
+ """
1277
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models
1278
+
1279
+ Args:
1280
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1281
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1282
+ fused_linear_cross_entropy (bool):
1283
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1284
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1285
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1286
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1287
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
1288
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1289
+ loaded. Default is None.
1290
+ """
1291
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1292
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1293
+ )
1294
+
1295
+ from transformers.models.qwen2 import modeling_qwen2
1296
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
1297
+
1298
+ if rope:
1299
+ modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
1300
+ if rms_norm:
1301
+ modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
1302
+
1303
+ if cross_entropy:
1304
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1305
+ from transformers.loss.loss_utils import nn
1306
+
1307
+ nn.functional.cross_entropy = liger_cross_entropy
1308
+ else:
1309
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1310
+ modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
1311
+
1312
+ if fused_linear_cross_entropy:
1313
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1314
+ if model is not None:
1315
+ model.forward = MethodType(qwen2_lce_forward, model)
1316
+ else:
1317
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
1318
+ else: # if version < 4.46.1
1319
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1320
+ if model is not None:
1321
+ model.forward = MethodType(qwen2_lce_forward_deprecated, model)
1322
+ else:
1323
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
1324
+
1325
+ if swiglu:
1326
+ modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
1327
+
1328
+ if model is not None:
1329
+ # The model instance already exists, so we need to additionally patch the
1330
+ # instance variables that reference already-instantiated modules
1331
+
1332
+ # get the base model from the model instance
1333
+ base_model: Qwen2Model = getattr(model, model.base_model_prefix, model)
1334
+
1335
+ if rms_norm:
1336
+ _patch_rms_norm_module(base_model.norm)
1337
+
1338
+ for decoder_layer in base_model.layers:
1339
+ if swiglu:
1340
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1341
+ if rms_norm:
1342
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1343
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1344
+
1345
+
1346
+ def apply_liger_kernel_to_qwen3(
1347
+ rope: bool = True,
1348
+ cross_entropy: bool = False,
1349
+ fused_linear_cross_entropy: bool = True,
1350
+ rms_norm: bool = True,
1351
+ swiglu: bool = True,
1352
+ model: PreTrainedModel = None,
1353
+ ) -> None:
1354
+ """
1355
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
1356
+ """
1357
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1358
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1359
+ )
1360
+
1361
+ from transformers.models.qwen3 import modeling_qwen3
1362
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3Model
1363
+
1364
+ from liger_kernel.transformers.model.qwen3 import lce_forward as qwen3_lce_forward
1365
+
1366
+ if rope:
1367
+ modeling_qwen3.apply_rotary_pos_emb = liger_rotary_pos_emb
1368
+
1369
+ if rms_norm:
1370
+ modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm
1371
+
1372
+ if cross_entropy:
1373
+ from transformers.loss.loss_utils import nn
1374
+
1375
+ nn.functional.cross_entropy = liger_cross_entropy
1376
+
1377
+ if fused_linear_cross_entropy:
1378
+ if model is not None:
1379
+ model.forward = MethodType(qwen3_lce_forward, model)
1380
+ else:
1381
+ modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
1382
+
1383
+ if swiglu:
1384
+ modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
1385
+
1386
+ if model is not None:
1387
+ # The model instance already exists, so we need to additionally patch the
1388
+ # instance variables that reference already-instantiated modules
1389
+
1390
+ # get the base model from the model instance
1391
+ base_model: Qwen3Model = getattr(model, model.base_model_prefix, model)
1392
+
1393
+ if rms_norm:
1394
+ _patch_rms_norm_module(base_model.norm)
1395
+ for decoder_layer in base_model.layers:
1396
+ if swiglu:
1397
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1398
+ if rms_norm:
1399
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1400
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1401
+
1402
+
1403
+ def apply_liger_kernel_to_qwen3_moe(
1404
+ rope: bool = True,
1405
+ cross_entropy: bool = False,
1406
+ fused_linear_cross_entropy: bool = True,
1407
+ rms_norm: bool = True,
1408
+ swiglu: bool = True,
1409
+ model: PreTrainedModel = None,
1410
+ ) -> None:
1411
+ """
1412
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
1413
+ """
1414
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1415
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1416
+ )
1417
+
1418
+ from transformers.models.qwen3_moe import modeling_qwen3_moe
1419
+ from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeModel
1420
+
1421
+ from liger_kernel.transformers.model.qwen3_moe import lce_forward as qwen3_lce_forward
1422
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
1423
+
1424
+ if rope:
1425
+ modeling_qwen3_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
1426
+
1427
+ if rms_norm:
1428
+ modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
1429
+
1430
+ if cross_entropy:
1431
+ from transformers.loss.loss_utils import nn
1432
+
1433
+ nn.functional.cross_entropy = liger_cross_entropy
1434
+
1435
+ if fused_linear_cross_entropy:
1436
+ if model is not None:
1437
+ model.forward = MethodType(qwen3_lce_forward, model)
1438
+ else:
1439
+ modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
1440
+
1441
+ if swiglu:
1442
+ modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
1443
+
1444
+ if model is not None:
1445
+ # The model instance already exists, so we need to additionally patch the
1446
+ # instance variables that reference already-instantiated modules
1447
+
1448
+ # get the base model from the model instance
1449
+ base_model: Qwen3MoeModel = getattr(model, model.base_model_prefix, model)
1450
+
1451
+ if rms_norm:
1452
+ _patch_rms_norm_module(base_model.norm)
1453
+ for decoder_layer in base_model.layers:
1454
+ if swiglu:
1455
+ for mlp_expert in decoder_layer.mlp.experts:
1456
+ _patch_swiglu_module(mlp_expert, LigerQwen3MoeSwiGLUMLP)
1457
+ if rms_norm:
1458
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1459
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1460
+
1461
+
1462
+ def apply_liger_kernel_to_qwen2_vl(
1463
+ rope: bool = True,
1464
+ cross_entropy: bool = False,
1465
+ fused_linear_cross_entropy: bool = True,
1466
+ rms_norm: bool = True,
1467
+ layer_norm: bool = True,
1468
+ swiglu: bool = True,
1469
+ model: PreTrainedModel = None,
1470
+ ) -> None:
1471
+ """
1472
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
1473
+ NOTE: Qwen2-VL is not supported in transformers<4.52.4
1474
+
1475
+ Args:
1476
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1477
+ fused_linear_cross_entropy (bool):
1478
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1479
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1480
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1481
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1482
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
1483
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
1484
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1485
+ loaded. Default is None.
1486
+ """
1487
+ if transformer_version < version.parse("4.52.4"):
1488
+ logger.warning("Qwen2-VL support is only compatible with transformers >= 4.52.4")
1489
+ return
1490
+
1491
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1492
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1493
+ )
1494
+
1495
+ from transformers.models.qwen2_vl import modeling_qwen2_vl
1496
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
1497
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
1498
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
1499
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel
1500
+
1501
+ from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
1502
+
1503
+ if rope:
1504
+ modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
1505
+ if rms_norm:
1506
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
1507
+ modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
1508
+ if layer_norm and model is None:
1509
+ modeling_qwen2_vl.LayerNorm = LigerLayerNorm
1510
+ if cross_entropy:
1511
+ modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1512
+ if fused_linear_cross_entropy:
1513
+ if model is not None:
1514
+ model.forward = MethodType(qwen2_vl_lce_forward, model)
1515
+ else:
1516
+ modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
1517
+ if swiglu:
1518
+ modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
1519
+
1520
+ if model is not None:
1521
+ # The model instance already exists, so we need to additionally patch the
1522
+ # instance variables that reference already-instantiated modules
1523
+
1524
+ if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)):
1525
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
1526
+ # Not sure if it is subject to changes in the future.
1527
+ # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
1528
+ text_model: Qwen2VLTextModel = model.language_model
1529
+ vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
1530
+ elif isinstance(model, Qwen2VLTextModel):
1531
+ text_model: Qwen2VLTextModel = model
1532
+ vision_model = None
1533
+ else:
1534
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
1535
+ raise TypeError(
1536
+ f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
1537
+ )
1538
+
1539
+ # Patch Qwen2VisionTransformerPretrainedModel
1540
+ if vision_model is not None:
1541
+ for vision_block in vision_model.blocks:
1542
+ if layer_norm:
1543
+ _patch_layer_norm_module(vision_block.norm1)
1544
+ _patch_layer_norm_module(vision_block.norm2)
1545
+
1546
+ # Patch Qwen2VisionTextModel
1547
+ if text_model is not None:
1548
+ if rms_norm:
1549
+ _patch_rms_norm_module(text_model.norm)
1550
+ for decoder_layer in text_model.layers:
1551
+ if swiglu:
1552
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1553
+ if rms_norm:
1554
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1555
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1556
+
1557
+
1558
+ def apply_liger_kernel_to_qwen2_5_vl(
1559
+ rope: bool = True,
1560
+ cross_entropy: bool = False,
1561
+ fused_linear_cross_entropy: bool = True,
1562
+ rms_norm: bool = True,
1563
+ swiglu: bool = True,
1564
+ model: PreTrainedModel = None,
1565
+ ) -> None:
1566
+ """
1567
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen2.5-VL models.
1568
+ NOTE: Qwen2.5-VL is not available in transformers<4.48.2
1569
+
1570
+ Args:
1571
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1572
+ fused_linear_cross_entropy (bool):
1573
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1574
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1575
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1576
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1577
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
1578
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1579
+ loaded. Default is None.
1580
+ """
1581
+ if transformer_version < version.parse("4.52.4"):
1582
+ logger.warning("Qwen2.5-VL support is only compatible with transformers >= 4.52.4")
1583
+ return
1584
+
1585
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1586
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1587
+ )
1588
+
1589
+ from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
1590
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
1591
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
1592
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
1593
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel
1594
+
1595
+ from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward
1596
+
1597
+ if rope:
1598
+ modeling_qwen2_5_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
1599
+ if rms_norm:
1600
+ modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm
1601
+ if cross_entropy:
1602
+ modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1603
+ if fused_linear_cross_entropy:
1604
+ if model is not None:
1605
+ model.forward = MethodType(qwen2_5_vl_lce_forward, model)
1606
+ else:
1607
+ modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
1608
+ if swiglu:
1609
+ modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
1610
+
1611
+ if model is not None:
1612
+ # The model instance already exists, so we need to additionally patch the
1613
+ # instance variables that reference already-instantiated modules
1614
+
1615
+ if isinstance(model, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel)):
1616
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
1617
+ # Not sure if it is subject to changes in the future.
1618
+ # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
1619
+ text_model: Qwen2_5_VLTextModel = model.language_model
1620
+ vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
1621
+ elif isinstance(model, Qwen2_5_VLTextModel):
1622
+ text_model: Qwen2_5_VLTextModel = model
1623
+ vision_model = None
1624
+ else:
1625
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
1626
+ raise TypeError(
1627
+ f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
1628
+ )
1629
+
1630
+ if vision_model is not None:
1631
+ # Patch Qwen2_5_VisionTransformerPretrainedModel
1632
+ for vision_block in model.visual.blocks:
1633
+ if rms_norm:
1634
+ _patch_rms_norm_module(vision_block.norm1)
1635
+ _patch_rms_norm_module(vision_block.norm2)
1636
+
1637
+ if text_model is not None:
1638
+ if rms_norm:
1639
+ _patch_rms_norm_module(text_model.norm)
1640
+ for decoder_layer in text_model.layers:
1641
+ if swiglu:
1642
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1643
+ if rms_norm:
1644
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1645
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1646
+
1647
+
1648
+ def apply_liger_kernel_to_qwen3_vl(
1649
+ rope: bool = True,
1650
+ cross_entropy: bool = False,
1651
+ fused_linear_cross_entropy: bool = True,
1652
+ rms_norm: bool = True,
1653
+ swiglu: bool = False,
1654
+ model: PreTrainedModel = None,
1655
+ ) -> None:
1656
+ """
1657
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
1658
+
1659
+ Args:
1660
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1661
+ fused_linear_cross_entropy (bool):
1662
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1663
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1664
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1665
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1666
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1667
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1668
+ loaded. Default is None.
1669
+ """
1670
+
1671
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1672
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1673
+ )
1674
+
1675
+ from transformers.models.qwen3_vl import modeling_qwen3_vl
1676
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
1677
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
1678
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
1679
+
1680
+ from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
1681
+
1682
+ if rope:
1683
+ modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
1684
+ modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
1685
+
1686
+ if rms_norm:
1687
+ modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
1688
+
1689
+ if cross_entropy:
1690
+ from transformers.loss.loss_utils import nn
1691
+
1692
+ nn.functional.cross_entropy = liger_cross_entropy
1693
+
1694
+ if fused_linear_cross_entropy:
1695
+ if model is not None:
1696
+ model.forward = MethodType(qwen3_vl_lce_forward, model)
1697
+ else:
1698
+ modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
1699
+
1700
+ if model is not None and rms_norm:
1701
+ if isinstance(model, (Qwen3VLForConditionalGeneration, Qwen3VLModel)):
1702
+ text_model: Qwen3VLTextModel = model.language_model
1703
+ elif isinstance(model, Qwen3VLTextModel):
1704
+ text_model = model
1705
+ else:
1706
+ raise TypeError(
1707
+ f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
1708
+ )
1709
+
1710
+ _patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1711
+
1712
+ if text_model is not None:
1713
+ _patch_qwen3_vl_rms_norm(text_model.norm)
1714
+ for decoder_layer in text_model.layers:
1715
+ _patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
1716
+ _patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
1717
+ self_attn = getattr(decoder_layer, "self_attn", None)
1718
+ if self_attn is not None:
1719
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1720
+ _patch_qwen3_vl_rms_norm(self_attn.q_norm)
1721
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1722
+ _patch_qwen3_vl_rms_norm(self_attn.k_norm)
1723
+
1724
+
1725
+ def apply_liger_kernel_to_qwen3_vl_moe(
1726
+ rope: bool = True,
1727
+ cross_entropy: bool = False,
1728
+ fused_linear_cross_entropy: bool = True,
1729
+ rms_norm: bool = True,
1730
+ swiglu: bool = False,
1731
+ model: PreTrainedModel = None,
1732
+ ) -> None:
1733
+ """
1734
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
1735
+
1736
+ Args:
1737
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1738
+ fused_linear_cross_entropy (bool):
1739
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
1740
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1741
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1742
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1743
+ loaded. Default is None.
1744
+ """
1745
+
1746
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1747
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1748
+ )
1749
+
1750
+ from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
1751
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
1752
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
1753
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
1754
+
1755
+ from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
1756
+
1757
+ if rope:
1758
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
1759
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
1760
+
1761
+ if rms_norm:
1762
+ modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
1763
+
1764
+ if cross_entropy:
1765
+ from transformers.loss.loss_utils import nn
1766
+
1767
+ nn.functional.cross_entropy = liger_cross_entropy
1768
+
1769
+ if fused_linear_cross_entropy:
1770
+ if model is not None:
1771
+ model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
1772
+ else:
1773
+ modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
1774
+
1775
+ if model is not None and rms_norm:
1776
+ if isinstance(model, (Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel)):
1777
+ text_model: Qwen3VLMoeTextModel = model.language_model
1778
+ elif isinstance(model, Qwen3VLMoeTextModel):
1779
+ text_model = model
1780
+ else:
1781
+ raise TypeError(
1782
+ f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
1783
+ )
1784
+
1785
+ _patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1786
+
1787
+ if text_model is not None:
1788
+ _patch_qwen3_vl_moe_rms_norm(text_model.norm)
1789
+ for decoder_layer in text_model.layers:
1790
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
1791
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
1792
+ self_attn = getattr(decoder_layer, "self_attn", None)
1793
+ if self_attn is not None:
1794
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1795
+ _patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
1796
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1797
+ _patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
1798
+
1799
+
1800
+ def apply_liger_kernel_to_phi3(
1801
+ rope: bool = True,
1802
+ cross_entropy: bool = False,
1803
+ fused_linear_cross_entropy: bool = True,
1804
+ rms_norm: bool = True,
1805
+ swiglu: bool = True,
1806
+ model: PreTrainedModel = None,
1807
+ ) -> None:
1808
+ """
1809
+ Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
1810
+
1811
+ Args:
1812
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1813
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1814
+ fused_linear_cross_entropy (bool):
1815
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1816
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1817
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1818
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1819
+ swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
1820
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1821
+ loaded. Default is None.
1822
+ """
1823
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1824
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1825
+ )
1826
+
1827
+ from transformers.models.phi3 import modeling_phi3
1828
+ from transformers.models.phi3.modeling_phi3 import Phi3Model
1829
+
1830
+ if rope:
1831
+ modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
1832
+ if rms_norm:
1833
+ modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
1834
+ if swiglu:
1835
+ modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
1836
+ if cross_entropy:
1837
+ from transformers.loss.loss_utils import nn
1838
+
1839
+ nn.functional.cross_entropy = liger_cross_entropy
1840
+ if fused_linear_cross_entropy:
1841
+ if model is not None:
1842
+ model.forward = MethodType(phi3_lce_forward, model)
1843
+ else:
1844
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1845
+
1846
+ if model is not None:
1847
+ # The model instance already exists, so we need to additionally patch the
1848
+ # instance variables that reference already-instantiated modules
1849
+
1850
+ # get the base model from the model instance
1851
+ base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
1852
+
1853
+ if rms_norm:
1854
+ _patch_rms_norm_module(base_model.norm)
1855
+
1856
+ for decoder_layer in base_model.layers:
1857
+ if swiglu:
1858
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
1859
+ if rms_norm:
1860
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1861
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1862
+
1863
+
1864
+ def apply_liger_kernel_to_olmo2(
1865
+ rope: bool = True,
1866
+ cross_entropy: bool = False,
1867
+ fused_linear_cross_entropy: bool = True,
1868
+ rms_norm: bool = True,
1869
+ swiglu: bool = True,
1870
+ model: PreTrainedModel = None,
1871
+ ) -> None:
1872
+ """
1873
+ Apply Liger kernels to replace original implementation in HuggingFace OLMO2 models.
1874
+
1875
+ Args:
1876
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1877
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1878
+ fused_linear_cross_entropy (bool):
1879
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1880
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1881
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1882
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1883
+ swiglu (bool): Whether to apply Liger's SwiGLU Olmo2MLP. Default is True.
1884
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1885
+ loaded. Default is None.
1886
+ """
1887
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1888
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1889
+ )
1890
+
1891
+ from transformers.models.olmo2 import modeling_olmo2
1892
+ from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
1893
+
1894
+ from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
1895
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
1896
+
1897
+ if rope:
1898
+ modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
1899
+ if rms_norm:
1900
+ modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
1901
+ if swiglu:
1902
+ modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
1903
+ if cross_entropy:
1904
+ from transformers.loss.loss_utils import nn
1905
+
1906
+ nn.functional.cross_entropy = liger_cross_entropy
1907
+ if fused_linear_cross_entropy:
1908
+ if model is not None:
1909
+ model.forward = MethodType(olmo2_lce_forward, model)
1910
+ else:
1911
+ modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
1912
+
1913
+ if model is not None:
1914
+ # The model instance already exists, so we need to additionally patch the
1915
+ # instance variables that reference already-instantiated modules
1916
+
1917
+ # get the base model from the model instance
1918
+ base_model: Olmo2Model = getattr(model, model.base_model_prefix, model)
1919
+
1920
+ if rms_norm:
1921
+ _patch_rms_norm_module(base_model.norm)
1922
+
1923
+ for decoder_layer in base_model.layers:
1924
+ if swiglu:
1925
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1926
+ if rms_norm:
1927
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
1928
+ _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
1929
+
1930
+
1931
+ def apply_liger_kernel_to_glm4(
1932
+ rope: bool = False,
1933
+ cross_entropy: bool = False,
1934
+ fused_linear_cross_entropy: bool = True,
1935
+ rms_norm: bool = True,
1936
+ swiglu: bool = True,
1937
+ model: PreTrainedModel = None,
1938
+ ) -> None:
1939
+ """
1940
+ Apply Liger kernels to replace original implementation in HuggingFace GLM-4 models.
1941
+
1942
+ Args:
1943
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
1944
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1945
+ fused_linear_cross_entropy (bool):
1946
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1947
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1948
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1949
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1950
+ swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
477
1951
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
478
1952
  loaded. Default is None.
479
1953
  """
480
- from transformers.models.gemma2 import modeling_gemma2
481
-
482
- LigerRMSNormForGemma2 = partial(
483
- LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros"
484
- )
485
- _patch_rms_norm_module_for_gemma2 = partial(
486
- _patch_rms_norm_module, offset=1.0, casting_mode="gemma"
1954
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1955
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
487
1956
  )
488
1957
 
1958
+ from transformers.models.glm4 import modeling_glm4
1959
+ from transformers.models.glm4.modeling_glm4 import Glm4Model
1960
+
1961
+ from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
1962
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
1963
+
489
1964
  if rope:
490
- modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
1965
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
491
1966
  if rms_norm:
492
- # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
493
- modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2
1967
+ modeling_glm4.Glm4RMSNorm = LigerRMSNormForGlm4
1968
+ if swiglu:
1969
+ modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
494
1970
  if cross_entropy:
495
- modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
496
- if geglu:
497
- modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
1971
+ from transformers.loss.loss_utils import nn
1972
+
1973
+ nn.functional.cross_entropy = liger_cross_entropy
1974
+ if fused_linear_cross_entropy:
1975
+ if model is not None:
1976
+ model.forward = MethodType(glm4_lce_forward, model)
1977
+ else:
1978
+ modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
498
1979
 
499
1980
  if model is not None:
500
1981
  # The model instance already exists, so we need to additionally patch the
501
1982
  # instance variables that reference already-instantiated modules
502
1983
 
503
- if hasattr(model, "model"):
504
- # The case for Gemma2ForCausalLM, Gemma2ForTokenClassification for example
505
- base_model = model.model
506
- else:
507
- # Direct Gemma2Model
508
- base_model = model
1984
+ # get the base model from the model instance
1985
+ base_model: Glm4Model = getattr(model, model.base_model_prefix, model)
509
1986
 
510
1987
  if rms_norm:
511
- _patch_rms_norm_module_for_gemma2(base_model.norm)
1988
+ _patch_rms_norm_module(base_model.norm, in_place=False)
512
1989
 
513
1990
  for decoder_layer in base_model.layers:
514
- if geglu:
515
- _bind_method_to_module(
516
- decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
517
- )
1991
+ if swiglu:
1992
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
518
1993
  if rms_norm:
519
- _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
520
- _patch_rms_norm_module_for_gemma2(
521
- decoder_layer.post_attention_layernorm
522
- )
523
- _patch_rms_norm_module_for_gemma2(
524
- decoder_layer.pre_feedforward_layernorm
525
- )
526
- _patch_rms_norm_module_for_gemma2(
527
- decoder_layer.post_feedforward_layernorm
528
- )
1994
+ _patch_rms_norm_module(decoder_layer.input_layernorm, in_place=False)
1995
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
1996
+ _patch_rms_norm_module(decoder_layer.post_self_attn_layernorm, in_place=False)
1997
+ _patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
529
1998
 
530
1999
 
531
- def apply_liger_kernel_to_qwen2(
532
- rope: bool = True,
2000
+ def apply_liger_kernel_to_glm4v(
2001
+ rope: bool = False,
533
2002
  cross_entropy: bool = False,
534
2003
  fused_linear_cross_entropy: bool = True,
535
2004
  rms_norm: bool = True,
@@ -537,150 +2006,469 @@ def apply_liger_kernel_to_qwen2(
537
2006
  model: PreTrainedModel = None,
538
2007
  ) -> None:
539
2008
  """
540
- Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models
2009
+ Apply Liger kernels to replace original implementation in HuggingFace GLM-4v models.
541
2010
 
542
2011
  Args:
543
- rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2012
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
544
2013
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
545
2014
  fused_linear_cross_entropy (bool):
546
2015
  Whether to apply Liger's fused linear cross entropy loss. Default is True.
547
2016
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
548
2017
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
549
2018
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
550
- swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
2019
+ swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
551
2020
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
552
2021
  loaded. Default is None.
553
2022
  """
554
- assert not (
555
- cross_entropy and fused_linear_cross_entropy
556
- ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
2023
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2024
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2025
+ )
557
2026
 
558
- from transformers.models.qwen2 import modeling_qwen2
2027
+ from transformers.models.glm4v import modeling_glm4v
2028
+ from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration
2029
+ from transformers.models.glm4v.modeling_glm4v import Glm4vModel
2030
+ from transformers.models.glm4v.modeling_glm4v import Glm4vTextModel
2031
+ from transformers.models.glm4v.modeling_glm4v import Glm4vVisionModel
2032
+
2033
+ from liger_kernel.transformers.model.glm4v import lce_forward as glm4v_lce_forward
2034
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
559
2035
 
560
2036
  if rope:
561
- modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
2037
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
562
2038
  if rms_norm:
563
- modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
2039
+ modeling_glm4v.Glm4vRMSNorm = LigerRMSNormForGlm4
564
2040
  if cross_entropy:
565
- modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
2041
+ from transformers.loss.loss_utils import nn
566
2042
 
567
- # import pdb; pdb.set_trace()
2043
+ nn.functional.cross_entropy = liger_cross_entropy
568
2044
  if fused_linear_cross_entropy:
569
-
570
- if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
571
- modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
572
- else: # if version < 4.46.1
573
- logger.warning(TRANSFORMER_DEPRECATION_WARNING)
574
- modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
575
-
576
- if swiglu:
577
- modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
2045
+ if model is not None:
2046
+ model.forward = MethodType(glm4v_lce_forward, model)
2047
+ else:
2048
+ modeling_glm4v.Glm4vForConditionalGeneration.forward = glm4v_lce_forward
578
2049
 
579
2050
  if model is not None:
580
2051
  # The model instance already exists, so we need to additionally patch the
581
2052
  # instance variables that reference already-instantiated modules
582
-
583
- if hasattr(model, "model"):
584
- # The case for Qwen2ForCausalLM, Qwen2ForTokenClassification for example
585
- base_model = model.model
2053
+ if isinstance(model, (Glm4vForConditionalGeneration, Glm4vModel)):
2054
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
2055
+ # Not sure if it is subject to changes in the future.
2056
+ # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py#L1305
2057
+ text_model: Glm4vTextModel = model.language_model
2058
+ vision_model: Glm4vVisionModel = model.visual
2059
+ elif isinstance(model, Glm4vTextModel):
2060
+ text_model: Glm4vTextModel = model
2061
+ vision_model = None
586
2062
  else:
587
- # Direct Qwen2Model
588
- base_model = model
2063
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
2064
+ raise TypeError(
2065
+ f"Unsupported glm4.1v model type. `model` must be `Glm4VLForConditionalGeneration`, `Glm4vVisionModel` or `Glm4vTextModel`. Got: {type(model)}"
2066
+ )
589
2067
 
590
- if rms_norm:
591
- _patch_rms_norm_module(base_model.norm)
2068
+ if vision_model is not None:
2069
+ for vision_block in vision_model.blocks:
2070
+ if rms_norm:
2071
+ _patch_rms_norm_module(vision_block.norm1)
2072
+ _patch_rms_norm_module(vision_block.norm2)
2073
+ if swiglu:
2074
+ _patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
592
2075
 
593
- for decoder_layer in base_model.layers:
594
- if swiglu:
595
- _bind_method_to_module(
596
- decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
597
- )
2076
+ if text_model is not None:
598
2077
  if rms_norm:
599
- _patch_rms_norm_module(decoder_layer.input_layernorm)
600
- _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
601
- print("Applied Liger kernels to Qwen2")
2078
+ _patch_rms_norm_module(text_model.norm)
2079
+ for decoder_layer in text_model.layers:
2080
+ if swiglu:
2081
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
2082
+ if rms_norm:
2083
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2084
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2085
+ _patch_rms_norm_module(decoder_layer.post_self_attn_layernorm)
2086
+ _patch_rms_norm_module(decoder_layer.post_mlp_layernorm)
602
2087
 
603
2088
 
604
- def apply_liger_kernel_to_qwen2_vl(
2089
+ def apply_liger_kernel_to_glm4v_moe(
2090
+ rope: bool = False,
605
2091
  cross_entropy: bool = False,
606
2092
  fused_linear_cross_entropy: bool = True,
607
2093
  rms_norm: bool = True,
608
- layer_norm: bool = True,
609
2094
  swiglu: bool = True,
610
2095
  model: PreTrainedModel = None,
611
2096
  ) -> None:
612
2097
  """
613
- Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
614
- NOTE: Qwen2-VL is not available in transformers<4.45.0
2098
+ Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
615
2099
 
616
2100
  Args:
2101
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
617
2102
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
618
2103
  fused_linear_cross_entropy (bool):
619
2104
  Whether to apply Liger's fused linear cross entropy loss. Default is True.
620
2105
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
621
2106
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
622
2107
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
623
- layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
624
- swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
2108
+ swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
625
2109
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
626
2110
  loaded. Default is None.
627
2111
  """
628
- assert not (
629
- cross_entropy and fused_linear_cross_entropy
630
- ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
2112
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2113
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2114
+ )
631
2115
 
632
- from transformers.models.qwen2_vl import modeling_qwen2_vl
2116
+ from transformers.models.glm4v_moe import modeling_glm4v_moe
2117
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeForConditionalGeneration
2118
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeModel
2119
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextModel
2120
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeVisionModel
2121
+
2122
+ from liger_kernel.transformers.model.glm4v_moe import lce_forward as glm4v_moe_lce_forward
2123
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
2124
+
2125
+ if rope:
2126
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
2127
+ if rms_norm:
2128
+ modeling_glm4v_moe.Glm4vMoeRMSNorm = LigerRMSNormForGlm4
2129
+ modeling_glm4v_moe.Glm4vMoeTextRMSNorm = LigerRMSNormForGlm4
2130
+ if cross_entropy:
2131
+ from transformers.loss.loss_utils import nn
2132
+
2133
+ nn.functional.cross_entropy = liger_cross_entropy
2134
+ if fused_linear_cross_entropy:
2135
+ if model is not None:
2136
+ model.forward = MethodType(glm4v_moe_lce_forward, model)
2137
+ else:
2138
+ modeling_glm4v_moe.Glm4vMoeForConditionalGeneration.forward = glm4v_moe_lce_forward
2139
+
2140
+ if model is not None:
2141
+ # The model instance already exists, so we need to additionally patch the
2142
+ # instance variables that reference already-instantiated modules
2143
+ if isinstance(model, (Glm4vMoeForConditionalGeneration, Glm4vMoeModel)):
2144
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
2145
+ # Not sure if it is subject to changes in the future.
2146
+ # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py#L337
2147
+ text_model: Glm4vMoeTextModel = model.language_model
2148
+ vision_model: Glm4vMoeVisionModel = model.visual
2149
+ Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
2150
+ elif isinstance(model, Glm4vMoeTextModel):
2151
+ text_model: Glm4vMoeTextModel = model
2152
+ vision_model = None
2153
+ else:
2154
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
2155
+ raise TypeError(
2156
+ f"Unsupported glm4v_moe model type. `model` must be `Glm4vMoeForConditionalGeneration`, `Glm4vMoeVisionModel` or `Glm4vMoeTextModel`. Got: {type(model)}"
2157
+ )
2158
+
2159
+ if vision_model is not None:
2160
+ _patch_rms_norm_module(vision_model.post_conv_layernorm)
2161
+ _patch_rms_norm_module(vision_model.post_layernorm)
2162
+ for vision_block in vision_model.blocks:
2163
+ if rms_norm:
2164
+ _patch_rms_norm_module(vision_block.norm1)
2165
+ _patch_rms_norm_module(vision_block.norm2)
2166
+ if swiglu:
2167
+ _patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
2168
+
2169
+ if text_model is not None:
2170
+ if rms_norm:
2171
+ _patch_rms_norm_module(text_model.norm)
2172
+ for decoder_layer in text_model.layers:
2173
+ if swiglu:
2174
+ decoder_layer.mlp = _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2175
+ if rms_norm:
2176
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2177
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2178
+ if isinstance(Glm4vMoeTextMoE, type) and isinstance(decoder_layer.mlp, Glm4vMoeTextMoE):
2179
+ experts = getattr(decoder_layer.mlp, "experts", None)
2180
+ if experts is not None:
2181
+ for expert in experts:
2182
+ _patch_swiglu_module(expert, LigerSwiGLUMLP)
2183
+ if decoder_layer.mlp.shared_experts is not None:
2184
+ _patch_swiglu_module(decoder_layer.mlp.shared_experts, LigerSwiGLUMLP)
2185
+ for decoder_layer in text_model.layers:
2186
+ if rms_norm:
2187
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2188
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
633
2189
 
634
- from liger_kernel.transformers.model.qwen2_vl import (
635
- lce_forward as qwen2_vl_lce_forward,
2190
+
2191
+ def apply_liger_kernel_to_internvl(
2192
+ cross_entropy: bool = False,
2193
+ fused_linear_cross_entropy: bool = True,
2194
+ rms_norm: bool = True,
2195
+ layer_norm: bool = True,
2196
+ model: Optional[PreTrainedModel] = None,
2197
+ **kwargs,
2198
+ ) -> None:
2199
+ """
2200
+ Apply Liger kernels to replace original implementation in HuggingFace InternVL models.
2201
+ Due to the characteristics of InternVL, the model must be passed to apply Liger-Kernel's patch to other models connected to InternVL.
2202
+ However, if an LM not supported by Liger-Kernel is connected to InternVL, unexpected side effects may occur.
2203
+ NOTE: InternVL is not available in transformers<4.52.1
2204
+
2205
+ Args:
2206
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2207
+ fused_linear_cross_entropy (bool):
2208
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2209
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2210
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2211
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2212
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
2213
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2214
+ loaded. Default is None.
2215
+ """
2216
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2217
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
636
2218
  )
2219
+ import torch.nn as torch_nn
2220
+
2221
+ from transformers.models.internvl import modeling_internvl
2222
+ from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration
2223
+ from transformers.models.internvl.modeling_internvl import InternVLModel
2224
+ from transformers.models.internvl.modeling_internvl import InternVLVisionLayer
2225
+ from transformers.models.internvl.modeling_internvl import InternVLVisionModel
2226
+ from transformers.models.internvl.modeling_internvl import InternVLVisionRMSNorm
2227
+
2228
+ from liger_kernel.transformers.layer_norm import LigerLayerNorm
2229
+ from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
2230
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm
2231
+
2232
+ if layer_norm and model is None:
2233
+ modeling_internvl.nn.LayerNorm = LigerLayerNorm
2234
+
2235
+ if cross_entropy:
2236
+ logger.info("Apply liger cross entropy")
637
2237
 
638
- # TODO: Support Qwen2-VL's multimodal RoPE implementation
2238
+ from transformers.loss.loss_utils import nn
639
2239
 
2240
+ nn.functional.cross_entropy = liger_cross_entropy
2241
+ if fused_linear_cross_entropy:
2242
+ modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
640
2243
  if rms_norm:
641
- # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
642
- modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
643
- if layer_norm:
644
- modeling_qwen2_vl.LayerNorm = LigerLayerNorm
2244
+ modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
2245
+
2246
+ if model is not None:
2247
+ # The model instance already exists, so we need to additionally patch the
2248
+ # instance variables that reference already-instantiated modules
2249
+ if isinstance(model, (InternVLForConditionalGeneration, InternVLModel)):
2250
+ # NOTE: language_model and visual properties can be accessed throught conditional class.
2251
+ text_model = model.language_model
2252
+ vision_model: InternVLVisionModel = model.vision_tower
2253
+ else:
2254
+ raise TypeError(
2255
+ f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}"
2256
+ )
2257
+
2258
+ text_model_name = model.config.text_config.model_type
2259
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2260
+
2261
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2262
+ if text_liger_fn:
2263
+ accept_params = inspect.signature(text_liger_fn).parameters
2264
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2265
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2266
+
2267
+ if remain_params:
2268
+ logger.warning(
2269
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2270
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2271
+ )
2272
+ text_kwargs["model"] = text_model
2273
+ text_liger_fn(**text_kwargs)
2274
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2275
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2276
+
2277
+ # Patch vision model RMSNorm layers
2278
+ if rms_norm:
2279
+ for encoder_layer in vision_model.encoder.layer:
2280
+ encoder_layer: InternVLVisionLayer
2281
+ if isinstance(encoder_layer.attention.q_norm, InternVLVisionRMSNorm):
2282
+ _patch_rms_norm_module(encoder_layer.attention.q_norm)
2283
+ if isinstance(encoder_layer.attention.k_norm, InternVLVisionRMSNorm):
2284
+ _patch_rms_norm_module(encoder_layer.attention.k_norm)
2285
+
2286
+ # Patch vision model LayerNorm layers
2287
+ if layer_norm:
2288
+ # Patch layernorm
2289
+ if isinstance(vision_model.layernorm, torch_nn.LayerNorm):
2290
+ _patch_layer_norm_module(vision_model.layernorm)
2291
+
2292
+ # Patch encoder layers
2293
+ for encoder_layer in vision_model.encoder.layer:
2294
+ encoder_layer: InternVLVisionLayer
2295
+ if isinstance(encoder_layer.layernorm_before, torch_nn.LayerNorm):
2296
+ _patch_layer_norm_module(encoder_layer.layernorm_before)
2297
+ if isinstance(encoder_layer.layernorm_after, torch_nn.LayerNorm):
2298
+ _patch_layer_norm_module(encoder_layer.layernorm_after)
2299
+
2300
+
2301
+ def apply_liger_kernel_to_smolvlm(
2302
+ cross_entropy: bool = False,
2303
+ fused_linear_cross_entropy: bool = True,
2304
+ rms_norm: bool = True,
2305
+ layer_norm: bool = True,
2306
+ model: Optional[PreTrainedModel] = None,
2307
+ **kwargs,
2308
+ ) -> None:
2309
+ """
2310
+ Apply Liger kernels to replace original implementation in HuggingFace SmolVLM models.
2311
+ Due to the characteristics of SmolVLM, the model must be passed to apply Liger-Kernel's patch to other models connected to SmolVLM.
2312
+ However, if an LM not supported by Liger-Kernel is connected to SmolVLM, unexpected side effects may occur.
2313
+ NOTE: SmolVLM is not available in transformers<4.50.0
2314
+
2315
+ Args:
2316
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2317
+ fused_linear_cross_entropy (bool):
2318
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2319
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2320
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2321
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2322
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
2323
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2324
+ loaded. Default is None.
2325
+ """
2326
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2327
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2328
+ )
2329
+
2330
+ from transformers.models.smolvlm import modeling_smolvlm
2331
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMEncoderLayer
2332
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration
2333
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMModel
2334
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer
2335
+
2336
+ from liger_kernel.transformers.model.smolvlm import lce_forward as smolvlm_lce_forward
2337
+
2338
+ # Patch LayerNorm for vision model if model is not provided (pre-initialization)
2339
+ if layer_norm and model is None:
2340
+ modeling_smolvlm.nn.LayerNorm = LigerLayerNorm
2341
+
645
2342
  if cross_entropy:
646
- modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
2343
+ logger.info("Apply liger cross entropy")
2344
+
2345
+ from transformers.loss.loss_utils import nn
2346
+
2347
+ nn.functional.cross_entropy = liger_cross_entropy
647
2348
  if fused_linear_cross_entropy:
648
- modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
649
- if swiglu:
650
- modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
2349
+ if model is not None:
2350
+ model.forward = MethodType(smolvlm_lce_forward, model)
2351
+ else:
2352
+ modeling_smolvlm.SmolVLMForConditionalGeneration.forward = smolvlm_lce_forward
2353
+ if rms_norm:
2354
+ modeling_smolvlm.SmolVLMRMSNorm = LigerRMSNorm
651
2355
 
652
2356
  if model is not None:
653
2357
  # The model instance already exists, so we need to additionally patch the
654
2358
  # instance variables that reference already-instantiated modules
2359
+ if isinstance(model, SmolVLMForConditionalGeneration):
2360
+ text_model = model.model.text_model
2361
+ vision_model: SmolVLMVisionTransformer = model.model.vision_model
2362
+ elif isinstance(model, SmolVLMModel):
2363
+ text_model = model.text_model
2364
+ vision_model: SmolVLMVisionTransformer = model.vision_model
2365
+ else:
2366
+ raise TypeError(
2367
+ f"Unsupported smolvlm model type. `model` must be `SmolVLMForConditionalGeneration`, `SmolVLMModel`. Got: {type(model)}"
2368
+ )
2369
+
2370
+ text_model_name = model.config.text_config.model_type
2371
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2372
+
2373
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2374
+ if text_liger_fn:
2375
+ accept_params = inspect.signature(text_liger_fn).parameters
2376
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2377
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2378
+
2379
+ if remain_params:
2380
+ logger.warning(
2381
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2382
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2383
+ )
2384
+ text_kwargs["model"] = text_model
2385
+ text_liger_fn(**text_kwargs)
2386
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2387
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2388
+
2389
+ # Patch vision model LayerNorm layers
2390
+ if layer_norm:
2391
+ # Patch post_layernorm
2392
+ _patch_layer_norm_module(vision_model.post_layernorm)
2393
+
2394
+ # Patch encoder layers
2395
+ for encoder_layer in vision_model.encoder.layers:
2396
+ encoder_layer: SmolVLMEncoderLayer
2397
+ _patch_layer_norm_module(encoder_layer.layer_norm1)
2398
+ _patch_layer_norm_module(encoder_layer.layer_norm2)
2399
+
2400
+
2401
+ def apply_liger_kernel_to_falcon_h1(
2402
+ rope: bool = True,
2403
+ cross_entropy: bool = False,
2404
+ fused_linear_cross_entropy: bool = True,
2405
+ rms_norm: bool = True,
2406
+ swiglu: bool = False,
2407
+ model: PreTrainedModel = None,
2408
+ ) -> None:
2409
+ """
2410
+ Apply Liger kernels to replace original implementation in HuggingFace Falcon-H1 models
2411
+ Args:
2412
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2413
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
2414
+ fused_linear_cross_entropy (bool):
2415
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
2416
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2417
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2418
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
2419
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
2420
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2421
+ loaded. Default is None.
2422
+ """
2423
+
2424
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2425
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2426
+ )
2427
+
2428
+ from transformers.models.falcon_h1 import modeling_falcon_h1
2429
+ from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1Model
2430
+
2431
+ if rope:
2432
+ logger.info("Apply liger rotary pos emb.")
2433
+ modeling_falcon_h1.apply_rotary_pos_emb = liger_rotary_pos_emb
2434
+ if rms_norm:
2435
+ logger.info("Apply liger RMSNorm")
2436
+ modeling_falcon_h1.FalconH1RMSNorm = LigerRMSNorm
2437
+ if swiglu:
2438
+ logger.warning("LigerSwiGLUMLP is not available for Falcon-H1 models. There will be no effect.")
655
2439
 
656
- if hasattr(model, "model"):
657
- # The case for Qwen2VLForConditionalGeneration.
658
- base_model = model.model
2440
+ if cross_entropy:
2441
+ logger.info("Apply liger cross entropy")
2442
+ from transformers.loss.loss_utils import nn
2443
+
2444
+ nn.functional.cross_entropy = liger_cross_entropy
2445
+
2446
+ if fused_linear_cross_entropy:
2447
+ if model is not None:
2448
+ model.forward = MethodType(falcon_h1_lce_forward, model)
659
2449
  else:
660
- # Direct Qwen2VLModel
661
- base_model = model
2450
+ modeling_falcon_h1.FalconH1ForCausalLM.forward = falcon_h1_lce_forward
662
2451
 
663
- if hasattr(model, "visual"):
664
- # Patch Qwen2VisionTransformerPretrainedModel
665
- for vision_block in model.visual.blocks:
666
- if layer_norm:
667
- _patch_layer_norm_module(vision_block.norm1)
668
- _patch_layer_norm_module(vision_block.norm2)
2452
+ if model is not None:
2453
+ # The model instance already exists, so we need to additionally patch the
2454
+ # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
2455
+
2456
+ # get the base model from the model instance
2457
+ base_model: FalconH1Model = getattr(model, model.base_model_prefix, model)
669
2458
 
670
2459
  if rms_norm:
671
- _patch_rms_norm_module(base_model.norm)
2460
+ _patch_rms_norm_module(base_model.final_layernorm)
2461
+
672
2462
  for decoder_layer in base_model.layers:
673
2463
  if swiglu:
674
- _bind_method_to_module(
675
- decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
676
- )
2464
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
677
2465
  if rms_norm:
678
2466
  _patch_rms_norm_module(decoder_layer.input_layernorm)
679
- _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2467
+ _patch_rms_norm_module(decoder_layer.pre_ff_layernorm)
680
2468
 
681
2469
 
682
- def apply_liger_kernel_to_phi3(
683
- rope: bool = True,
2470
+ def apply_liger_kernel_to_qwen3_next(
2471
+ rope: bool = False,
684
2472
  cross_entropy: bool = False,
685
2473
  fused_linear_cross_entropy: bool = True,
686
2474
  rms_norm: bool = True,
@@ -688,77 +2476,125 @@ def apply_liger_kernel_to_phi3(
688
2476
  model: PreTrainedModel = None,
689
2477
  ) -> None:
690
2478
  """
691
- Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
2479
+ Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
692
2480
 
693
2481
  Args:
694
- rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2482
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
695
2483
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
696
2484
  fused_linear_cross_entropy (bool):
697
2485
  Whether to apply Liger's fused linear cross entropy loss. Default is True.
698
2486
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
699
2487
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
700
2488
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
701
- swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
2489
+ swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
702
2490
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
703
2491
  loaded. Default is None.
704
2492
  """
705
- assert not (
706
- cross_entropy and fused_linear_cross_entropy
707
- ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
2493
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2494
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2495
+ )
708
2496
 
709
- from transformers.models.phi3 import modeling_phi3
2497
+ from transformers.models.qwen3_next import modeling_qwen3_next
2498
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM
2499
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP
2500
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextModel
2501
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
2502
+
2503
+ from liger_kernel.transformers.model.qwen3_next import lce_forward as qwen3_next_lce_forward
2504
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
2505
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
710
2506
 
711
2507
  if rope:
712
- modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
2508
+ # It might enocunter nan issue
2509
+ # modeling_qwen3_next.apply_rotary_pos_emb = liger_rotary_pos_emb
2510
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3Next models.")
713
2511
  if rms_norm:
714
- modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
715
- if swiglu:
716
- modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
2512
+ modeling_qwen3_next.Qwen3NextRMSNorm = LigerRMSNormForQwen3Next
717
2513
  if cross_entropy:
718
- modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
2514
+ from transformers.loss.loss_utils import nn
2515
+
2516
+ nn.functional.cross_entropy = liger_cross_entropy
719
2517
  if fused_linear_cross_entropy:
720
- if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
721
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
722
- else: # if version < 4.46.1
723
- logger.warning(TRANSFORMER_DEPRECATION_WARNING)
724
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
2518
+ if model is not None:
2519
+ if isinstance(model, Qwen3NextForCausalLM):
2520
+ model.forward = MethodType(qwen3_next_lce_forward, model)
2521
+ else:
2522
+ raise TypeError(
2523
+ f" fused_linear_cross_entropy is only applicable on Qwen3NextForCausalLM. Got: {type(model)}"
2524
+ )
2525
+ else:
2526
+ modeling_qwen3_next.Qwen3NextForCausalLM.forward = qwen3_next_lce_forward
2527
+ if swiglu:
2528
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2529
+ modeling_qwen3_next.Qwen3NextMLP = LigerQwen3MoeSwiGLUMLP
725
2530
 
726
2531
  if model is not None:
727
2532
  # The model instance already exists, so we need to additionally patch the
728
2533
  # instance variables that reference already-instantiated modules
729
-
730
- if hasattr(model, "model"):
731
- # The case for Phi3ForCausalLM, Phi3ForTokenClassification for example
732
- base_model = model.model
2534
+ if isinstance(model, (Qwen3NextForCausalLM, Qwen3NextModel)):
2535
+ base_model: Qwen3NextForCausalLM = getattr(model, model.base_model_prefix, model)
733
2536
  else:
734
- # Direct Phi3Model
735
- base_model = model
2537
+ raise TypeError(
2538
+ f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}"
2539
+ )
736
2540
 
737
2541
  if rms_norm:
738
2542
  _patch_rms_norm_module(base_model.norm)
739
2543
 
740
2544
  for decoder_layer in base_model.layers:
741
- if swiglu:
742
- _bind_method_to_module(
743
- decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward
744
- )
745
2545
  if rms_norm:
746
2546
  _patch_rms_norm_module(decoder_layer.input_layernorm)
747
2547
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
748
2548
 
2549
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2550
+ if swiglu:
2551
+ if isinstance(decoder_layer.mlp, Qwen3NextMLP):
2552
+ _patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
2553
+ if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock):
2554
+ _patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP)
2555
+ experts = getattr(decoder_layer.mlp, "experts", None)
2556
+ if experts is not None:
2557
+ for expert in experts:
2558
+ _patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
2559
+
749
2560
 
750
2561
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
751
2562
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
752
2563
  "gemma": apply_liger_kernel_to_gemma,
753
2564
  "gemma2": apply_liger_kernel_to_gemma2,
2565
+ "gemma3_text": apply_liger_kernel_to_gemma3_text,
2566
+ "gemma3": apply_liger_kernel_to_gemma3,
2567
+ "glm4": apply_liger_kernel_to_glm4,
2568
+ "glm4v": apply_liger_kernel_to_glm4v,
2569
+ "glm4v_moe": apply_liger_kernel_to_glm4v_moe,
2570
+ "internvl": apply_liger_kernel_to_internvl,
754
2571
  "llama": apply_liger_kernel_to_llama,
2572
+ "llama4_text": apply_liger_kernel_to_llama4,
2573
+ "llama4": apply_liger_kernel_to_llama4,
2574
+ "llava": apply_liger_kernel_to_llava,
2575
+ "granite": apply_liger_kernel_to_granite,
755
2576
  "mllama": apply_liger_kernel_to_mllama,
756
2577
  "mllama_text_model": apply_liger_kernel_to_mllama,
757
2578
  "mistral": apply_liger_kernel_to_mistral,
758
2579
  "mixtral": apply_liger_kernel_to_mixtral,
2580
+ "olmo2": apply_liger_kernel_to_olmo2,
759
2581
  "qwen2": apply_liger_kernel_to_qwen2,
2582
+ "qwen3": apply_liger_kernel_to_qwen3,
2583
+ "qwen3_moe": apply_liger_kernel_to_qwen3_moe,
760
2584
  "qwen2_vl": apply_liger_kernel_to_qwen2_vl,
2585
+ "qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
2586
+ "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
2587
+ "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
2588
+ "qwen3_next": apply_liger_kernel_to_qwen3_next,
2589
+ "qwen3_vl": apply_liger_kernel_to_qwen3_vl,
2590
+ "qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
2591
+ "qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
2592
+ "qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
2593
+ "smollm3": apply_liger_kernel_to_smollm3,
761
2594
  "phi3": apply_liger_kernel_to_phi3,
2595
+ "paligemma": apply_liger_kernel_to_paligemma,
2596
+ "falcon_h1": apply_liger_kernel_to_falcon_h1,
2597
+ "smolvlm": apply_liger_kernel_to_smolvlm,
762
2598
  }
763
2599
 
764
2600
 
@@ -782,24 +2618,16 @@ def _apply_liger_kernel(model_type: str, **kwargs) -> None:
782
2618
  return
783
2619
 
784
2620
  if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
785
- logger.info(
786
- f"There are currently no Liger kernels supported for model type: {model_type}."
787
- )
2621
+ logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
788
2622
  return
789
2623
 
790
2624
  apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
791
2625
  apply_fn_signature = inspect.signature(apply_fn)
792
2626
 
793
2627
  # Filter out the keyword arguments that are not supported by the apply function
794
- applicable_kwargs = {
795
- key: value
796
- for key, value in kwargs.items()
797
- if key in apply_fn_signature.parameters
798
- }
2628
+ applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
799
2629
 
800
- logger.info(
801
- f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}"
802
- )
2630
+ logger.info(f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}")
803
2631
 
804
2632
  # Assume this is invoked pre-model initialization, so we only need to patch transformers code
805
2633
  apply_fn(**applicable_kwargs)
@@ -813,32 +2641,21 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
813
2641
  - model: the model instance to apply Liger kernels to
814
2642
  - kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
815
2643
  """
816
- model_type = getattr(model, "config", None) and getattr(
817
- model.config, "model_type", None
818
- )
2644
+ model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
819
2645
 
820
2646
  if not model_type:
821
- logger.info(
822
- "Model type could not be determined from model config. No Liger kernels will be applied."
823
- )
2647
+ logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
824
2648
  return
825
2649
 
826
2650
  if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
827
- logger.info(
828
- f"There are currently no Liger kernels supported for model type: {model_type}."
829
- )
2651
+ logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
830
2652
  return
831
2653
 
832
2654
  apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
833
-
834
2655
  apply_fn_signature = inspect.signature(apply_fn)
835
2656
 
836
2657
  # Filter out the keyword arguments that are not supported by the apply function
837
- applicable_kwargs = {
838
- key: value
839
- for key, value in kwargs.items()
840
- if key in apply_fn_signature.parameters
841
- }
2658
+ applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
842
2659
  logger.info(
843
2660
  f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
844
2661
  )