liger-kernel-nightly 0.5.10.dev20250624183504__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (68) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +54 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +4 -0
  7. liger_kernel/chunked_loss/grpo_loss.py +38 -4
  8. liger_kernel/chunked_loss/jsd_loss.py +23 -7
  9. liger_kernel/ops/cross_entropy.py +118 -62
  10. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  11. liger_kernel/ops/fused_linear_cross_entropy.py +113 -21
  12. liger_kernel/ops/geglu.py +1 -1
  13. liger_kernel/ops/layer_norm.py +124 -89
  14. liger_kernel/ops/llama4_rope.py +225 -0
  15. liger_kernel/ops/poly_norm.py +386 -0
  16. liger_kernel/ops/rms_norm.py +2 -2
  17. liger_kernel/ops/rope.py +1 -1
  18. liger_kernel/ops/swiglu.py +1 -1
  19. liger_kernel/ops/tiled_mlp.py +136 -0
  20. liger_kernel/transformers/__init__.py +50 -0
  21. liger_kernel/transformers/cross_entropy.py +8 -3
  22. liger_kernel/transformers/experimental/__init__.py +5 -0
  23. liger_kernel/transformers/functional.py +38 -6
  24. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  25. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -4
  26. liger_kernel/transformers/llama4_rope.py +93 -0
  27. liger_kernel/transformers/model/falcon_h1.py +122 -0
  28. liger_kernel/transformers/model/gemma.py +28 -8
  29. liger_kernel/transformers/model/gemma2.py +31 -8
  30. liger_kernel/transformers/model/gemma3.py +100 -110
  31. liger_kernel/transformers/model/glm4.py +18 -5
  32. liger_kernel/transformers/model/glm4v.py +163 -0
  33. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  34. liger_kernel/transformers/model/internvl.py +157 -0
  35. liger_kernel/transformers/model/llama.py +26 -7
  36. liger_kernel/transformers/model/llama4.py +121 -0
  37. liger_kernel/transformers/model/llava.py +18 -6
  38. liger_kernel/transformers/model/loss_utils.py +34 -3
  39. liger_kernel/transformers/model/mistral.py +17 -10
  40. liger_kernel/transformers/model/mixtral.py +24 -9
  41. liger_kernel/transformers/model/mllama.py +18 -7
  42. liger_kernel/transformers/model/olmo2.py +18 -5
  43. liger_kernel/transformers/model/output_classes.py +147 -0
  44. liger_kernel/transformers/model/paligemma.py +41 -5
  45. liger_kernel/transformers/model/phi3.py +24 -159
  46. liger_kernel/transformers/model/qwen2.py +26 -4
  47. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  48. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  49. liger_kernel/transformers/model/qwen3.py +22 -6
  50. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  51. liger_kernel/transformers/model/qwen3_next.py +146 -0
  52. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  53. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  54. liger_kernel/transformers/model/smollm3.py +199 -0
  55. liger_kernel/transformers/model/smolvlm.py +158 -0
  56. liger_kernel/transformers/monkey_patch.py +1090 -116
  57. liger_kernel/transformers/multi_token_attention.py +1 -1
  58. liger_kernel/transformers/poly_norm.py +42 -0
  59. liger_kernel/transformers/rms_norm.py +7 -0
  60. liger_kernel/transformers/rope.py +43 -0
  61. liger_kernel/transformers/tiled_mlp.py +133 -0
  62. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +26 -24
  63. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  64. liger_kernel_nightly-0.5.10.dev20250624183504.dist-info/RECORD +0 -95
  65. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  66. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  67. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +0 -0
  68. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,9 @@ import inspect
2
2
  import logging
3
3
 
4
4
  from functools import partial
5
+ from types import MethodType
5
6
  from typing import Callable
7
+ from typing import Optional
6
8
 
7
9
  import transformers
8
10
 
@@ -13,6 +15,7 @@ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
13
15
  from liger_kernel.transformers.functional import liger_cross_entropy
14
16
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
15
17
  from liger_kernel.transformers.layer_norm import LigerLayerNorm
18
+ from liger_kernel.transformers.model.falcon_h1 import lce_forward as falcon_h1_lce_forward
16
19
  from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
17
20
  from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
18
21
  from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
@@ -25,12 +28,14 @@ from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_f
25
28
  from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
26
29
  from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
27
30
  from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
28
- from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
29
31
  from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
30
32
  from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
33
+ from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
31
34
  from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
32
35
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
33
36
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
37
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb_with_cast
38
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb_with_cast_and_leading_batch
34
39
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
35
40
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
36
41
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
@@ -76,8 +81,8 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
76
81
  _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
77
82
  _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
78
83
  _bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
79
- module.modules_to_save.default.__class__.__name__ = LigerRMSNorm.__name__
80
- module.original_module.__class__.__name__ = LigerRMSNorm.__name__
84
+ _bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerRMSNorm.__name__)
85
+ _bind_method_to_module(module.original_module, "_get_name", lambda self: LigerRMSNorm.__name__)
81
86
  else:
82
87
  module.offset = offset
83
88
  module.casting_mode = casting_mode
@@ -86,7 +91,7 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
86
91
  module.row_mode = row_mode
87
92
  _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
88
93
  _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
89
- module.__class__.__name__ = LigerRMSNorm.__name__
94
+ _bind_method_to_module(module, "_get_name", lambda self: LigerRMSNorm.__name__)
90
95
 
91
96
 
92
97
  def _patch_layer_norm_module(module, eps=1e-6):
@@ -108,28 +113,28 @@ def _patch_layer_norm_module(module, eps=1e-6):
108
113
  module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
109
114
  module, "normalized_shape", None
110
115
  )
111
- _bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
112
- _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
113
- _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
114
- _bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
115
- module.modules_to_save.default.__class__.__name__ = LigerLayerNorm.__name__
116
- module.original_module.__class__.__name__ = LigerLayerNorm.__name__
116
+ _bind_method_to_module(module.modules_to_save.default, "forward", LigerLayerNorm.forward)
117
+ _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerLayerNorm.extra_repr)
118
+ _bind_method_to_module(module.original_module, "forward", LigerLayerNorm.forward)
119
+ _bind_method_to_module(module.original_module, "extra_repr", LigerLayerNorm.extra_repr)
120
+ _bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerLayerNorm.__name__)
121
+ _bind_method_to_module(module.original_module, "_get_name", lambda self: LigerLayerNorm.__name__)
117
122
  else:
118
123
  module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
119
124
  module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
120
125
  _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
121
126
  _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
122
- module.__class__.__name__ = LigerLayerNorm.__name__
127
+ _bind_method_to_module(module, "_get_name", lambda self: LigerLayerNorm.__name__)
123
128
 
124
129
 
125
130
  def _patch_swiglu_module(module, liger_module):
126
131
  _bind_method_to_module(module, "forward", liger_module.forward)
127
- module.__class__.__name__ = liger_module.__name__
132
+ _bind_method_to_module(module, "_get_name", lambda self: liger_module.__name__)
128
133
 
129
134
 
130
135
  def _patch_geglu_module(module):
131
136
  _bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
132
- module.__class__.__name__ = LigerGEGLUMLP.__name__
137
+ _bind_method_to_module(module, "_get_name", lambda self: LigerGEGLUMLP.__name__)
133
138
 
134
139
 
