liger-kernel-nightly 0.5.6.dev20250403190551__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +35 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +25 -9
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  15. liger_kernel/ops/backends/registry.py +61 -0
  16. liger_kernel/ops/cross_entropy.py +124 -64
  17. liger_kernel/ops/dyt.py +115 -180
  18. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  19. liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
  20. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  21. liger_kernel/ops/geglu.py +3 -2
  22. liger_kernel/ops/group_norm.py +2 -1
  23. liger_kernel/ops/grpo_loss.py +312 -0
  24. liger_kernel/ops/jsd.py +2 -1
  25. liger_kernel/ops/kl_div.py +13 -6
  26. liger_kernel/ops/layer_norm.py +146 -78
  27. liger_kernel/ops/llama4_rope.py +225 -0
  28. liger_kernel/ops/multi_token_attention.py +207 -0
  29. liger_kernel/ops/poly_norm.py +390 -0
  30. liger_kernel/ops/rms_norm.py +283 -56
  31. liger_kernel/ops/rope.py +1 -1
  32. liger_kernel/ops/softmax.py +201 -0
  33. liger_kernel/ops/sparsemax.py +179 -0
  34. liger_kernel/ops/swiglu.py +1 -1
  35. liger_kernel/ops/tiled_mlp.py +136 -0
  36. liger_kernel/ops/utils.py +2 -0
  37. liger_kernel/transformers/__init__.py +205 -19
  38. liger_kernel/transformers/cross_entropy.py +9 -4
  39. liger_kernel/transformers/dyt.py +6 -4
  40. liger_kernel/transformers/experimental/__init__.py +5 -0
  41. liger_kernel/transformers/experimental/embedding.py +1 -1
  42. liger_kernel/transformers/fsdp.py +55 -0
  43. liger_kernel/transformers/functional.py +122 -20
  44. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  45. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  46. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  47. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  48. liger_kernel/transformers/geglu.py +1 -1
  49. liger_kernel/transformers/group_norm.py +1 -1
  50. liger_kernel/transformers/grpo_loss.py +153 -0
  51. liger_kernel/transformers/jsd.py +1 -1
  52. liger_kernel/transformers/kl_div.py +1 -1
  53. liger_kernel/transformers/layer_norm.py +1 -1
  54. liger_kernel/transformers/llama4_rope.py +93 -0
  55. liger_kernel/transformers/model/falcon_h1.py +122 -0
  56. liger_kernel/transformers/model/gemma.py +50 -25
  57. liger_kernel/transformers/model/gemma2.py +55 -23
  58. liger_kernel/transformers/model/gemma3.py +117 -120
  59. liger_kernel/transformers/model/glm4.py +141 -0
  60. liger_kernel/transformers/model/glm4v.py +163 -0
  61. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  62. liger_kernel/transformers/model/gpt_oss.py +211 -0
  63. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  64. liger_kernel/transformers/model/internvl.py +157 -0
  65. liger_kernel/transformers/model/llama.py +102 -25
  66. liger_kernel/transformers/model/llama4.py +121 -0
  67. liger_kernel/transformers/model/llava.py +111 -136
  68. liger_kernel/transformers/model/loss_utils.py +50 -12
  69. liger_kernel/transformers/model/mistral.py +36 -23
  70. liger_kernel/transformers/model/mixtral.py +45 -25
  71. liger_kernel/transformers/model/mllama.py +39 -22
  72. liger_kernel/transformers/model/olmo2.py +40 -20
  73. liger_kernel/transformers/model/olmo3.py +142 -0
  74. liger_kernel/transformers/model/output_classes.py +147 -0
  75. liger_kernel/transformers/model/paligemma.py +50 -14
  76. liger_kernel/transformers/model/phi3.py +47 -177
  77. liger_kernel/transformers/model/qwen2.py +48 -21
  78. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  79. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  80. liger_kernel/transformers/model/qwen3.py +136 -0
  81. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  82. liger_kernel/transformers/model/qwen3_next.py +146 -0
  83. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  84. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  85. liger_kernel/transformers/model/smollm3.py +199 -0
  86. liger_kernel/transformers/model/smolvlm.py +158 -0
  87. liger_kernel/transformers/monkey_patch.py +1678 -160
  88. liger_kernel/transformers/multi_token_attention.py +64 -0
  89. liger_kernel/transformers/poly_norm.py +42 -0
  90. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  91. liger_kernel/transformers/rms_norm.py +48 -5
  92. liger_kernel/transformers/rope.py +45 -1
  93. liger_kernel/transformers/softmax.py +12 -0
  94. liger_kernel/transformers/sparsemax.py +16 -0
  95. liger_kernel/transformers/swiglu.py +39 -1
  96. liger_kernel/transformers/tiled_mlp.py +133 -0
  97. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  98. liger_kernel/transformers/tvd.py +1 -1
  99. liger_kernel/utils.py +36 -0
  100. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/METADATA +68 -38
  101. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  102. liger_kernel/transformers/gema3_rms.py +0 -8
  103. liger_kernel_nightly-0.5.6.dev20250403190551.dist-info/RECORD +0 -82
  104. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  105. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/NOTICE +0 -0
  106. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +0 -0
  107. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,9 @@ import inspect
2
2
  import logging
3
3
 
4
4
  from functools import partial
5
+ from types import MethodType
5
6
  from typing import Callable
7
+ from typing import Optional
6
8
 
7
9
  import transformers
8
10
 
@@ -13,10 +15,12 @@ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
13
15
  from liger_kernel.transformers.functional import liger_cross_entropy
14
16
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
15
17
  from liger_kernel.transformers.layer_norm import LigerLayerNorm
18
+ from liger_kernel.transformers.model.falcon_h1 import lce_forward as falcon_h1_lce_forward
16
19
  from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
17
20
  from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
18
21
  from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
19
22
  from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
23
+ from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward
20
24
  from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
21
25
  from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
22
26
  from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
@@ -25,16 +29,24 @@ from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_f
25
29
  from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
26
30
  from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
27
31
  from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
28
- from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
29
32
  from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
30
33
  from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
34
+ from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
31
35
  from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
32
36
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
33
37
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
38
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb_vision
34
39
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
35
40
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
36
41
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
37
42
 
43
+ try:
44
+ import peft
45
+
46
+ PEFT_AVAILABLE = True
47
+ except ImportError:
48
+ PEFT_AVAILABLE = False
49
+
38
50
  transformer_version = version.parse(transformers.__version__)
39
51
 
40
52
  logger = logging.getLogger(__name__)
@@ -47,33 +59,82 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
47
59
  module.__dict__[method_name] = new_method.__get__(module, module.__class__)
48
60
 
49
61
 
50
- def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
51
- module.offset = offset
52
- module.casting_mode = casting_mode
53
- module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
54
- module.in_place = in_place
55
- _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
56
- _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
57
- module.__class__.__name__ = LigerRMSNorm.__name__
62
+ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None):
63
+ # Check if the module is a PEFT ModulesToSaveWrapper
64
+ # If it is, we need to patch the modules_to_save.default and original_modules
65
+ if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
66
+ module.modules_to_save.default.offset = offset
67
+ module.modules_to_save.default.casting_mode = casting_mode
68
+ module.modules_to_save.default.variance_epsilon = (
69
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
70
+ )
71
+ module.modules_to_save.default.in_place = in_place
72
+ module.modules_to_save.default.row_mode = row_mode
73
+ module.original_module.offset = offset
74
+ module.original_module.casting_mode = casting_mode
75
+ module.original_module.variance_epsilon = (
76
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
77
+ )
78
+ module.original_module.in_place = in_place
79
+ module.original_module.row_mode = row_mode
80
+ _bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
81
+ _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
82
+ _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
83
+ _bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
84
+ _bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerRMSNorm.__name__)
85
+ _bind_method_to_module(module.original_module, "_get_name", lambda self: LigerRMSNorm.__name__)
86
+ else:
87
+ module.offset = offset
88
+ module.casting_mode = casting_mode
89
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
90
+ module.in_place = in_place
91
+ module.row_mode = row_mode
92
+ _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
93
+ _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
94
+ _bind_method_to_module(module, "_get_name", lambda self: LigerRMSNorm.__name__)
58
95
 
59
96
 
60
97
  def _patch_layer_norm_module(module, eps=1e-6):
61
- module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
62
- module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
63
-
64
- _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
65
- _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
66
- module.__class__.__name__ = LigerLayerNorm.__name__
98
+ # Check if the module is a PEFT ModulesToSaveWrapper
99
+ # If it is, we need to patch the modules_to_save.default and original_modules
100
+ if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
101
+ module.hidden_size = module.normalized_shape
102
+ _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
103
+ _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
104
+ module.modules_to_save.default.variance_epsilon = (
105
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
106
+ )
107
+ module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
108
+ module, "normalized_shape", None
109
+ )
110
+ module.original_module.variance_epsilon = (
111
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
112
+ )
113
+ module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
114
+ module, "normalized_shape", None
115
+ )
116
+ _bind_method_to_module(module.modules_to_save.default, "forward", LigerLayerNorm.forward)
117
+ _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerLayerNorm.extra_repr)
118
+ _bind_method_to_module(module.original_module, "forward", LigerLayerNorm.forward)
119
+ _bind_method_to_module(module.original_module, "extra_repr", LigerLayerNorm.extra_repr)
120
+ _bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerLayerNorm.__name__)
121
+ _bind_method_to_module(module.original_module, "_get_name", lambda self: LigerLayerNorm.__name__)
122
+ else:
123
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
124
+ module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
125
+ _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
126
+ _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
127
+ _bind_method_to_module(module, "_get_name", lambda self: LigerLayerNorm.__name__)
67
128
 
68
129
 
69
130
  def _patch_swiglu_module(module, liger_module):
70
131
  _bind_method_to_module(module, "forward", liger_module.forward)
71
- module.__class__.__name__ = liger_module.__name__
132
+ _bind_method_to_module(module, "_get_name", lambda self: liger_module.__name__)
72
133
 
73
134
 
74
135
  def _patch_geglu_module(module):
75
136
  _bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
76
- module.__class__.__name__ = LigerGEGLUMLP.__name__
137
+ _bind_method_to_module(module, "_get_name", lambda self: LigerGEGLUMLP.__name__)
77
138
 
78
139
 
