liger-kernel-nightly 0.5.10.dev20250611191801__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (107) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +54 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
  7. liger_kernel/chunked_loss/grpo_loss.py +46 -9
  8. liger_kernel/chunked_loss/jsd_loss.py +44 -13
  9. liger_kernel/ops/__init__.py +141 -0
  10. liger_kernel/ops/backends/README.md +151 -0
  11. liger_kernel/ops/backends/__init__.py +13 -0
  12. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  13. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  14. liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
  15. liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
  16. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  17. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  18. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  19. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  20. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  21. liger_kernel/ops/backends/registry.py +61 -0
  22. liger_kernel/ops/cross_entropy.py +130 -64
  23. liger_kernel/ops/dyt.py +5 -4
  24. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  25. liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
  26. liger_kernel/ops/geglu.py +6 -4
  27. liger_kernel/ops/group_norm.py +7 -7
  28. liger_kernel/ops/grpo_loss.py +3 -1
  29. liger_kernel/ops/kl_div.py +8 -11
  30. liger_kernel/ops/layer_norm.py +135 -80
  31. liger_kernel/ops/llama4_rope.py +225 -0
  32. liger_kernel/ops/poly_norm.py +390 -0
  33. liger_kernel/ops/rms_norm.py +148 -71
  34. liger_kernel/ops/rope.py +1 -1
  35. liger_kernel/ops/swiglu.py +1 -1
  36. liger_kernel/ops/tiled_mlp.py +136 -0
  37. liger_kernel/ops/utils.py +14 -0
  38. liger_kernel/transformers/__init__.py +65 -0
  39. liger_kernel/transformers/auto_model.py +21 -0
  40. liger_kernel/transformers/cross_entropy.py +9 -4
  41. liger_kernel/transformers/dyt.py +1 -1
  42. liger_kernel/transformers/experimental/__init__.py +5 -0
  43. liger_kernel/transformers/experimental/embedding.py +1 -1
  44. liger_kernel/transformers/functional.py +56 -24
  45. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  46. liger_kernel/transformers/fused_linear_cross_entropy.py +17 -5
  47. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  48. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  49. liger_kernel/transformers/geglu.py +1 -1
  50. liger_kernel/transformers/group_norm.py +1 -1
  51. liger_kernel/transformers/grpo_loss.py +57 -2
  52. liger_kernel/transformers/jsd.py +1 -1
  53. liger_kernel/transformers/kl_div.py +1 -1
  54. liger_kernel/transformers/layer_norm.py +1 -1
  55. liger_kernel/transformers/llama4_rope.py +93 -0
  56. liger_kernel/transformers/model/exaone4.py +136 -0
  57. liger_kernel/transformers/model/falcon_h1.py +122 -0
  58. liger_kernel/transformers/model/gemma.py +28 -8
  59. liger_kernel/transformers/model/gemma2.py +34 -11
  60. liger_kernel/transformers/model/gemma3.py +102 -112
  61. liger_kernel/transformers/model/glm4.py +18 -5
  62. liger_kernel/transformers/model/glm4v.py +163 -0
  63. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  64. liger_kernel/transformers/model/gpt_oss.py +211 -0
  65. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  66. liger_kernel/transformers/model/internvl.py +157 -0
  67. liger_kernel/transformers/model/llama.py +26 -7
  68. liger_kernel/transformers/model/llama4.py +121 -0
  69. liger_kernel/transformers/model/llava.py +18 -6
  70. liger_kernel/transformers/model/loss_utils.py +34 -3
  71. liger_kernel/transformers/model/mistral.py +17 -10
  72. liger_kernel/transformers/model/mixtral.py +24 -9
  73. liger_kernel/transformers/model/mllama.py +18 -7
  74. liger_kernel/transformers/model/olmo2.py +18 -5
  75. liger_kernel/transformers/model/olmo3.py +142 -0
  76. liger_kernel/transformers/model/output_classes.py +147 -0
  77. liger_kernel/transformers/model/paligemma.py +42 -5
  78. liger_kernel/transformers/model/phi3.py +24 -159
  79. liger_kernel/transformers/model/qwen2.py +26 -4
  80. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  81. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  82. liger_kernel/transformers/model/qwen3.py +22 -6
  83. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  84. liger_kernel/transformers/model/qwen3_next.py +146 -0
  85. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  86. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  87. liger_kernel/transformers/model/smollm3.py +199 -0
  88. liger_kernel/transformers/model/smolvlm.py +158 -0
  89. liger_kernel/transformers/monkey_patch.py +1423 -100
  90. liger_kernel/transformers/multi_token_attention.py +2 -2
  91. liger_kernel/transformers/poly_norm.py +42 -0
  92. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  93. liger_kernel/transformers/rms_norm.py +15 -5
  94. liger_kernel/transformers/rope.py +45 -1
  95. liger_kernel/transformers/softmax.py +1 -1
  96. liger_kernel/transformers/sparsemax.py +1 -1
  97. liger_kernel/transformers/swiglu.py +18 -1
  98. liger_kernel/transformers/tiled_mlp.py +125 -0
  99. liger_kernel/transformers/tvd.py +1 -1
  100. liger_kernel/utils.py +52 -0
  101. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +37 -25
  102. liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
  103. liger_kernel_nightly-0.5.10.dev20250611191801.dist-info/RECORD +0 -95
  104. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
  105. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
  106. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
  107. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,9 @@ import inspect
2
2
  import logging
3
3
 
4
4
  from functools import partial
5
+ from types import MethodType
5
6
  from typing import Callable
7
+ from typing import Optional
6
8
 
7
9
  import transformers
8
10
 
@@ -13,10 +15,12 @@ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
13
15
  from liger_kernel.transformers.functional import liger_cross_entropy
14
16
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
15
17
  from liger_kernel.transformers.layer_norm import LigerLayerNorm
18
+ from liger_kernel.transformers.model.falcon_h1 import lce_forward as falcon_h1_lce_forward
16
19
  from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
17
20
  from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
18
21
  from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
19
22
  from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
23
+ from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward
20
24
  from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
21
25
  from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
22
26
  from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
@@ -25,12 +29,13 @@ from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_f
25
29
  from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
26
30
  from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
27
31
  from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
28
- from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
29
32
  from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
30
33
  from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
34
+ from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
31
35
  from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
32
36
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
33
37
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
38
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb_vision
34
39
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
35
40
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
36
41
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
@@ -54,7 +59,7 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
54
59
  module.__dict__[method_name] = new_method.__get__(module, module.__class__)
55
60
 
56
61
 
57
- def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
62
+ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None):
58
63
  # Check if the module is a PEFT ModulesToSaveWrapper
59
64
  # If it is, we need to patch the modules_to_save.default and original_modules
60
65
  if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
@@ -64,26 +69,29 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
64
69
  getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
65
70
  )
66
71
  module.modules_to_save.default.in_place = in_place
72
+ module.modules_to_save.default.row_mode = row_mode
67
73
  module.original_module.offset = offset
68
74
  module.original_module.casting_mode = casting_mode
69
75
  module.original_module.variance_epsilon = (
70
76
  getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
71
77
  )
72
78
  module.original_module.in_place = in_place
79
+ module.original_module.row_mode = row_mode
73
80
  _bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
74
81
  _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
75
82
  _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
76
83
  _bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
77
- module.modules_to_save.default.__class__.__name__ = LigerRMSNorm.__name__
78
- module.original_module.__class__.__name__ = LigerRMSNorm.__name__
84
+ _bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerRMSNorm.__name__)
85
+ _bind_method_to_module(module.original_module, "_get_name", lambda self: LigerRMSNorm.__name__)
79
86
  else:
80
87
  module.offset = offset
81
88
  module.casting_mode = casting_mode
82
89
  module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
83
90
  module.in_place = in_place
91
+ module.row_mode = row_mode
84
92
  _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
85
93
  _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
86
- module.__class__.__name__ = LigerRMSNorm.__name__
94
+ _bind_method_to_module(module, "_get_name", lambda self: LigerRMSNorm.__name__)
87
95
 
88
96
 
89
97
  def _patch_layer_norm_module(module, eps=1e-6):
@@ -105,28 +113,28 @@ def _patch_layer_norm_module(module, eps=1e-6):
105
113
  module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
106
114
  module, "normalized_shape", None
107
115
  )
108
- _bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
109
- _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
110
- _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
111
- _bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
112
- module.modules_to_save.default.__class__.__name__ = LigerLayerNorm.__name__
113
- module.original_module.__class__.__name__ = LigerLayerNorm.__name__
116
+ _bind_method_to_module(module.modules_to_save.default, "forward", LigerLayerNorm.forward)
117
+ _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerLayerNorm.extra_repr)
118
+ _bind_method_to_module(module.original_module, "forward", LigerLayerNorm.forward)
119
+ _bind_method_to_module(module.original_module, "extra_repr", LigerLayerNorm.extra_repr)
120
+ _bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerLayerNorm.__name__)
121
+ _bind_method_to_module(module.original_module, "_get_name", lambda self: LigerLayerNorm.__name__)
114
122
  else:
115
123
  module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
116
124
  module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
117
125
  _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
118
126
  _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
119
- module.__class__.__name__ = LigerLayerNorm.__name__
127
+ _bind_method_to_module(module, "_get_name", lambda self: LigerLayerNorm.__name__)
120
128
 
121
129
 
122
130
  def _patch_swiglu_module(module, liger_module):
123
131
  _bind_method_to_module(module, "forward", liger_module.forward)
124
- module.__class__.__name__ = liger_module.__name__
132
+ _bind_method_to_module(module, "_get_name", lambda self: liger_module.__name__)
125
133
 
126
134
 
127
135
  def _patch_geglu_module(module):
128
136
  _bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
129
- module.__class__.__name__ = LigerGEGLUMLP.__name__
137
+ _bind_method_to_module(module, "_get_name", lambda self: LigerGEGLUMLP.__name__)
130
138
 
131
139
 