135
140
  def apply_liger_kernel_to_granite(
@@ -260,10 +265,16 @@ def apply_liger_kernel_to_llama(
260
265
 
261
266
  if fused_linear_cross_entropy:
262
267
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
263
- modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
268
+ if model is not None:
269
+ model.forward = MethodType(llama_lce_forward, model)
270
+ else:
271
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
264
272
  else: # if version < 4.46.1
265
273
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
266
- modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
274
+ if model is not None:
275
+ model.forward = MethodType(llama_lce_forward_deprecated, model)
276
+ else:
277
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
267
278
 
268
279
  if model is not None:
269
280
  # The model instance already exists, so we need to additionally patch the
@@ -283,6 +294,77 @@ def apply_liger_kernel_to_llama(
283
294
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
284
295
 
285
296
 
297
+ def apply_liger_kernel_to_smollm3(
298
+ rope: bool = True,
299
+ cross_entropy: bool = False,
300
+ fused_linear_cross_entropy: bool = True,
301
+ rms_norm: bool = True,
302
+ swiglu: bool = True,
303
+ model: PreTrainedModel = None,
304
+ ) -> None:
305
+ """
306
+ Apply Liger kernels to replace original implementation in HuggingFace SmolLM3 model
307
+
308
+ Args:
309
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
310
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
311
+ fused_linear_cross_entropy (bool):
312
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
313
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
314
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
315
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
316
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
317
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
318
+ loaded. Default is None.
319
+ """
320
+
321
+ assert not (cross_entropy and fused_linear_cross_entropy), (
322
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
323
+ )
324
+
325
+ from transformers.models.smollm3 import modeling_smollm3
326
+ from transformers.models.smollm3.modeling_smollm3 import SmolLM3Model
327
+
328
+ if rope:
329
+ modeling_smollm3.apply_rotary_pos_emb = liger_rotary_pos_emb
330
+ if rms_norm:
331
+ modeling_smollm3.SmolLM3RMSNorm = LigerRMSNorm
332
+ if swiglu:
333
+ modeling_smollm3.SmolLM3MLP = LigerSwiGLUMLP
334
+
335
+ if cross_entropy:
336
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
337
+ from transformers.loss.loss_utils import nn
338
+
339
+ nn.functional.cross_entropy = liger_cross_entropy
340
+ else:
341
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
342
+ modeling_smollm3.CrossEntropyLoss = LigerCrossEntropyLoss
343
+
344
+ if fused_linear_cross_entropy:
345
+ if model is not None:
346
+ model.forward = MethodType(smollm3_lce_forward, model)
347
+ else:
348
+ modeling_smollm3.SmolLM3ForCausalLM.forward = smollm3_lce_forward
349
+
350
+ if model is not None:
351
+ # The model instance already exists, so we need to additionally patch the
352
+ # instance variables that reference already-instantiated modules (e.g. SmolLM3RMSNorm or SmolLM3MLP)
353
+
354
+ # get the base model from the model instance
355
+ base_model: SmolLM3Model = getattr(model, model.base_model_prefix, model)
356
+
357
+ if rms_norm:
358
+ _patch_rms_norm_module(base_model.norm)
359
+
360
+ for decoder_layer in base_model.layers:
361
+ if swiglu:
362
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
363
+ if rms_norm:
364
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
365
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
366
+
367
+
286
368
  def apply_liger_kernel_to_llava(
287
369
  cross_entropy: bool = False,
288
370
  fused_linear_cross_entropy: bool = True,
@@ -318,9 +400,15 @@ def apply_liger_kernel_to_llava(
318
400
  modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
319
401
  if fused_linear_cross_entropy:
320
402
  if transformer_version >= version.parse("4.52.0"):
321
- modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
403
+ if model is not None:
404
+ model.forward = MethodType(llava_lce_forward, model)
405
+ else:
406
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
322
407
  elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
323
- modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
408
+ if model is not None:
409
+ model.forward = MethodType(llava_lce_forward_deprecated, model)
410
+ else:
411
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
324
412
  else: # if version < 4.49.0
325
413
  logger.warning(
326
414
  "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
@@ -363,6 +451,97 @@ def apply_liger_kernel_to_llava(
363
451
  logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
364
452
 
365
453
 
454
+ def apply_liger_kernel_to_llama4(
455
+ rope: bool = True,
456
+ cross_entropy: bool = False,
457
+ fused_linear_cross_entropy: bool = True,
458
+ rms_norm: bool = True,
459
+ swiglu: bool = True,
460
+ model: PreTrainedModel = None,
461
+ layer_norm: bool = True,
462
+ ) -> None:
463
+ """
464
+ Apply Liger kernels to replace original implementation in HuggingFace Llama4 models.
465
+
466
+ Args:
467
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
468
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
469
+ fused_linear_cross_entropy (bool):
470
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
471
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
472
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
473
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
474
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
475
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
476
+ loaded. Default is None.
477
+ """
478
+ assert not (cross_entropy and fused_linear_cross_entropy), (
479
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
480
+ )
481
+
482
+ from transformers.models.llama4 import modeling_llama4
483
+ from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
484
+ from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
485
+ from transformers.models.llama4.modeling_llama4 import Llama4TextModel
486
+ from transformers.models.llama4.modeling_llama4 import Llama4VisionModel
487
+
488
+ from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
489
+
490
+ if rope:
491
+ from liger_kernel.transformers.llama4_rope import apply_liger_llama4_rope_full
492
+
493
+ apply_liger_llama4_rope_full(modeling_llama4)
494
+ if rms_norm:
495
+ modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
496
+ if swiglu:
497
+ modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
498
+
499
+ if cross_entropy:
500
+ modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
501
+
502
+ if fused_linear_cross_entropy:
503
+ modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
504
+
505
+ if model is not None:
506
+ # The model instance already exists, so we need to additionally patch the
507
+ # instance variables that reference already-instantiated modules
508
+ if isinstance(model, Llama4ForConditionalGeneration):
509
+ language_model: Llama4ForCausalLM = model.language_model
510
+ vision_model: Llama4VisionModel = model.vision_model
511
+ text_model: Llama4TextModel = language_model.model
512
+ elif isinstance(model, Llama4ForCausalLM):
513
+ text_model = model.model
514
+ vision_model = None
515
+ elif isinstance(model, Llama4TextModel):
516
+ text_model = model
517
+ vision_model = None
518
+
519
+ else:
520
+ raise ValueError(f"Unsupported Llama4 model type: {type(model)}")
521
+
522
+ if text_model:
523
+ if rms_norm:
524
+ _patch_rms_norm_module(text_model.norm)
525
+ for decoder_layer in text_model.layers:
526
+ if swiglu:
527
+ if decoder_layer.is_moe_layer:
528
+ _patch_swiglu_module(decoder_layer.feed_forward.shared_expert, LigerSwiGLUMLP)
529
+ else:
530
+ _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
531
+ if rms_norm:
532
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
533
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
534
+
535
+ if vision_model:
536
+ _patch_layer_norm_module(vision_model.layernorm_pre)
537
+ _patch_layer_norm_module(vision_model.layernorm_post)
538
+
539
+ for layer in vision_model.model.layers:
540
+ if layer_norm:
541
+ _patch_layer_norm_module(layer.input_layernorm)
542
+ _patch_layer_norm_module(layer.post_attention_layernorm)
543
+
544
+
366
545
  def apply_liger_kernel_to_mllama(
367
546
  rope: bool = True,
368
547
  cross_entropy: bool = False,
@@ -404,7 +583,7 @@ def apply_liger_kernel_to_mllama(
404
583
 
405
584
  if rope:
406
585
  modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
407
- if layer_norm:
586
+ if layer_norm and model is None:
408
587
  modeling_mllama.nn.LayerNorm = LigerLayerNorm
409
588
  if rms_norm:
410
589
  modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
@@ -420,10 +599,16 @@ def apply_liger_kernel_to_mllama(
420
599
  modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
421
600
  if fused_linear_cross_entropy:
422
601
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
423
- modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
602
+ if model is not None:
603
+ model.forward = MethodType(mllama_lce_forward, model)
604
+ else:
605
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
424
606
  else: # if version < 4.46.1
425
607
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
426
- modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
608
+ if model is not None:
609
+ model.forward = MethodType(mllama_lce_forward_deprecated, model)
610
+ else:
611
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
427
612
 
428
613
  if model is not None:
429
614
  # The model instance already exists, so we need to additionally patch the
@@ -432,7 +617,10 @@ def apply_liger_kernel_to_mllama(
432
617
  if isinstance(model, MllamaForConditionalGeneration):
433
618
  language_model: MllamaForCausalLM = model.language_model
434
619
  vision_model: MllamaVisionModel = model.vision_model
435
- text_model: MllamaTextModel = language_model
620
+ if isinstance(language_model, MllamaForCausalLM):
621
+ text_model: MllamaTextModel = language_model.model
622
+ else:
623
+ text_model = language_model
436
624
  elif isinstance(model, MllamaForCausalLM):
437
625
  text_model = model.model
438
626
  vision_model = None
@@ -506,7 +694,17 @@ def apply_liger_kernel_to_mistral(
506
694
  if cross_entropy:
507
695
  modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
508
696
  if fused_linear_cross_entropy:
509
- modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
697
+ if transformer_version >= version.parse("4.49.0"):
698
+ if model is not None:
699
+ model.forward = MethodType(mistral_lce_forward, model)
700
+ else:
701
+ modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
702
+ else:
703
+ logger.warning(
704
+ "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
705
+ )
706
+ logger.warning("LigerFusedLinearCrossEntropy patch is not applied.")
707
+
510
708
  if swiglu:
511
709
  modeling_mistral.MistralMLP = LigerSwiGLUMLP
512
710
 
@@ -574,10 +772,16 @@ def apply_liger_kernel_to_mixtral(
574
772
 
575
773
  if fused_linear_cross_entropy:
576
774
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
577
- modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
775
+ if model is not None:
776
+ model.forward = MethodType(mixtral_lce_forward, model)
777
+ else:
778
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
578
779
  else: # if version < 4.46.1
579
780
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
580
- modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
781
+ if model is not None:
782
+ model.forward = MethodType(mixtral_lce_forward_deprecated, model)
783
+ else:
784
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
581
785
  if swiglu:
582
786
  modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
583
787
 
@@ -651,10 +855,16 @@ def apply_liger_kernel_to_gemma(
651
855
  modeling_gemma.GemmaMLP = LigerGEGLUMLP
652
856
  if fused_linear_cross_entropy:
653
857
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
654
- modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
858
+ if model is not None:
859
+ model.forward = MethodType(gemma_lce_forward, model)
860
+ else:
861
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
655
862
  else: # if version < 4.46.1
656
863
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
657
- modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
864
+ if model is not None:
865
+ model.forward = MethodType(gemma_lce_forward_deprecated, model)
866
+ else:
867
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
658
868
 
659
869
  if model is not None:
660
870
  # The model instance already exists, so we need to additionally patch the
@@ -726,10 +936,16 @@ def apply_liger_kernel_to_gemma2(
726
936
  modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
727
937
  if fused_linear_cross_entropy:
728
938
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
729
- modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
939
+ if model is not None:
940
+ model.forward = MethodType(gemma2_lce_forward, model)
941
+ else:
942
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
730
943
  else:
731
944
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
732
- modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
945
+ if model is not None:
946
+ model.forward = MethodType(gemma2_lce_forward_deprected, model)
947
+ else:
948
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
733
949
  if geglu:
734
950
  modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
735
951
 
@@ -808,7 +1024,10 @@ def apply_liger_kernel_to_gemma3_text(
808
1024
  nn.functional.cross_entropy = liger_cross_entropy
809
1025
 
810
1026
  if fused_linear_cross_entropy:
811
- modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
1027
+ if model is not None:
1028
+ model.forward = MethodType(causal_forward, model)
1029
+ else:
1030
+ modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
812
1031
 
813
1032
  if model is not None:
814
1033
  # The model instance already exists, so we need to additionally patch the
@@ -878,7 +1097,7 @@ def apply_liger_kernel_to_gemma3(
878
1097
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
879
1098
  )
880
1099
 
881
- if layer_norm:
1100
+ if layer_norm and model is None:
882
1101
  modeling_siglip.nn.LayerNorm = LigerLayerNorm
883
1102
 
884
1103
  apply_liger_kernel_to_gemma3_text(
@@ -889,7 +1108,10 @@ def apply_liger_kernel_to_gemma3(
889
1108
  modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
890
1109
 
891
1110
  if fused_linear_cross_entropy:
892
- modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
1111
+ if model is not None:
1112
+ model.forward = MethodType(multimodal_forward, model)
1113
+ else:
1114
+ modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
893
1115
 
894
1116
  if model is not None:
895
1117
  # The model instance already exists, so we need to additionally patch the
@@ -957,7 +1179,9 @@ def apply_liger_kernel_to_paligemma(
957
1179
  # PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
958
1180
 
959
1181
  from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
1182
+ from transformers.models.gemma.modeling_gemma import GemmaModel
960
1183
  from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
1184
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
961
1185
  from transformers.models.paligemma import modeling_paligemma
962
1186
  from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
963
1187
  from transformers.models.siglip import modeling_siglip
@@ -968,7 +1192,7 @@ def apply_liger_kernel_to_paligemma(
968
1192
  from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
969
1193
 
970
1194
  # The vision_tower is a SiglipVisionModel
971
- if layer_norm:
1195
+ if layer_norm and model is None:
972
1196
  modeling_siglip.nn.LayerNorm = LigerLayerNorm
973
1197
 
974
1198
  # SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
@@ -986,10 +1210,16 @@ def apply_liger_kernel_to_paligemma(
986
1210
  modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
987
1211
  if fused_linear_cross_entropy:
988
1212
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
989
- modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
1213
+ if model is not None:
1214
+ model.forward = MethodType(lce_forward, model)
1215
+ else:
1216
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
990
1217
  else: # if version < 4.46.1
991
1218
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
992
- modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
1219
+ if model is not None:
1220
+ model.forward = MethodType(lce_forward_deprecated, model)
1221
+ else:
1222
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
993
1223
 
994
1224
  if model is not None:
995
1225
  # The model instance already exists, so we need to additionally patch the
@@ -1010,7 +1240,7 @@ def apply_liger_kernel_to_paligemma(
1010
1240
 
1011
1241
  language_model = model.language_model
1012
1242
 
1013
- if isinstance(language_model, GemmaForCausalLM):
1243
+ if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
1014
1244
  apply_liger_kernel_to_gemma(
1015
1245
  rope=rope,
1016
1246
  cross_entropy=False,
@@ -1020,7 +1250,7 @@ def apply_liger_kernel_to_paligemma(
1020
1250
  model=language_model,
1021
1251
  )
1022
1252
 
1023
- elif isinstance(language_model, Gemma2ForCausalLM):
1253
+ elif isinstance(language_model, (Gemma2ForCausalLM, Gemma2Model)):
1024
1254
  apply_liger_kernel_to_gemma2(
1025
1255
  rope=rope,
1026
1256
  cross_entropy=False,
@@ -1081,10 +1311,16 @@ def apply_liger_kernel_to_qwen2(
1081
1311
 
1082
1312
  if fused_linear_cross_entropy:
1083
1313
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1084
- modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
1314
+ if model is not None:
1315
+ model.forward = MethodType(qwen2_lce_forward, model)
1316
+ else:
1317
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
1085
1318
  else: # if version < 4.46.1
1086
1319
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1087
- modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
1320
+ if model is not None:
1321
+ model.forward = MethodType(qwen2_lce_forward_deprecated, model)
1322
+ else:
1323
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
1088
1324
 
1089
1325
  if swiglu:
1090
1326
  modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
@@ -1105,7 +1341,6 @@ def apply_liger_kernel_to_qwen2(
1105
1341
  if rms_norm:
1106
1342
  _patch_rms_norm_module(decoder_layer.input_layernorm)
1107
1343
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1108
- print("Applied Liger kernels to Qwen2")
1109
1344
 
1110
1345
 
1111
1346
  def apply_liger_kernel_to_qwen3(
@@ -1140,7 +1375,10 @@ def apply_liger_kernel_to_qwen3(
1140
1375
  nn.functional.cross_entropy = liger_cross_entropy
1141
1376
 
1142
1377
  if fused_linear_cross_entropy:
1143
- modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
1378
+ if model is not None:
1379
+ model.forward = MethodType(qwen3_lce_forward, model)
1380
+ else:
1381
+ modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
1144
1382
 
1145
1383
  if swiglu:
1146
1384
  modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
@@ -1195,7 +1433,10 @@ def apply_liger_kernel_to_qwen3_moe(
1195
1433
  nn.functional.cross_entropy = liger_cross_entropy
1196
1434
 
1197
1435
  if fused_linear_cross_entropy:
1198
- modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
1436
+ if model is not None:
1437
+ model.forward = MethodType(qwen3_lce_forward, model)
1438
+ else:
1439
+ modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
1199
1440
 
1200
1441
  if swiglu:
1201
1442
  modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
@@ -1264,12 +1505,15 @@ def apply_liger_kernel_to_qwen2_vl(
1264
1505
  if rms_norm:
1265
1506
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
1266
1507
  modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
1267
- if layer_norm:
1508
+ if layer_norm and model is None:
1268
1509
  modeling_qwen2_vl.LayerNorm = LigerLayerNorm
1269
1510
  if cross_entropy:
1270
1511
  modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1271
1512
  if fused_linear_cross_entropy:
1272
- modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
1513
+ if model is not None:
1514
+ model.forward = MethodType(qwen2_vl_lce_forward, model)
1515
+ else:
1516
+ modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
1273
1517
  if swiglu:
1274
1518
  modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
1275
1519
 
@@ -1357,7 +1601,10 @@ def apply_liger_kernel_to_qwen2_5_vl(
1357
1601
  if cross_entropy:
1358
1602
  modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1359
1603
  if fused_linear_cross_entropy:
1360
- modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
1604
+ if model is not None:
1605
+ model.forward = MethodType(qwen2_5_vl_lce_forward, model)
1606
+ else:
1607
+ modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
1361
1608
  if swiglu:
1362
1609
  modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
1363
1610
 
@@ -1398,141 +1645,160 @@ def apply_liger_kernel_to_qwen2_5_vl(
1398
1645
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1399
1646
 
1400
1647
 
1401
- def apply_liger_kernel_to_phi3(
1648
+ def apply_liger_kernel_to_qwen3_vl(
1402
1649
  rope: bool = True,
1403
1650
  cross_entropy: bool = False,
1404
1651
  fused_linear_cross_entropy: bool = True,
1405
1652
  rms_norm: bool = True,
1406
- swiglu: bool = True,
1653
+ swiglu: bool = False,
1407
1654
  model: PreTrainedModel = None,
1408
1655
  ) -> None:
1409
1656
  """
1410
- Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
1657
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
1411
1658
 
1412
1659
  Args:
1413
- rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1414
1660
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1415
1661
  fused_linear_cross_entropy (bool):
1416
1662
  Whether to apply Liger's fused linear cross entropy loss. Default is True.
1417
1663
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1418
1664
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1419
1665
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1420
- swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
1666
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1421
1667
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1422
1668
  loaded. Default is None.
1423
1669
  """
1670
+
1424
1671
  assert not (cross_entropy and fused_linear_cross_entropy), (
1425
1672
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
1426
1673
  )
1427
1674
 
1428
- from transformers.models.phi3 import modeling_phi3
1429
- from transformers.models.phi3.modeling_phi3 import Phi3Model
1675
+ from transformers.models.qwen3_vl import modeling_qwen3_vl
1676
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
1677
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
1678
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
1679
+
1680
+ from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
1430
1681
 
1431
1682
  if rope:
1432
- modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
1683
+ modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
1684
+ modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
1685
+
1433
1686
  if rms_norm:
1434
- modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
1435
- if swiglu:
1436
- modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
1687
+ modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
1688
+
1437
1689
  if cross_entropy:
1438
- if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1439
- from transformers.loss.loss_utils import nn
1690
+ from transformers.loss.loss_utils import nn
1440
1691
 
1441
- nn.functional.cross_entropy = liger_cross_entropy
1442
- else:
1443
- logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1444
- modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
1445
- if fused_linear_cross_entropy:
1446
- if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1447
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1448
- else: # if version < 4.46.1
1449
- logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1450
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
1692
+ nn.functional.cross_entropy = liger_cross_entropy
1451
1693
 
1452
- if model is not None:
1453
- # The model instance already exists, so we need to additionally patch the
1454
- # instance variables that reference already-instantiated modules
1694
+ if fused_linear_cross_entropy:
1695
+ if model is not None:
1696
+ model.forward = MethodType(qwen3_vl_lce_forward, model)
1697
+ else:
1698
+ modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
1455
1699
 
1456
- # get the base model from the model instance
1457
- base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
1700
+ if model is not None and rms_norm:
1701
+ if isinstance(model, (Qwen3VLForConditionalGeneration, Qwen3VLModel)):
1702
+ text_model: Qwen3VLTextModel = model.language_model
1703
+ elif isinstance(model, Qwen3VLTextModel):
1704
+ text_model = model
1705
+ else:
1706
+ raise TypeError(
1707
+ f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
1708
+ )
1458
1709
 
1459
- if rms_norm:
1460
- _patch_rms_norm_module(base_model.norm)
1710
+ _patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1461
1711
 
1462
- for decoder_layer in base_model.layers:
1463
- if swiglu:
1464
- _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
1465
- if rms_norm:
1466
- _patch_rms_norm_module(decoder_layer.input_layernorm)
1467
- _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1712
+ if text_model is not None:
1713
+ _patch_qwen3_vl_rms_norm(text_model.norm)
1714
+ for decoder_layer in text_model.layers:
1715
+ _patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
1716
+ _patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
1717
+ self_attn = getattr(decoder_layer, "self_attn", None)
1718
+ if self_attn is not None:
1719
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1720
+ _patch_qwen3_vl_rms_norm(self_attn.q_norm)
1721
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1722
+ _patch_qwen3_vl_rms_norm(self_attn.k_norm)
1468
1723
 
1469
1724
 
1470
- def apply_liger_kernel_to_olmo2(
1725
+ def apply_liger_kernel_to_qwen3_vl_moe(
1471
1726
  rope: bool = True,
1472
1727
  cross_entropy: bool = False,
1473
1728
  fused_linear_cross_entropy: bool = True,
1474
1729
  rms_norm: bool = True,
1475
- swiglu: bool = True,
1730
+ swiglu: bool = False,
1476
1731
  model: PreTrainedModel = None,
1477
1732
  ) -> None:
1478
1733
  """
1479
- Apply Liger kernels to replace original implementation in HuggingFace OLMO2 models.
1734
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
1480
1735
 
1481
1736
  Args:
1482
- rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1483
1737
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1484
1738
  fused_linear_cross_entropy (bool):
1485
- Whether to apply Liger's fused linear cross entropy loss. Default is True.
1486
- `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1487
- If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1739
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
1488
1740
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1489
- swiglu (bool): Whether to apply Liger's SwiGLU Olmo2MLP. Default is True.
1741
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1490
1742
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1491
1743
  loaded. Default is None.
1492
1744
  """
1745
+
1493
1746
  assert not (cross_entropy and fused_linear_cross_entropy), (
1494
1747
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
1495
1748
  )
1496
1749
 
1497
- from transformers.models.olmo2 import modeling_olmo2
1498
- from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
1750
+ from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
1751
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
1752
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
1753
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
1499
1754
 
1500
- from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
1501
- from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
1755
+ from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
1502
1756
 
1503
1757
  if rope:
1504
- modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
1758
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
1759
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
1760
+
1505
1761
  if rms_norm:
1506
- modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
1507
- if swiglu:
1508
- modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
1762
+ modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
1763
+
1509
1764
  if cross_entropy:
1510
1765
  from transformers.loss.loss_utils import nn
1511
1766
 
1512
1767
  nn.functional.cross_entropy = liger_cross_entropy
1513
- if fused_linear_cross_entropy:
1514
- modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
1515
1768
 
1516
- if model is not None:
1517
- # The model instance already exists, so we need to additionally patch the
1518
- # instance variables that reference already-instantiated modules
1769
+ if fused_linear_cross_entropy:
1770
+ if model is not None:
1771
+ model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
1772
+ else:
1773
+ modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
1519
1774
 
1520
- # get the base model from the model instance
1521
- base_model: Olmo2Model = getattr(model, model.base_model_prefix, model)
1775
+ if model is not None and rms_norm:
1776
+ if isinstance(model, (Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel)):
1777
+ text_model: Qwen3VLMoeTextModel = model.language_model
1778
+ elif isinstance(model, Qwen3VLMoeTextModel):
1779
+ text_model = model
1780
+ else:
1781
+ raise TypeError(
1782
+ f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
1783
+ )
1522
1784
 
1523
- if rms_norm:
1524
- _patch_rms_norm_module(base_model.norm)
1785
+ _patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1525
1786
 
1526
- for decoder_layer in base_model.layers:
1527
- if swiglu:
1528
- _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1529
- if rms_norm:
1530
- _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
1531
- _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
1787
+ if text_model is not None:
1788
+ _patch_qwen3_vl_moe_rms_norm(text_model.norm)
1789
+ for decoder_layer in text_model.layers:
1790
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
1791
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
1792
+ self_attn = getattr(decoder_layer, "self_attn", None)
1793
+ if self_attn is not None:
1794
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1795
+ _patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
1796
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1797
+ _patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
1532
1798
 
1533
1799
 
1534
- def apply_liger_kernel_to_glm4(
1535
- rope: bool = False,
1800
+ def apply_liger_kernel_to_phi3(
1801
+ rope: bool = True,
1536
1802
  cross_entropy: bool = False,
1537
1803
  fused_linear_cross_entropy: bool = True,
1538
1804
  rms_norm: bool = True,
@@ -1540,10 +1806,141 @@ def apply_liger_kernel_to_glm4(
1540
1806
  model: PreTrainedModel = None,
1541
1807
  ) -> None:
1542
1808
  """
1543
- Apply Liger kernels to replace original implementation in HuggingFace GLM-4 models.
1809
+ Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
1544
1810
 
1545
1811
  Args:
1546
- rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
1812
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1813
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1814
+ fused_linear_cross_entropy (bool):
1815
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1816
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1817
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1818
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1819
+ swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
1820
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1821
+ loaded. Default is None.
1822
+ """
1823
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1824
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1825
+ )
1826
+
1827
+ from transformers.models.phi3 import modeling_phi3
1828
+ from transformers.models.phi3.modeling_phi3 import Phi3Model
1829
+
1830
+ if rope:
1831
+ modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
1832
+ if rms_norm:
1833
+ modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
1834
+ if swiglu:
1835
+ modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
1836
+ if cross_entropy:
1837
+ from transformers.loss.loss_utils import nn
1838
+
1839
+ nn.functional.cross_entropy = liger_cross_entropy
1840
+ if fused_linear_cross_entropy:
1841
+ if model is not None:
1842
+ model.forward = MethodType(phi3_lce_forward, model)
1843
+ else:
1844
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1845
+
1846
+ if model is not None:
1847
+ # The model instance already exists, so we need to additionally patch the
1848
+ # instance variables that reference already-instantiated modules
1849
+
1850
+ # get the base model from the model instance
1851
+ base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
1852
+
1853
+ if rms_norm:
1854
+ _patch_rms_norm_module(base_model.norm)
1855
+
1856
+ for decoder_layer in base_model.layers:
1857
+ if swiglu:
1858
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
1859
+ if rms_norm:
1860
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1861
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1862
+
1863
+
1864
+ def apply_liger_kernel_to_olmo2(
1865
+ rope: bool = True,
1866
+ cross_entropy: bool = False,
1867
+ fused_linear_cross_entropy: bool = True,
1868
+ rms_norm: bool = True,
1869
+ swiglu: bool = True,
1870
+ model: PreTrainedModel = None,
1871
+ ) -> None:
1872
+ """
1873
+ Apply Liger kernels to replace original implementation in HuggingFace OLMO2 models.
1874
+
1875
+ Args:
1876
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1877
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1878
+ fused_linear_cross_entropy (bool):
1879
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1880
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1881
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1882
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1883
+ swiglu (bool): Whether to apply Liger's SwiGLU Olmo2MLP. Default is True.
1884
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1885
+ loaded. Default is None.
1886
+ """
1887
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1888
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1889
+ )
1890
+
1891
+ from transformers.models.olmo2 import modeling_olmo2
1892
+ from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
1893
+
1894
+ from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
1895
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
1896
+
1897
+ if rope:
1898
+ modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
1899
+ if rms_norm:
1900
+ modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
1901
+ if swiglu:
1902
+ modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
1903
+ if cross_entropy:
1904
+ from transformers.loss.loss_utils import nn
1905
+
1906
+ nn.functional.cross_entropy = liger_cross_entropy
1907
+ if fused_linear_cross_entropy:
1908
+ if model is not None:
1909
+ model.forward = MethodType(olmo2_lce_forward, model)
1910
+ else:
1911
+ modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
1912
+
1913
+ if model is not None:
1914
+ # The model instance already exists, so we need to additionally patch the
1915
+ # instance variables that reference already-instantiated modules
1916
+
1917
+ # get the base model from the model instance
1918
+ base_model: Olmo2Model = getattr(model, model.base_model_prefix, model)
1919
+
1920
+ if rms_norm:
1921
+ _patch_rms_norm_module(base_model.norm)
1922
+
1923
+ for decoder_layer in base_model.layers:
1924
+ if swiglu:
1925
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1926
+ if rms_norm:
1927
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
1928
+ _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
1929
+
1930
+
1931
+ def apply_liger_kernel_to_glm4(
1932
+ rope: bool = False,
1933
+ cross_entropy: bool = False,
1934
+ fused_linear_cross_entropy: bool = True,
1935
+ rms_norm: bool = True,
1936
+ swiglu: bool = True,
1937
+ model: PreTrainedModel = None,
1938
+ ) -> None:
1939
+ """
1940
+ Apply Liger kernels to replace original implementation in HuggingFace GLM-4 models.
1941
+
1942
+ Args:
1943
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
1547
1944
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1548
1945
  fused_linear_cross_entropy (bool):
1549
1946
  Whether to apply Liger's fused linear cross entropy loss. Default is True.
@@ -1575,7 +1972,10 @@ def apply_liger_kernel_to_glm4(
1575
1972
 
1576
1973
  nn.functional.cross_entropy = liger_cross_entropy
1577
1974
  if fused_linear_cross_entropy:
1578
- modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
1975
+ if model is not None:
1976
+ model.forward = MethodType(glm4_lce_forward, model)
1977
+ else:
1978
+ modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
1579
1979
 
1580
1980
  if model is not None:
1581
1981
  # The model instance already exists, so we need to additionally patch the
@@ -1597,6 +1997,567 @@ def apply_liger_kernel_to_glm4(
1597
1997
  _patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
1598
1998
 
1599
1999
 
2000
+ def apply_liger_kernel_to_glm4v(
2001
+ rope: bool = False,
2002
+ cross_entropy: bool = False,
2003
+ fused_linear_cross_entropy: bool = True,
2004
+ rms_norm: bool = True,
2005
+ swiglu: bool = True,
2006
+ model: PreTrainedModel = None,
2007
+ ) -> None:
2008
+ """
2009
+ Apply Liger kernels to replace original implementation in HuggingFace GLM-4v models.
2010
+
2011
+ Args:
2012
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2013
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2014
+ fused_linear_cross_entropy (bool):
2015
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2016
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2017
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2018
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2019
+ swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
2020
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2021
+ loaded. Default is None.
2022
+ """
2023
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2024
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2025
+ )
2026
+
2027
+ from transformers.models.glm4v import modeling_glm4v
2028
+ from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration
2029
+ from transformers.models.glm4v.modeling_glm4v import Glm4vModel
2030
+ from transformers.models.glm4v.modeling_glm4v import Glm4vTextModel
2031
+ from transformers.models.glm4v.modeling_glm4v import Glm4vVisionModel
2032
+
2033
+ from liger_kernel.transformers.model.glm4v import lce_forward as glm4v_lce_forward
2034
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
2035
+
2036
+ if rope:
2037
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
2038
+ if rms_norm:
2039
+ modeling_glm4v.Glm4vRMSNorm = LigerRMSNormForGlm4
2040
+ if cross_entropy:
2041
+ from transformers.loss.loss_utils import nn
2042
+
2043
+ nn.functional.cross_entropy = liger_cross_entropy
2044
+ if fused_linear_cross_entropy:
2045
+ if model is not None:
2046
+ model.forward = MethodType(glm4v_lce_forward, model)
2047
+ else:
2048
+ modeling_glm4v.Glm4vForConditionalGeneration.forward = glm4v_lce_forward
2049
+
2050
+ if model is not None:
2051
+ # The model instance already exists, so we need to additionally patch the
2052
+ # instance variables that reference already-instantiated modules
2053
+ if isinstance(model, (Glm4vForConditionalGeneration, Glm4vModel)):
2054
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
2055
+ # Not sure if it is subject to changes in the future.
2056
+ # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py#L1305
2057
+ text_model: Glm4vTextModel = model.language_model
2058
+ vision_model: Glm4vVisionModel = model.visual
2059
+ elif isinstance(model, Glm4vTextModel):
2060
+ text_model: Glm4vTextModel = model
2061
+ vision_model = None
2062
+ else:
2063
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
2064
+ raise TypeError(
2065
+ f"Unsupported glm4.1v model type. `model` must be `Glm4VLForConditionalGeneration`, `Glm4vVisionModel` or `Glm4vTextModel`. Got: {type(model)}"
2066
+ )
2067
+
2068
+ if vision_model is not None:
2069
+ for vision_block in vision_model.blocks:
2070
+ if rms_norm:
2071
+ _patch_rms_norm_module(vision_block.norm1)
2072
+ _patch_rms_norm_module(vision_block.norm2)
2073
+ if swiglu:
2074
+ _patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
2075
+
2076
+ if text_model is not None:
2077
+ if rms_norm:
2078
+ _patch_rms_norm_module(text_model.norm)
2079
+ for decoder_layer in text_model.layers:
2080
+ if swiglu:
2081
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
2082
+ if rms_norm:
2083
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2084
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2085
+ _patch_rms_norm_module(decoder_layer.post_self_attn_layernorm)
2086
+ _patch_rms_norm_module(decoder_layer.post_mlp_layernorm)
2087
+
2088
+
2089
+ def apply_liger_kernel_to_glm4v_moe(
2090
+ rope: bool = False,
2091
+ cross_entropy: bool = False,
2092
+ fused_linear_cross_entropy: bool = True,
2093
+ rms_norm: bool = True,
2094
+ swiglu: bool = True,
2095
+ model: PreTrainedModel = None,
2096
+ ) -> None:
2097
+ """
2098
+ Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
2099
+
2100
+ Args:
2101
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2102
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2103
+ fused_linear_cross_entropy (bool):
2104
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2105
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2106
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2107
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2108
+ swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
2109
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2110
+ loaded. Default is None.
2111
+ """
2112
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2113
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2114
+ )
2115
+
2116
+ from transformers.models.glm4v_moe import modeling_glm4v_moe
2117
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeForConditionalGeneration
2118
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeModel
2119
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextModel
2120
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeVisionModel
2121
+
2122
+ from liger_kernel.transformers.model.glm4v_moe import lce_forward as glm4v_moe_lce_forward
2123
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
2124
+
2125
+ if rope:
2126
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
2127
+ if rms_norm:
2128
+ modeling_glm4v_moe.Glm4vMoeRMSNorm = LigerRMSNormForGlm4
2129
+ modeling_glm4v_moe.Glm4vMoeTextRMSNorm = LigerRMSNormForGlm4
2130
+ if cross_entropy:
2131
+ from transformers.loss.loss_utils import nn
2132
+
2133
+ nn.functional.cross_entropy = liger_cross_entropy
2134
+ if fused_linear_cross_entropy:
2135
+ if model is not None:
2136
+ model.forward = MethodType(glm4v_moe_lce_forward, model)
2137
+ else:
2138
+ modeling_glm4v_moe.Glm4vMoeForConditionalGeneration.forward = glm4v_moe_lce_forward
2139
+
2140
+ if model is not None:
2141
+ # The model instance already exists, so we need to additionally patch the
2142
+ # instance variables that reference already-instantiated modules
2143
+ if isinstance(model, (Glm4vMoeForConditionalGeneration, Glm4vMoeModel)):
2144
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
2145
+ # Not sure if it is subject to changes in the future.
2146
+ # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py#L337
2147
+ text_model: Glm4vMoeTextModel = model.language_model
2148
+ vision_model: Glm4vMoeVisionModel = model.visual
2149
+ Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
2150
+ elif isinstance(model, Glm4vMoeTextModel):
2151
+ text_model: Glm4vMoeTextModel = model
2152
+ vision_model = None
2153
+ else:
2154
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
2155
+ raise TypeError(
2156
+ f"Unsupported glm4v_moe model type. `model` must be `Glm4vMoeForConditionalGeneration`, `Glm4vMoeVisionModel` or `Glm4vMoeTextModel`. Got: {type(model)}"
2157
+ )
2158
+
2159
+ if vision_model is not None:
2160
+ _patch_rms_norm_module(vision_model.post_conv_layernorm)
2161
+ _patch_rms_norm_module(vision_model.post_layernorm)
2162
+ for vision_block in vision_model.blocks:
2163
+ if rms_norm:
2164
+ _patch_rms_norm_module(vision_block.norm1)
2165
+ _patch_rms_norm_module(vision_block.norm2)
2166
+ if swiglu:
2167
+ _patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
2168
+
2169
+ if text_model is not None:
2170
+ if rms_norm:
2171
+ _patch_rms_norm_module(text_model.norm)
2172
+ for decoder_layer in text_model.layers:
2173
+ if swiglu:
2174
+ decoder_layer.mlp = _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2175
+ if rms_norm:
2176
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2177
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2178
+ if isinstance(Glm4vMoeTextMoE, type) and isinstance(decoder_layer.mlp, Glm4vMoeTextMoE):
2179
+ experts = getattr(decoder_layer.mlp, "experts", None)
2180
+ if experts is not None:
2181
+ for expert in experts:
2182
+ _patch_swiglu_module(expert, LigerSwiGLUMLP)
2183
+ if decoder_layer.mlp.shared_experts is not None:
2184
+ _patch_swiglu_module(decoder_layer.mlp.shared_experts, LigerSwiGLUMLP)
2185
+ for decoder_layer in text_model.layers:
2186
+ if rms_norm:
2187
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2188
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2189
+
2190
+
2191
+ def apply_liger_kernel_to_internvl(
2192
+ cross_entropy: bool = False,
2193
+ fused_linear_cross_entropy: bool = True,
2194
+ rms_norm: bool = True,
2195
+ layer_norm: bool = True,
2196
+ model: Optional[PreTrainedModel] = None,
2197
+ **kwargs,
2198
+ ) -> None:
2199
+ """
2200
+ Apply Liger kernels to replace original implementation in HuggingFace InternVL models.
2201
+ Due to the characteristics of InternVL, the model must be passed to apply Liger-Kernel's patch to other models connected to InternVL.
2202
+ However, if an LM not supported by Liger-Kernel is connected to InternVL, unexpected side effects may occur.
2203
+ NOTE: InternVL is not available in transformers<4.52.1
2204
+
2205
+ Args:
2206
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2207
+ fused_linear_cross_entropy (bool):
2208
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2209
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2210
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2211
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2212
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
2213
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2214
+ loaded. Default is None.
2215
+ """
2216
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2217
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2218
+ )
2219
+ import torch.nn as torch_nn
2220
+
2221
+ from transformers.models.internvl import modeling_internvl
2222
+ from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration
2223
+ from transformers.models.internvl.modeling_internvl import InternVLModel
2224
+ from transformers.models.internvl.modeling_internvl import InternVLVisionLayer
2225
+ from transformers.models.internvl.modeling_internvl import InternVLVisionModel
2226
+ from transformers.models.internvl.modeling_internvl import InternVLVisionRMSNorm
2227
+
2228
+ from liger_kernel.transformers.layer_norm import LigerLayerNorm
2229
+ from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
2230
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm
2231
+
2232
+ if layer_norm and model is None:
2233
+ modeling_internvl.nn.LayerNorm = LigerLayerNorm
2234
+
2235
+ if cross_entropy:
2236
+ logger.info("Apply liger cross entropy")
2237
+
2238
+ from transformers.loss.loss_utils import nn
2239
+
2240
+ nn.functional.cross_entropy = liger_cross_entropy
2241
+ if fused_linear_cross_entropy:
2242
+ modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
2243
+ if rms_norm:
2244
+ modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
2245
+
2246
+ if model is not None:
2247
+ # The model instance already exists, so we need to additionally patch the
2248
+ # instance variables that reference already-instantiated modules
2249
+ if isinstance(model, (InternVLForConditionalGeneration, InternVLModel)):
2250
+ # NOTE: language_model and visual properties can be accessed throught conditional class.
2251
+ text_model = model.language_model
2252
+ vision_model: InternVLVisionModel = model.vision_tower
2253
+ else:
2254
+ raise TypeError(
2255
+ f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}"
2256
+ )
2257
+
2258
+ text_model_name = model.config.text_config.model_type
2259
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2260
+
2261
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2262
+ if text_liger_fn:
2263
+ accept_params = inspect.signature(text_liger_fn).parameters
2264
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2265
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2266
+
2267
+ if remain_params:
2268
+ logger.warning(
2269
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2270
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2271
+ )
2272
+ text_kwargs["model"] = text_model
2273
+ text_liger_fn(**text_kwargs)
2274
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2275
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2276
+
2277
+ # Patch vision model RMSNorm layers
2278
+ if rms_norm:
2279
+ for encoder_layer in vision_model.encoder.layer:
2280
+ encoder_layer: InternVLVisionLayer
2281
+ if isinstance(encoder_layer.attention.q_norm, InternVLVisionRMSNorm):
2282
+ _patch_rms_norm_module(encoder_layer.attention.q_norm)
2283
+ if isinstance(encoder_layer.attention.k_norm, InternVLVisionRMSNorm):
2284
+ _patch_rms_norm_module(encoder_layer.attention.k_norm)
2285
+
2286
+ # Patch vision model LayerNorm layers
2287
+ if layer_norm:
2288
+ # Patch layernorm
2289
+ if isinstance(vision_model.layernorm, torch_nn.LayerNorm):
2290
+ _patch_layer_norm_module(vision_model.layernorm)
2291
+
2292
+ # Patch encoder layers
2293
+ for encoder_layer in vision_model.encoder.layer:
2294
+ encoder_layer: InternVLVisionLayer
2295
+ if isinstance(encoder_layer.layernorm_before, torch_nn.LayerNorm):
2296
+ _patch_layer_norm_module(encoder_layer.layernorm_before)
2297
+ if isinstance(encoder_layer.layernorm_after, torch_nn.LayerNorm):
2298
+ _patch_layer_norm_module(encoder_layer.layernorm_after)
2299
+
2300
+
2301
+ def apply_liger_kernel_to_smolvlm(
2302
+ cross_entropy: bool = False,
2303
+ fused_linear_cross_entropy: bool = True,
2304
+ rms_norm: bool = True,
2305
+ layer_norm: bool = True,
2306
+ model: Optional[PreTrainedModel] = None,
2307
+ **kwargs,
2308
+ ) -> None:
2309
+ """
2310
+ Apply Liger kernels to replace original implementation in HuggingFace SmolVLM models.
2311
+ Due to the characteristics of SmolVLM, the model must be passed to apply Liger-Kernel's patch to other models connected to SmolVLM.
2312
+ However, if an LM not supported by Liger-Kernel is connected to SmolVLM, unexpected side effects may occur.
2313
+ NOTE: SmolVLM is not available in transformers<4.50.0
2314
+
2315
+ Args:
2316
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2317
+ fused_linear_cross_entropy (bool):
2318
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2319
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2320
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2321
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2322
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
2323
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2324
+ loaded. Default is None.
2325
+ """
2326
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2327
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2328
+ )
2329
+
2330
+ from transformers.models.smolvlm import modeling_smolvlm
2331
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMEncoderLayer
2332
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration
2333
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMModel
2334
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer
2335
+
2336
+ from liger_kernel.transformers.model.smolvlm import lce_forward as smolvlm_lce_forward
2337
+
2338
+ # Patch LayerNorm for vision model if model is not provided (pre-initialization)
2339
+ if layer_norm and model is None:
2340
+ modeling_smolvlm.nn.LayerNorm = LigerLayerNorm
2341
+
2342
+ if cross_entropy:
2343
+ logger.info("Apply liger cross entropy")
2344
+
2345
+ from transformers.loss.loss_utils import nn
2346
+
2347
+ nn.functional.cross_entropy = liger_cross_entropy
2348
+ if fused_linear_cross_entropy:
2349
+ if model is not None:
2350
+ model.forward = MethodType(smolvlm_lce_forward, model)
2351
+ else:
2352
+ modeling_smolvlm.SmolVLMForConditionalGeneration.forward = smolvlm_lce_forward
2353
+ if rms_norm:
2354
+ modeling_smolvlm.SmolVLMRMSNorm = LigerRMSNorm
2355
+
2356
+ if model is not None:
2357
+ # The model instance already exists, so we need to additionally patch the
2358
+ # instance variables that reference already-instantiated modules
2359
+ if isinstance(model, SmolVLMForConditionalGeneration):
2360
+ text_model = model.model.text_model
2361
+ vision_model: SmolVLMVisionTransformer = model.model.vision_model
2362
+ elif isinstance(model, SmolVLMModel):
2363
+ text_model = model.text_model
2364
+ vision_model: SmolVLMVisionTransformer = model.vision_model
2365
+ else:
2366
+ raise TypeError(
2367
+ f"Unsupported smolvlm model type. `model` must be `SmolVLMForConditionalGeneration`, `SmolVLMModel`. Got: {type(model)}"
2368
+ )
2369
+
2370
+ text_model_name = model.config.text_config.model_type
2371
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2372
+
2373
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2374
+ if text_liger_fn:
2375
+ accept_params = inspect.signature(text_liger_fn).parameters
2376
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2377
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2378
+
2379
+ if remain_params:
2380
+ logger.warning(
2381
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2382
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2383
+ )
2384
+ text_kwargs["model"] = text_model
2385
+ text_liger_fn(**text_kwargs)
2386
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2387
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2388
+
2389
+ # Patch vision model LayerNorm layers
2390
+ if layer_norm:
2391
+ # Patch post_layernorm
2392
+ _patch_layer_norm_module(vision_model.post_layernorm)
2393
+
2394
+ # Patch encoder layers
2395
+ for encoder_layer in vision_model.encoder.layers:
2396
+ encoder_layer: SmolVLMEncoderLayer
2397
+ _patch_layer_norm_module(encoder_layer.layer_norm1)
2398
+ _patch_layer_norm_module(encoder_layer.layer_norm2)
2399
+
2400
+
2401
+ def apply_liger_kernel_to_falcon_h1(
2402
+ rope: bool = True,
2403
+ cross_entropy: bool = False,
2404
+ fused_linear_cross_entropy: bool = True,
2405
+ rms_norm: bool = True,
2406
+ swiglu: bool = False,
2407
+ model: PreTrainedModel = None,
2408
+ ) -> None:
2409
+ """
2410
+ Apply Liger kernels to replace original implementation in HuggingFace Falcon-H1 models
2411
+ Args:
2412
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2413
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
2414
+ fused_linear_cross_entropy (bool):
2415
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
2416
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2417
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2418
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
2419
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
2420
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2421
+ loaded. Default is None.
2422
+ """
2423
+
2424
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2425
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2426
+ )
2427
+
2428
+ from transformers.models.falcon_h1 import modeling_falcon_h1
2429
+ from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1Model
2430
+
2431
+ if rope:
2432
+ logger.info("Apply liger rotary pos emb.")
2433
+ modeling_falcon_h1.apply_rotary_pos_emb = liger_rotary_pos_emb
2434
+ if rms_norm:
2435
+ logger.info("Apply liger RMSNorm")
2436
+ modeling_falcon_h1.FalconH1RMSNorm = LigerRMSNorm
2437
+ if swiglu:
2438
+ logger.warning("LigerSwiGLUMLP is not available for Falcon-H1 models. There will be no effect.")
2439
+
2440
+ if cross_entropy:
2441
+ logger.info("Apply liger cross entropy")
2442
+ from transformers.loss.loss_utils import nn
2443
+
2444
+ nn.functional.cross_entropy = liger_cross_entropy
2445
+
2446
+ if fused_linear_cross_entropy:
2447
+ if model is not None:
2448
+ model.forward = MethodType(falcon_h1_lce_forward, model)
2449
+ else:
2450
+ modeling_falcon_h1.FalconH1ForCausalLM.forward = falcon_h1_lce_forward
2451
+
2452
+ if model is not None:
2453
+ # The model instance already exists, so we need to additionally patch the
2454
+ # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
2455
+
2456
+ # get the base model from the model instance
2457
+ base_model: FalconH1Model = getattr(model, model.base_model_prefix, model)
2458
+
2459
+ if rms_norm:
2460
+ _patch_rms_norm_module(base_model.final_layernorm)
2461
+
2462
+ for decoder_layer in base_model.layers:
2463
+ if swiglu:
2464
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2465
+ if rms_norm:
2466
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2467
+ _patch_rms_norm_module(decoder_layer.pre_ff_layernorm)
2468
+
2469
+
2470
+ def apply_liger_kernel_to_qwen3_next(
2471
+ rope: bool = False,
2472
+ cross_entropy: bool = False,
2473
+ fused_linear_cross_entropy: bool = True,
2474
+ rms_norm: bool = True,
2475
+ swiglu: bool = True,
2476
+ model: PreTrainedModel = None,
2477
+ ) -> None:
2478
+ """
2479
+ Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
2480
+
2481
+ Args:
2482
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2483
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2484
+ fused_linear_cross_entropy (bool):
2485
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2486
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2487
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2488
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2489
+ swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
2490
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2491
+ loaded. Default is None.
2492
+ """
2493
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2494
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2495
+ )
2496
+
2497
+ from transformers.models.qwen3_next import modeling_qwen3_next
2498
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM
2499
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP
2500
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextModel
2501
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
2502
+
2503
+ from liger_kernel.transformers.model.qwen3_next import lce_forward as qwen3_next_lce_forward
2504
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
2505
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
2506
+
2507
+ if rope:
2508
+ # It might enocunter nan issue
2509
+ # modeling_qwen3_next.apply_rotary_pos_emb = liger_rotary_pos_emb
2510
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3Next models.")
2511
+ if rms_norm:
2512
+ modeling_qwen3_next.Qwen3NextRMSNorm = LigerRMSNormForQwen3Next
2513
+ if cross_entropy:
2514
+ from transformers.loss.loss_utils import nn
2515
+
2516
+ nn.functional.cross_entropy = liger_cross_entropy
2517
+ if fused_linear_cross_entropy:
2518
+ if model is not None:
2519
+ if isinstance(model, Qwen3NextForCausalLM):
2520
+ model.forward = MethodType(qwen3_next_lce_forward, model)
2521
+ else:
2522
+ raise TypeError(
2523
+ f" fused_linear_cross_entropy is only applicable on Qwen3NextForCausalLM. Got: {type(model)}"
2524
+ )
2525
+ else:
2526
+ modeling_qwen3_next.Qwen3NextForCausalLM.forward = qwen3_next_lce_forward
2527
+ if swiglu:
2528
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2529
+ modeling_qwen3_next.Qwen3NextMLP = LigerQwen3MoeSwiGLUMLP
2530
+
2531
+ if model is not None:
2532
+ # The model instance already exists, so we need to additionally patch the
2533
+ # instance variables that reference already-instantiated modules
2534
+ if isinstance(model, (Qwen3NextForCausalLM, Qwen3NextModel)):
2535
+ base_model: Qwen3NextForCausalLM = getattr(model, model.base_model_prefix, model)
2536
+ else:
2537
+ raise TypeError(
2538
+ f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}"
2539
+ )
2540
+
2541
+ if rms_norm:
2542
+ _patch_rms_norm_module(base_model.norm)
2543
+
2544
+ for decoder_layer in base_model.layers:
2545
+ if rms_norm:
2546
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2547
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2548
+
2549
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2550
+ if swiglu:
2551
+ if isinstance(decoder_layer.mlp, Qwen3NextMLP):
2552
+ _patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
2553
+ if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock):
2554
+ _patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP)
2555
+ experts = getattr(decoder_layer.mlp, "experts", None)
2556
+ if experts is not None:
2557
+ for expert in experts:
2558
+ _patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
2559
+
2560
+
1600
2561
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
1601
2562
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
1602
2563
  "gemma": apply_liger_kernel_to_gemma,
@@ -1604,7 +2565,12 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1604
2565
  "gemma3_text": apply_liger_kernel_to_gemma3_text,
1605
2566
  "gemma3": apply_liger_kernel_to_gemma3,
1606
2567
  "glm4": apply_liger_kernel_to_glm4,
2568
+ "glm4v": apply_liger_kernel_to_glm4v,
2569
+ "glm4v_moe": apply_liger_kernel_to_glm4v_moe,
2570
+ "internvl": apply_liger_kernel_to_internvl,
1607
2571
  "llama": apply_liger_kernel_to_llama,
2572
+ "llama4_text": apply_liger_kernel_to_llama4,
2573
+ "llama4": apply_liger_kernel_to_llama4,
1608
2574
  "llava": apply_liger_kernel_to_llava,
1609
2575
  "granite": apply_liger_kernel_to_granite,
1610
2576
  "mllama": apply_liger_kernel_to_mllama,
@@ -1619,8 +2585,16 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1619
2585
  "qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
1620
2586
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
1621
2587
  "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
2588
+ "qwen3_next": apply_liger_kernel_to_qwen3_next,
2589
+ "qwen3_vl": apply_liger_kernel_to_qwen3_vl,
2590
+ "qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
2591
+ "qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
2592
+ "qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
2593
+ "smollm3": apply_liger_kernel_to_smollm3,
1622
2594
  "phi3": apply_liger_kernel_to_phi3,
1623
2595
  "paligemma": apply_liger_kernel_to_paligemma,
2596
+ "falcon_h1": apply_liger_kernel_to_falcon_h1,
2597
+ "smolvlm": apply_liger_kernel_to_smolvlm,
1624
2598
  }
1625
2599
 
1626
2600