79
140
  def apply_liger_kernel_to_granite(
@@ -204,10 +265,16 @@ def apply_liger_kernel_to_llama(
204
265
 
205
266
  if fused_linear_cross_entropy:
206
267
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
207
- modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
268
+ if model is not None:
269
+ model.forward = MethodType(llama_lce_forward, model)
270
+ else:
271
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
208
272
  else: # if version < 4.46.1
209
273
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
210
- modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
274
+ if model is not None:
275
+ model.forward = MethodType(llama_lce_forward_deprecated, model)
276
+ else:
277
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
211
278
 
212
279
  if model is not None:
213
280
  # The model instance already exists, so we need to additionally patch the
@@ -227,6 +294,77 @@ def apply_liger_kernel_to_llama(
227
294
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
228
295
 
229
296
 
297
+ def apply_liger_kernel_to_smollm3(
298
+ rope: bool = True,
299
+ cross_entropy: bool = False,
300
+ fused_linear_cross_entropy: bool = True,
301
+ rms_norm: bool = True,
302
+ swiglu: bool = True,
303
+ model: PreTrainedModel = None,
304
+ ) -> None:
305
+ """
306
+ Apply Liger kernels to replace original implementation in HuggingFace SmolLM3 model
307
+
308
+ Args:
309
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
310
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
311
+ fused_linear_cross_entropy (bool):
312
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
313
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
314
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
315
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
316
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
317
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
318
+ loaded. Default is None.
319
+ """
320
+
321
+ assert not (cross_entropy and fused_linear_cross_entropy), (
322
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
323
+ )
324
+
325
+ from transformers.models.smollm3 import modeling_smollm3
326
+ from transformers.models.smollm3.modeling_smollm3 import SmolLM3Model
327
+
328
+ if rope:
329
+ modeling_smollm3.apply_rotary_pos_emb = liger_rotary_pos_emb
330
+ if rms_norm:
331
+ modeling_smollm3.SmolLM3RMSNorm = LigerRMSNorm
332
+ if swiglu:
333
+ modeling_smollm3.SmolLM3MLP = LigerSwiGLUMLP
334
+
335
+ if cross_entropy:
336
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
337
+ from transformers.loss.loss_utils import nn
338
+
339
+ nn.functional.cross_entropy = liger_cross_entropy
340
+ else:
341
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
342
+ modeling_smollm3.CrossEntropyLoss = LigerCrossEntropyLoss
343
+
344
+ if fused_linear_cross_entropy:
345
+ if model is not None:
346
+ model.forward = MethodType(smollm3_lce_forward, model)
347
+ else:
348
+ modeling_smollm3.SmolLM3ForCausalLM.forward = smollm3_lce_forward
349
+
350
+ if model is not None:
351
+ # The model instance already exists, so we need to additionally patch the
352
+ # instance variables that reference already-instantiated modules (e.g. SmolLM3RMSNorm or SmolLM3MLP)
353
+
354
+ # get the base model from the model instance
355
+ base_model: SmolLM3Model = getattr(model, model.base_model_prefix, model)
356
+
357
+ if rms_norm:
358
+ _patch_rms_norm_module(base_model.norm)
359
+
360
+ for decoder_layer in base_model.layers:
361
+ if swiglu:
362
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
363
+ if rms_norm:
364
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
365
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
366
+
367
+
230
368
  def apply_liger_kernel_to_llava(
231
369
  cross_entropy: bool = False,
232
370
  fused_linear_cross_entropy: bool = True,
@@ -261,13 +399,20 @@ def apply_liger_kernel_to_llava(
261
399
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
262
400
  modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
263
401
  if fused_linear_cross_entropy:
264
- if transformer_version >= version.parse("4.49.0"):
265
- modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
402
+ if transformer_version >= version.parse("4.52.0"):
403
+ if model is not None:
404
+ model.forward = MethodType(llava_lce_forward, model)
405
+ else:
406
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
407
+ elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
408
+ if model is not None:
409
+ model.forward = MethodType(llava_lce_forward_deprecated, model)
410
+ else:
411
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
266
412
  else: # if version < 4.49.0
267
413
  logger.warning(
268
- "Support for transformers versions < 4.49.0 will soon be discontinued due to issues with incorrect legacy processing. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/35526"
414
+ "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
269
415
  )
270
- modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
271
416
 
272
417
  if model is not None:
273
418
  text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
@@ -306,6 +451,97 @@ def apply_liger_kernel_to_llava(
306
451
  logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
307
452
 
308
453
 
454
+ def apply_liger_kernel_to_llama4(
455
+ rope: bool = True,
456
+ cross_entropy: bool = False,
457
+ fused_linear_cross_entropy: bool = True,
458
+ rms_norm: bool = True,
459
+ swiglu: bool = True,
460
+ model: PreTrainedModel = None,
461
+ layer_norm: bool = True,
462
+ ) -> None:
463
+ """
464
+ Apply Liger kernels to replace original implementation in HuggingFace Llama4 models.
465
+
466
+ Args:
467
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
468
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
469
+ fused_linear_cross_entropy (bool):
470
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
471
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
472
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
473
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
474
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
475
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
476
+ loaded. Default is None.
477
+ """
478
+ assert not (cross_entropy and fused_linear_cross_entropy), (
479
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
480
+ )
481
+
482
+ from transformers.models.llama4 import modeling_llama4
483
+ from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
484
+ from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
485
+ from transformers.models.llama4.modeling_llama4 import Llama4TextModel
486
+ from transformers.models.llama4.modeling_llama4 import Llama4VisionModel
487
+
488
+ from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
489
+
490
+ if rope:
491
+ from liger_kernel.transformers.llama4_rope import apply_liger_llama4_rope_full
492
+
493
+ apply_liger_llama4_rope_full(modeling_llama4)
494
+ if rms_norm:
495
+ modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
496
+ if swiglu:
497
+ modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
498
+
499
+ if cross_entropy:
500
+ modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
501
+
502
+ if fused_linear_cross_entropy:
503
+ modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
504
+
505
+ if model is not None:
506
+ # The model instance already exists, so we need to additionally patch the
507
+ # instance variables that reference already-instantiated modules
508
+ if isinstance(model, Llama4ForConditionalGeneration):
509
+ language_model: Llama4ForCausalLM = model.language_model
510
+ vision_model: Llama4VisionModel = model.vision_model
511
+ text_model: Llama4TextModel = language_model.model
512
+ elif isinstance(model, Llama4ForCausalLM):
513
+ text_model = model.model
514
+ vision_model = None
515
+ elif isinstance(model, Llama4TextModel):
516
+ text_model = model
517
+ vision_model = None
518
+
519
+ else:
520
+ raise ValueError(f"Unsupported Llama4 model type: {type(model)}")
521
+
522
+ if text_model:
523
+ if rms_norm:
524
+ _patch_rms_norm_module(text_model.norm)
525
+ for decoder_layer in text_model.layers:
526
+ if swiglu:
527
+ if decoder_layer.is_moe_layer:
528
+ _patch_swiglu_module(decoder_layer.feed_forward.shared_expert, LigerSwiGLUMLP)
529
+ else:
530
+ _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
531
+ if rms_norm:
532
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
533
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
534
+
535
+ if vision_model:
536
+ _patch_layer_norm_module(vision_model.layernorm_pre)
537
+ _patch_layer_norm_module(vision_model.layernorm_post)
538
+
539
+ for layer in vision_model.model.layers:
540
+ if layer_norm:
541
+ _patch_layer_norm_module(layer.input_layernorm)
542
+ _patch_layer_norm_module(layer.post_attention_layernorm)
543
+
544
+
309
545
  def apply_liger_kernel_to_mllama(
310
546
  rope: bool = True,
311
547
  cross_entropy: bool = False,
@@ -347,7 +583,7 @@ def apply_liger_kernel_to_mllama(
347
583
 
348
584
  if rope:
349
585
  modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
350
- if layer_norm:
586
+ if layer_norm and model is None:
351
587
  modeling_mllama.nn.LayerNorm = LigerLayerNorm
352
588
  if rms_norm:
353
589
  modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
@@ -363,10 +599,16 @@ def apply_liger_kernel_to_mllama(
363
599
  modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
364
600
  if fused_linear_cross_entropy:
365
601
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
366
- modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
602
+ if model is not None:
603
+ model.forward = MethodType(mllama_lce_forward, model)
604
+ else:
605
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
367
606
  else: # if version < 4.46.1
368
607
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
369
- modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
608
+ if model is not None:
609
+ model.forward = MethodType(mllama_lce_forward_deprecated, model)
610
+ else:
611
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
370
612
 
371
613
  if model is not None:
372
614
  # The model instance already exists, so we need to additionally patch the
@@ -375,13 +617,17 @@ def apply_liger_kernel_to_mllama(
375
617
  if isinstance(model, MllamaForConditionalGeneration):
376
618
  language_model: MllamaForCausalLM = model.language_model
377
619
  vision_model: MllamaVisionModel = model.vision_model
378
- text_model: MllamaTextModel = language_model.model
620
+ if isinstance(language_model, MllamaForCausalLM):
621
+ text_model: MllamaTextModel = language_model.model
622
+ else:
623
+ text_model = language_model
379
624
  elif isinstance(model, MllamaForCausalLM):
380
625
  text_model = model.model
381
626
  vision_model = None
382
627
  elif isinstance(model, MllamaTextModel):
383
628
  text_model = model
384
629
  vision_model = None
630
+
385
631
  else:
386
632
  raise ValueError(f"Unsupported Mllama model type: {type(model)}")
387
633
 
@@ -448,7 +694,17 @@ def apply_liger_kernel_to_mistral(
448
694
  if cross_entropy:
449
695
  modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
450
696
  if fused_linear_cross_entropy:
451
- modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
697
+ if transformer_version >= version.parse("4.49.0"):
698
+ if model is not None:
699
+ model.forward = MethodType(mistral_lce_forward, model)
700
+ else:
701
+ modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
702
+ else:
703
+ logger.warning(
704
+ "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
705
+ )
706
+ logger.warning("LigerFusedLinearCrossEntropy patch is not applied.")
707
+
452
708
  if swiglu:
453
709
  modeling_mistral.MistralMLP = LigerSwiGLUMLP
454
710
 
@@ -516,10 +772,16 @@ def apply_liger_kernel_to_mixtral(
516
772
 
517
773
  if fused_linear_cross_entropy:
518
774
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
519
- modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
775
+ if model is not None:
776
+ model.forward = MethodType(mixtral_lce_forward, model)
777
+ else:
778
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
520
779
  else: # if version < 4.46.1
521
780
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
522
- modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
781
+ if model is not None:
782
+ model.forward = MethodType(mixtral_lce_forward_deprecated, model)
783
+ else:
784
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
523
785
  if swiglu:
524
786
  modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
525
787
 
@@ -573,8 +835,8 @@ def apply_liger_kernel_to_gemma(
573
835
  from transformers.models.gemma import modeling_gemma
574
836
  from transformers.models.gemma.modeling_gemma import GemmaModel
575
837
 
576
- # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
577
- LigerRMSNormForGemma = partial(LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma")
838
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma
839
+
578
840
  _patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
579
841
 
580
842
  if rope:
@@ -593,10 +855,16 @@ def apply_liger_kernel_to_gemma(
593
855
  modeling_gemma.GemmaMLP = LigerGEGLUMLP
594
856
  if fused_linear_cross_entropy:
595
857
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
596
- modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
858
+ if model is not None:
859
+ model.forward = MethodType(gemma_lce_forward, model)
860
+ else:
861
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
597
862
  else: # if version < 4.46.1
598
863
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
599
- modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
864
+ if model is not None:
865
+ model.forward = MethodType(gemma_lce_forward_deprecated, model)
866
+ else:
867
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
600
868
 
601
869
  if model is not None:
602
870
  # The model instance already exists, so we need to additionally patch the
@@ -647,7 +915,8 @@ def apply_liger_kernel_to_gemma2(
647
915
  from transformers.models.gemma2 import modeling_gemma2
648
916
  from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
649
917
 
650
- LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False)
918
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2
919
+
651
920
  _patch_rms_norm_module_for_gemma2 = partial(
652
921
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
653
922
  )
@@ -667,10 +936,16 @@ def apply_liger_kernel_to_gemma2(
667
936
  modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
668
937
  if fused_linear_cross_entropy:
669
938
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
670
- modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
939
+ if model is not None:
940
+ model.forward = MethodType(gemma2_lce_forward, model)
941
+ else:
942
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
671
943
  else:
672
944
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
673
- modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
945
+ if model is not None:
946
+ model.forward = MethodType(gemma2_lce_forward_deprected, model)
947
+ else:
948
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
674
949
  if geglu:
675
950
  modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
676
951
 
@@ -724,9 +999,10 @@ def apply_liger_kernel_to_gemma3_text(
724
999
  from transformers.models.gemma3 import modeling_gemma3
725
1000
  from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
726
1001
  from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
1002
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
727
1003
 
728
- from liger_kernel.transformers.gema3_rms import LigerRMSNormForGemma3
729
1004
  from liger_kernel.transformers.model.gemma3 import causal_forward
1005
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma3
730
1006
 
731
1007
  _patch_rms_norm_module_for_gemma3 = partial(
732
1008
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
@@ -748,15 +1024,18 @@ def apply_liger_kernel_to_gemma3_text(
748
1024
  nn.functional.cross_entropy = liger_cross_entropy
749
1025
 
750
1026
  if fused_linear_cross_entropy:
751
- modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
1027
+ if model is not None:
1028
+ model.forward = MethodType(causal_forward, model)
1029
+ else:
1030
+ modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
752
1031
 
753
1032
  if model is not None:
754
1033
  # The model instance already exists, so we need to additionally patch the
755
1034
  # instance variables that reference already-instantiated modules
756
1035
 
757
- if isinstance(model, Gemma3ForCausalLM):
1036
+ if isinstance(model, Gemma3ForCausalLM) or isinstance(model, Gemma3TextModel):
758
1037
  # get the base model from the model instance
759
- base_model = model.model
1038
+ base_model = model.model if isinstance(model, Gemma3ForCausalLM) else model
760
1039
 
761
1040
  if rms_norm:
762
1041
  _patch_rms_norm_module_for_gemma3(base_model.norm)
@@ -818,7 +1097,7 @@ def apply_liger_kernel_to_gemma3(
818
1097
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
819
1098
  )
820
1099
 
821
- if layer_norm:
1100
+ if layer_norm and model is None:
822
1101
  modeling_siglip.nn.LayerNorm = LigerLayerNorm
823
1102
 
824
1103
  apply_liger_kernel_to_gemma3_text(
@@ -829,7 +1108,10 @@ def apply_liger_kernel_to_gemma3(
829
1108
  modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
830
1109
 
831
1110
  if fused_linear_cross_entropy:
832
- modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
1111
+ if model is not None:
1112
+ model.forward = MethodType(multimodal_forward, model)
1113
+ else:
1114
+ modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
833
1115
 
834
1116
  if model is not None:
835
1117
  # The model instance already exists, so we need to additionally patch the
@@ -897,7 +1179,9 @@ def apply_liger_kernel_to_paligemma(
897
1179
  # PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
898
1180
 
899
1181
  from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
1182
+ from transformers.models.gemma.modeling_gemma import GemmaModel
900
1183
  from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
1184
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
901
1185
  from transformers.models.paligemma import modeling_paligemma
902
1186
  from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
903
1187
  from transformers.models.siglip import modeling_siglip
@@ -908,7 +1192,7 @@ def apply_liger_kernel_to_paligemma(
908
1192
  from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
909
1193
 
910
1194
  # The vision_tower is a SiglipVisionModel
911
- if layer_norm:
1195
+ if layer_norm and model is None:
912
1196
  modeling_siglip.nn.LayerNorm = LigerLayerNorm
913
1197
 
914
1198
  # SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
@@ -926,10 +1210,16 @@ def apply_liger_kernel_to_paligemma(
926
1210
  modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
927
1211
  if fused_linear_cross_entropy:
928
1212
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
929
- modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
1213
+ if model is not None:
1214
+ model.forward = MethodType(lce_forward, model)
1215
+ else:
1216
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
930
1217
  else: # if version < 4.46.1
931
1218
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
932
- modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
1219
+ if model is not None:
1220
+ model.forward = MethodType(lce_forward_deprecated, model)
1221
+ else:
1222
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
933
1223
 
934
1224
  if model is not None:
935
1225
  # The model instance already exists, so we need to additionally patch the
@@ -950,7 +1240,7 @@ def apply_liger_kernel_to_paligemma(
950
1240
 
951
1241
  language_model = model.language_model
952
1242
 
953
- if isinstance(language_model, GemmaForCausalLM):
1243
+ if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
954
1244
  apply_liger_kernel_to_gemma(
955
1245
  rope=rope,
956
1246
  cross_entropy=False,
@@ -960,7 +1250,7 @@ def apply_liger_kernel_to_paligemma(
960
1250
  model=language_model,
961
1251
  )
962
1252
 
963
- elif isinstance(language_model, Gemma2ForCausalLM):
1253
+ elif isinstance(language_model, (Gemma2ForCausalLM, Gemma2Model)):
964
1254
  apply_liger_kernel_to_gemma2(
965
1255
  rope=rope,
966
1256
  cross_entropy=False,
@@ -1021,10 +1311,16 @@ def apply_liger_kernel_to_qwen2(
1021
1311
 
1022
1312
  if fused_linear_cross_entropy:
1023
1313
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1024
- modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
1314
+ if model is not None:
1315
+ model.forward = MethodType(qwen2_lce_forward, model)
1316
+ else:
1317
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
1025
1318
  else: # if version < 4.46.1
1026
1319
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1027
- modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
1320
+ if model is not None:
1321
+ model.forward = MethodType(qwen2_lce_forward_deprecated, model)
1322
+ else:
1323
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
1028
1324
 
1029
1325
  if swiglu:
1030
1326
  modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
@@ -1045,70 +1341,54 @@ def apply_liger_kernel_to_qwen2(
1045
1341
  if rms_norm:
1046
1342
  _patch_rms_norm_module(decoder_layer.input_layernorm)
1047
1343
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1048
- print("Applied Liger kernels to Qwen2")
1049
1344
 
1050
1345
 
1051
- def apply_liger_kernel_to_qwen2_vl(
1346
+ def apply_liger_kernel_to_qwen3(
1052
1347
  rope: bool = True,
1053
1348
  cross_entropy: bool = False,
1054
1349
  fused_linear_cross_entropy: bool = True,
1055
1350
  rms_norm: bool = True,
1056
- layer_norm: bool = True,
1057
1351
  swiglu: bool = True,
1058
1352
  model: PreTrainedModel = None,
1059
1353
  ) -> None:
1060
1354
  """
1061
- Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
1062
- NOTE: Qwen2-VL is not available in transformers<4.45.0
1063
-
1064
- Args:
1065
- cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1066
- fused_linear_cross_entropy (bool):
1067
- Whether to apply Liger's fused linear cross entropy loss. Default is True.
1068
- `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1069
- If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1070
- rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1071
- layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
1072
- swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
1073
- model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1074
- loaded. Default is None.
1355
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
1075
1356
  """
1076
1357
  assert not (cross_entropy and fused_linear_cross_entropy), (
1077
1358
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
1078
1359
  )
1079
1360
 
1080
- from transformers.models.qwen2_vl import modeling_qwen2_vl
1081
- from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
1361
+ from transformers.models.qwen3 import modeling_qwen3
1362
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3Model
1082
1363
 
1083
- from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
1364
+ from liger_kernel.transformers.model.qwen3 import lce_forward as qwen3_lce_forward
1084
1365
 
1085
1366
  if rope:
1086
- modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
1367
+ modeling_qwen3.apply_rotary_pos_emb = liger_rotary_pos_emb
1368
+
1087
1369
  if rms_norm:
1088
- # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
1089
- modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
1090
- if layer_norm:
1091
- modeling_qwen2_vl.LayerNorm = LigerLayerNorm
1370
+ modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm
1371
+
1092
1372
  if cross_entropy:
1093
- modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1373
+ from transformers.loss.loss_utils import nn
1374
+
1375
+ nn.functional.cross_entropy = liger_cross_entropy
1376
+
1094
1377
  if fused_linear_cross_entropy:
1095
- modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
1378
+ if model is not None:
1379
+ model.forward = MethodType(qwen3_lce_forward, model)
1380
+ else:
1381
+ modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
1382
+
1096
1383
  if swiglu:
1097
- modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
1384
+ modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
1098
1385
 
1099
1386
  if model is not None:
1100
1387
  # The model instance already exists, so we need to additionally patch the
1101
1388
  # instance variables that reference already-instantiated modules
1102
1389
 
1103
1390
  # get the base model from the model instance
1104
- base_model: Qwen2VLModel = getattr(model, model.base_model_prefix, model)
1105
-
1106
- if hasattr(model, "visual"):
1107
- # Patch Qwen2VisionTransformerPretrainedModel
1108
- for vision_block in model.visual.blocks:
1109
- if layer_norm:
1110
- _patch_layer_norm_module(vision_block.norm1)
1111
- _patch_layer_norm_module(vision_block.norm2)
1391
+ base_model: Qwen3Model = getattr(model, model.base_model_prefix, model)
1112
1392
 
1113
1393
  if rms_norm:
1114
1394
  _patch_rms_norm_module(base_model.norm)
@@ -1120,7 +1400,7 @@ def apply_liger_kernel_to_qwen2_vl(
1120
1400
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1121
1401
 
1122
1402
 
1123
- def apply_liger_kernel_to_qwen2_5_vl(
1403
+ def apply_liger_kernel_to_qwen3_moe(
1124
1404
  rope: bool = True,
1125
1405
  cross_entropy: bool = False,
1126
1406
  fused_linear_cross_entropy: bool = True,
@@ -1129,74 +1409,69 @@ def apply_liger_kernel_to_qwen2_5_vl(
1129
1409
  model: PreTrainedModel = None,
1130
1410
  ) -> None:
1131
1411
  """
1132
- Apply Liger kernels to replace original implementation in HuggingFace Qwen2.5-VL models.
1133
- NOTE: Qwen2.5-VL is not available in transformers<4.48.2
1134
-
1135
- Args:
1136
- cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1137
- fused_linear_cross_entropy (bool):
1138
- Whether to apply Liger's fused linear cross entropy loss. Default is True.
1139
- `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1140
- If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1141
- rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1142
- swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
1143
- model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1144
- loaded. Default is None.
1412
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
1145
1413
  """
1146
1414
  assert not (cross_entropy and fused_linear_cross_entropy), (
1147
1415
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
1148
1416
  )
1149
1417
 
1150
- from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
1151
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
1418
+ from transformers.models.qwen3_moe import modeling_qwen3_moe
1419
+ from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeModel
1152
1420
 
1153
- from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward
1421
+ from liger_kernel.transformers.model.qwen3_moe import lce_forward as qwen3_lce_forward
1422
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
1154
1423
 
1155
1424
  if rope:
1156
- modeling_qwen2_5_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
1425
+ modeling_qwen3_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
1426
+
1157
1427
  if rms_norm:
1158
- modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm
1428
+ modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
1429
+
1159
1430
  if cross_entropy:
1160
- modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1431
+ from transformers.loss.loss_utils import nn
1432
+
1433
+ nn.functional.cross_entropy = liger_cross_entropy
1434
+
1161
1435
  if fused_linear_cross_entropy:
1162
- modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
1436
+ if model is not None:
1437
+ model.forward = MethodType(qwen3_lce_forward, model)
1438
+ else:
1439
+ modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
1440
+
1163
1441
  if swiglu:
1164
- modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
1442
+ modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
1165
1443
 
1166
1444
  if model is not None:
1167
1445
  # The model instance already exists, so we need to additionally patch the
1168
1446
  # instance variables that reference already-instantiated modules
1169
1447
 
1170
1448
  # get the base model from the model instance
1171
- base_model: Qwen2_5_VLModel = getattr(model, model.base_model_prefix, model)
1172
-
1173
- if hasattr(model, "visual"):
1174
- # Patch Qwen2_5_VisionTransformerPretrainedModel
1175
- for vision_block in model.visual.blocks:
1176
- if rms_norm:
1177
- _patch_rms_norm_module(vision_block.norm1)
1178
- _patch_rms_norm_module(vision_block.norm2)
1449
+ base_model: Qwen3MoeModel = getattr(model, model.base_model_prefix, model)
1179
1450
 
1180
1451
  if rms_norm:
1181
1452
  _patch_rms_norm_module(base_model.norm)
1182
1453
  for decoder_layer in base_model.layers:
1183
1454
  if swiglu:
1184
- _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1455
+ for mlp_expert in decoder_layer.mlp.experts:
1456
+ _patch_swiglu_module(mlp_expert, LigerQwen3MoeSwiGLUMLP)
1185
1457
  if rms_norm:
1186
1458
  _patch_rms_norm_module(decoder_layer.input_layernorm)
1187
1459
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1188
1460
 
1189
1461
 
1190
- def apply_liger_kernel_to_phi3(
1462
+ def apply_liger_kernel_to_gpt_oss(
1191
1463
  rope: bool = True,
1192
1464
  cross_entropy: bool = False,
1193
1465
  fused_linear_cross_entropy: bool = True,
1194
1466
  rms_norm: bool = True,
1195
- swiglu: bool = True,
1467
+ swiglu: bool = False, # Set to False by default since GPT-OSS has custom expert implementation
1196
1468
  model: PreTrainedModel = None,
1197
1469
  ) -> None:
1198
1470
  """
1199
- Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
1471
+ Apply Liger kernels to replace original implementation in HuggingFace GPT-OSS models.
1472
+ NOTE: GPT-OSS is supported in transformers >= 4.55.0
1473
+ NOTE: SwiGLU patching is disabled by default for GPT-OSS as it uses a custom expert
1474
+ implementation with clamping and MXFP4 quantization.
1200
1475
 
1201
1476
  Args:
1202
1477
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
@@ -1206,58 +1481,1135 @@ def apply_liger_kernel_to_phi3(
1206
1481
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1207
1482
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1208
1483
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1209
- swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
1484
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1485
+ Note: GPT-OSS uses a custom expert implementation, so SwiGLU patching is disabled by default.
1210
1486
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1211
- loaded. Default is None.
1487
+ loaded. Default is None.
1212
1488
  """
1489
+ if version.parse(transformers.__version__) < version.parse("4.55.0"):
1490
+ logger.warning("GPT-OSS support requires transformers >= 4.55.0")
1491
+ return
1492
+
1213
1493
  assert not (cross_entropy and fused_linear_cross_entropy), (
1214
1494
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
1215
1495
  )
1216
1496
 
1217
- from transformers.models.phi3 import modeling_phi3
1218
- from transformers.models.phi3.modeling_phi3 import Phi3Model
1497
+ from transformers.models.gpt_oss import modeling_gpt_oss
1498
+ from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel
1219
1499
 
1220
1500
  if rope:
1221
- modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
1501
+ modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb
1502
+
1222
1503
  if rms_norm:
1223
- modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
1224
- if swiglu:
1225
- modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
1504
+ modeling_gpt_oss.GptOssRMSNorm = LigerRMSNorm
1505
+
1226
1506
  if cross_entropy:
1227
- if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1228
- from transformers.loss.loss_utils import nn
1507
+ from transformers.loss.loss_utils import nn
1508
+
1509
+ nn.functional.cross_entropy = liger_cross_entropy
1229
1510
 
1230
- nn.functional.cross_entropy = liger_cross_entropy
1231
- else:
1232
- logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1233
- modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
1234
1511
  if fused_linear_cross_entropy:
1235
- if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1236
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1237
- else: # if version < 4.46.1
1238
- logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1239
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
1512
+ if model is not None:
1513
+ model.forward = MethodType(gpt_oss_lce_forward, model)
1514
+ else:
1515
+ modeling_gpt_oss.GptOssForCausalLM.forward = gpt_oss_lce_forward
1516
+
1517
+ # Note: SwiGLU patching is not implemented for GPT-OSS due to custom expert implementation
1518
+ # with clamping (swiglu_limit=7.0) and MXFP4 quantization
1519
+
1520
+ if model is not None:
1521
+ # The model instance already exists, so we need to additionally patch the
1522
+ # instance variables that reference already-instantiated modules
1523
+
1524
+ # get the base model from the model instance
1525
+ base_model: GptOssModel = getattr(model, model.base_model_prefix, model)
1526
+
1527
+ if rms_norm:
1528
+ _patch_rms_norm_module(base_model.norm)
1529
+ for decoder_layer in base_model.layers:
1530
+ if rms_norm:
1531
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1532
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1533
+
1534
+
1535
+ def apply_liger_kernel_to_qwen2_vl(
1536
+ rope: bool = True,
1537
+ cross_entropy: bool = False,
1538
+ fused_linear_cross_entropy: bool = True,
1539
+ rms_norm: bool = True,
1540
+ layer_norm: bool = True,
1541
+ swiglu: bool = True,
1542
+ model: PreTrainedModel = None,
1543
+ ) -> None:
1544
+ """
1545
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
1546
+ NOTE: Qwen2-VL is not supported in transformers<4.52.4
1547
+
1548
+ Args:
1549
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1550
+ fused_linear_cross_entropy (bool):
1551
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1552
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1553
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1554
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1555
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
1556
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
1557
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1558
+ loaded. Default is None.
1559
+ """
1560
+ if transformer_version < version.parse("4.52.4"):
1561
+ logger.warning("Qwen2-VL support is only compatible with transformers >= 4.52.4")
1562
+ return
1563
+
1564
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1565
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1566
+ )
1567
+
1568
+ from transformers.models.qwen2_vl import modeling_qwen2_vl
1569
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
1570
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
1571
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
1572
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel
1573
+
1574
+ from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
1575
+
1576
+ if rope:
1577
+ modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
1578
+ if rms_norm:
1579
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
1580
+ modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
1581
+ if layer_norm and model is None:
1582
+ modeling_qwen2_vl.LayerNorm = LigerLayerNorm
1583
+ if cross_entropy:
1584
+ modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1585
+ if fused_linear_cross_entropy:
1586
+ if model is not None:
1587
+ model.forward = MethodType(qwen2_vl_lce_forward, model)
1588
+ else:
1589
+ modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
1590
+ if swiglu:
1591
+ modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
1592
+
1593
+ if model is not None:
1594
+ # The model instance already exists, so we need to additionally patch the
1595
+ # instance variables that reference already-instantiated modules
1596
+
1597
+ if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)):
1598
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
1599
+ # Not sure if it is subject to changes in the future.
1600
+ # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
1601
+ text_model: Qwen2VLTextModel = model.language_model
1602
+ vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
1603
+ elif isinstance(model, Qwen2VLTextModel):
1604
+ text_model: Qwen2VLTextModel = model
1605
+ vision_model = None
1606
+ else:
1607
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
1608
+ raise TypeError(
1609
+ f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
1610
+ )
1611
+
1612
+ # Patch Qwen2VisionTransformerPretrainedModel
1613
+ if vision_model is not None:
1614
+ for vision_block in vision_model.blocks:
1615
+ if layer_norm:
1616
+ _patch_layer_norm_module(vision_block.norm1)
1617
+ _patch_layer_norm_module(vision_block.norm2)
1618
+
1619
+ # Patch Qwen2VisionTextModel
1620
+ if text_model is not None:
1621
+ if rms_norm:
1622
+ _patch_rms_norm_module(text_model.norm)
1623
+ for decoder_layer in text_model.layers:
1624
+ if swiglu:
1625
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1626
+ if rms_norm:
1627
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1628
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1629
+
1630
+
1631
+ def apply_liger_kernel_to_qwen2_5_vl(
1632
+ rope: bool = True,
1633
+ cross_entropy: bool = False,
1634
+ fused_linear_cross_entropy: bool = True,
1635
+ rms_norm: bool = True,
1636
+ swiglu: bool = True,
1637
+ model: PreTrainedModel = None,
1638
+ ) -> None:
1639
+ """
1640
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen2.5-VL models.
1641
+ NOTE: Qwen2.5-VL is not available in transformers<4.48.2
1642
+
1643
+ Args:
1644
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1645
+ fused_linear_cross_entropy (bool):
1646
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1647
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1648
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1649
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1650
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
1651
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1652
+ loaded. Default is None.
1653
+ """
1654
+ if transformer_version < version.parse("4.52.4"):
1655
+ logger.warning("Qwen2.5-VL support is only compatible with transformers >= 4.52.4")
1656
+ return
1657
+
1658
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1659
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1660
+ )
1661
+
1662
+ from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
1663
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
1664
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
1665
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
1666
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel
1667
+
1668
+ from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward
1669
+
1670
+ if rope:
1671
+ modeling_qwen2_5_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
1672
+ if rms_norm:
1673
+ modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm
1674
+ if cross_entropy:
1675
+ modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1676
+ if fused_linear_cross_entropy:
1677
+ if model is not None:
1678
+ model.forward = MethodType(qwen2_5_vl_lce_forward, model)
1679
+ else:
1680
+ modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
1681
+ if swiglu:
1682
+ modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
1683
+
1684
+ if model is not None:
1685
+ # The model instance already exists, so we need to additionally patch the
1686
+ # instance variables that reference already-instantiated modules
1687
+
1688
+ if isinstance(model, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel)):
1689
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
1690
+ # Not sure if it is subject to changes in the future.
1691
+ # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
1692
+ text_model: Qwen2_5_VLTextModel = model.language_model
1693
+ vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
1694
+ elif isinstance(model, Qwen2_5_VLTextModel):
1695
+ text_model: Qwen2_5_VLTextModel = model
1696
+ vision_model = None
1697
+ else:
1698
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
1699
+ raise TypeError(
1700
+ f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
1701
+ )
1702
+
1703
+ if vision_model is not None:
1704
+ # Patch Qwen2_5_VisionTransformerPretrainedModel
1705
+ for vision_block in model.visual.blocks:
1706
+ if rms_norm:
1707
+ _patch_rms_norm_module(vision_block.norm1)
1708
+ _patch_rms_norm_module(vision_block.norm2)
1709
+
1710
+ if text_model is not None:
1711
+ if rms_norm:
1712
+ _patch_rms_norm_module(text_model.norm)
1713
+ for decoder_layer in text_model.layers:
1714
+ if swiglu:
1715
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1716
+ if rms_norm:
1717
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1718
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1719
+
1720
+
1721
+ def apply_liger_kernel_to_qwen3_vl(
1722
+ rope: bool = True,
1723
+ cross_entropy: bool = False,
1724
+ fused_linear_cross_entropy: bool = True,
1725
+ rms_norm: bool = True,
1726
+ swiglu: bool = False,
1727
+ model: PreTrainedModel = None,
1728
+ ) -> None:
1729
+ """
1730
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
1731
+
1732
+ Args:
1733
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1734
+ fused_linear_cross_entropy (bool):
1735
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1736
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1737
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1738
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1739
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1740
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1741
+ loaded. Default is None.
1742
+ """
1743
+
1744
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1745
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1746
+ )
1747
+
1748
+ from transformers.models.qwen3_vl import modeling_qwen3_vl
1749
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
1750
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
1751
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
1752
+
1753
+ from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
1754
+
1755
+ if rope:
1756
+ modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb
1757
+ modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
1758
+
1759
+ if rms_norm:
1760
+ modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
1761
+
1762
+ if cross_entropy:
1763
+ from transformers.loss.loss_utils import nn
1764
+
1765
+ nn.functional.cross_entropy = liger_cross_entropy
1766
+
1767
+ if fused_linear_cross_entropy:
1768
+ if model is not None:
1769
+ model.forward = MethodType(qwen3_vl_lce_forward, model)
1770
+ else:
1771
+ modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
1772
+
1773
+ if model is not None and rms_norm:
1774
+ if isinstance(model, (Qwen3VLForConditionalGeneration, Qwen3VLModel)):
1775
+ text_model: Qwen3VLTextModel = model.language_model
1776
+ elif isinstance(model, Qwen3VLTextModel):
1777
+ text_model = model
1778
+ else:
1779
+ raise TypeError(
1780
+ f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
1781
+ )
1782
+
1783
+ _patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1784
+
1785
+ if text_model is not None:
1786
+ _patch_qwen3_vl_rms_norm(text_model.norm)
1787
+ for decoder_layer in text_model.layers:
1788
+ _patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
1789
+ _patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
1790
+ self_attn = getattr(decoder_layer, "self_attn", None)
1791
+ if self_attn is not None:
1792
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1793
+ _patch_qwen3_vl_rms_norm(self_attn.q_norm)
1794
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1795
+ _patch_qwen3_vl_rms_norm(self_attn.k_norm)
1796
+
1797
+
1798
+ def apply_liger_kernel_to_qwen3_vl_moe(
1799
+ rope: bool = True,
1800
+ cross_entropy: bool = False,
1801
+ fused_linear_cross_entropy: bool = True,
1802
+ rms_norm: bool = True,
1803
+ swiglu: bool = False,
1804
+ model: PreTrainedModel = None,
1805
+ ) -> None:
1806
+ """
1807
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
1808
+
1809
+ Args:
1810
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1811
+ fused_linear_cross_entropy (bool):
1812
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
1813
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1814
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1815
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1816
+ loaded. Default is None.
1817
+ """
1818
+
1819
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1820
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1821
+ )
1822
+
1823
+ from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
1824
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
1825
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
1826
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
1827
+
1828
+ from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
1829
+
1830
+ if rope:
1831
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
1832
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
1833
+
1834
+ if rms_norm:
1835
+ modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
1836
+
1837
+ if cross_entropy:
1838
+ from transformers.loss.loss_utils import nn
1839
+
1840
+ nn.functional.cross_entropy = liger_cross_entropy
1841
+
1842
+ if fused_linear_cross_entropy:
1843
+ if model is not None:
1844
+ model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
1845
+ else:
1846
+ modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
1847
+
1848
+ if model is not None and rms_norm:
1849
+ if isinstance(model, (Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel)):
1850
+ text_model: Qwen3VLMoeTextModel = model.language_model
1851
+ elif isinstance(model, Qwen3VLMoeTextModel):
1852
+ text_model = model
1853
+ else:
1854
+ raise TypeError(
1855
+ f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
1856
+ )
1857
+
1858
+ _patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1859
+
1860
+ if text_model is not None:
1861
+ _patch_qwen3_vl_moe_rms_norm(text_model.norm)
1862
+ for decoder_layer in text_model.layers:
1863
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
1864
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
1865
+ self_attn = getattr(decoder_layer, "self_attn", None)
1866
+ if self_attn is not None:
1867
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1868
+ _patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
1869
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1870
+ _patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
1871
+
1872
+
1873
+ def apply_liger_kernel_to_phi3(
1874
+ rope: bool = True,
1875
+ cross_entropy: bool = False,
1876
+ fused_linear_cross_entropy: bool = True,
1877
+ rms_norm: bool = True,
1878
+ swiglu: bool = True,
1879
+ model: PreTrainedModel = None,
1880
+ ) -> None:
1881
+ """
1882
+ Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
1883
+
1884
+ Args:
1885
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1886
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1887
+ fused_linear_cross_entropy (bool):
1888
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1889
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1890
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1891
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1892
+ swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
1893
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1894
+ loaded. Default is None.
1895
+ """
1896
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1897
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1898
+ )
1899
+
1900
+ from transformers.models.phi3 import modeling_phi3
1901
+ from transformers.models.phi3.modeling_phi3 import Phi3Model
1902
+
1903
+ if rope:
1904
+ modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
1905
+ if rms_norm:
1906
+ modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
1907
+ if swiglu:
1908
+ modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
1909
+ if cross_entropy:
1910
+ from transformers.loss.loss_utils import nn
1911
+
1912
+ nn.functional.cross_entropy = liger_cross_entropy
1913
+ if fused_linear_cross_entropy:
1914
+ if model is not None:
1915
+ model.forward = MethodType(phi3_lce_forward, model)
1916
+ else:
1917
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1918
+
1919
+ if model is not None:
1920
+ # The model instance already exists, so we need to additionally patch the
1921
+ # instance variables that reference already-instantiated modules
1922
+
1923
+ # get the base model from the model instance
1924
+ base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
1925
+
1926
+ if rms_norm:
1927
+ _patch_rms_norm_module(base_model.norm)
1928
+
1929
+ for decoder_layer in base_model.layers:
1930
+ if swiglu:
1931
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
1932
+ if rms_norm:
1933
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1934
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1935
+
1936
+
1937
+ def apply_liger_kernel_to_olmo2(
1938
+ rope: bool = True,
1939
+ cross_entropy: bool = False,
1940
+ fused_linear_cross_entropy: bool = True,
1941
+ rms_norm: bool = True,
1942
+ swiglu: bool = True,
1943
+ model: PreTrainedModel = None,
1944
+ ) -> None:
1945
+ """
1946
+ Apply Liger kernels to replace original implementation in HuggingFace OLMO2 models.
1947
+
1948
+ Args:
1949
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1950
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1951
+ fused_linear_cross_entropy (bool):
1952
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1953
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1954
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1955
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1956
+ swiglu (bool): Whether to apply Liger's SwiGLU Olmo2MLP. Default is True.
1957
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1958
+ loaded. Default is None.
1959
+ """
1960
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1961
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1962
+ )
1963
+
1964
+ from transformers.models.olmo2 import modeling_olmo2
1965
+ from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
1966
+
1967
+ from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
1968
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
1969
+
1970
+ if rope:
1971
+ modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
1972
+ if rms_norm:
1973
+ modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
1974
+ if swiglu:
1975
+ modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
1976
+ if cross_entropy:
1977
+ from transformers.loss.loss_utils import nn
1978
+
1979
+ nn.functional.cross_entropy = liger_cross_entropy
1980
+ if fused_linear_cross_entropy:
1981
+ if model is not None:
1982
+ model.forward = MethodType(olmo2_lce_forward, model)
1983
+ else:
1984
+ modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
1985
+
1986
+ if model is not None:
1987
+ # The model instance already exists, so we need to additionally patch the
1988
+ # instance variables that reference already-instantiated modules
1989
+
1990
+ # get the base model from the model instance
1991
+ base_model: Olmo2Model = getattr(model, model.base_model_prefix, model)
1992
+
1993
+ if rms_norm:
1994
+ _patch_rms_norm_module(base_model.norm)
1995
+
1996
+ for decoder_layer in base_model.layers:
1997
+ if swiglu:
1998
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1999
+ if rms_norm:
2000
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
2001
+ _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
2002
+
2003
+
2004
+ def apply_liger_kernel_to_olmo3(
2005
+ rope: bool = True,
2006
+ cross_entropy: bool = False,
2007
+ fused_linear_cross_entropy: bool = True,
2008
+ rms_norm: bool = True,
2009
+ swiglu: bool = True,
2010
+ model: PreTrainedModel = None,
2011
+ ) -> None:
2012
+ """
2013
+ Apply Liger kernels to replace original implementation in HuggingFace Olmo3 models.
2014
+
2015
+ Args:
2016
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2017
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2018
+ fused_linear_cross_entropy (bool):
2019
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2020
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2021
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2022
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2023
+ swiglu (bool): Whether to apply Liger's SwiGLU to Olmo3MLP. Default is True.
2024
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2025
+ loaded. Default is None.
2026
+ """
2027
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2028
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2029
+ )
2030
+
2031
+ from transformers.models.olmo3 import modeling_olmo3
2032
+ from transformers.models.olmo3.modeling_olmo3 import Olmo3Model
2033
+
2034
+ from liger_kernel.transformers.model.olmo3 import lce_forward as olmo3_lce_forward
2035
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
2036
+
2037
+ # Olmo3 arch is very similar to Olmo2, so we can reuse all these components in the same way.
2038
+ if rope:
2039
+ modeling_olmo3.apply_rotary_pos_emb = liger_rotary_pos_emb
2040
+ if rms_norm:
2041
+ modeling_olmo3.Olmo3RMSNorm = LigerRMSNormForOlmo2 # same as olmo2
2042
+ if swiglu:
2043
+ modeling_olmo3.Olmo3MLP = LigerSwiGLUMLP
2044
+ if cross_entropy:
2045
+ from transformers.loss.loss_utils import nn
2046
+
2047
+ nn.functional.cross_entropy = liger_cross_entropy
2048
+ if fused_linear_cross_entropy:
2049
+ if model is not None:
2050
+ model.forward = MethodType(olmo3_lce_forward, model)
2051
+ else:
2052
+ modeling_olmo3.Olmo3ForCausalLM.forward = olmo3_lce_forward
2053
+
2054
+ if model is not None:
2055
+ # The model instance already exists, so we need to additionally patch the
2056
+ # instance variables that reference already-instantiated modules
2057
+
2058
+ # get the base model from the model instance
2059
+ base_model: Olmo3Model = getattr(model, model.base_model_prefix, model)
2060
+
2061
+ if rms_norm:
2062
+ _patch_rms_norm_module(base_model.norm)
2063
+
2064
+ for decoder_layer in base_model.layers:
2065
+ if swiglu:
2066
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2067
+ if rms_norm:
2068
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
2069
+ _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
2070
+
2071
+
2072
+ def apply_liger_kernel_to_glm4(
2073
+ rope: bool = False,
2074
+ cross_entropy: bool = False,
2075
+ fused_linear_cross_entropy: bool = True,
2076
+ rms_norm: bool = True,
2077
+ swiglu: bool = True,
2078
+ model: PreTrainedModel = None,
2079
+ ) -> None:
2080
+ """
2081
+ Apply Liger kernels to replace original implementation in HuggingFace GLM-4 models.
2082
+
2083
+ Args:
2084
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2085
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2086
+ fused_linear_cross_entropy (bool):
2087
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2088
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2089
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2090
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2091
+ swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
2092
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2093
+ loaded. Default is None.
2094
+ """
2095
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2096
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2097
+ )
2098
+
2099
+ from transformers.models.glm4 import modeling_glm4
2100
+ from transformers.models.glm4.modeling_glm4 import Glm4Model
2101
+
2102
+ from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
2103
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
2104
+
2105
+ if rope:
2106
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
2107
+ if rms_norm:
2108
+ modeling_glm4.Glm4RMSNorm = LigerRMSNormForGlm4
2109
+ if swiglu:
2110
+ modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
2111
+ if cross_entropy:
2112
+ from transformers.loss.loss_utils import nn
2113
+
2114
+ nn.functional.cross_entropy = liger_cross_entropy
2115
+ if fused_linear_cross_entropy:
2116
+ if model is not None:
2117
+ model.forward = MethodType(glm4_lce_forward, model)
2118
+ else:
2119
+ modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
2120
+
2121
+ if model is not None:
2122
+ # The model instance already exists, so we need to additionally patch the
2123
+ # instance variables that reference already-instantiated modules
2124
+
2125
+ # get the base model from the model instance
2126
+ base_model: Glm4Model = getattr(model, model.base_model_prefix, model)
2127
+
2128
+ if rms_norm:
2129
+ _patch_rms_norm_module(base_model.norm, in_place=False)
2130
+
2131
+ for decoder_layer in base_model.layers:
2132
+ if swiglu:
2133
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
2134
+ if rms_norm:
2135
+ _patch_rms_norm_module(decoder_layer.input_layernorm, in_place=False)
2136
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
2137
+ _patch_rms_norm_module(decoder_layer.post_self_attn_layernorm, in_place=False)
2138
+ _patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
2139
+
2140
+
2141
+ def apply_liger_kernel_to_glm4v(
2142
+ rope: bool = False,
2143
+ cross_entropy: bool = False,
2144
+ fused_linear_cross_entropy: bool = True,
2145
+ rms_norm: bool = True,
2146
+ swiglu: bool = True,
2147
+ model: PreTrainedModel = None,
2148
+ ) -> None:
2149
+ """
2150
+ Apply Liger kernels to replace original implementation in HuggingFace GLM-4v models.
2151
+
2152
+ Args:
2153
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2154
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2155
+ fused_linear_cross_entropy (bool):
2156
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2157
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2158
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2159
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2160
+ swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
2161
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2162
+ loaded. Default is None.
2163
+ """
2164
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2165
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2166
+ )
2167
+
2168
+ from transformers.models.glm4v import modeling_glm4v
2169
+ from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration
2170
+ from transformers.models.glm4v.modeling_glm4v import Glm4vModel
2171
+ from transformers.models.glm4v.modeling_glm4v import Glm4vTextModel
2172
+ from transformers.models.glm4v.modeling_glm4v import Glm4vVisionModel
2173
+
2174
+ from liger_kernel.transformers.model.glm4v import lce_forward as glm4v_lce_forward
2175
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
2176
+
2177
+ if rope:
2178
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
2179
+ if rms_norm:
2180
+ modeling_glm4v.Glm4vRMSNorm = LigerRMSNormForGlm4
2181
+ if cross_entropy:
2182
+ from transformers.loss.loss_utils import nn
2183
+
2184
+ nn.functional.cross_entropy = liger_cross_entropy
2185
+ if fused_linear_cross_entropy:
2186
+ if model is not None:
2187
+ model.forward = MethodType(glm4v_lce_forward, model)
2188
+ else:
2189
+ modeling_glm4v.Glm4vForConditionalGeneration.forward = glm4v_lce_forward
2190
+
2191
+ if model is not None:
2192
+ # The model instance already exists, so we need to additionally patch the
2193
+ # instance variables that reference already-instantiated modules
2194
+ if isinstance(model, (Glm4vForConditionalGeneration, Glm4vModel)):
2195
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
2196
+ # Not sure if it is subject to changes in the future.
2197
+ # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py#L1305
2198
+ text_model: Glm4vTextModel = model.language_model
2199
+ vision_model: Glm4vVisionModel = model.visual
2200
+ elif isinstance(model, Glm4vTextModel):
2201
+ text_model: Glm4vTextModel = model
2202
+ vision_model = None
2203
+ else:
2204
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
2205
+ raise TypeError(
2206
+ f"Unsupported glm4.1v model type. `model` must be `Glm4VLForConditionalGeneration`, `Glm4vVisionModel` or `Glm4vTextModel`. Got: {type(model)}"
2207
+ )
2208
+
2209
+ if vision_model is not None:
2210
+ for vision_block in vision_model.blocks:
2211
+ if rms_norm:
2212
+ _patch_rms_norm_module(vision_block.norm1)
2213
+ _patch_rms_norm_module(vision_block.norm2)
2214
+ if swiglu:
2215
+ _patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
2216
+
2217
+ if text_model is not None:
2218
+ if rms_norm:
2219
+ _patch_rms_norm_module(text_model.norm)
2220
+ for decoder_layer in text_model.layers:
2221
+ if swiglu:
2222
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
2223
+ if rms_norm:
2224
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2225
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2226
+ _patch_rms_norm_module(decoder_layer.post_self_attn_layernorm)
2227
+ _patch_rms_norm_module(decoder_layer.post_mlp_layernorm)
2228
+
2229
+
2230
+ def apply_liger_kernel_to_glm4v_moe(
2231
+ rope: bool = False,
2232
+ cross_entropy: bool = False,
2233
+ fused_linear_cross_entropy: bool = True,
2234
+ rms_norm: bool = True,
2235
+ swiglu: bool = True,
2236
+ model: PreTrainedModel = None,
2237
+ ) -> None:
2238
+ """
2239
+ Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
2240
+
2241
+ Args:
2242
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2243
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2244
+ fused_linear_cross_entropy (bool):
2245
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2246
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2247
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2248
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2249
+ swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
2250
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2251
+ loaded. Default is None.
2252
+ """
2253
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2254
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2255
+ )
2256
+
2257
+ from transformers.models.glm4v_moe import modeling_glm4v_moe
2258
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeForConditionalGeneration
2259
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeModel
2260
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextModel
2261
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeVisionModel
2262
+
2263
+ from liger_kernel.transformers.model.glm4v_moe import lce_forward as glm4v_moe_lce_forward
2264
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
2265
+
2266
+ if rope:
2267
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
2268
+ if rms_norm:
2269
+ modeling_glm4v_moe.Glm4vMoeRMSNorm = LigerRMSNormForGlm4
2270
+ modeling_glm4v_moe.Glm4vMoeTextRMSNorm = LigerRMSNormForGlm4
2271
+ if cross_entropy:
2272
+ from transformers.loss.loss_utils import nn
2273
+
2274
+ nn.functional.cross_entropy = liger_cross_entropy
2275
+ if fused_linear_cross_entropy:
2276
+ if model is not None:
2277
+ model.forward = MethodType(glm4v_moe_lce_forward, model)
2278
+ else:
2279
+ modeling_glm4v_moe.Glm4vMoeForConditionalGeneration.forward = glm4v_moe_lce_forward
2280
+
2281
+ if model is not None:
2282
+ # The model instance already exists, so we need to additionally patch the
2283
+ # instance variables that reference already-instantiated modules
2284
+ if isinstance(model, (Glm4vMoeForConditionalGeneration, Glm4vMoeModel)):
2285
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
2286
+ # Not sure if it is subject to changes in the future.
2287
+ # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py#L337
2288
+ text_model: Glm4vMoeTextModel = model.language_model
2289
+ vision_model: Glm4vMoeVisionModel = model.visual
2290
+ Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
2291
+ elif isinstance(model, Glm4vMoeTextModel):
2292
+ text_model: Glm4vMoeTextModel = model
2293
+ vision_model = None
2294
+ else:
2295
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
2296
+ raise TypeError(
2297
+ f"Unsupported glm4v_moe model type. `model` must be `Glm4vMoeForConditionalGeneration`, `Glm4vMoeVisionModel` or `Glm4vMoeTextModel`. Got: {type(model)}"
2298
+ )
2299
+
2300
+ if vision_model is not None:
2301
+ _patch_rms_norm_module(vision_model.post_conv_layernorm)
2302
+ _patch_rms_norm_module(vision_model.post_layernorm)
2303
+ for vision_block in vision_model.blocks:
2304
+ if rms_norm:
2305
+ _patch_rms_norm_module(vision_block.norm1)
2306
+ _patch_rms_norm_module(vision_block.norm2)
2307
+ if swiglu:
2308
+ _patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
2309
+
2310
+ if text_model is not None:
2311
+ if rms_norm:
2312
+ _patch_rms_norm_module(text_model.norm)
2313
+ for decoder_layer in text_model.layers:
2314
+ if swiglu:
2315
+ decoder_layer.mlp = _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2316
+ if rms_norm:
2317
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2318
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2319
+ if isinstance(Glm4vMoeTextMoE, type) and isinstance(decoder_layer.mlp, Glm4vMoeTextMoE):
2320
+ experts = getattr(decoder_layer.mlp, "experts", None)
2321
+ if experts is not None:
2322
+ for expert in experts:
2323
+ _patch_swiglu_module(expert, LigerSwiGLUMLP)
2324
+ if decoder_layer.mlp.shared_experts is not None:
2325
+ _patch_swiglu_module(decoder_layer.mlp.shared_experts, LigerSwiGLUMLP)
2326
+ for decoder_layer in text_model.layers:
2327
+ if rms_norm:
2328
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2329
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2330
+
2331
+
2332
+ def apply_liger_kernel_to_internvl(
2333
+ cross_entropy: bool = False,
2334
+ fused_linear_cross_entropy: bool = True,
2335
+ rms_norm: bool = True,
2336
+ layer_norm: bool = True,
2337
+ model: Optional[PreTrainedModel] = None,
2338
+ **kwargs,
2339
+ ) -> None:
2340
+ """
2341
+ Apply Liger kernels to replace original implementation in HuggingFace InternVL models.
2342
+ Due to the characteristics of InternVL, the model must be passed to apply Liger-Kernel's patch to other models connected to InternVL.
2343
+ However, if an LM not supported by Liger-Kernel is connected to InternVL, unexpected side effects may occur.
2344
+ NOTE: InternVL is not available in transformers<4.52.1
2345
+
2346
+ Args:
2347
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2348
+ fused_linear_cross_entropy (bool):
2349
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2350
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2351
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2352
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2353
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
2354
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2355
+ loaded. Default is None.
2356
+ """
2357
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2358
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2359
+ )
2360
+ import torch.nn as torch_nn
2361
+
2362
+ from transformers.models.internvl import modeling_internvl
2363
+ from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration
2364
+ from transformers.models.internvl.modeling_internvl import InternVLModel
2365
+ from transformers.models.internvl.modeling_internvl import InternVLVisionLayer
2366
+ from transformers.models.internvl.modeling_internvl import InternVLVisionModel
2367
+ from transformers.models.internvl.modeling_internvl import InternVLVisionRMSNorm
2368
+
2369
+ from liger_kernel.transformers.layer_norm import LigerLayerNorm
2370
+ from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
2371
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm
2372
+
2373
+ if layer_norm and model is None:
2374
+ modeling_internvl.nn.LayerNorm = LigerLayerNorm
2375
+
2376
+ if cross_entropy:
2377
+ logger.info("Apply liger cross entropy")
2378
+
2379
+ from transformers.loss.loss_utils import nn
2380
+
2381
+ nn.functional.cross_entropy = liger_cross_entropy
2382
+ if fused_linear_cross_entropy:
2383
+ modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
2384
+ if rms_norm:
2385
+ modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
2386
+
2387
+ if model is not None:
2388
+ # The model instance already exists, so we need to additionally patch the
2389
+ # instance variables that reference already-instantiated modules
2390
+ if isinstance(model, (InternVLForConditionalGeneration, InternVLModel)):
2391
+ # NOTE: language_model and visual properties can be accessed throught conditional class.
2392
+ text_model = model.language_model
2393
+ vision_model: InternVLVisionModel = model.vision_tower
2394
+ else:
2395
+ raise TypeError(
2396
+ f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}"
2397
+ )
2398
+
2399
+ text_model_name = model.config.text_config.model_type
2400
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2401
+
2402
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2403
+ if text_liger_fn:
2404
+ accept_params = inspect.signature(text_liger_fn).parameters
2405
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2406
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2407
+
2408
+ if remain_params:
2409
+ logger.warning(
2410
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2411
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2412
+ )
2413
+ text_kwargs["model"] = text_model
2414
+ text_liger_fn(**text_kwargs)
2415
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2416
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2417
+
2418
+ # Patch vision model RMSNorm layers
2419
+ if rms_norm:
2420
+ for encoder_layer in vision_model.encoder.layer:
2421
+ encoder_layer: InternVLVisionLayer
2422
+ if isinstance(encoder_layer.attention.q_norm, InternVLVisionRMSNorm):
2423
+ _patch_rms_norm_module(encoder_layer.attention.q_norm)
2424
+ if isinstance(encoder_layer.attention.k_norm, InternVLVisionRMSNorm):
2425
+ _patch_rms_norm_module(encoder_layer.attention.k_norm)
2426
+
2427
+ # Patch vision model LayerNorm layers
2428
+ if layer_norm:
2429
+ # Patch layernorm
2430
+ if isinstance(vision_model.layernorm, torch_nn.LayerNorm):
2431
+ _patch_layer_norm_module(vision_model.layernorm)
2432
+
2433
+ # Patch encoder layers
2434
+ for encoder_layer in vision_model.encoder.layer:
2435
+ encoder_layer: InternVLVisionLayer
2436
+ if isinstance(encoder_layer.layernorm_before, torch_nn.LayerNorm):
2437
+ _patch_layer_norm_module(encoder_layer.layernorm_before)
2438
+ if isinstance(encoder_layer.layernorm_after, torch_nn.LayerNorm):
2439
+ _patch_layer_norm_module(encoder_layer.layernorm_after)
2440
+
2441
+
2442
+ def apply_liger_kernel_to_smolvlm(
2443
+ cross_entropy: bool = False,
2444
+ fused_linear_cross_entropy: bool = True,
2445
+ rms_norm: bool = True,
2446
+ layer_norm: bool = True,
2447
+ model: Optional[PreTrainedModel] = None,
2448
+ **kwargs,
2449
+ ) -> None:
2450
+ """
2451
+ Apply Liger kernels to replace original implementation in HuggingFace SmolVLM models.
2452
+ Due to the characteristics of SmolVLM, the model must be passed to apply Liger-Kernel's patch to other models connected to SmolVLM.
2453
+ However, if an LM not supported by Liger-Kernel is connected to SmolVLM, unexpected side effects may occur.
2454
+ NOTE: SmolVLM is not available in transformers<4.50.0
2455
+
2456
+ Args:
2457
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2458
+ fused_linear_cross_entropy (bool):
2459
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2460
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2461
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2462
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2463
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
2464
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2465
+ loaded. Default is None.
2466
+ """
2467
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2468
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2469
+ )
2470
+
2471
+ from transformers.models.smolvlm import modeling_smolvlm
2472
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMEncoderLayer
2473
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration
2474
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMModel
2475
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer
2476
+
2477
+ from liger_kernel.transformers.model.smolvlm import lce_forward as smolvlm_lce_forward
2478
+
2479
+ # Patch LayerNorm for vision model if model is not provided (pre-initialization)
2480
+ if layer_norm and model is None:
2481
+ modeling_smolvlm.nn.LayerNorm = LigerLayerNorm
2482
+
2483
+ if cross_entropy:
2484
+ logger.info("Apply liger cross entropy")
2485
+
2486
+ from transformers.loss.loss_utils import nn
2487
+
2488
+ nn.functional.cross_entropy = liger_cross_entropy
2489
+ if fused_linear_cross_entropy:
2490
+ if model is not None:
2491
+ model.forward = MethodType(smolvlm_lce_forward, model)
2492
+ else:
2493
+ modeling_smolvlm.SmolVLMForConditionalGeneration.forward = smolvlm_lce_forward
2494
+ if rms_norm:
2495
+ modeling_smolvlm.SmolVLMRMSNorm = LigerRMSNorm
1240
2496
 
1241
2497
  if model is not None:
1242
2498
  # The model instance already exists, so we need to additionally patch the
1243
2499
  # instance variables that reference already-instantiated modules
2500
+ if isinstance(model, SmolVLMForConditionalGeneration):
2501
+ text_model = model.model.text_model
2502
+ vision_model: SmolVLMVisionTransformer = model.model.vision_model
2503
+ elif isinstance(model, SmolVLMModel):
2504
+ text_model = model.text_model
2505
+ vision_model: SmolVLMVisionTransformer = model.vision_model
2506
+ else:
2507
+ raise TypeError(
2508
+ f"Unsupported smolvlm model type. `model` must be `SmolVLMForConditionalGeneration`, `SmolVLMModel`. Got: {type(model)}"
2509
+ )
2510
+
2511
+ text_model_name = model.config.text_config.model_type
2512
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2513
+
2514
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2515
+ if text_liger_fn:
2516
+ accept_params = inspect.signature(text_liger_fn).parameters
2517
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2518
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2519
+
2520
+ if remain_params:
2521
+ logger.warning(
2522
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2523
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2524
+ )
2525
+ text_kwargs["model"] = text_model
2526
+ text_liger_fn(**text_kwargs)
2527
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2528
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2529
+
2530
+ # Patch vision model LayerNorm layers
2531
+ if layer_norm:
2532
+ # Patch post_layernorm
2533
+ _patch_layer_norm_module(vision_model.post_layernorm)
2534
+
2535
+ # Patch encoder layers
2536
+ for encoder_layer in vision_model.encoder.layers:
2537
+ encoder_layer: SmolVLMEncoderLayer
2538
+ _patch_layer_norm_module(encoder_layer.layer_norm1)
2539
+ _patch_layer_norm_module(encoder_layer.layer_norm2)
2540
+
2541
+
2542
+ def apply_liger_kernel_to_falcon_h1(
2543
+ rope: bool = True,
2544
+ cross_entropy: bool = False,
2545
+ fused_linear_cross_entropy: bool = True,
2546
+ rms_norm: bool = True,
2547
+ swiglu: bool = False,
2548
+ model: PreTrainedModel = None,
2549
+ ) -> None:
2550
+ """
2551
+ Apply Liger kernels to replace original implementation in HuggingFace Falcon-H1 models
2552
+ Args:
2553
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2554
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
2555
+ fused_linear_cross_entropy (bool):
2556
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
2557
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2558
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2559
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
2560
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
2561
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2562
+ loaded. Default is None.
2563
+ """
2564
+
2565
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2566
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2567
+ )
2568
+
2569
+ from transformers.models.falcon_h1 import modeling_falcon_h1
2570
+ from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1Model
2571
+
2572
+ if rope:
2573
+ logger.info("Apply liger rotary pos emb.")
2574
+ modeling_falcon_h1.apply_rotary_pos_emb = liger_rotary_pos_emb
2575
+ if rms_norm:
2576
+ logger.info("Apply liger RMSNorm")
2577
+ modeling_falcon_h1.FalconH1RMSNorm = LigerRMSNorm
2578
+ if swiglu:
2579
+ logger.warning("LigerSwiGLUMLP is not available for Falcon-H1 models. There will be no effect.")
2580
+
2581
+ if cross_entropy:
2582
+ logger.info("Apply liger cross entropy")
2583
+ from transformers.loss.loss_utils import nn
2584
+
2585
+ nn.functional.cross_entropy = liger_cross_entropy
2586
+
2587
+ if fused_linear_cross_entropy:
2588
+ if model is not None:
2589
+ model.forward = MethodType(falcon_h1_lce_forward, model)
2590
+ else:
2591
+ modeling_falcon_h1.FalconH1ForCausalLM.forward = falcon_h1_lce_forward
2592
+
2593
+ if model is not None:
2594
+ # The model instance already exists, so we need to additionally patch the
2595
+ # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
1244
2596
 
1245
2597
  # get the base model from the model instance
1246
- base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
2598
+ base_model: FalconH1Model = getattr(model, model.base_model_prefix, model)
1247
2599
 
1248
2600
  if rms_norm:
1249
- _patch_rms_norm_module(base_model.norm)
2601
+ _patch_rms_norm_module(base_model.final_layernorm)
1250
2602
 
1251
2603
  for decoder_layer in base_model.layers:
1252
2604
  if swiglu:
1253
- _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
2605
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1254
2606
  if rms_norm:
1255
2607
  _patch_rms_norm_module(decoder_layer.input_layernorm)
1256
- _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2608
+ _patch_rms_norm_module(decoder_layer.pre_ff_layernorm)
1257
2609
 
1258
2610
 
1259
- def apply_liger_kernel_to_olmo2(
1260
- rope: bool = True,
2611
+ def apply_liger_kernel_to_qwen3_next(
2612
+ rope: bool = False,
1261
2613
  cross_entropy: bool = False,
1262
2614
  fused_linear_cross_entropy: bool = True,
1263
2615
  rms_norm: bool = True,
@@ -1265,17 +2617,17 @@ def apply_liger_kernel_to_olmo2(
1265
2617
  model: PreTrainedModel = None,
1266
2618
  ) -> None:
1267
2619
  """
1268
- Apply Liger kernels to replace original implementation in HuggingFace OLMO2 models.
2620
+ Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
1269
2621
 
1270
2622
  Args:
1271
- rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2623
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
1272
2624
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1273
2625
  fused_linear_cross_entropy (bool):
1274
2626
  Whether to apply Liger's fused linear cross entropy loss. Default is True.
1275
2627
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1276
2628
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1277
2629
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1278
- swiglu (bool): Whether to apply Liger's SwiGLU Olmo2MLP. Default is True.
2630
+ swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
1279
2631
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1280
2632
  loaded. Default is None.
1281
2633
  """
@@ -1283,40 +2635,185 @@ def apply_liger_kernel_to_olmo2(
1283
2635
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
1284
2636
  )
1285
2637
 
1286
- from transformers.models.olmo2 import modeling_olmo2
1287
- from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
2638
+ from transformers.models.qwen3_next import modeling_qwen3_next
2639
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM
2640
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP
2641
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextModel
2642
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
1288
2643
 
1289
- from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
2644
+ from liger_kernel.transformers.model.qwen3_next import lce_forward as qwen3_next_lce_forward
2645
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
2646
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
1290
2647
 
1291
2648
  if rope:
1292
- modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
2649
+ # It might enocunter nan issue
2650
+ # modeling_qwen3_next.apply_rotary_pos_emb = liger_rotary_pos_emb
2651
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3Next models.")
1293
2652
  if rms_norm:
1294
- modeling_olmo2.Olmo2RMSNorm = partial(LigerRMSNorm, in_place=False)
2653
+ modeling_qwen3_next.Qwen3NextRMSNorm = LigerRMSNormForQwen3Next
2654
+ if cross_entropy:
2655
+ from transformers.loss.loss_utils import nn
2656
+
2657
+ nn.functional.cross_entropy = liger_cross_entropy
2658
+ if fused_linear_cross_entropy:
2659
+ if model is not None:
2660
+ if isinstance(model, Qwen3NextForCausalLM):
2661
+ model.forward = MethodType(qwen3_next_lce_forward, model)
2662
+ else:
2663
+ raise TypeError(
2664
+ f" fused_linear_cross_entropy is only applicable on Qwen3NextForCausalLM. Got: {type(model)}"
2665
+ )
2666
+ else:
2667
+ modeling_qwen3_next.Qwen3NextForCausalLM.forward = qwen3_next_lce_forward
1295
2668
  if swiglu:
1296
- modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
2669
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2670
+ modeling_qwen3_next.Qwen3NextMLP = LigerQwen3MoeSwiGLUMLP
2671
+
2672
+ if model is not None:
2673
+ # The model instance already exists, so we need to additionally patch the
2674
+ # instance variables that reference already-instantiated modules
2675
+ if isinstance(model, (Qwen3NextForCausalLM, Qwen3NextModel)):
2676
+ base_model: Qwen3NextForCausalLM = getattr(model, model.base_model_prefix, model)
2677
+ else:
2678
+ raise TypeError(
2679
+ f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}"
2680
+ )
2681
+
2682
+ if rms_norm:
2683
+ _patch_rms_norm_module(base_model.norm)
2684
+
2685
+ for decoder_layer in base_model.layers:
2686
+ if rms_norm:
2687
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2688
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2689
+
2690
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2691
+ if swiglu:
2692
+ if isinstance(decoder_layer.mlp, Qwen3NextMLP):
2693
+ _patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
2694
+ if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock):
2695
+ _patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP)
2696
+ experts = getattr(decoder_layer.mlp, "experts", None)
2697
+ if experts is not None:
2698
+ for expert in experts:
2699
+ _patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
2700
+
2701
+
2702
+ def apply_liger_kernel_to_hunyuan_v1_dense(
2703
+ rope: bool = True,
2704
+ cross_entropy: bool = False,
2705
+ fused_linear_cross_entropy: bool = True,
2706
+ rms_norm: bool = True,
2707
+ swiglu: bool = True,
2708
+ model: PreTrainedModel = None,
2709
+ ) -> None:
2710
+ """
2711
+ Apply Liger kernels to replace original implementation in HuggingFace Hunyuan v1 dense models.
2712
+ """
2713
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2714
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2715
+ )
2716
+
2717
+ from transformers.models.hunyuan_v1_dense import modeling_hunyuan_v1_dense
2718
+ from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1Model
2719
+
2720
+ from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_lce_forward
2721
+ from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
2722
+
2723
+ if rope:
2724
+ modeling_hunyuan_v1_dense.apply_rotary_pos_emb = liger_rotary_pos_emb
2725
+
2726
+ if rms_norm:
2727
+ modeling_hunyuan_v1_dense.HunYuanDenseV1RMSNorm = LigerRMSNorm
2728
+
1297
2729
  if cross_entropy:
1298
2730
  from transformers.loss.loss_utils import nn
1299
2731
 
1300
2732
  nn.functional.cross_entropy = liger_cross_entropy
2733
+
1301
2734
  if fused_linear_cross_entropy:
1302
- modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
2735
+ if model is not None:
2736
+ model.forward = MethodType(hunyuan_v1_lce_forward, model)
2737
+ else:
2738
+ modeling_hunyuan_v1_dense.HunYuanDenseV1ForCausalLM.forward = hunyuan_v1_lce_forward
2739
+
2740
+ if swiglu:
2741
+ modeling_hunyuan_v1_dense.HunYuanDenseV1MLP = LigerHunyuanV1SwiGLUMLP
1303
2742
 
1304
2743
  if model is not None:
1305
2744
  # The model instance already exists, so we need to additionally patch the
1306
2745
  # instance variables that reference already-instantiated modules
1307
2746
 
1308
2747
  # get the base model from the model instance
1309
- base_model: Olmo2Model = getattr(model, model.base_model_prefix, model)
2748
+ base_model: HunYuanDenseV1Model = getattr(model, model.base_model_prefix, model)
1310
2749
 
1311
2750
  if rms_norm:
1312
2751
  _patch_rms_norm_module(base_model.norm)
2752
+ for decoder_layer in base_model.layers:
2753
+ if swiglu:
2754
+ _patch_swiglu_module(decoder_layer.mlp, LigerHunyuanV1SwiGLUMLP)
2755
+ if rms_norm:
2756
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2757
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2758
+
2759
+
2760
+ def apply_liger_kernel_to_hunyuan_v1_moe(
2761
+ rope: bool = True,
2762
+ cross_entropy: bool = False,
2763
+ fused_linear_cross_entropy: bool = True,
2764
+ rms_norm: bool = True,
2765
+ swiglu: bool = True,
2766
+ model: PreTrainedModel = None,
2767
+ ) -> None:
2768
+ """
2769
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
2770
+ """
2771
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2772
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2773
+ )
2774
+
2775
+ from transformers.models.hunyuan_v1_moe import modeling_hunyuan_v1_moe
2776
+ from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1Model
2777
+
2778
+ from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_moe_lce_forward
2779
+ from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
2780
+
2781
+ if rope:
2782
+ modeling_hunyuan_v1_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
2783
+
2784
+ if rms_norm:
2785
+ modeling_hunyuan_v1_moe.HunYuanMoEV1RMSNorm = LigerRMSNorm
2786
+
2787
+ if cross_entropy:
2788
+ from transformers.loss.loss_utils import nn
2789
+
2790
+ nn.functional.cross_entropy = liger_cross_entropy
2791
+
2792
+ if fused_linear_cross_entropy:
2793
+ if model is not None:
2794
+ model.forward = MethodType(hunyuan_v1_moe_lce_forward, model)
2795
+ else:
2796
+ modeling_hunyuan_v1_moe.HunYuanMoEV1ForCausalLM.forward = hunyuan_v1_moe_lce_forward
1313
2797
 
2798
+ if swiglu:
2799
+ modeling_hunyuan_v1_moe.HunYuanMoEV1MLP = LigerHunyuanV1SwiGLUMLP
2800
+
2801
+ if model is not None:
2802
+ # The model instance already exists, so we need to additionally patch the
2803
+ # instance variables that reference already-instantiated modules
2804
+
2805
+ # get the base model from the model instance
2806
+ base_model: HunYuanMoEV1Model = getattr(model, model.base_model_prefix, model)
2807
+
2808
+ if rms_norm:
2809
+ _patch_rms_norm_module(base_model.norm)
1314
2810
  for decoder_layer in base_model.layers:
1315
2811
  if swiglu:
1316
- _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2812
+ for mlp_expert in decoder_layer.mlp.experts:
2813
+ _patch_swiglu_module(mlp_expert, LigerHunyuanV1SwiGLUMLP)
1317
2814
  if rms_norm:
1318
- _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
1319
- _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
2815
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2816
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1320
2817
 
1321
2818
 
1322
2819
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
@@ -1325,7 +2822,14 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1325
2822
  "gemma2": apply_liger_kernel_to_gemma2,
1326
2823
  "gemma3_text": apply_liger_kernel_to_gemma3_text,
1327
2824
  "gemma3": apply_liger_kernel_to_gemma3,
2825
+ "glm4": apply_liger_kernel_to_glm4,
2826
+ "glm4v": apply_liger_kernel_to_glm4v,
2827
+ "glm4v_moe": apply_liger_kernel_to_glm4v_moe,
2828
+ "gpt_oss": apply_liger_kernel_to_gpt_oss,
2829
+ "internvl": apply_liger_kernel_to_internvl,
1328
2830
  "llama": apply_liger_kernel_to_llama,
2831
+ "llama4_text": apply_liger_kernel_to_llama4,
2832
+ "llama4": apply_liger_kernel_to_llama4,
1329
2833
  "llava": apply_liger_kernel_to_llava,
1330
2834
  "granite": apply_liger_kernel_to_granite,
1331
2835
  "mllama": apply_liger_kernel_to_mllama,
@@ -1333,11 +2837,26 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1333
2837
  "mistral": apply_liger_kernel_to_mistral,
1334
2838
  "mixtral": apply_liger_kernel_to_mixtral,
1335
2839
  "olmo2": apply_liger_kernel_to_olmo2,
2840
+ "olmo3": apply_liger_kernel_to_olmo3,
1336
2841
  "qwen2": apply_liger_kernel_to_qwen2,
2842
+ "qwen3": apply_liger_kernel_to_qwen3,
2843
+ "qwen3_moe": apply_liger_kernel_to_qwen3_moe,
1337
2844
  "qwen2_vl": apply_liger_kernel_to_qwen2_vl,
2845
+ "qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
1338
2846
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
2847
+ "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
2848
+ "qwen3_next": apply_liger_kernel_to_qwen3_next,
2849
+ "qwen3_vl": apply_liger_kernel_to_qwen3_vl,
2850
+ "qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
2851
+ "qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
2852
+ "qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
2853
+ "smollm3": apply_liger_kernel_to_smollm3,
1339
2854
  "phi3": apply_liger_kernel_to_phi3,
1340
2855
  "paligemma": apply_liger_kernel_to_paligemma,
2856
+ "falcon_h1": apply_liger_kernel_to_falcon_h1,
2857
+ "smolvlm": apply_liger_kernel_to_smolvlm,
2858
+ "hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
2859
+ "hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
1341
2860
  }
1342
2861
 
1343
2862
 
@@ -1395,7 +2914,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
1395
2914
  return
1396
2915
 
1397
2916
  apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
1398
-
1399
2917
  apply_fn_signature = inspect.signature(apply_fn)
1400
2918
 
1401
2919
  # Filter out the keyword arguments that are not supported by the apply function