liger-kernel 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +39 -11
  6. liger_kernel/ops/__init__.py +141 -0
  7. liger_kernel/ops/backends/README.md +151 -0
  8. liger_kernel/ops/backends/__init__.py +13 -0
  9. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  10. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
  11. liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
  12. liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
  13. liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
  14. liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
  15. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
  16. liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
  17. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  18. liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
  19. liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
  20. liger_kernel/ops/backends/registry.py +61 -0
  21. liger_kernel/ops/cross_entropy.py +71 -11
  22. liger_kernel/ops/dyt.py +5 -2
  23. liger_kernel/ops/fused_add_rms_norm.py +21 -23
  24. liger_kernel/ops/fused_linear_cross_entropy.py +32 -5
  25. liger_kernel/ops/geglu.py +5 -3
  26. liger_kernel/ops/group_norm.py +12 -8
  27. liger_kernel/ops/grpo_loss.py +3 -1
  28. liger_kernel/ops/kl_div.py +8 -11
  29. liger_kernel/ops/layer_norm.py +89 -69
  30. liger_kernel/ops/poly_norm.py +19 -21
  31. liger_kernel/ops/rms_norm.py +149 -71
  32. liger_kernel/ops/tiled_mlp.py +136 -0
  33. liger_kernel/ops/utils.py +25 -0
  34. liger_kernel/transformers/__init__.py +25 -0
  35. liger_kernel/transformers/auto_model.py +21 -0
  36. liger_kernel/transformers/cross_entropy.py +9 -4
  37. liger_kernel/transformers/dyt.py +1 -1
  38. liger_kernel/transformers/experimental/embedding.py +1 -1
  39. liger_kernel/transformers/functional.py +44 -26
  40. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  41. liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
  42. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  43. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  44. liger_kernel/transformers/geglu.py +1 -1
  45. liger_kernel/transformers/group_norm.py +1 -1
  46. liger_kernel/transformers/grpo_loss.py +57 -2
  47. liger_kernel/transformers/jsd.py +1 -1
  48. liger_kernel/transformers/kl_div.py +1 -1
  49. liger_kernel/transformers/layer_norm.py +1 -1
  50. liger_kernel/transformers/llama4_rope.py +1 -1
  51. liger_kernel/transformers/model/exaone4.py +136 -0
  52. liger_kernel/transformers/model/falcon_h1.py +19 -5
  53. liger_kernel/transformers/model/gemma.py +17 -6
  54. liger_kernel/transformers/model/gemma2.py +17 -8
  55. liger_kernel/transformers/model/gemma3.py +35 -16
  56. liger_kernel/transformers/model/glm4.py +16 -4
  57. liger_kernel/transformers/model/glm4v.py +16 -4
  58. liger_kernel/transformers/model/glm4v_moe.py +23 -4
  59. liger_kernel/transformers/model/gpt_oss.py +211 -0
  60. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  61. liger_kernel/transformers/model/internvl.py +12 -5
  62. liger_kernel/transformers/model/llama.py +14 -5
  63. liger_kernel/transformers/model/llama4.py +16 -4
  64. liger_kernel/transformers/model/llava.py +12 -4
  65. liger_kernel/transformers/model/loss_utils.py +37 -3
  66. liger_kernel/transformers/model/mistral.py +15 -6
  67. liger_kernel/transformers/model/mixtral.py +16 -7
  68. liger_kernel/transformers/model/mllama.py +12 -4
  69. liger_kernel/transformers/model/olmo2.py +16 -4
  70. liger_kernel/transformers/model/olmo3.py +142 -0
  71. liger_kernel/transformers/model/output_classes.py +147 -0
  72. liger_kernel/transformers/model/paligemma.py +23 -5
  73. liger_kernel/transformers/model/phi3.py +14 -7
  74. liger_kernel/transformers/model/qwen2.py +16 -3
  75. liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
  76. liger_kernel/transformers/model/qwen2_vl.py +16 -4
  77. liger_kernel/transformers/model/qwen3.py +20 -5
  78. liger_kernel/transformers/model/qwen3_moe.py +19 -5
  79. liger_kernel/transformers/model/qwen3_next.py +17 -5
  80. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  81. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  82. liger_kernel/transformers/model/smollm3.py +15 -6
  83. liger_kernel/transformers/monkey_patch.py +584 -49
  84. liger_kernel/transformers/multi_token_attention.py +1 -1
  85. liger_kernel/transformers/poly_norm.py +1 -1
  86. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  87. liger_kernel/transformers/rms_norm.py +8 -3
  88. liger_kernel/transformers/rope.py +45 -1
  89. liger_kernel/transformers/softmax.py +1 -1
  90. liger_kernel/transformers/sparsemax.py +1 -1
  91. liger_kernel/transformers/swiglu.py +18 -1
  92. liger_kernel/transformers/tiled_mlp.py +125 -0
  93. liger_kernel/transformers/tvd.py +1 -1
  94. liger_kernel/utils.py +54 -0
  95. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +14 -4
  96. liger_kernel-0.6.5.dist-info/RECORD +134 -0
  97. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
  98. liger_kernel-0.6.3.dist-info/RECORD +0 -111
  99. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
  100. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
  101. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forwa