132
140
  def apply_liger_kernel_to_granite(
@@ -257,10 +265,16 @@ def apply_liger_kernel_to_llama(
257
265
 
258
266
  if fused_linear_cross_entropy:
259
267
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
260
- modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
268
+ if model is not None:
269
+ model.forward = MethodType(llama_lce_forward, model)
270
+ else:
271
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
261
272
  else: # if version < 4.46.1
262
273
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
263
- modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
274
+ if model is not None:
275
+ model.forward = MethodType(llama_lce_forward_deprecated, model)
276
+ else:
277
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
264
278
 
265
279
  if model is not None:
266
280
  # The model instance already exists, so we need to additionally patch the
@@ -280,6 +294,77 @@ def apply_liger_kernel_to_llama(
280
294
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
281
295
 
282
296
 
297
+ def apply_liger_kernel_to_smollm3(
298
+ rope: bool = True,
299
+ cross_entropy: bool = False,
300
+ fused_linear_cross_entropy: bool = True,
301
+ rms_norm: bool = True,
302
+ swiglu: bool = True,
303
+ model: PreTrainedModel = None,
304
+ ) -> None:
305
+ """
306
+ Apply Liger kernels to replace original implementation in HuggingFace SmolLM3 model
307
+
308
+ Args:
309
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
310
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
311
+ fused_linear_cross_entropy (bool):
312
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
313
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
314
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
315
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
316
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
317
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
318
+ loaded. Default is None.
319
+ """
320
+
321
+ assert not (cross_entropy and fused_linear_cross_entropy), (
322
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
323
+ )
324
+
325
+ from transformers.models.smollm3 import modeling_smollm3
326
+ from transformers.models.smollm3.modeling_smollm3 import SmolLM3Model
327
+
328
+ if rope:
329
+ modeling_smollm3.apply_rotary_pos_emb = liger_rotary_pos_emb
330
+ if rms_norm:
331
+ modeling_smollm3.SmolLM3RMSNorm = LigerRMSNorm
332
+ if swiglu:
333
+ modeling_smollm3.SmolLM3MLP = LigerSwiGLUMLP
334
+
335
+ if cross_entropy:
336
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
337
+ from transformers.loss.loss_utils import nn
338
+
339
+ nn.functional.cross_entropy = liger_cross_entropy
340
+ else:
341
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
342
+ modeling_smollm3.CrossEntropyLoss = LigerCrossEntropyLoss
343
+
344
+ if fused_linear_cross_entropy:
345
+ if model is not None:
346
+ model.forward = MethodType(smollm3_lce_forward, model)
347
+ else:
348
+ modeling_smollm3.SmolLM3ForCausalLM.forward = smollm3_lce_forward
349
+
350
+ if model is not None:
351
+ # The model instance already exists, so we need to additionally patch the
352
+ # instance variables that reference already-instantiated modules (e.g. SmolLM3RMSNorm or SmolLM3MLP)
353
+
354
+ # get the base model from the model instance
355
+ base_model: SmolLM3Model = getattr(model, model.base_model_prefix, model)
356
+
357
+ if rms_norm:
358
+ _patch_rms_norm_module(base_model.norm)
359
+
360
+ for decoder_layer in base_model.layers:
361
+ if swiglu:
362
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
363
+ if rms_norm:
364
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
365
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
366
+
367
+
283
368
  def apply_liger_kernel_to_llava(
284
369
  cross_entropy: bool = False,
285
370
  fused_linear_cross_entropy: bool = True,
@@ -315,9 +400,15 @@ def apply_liger_kernel_to_llava(
315
400
  modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
316
401
  if fused_linear_cross_entropy:
317
402
  if transformer_version >= version.parse("4.52.0"):
318
- modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
403
+ if model is not None:
404
+ model.forward = MethodType(llava_lce_forward, model)
405
+ else:
406
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
319
407
  elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
320
- modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
408
+ if model is not None:
409
+ model.forward = MethodType(llava_lce_forward_deprecated, model)
410
+ else:
411
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
321
412
  else: # if version < 4.49.0
322
413
  logger.warning(
323
414
  "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
@@ -339,7 +430,7 @@ def apply_liger_kernel_to_llava(
339
430
  f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
340
431
  f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
341
432
  )
342
- text_kwargs["model"] = model.language_model
433
+ text_kwargs["model"] = model.model.language_model
343
434
  text_liger_fn(**text_kwargs)
344
435
  elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
345
436
  logger.warning(f"{text_model_name} is not supported by Liger kernel.")
@@ -354,12 +445,103 @@ def apply_liger_kernel_to_llava(
354
445
  f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
355
446
  f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
356
447
  )
357
- vision_kwargs["model"] = model.vision_tower
448
+ vision_kwargs["model"] = model.model.vision_tower
358
449
  vision_liger_fn(**vision_kwargs)
359
450
  elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
360
451
  logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
361
452
 
362
453
 
454
+ def apply_liger_kernel_to_llama4(
455
+ rope: bool = True,
456
+ cross_entropy: bool = False,
457
+ fused_linear_cross_entropy: bool = True,
458
+ rms_norm: bool = True,
459
+ swiglu: bool = True,
460
+ model: PreTrainedModel = None,
461
+ layer_norm: bool = True,
462
+ ) -> None:
463
+ """
464
+ Apply Liger kernels to replace original implementation in HuggingFace Llama4 models.
465
+
466
+ Args:
467
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
468
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
469
+ fused_linear_cross_entropy (bool):
470
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
471
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
472
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
473
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
474
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
475
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
476
+ loaded. Default is None.
477
+ """
478
+ assert not (cross_entropy and fused_linear_cross_entropy), (
479
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
480
+ )
481
+
482
+ from transformers.models.llama4 import modeling_llama4
483
+ from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
484
+ from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
485
+ from transformers.models.llama4.modeling_llama4 import Llama4TextModel
486
+ from transformers.models.llama4.modeling_llama4 import Llama4VisionModel
487
+
488
+ from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
489
+
490
+ if rope:
491
+ from liger_kernel.transformers.llama4_rope import apply_liger_llama4_rope_full
492
+
493
+ apply_liger_llama4_rope_full(modeling_llama4)
494
+ if rms_norm:
495
+ modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
496
+ if swiglu:
497
+ modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
498
+
499
+ if cross_entropy:
500
+ modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
501
+
502
+ if fused_linear_cross_entropy:
503
+ modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
504
+
505
+ if model is not None:
506
+ # The model instance already exists, so we need to additionally patch the
507
+ # instance variables that reference already-instantiated modules
508
+ if isinstance(model, Llama4ForConditionalGeneration):
509
+ language_model: Llama4ForCausalLM = model.language_model
510
+ vision_model: Llama4VisionModel = model.vision_model
511
+ text_model: Llama4TextModel = language_model.model
512
+ elif isinstance(model, Llama4ForCausalLM):
513
+ text_model = model.model
514
+ vision_model = None
515
+ elif isinstance(model, Llama4TextModel):
516
+ text_model = model
517
+ vision_model = None
518
+
519
+ else:
520
+ raise ValueError(f"Unsupported Llama4 model type: {type(model)}")
521
+
522
+ if text_model:
523
+ if rms_norm:
524
+ _patch_rms_norm_module(text_model.norm)
525
+ for decoder_layer in text_model.layers:
526
+ if swiglu:
527
+ if decoder_layer.is_moe_layer:
528
+ _patch_swiglu_module(decoder_layer.feed_forward.shared_expert, LigerSwiGLUMLP)
529
+ else:
530
+ _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
531
+ if rms_norm:
532
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
533
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
534
+
535
+ if vision_model:
536
+ _patch_layer_norm_module(vision_model.layernorm_pre)
537
+ _patch_layer_norm_module(vision_model.layernorm_post)
538
+
539
+ for layer in vision_model.model.layers:
540
+ if layer_norm:
541
+ _patch_layer_norm_module(layer.input_layernorm)
542
+ _patch_layer_norm_module(layer.post_attention_layernorm)
543
+
544
+
363
545
  def apply_liger_kernel_to_mllama(
364
546
  rope: bool = True,
365
547
  cross_entropy: bool = False,
@@ -401,7 +583,7 @@ def apply_liger_kernel_to_mllama(
401
583
 
402
584
  if rope:
403
585
  modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
404
- if layer_norm:
586
+ if layer_norm and model is None:
405
587
  modeling_mllama.nn.LayerNorm = LigerLayerNorm
406
588
  if rms_norm:
407
589
  modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
@@ -417,19 +599,28 @@ def apply_liger_kernel_to_mllama(
417
599
  modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
418
600
  if fused_linear_cross_entropy:
419
601
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
420
- modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
602
+ if model is not None:
603
+ model.forward = MethodType(mllama_lce_forward, model)
604
+ else:
605
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
421
606
  else: # if version < 4.46.1
422
607
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
423
- modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
608
+ if model is not None:
609
+ model.forward = MethodType(mllama_lce_forward_deprecated, model)
610
+ else:
611
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
424
612
 
425
613
  if model is not None:
426
614
  # The model instance already exists, so we need to additionally patch the
427
615
  # instance variables that reference already-instantiated modules
428
616
 
429
617
  if isinstance(model, MllamaForConditionalGeneration):
430
- language_model: MllamaForCausalLM = model.language_model
431
- vision_model: MllamaVisionModel = model.vision_model
432
- text_model: MllamaTextModel = language_model
618
+ language_model: MllamaForCausalLM = model.model.language_model
619
+ vision_model: MllamaVisionModel = model.model.vision_model
620
+ if isinstance(language_model, MllamaForCausalLM):
621
+ text_model: MllamaTextModel = language_model.model
622
+ else:
623
+ text_model = language_model
433
624
  elif isinstance(model, MllamaForCausalLM):
434
625
  text_model = model.model
435
626
  vision_model = None
@@ -503,7 +694,17 @@ def apply_liger_kernel_to_mistral(
503
694
  if cross_entropy:
504
695
  modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
505
696
  if fused_linear_cross_entropy:
506
- modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
697
+ if transformer_version >= version.parse("4.49.0"):
698
+ if model is not None:
699
+ model.forward = MethodType(mistral_lce_forward, model)
700
+ else:
701
+ modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
702
+ else:
703
+ logger.warning(
704
+ "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
705
+ )
706
+ logger.warning("LigerFusedLinearCrossEntropy patch is not applied.")
707
+
507
708
  if swiglu:
508
709
  modeling_mistral.MistralMLP = LigerSwiGLUMLP
509
710
 
@@ -571,10 +772,16 @@ def apply_liger_kernel_to_mixtral(
571
772
 
572
773
  if fused_linear_cross_entropy:
573
774
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
574
- modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
775
+ if model is not None:
776
+ model.forward = MethodType(mixtral_lce_forward, model)
777
+ else:
778
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
575
779
  else: # if version < 4.46.1
576
780
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
577
- modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
781
+ if model is not None:
782
+ model.forward = MethodType(mixtral_lce_forward_deprecated, model)
783
+ else:
784
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
578
785
  if swiglu:
579
786
  modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
580
787
 
@@ -648,10 +855,16 @@ def apply_liger_kernel_to_gemma(
648
855
  modeling_gemma.GemmaMLP = LigerGEGLUMLP
649
856
  if fused_linear_cross_entropy:
650
857
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
651
- modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
858
+ if model is not None:
859
+ model.forward = MethodType(gemma_lce_forward, model)
860
+ else:
861
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
652
862
  else: # if version < 4.46.1
653
863
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
654
- modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
864
+ if model is not None:
865
+ model.forward = MethodType(gemma_lce_forward_deprecated, model)
866
+ else:
867
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
655
868
 
656
869
  if model is not None:
657
870
  # The model instance already exists, so we need to additionally patch the
@@ -723,10 +936,16 @@ def apply_liger_kernel_to_gemma2(
723
936
  modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
724
937
  if fused_linear_cross_entropy:
725
938
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
726
- modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
939
+ if model is not None:
940
+ model.forward = MethodType(gemma2_lce_forward, model)
941
+ else:
942
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
727
943
  else:
728
944
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
729
- modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
945
+ if model is not None:
946
+ model.forward = MethodType(gemma2_lce_forward_deprected, model)
947
+ else:
948
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
730
949
  if geglu:
731
950
  modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
732
951
 
@@ -805,7 +1024,10 @@ def apply_liger_kernel_to_gemma3_text(
805
1024
  nn.functional.cross_entropy = liger_cross_entropy
806
1025
 
807
1026
  if fused_linear_cross_entropy:
808
- modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
1027
+ if model is not None:
1028
+ model.forward = MethodType(causal_forward, model)
1029
+ else:
1030
+ modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
809
1031
 
810
1032
  if model is not None:
811
1033
  # The model instance already exists, so we need to additionally patch the
@@ -875,7 +1097,7 @@ def apply_liger_kernel_to_gemma3(
875
1097
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
876
1098
  )
877
1099
 
878
- if layer_norm:
1100
+ if layer_norm and model is None:
879
1101
  modeling_siglip.nn.LayerNorm = LigerLayerNorm
880
1102
 
881
1103
  apply_liger_kernel_to_gemma3_text(
@@ -886,15 +1108,18 @@ def apply_liger_kernel_to_gemma3(
886
1108
  modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
887
1109
 
888
1110
  if fused_linear_cross_entropy:
889
- modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
1111
+ if model is not None:
1112
+ model.forward = MethodType(multimodal_forward, model)
1113
+ else:
1114
+ modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
890
1115
 
891
1116
  if model is not None:
892
1117
  # The model instance already exists, so we need to additionally patch the
893
1118
  # instance variables that reference already-instantiated modules
894
1119
 
895
1120
  if isinstance(model, Gemma3ForConditionalGeneration):
896
- if isinstance(model.vision_tower, SiglipVisionModel):
897
- vision_tower = model.vision_tower
1121
+ if isinstance(model.model.vision_tower, SiglipVisionModel):
1122
+ vision_tower = model.model.vision_tower
898
1123
 
899
1124
  _patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
900
1125
 
@@ -907,7 +1132,7 @@ def apply_liger_kernel_to_gemma3(
907
1132
  raise TypeError("The vision tower must be SiglipVisionModel")
908
1133
 
909
1134
  if rms_norm:
910
- _patch_rms_norm_module_for_gemma3(model.multi_modal_projector.mm_soft_emb_norm)
1135
+ _patch_rms_norm_module_for_gemma3(model.model.multi_modal_projector.mm_soft_emb_norm)
911
1136
 
912
1137
  apply_liger_kernel_to_gemma3_text(
913
1138
  rope=rope,
@@ -915,7 +1140,7 @@ def apply_liger_kernel_to_gemma3(
915
1140
  fused_linear_cross_entropy=False,
916
1141
  rms_norm=rms_norm,
917
1142
  geglu=geglu,
918
- model=model.language_model,
1143
+ model=model.model.language_model,
919
1144
  )
920
1145
 
921
1146
  else:
@@ -954,7 +1179,9 @@ def apply_liger_kernel_to_paligemma(
954
1179
  # PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
955
1180
 
956
1181
  from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
1182
+ from transformers.models.gemma.modeling_gemma import GemmaModel
957
1183
  from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
1184
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
958
1185
  from transformers.models.paligemma import modeling_paligemma
959
1186
  from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
960
1187
  from transformers.models.siglip import modeling_siglip
@@ -965,7 +1192,7 @@ def apply_liger_kernel_to_paligemma(
965
1192
  from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
966
1193
 
967
1194
  # The vision_tower is a SiglipVisionModel
968
- if layer_norm:
1195
+ if layer_norm and model is None:
969
1196
  modeling_siglip.nn.LayerNorm = LigerLayerNorm
970
1197
 
971
1198
  # SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
@@ -983,10 +1210,16 @@ def apply_liger_kernel_to_paligemma(
983
1210
  modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
984
1211
  if fused_linear_cross_entropy:
985
1212
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
986
- modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
1213
+ if model is not None:
1214
+ model.forward = MethodType(lce_forward, model)
1215
+ else:
1216
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
987
1217
  else: # if version < 4.46.1
988
1218
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
989
- modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
1219
+ if model is not None:
1220
+ model.forward = MethodType(lce_forward_deprecated, model)
1221
+ else:
1222
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
990
1223
 
991
1224
  if model is not None:
992
1225
  # The model instance already exists, so we need to additionally patch the
@@ -995,7 +1228,7 @@ def apply_liger_kernel_to_paligemma(
995
1228
  if not isinstance(model, PaliGemmaForConditionalGeneration):
996
1229
  raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
997
1230
 
998
- vision_tower: SiglipVisionModel = model.vision_tower
1231
+ vision_tower: SiglipVisionModel = model.model.vision_tower
999
1232
 
1000
1233
  _patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
1001
1234
 
@@ -1005,9 +1238,9 @@ def apply_liger_kernel_to_paligemma(
1005
1238
  _patch_layer_norm_module(layer.layer_norm1)
1006
1239
  _patch_layer_norm_module(layer.layer_norm2)
1007
1240
 
1008
- language_model = model.language_model
1241
+ language_model = model.model.language_model
1009
1242
 
1010
- if isinstance(language_model, GemmaForCausalLM):
1243
+ if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
1011
1244
  apply_liger_kernel_to_gemma(
1012
1245
  rope=rope,
1013
1246
  cross_entropy=False,
@@ -1017,7 +1250,7 @@ def apply_liger_kernel_to_paligemma(
1017
1250
  model=language_model,
1018
1251
  )
1019
1252
 
1020
- elif isinstance(language_model, Gemma2ForCausalLM):
1253
+ elif isinstance(language_model, (Gemma2ForCausalLM, Gemma2Model)):
1021
1254
  apply_liger_kernel_to_gemma2(
1022
1255
  rope=rope,
1023
1256
  cross_entropy=False,
@@ -1078,10 +1311,16 @@ def apply_liger_kernel_to_qwen2(
1078
1311
 
1079
1312
  if fused_linear_cross_entropy:
1080
1313
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1081
- modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
1314
+ if model is not None:
1315
+ model.forward = MethodType(qwen2_lce_forward, model)
1316
+ else:
1317
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
1082
1318
  else: # if version < 4.46.1
1083
1319
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1084
- modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
1320
+ if model is not None:
1321
+ model.forward = MethodType(qwen2_lce_forward_deprecated, model)
1322
+ else:
1323
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
1085
1324
 
1086
1325
  if swiglu:
1087
1326
  modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
@@ -1102,7 +1341,6 @@ def apply_liger_kernel_to_qwen2(
1102
1341
  if rms_norm:
1103
1342
  _patch_rms_norm_module(decoder_layer.input_layernorm)
1104
1343
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1105
- print("Applied Liger kernels to Qwen2")
1106
1344
 
1107
1345
 
1108
1346
  def apply_liger_kernel_to_qwen3(
@@ -1137,7 +1375,10 @@ def apply_liger_kernel_to_qwen3(
1137
1375
  nn.functional.cross_entropy = liger_cross_entropy
1138
1376
 
1139
1377
  if fused_linear_cross_entropy:
1140
- modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
1378
+ if model is not None:
1379
+ model.forward = MethodType(qwen3_lce_forward, model)
1380
+ else:
1381
+ modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
1141
1382
 
1142
1383
  if swiglu:
1143
1384
  modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
@@ -1192,7 +1433,10 @@ def apply_liger_kernel_to_qwen3_moe(
1192
1433
  nn.functional.cross_entropy = liger_cross_entropy
1193
1434
 
1194
1435
  if fused_linear_cross_entropy:
1195
- modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
1436
+ if model is not None:
1437
+ model.forward = MethodType(qwen3_lce_forward, model)
1438
+ else:
1439
+ modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
1196
1440
 
1197
1441
  if swiglu:
1198
1442
  modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
@@ -1208,7 +1452,81 @@ def apply_liger_kernel_to_qwen3_moe(
1208
1452
  _patch_rms_norm_module(base_model.norm)
1209
1453
  for decoder_layer in base_model.layers:
1210
1454
  if swiglu:
1211
- _patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
1455
+ for mlp_expert in decoder_layer.mlp.experts:
1456
+ _patch_swiglu_module(mlp_expert, LigerQwen3MoeSwiGLUMLP)
1457
+ if rms_norm:
1458
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1459
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1460
+
1461
+
1462
+ def apply_liger_kernel_to_gpt_oss(
1463
+ rope: bool = True,
1464
+ cross_entropy: bool = False,
1465
+ fused_linear_cross_entropy: bool = True,
1466
+ rms_norm: bool = True,
1467
+ swiglu: bool = False, # Set to False by default since GPT-OSS has custom expert implementation
1468
+ model: PreTrainedModel = None,
1469
+ ) -> None:
1470
+ """
1471
+ Apply Liger kernels to replace original implementation in HuggingFace GPT-OSS models.
1472
+ NOTE: GPT-OSS is supported in transformers >= 4.55.0
1473
+ NOTE: SwiGLU patching is disabled by default for GPT-OSS as it uses a custom expert
1474
+ implementation with clamping and MXFP4 quantization.
1475
+
1476
+ Args:
1477
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1478
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1479
+ fused_linear_cross_entropy (bool):
1480
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1481
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1482
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1483
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1484
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1485
+ Note: GPT-OSS uses a custom expert implementation, so SwiGLU patching is disabled by default.
1486
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1487
+ loaded. Default is None.
1488
+ """
1489
+ if version.parse(transformers.__version__) < version.parse("4.55.0"):
1490
+ logger.warning("GPT-OSS support requires transformers >= 4.55.0")
1491
+ return
1492
+
1493
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1494
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1495
+ )
1496
+
1497
+ from transformers.models.gpt_oss import modeling_gpt_oss
1498
+ from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel
1499
+
1500
+ if rope:
1501
+ modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb
1502
+
1503
+ if rms_norm:
1504
+ modeling_gpt_oss.GptOssRMSNorm = LigerRMSNorm
1505
+
1506
+ if cross_entropy:
1507
+ from transformers.loss.loss_utils import nn
1508
+
1509
+ nn.functional.cross_entropy = liger_cross_entropy
1510
+
1511
+ if fused_linear_cross_entropy:
1512
+ if model is not None:
1513
+ model.forward = MethodType(gpt_oss_lce_forward, model)
1514
+ else:
1515
+ modeling_gpt_oss.GptOssForCausalLM.forward = gpt_oss_lce_forward
1516
+
1517
+ # Note: SwiGLU patching is not implemented for GPT-OSS due to custom expert implementation
1518
+ # with clamping (swiglu_limit=7.0) and MXFP4 quantization
1519
+
1520
+ if model is not None:
1521
+ # The model instance already exists, so we need to additionally patch the
1522
+ # instance variables that reference already-instantiated modules
1523
+
1524
+ # get the base model from the model instance
1525
+ base_model: GptOssModel = getattr(model, model.base_model_prefix, model)
1526
+
1527
+ if rms_norm:
1528
+ _patch_rms_norm_module(base_model.norm)
1529
+ for decoder_layer in base_model.layers:
1212
1530
  if rms_norm:
1213
1531
  _patch_rms_norm_module(decoder_layer.input_layernorm)
1214
1532
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -1260,23 +1578,25 @@ def apply_liger_kernel_to_qwen2_vl(
1260
1578
  if rms_norm:
1261
1579
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
1262
1580
  modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
1263
- if layer_norm:
1581
+ if layer_norm and model is None:
1264
1582
  modeling_qwen2_vl.LayerNorm = LigerLayerNorm
1265
1583
  if cross_entropy:
1266
1584
  modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1267
1585
  if fused_linear_cross_entropy:
1268
- modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
1586
+ if model is not None:
1587
+ model.forward = MethodType(qwen2_vl_lce_forward, model)
1588
+ else:
1589
+ modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
1269
1590
  if swiglu:
1270
1591
  modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
1271
1592
 
1272
1593
  if model is not None:
1273
1594
  # The model instance already exists, so we need to additionally patch the
1274
1595
  # instance variables that reference already-instantiated modules
1275
-
1276
- if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)):
1277
- # Note: language_model and visual properties can be accessed throught conditional class for BC.
1278
- # Not sure if it is subject to changes in the future.
1279
- # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
1596
+ if isinstance(model, Qwen2VLForConditionalGeneration):
1597
+ text_model: Qwen2VLTextModel = model.model.language_model
1598
+ vision_model: Qwen2VisionTransformerPretrainedModel = model.model.visual
1599
+ elif isinstance(model, Qwen2VLModel):
1280
1600
  text_model: Qwen2VLTextModel = model.language_model
1281
1601
  vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
1282
1602
  elif isinstance(model, Qwen2VLTextModel):
@@ -1353,18 +1673,20 @@ def apply_liger_kernel_to_qwen2_5_vl(
1353
1673
  if cross_entropy:
1354
1674
  modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1355
1675
  if fused_linear_cross_entropy:
1356
- modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
1676
+ if model is not None:
1677
+ model.forward = MethodType(qwen2_5_vl_lce_forward, model)
1678
+ else:
1679
+ modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
1357
1680
  if swiglu:
1358
1681
  modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
1359
1682
 
1360
1683
  if model is not None:
1361
1684
  # The model instance already exists, so we need to additionally patch the
1362
1685
  # instance variables that reference already-instantiated modules
1363
-
1364
- if isinstance(model, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel)):
1365
- # Note: language_model and visual properties can be accessed throught conditional class for BC.
1366
- # Not sure if it is subject to changes in the future.
1367
- # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
1686
+ if isinstance(model, Qwen2_5_VLForConditionalGeneration):
1687
+ text_model: Qwen2_5_VLTextModel = model.model.language_model
1688
+ vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.model.visual
1689
+ elif isinstance(model, Qwen2_5_VLModel):
1368
1690
  text_model: Qwen2_5_VLTextModel = model.language_model
1369
1691
  vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
1370
1692
  elif isinstance(model, Qwen2_5_VLTextModel):
@@ -1378,7 +1700,7 @@ def apply_liger_kernel_to_qwen2_5_vl(
1378
1700
 
1379
1701
  if vision_model is not None:
1380
1702
  # Patch Qwen2_5_VisionTransformerPretrainedModel
1381
- for vision_block in model.visual.blocks:
1703
+ for vision_block in vision_model.blocks:
1382
1704
  if rms_norm:
1383
1705
  _patch_rms_norm_module(vision_block.norm1)
1384
1706
  _patch_rms_norm_module(vision_block.norm2)
@@ -1394,69 +1716,220 @@ def apply_liger_kernel_to_qwen2_5_vl(
1394
1716
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1395
1717
 
1396
1718
 
1397
- def apply_liger_kernel_to_phi3(
1719
+ def apply_liger_kernel_to_qwen3_vl(
1398
1720
  rope: bool = True,
1399
1721
  cross_entropy: bool = False,
1400
1722
  fused_linear_cross_entropy: bool = True,
1401
1723
  rms_norm: bool = True,
1402
- swiglu: bool = True,
1724
+ swiglu: bool = False,
1403
1725
  model: PreTrainedModel = None,
1404
1726
  ) -> None:
1405
1727
  """
1406
- Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
1728
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
1407
1729
 
1408
1730
  Args:
1409
- rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1410
1731
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1411
1732
  fused_linear_cross_entropy (bool):
1412
1733
  Whether to apply Liger's fused linear cross entropy loss. Default is True.
1413
1734
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1414
1735
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1415
1736
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1416
- swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
1737
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1417
1738
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1418
1739
  loaded. Default is None.
1419
1740
  """
1741
+
1420
1742
  assert not (cross_entropy and fused_linear_cross_entropy), (
1421
1743
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
1422
1744
  )
1423
1745
 
1424
- from transformers.models.phi3 import modeling_phi3
1425
- from transformers.models.phi3.modeling_phi3 import Phi3Model
1746
+ from transformers.models.qwen3_vl import modeling_qwen3_vl
1747
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
1748
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
1749
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
1750
+
1751
+ from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
1426
1752
 
1427
1753
  if rope:
1428
- modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
1754
+ modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb
1755
+ modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
1756
+
1429
1757
  if rms_norm:
1430
- modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
1431
- if swiglu:
1432
- modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
1758
+ modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
1759
+
1433
1760
  if cross_entropy:
1434
- if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1435
- from transformers.loss.loss_utils import nn
1761
+ from transformers.loss.loss_utils import nn
1762
+
1763
+ nn.functional.cross_entropy = liger_cross_entropy
1436
1764
 
1437
- nn.functional.cross_entropy = liger_cross_entropy
1438
- else:
1439
- logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1440
- modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
1441
1765
  if fused_linear_cross_entropy:
1442
- if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1443
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1444
- else: # if version < 4.46.1
1445
- logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1446
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
1766
+ if model is not None:
1767
+ model.forward = MethodType(qwen3_vl_lce_forward, model)
1768
+ else:
1769
+ modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
1770
+
1771
+ if model is not None and rms_norm:
1772
+ if isinstance(model, Qwen3VLForConditionalGeneration):
1773
+ text_model: Qwen3VLTextModel = model.model.language_model
1774
+ elif isinstance(model, Qwen3VLModel):
1775
+ text_model: Qwen3VLTextModel = model.language_model
1776
+ elif isinstance(model, Qwen3VLTextModel):
1777
+ text_model = model
1778
+ else:
1779
+ raise TypeError(
1780
+ f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
1781
+ )
1447
1782
 
1448
- if model is not None:
1449
- # The model instance already exists, so we need to additionally patch the
1450
- # instance variables that reference already-instantiated modules
1783
+ _patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1451
1784
 
1452
- # get the base model from the model instance
1453
- base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
1785
+ if text_model is not None:
1786
+ _patch_qwen3_vl_rms_norm(text_model.norm)
1787
+ for decoder_layer in text_model.layers:
1788
+ _patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
1789
+ _patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
1790
+ self_attn = getattr(decoder_layer, "self_attn", None)
1791
+ if self_attn is not None:
1792
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1793
+ _patch_qwen3_vl_rms_norm(self_attn.q_norm)
1794
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1795
+ _patch_qwen3_vl_rms_norm(self_attn.k_norm)
1454
1796
 
1455
- if rms_norm:
1456
- _patch_rms_norm_module(base_model.norm)
1457
1797
 
1458
- for decoder_layer in base_model.layers:
1459
- if swiglu:
1798
+ def apply_liger_kernel_to_qwen3_vl_moe(
1799
+ rope: bool = True,
1800
+ cross_entropy: bool = False,
1801
+ fused_linear_cross_entropy: bool = True,
1802
+ rms_norm: bool = True,
1803
+ swiglu: bool = False,
1804
+ model: PreTrainedModel = None,
1805
+ ) -> None:
1806
+ """
1807
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
1808
+
1809
+ Args:
1810
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1811
+ fused_linear_cross_entropy (bool):
1812
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
1813
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1814
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1815
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1816
+ loaded. Default is None.
1817
+ """
1818
+
1819
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1820
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1821
+ )
1822
+
1823
+ from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
1824
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
1825
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
1826
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
1827
+
1828
+ from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
1829
+
1830
+ if rope:
1831
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
1832
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
1833
+
1834
+ if rms_norm:
1835
+ modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
1836
+
1837
+ if cross_entropy:
1838
+ from transformers.loss.loss_utils import nn
1839
+
1840
+ nn.functional.cross_entropy = liger_cross_entropy
1841
+
1842
+ if fused_linear_cross_entropy:
1843
+ if model is not None:
1844
+ model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
1845
+ else:
1846
+ modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
1847
+
1848
+ if model is not None and rms_norm:
1849
+ if isinstance(model, Qwen3VLMoeForConditionalGeneration):
1850
+ text_model: Qwen3VLMoeTextModel = model.model.language_model
1851
+ elif isinstance(model, Qwen3VLMoeModel):
1852
+ text_model: Qwen3VLMoeTextModel = model.language_model
1853
+ elif isinstance(model, Qwen3VLMoeTextModel):
1854
+ text_model = model
1855
+ else:
1856
+ raise TypeError(
1857
+ f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
1858
+ )
1859
+
1860
+ _patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1861
+
1862
+ if text_model is not None:
1863
+ _patch_qwen3_vl_moe_rms_norm(text_model.norm)
1864
+ for decoder_layer in text_model.layers:
1865
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
1866
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
1867
+ self_attn = getattr(decoder_layer, "self_attn", None)
1868
+ if self_attn is not None:
1869
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1870
+ _patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
1871
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1872
+ _patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
1873
+
1874
+
1875
+ def apply_liger_kernel_to_phi3(
1876
+ rope: bool = True,
1877
+ cross_entropy: bool = False,
1878
+ fused_linear_cross_entropy: bool = True,
1879
+ rms_norm: bool = True,
1880
+ swiglu: bool = True,
1881
+ model: PreTrainedModel = None,
1882
+ ) -> None:
1883
+ """
1884
+ Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
1885
+
1886
+ Args:
1887
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1888
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1889
+ fused_linear_cross_entropy (bool):
1890
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1891
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1892
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1893
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1894
+ swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
1895
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1896
+ loaded. Default is None.
1897
+ """
1898
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1899
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1900
+ )
1901
+
1902
+ from transformers.models.phi3 import modeling_phi3
1903
+ from transformers.models.phi3.modeling_phi3 import Phi3Model
1904
+
1905
+ if rope:
1906
+ modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
1907
+ if rms_norm:
1908
+ modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
1909
+ if swiglu:
1910
+ modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
1911
+ if cross_entropy:
1912
+ from transformers.loss.loss_utils import nn
1913
+
1914
+ nn.functional.cross_entropy = liger_cross_entropy
1915
+ if fused_linear_cross_entropy:
1916
+ if model is not None:
1917
+ model.forward = MethodType(phi3_lce_forward, model)
1918
+ else:
1919
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1920
+
1921
+ if model is not None:
1922
+ # The model instance already exists, so we need to additionally patch the
1923
+ # instance variables that reference already-instantiated modules
1924
+
1925
+ # get the base model from the model instance
1926
+ base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
1927
+
1928
+ if rms_norm:
1929
+ _patch_rms_norm_module(base_model.norm)
1930
+
1931
+ for decoder_layer in base_model.layers:
1932
+ if swiglu:
1460
1933
  _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
1461
1934
  if rms_norm:
1462
1935
  _patch_rms_norm_module(decoder_layer.input_layernorm)
@@ -1507,7 +1980,10 @@ def apply_liger_kernel_to_olmo2(
1507
1980
 
1508
1981
  nn.functional.cross_entropy = liger_cross_entropy
1509
1982
  if fused_linear_cross_entropy:
1510
- modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
1983
+ if model is not None:
1984
+ model.forward = MethodType(olmo2_lce_forward, model)
1985
+ else:
1986
+ modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
1511
1987
 
1512
1988
  if model is not None:
1513
1989
  # The model instance already exists, so we need to additionally patch the
@@ -1527,6 +2003,74 @@ def apply_liger_kernel_to_olmo2(
1527
2003
  _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
1528
2004
 
1529
2005
 
2006
+ def apply_liger_kernel_to_olmo3(
2007
+ rope: bool = True,
2008
+ cross_entropy: bool = False,
2009
+ fused_linear_cross_entropy: bool = True,
2010
+ rms_norm: bool = True,
2011
+ swiglu: bool = True,
2012
+ model: PreTrainedModel = None,
2013
+ ) -> None:
2014
+ """
2015
+ Apply Liger kernels to replace original implementation in HuggingFace Olmo3 models.
2016
+
2017
+ Args:
2018
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2019
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2020
+ fused_linear_cross_entropy (bool):
2021
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2022
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2023
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2024
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2025
+ swiglu (bool): Whether to apply Liger's SwiGLU to Olmo3MLP. Default is True.
2026
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2027
+ loaded. Default is None.
2028
+ """
2029
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2030
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2031
+ )
2032
+
2033
+ from transformers.models.olmo3 import modeling_olmo3
2034
+ from transformers.models.olmo3.modeling_olmo3 import Olmo3Model
2035
+
2036
+ from liger_kernel.transformers.model.olmo3 import lce_forward as olmo3_lce_forward
2037
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
2038
+
2039
+ # Olmo3 arch is very similar to Olmo2, so we can reuse all these components in the same way.
2040
+ if rope:
2041
+ modeling_olmo3.apply_rotary_pos_emb = liger_rotary_pos_emb
2042
+ if rms_norm:
2043
+ modeling_olmo3.Olmo3RMSNorm = LigerRMSNormForOlmo2 # same as olmo2
2044
+ if swiglu:
2045
+ modeling_olmo3.Olmo3MLP = LigerSwiGLUMLP
2046
+ if cross_entropy:
2047
+ from transformers.loss.loss_utils import nn
2048
+
2049
+ nn.functional.cross_entropy = liger_cross_entropy
2050
+ if fused_linear_cross_entropy:
2051
+ if model is not None:
2052
+ model.forward = MethodType(olmo3_lce_forward, model)
2053
+ else:
2054
+ modeling_olmo3.Olmo3ForCausalLM.forward = olmo3_lce_forward
2055
+
2056
+ if model is not None:
2057
+ # The model instance already exists, so we need to additionally patch the
2058
+ # instance variables that reference already-instantiated modules
2059
+
2060
+ # get the base model from the model instance
2061
+ base_model: Olmo3Model = getattr(model, model.base_model_prefix, model)
2062
+
2063
+ if rms_norm:
2064
+ _patch_rms_norm_module(base_model.norm)
2065
+
2066
+ for decoder_layer in base_model.layers:
2067
+ if swiglu:
2068
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2069
+ if rms_norm:
2070
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
2071
+ _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
2072
+
2073
+
1530
2074
  def apply_liger_kernel_to_glm4(
1531
2075
  rope: bool = False,
1532
2076
  cross_entropy: bool = False,
@@ -1571,7 +2115,10 @@ def apply_liger_kernel_to_glm4(
1571
2115
 
1572
2116
  nn.functional.cross_entropy = liger_cross_entropy
1573
2117
  if fused_linear_cross_entropy:
1574
- modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
2118
+ if model is not None:
2119
+ model.forward = MethodType(glm4_lce_forward, model)
2120
+ else:
2121
+ modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
1575
2122
 
1576
2123
  if model is not None:
1577
2124
  # The model instance already exists, so we need to additionally patch the
@@ -1593,6 +2140,764 @@ def apply_liger_kernel_to_glm4(
1593
2140
  _patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
1594
2141
 
1595
2142
 
2143
+ def apply_liger_kernel_to_glm4v(
2144
+ rope: bool = False,
2145
+ cross_entropy: bool = False,
2146
+ fused_linear_cross_entropy: bool = True,
2147
+ rms_norm: bool = True,
2148
+ swiglu: bool = True,
2149
+ model: PreTrainedModel = None,
2150
+ ) -> None:
2151
+ """
2152
+ Apply Liger kernels to replace original implementation in HuggingFace GLM-4v models.
2153
+
2154
+ Args:
2155
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2156
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2157
+ fused_linear_cross_entropy (bool):
2158
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2159
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2160
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2161
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2162
+ swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
2163
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2164
+ loaded. Default is None.
2165
+ """
2166
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2167
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2168
+ )
2169
+
2170
+ from transformers.models.glm4v import modeling_glm4v
2171
+ from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration
2172
+ from transformers.models.glm4v.modeling_glm4v import Glm4vModel
2173
+ from transformers.models.glm4v.modeling_glm4v import Glm4vTextModel
2174
+ from transformers.models.glm4v.modeling_glm4v import Glm4vVisionModel
2175
+
2176
+ from liger_kernel.transformers.model.glm4v import lce_forward as glm4v_lce_forward
2177
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
2178
+
2179
+ if rope:
2180
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
2181
+ if rms_norm:
2182
+ modeling_glm4v.Glm4vRMSNorm = LigerRMSNormForGlm4
2183
+ if cross_entropy:
2184
+ from transformers.loss.loss_utils import nn
2185
+
2186
+ nn.functional.cross_entropy = liger_cross_entropy
2187
+ if fused_linear_cross_entropy:
2188
+ if model is not None:
2189
+ model.forward = MethodType(glm4v_lce_forward, model)
2190
+ else:
2191
+ modeling_glm4v.Glm4vForConditionalGeneration.forward = glm4v_lce_forward
2192
+
2193
+ if model is not None:
2194
+ # The model instance already exists, so we need to additionally patch the
2195
+ # instance variables that reference already-instantiated modules
2196
+ if isinstance(model, Glm4vForConditionalGeneration):
2197
+ text_model: Glm4vTextModel = model.model.language_model
2198
+ vision_model: Glm4vVisionModel = model.model.visual
2199
+ elif isinstance(model, Glm4vModel):
2200
+ text_model: Glm4vTextModel = model.language_model
2201
+ vision_model: Glm4vVisionModel = model.visual
2202
+ elif isinstance(model, Glm4vTextModel):
2203
+ text_model: Glm4vTextModel = model
2204
+ vision_model = None
2205
+ else:
2206
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
2207
+ raise TypeError(
2208
+ f"Unsupported glm4.1v model type. `model` must be `Glm4VLForConditionalGeneration`, `Glm4vVisionModel` or `Glm4vTextModel`. Got: {type(model)}"
2209
+ )
2210
+
2211
+ if vision_model is not None:
2212
+ for vision_block in vision_model.blocks:
2213
+ if rms_norm:
2214
+ _patch_rms_norm_module(vision_block.norm1)
2215
+ _patch_rms_norm_module(vision_block.norm2)
2216
+ if swiglu:
2217
+ _patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
2218
+
2219
+ if text_model is not None:
2220
+ if rms_norm:
2221
+ _patch_rms_norm_module(text_model.norm)
2222
+ for decoder_layer in text_model.layers:
2223
+ if swiglu:
2224
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
2225
+ if rms_norm:
2226
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2227
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2228
+ _patch_rms_norm_module(decoder_layer.post_self_attn_layernorm)
2229
+ _patch_rms_norm_module(decoder_layer.post_mlp_layernorm)
2230
+
2231
+
2232
+ def apply_liger_kernel_to_glm4v_moe(
2233
+ rope: bool = False,
2234
+ cross_entropy: bool = False,
2235
+ fused_linear_cross_entropy: bool = True,
2236
+ rms_norm: bool = True,
2237
+ swiglu: bool = True,
2238
+ model: PreTrainedModel = None,
2239
+ ) -> None:
2240
+ """
2241
+ Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
2242
+
2243
+ Args:
2244
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2245
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2246
+ fused_linear_cross_entropy (bool):
2247
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2248
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2249
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2250
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2251
+ swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
2252
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2253
+ loaded. Default is None.
2254
+ """
2255
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2256
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2257
+ )
2258
+
2259
+ from transformers.models.glm4v_moe import modeling_glm4v_moe
2260
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeForConditionalGeneration
2261
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeModel
2262
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextModel
2263
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeVisionModel
2264
+
2265
+ from liger_kernel.transformers.model.glm4v_moe import lce_forward as glm4v_moe_lce_forward
2266
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
2267
+
2268
+ if rope:
2269
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
2270
+ if rms_norm:
2271
+ modeling_glm4v_moe.Glm4vMoeRMSNorm = LigerRMSNormForGlm4
2272
+ modeling_glm4v_moe.Glm4vMoeTextRMSNorm = LigerRMSNormForGlm4
2273
+ if cross_entropy:
2274
+ from transformers.loss.loss_utils import nn
2275
+
2276
+ nn.functional.cross_entropy = liger_cross_entropy
2277
+ if fused_linear_cross_entropy:
2278
+ if model is not None:
2279
+ model.forward = MethodType(glm4v_moe_lce_forward, model)
2280
+ else:
2281
+ modeling_glm4v_moe.Glm4vMoeForConditionalGeneration.forward = glm4v_moe_lce_forward
2282
+
2283
+ if model is not None:
2284
+ # The model instance already exists, so we need to additionally patch the
2285
+ # instance variables that reference already-instantiated modules
2286
+ if isinstance(model, Glm4vMoeForConditionalGeneration):
2287
+ text_model: Glm4vMoeTextModel = model.model.language_model
2288
+ vision_model: Glm4vMoeVisionModel = model.model.visual
2289
+ Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
2290
+ elif isinstance(model, Glm4vMoeModel):
2291
+ text_model: Glm4vMoeTextModel = model.language_model
2292
+ vision_model: Glm4vMoeVisionModel = model.visual
2293
+ Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
2294
+ elif isinstance(model, Glm4vMoeTextModel):
2295
+ text_model: Glm4vMoeTextModel = model
2296
+ vision_model = None
2297
+ else:
2298
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
2299
+ raise TypeError(
2300
+ f"Unsupported glm4v_moe model type. `model` must be `Glm4vMoeForConditionalGeneration`, `Glm4vMoeVisionModel` or `Glm4vMoeTextModel`. Got: {type(model)}"
2301
+ )
2302
+
2303
+ if vision_model is not None:
2304
+ _patch_rms_norm_module(vision_model.post_conv_layernorm)
2305
+ _patch_rms_norm_module(vision_model.post_layernorm)
2306
+ for vision_block in vision_model.blocks:
2307
+ if rms_norm:
2308
+ _patch_rms_norm_module(vision_block.norm1)
2309
+ _patch_rms_norm_module(vision_block.norm2)
2310
+ if swiglu:
2311
+ _patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
2312
+
2313
+ if text_model is not None:
2314
+ if rms_norm:
2315
+ _patch_rms_norm_module(text_model.norm)
2316
+ for decoder_layer in text_model.layers:
2317
+ if swiglu:
2318
+ decoder_layer.mlp = _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2319
+ if rms_norm:
2320
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2321
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2322
+ if isinstance(Glm4vMoeTextMoE, type) and isinstance(decoder_layer.mlp, Glm4vMoeTextMoE):
2323
+ experts = getattr(decoder_layer.mlp, "experts", None)
2324
+ if experts is not None:
2325
+ for expert in experts:
2326
+ _patch_swiglu_module(expert, LigerSwiGLUMLP)
2327
+ if decoder_layer.mlp.shared_experts is not None:
2328
+ _patch_swiglu_module(decoder_layer.mlp.shared_experts, LigerSwiGLUMLP)
2329
+ for decoder_layer in text_model.layers:
2330
+ if rms_norm:
2331
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2332
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2333
+
2334
+
2335
+ def apply_liger_kernel_to_internvl(
2336
+ cross_entropy: bool = False,
2337
+ fused_linear_cross_entropy: bool = True,
2338
+ rms_norm: bool = True,
2339
+ layer_norm: bool = True,
2340
+ model: Optional[PreTrainedModel] = None,
2341
+ **kwargs,
2342
+ ) -> None:
2343
+ """
2344
+ Apply Liger kernels to replace original implementation in HuggingFace InternVL models.
2345
+ Due to the characteristics of InternVL, the model must be passed to apply Liger-Kernel's patch to other models connected to InternVL.
2346
+ However, if an LM not supported by Liger-Kernel is connected to InternVL, unexpected side effects may occur.
2347
+ NOTE: InternVL is not available in transformers<4.52.1
2348
+
2349
+ Args:
2350
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2351
+ fused_linear_cross_entropy (bool):
2352
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2353
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2354
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2355
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2356
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
2357
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2358
+ loaded. Default is None.
2359
+ """
2360
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2361
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2362
+ )
2363
+ import torch.nn as torch_nn
2364
+
2365
+ from transformers.models.internvl import modeling_internvl
2366
+ from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration
2367
+ from transformers.models.internvl.modeling_internvl import InternVLModel
2368
+ from transformers.models.internvl.modeling_internvl import InternVLVisionLayer
2369
+ from transformers.models.internvl.modeling_internvl import InternVLVisionModel
2370
+ from transformers.models.internvl.modeling_internvl import InternVLVisionRMSNorm
2371
+
2372
+ from liger_kernel.transformers.layer_norm import LigerLayerNorm
2373
+ from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
2374
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm
2375
+
2376
+ if layer_norm and model is None:
2377
+ modeling_internvl.nn.LayerNorm = LigerLayerNorm
2378
+
2379
+ if cross_entropy:
2380
+ logger.info("Apply liger cross entropy")
2381
+
2382
+ from transformers.loss.loss_utils import nn
2383
+
2384
+ nn.functional.cross_entropy = liger_cross_entropy
2385
+ if fused_linear_cross_entropy:
2386
+ modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
2387
+ if rms_norm:
2388
+ modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
2389
+
2390
+ if model is not None:
2391
+ # The model instance already exists, so we need to additionally patch the
2392
+ # instance variables that reference already-instantiated modules
2393
+ if isinstance(model, InternVLForConditionalGeneration):
2394
+ text_model = model.model.language_model
2395
+ vision_model: InternVLVisionModel = model.model.vision_tower
2396
+ elif isinstance(model, InternVLModel):
2397
+ text_model = model.language_model
2398
+ vision_model: InternVLVisionModel = model.vision_tower
2399
+ else:
2400
+ raise TypeError(
2401
+ f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}"
2402
+ )
2403
+
2404
+ text_model_name = model.config.text_config.model_type
2405
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2406
+
2407
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2408
+ if text_liger_fn:
2409
+ accept_params = inspect.signature(text_liger_fn).parameters
2410
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2411
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2412
+
2413
+ if remain_params:
2414
+ logger.warning(
2415
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2416
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2417
+ )
2418
+ text_kwargs["model"] = text_model
2419
+ text_liger_fn(**text_kwargs)
2420
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2421
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2422
+
2423
+ # Patch vision model RMSNorm layers
2424
+ if rms_norm:
2425
+ for encoder_layer in vision_model.encoder.layer:
2426
+ encoder_layer: InternVLVisionLayer
2427
+ if isinstance(encoder_layer.attention.q_norm, InternVLVisionRMSNorm):
2428
+ _patch_rms_norm_module(encoder_layer.attention.q_norm)
2429
+ if isinstance(encoder_layer.attention.k_norm, InternVLVisionRMSNorm):
2430
+ _patch_rms_norm_module(encoder_layer.attention.k_norm)
2431
+
2432
+ # Patch vision model LayerNorm layers
2433
+ if layer_norm:
2434
+ # Patch layernorm
2435
+ if isinstance(vision_model.layernorm, torch_nn.LayerNorm):
2436
+ _patch_layer_norm_module(vision_model.layernorm)
2437
+
2438
+ # Patch encoder layers
2439
+ for encoder_layer in vision_model.encoder.layer:
2440
+ encoder_layer: InternVLVisionLayer
2441
+ if isinstance(encoder_layer.layernorm_before, torch_nn.LayerNorm):
2442
+ _patch_layer_norm_module(encoder_layer.layernorm_before)
2443
+ if isinstance(encoder_layer.layernorm_after, torch_nn.LayerNorm):
2444
+ _patch_layer_norm_module(encoder_layer.layernorm_after)
2445
+
2446
+
2447
+ def apply_liger_kernel_to_smolvlm(
2448
+ cross_entropy: bool = False,
2449
+ fused_linear_cross_entropy: bool = True,
2450
+ rms_norm: bool = True,
2451
+ layer_norm: bool = True,
2452
+ model: Optional[PreTrainedModel] = None,
2453
+ **kwargs,
2454
+ ) -> None:
2455
+ """
2456
+ Apply Liger kernels to replace original implementation in HuggingFace SmolVLM models.
2457
+ Due to the characteristics of SmolVLM, the model must be passed to apply Liger-Kernel's patch to other models connected to SmolVLM.
2458
+ However, if an LM not supported by Liger-Kernel is connected to SmolVLM, unexpected side effects may occur.
2459
+ NOTE: SmolVLM is not available in transformers<4.50.0
2460
+
2461
+ Args:
2462
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2463
+ fused_linear_cross_entropy (bool):
2464
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2465
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2466
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2467
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2468
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
2469
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2470
+ loaded. Default is None.
2471
+ """
2472
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2473
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2474
+ )
2475
+
2476
+ from transformers.models.smolvlm import modeling_smolvlm
2477
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMEncoderLayer
2478
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration
2479
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMModel
2480
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer
2481
+
2482
+ from liger_kernel.transformers.model.smolvlm import lce_forward as smolvlm_lce_forward
2483
+
2484
+ # Patch LayerNorm for vision model if model is not provided (pre-initialization)
2485
+ if layer_norm and model is None:
2486
+ modeling_smolvlm.nn.LayerNorm = LigerLayerNorm
2487
+
2488
+ if cross_entropy:
2489
+ logger.info("Apply liger cross entropy")
2490
+
2491
+ from transformers.loss.loss_utils import nn
2492
+
2493
+ nn.functional.cross_entropy = liger_cross_entropy
2494
+ if fused_linear_cross_entropy:
2495
+ if model is not None:
2496
+ model.forward = MethodType(smolvlm_lce_forward, model)
2497
+ else:
2498
+ modeling_smolvlm.SmolVLMForConditionalGeneration.forward = smolvlm_lce_forward
2499
+ if rms_norm:
2500
+ modeling_smolvlm.SmolVLMRMSNorm = LigerRMSNorm
2501
+
2502
+ if model is not None:
2503
+ # The model instance already exists, so we need to additionally patch the
2504
+ # instance variables that reference already-instantiated modules
2505
+ if isinstance(model, SmolVLMForConditionalGeneration):
2506
+ text_model = model.model.text_model
2507
+ vision_model: SmolVLMVisionTransformer = model.model.vision_model
2508
+ elif isinstance(model, SmolVLMModel):
2509
+ text_model = model.text_model
2510
+ vision_model: SmolVLMVisionTransformer = model.vision_model
2511
+ else:
2512
+ raise TypeError(
2513
+ f"Unsupported smolvlm model type. `model` must be `SmolVLMForConditionalGeneration`, `SmolVLMModel`. Got: {type(model)}"
2514
+ )
2515
+
2516
+ text_model_name = model.config.text_config.model_type
2517
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2518
+
2519
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2520
+ if text_liger_fn:
2521
+ accept_params = inspect.signature(text_liger_fn).parameters
2522
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2523
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2524
+
2525
+ if remain_params:
2526
+ logger.warning(
2527
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2528
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2529
+ )
2530
+ text_kwargs["model"] = text_model
2531
+ text_liger_fn(**text_kwargs)
2532
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2533
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2534
+
2535
+ # Patch vision model LayerNorm layers
2536
+ if layer_norm:
2537
+ # Patch post_layernorm
2538
+ _patch_layer_norm_module(vision_model.post_layernorm)
2539
+
2540
+ # Patch encoder layers
2541
+ for encoder_layer in vision_model.encoder.layers:
2542
+ encoder_layer: SmolVLMEncoderLayer
2543
+ _patch_layer_norm_module(encoder_layer.layer_norm1)
2544
+ _patch_layer_norm_module(encoder_layer.layer_norm2)
2545
+
2546
+
2547
+ def apply_liger_kernel_to_falcon_h1(
2548
+ rope: bool = True,
2549
+ cross_entropy: bool = False,
2550
+ fused_linear_cross_entropy: bool = True,
2551
+ rms_norm: bool = True,
2552
+ swiglu: bool = False,
2553
+ model: PreTrainedModel = None,
2554
+ ) -> None:
2555
+ """
2556
+ Apply Liger kernels to replace original implementation in HuggingFace Falcon-H1 models
2557
+ Args:
2558
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2559
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
2560
+ fused_linear_cross_entropy (bool):
2561
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
2562
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2563
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2564
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
2565
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
2566
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2567
+ loaded. Default is None.
2568
+ """
2569
+
2570
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2571
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2572
+ )
2573
+
2574
+ from transformers.models.falcon_h1 import modeling_falcon_h1
2575
+ from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1Model
2576
+
2577
+ if rope:
2578
+ logger.info("Apply liger rotary pos emb.")
2579
+ modeling_falcon_h1.apply_rotary_pos_emb = liger_rotary_pos_emb
2580
+ if rms_norm:
2581
+ logger.info("Apply liger RMSNorm")
2582
+ modeling_falcon_h1.FalconH1RMSNorm = LigerRMSNorm
2583
+ if swiglu:
2584
+ logger.warning("LigerSwiGLUMLP is not available for Falcon-H1 models. There will be no effect.")
2585
+
2586
+ if cross_entropy:
2587
+ logger.info("Apply liger cross entropy")
2588
+ from transformers.loss.loss_utils import nn
2589
+
2590
+ nn.functional.cross_entropy = liger_cross_entropy
2591
+
2592
+ if fused_linear_cross_entropy:
2593
+ if model is not None:
2594
+ model.forward = MethodType(falcon_h1_lce_forward, model)
2595
+ else:
2596
+ modeling_falcon_h1.FalconH1ForCausalLM.forward = falcon_h1_lce_forward
2597
+
2598
+ if model is not None:
2599
+ # The model instance already exists, so we need to additionally patch the
2600
+ # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
2601
+
2602
+ # get the base model from the model instance
2603
+ base_model: FalconH1Model = getattr(model, model.base_model_prefix, model)
2604
+
2605
+ if rms_norm:
2606
+ _patch_rms_norm_module(base_model.final_layernorm)
2607
+
2608
+ for decoder_layer in base_model.layers:
2609
+ if swiglu:
2610
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2611
+ if rms_norm:
2612
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2613
+ _patch_rms_norm_module(decoder_layer.pre_ff_layernorm)
2614
+
2615
+
2616
+ def apply_liger_kernel_to_qwen3_next(
2617
+ rope: bool = False,
2618
+ cross_entropy: bool = False,
2619
+ fused_linear_cross_entropy: bool = True,
2620
+ rms_norm: bool = True,
2621
+ swiglu: bool = True,
2622
+ model: PreTrainedModel = None,
2623
+ ) -> None:
2624
+ """
2625
+ Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
2626
+
2627
+ Args:
2628
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2629
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2630
+ fused_linear_cross_entropy (bool):
2631
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2632
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2633
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2634
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2635
+ swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
2636
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2637
+ loaded. Default is None.
2638
+ """
2639
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2640
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2641
+ )
2642
+
2643
+ from transformers.models.qwen3_next import modeling_qwen3_next
2644
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM
2645
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP
2646
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextModel
2647
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
2648
+
2649
+ from liger_kernel.transformers.model.qwen3_next import lce_forward as qwen3_next_lce_forward
2650
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
2651
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
2652
+
2653
+ if rope:
2654
+ # It might enocunter nan issue
2655
+ # modeling_qwen3_next.apply_rotary_pos_emb = liger_rotary_pos_emb
2656
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3Next models.")
2657
+ if rms_norm:
2658
+ modeling_qwen3_next.Qwen3NextRMSNorm = LigerRMSNormForQwen3Next
2659
+ if cross_entropy:
2660
+ from transformers.loss.loss_utils import nn
2661
+
2662
+ nn.functional.cross_entropy = liger_cross_entropy
2663
+ if fused_linear_cross_entropy:
2664
+ if model is not None:
2665
+ if isinstance(model, Qwen3NextForCausalLM):
2666
+ model.forward = MethodType(qwen3_next_lce_forward, model)
2667
+ else:
2668
+ raise TypeError(
2669
+ f" fused_linear_cross_entropy is only applicable on Qwen3NextForCausalLM. Got: {type(model)}"
2670
+ )
2671
+ else:
2672
+ modeling_qwen3_next.Qwen3NextForCausalLM.forward = qwen3_next_lce_forward
2673
+ if swiglu:
2674
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2675
+ modeling_qwen3_next.Qwen3NextMLP = LigerQwen3MoeSwiGLUMLP
2676
+
2677
+ if model is not None:
2678
+ # The model instance already exists, so we need to additionally patch the
2679
+ # instance variables that reference already-instantiated modules
2680
+ if isinstance(model, (Qwen3NextForCausalLM, Qwen3NextModel)):
2681
+ base_model: Qwen3NextForCausalLM = getattr(model, model.base_model_prefix, model)
2682
+ else:
2683
+ raise TypeError(
2684
+ f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}"
2685
+ )
2686
+
2687
+ if rms_norm:
2688
+ _patch_rms_norm_module(base_model.norm)
2689
+
2690
+ for decoder_layer in base_model.layers:
2691
+ if rms_norm:
2692
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2693
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2694
+
2695
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2696
+ if swiglu:
2697
+ if isinstance(decoder_layer.mlp, Qwen3NextMLP):
2698
+ _patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
2699
+ if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock):
2700
+ _patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP)
2701
+ experts = getattr(decoder_layer.mlp, "experts", None)
2702
+ if experts is not None:
2703
+ for expert in experts:
2704
+ _patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
2705
+
2706
+
2707
+ def apply_liger_kernel_to_hunyuan_v1_dense(
2708
+ rope: bool = True,
2709
+ cross_entropy: bool = False,
2710
+ fused_linear_cross_entropy: bool = True,
2711
+ rms_norm: bool = True,
2712
+ swiglu: bool = True,
2713
+ model: PreTrainedModel = None,
2714
+ ) -> None:
2715
+ """
2716
+ Apply Liger kernels to replace original implementation in HuggingFace Hunyuan v1 dense models.
2717
+ """
2718
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2719
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2720
+ )
2721
+
2722
+ from transformers.models.hunyuan_v1_dense import modeling_hunyuan_v1_dense
2723
+ from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1Model
2724
+
2725
+ from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_lce_forward
2726
+ from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
2727
+
2728
+ if rope:
2729
+ modeling_hunyuan_v1_dense.apply_rotary_pos_emb = liger_rotary_pos_emb
2730
+
2731
+ if rms_norm:
2732
+ modeling_hunyuan_v1_dense.HunYuanDenseV1RMSNorm = LigerRMSNorm
2733
+
2734
+ if cross_entropy:
2735
+ from transformers.loss.loss_utils import nn
2736
+
2737
+ nn.functional.cross_entropy = liger_cross_entropy
2738
+
2739
+ if fused_linear_cross_entropy:
2740
+ if model is not None:
2741
+ model.forward = MethodType(hunyuan_v1_lce_forward, model)
2742
+ else:
2743
+ modeling_hunyuan_v1_dense.HunYuanDenseV1ForCausalLM.forward = hunyuan_v1_lce_forward
2744
+
2745
+ if swiglu:
2746
+ modeling_hunyuan_v1_dense.HunYuanDenseV1MLP = LigerHunyuanV1SwiGLUMLP
2747
+
2748
+ if model is not None:
2749
+ # The model instance already exists, so we need to additionally patch the
2750
+ # instance variables that reference already-instantiated modules
2751
+
2752
+ # get the base model from the model instance
2753
+ base_model: HunYuanDenseV1Model = getattr(model, model.base_model_prefix, model)
2754
+
2755
+ if rms_norm:
2756
+ _patch_rms_norm_module(base_model.norm)
2757
+ for decoder_layer in base_model.layers:
2758
+ if swiglu:
2759
+ _patch_swiglu_module(decoder_layer.mlp, LigerHunyuanV1SwiGLUMLP)
2760
+ if rms_norm:
2761
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2762
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2763
+
2764
+
2765
+ def apply_liger_kernel_to_hunyuan_v1_moe(
2766
+ rope: bool = True,
2767
+ cross_entropy: bool = False,
2768
+ fused_linear_cross_entropy: bool = True,
2769
+ rms_norm: bool = True,
2770
+ swiglu: bool = True,
2771
+ model: PreTrainedModel = None,
2772
+ ) -> None:
2773
+ """
2774
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
2775
+ """
2776
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2777
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2778
+ )
2779
+
2780
+ from transformers.models.hunyuan_v1_moe import modeling_hunyuan_v1_moe
2781
+ from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1Model
2782
+
2783
+ from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_moe_lce_forward
2784
+ from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
2785
+
2786
+ if rope:
2787
+ modeling_hunyuan_v1_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
2788
+
2789
+ if rms_norm:
2790
+ modeling_hunyuan_v1_moe.HunYuanMoEV1RMSNorm = LigerRMSNorm
2791
+
2792
+ if cross_entropy:
2793
+ from transformers.loss.loss_utils import nn
2794
+
2795
+ nn.functional.cross_entropy = liger_cross_entropy
2796
+
2797
+ if fused_linear_cross_entropy:
2798
+ if model is not None:
2799
+ model.forward = MethodType(hunyuan_v1_moe_lce_forward, model)
2800
+ else:
2801
+ modeling_hunyuan_v1_moe.HunYuanMoEV1ForCausalLM.forward = hunyuan_v1_moe_lce_forward
2802
+
2803
+ if swiglu:
2804
+ modeling_hunyuan_v1_moe.HunYuanMoEV1MLP = LigerHunyuanV1SwiGLUMLP
2805
+
2806
+ if model is not None:
2807
+ # The model instance already exists, so we need to additionally patch the
2808
+ # instance variables that reference already-instantiated modules
2809
+
2810
+ # get the base model from the model instance
2811
+ base_model: HunYuanMoEV1Model = getattr(model, model.base_model_prefix, model)
2812
+
2813
+ if rms_norm:
2814
+ _patch_rms_norm_module(base_model.norm)
2815
+ for decoder_layer in base_model.layers:
2816
+ if swiglu:
2817
+ for mlp_expert in decoder_layer.mlp.experts:
2818
+ _patch_swiglu_module(mlp_expert, LigerHunyuanV1SwiGLUMLP)
2819
+ if rms_norm:
2820
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2821
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2822
+
2823
+
2824
+ def apply_liger_kernel_to_exaone4(
2825
+ rope: bool = True,
2826
+ cross_entropy: bool = False,
2827
+ fused_linear_cross_entropy: bool = True,
2828
+ rms_norm: bool = True,
2829
+ swiglu: bool = True,
2830
+ model: PreTrainedModel = None,
2831
+ ) -> None:
2832
+ """
2833
+ Apply Liger kernels to replace original implementation in HuggingFace EXAONE4 models.
2834
+
2835
+ Args:
2836
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2837
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2838
+ fused_linear_cross_entropy (bool):
2839
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2840
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2841
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2842
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2843
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
2844
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2845
+ loaded. Default is None.
2846
+ """
2847
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2848
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2849
+ )
2850
+
2851
+ from transformers.models.exaone4 import modeling_exaone4
2852
+ from transformers.models.exaone4.modeling_exaone4 import Exaone4Model
2853
+
2854
+ from liger_kernel.transformers.model.exaone4 import lce_forward as exaone4_lce_forward
2855
+
2856
+ if rope:
2857
+ modeling_exaone4.apply_rotary_pos_emb = liger_rotary_pos_emb
2858
+
2859
+ if rms_norm:
2860
+ # EXAONE4 requires in_place=False to avoid gradient issues
2861
+ class Exaone4LigerRMSNorm(LigerRMSNorm):
2862
+ def __init__(self, hidden_size, eps=1e-6, **kwargs):
2863
+ super().__init__(hidden_size, eps, **kwargs)
2864
+ self.in_place = False
2865
+
2866
+ modeling_exaone4.Exaone4RMSNorm = Exaone4LigerRMSNorm
2867
+
2868
+ if cross_entropy:
2869
+ from transformers.loss.loss_utils import nn
2870
+
2871
+ nn.functional.cross_entropy = liger_cross_entropy
2872
+
2873
+ if fused_linear_cross_entropy:
2874
+ if model is not None:
2875
+ model.forward = MethodType(exaone4_lce_forward, model)
2876
+ else:
2877
+ modeling_exaone4.Exaone4ForCausalLM.forward = exaone4_lce_forward
2878
+
2879
+ if swiglu:
2880
+ modeling_exaone4.Exaone4MLP = LigerSwiGLUMLP
2881
+
2882
+ if model is not None:
2883
+ # The model instance already exists, so we need to additionally patch the
2884
+ # instance variables that reference already-instantiated modules
2885
+
2886
+ # get the base model from the model instance
2887
+ base_model: Exaone4Model = getattr(model, model.base_model_prefix, model)
2888
+
2889
+ if rms_norm:
2890
+ _patch_rms_norm_module(base_model.norm, in_place=False)
2891
+ for decoder_layer in base_model.layers:
2892
+ if swiglu:
2893
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
2894
+ if rms_norm:
2895
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
2896
+ _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
2897
+ _patch_rms_norm_module(decoder_layer.self_attn.q_norm, in_place=False)
2898
+ _patch_rms_norm_module(decoder_layer.self_attn.k_norm, in_place=False)
2899
+
2900
+
1596
2901
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
1597
2902
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
1598
2903
  "gemma": apply_liger_kernel_to_gemma,
@@ -1600,7 +2905,13 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1600
2905
  "gemma3_text": apply_liger_kernel_to_gemma3_text,
1601
2906
  "gemma3": apply_liger_kernel_to_gemma3,
1602
2907
  "glm4": apply_liger_kernel_to_glm4,
2908
+ "glm4v": apply_liger_kernel_to_glm4v,
2909
+ "glm4v_moe": apply_liger_kernel_to_glm4v_moe,
2910
+ "gpt_oss": apply_liger_kernel_to_gpt_oss,
2911
+ "internvl": apply_liger_kernel_to_internvl,
1603
2912
  "llama": apply_liger_kernel_to_llama,
2913
+ "llama4_text": apply_liger_kernel_to_llama4,
2914
+ "llama4": apply_liger_kernel_to_llama4,
1604
2915
  "llava": apply_liger_kernel_to_llava,
1605
2916
  "granite": apply_liger_kernel_to_granite,
1606
2917
  "mllama": apply_liger_kernel_to_mllama,
@@ -1608,6 +2919,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1608
2919
  "mistral": apply_liger_kernel_to_mistral,
1609
2920
  "mixtral": apply_liger_kernel_to_mixtral,
1610
2921
  "olmo2": apply_liger_kernel_to_olmo2,
2922
+ "olmo3": apply_liger_kernel_to_olmo3,
1611
2923
  "qwen2": apply_liger_kernel_to_qwen2,
1612
2924
  "qwen3": apply_liger_kernel_to_qwen3,
1613
2925
  "qwen3_moe": apply_liger_kernel_to_qwen3_moe,
@@ -1615,8 +2927,19 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1615
2927
  "qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
1616
2928
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
1617
2929
  "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
2930
+ "qwen3_next": apply_liger_kernel_to_qwen3_next,
2931
+ "qwen3_vl": apply_liger_kernel_to_qwen3_vl,
2932
+ "qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
2933
+ "qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
2934
+ "qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
2935
+ "smollm3": apply_liger_kernel_to_smollm3,
1618
2936
  "phi3": apply_liger_kernel_to_phi3,
1619
2937
  "paligemma": apply_liger_kernel_to_paligemma,
2938
+ "falcon_h1": apply_liger_kernel_to_falcon_h1,
2939
+ "smolvlm": apply_liger_kernel_to_smolvlm,
2940
+ "hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
2941
+ "hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
2942
+ "exaone4": apply_liger_kernel_to_exaone4,
1620
2943
  }
1621
2944
 
1622
2945