20
20
  from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
21
21
  from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
22
22
  from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
23
+ from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward
23
24
  from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
24
25
  from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
25
26
  from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
@@ -34,6 +35,7 @@ from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_f
34
35
  from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
35
36
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
36
37
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
38
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb_vision
37
39
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
38
40
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
39
41
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
@@ -428,7 +430,7 @@ def apply_liger_kernel_to_llava(
428
430
  f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
429
431
  f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
430
432
  )
431
- text_kwargs["model"] = model.language_model
433
+ text_kwargs["model"] = model.model.language_model
432
434
  text_liger_fn(**text_kwargs)
433
435
  elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
434
436
  logger.warning(f"{text_model_name} is not supported by Liger kernel.")
@@ -443,7 +445,7 @@ def apply_liger_kernel_to_llava(
443
445
  f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
444
446
  f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
445
447
  )
446
- vision_kwargs["model"] = model.vision_tower
448
+ vision_kwargs["model"] = model.model.vision_tower
447
449
  vision_liger_fn(**vision_kwargs)
448
450
  elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
449
451
  logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
@@ -613,8 +615,8 @@ def apply_liger_kernel_to_mllama(
613
615
  # instance variables that reference already-instantiated modules
614
616
 
615
617
  if isinstance(model, MllamaForConditionalGeneration):
616
- language_model: MllamaForCausalLM = model.language_model
617
- vision_model: MllamaVisionModel = model.vision_model
618
+ language_model: MllamaForCausalLM = model.model.language_model
619
+ vision_model: MllamaVisionModel = model.model.vision_model
618
620
  if isinstance(language_model, MllamaForCausalLM):
619
621
  text_model: MllamaTextModel = language_model.model
620
622
  else:
@@ -1116,8 +1118,8 @@ def apply_liger_kernel_to_gemma3(
1116
1118
  # instance variables that reference already-instantiated modules
1117
1119
 
1118
1120
  if isinstance(model, Gemma3ForConditionalGeneration):
1119
- if isinstance(model.vision_tower, SiglipVisionModel):
1120
- vision_tower = model.vision_tower
1121
+ if isinstance(model.model.vision_tower, SiglipVisionModel):
1122
+ vision_tower = model.model.vision_tower
1121
1123
 
1122
1124
  _patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
1123
1125
 
@@ -1130,7 +1132,7 @@ def apply_liger_kernel_to_gemma3(
1130
1132
  raise TypeError("The vision tower must be SiglipVisionModel")
1131
1133
 
1132
1134
  if rms_norm:
1133
- _patch_rms_norm_module_for_gemma3(model.multi_modal_projector.mm_soft_emb_norm)
1135
+ _patch_rms_norm_module_for_gemma3(model.model.multi_modal_projector.mm_soft_emb_norm)
1134
1136
 
1135
1137
  apply_liger_kernel_to_gemma3_text(
1136
1138
  rope=rope,
@@ -1138,7 +1140,7 @@ def apply_liger_kernel_to_gemma3(
1138
1140
  fused_linear_cross_entropy=False,
1139
1141
  rms_norm=rms_norm,
1140
1142
  geglu=geglu,
1141
- model=model.language_model,
1143
+ model=model.model.language_model,
1142
1144
  )
1143
1145
 
1144
1146
  else:
@@ -1226,7 +1228,7 @@ def apply_liger_kernel_to_paligemma(
1226
1228
  if not isinstance(model, PaliGemmaForConditionalGeneration):
1227
1229
  raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
1228
1230
 
1229
- vision_tower: SiglipVisionModel = model.vision_tower
1231
+ vision_tower: SiglipVisionModel = model.model.vision_tower
1230
1232
 
1231
1233
  _patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
1232
1234
 
@@ -1236,7 +1238,7 @@ def apply_liger_kernel_to_paligemma(
1236
1238
  _patch_layer_norm_module(layer.layer_norm1)
1237
1239
  _patch_layer_norm_module(layer.layer_norm2)
1238
1240
 
1239
- language_model = model.language_model
1241
+ language_model = model.model.language_model
1240
1242
 
1241
1243
  if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
1242
1244
  apply_liger_kernel_to_gemma(
@@ -1457,6 +1459,79 @@ def apply_liger_kernel_to_qwen3_moe(
1457
1459
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1458
1460
 
1459
1461
 
1462
+ def apply_liger_kernel_to_gpt_oss(
1463
+ rope: bool = True,
1464
+ cross_entropy: bool = False,
1465
+ fused_linear_cross_entropy: bool = True,
1466
+ rms_norm: bool = True,
1467
+ swiglu: bool = False, # Set to False by default since GPT-OSS has custom expert implementation
1468
+ model: PreTrainedModel = None,
1469
+ ) -> None:
1470
+ """
1471
+ Apply Liger kernels to replace original implementation in HuggingFace GPT-OSS models.
1472
+ NOTE: GPT-OSS is supported in transformers >= 4.55.0
1473
+ NOTE: SwiGLU patching is disabled by default for GPT-OSS as it uses a custom expert
1474
+ implementation with clamping and MXFP4 quantization.
1475
+
1476
+ Args:
1477
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1478
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1479
+ fused_linear_cross_entropy (bool):
1480
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1481
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1482
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1483
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1484
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1485
+ Note: GPT-OSS uses a custom expert implementation, so SwiGLU patching is disabled by default.
1486
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1487
+ loaded. Default is None.
1488
+ """
1489
+ if version.parse(transformers.__version__) < version.parse("4.55.0"):
1490
+ logger.warning("GPT-OSS support requires transformers >= 4.55.0")
1491
+ return
1492
+
1493
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1494
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1495
+ )
1496
+
1497
+ from transformers.models.gpt_oss import modeling_gpt_oss
1498
+ from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel
1499
+
1500
+ if rope:
1501
+ modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb
1502
+
1503
+ if rms_norm:
1504
+ modeling_gpt_oss.GptOssRMSNorm = LigerRMSNorm
1505
+
1506
+ if cross_entropy:
1507
+ from transformers.loss.loss_utils import nn
1508
+
1509
+ nn.functional.cross_entropy = liger_cross_entropy
1510
+
1511
+ if fused_linear_cross_entropy:
1512
+ if model is not None:
1513
+ model.forward = MethodType(gpt_oss_lce_forward, model)
1514
+ else:
1515
+ modeling_gpt_oss.GptOssForCausalLM.forward = gpt_oss_lce_forward
1516
+
1517
+ # Note: SwiGLU patching is not implemented for GPT-OSS due to custom expert implementation
1518
+ # with clamping (swiglu_limit=7.0) and MXFP4 quantization
1519
+
1520
+ if model is not None:
1521
+ # The model instance already exists, so we need to additionally patch the
1522
+ # instance variables that reference already-instantiated modules
1523
+
1524
+ # get the base model from the model instance
1525
+ base_model: GptOssModel = getattr(model, model.base_model_prefix, model)
1526
+
1527
+ if rms_norm:
1528
+ _patch_rms_norm_module(base_model.norm)
1529
+ for decoder_layer in base_model.layers:
1530
+ if rms_norm:
1531
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1532
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1533
+
1534
+
1460
1535
  def apply_liger_kernel_to_qwen2_vl(
1461
1536
  rope: bool = True,
1462
1537
  cross_entropy: bool = False,
@@ -1518,11 +1593,10 @@ def apply_liger_kernel_to_qwen2_vl(
1518
1593
  if model is not None:
1519
1594
  # The model instance already exists, so we need to additionally patch the
1520
1595
  # instance variables that reference already-instantiated modules
1521
-
1522
- if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)):
1523
- # Note: language_model and visual properties can be accessed throught conditional class for BC.
1524
- # Not sure if it is subject to changes in the future.
1525
- # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
1596
+ if isinstance(model, Qwen2VLForConditionalGeneration):
1597
+ text_model: Qwen2VLTextModel = model.model.language_model
1598
+ vision_model: Qwen2VisionTransformerPretrainedModel = model.model.visual
1599
+ elif isinstance(model, Qwen2VLModel):
1526
1600
  text_model: Qwen2VLTextModel = model.language_model
1527
1601
  vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
1528
1602
  elif isinstance(model, Qwen2VLTextModel):
@@ -1609,11 +1683,10 @@ def apply_liger_kernel_to_qwen2_5_vl(
1609
1683
  if model is not None:
1610
1684
  # The model instance already exists, so we need to additionally patch the
1611
1685
  # instance variables that reference already-instantiated modules
1612
-
1613
- if isinstance(model, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel)):
1614
- # Note: language_model and visual properties can be accessed throught conditional class for BC.
1615
- # Not sure if it is subject to changes in the future.
1616
- # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
1686
+ if isinstance(model, Qwen2_5_VLForConditionalGeneration):
1687
+ text_model: Qwen2_5_VLTextModel = model.model.language_model
1688
+ vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.model.visual
1689
+ elif isinstance(model, Qwen2_5_VLModel):
1617
1690
  text_model: Qwen2_5_VLTextModel = model.language_model
1618
1691
  vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
1619
1692
  elif isinstance(model, Qwen2_5_VLTextModel):
@@ -1627,7 +1700,7 @@ def apply_liger_kernel_to_qwen2_5_vl(
1627
1700
 
1628
1701
  if vision_model is not None:
1629
1702
  # Patch Qwen2_5_VisionTransformerPretrainedModel
1630
- for vision_block in model.visual.blocks:
1703
+ for vision_block in vision_model.blocks:
1631
1704
  if rms_norm:
1632
1705
  _patch_rms_norm_module(vision_block.norm1)
1633
1706
  _patch_rms_norm_module(vision_block.norm2)
@@ -1643,6 +1716,162 @@ def apply_liger_kernel_to_qwen2_5_vl(
1643
1716
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1644
1717
 
1645
1718
 
1719
+ def apply_liger_kernel_to_qwen3_vl(
1720
+ rope: bool = True,
1721
+ cross_entropy: bool = False,
1722
+ fused_linear_cross_entropy: bool = True,
1723
+ rms_norm: bool = True,
1724
+ swiglu: bool = False,
1725
+ model: PreTrainedModel = None,
1726
+ ) -> None:
1727
+ """
1728
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
1729
+
1730
+ Args:
1731
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1732
+ fused_linear_cross_entropy (bool):
1733
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1734
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1735
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1736
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1737
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1738
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1739
+ loaded. Default is None.
1740
+ """
1741
+
1742
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1743
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1744
+ )
1745
+
1746
+ from transformers.models.qwen3_vl import modeling_qwen3_vl
1747
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
1748
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
1749
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
1750
+
1751
+ from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
1752
+
1753
+ if rope:
1754
+ modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb
1755
+ modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
1756
+
1757
+ if rms_norm:
1758
+ modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
1759
+
1760
+ if cross_entropy:
1761
+ from transformers.loss.loss_utils import nn
1762
+
1763
+ nn.functional.cross_entropy = liger_cross_entropy
1764
+
1765
+ if fused_linear_cross_entropy:
1766
+ if model is not None:
1767
+ model.forward = MethodType(qwen3_vl_lce_forward, model)
1768
+ else:
1769
+ modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
1770
+
1771
+ if model is not None and rms_norm:
1772
+ if isinstance(model, Qwen3VLForConditionalGeneration):
1773
+ text_model: Qwen3VLTextModel = model.model.language_model
1774
+ elif isinstance(model, Qwen3VLModel):
1775
+ text_model: Qwen3VLTextModel = model.language_model
1776
+ elif isinstance(model, Qwen3VLTextModel):
1777
+ text_model = model
1778
+ else:
1779
+ raise TypeError(
1780
+ f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
1781
+ )
1782
+
1783
+ _patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1784
+
1785
+ if text_model is not None:
1786
+ _patch_qwen3_vl_rms_norm(text_model.norm)
1787
+ for decoder_layer in text_model.layers:
1788
+ _patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
1789
+ _patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
1790
+ self_attn = getattr(decoder_layer, "self_attn", None)
1791
+ if self_attn is not None:
1792
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1793
+ _patch_qwen3_vl_rms_norm(self_attn.q_norm)
1794
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1795
+ _patch_qwen3_vl_rms_norm(self_attn.k_norm)
1796
+
1797
+
1798
+ def apply_liger_kernel_to_qwen3_vl_moe(
1799
+ rope: bool = True,
1800
+ cross_entropy: bool = False,
1801
+ fused_linear_cross_entropy: bool = True,
1802
+ rms_norm: bool = True,
1803
+ swiglu: bool = False,
1804
+ model: PreTrainedModel = None,
1805
+ ) -> None:
1806
+ """
1807
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
1808
+
1809
+ Args:
1810
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1811
+ fused_linear_cross_entropy (bool):
1812
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
1813
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1814
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1815
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1816
+ loaded. Default is None.
1817
+ """
1818
+
1819
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1820
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1821
+ )
1822
+
1823
+ from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
1824
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
1825
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
1826
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
1827
+
1828
+ from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
1829
+
1830
+ if rope:
1831
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
1832
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
1833
+
1834
+ if rms_norm:
1835
+ modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
1836
+
1837
+ if cross_entropy:
1838
+ from transformers.loss.loss_utils import nn
1839
+
1840
+ nn.functional.cross_entropy = liger_cross_entropy
1841
+
1842
+ if fused_linear_cross_entropy:
1843
+ if model is not None:
1844
+ model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
1845
+ else:
1846
+ modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
1847
+
1848
+ if model is not None and rms_norm:
1849
+ if isinstance(model, Qwen3VLMoeForConditionalGeneration):
1850
+ text_model: Qwen3VLMoeTextModel = model.model.language_model
1851
+ elif isinstance(model, Qwen3VLMoeModel):
1852
+ text_model: Qwen3VLMoeTextModel = model.language_model
1853
+ elif isinstance(model, Qwen3VLMoeTextModel):
1854
+ text_model = model
1855
+ else:
1856
+ raise TypeError(
1857
+ f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
1858
+ )
1859
+
1860
+ _patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1861
+
1862
+ if text_model is not None:
1863
+ _patch_qwen3_vl_moe_rms_norm(text_model.norm)
1864
+ for decoder_layer in text_model.layers:
1865
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
1866
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
1867
+ self_attn = getattr(decoder_layer, "self_attn", None)
1868
+ if self_attn is not None:
1869
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1870
+ _patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
1871
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1872
+ _patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
1873
+
1874
+
1646
1875
  def apply_liger_kernel_to_phi3(
1647
1876
  rope: bool = True,
1648
1877
  cross_entropy: bool = False,
@@ -1774,6 +2003,74 @@ def apply_liger_kernel_to_olmo2(
1774
2003
  _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
1775
2004
 
1776
2005
 
2006
+ def apply_liger_kernel_to_olmo3(
2007
+ rope: bool = True,
2008
+ cross_entropy: bool = False,
2009
+ fused_linear_cross_entropy: bool = True,
2010
+ rms_norm: bool = True,
2011
+ swiglu: bool = True,
2012
+ model: PreTrainedModel = None,
2013
+ ) -> None:
2014
+ """
2015
+ Apply Liger kernels to replace original implementation in HuggingFace Olmo3 models.
2016
+
2017
+ Args:
2018
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2019
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2020
+ fused_linear_cross_entropy (bool):
2021
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2022
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2023
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2024
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2025
+ swiglu (bool): Whether to apply Liger's SwiGLU to Olmo3MLP. Default is True.
2026
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2027
+ loaded. Default is None.
2028
+ """
2029
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2030
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2031
+ )
2032
+
2033
+ from transformers.models.olmo3 import modeling_olmo3
2034
+ from transformers.models.olmo3.modeling_olmo3 import Olmo3Model
2035
+
2036
+ from liger_kernel.transformers.model.olmo3 import lce_forward as olmo3_lce_forward
2037
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
2038
+
2039
+ # Olmo3 arch is very similar to Olmo2, so we can reuse all these components in the same way.
2040
+ if rope:
2041
+ modeling_olmo3.apply_rotary_pos_emb = liger_rotary_pos_emb
2042
+ if rms_norm:
2043
+ modeling_olmo3.Olmo3RMSNorm = LigerRMSNormForOlmo2 # same as olmo2
2044
+ if swiglu:
2045
+ modeling_olmo3.Olmo3MLP = LigerSwiGLUMLP
2046
+ if cross_entropy:
2047
+ from transformers.loss.loss_utils import nn
2048
+
2049
+ nn.functional.cross_entropy = liger_cross_entropy
2050
+ if fused_linear_cross_entropy:
2051
+ if model is not None:
2052
+ model.forward = MethodType(olmo3_lce_forward, model)
2053
+ else:
2054
+ modeling_olmo3.Olmo3ForCausalLM.forward = olmo3_lce_forward
2055
+
2056
+ if model is not None:
2057
+ # The model instance already exists, so we need to additionally patch the
2058
+ # instance variables that reference already-instantiated modules
2059
+
2060
+ # get the base model from the model instance
2061
+ base_model: Olmo3Model = getattr(model, model.base_model_prefix, model)
2062
+
2063
+ if rms_norm:
2064
+ _patch_rms_norm_module(base_model.norm)
2065
+
2066
+ for decoder_layer in base_model.layers:
2067
+ if swiglu:
2068
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2069
+ if rms_norm:
2070
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
2071
+ _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
2072
+
2073
+
1777
2074
  def apply_liger_kernel_to_glm4(
1778
2075
  rope: bool = False,
1779
2076
  cross_entropy: bool = False,
@@ -1896,10 +2193,10 @@ def apply_liger_kernel_to_glm4v(
1896
2193
  if model is not None:
1897
2194
  # The model instance already exists, so we need to additionally patch the
1898
2195
  # instance variables that reference already-instantiated modules
1899
- if isinstance(model, (Glm4vForConditionalGeneration, Glm4vModel)):
1900
- # Note: language_model and visual properties can be accessed throught conditional class for BC.
1901
- # Not sure if it is subject to changes in the future.
1902
- # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py#L1305
2196
+ if isinstance(model, Glm4vForConditionalGeneration):
2197
+ text_model: Glm4vTextModel = model.model.language_model
2198
+ vision_model: Glm4vVisionModel = model.model.visual
2199
+ elif isinstance(model, Glm4vModel):
1903
2200
  text_model: Glm4vTextModel = model.language_model
1904
2201
  vision_model: Glm4vVisionModel = model.visual
1905
2202
  elif isinstance(model, Glm4vTextModel):
@@ -1986,10 +2283,11 @@ def apply_liger_kernel_to_glm4v_moe(
1986
2283
  if model is not None:
1987
2284
  # The model instance already exists, so we need to additionally patch the
1988
2285
  # instance variables that reference already-instantiated modules
1989
- if isinstance(model, (Glm4vMoeForConditionalGeneration, Glm4vMoeModel)):
1990
- # Note: language_model and visual properties can be accessed throught conditional class for BC.
1991
- # Not sure if it is subject to changes in the future.
1992
- # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py#L337
2286
+ if isinstance(model, Glm4vMoeForConditionalGeneration):
2287
+ text_model: Glm4vMoeTextModel = model.model.language_model
2288
+ vision_model: Glm4vMoeVisionModel = model.model.visual
2289
+ Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
2290
+ elif isinstance(model, Glm4vMoeModel):
1993
2291
  text_model: Glm4vMoeTextModel = model.language_model
1994
2292
  vision_model: Glm4vMoeVisionModel = model.visual
1995
2293
  Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
@@ -2038,6 +2336,7 @@ def apply_liger_kernel_to_internvl(
2038
2336
  cross_entropy: bool = False,
2039
2337
  fused_linear_cross_entropy: bool = True,
2040
2338
  rms_norm: bool = True,
2339
+ layer_norm: bool = True,
2041
2340
  model: Optional[PreTrainedModel] = None,
2042
2341
  **kwargs,
2043
2342
  ) -> None:
@@ -2048,37 +2347,62 @@ def apply_liger_kernel_to_internvl(
2048
2347
  NOTE: InternVL is not available in transformers<4.52.1
2049
2348
 
2050
2349
  Args:
2051
- rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2052
2350
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2053
2351
  fused_linear_cross_entropy (bool):
2054
2352
  Whether to apply Liger's fused linear cross entropy loss. Default is True.
2055
2353
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2056
2354
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2057
2355
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2058
- swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
2356
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
2059
2357
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2060
2358
  loaded. Default is None.
2061
2359
  """
2062
2360
  assert not (cross_entropy and fused_linear_cross_entropy), (
2063
2361
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
2064
2362
  )
2363
+ import torch.nn as torch_nn
2065
2364
 
2066
2365
  from transformers.models.internvl import modeling_internvl
2366
+ from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration
2367
+ from transformers.models.internvl.modeling_internvl import InternVLModel
2368
+ from transformers.models.internvl.modeling_internvl import InternVLVisionLayer
2369
+ from transformers.models.internvl.modeling_internvl import InternVLVisionModel
2370
+ from transformers.models.internvl.modeling_internvl import InternVLVisionRMSNorm
2067
2371
 
2372
+ from liger_kernel.transformers.layer_norm import LigerLayerNorm
2068
2373
  from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
2374
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm
2375
+
2376
+ if layer_norm and model is None:
2377
+ modeling_internvl.nn.LayerNorm = LigerLayerNorm
2069
2378
 
2070
2379
  if cross_entropy:
2071
- logger.warning(TRANSFORMER_DEPRECATION_WARNING)
2072
- modeling_internvl.nn.CrossEntropyLoss = LigerCrossEntropyLoss
2380
+ logger.info("Apply liger cross entropy")
2381
+
2382
+ from transformers.loss.loss_utils import nn
2383
+
2384
+ nn.functional.cross_entropy = liger_cross_entropy
2073
2385
  if fused_linear_cross_entropy:
2074
2386
  modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
2075
2387
  if rms_norm:
2076
2388
  modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
2077
2389
 
2078
2390
  if model is not None:
2079
- text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
2391
+ # The model instance already exists, so we need to additionally patch the
2392
+ # instance variables that reference already-instantiated modules
2393
+ if isinstance(model, InternVLForConditionalGeneration):
2394
+ text_model = model.model.language_model
2395
+ vision_model: InternVLVisionModel = model.model.vision_tower
2396
+ elif isinstance(model, InternVLModel):
2397
+ text_model = model.language_model
2398
+ vision_model: InternVLVisionModel = model.vision_tower
2399
+ else:
2400
+ raise TypeError(
2401
+ f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}"
2402
+ )
2403
+
2404
+ text_model_name = model.config.text_config.model_type
2080
2405
  text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2081
- vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)
2082
2406
 
2083
2407
  kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2084
2408
  if text_liger_fn:
@@ -2091,25 +2415,33 @@ def apply_liger_kernel_to_internvl(
2091
2415
  f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2092
2416
  f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2093
2417
  )
2094
- text_kwargs["model"] = model.language_model
2418
+ text_kwargs["model"] = text_model
2095
2419
  text_liger_fn(**text_kwargs)
2096
2420
  elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2097
2421
  logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2098
2422
 
2099
- if vision_liger_fn:
2100
- accept_params = inspect.signature(vision_liger_fn).parameters
2101
- remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2102
- vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2423
+ # Patch vision model RMSNorm layers
2424
+ if rms_norm:
2425
+ for encoder_layer in vision_model.encoder.layer:
2426
+ encoder_layer: InternVLVisionLayer
2427
+ if isinstance(encoder_layer.attention.q_norm, InternVLVisionRMSNorm):
2428
+ _patch_rms_norm_module(encoder_layer.attention.q_norm)
2429
+ if isinstance(encoder_layer.attention.k_norm, InternVLVisionRMSNorm):
2430
+ _patch_rms_norm_module(encoder_layer.attention.k_norm)
2103
2431
 
2104
- if remain_params:
2105
- logger.warning(
2106
- f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
2107
- f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
2108
- )
2109
- vision_kwargs["model"] = model.vision_tower
2110
- vision_liger_fn(**vision_kwargs)
2111
- elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2112
- logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
2432
+ # Patch vision model LayerNorm layers
2433
+ if layer_norm:
2434
+ # Patch layernorm
2435
+ if isinstance(vision_model.layernorm, torch_nn.LayerNorm):
2436
+ _patch_layer_norm_module(vision_model.layernorm)
2437
+
2438
+ # Patch encoder layers
2439
+ for encoder_layer in vision_model.encoder.layer:
2440
+ encoder_layer: InternVLVisionLayer
2441
+ if isinstance(encoder_layer.layernorm_before, torch_nn.LayerNorm):
2442
+ _patch_layer_norm_module(encoder_layer.layernorm_before)
2443
+ if isinstance(encoder_layer.layernorm_after, torch_nn.LayerNorm):
2444
+ _patch_layer_norm_module(encoder_layer.layernorm_after)
2113
2445
 
2114
2446
 
2115
2447
  def apply_liger_kernel_to_smolvlm(
@@ -2372,6 +2704,200 @@ def apply_liger_kernel_to_qwen3_next(
2372
2704
  _patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
2373
2705
 
2374
2706
 
2707
+ def apply_liger_kernel_to_hunyuan_v1_dense(
2708
+ rope: bool = True,
2709
+ cross_entropy: bool = False,
2710
+ fused_linear_cross_entropy: bool = True,
2711
+ rms_norm: bool = True,
2712
+ swiglu: bool = True,
2713
+ model: PreTrainedModel = None,
2714
+ ) -> None:
2715
+ """
2716
+ Apply Liger kernels to replace original implementation in HuggingFace Hunyuan v1 dense models.
2717
+ """
2718
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2719
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2720
+ )
2721
+
2722
+ from transformers.models.hunyuan_v1_dense import modeling_hunyuan_v1_dense
2723
+ from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1Model
2724
+
2725
+ from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_lce_forward
2726
+ from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
2727
+
2728
+ if rope:
2729
+ modeling_hunyuan_v1_dense.apply_rotary_pos_emb = liger_rotary_pos_emb
2730
+
2731
+ if rms_norm:
2732
+ modeling_hunyuan_v1_dense.HunYuanDenseV1RMSNorm = LigerRMSNorm
2733
+
2734
+ if cross_entropy:
2735
+ from transformers.loss.loss_utils import nn
2736
+
2737
+ nn.functional.cross_entropy = liger_cross_entropy
2738
+
2739
+ if fused_linear_cross_entropy:
2740
+ if model is not None:
2741
+ model.forward = MethodType(hunyuan_v1_lce_forward, model)
2742
+ else:
2743
+ modeling_hunyuan_v1_dense.HunYuanDenseV1ForCausalLM.forward = hunyuan_v1_lce_forward
2744
+
2745
+ if swiglu:
2746
+ modeling_hunyuan_v1_dense.HunYuanDenseV1MLP = LigerHunyuanV1SwiGLUMLP
2747
+
2748
+ if model is not None:
2749
+ # The model instance already exists, so we need to additionally patch the
2750
+ # instance variables that reference already-instantiated modules
2751
+
2752
+ # get the base model from the model instance
2753
+ base_model: HunYuanDenseV1Model = getattr(model, model.base_model_prefix, model)
2754
+
2755
+ if rms_norm:
2756
+ _patch_rms_norm_module(base_model.norm)
2757
+ for decoder_layer in base_model.layers:
2758
+ if swiglu:
2759
+ _patch_swiglu_module(decoder_layer.mlp, LigerHunyuanV1SwiGLUMLP)
2760
+ if rms_norm:
2761
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2762
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2763
+
2764
+
2765
+ def apply_liger_kernel_to_hunyuan_v1_moe(
2766
+ rope: bool = True,
2767
+ cross_entropy: bool = False,
2768
+ fused_linear_cross_entropy: bool = True,
2769
+ rms_norm: bool = True,
2770
+ swiglu: bool = True,
2771
+ model: PreTrainedModel = None,
2772
+ ) -> None:
2773
+ """
2774
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
2775
+ """
2776
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2777
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2778
+ )
2779
+
2780
+ from transformers.models.hunyuan_v1_moe import modeling_hunyuan_v1_moe
2781
+ from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1Model
2782
+
2783
+ from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_moe_lce_forward
2784
+ from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
2785
+
2786
+ if rope:
2787
+ modeling_hunyuan_v1_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
2788
+
2789
+ if rms_norm:
2790
+ modeling_hunyuan_v1_moe.HunYuanMoEV1RMSNorm = LigerRMSNorm
2791
+
2792
+ if cross_entropy:
2793
+ from transformers.loss.loss_utils import nn
2794
+
2795
+ nn.functional.cross_entropy = liger_cross_entropy
2796
+
2797
+ if fused_linear_cross_entropy:
2798
+ if model is not None:
2799
+ model.forward = MethodType(hunyuan_v1_moe_lce_forward, model)
2800
+ else:
2801
+ modeling_hunyuan_v1_moe.HunYuanMoEV1ForCausalLM.forward = hunyuan_v1_moe_lce_forward
2802
+
2803
+ if swiglu:
2804
+ modeling_hunyuan_v1_moe.HunYuanMoEV1MLP = LigerHunyuanV1SwiGLUMLP
2805
+
2806
+ if model is not None:
2807
+ # The model instance already exists, so we need to additionally patch the
2808
+ # instance variables that reference already-instantiated modules
2809
+
2810
+ # get the base model from the model instance
2811
+ base_model: HunYuanMoEV1Model = getattr(model, model.base_model_prefix, model)
2812
+
2813
+ if rms_norm:
2814
+ _patch_rms_norm_module(base_model.norm)
2815
+ for decoder_layer in base_model.layers:
2816
+ if swiglu:
2817
+ for mlp_expert in decoder_layer.mlp.experts:
2818
+ _patch_swiglu_module(mlp_expert, LigerHunyuanV1SwiGLUMLP)
2819
+ if rms_norm:
2820
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2821
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2822
+
2823
+
2824
+ def apply_liger_kernel_to_exaone4(
2825
+ rope: bool = True,
2826
+ cross_entropy: bool = False,
2827
+ fused_linear_cross_entropy: bool = True,
2828
+ rms_norm: bool = True,
2829
+ swiglu: bool = True,
2830
+ model: PreTrainedModel = None,
2831
+ ) -> None:
2832
+ """
2833
+ Apply Liger kernels to replace original implementation in HuggingFace EXAONE4 models.
2834
+
2835
+ Args:
2836
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2837
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2838
+ fused_linear_cross_entropy (bool):
2839
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2840
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2841
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2842
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2843
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
2844
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2845
+ loaded. Default is None.
2846
+ """
2847
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2848
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2849
+ )
2850
+
2851
+ from transformers.models.exaone4 import modeling_exaone4
2852
+ from transformers.models.exaone4.modeling_exaone4 import Exaone4Model
2853
+
2854
+ from liger_kernel.transformers.model.exaone4 import lce_forward as exaone4_lce_forward
2855
+
2856
+ if rope:
2857
+ modeling_exaone4.apply_rotary_pos_emb = liger_rotary_pos_emb
2858
+
2859
+ if rms_norm:
2860
+ # EXAONE4 requires in_place=False to avoid gradient issues
2861
+ class Exaone4LigerRMSNorm(LigerRMSNorm):
2862
+ def __init__(self, hidden_size, eps=1e-6, **kwargs):
2863
+ super().__init__(hidden_size, eps, **kwargs)
2864
+ self.in_place = False
2865
+
2866
+ modeling_exaone4.Exaone4RMSNorm = Exaone4LigerRMSNorm
2867
+
2868
+ if cross_entropy:
2869
+ from transformers.loss.loss_utils import nn
2870
+
2871
+ nn.functional.cross_entropy = liger_cross_entropy
2872
+
2873
+ if fused_linear_cross_entropy:
2874
+ if model is not None:
2875
+ model.forward = MethodType(exaone4_lce_forward, model)
2876
+ else:
2877
+ modeling_exaone4.Exaone4ForCausalLM.forward = exaone4_lce_forward
2878
+
2879
+ if swiglu:
2880
+ modeling_exaone4.Exaone4MLP = LigerSwiGLUMLP
2881
+
2882
+ if model is not None:
2883
+ # The model instance already exists, so we need to additionally patch the
2884
+ # instance variables that reference already-instantiated modules
2885
+
2886
+ # get the base model from the model instance
2887
+ base_model: Exaone4Model = getattr(model, model.base_model_prefix, model)
2888
+
2889
+ if rms_norm:
2890
+ _patch_rms_norm_module(base_model.norm, in_place=False)
2891
+ for decoder_layer in base_model.layers:
2892
+ if swiglu:
2893
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
2894
+ if rms_norm:
2895
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
2896
+ _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
2897
+ _patch_rms_norm_module(decoder_layer.self_attn.q_norm, in_place=False)
2898
+ _patch_rms_norm_module(decoder_layer.self_attn.k_norm, in_place=False)
2899
+
2900
+
2375
2901
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
2376
2902
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
2377
2903
  "gemma": apply_liger_kernel_to_gemma,
@@ -2381,6 +2907,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
2381
2907
  "glm4": apply_liger_kernel_to_glm4,
2382
2908
  "glm4v": apply_liger_kernel_to_glm4v,
2383
2909
  "glm4v_moe": apply_liger_kernel_to_glm4v_moe,
2910
+ "gpt_oss": apply_liger_kernel_to_gpt_oss,
2384
2911
  "internvl": apply_liger_kernel_to_internvl,
2385
2912
  "llama": apply_liger_kernel_to_llama,
2386
2913
  "llama4_text": apply_liger_kernel_to_llama4,
@@ -2392,6 +2919,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
2392
2919
  "mistral": apply_liger_kernel_to_mistral,
2393
2920
  "mixtral": apply_liger_kernel_to_mixtral,
2394
2921
  "olmo2": apply_liger_kernel_to_olmo2,
2922
+ "olmo3": apply_liger_kernel_to_olmo3,
2395
2923
  "qwen2": apply_liger_kernel_to_qwen2,
2396
2924
  "qwen3": apply_liger_kernel_to_qwen3,
2397
2925
  "qwen3_moe": apply_liger_kernel_to_qwen3_moe,
@@ -2400,11 +2928,18 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
2400
2928
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
2401
2929
  "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
2402
2930
  "qwen3_next": apply_liger_kernel_to_qwen3_next,
2931
+ "qwen3_vl": apply_liger_kernel_to_qwen3_vl,
2932
+ "qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
2933
+ "qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
2934
+ "qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
2403
2935
  "smollm3": apply_liger_kernel_to_smollm3,
2404
2936
  "phi3": apply_liger_kernel_to_phi3,
2405
2937
  "paligemma": apply_liger_kernel_to_paligemma,
2406
2938
  "falcon_h1": apply_liger_kernel_to_falcon_h1,
2407
2939
  "smolvlm": apply_liger_kernel_to_smolvlm,
2940
+ "hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
2941
+ "hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
2942
+ "exaone4": apply_liger_kernel_to_exaone4,
2408
2943
  }
2409
2944
 
2410
2945