liger-kernel-nightly 0.6.2.dev20250919191028__py3-none-any.whl → 0.6.4.dev20251202054858__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (67) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +18 -5
  6. liger_kernel/ops/cross_entropy.py +120 -63
  7. liger_kernel/ops/dyt.py +5 -2
  8. liger_kernel/ops/fused_add_rms_norm.py +5 -1
  9. liger_kernel/ops/fused_linear_cross_entropy.py +43 -12
  10. liger_kernel/ops/geglu.py +2 -1
  11. liger_kernel/ops/group_norm.py +2 -1
  12. liger_kernel/ops/grpo_loss.py +3 -1
  13. liger_kernel/ops/layer_norm.py +88 -70
  14. liger_kernel/ops/poly_norm.py +390 -0
  15. liger_kernel/ops/rms_norm.py +7 -2
  16. liger_kernel/ops/tiled_mlp.py +136 -0
  17. liger_kernel/ops/utils.py +2 -0
  18. liger_kernel/transformers/__init__.py +33 -0
  19. liger_kernel/transformers/cross_entropy.py +8 -3
  20. liger_kernel/transformers/functional.py +29 -6
  21. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -3
  22. liger_kernel/transformers/grpo_loss.py +56 -1
  23. liger_kernel/transformers/model/falcon_h1.py +122 -0
  24. liger_kernel/transformers/model/gemma.py +19 -7
  25. liger_kernel/transformers/model/gemma2.py +22 -7
  26. liger_kernel/transformers/model/gemma3.py +52 -14
  27. liger_kernel/transformers/model/glm4.py +18 -5
  28. liger_kernel/transformers/model/glm4v.py +18 -5
  29. liger_kernel/transformers/model/glm4v_moe.py +25 -5
  30. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  31. liger_kernel/transformers/model/internvl.py +157 -0
  32. liger_kernel/transformers/model/llama.py +16 -6
  33. liger_kernel/transformers/model/llama4.py +18 -5
  34. liger_kernel/transformers/model/llava.py +18 -6
  35. liger_kernel/transformers/model/loss_utils.py +31 -3
  36. liger_kernel/transformers/model/mistral.py +17 -7
  37. liger_kernel/transformers/model/mixtral.py +24 -9
  38. liger_kernel/transformers/model/mllama.py +14 -5
  39. liger_kernel/transformers/model/olmo2.py +18 -5
  40. liger_kernel/transformers/model/olmo3.py +142 -0
  41. liger_kernel/transformers/model/output_classes.py +147 -0
  42. liger_kernel/transformers/model/paligemma.py +41 -5
  43. liger_kernel/transformers/model/phi3.py +16 -8
  44. liger_kernel/transformers/model/qwen2.py +18 -4
  45. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  46. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  47. liger_kernel/transformers/model/qwen3.py +22 -6
  48. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  49. liger_kernel/transformers/model/qwen3_next.py +146 -0
  50. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  51. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  52. liger_kernel/transformers/model/smollm3.py +17 -7
  53. liger_kernel/transformers/model/smolvlm.py +158 -0
  54. liger_kernel/transformers/monkey_patch.py +729 -4
  55. liger_kernel/transformers/poly_norm.py +42 -0
  56. liger_kernel/transformers/rms_norm.py +7 -0
  57. liger_kernel/transformers/rope.py +43 -0
  58. liger_kernel/transformers/swiglu.py +17 -0
  59. liger_kernel/transformers/tiled_mlp.py +133 -0
  60. liger_kernel/utils.py +25 -0
  61. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/METADATA +13 -6
  62. liger_kernel_nightly-0.6.4.dev20251202054858.dist-info/RECORD +118 -0
  63. liger_kernel_nightly-0.6.2.dev20250919191028.dist-info/RECORD +0 -105
  64. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/LICENSE +0 -0
  65. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/NOTICE +0 -0
  66. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/WHEEL +0 -0
  67. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,7 @@ import logging
4
4
  from functools import partial
5
5
  from types import MethodType
6
6
  from typing import Callable
7
+ from typing import Optional
7
8
 
8
9
  import transformers
9
10
 
@@ -14,6 +15,7 @@ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
14
15
  from liger_kernel.transformers.functional import liger_cross_entropy
15
16
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
16
17
  from liger_kernel.transformers.layer_norm import LigerLayerNorm
18
+ from liger_kernel.transformers.model.falcon_h1 import lce_forward as falcon_h1_lce_forward
17
19
  from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
18
20
  from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
19
21
  from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
@@ -32,6 +34,8 @@ from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_f
32
34
  from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
33
35
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
34
36
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
37
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb_with_cast
38
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb_with_cast_and_leading_batch
35
39
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
36
40
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
37
41
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
@@ -467,7 +471,7 @@ def apply_liger_kernel_to_llama4(
467
471
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
468
472
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
469
473
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
470
- swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
474
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
471
475
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
472
476
  loaded. Default is None.
473
477
  """
@@ -520,7 +524,10 @@ def apply_liger_kernel_to_llama4(
520
524
  _patch_rms_norm_module(text_model.norm)
521
525
  for decoder_layer in text_model.layers:
522
526
  if swiglu:
523
- _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
527
+ if decoder_layer.is_moe_layer:
528
+ _patch_swiglu_module(decoder_layer.feed_forward.shared_expert, LigerSwiGLUMLP)
529
+ else:
530
+ _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
524
531
  if rms_norm:
525
532
  _patch_rms_norm_module(decoder_layer.input_layernorm)
526
533
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -1334,7 +1341,6 @@ def apply_liger_kernel_to_qwen2(
1334
1341
  if rms_norm:
1335
1342
  _patch_rms_norm_module(decoder_layer.input_layernorm)
1336
1343
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1337
- print("Applied Liger kernels to Qwen2")
1338
1344
 
1339
1345
 
1340
1346
  def apply_liger_kernel_to_qwen3(
@@ -1639,6 +1645,158 @@ def apply_liger_kernel_to_qwen2_5_vl(
1639
1645
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1640
1646
 
1641
1647
 
1648
+ def apply_liger_kernel_to_qwen3_vl(
1649
+ rope: bool = True,
1650
+ cross_entropy: bool = False,
1651
+ fused_linear_cross_entropy: bool = True,
1652
+ rms_norm: bool = True,
1653
+ swiglu: bool = False,
1654
+ model: PreTrainedModel = None,
1655
+ ) -> None:
1656
+ """
1657
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
1658
+
1659
+ Args:
1660
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1661
+ fused_linear_cross_entropy (bool):
1662
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1663
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1664
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1665
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1666
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1667
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1668
+ loaded. Default is None.
1669
+ """
1670
+
1671
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1672
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1673
+ )
1674
+
1675
+ from transformers.models.qwen3_vl import modeling_qwen3_vl
1676
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
1677
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
1678
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
1679
+
1680
+ from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
1681
+
1682
+ if rope:
1683
+ modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
1684
+ modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
1685
+
1686
+ if rms_norm:
1687
+ modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
1688
+
1689
+ if cross_entropy:
1690
+ from transformers.loss.loss_utils import nn
1691
+
1692
+ nn.functional.cross_entropy = liger_cross_entropy
1693
+
1694
+ if fused_linear_cross_entropy:
1695
+ if model is not None:
1696
+ model.forward = MethodType(qwen3_vl_lce_forward, model)
1697
+ else:
1698
+ modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
1699
+
1700
+ if model is not None and rms_norm:
1701
+ if isinstance(model, (Qwen3VLForConditionalGeneration, Qwen3VLModel)):
1702
+ text_model: Qwen3VLTextModel = model.language_model
1703
+ elif isinstance(model, Qwen3VLTextModel):
1704
+ text_model = model
1705
+ else:
1706
+ raise TypeError(
1707
+ f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
1708
+ )
1709
+
1710
+ _patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1711
+
1712
+ if text_model is not None:
1713
+ _patch_qwen3_vl_rms_norm(text_model.norm)
1714
+ for decoder_layer in text_model.layers:
1715
+ _patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
1716
+ _patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
1717
+ self_attn = getattr(decoder_layer, "self_attn", None)
1718
+ if self_attn is not None:
1719
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1720
+ _patch_qwen3_vl_rms_norm(self_attn.q_norm)
1721
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1722
+ _patch_qwen3_vl_rms_norm(self_attn.k_norm)
1723
+
1724
+
1725
+ def apply_liger_kernel_to_qwen3_vl_moe(
1726
+ rope: bool = True,
1727
+ cross_entropy: bool = False,
1728
+ fused_linear_cross_entropy: bool = True,
1729
+ rms_norm: bool = True,
1730
+ swiglu: bool = False,
1731
+ model: PreTrainedModel = None,
1732
+ ) -> None:
1733
+ """
1734
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
1735
+
1736
+ Args:
1737
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1738
+ fused_linear_cross_entropy (bool):
1739
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
1740
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1741
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1742
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1743
+ loaded. Default is None.
1744
+ """
1745
+
1746
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1747
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1748
+ )
1749
+
1750
+ from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
1751
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
1752
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
1753
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
1754
+
1755
+ from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
1756
+
1757
+ if rope:
1758
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
1759
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
1760
+
1761
+ if rms_norm:
1762
+ modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
1763
+
1764
+ if cross_entropy:
1765
+ from transformers.loss.loss_utils import nn
1766
+
1767
+ nn.functional.cross_entropy = liger_cross_entropy
1768
+
1769
+ if fused_linear_cross_entropy:
1770
+ if model is not None:
1771
+ model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
1772
+ else:
1773
+ modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
1774
+
1775
+ if model is not None and rms_norm:
1776
+ if isinstance(model, (Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel)):
1777
+ text_model: Qwen3VLMoeTextModel = model.language_model
1778
+ elif isinstance(model, Qwen3VLMoeTextModel):
1779
+ text_model = model
1780
+ else:
1781
+ raise TypeError(
1782
+ f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
1783
+ )
1784
+
1785
+ _patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1786
+
1787
+ if text_model is not None:
1788
+ _patch_qwen3_vl_moe_rms_norm(text_model.norm)
1789
+ for decoder_layer in text_model.layers:
1790
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
1791
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
1792
+ self_attn = getattr(decoder_layer, "self_attn", None)
1793
+ if self_attn is not None:
1794
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1795
+ _patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
1796
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1797
+ _patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
1798
+
1799
+
1642
1800
  def apply_liger_kernel_to_phi3(
1643
1801
  rope: bool = True,
1644
1802
  cross_entropy: bool = False,
@@ -1770,6 +1928,74 @@ def apply_liger_kernel_to_olmo2(
1770
1928
  _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
1771
1929
 
1772
1930
 
1931
+ def apply_liger_kernel_to_olmo3(
1932
+ rope: bool = True,
1933
+ cross_entropy: bool = False,
1934
+ fused_linear_cross_entropy: bool = True,
1935
+ rms_norm: bool = True,
1936
+ swiglu: bool = True,
1937
+ model: PreTrainedModel = None,
1938
+ ) -> None:
1939
+ """
1940
+ Apply Liger kernels to replace original implementation in HuggingFace Olmo3 models.
1941
+
1942
+ Args:
1943
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1944
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1945
+ fused_linear_cross_entropy (bool):
1946
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1947
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1948
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1949
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1950
+ swiglu (bool): Whether to apply Liger's SwiGLU to Olmo3MLP. Default is True.
1951
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1952
+ loaded. Default is None.
1953
+ """
1954
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1955
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1956
+ )
1957
+
1958
+ from transformers.models.olmo3 import modeling_olmo3
1959
+ from transformers.models.olmo3.modeling_olmo3 import Olmo3Model
1960
+
1961
+ from liger_kernel.transformers.model.olmo3 import lce_forward as olmo3_lce_forward
1962
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
1963
+
1964
+ # Olmo3 arch is very similar to Olmo2, so we can reuse all these components in the same way.
1965
+ if rope:
1966
+ modeling_olmo3.apply_rotary_pos_emb = liger_rotary_pos_emb
1967
+ if rms_norm:
1968
+ modeling_olmo3.Olmo3RMSNorm = LigerRMSNormForOlmo2 # same as olmo2
1969
+ if swiglu:
1970
+ modeling_olmo3.Olmo3MLP = LigerSwiGLUMLP
1971
+ if cross_entropy:
1972
+ from transformers.loss.loss_utils import nn
1973
+
1974
+ nn.functional.cross_entropy = liger_cross_entropy
1975
+ if fused_linear_cross_entropy:
1976
+ if model is not None:
1977
+ model.forward = MethodType(olmo3_lce_forward, model)
1978
+ else:
1979
+ modeling_olmo3.Olmo3ForCausalLM.forward = olmo3_lce_forward
1980
+
1981
+ if model is not None:
1982
+ # The model instance already exists, so we need to additionally patch the
1983
+ # instance variables that reference already-instantiated modules
1984
+
1985
+ # get the base model from the model instance
1986
+ base_model: Olmo3Model = getattr(model, model.base_model_prefix, model)
1987
+
1988
+ if rms_norm:
1989
+ _patch_rms_norm_module(base_model.norm)
1990
+
1991
+ for decoder_layer in base_model.layers:
1992
+ if swiglu:
1993
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1994
+ if rms_norm:
1995
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
1996
+ _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
1997
+
1998
+
1773
1999
  def apply_liger_kernel_to_glm4(
1774
2000
  rope: bool = False,
1775
2001
  cross_entropy: bool = False,
@@ -1967,7 +2193,8 @@ def apply_liger_kernel_to_glm4v_moe(
1967
2193
  if rope:
1968
2194
  raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
1969
2195
  if rms_norm:
1970
- modeling_glm4v_moe.Glm4vRMSNorm = LigerRMSNormForGlm4
2196
+ modeling_glm4v_moe.Glm4vMoeRMSNorm = LigerRMSNormForGlm4
2197
+ modeling_glm4v_moe.Glm4vMoeTextRMSNorm = LigerRMSNormForGlm4
1971
2198
  if cross_entropy:
1972
2199
  from transformers.loss.loss_utils import nn
1973
2200
 
@@ -2029,6 +2256,493 @@ def apply_liger_kernel_to_glm4v_moe(
2029
2256
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2030
2257
 
2031
2258
 
2259
+ def apply_liger_kernel_to_internvl(
2260
+ cross_entropy: bool = False,
2261
+ fused_linear_cross_entropy: bool = True,
2262
+ rms_norm: bool = True,
2263
+ layer_norm: bool = True,
2264
+ model: Optional[PreTrainedModel] = None,
2265
+ **kwargs,
2266
+ ) -> None:
2267
+ """
2268
+ Apply Liger kernels to replace original implementation in HuggingFace InternVL models.
2269
+ Due to the characteristics of InternVL, the model must be passed to apply Liger-Kernel's patch to other models connected to InternVL.
2270
+ However, if an LM not supported by Liger-Kernel is connected to InternVL, unexpected side effects may occur.
2271
+ NOTE: InternVL is not available in transformers<4.52.1
2272
+
2273
+ Args:
2274
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2275
+ fused_linear_cross_entropy (bool):
2276
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2277
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2278
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2279
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2280
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
2281
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2282
+ loaded. Default is None.
2283
+ """
2284
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2285
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2286
+ )
2287
+ import torch.nn as torch_nn
2288
+
2289
+ from transformers.models.internvl import modeling_internvl
2290
+ from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration
2291
+ from transformers.models.internvl.modeling_internvl import InternVLModel
2292
+ from transformers.models.internvl.modeling_internvl import InternVLVisionLayer
2293
+ from transformers.models.internvl.modeling_internvl import InternVLVisionModel
2294
+ from transformers.models.internvl.modeling_internvl import InternVLVisionRMSNorm
2295
+
2296
+ from liger_kernel.transformers.layer_norm import LigerLayerNorm
2297
+ from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
2298
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm
2299
+
2300
+ if layer_norm and model is None:
2301
+ modeling_internvl.nn.LayerNorm = LigerLayerNorm
2302
+
2303
+ if cross_entropy:
2304
+ logger.info("Apply liger cross entropy")
2305
+
2306
+ from transformers.loss.loss_utils import nn
2307
+
2308
+ nn.functional.cross_entropy = liger_cross_entropy
2309
+ if fused_linear_cross_entropy:
2310
+ modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
2311
+ if rms_norm:
2312
+ modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
2313
+
2314
+ if model is not None:
2315
+ # The model instance already exists, so we need to additionally patch the
2316
+ # instance variables that reference already-instantiated modules
2317
+ if isinstance(model, (InternVLForConditionalGeneration, InternVLModel)):
2318
+ # NOTE: language_model and visual properties can be accessed throught conditional class.
2319
+ text_model = model.language_model
2320
+ vision_model: InternVLVisionModel = model.vision_tower
2321
+ else:
2322
+ raise TypeError(
2323
+ f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}"
2324
+ )
2325
+
2326
+ text_model_name = model.config.text_config.model_type
2327
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2328
+
2329
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2330
+ if text_liger_fn:
2331
+ accept_params = inspect.signature(text_liger_fn).parameters
2332
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2333
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2334
+
2335
+ if remain_params:
2336
+ logger.warning(
2337
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2338
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2339
+ )
2340
+ text_kwargs["model"] = text_model
2341
+ text_liger_fn(**text_kwargs)
2342
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2343
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2344
+
2345
+ # Patch vision model RMSNorm layers
2346
+ if rms_norm:
2347
+ for encoder_layer in vision_model.encoder.layer:
2348
+ encoder_layer: InternVLVisionLayer
2349
+ if isinstance(encoder_layer.attention.q_norm, InternVLVisionRMSNorm):
2350
+ _patch_rms_norm_module(encoder_layer.attention.q_norm)
2351
+ if isinstance(encoder_layer.attention.k_norm, InternVLVisionRMSNorm):
2352
+ _patch_rms_norm_module(encoder_layer.attention.k_norm)
2353
+
2354
+ # Patch vision model LayerNorm layers
2355
+ if layer_norm:
2356
+ # Patch layernorm
2357
+ if isinstance(vision_model.layernorm, torch_nn.LayerNorm):
2358
+ _patch_layer_norm_module(vision_model.layernorm)
2359
+
2360
+ # Patch encoder layers
2361
+ for encoder_layer in vision_model.encoder.layer:
2362
+ encoder_layer: InternVLVisionLayer
2363
+ if isinstance(encoder_layer.layernorm_before, torch_nn.LayerNorm):
2364
+ _patch_layer_norm_module(encoder_layer.layernorm_before)
2365
+ if isinstance(encoder_layer.layernorm_after, torch_nn.LayerNorm):
2366
+ _patch_layer_norm_module(encoder_layer.layernorm_after)
2367
+
2368
+
2369
+ def apply_liger_kernel_to_smolvlm(
2370
+ cross_entropy: bool = False,
2371
+ fused_linear_cross_entropy: bool = True,
2372
+ rms_norm: bool = True,
2373
+ layer_norm: bool = True,
2374
+ model: Optional[PreTrainedModel] = None,
2375
+ **kwargs,
2376
+ ) -> None:
2377
+ """
2378
+ Apply Liger kernels to replace original implementation in HuggingFace SmolVLM models.
2379
+ Due to the characteristics of SmolVLM, the model must be passed to apply Liger-Kernel's patch to other models connected to SmolVLM.
2380
+ However, if an LM not supported by Liger-Kernel is connected to SmolVLM, unexpected side effects may occur.
2381
+ NOTE: SmolVLM is not available in transformers<4.50.0
2382
+
2383
+ Args:
2384
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2385
+ fused_linear_cross_entropy (bool):
2386
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2387
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2388
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2389
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2390
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
2391
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2392
+ loaded. Default is None.
2393
+ """
2394
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2395
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2396
+ )
2397
+
2398
+ from transformers.models.smolvlm import modeling_smolvlm
2399
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMEncoderLayer
2400
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration
2401
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMModel
2402
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer
2403
+
2404
+ from liger_kernel.transformers.model.smolvlm import lce_forward as smolvlm_lce_forward
2405
+
2406
+ # Patch LayerNorm for vision model if model is not provided (pre-initialization)
2407
+ if layer_norm and model is None:
2408
+ modeling_smolvlm.nn.LayerNorm = LigerLayerNorm
2409
+
2410
+ if cross_entropy:
2411
+ logger.info("Apply liger cross entropy")
2412
+
2413
+ from transformers.loss.loss_utils import nn
2414
+
2415
+ nn.functional.cross_entropy = liger_cross_entropy
2416
+ if fused_linear_cross_entropy:
2417
+ if model is not None:
2418
+ model.forward = MethodType(smolvlm_lce_forward, model)
2419
+ else:
2420
+ modeling_smolvlm.SmolVLMForConditionalGeneration.forward = smolvlm_lce_forward
2421
+ if rms_norm:
2422
+ modeling_smolvlm.SmolVLMRMSNorm = LigerRMSNorm
2423
+
2424
+ if model is not None:
2425
+ # The model instance already exists, so we need to additionally patch the
2426
+ # instance variables that reference already-instantiated modules
2427
+ if isinstance(model, SmolVLMForConditionalGeneration):
2428
+ text_model = model.model.text_model
2429
+ vision_model: SmolVLMVisionTransformer = model.model.vision_model
2430
+ elif isinstance(model, SmolVLMModel):
2431
+ text_model = model.text_model
2432
+ vision_model: SmolVLMVisionTransformer = model.vision_model
2433
+ else:
2434
+ raise TypeError(
2435
+ f"Unsupported smolvlm model type. `model` must be `SmolVLMForConditionalGeneration`, `SmolVLMModel`. Got: {type(model)}"
2436
+ )
2437
+
2438
+ text_model_name = model.config.text_config.model_type
2439
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2440
+
2441
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2442
+ if text_liger_fn:
2443
+ accept_params = inspect.signature(text_liger_fn).parameters
2444
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2445
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2446
+
2447
+ if remain_params:
2448
+ logger.warning(
2449
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2450
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2451
+ )
2452
+ text_kwargs["model"] = text_model
2453
+ text_liger_fn(**text_kwargs)
2454
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2455
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2456
+
2457
+ # Patch vision model LayerNorm layers
2458
+ if layer_norm:
2459
+ # Patch post_layernorm
2460
+ _patch_layer_norm_module(vision_model.post_layernorm)
2461
+
2462
+ # Patch encoder layers
2463
+ for encoder_layer in vision_model.encoder.layers:
2464
+ encoder_layer: SmolVLMEncoderLayer
2465
+ _patch_layer_norm_module(encoder_layer.layer_norm1)
2466
+ _patch_layer_norm_module(encoder_layer.layer_norm2)
2467
+
2468
+
2469
+ def apply_liger_kernel_to_falcon_h1(
2470
+ rope: bool = True,
2471
+ cross_entropy: bool = False,
2472
+ fused_linear_cross_entropy: bool = True,
2473
+ rms_norm: bool = True,
2474
+ swiglu: bool = False,
2475
+ model: PreTrainedModel = None,
2476
+ ) -> None:
2477
+ """
2478
+ Apply Liger kernels to replace original implementation in HuggingFace Falcon-H1 models
2479
+ Args:
2480
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2481
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
2482
+ fused_linear_cross_entropy (bool):
2483
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
2484
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2485
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2486
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
2487
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
2488
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2489
+ loaded. Default is None.
2490
+ """
2491
+
2492
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2493
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2494
+ )
2495
+
2496
+ from transformers.models.falcon_h1 import modeling_falcon_h1
2497
+ from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1Model
2498
+
2499
+ if rope:
2500
+ logger.info("Apply liger rotary pos emb.")
2501
+ modeling_falcon_h1.apply_rotary_pos_emb = liger_rotary_pos_emb
2502
+ if rms_norm:
2503
+ logger.info("Apply liger RMSNorm")
2504
+ modeling_falcon_h1.FalconH1RMSNorm = LigerRMSNorm
2505
+ if swiglu:
2506
+ logger.warning("LigerSwiGLUMLP is not available for Falcon-H1 models. There will be no effect.")
2507
+
2508
+ if cross_entropy:
2509
+ logger.info("Apply liger cross entropy")
2510
+ from transformers.loss.loss_utils import nn
2511
+
2512
+ nn.functional.cross_entropy = liger_cross_entropy
2513
+
2514
+ if fused_linear_cross_entropy:
2515
+ if model is not None:
2516
+ model.forward = MethodType(falcon_h1_lce_forward, model)
2517
+ else:
2518
+ modeling_falcon_h1.FalconH1ForCausalLM.forward = falcon_h1_lce_forward
2519
+
2520
+ if model is not None:
2521
+ # The model instance already exists, so we need to additionally patch the
2522
+ # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
2523
+
2524
+ # get the base model from the model instance
2525
+ base_model: FalconH1Model = getattr(model, model.base_model_prefix, model)
2526
+
2527
+ if rms_norm:
2528
+ _patch_rms_norm_module(base_model.final_layernorm)
2529
+
2530
+ for decoder_layer in base_model.layers:
2531
+ if swiglu:
2532
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2533
+ if rms_norm:
2534
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2535
+ _patch_rms_norm_module(decoder_layer.pre_ff_layernorm)
2536
+
2537
+
2538
+ def apply_liger_kernel_to_qwen3_next(
2539
+ rope: bool = False,
2540
+ cross_entropy: bool = False,
2541
+ fused_linear_cross_entropy: bool = True,
2542
+ rms_norm: bool = True,
2543
+ swiglu: bool = True,
2544
+ model: PreTrainedModel = None,
2545
+ ) -> None:
2546
+ """
2547
+ Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
2548
+
2549
+ Args:
2550
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2551
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2552
+ fused_linear_cross_entropy (bool):
2553
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2554
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2555
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2556
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2557
+ swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
2558
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2559
+ loaded. Default is None.
2560
+ """
2561
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2562
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2563
+ )
2564
+
2565
+ from transformers.models.qwen3_next import modeling_qwen3_next
2566
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM
2567
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP
2568
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextModel
2569
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
2570
+
2571
+ from liger_kernel.transformers.model.qwen3_next import lce_forward as qwen3_next_lce_forward
2572
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
2573
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
2574
+
2575
+ if rope:
2576
+ # It might enocunter nan issue
2577
+ # modeling_qwen3_next.apply_rotary_pos_emb = liger_rotary_pos_emb
2578
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3Next models.")
2579
+ if rms_norm:
2580
+ modeling_qwen3_next.Qwen3NextRMSNorm = LigerRMSNormForQwen3Next
2581
+ if cross_entropy:
2582
+ from transformers.loss.loss_utils import nn
2583
+
2584
+ nn.functional.cross_entropy = liger_cross_entropy
2585
+ if fused_linear_cross_entropy:
2586
+ if model is not None:
2587
+ if isinstance(model, Qwen3NextForCausalLM):
2588
+ model.forward = MethodType(qwen3_next_lce_forward, model)
2589
+ else:
2590
+ raise TypeError(
2591
+ f" fused_linear_cross_entropy is only applicable on Qwen3NextForCausalLM. Got: {type(model)}"
2592
+ )
2593
+ else:
2594
+ modeling_qwen3_next.Qwen3NextForCausalLM.forward = qwen3_next_lce_forward
2595
+ if swiglu:
2596
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2597
+ modeling_qwen3_next.Qwen3NextMLP = LigerQwen3MoeSwiGLUMLP
2598
+
2599
+ if model is not None:
2600
+ # The model instance already exists, so we need to additionally patch the
2601
+ # instance variables that reference already-instantiated modules
2602
+ if isinstance(model, (Qwen3NextForCausalLM, Qwen3NextModel)):
2603
+ base_model: Qwen3NextForCausalLM = getattr(model, model.base_model_prefix, model)
2604
+ else:
2605
+ raise TypeError(
2606
+ f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}"
2607
+ )
2608
+
2609
+ if rms_norm:
2610
+ _patch_rms_norm_module(base_model.norm)
2611
+
2612
+ for decoder_layer in base_model.layers:
2613
+ if rms_norm:
2614
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2615
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2616
+
2617
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2618
+ if swiglu:
2619
+ if isinstance(decoder_layer.mlp, Qwen3NextMLP):
2620
+ _patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
2621
+ if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock):
2622
+ _patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP)
2623
+ experts = getattr(decoder_layer.mlp, "experts", None)
2624
+ if experts is not None:
2625
+ for expert in experts:
2626
+ _patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
2627
+
2628
+
2629
+ def apply_liger_kernel_to_hunyuan_v1_dense(
2630
+ rope: bool = True,
2631
+ cross_entropy: bool = False,
2632
+ fused_linear_cross_entropy: bool = True,
2633
+ rms_norm: bool = True,
2634
+ swiglu: bool = True,
2635
+ model: PreTrainedModel = None,
2636
+ ) -> None:
2637
+ """
2638
+ Apply Liger kernels to replace original implementation in HuggingFace Hunyuan v1 dense models.
2639
+ """
2640
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2641
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2642
+ )
2643
+
2644
+ from transformers.models.hunyuan_v1_dense import modeling_hunyuan_v1_dense
2645
+ from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1Model
2646
+
2647
+ from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_lce_forward
2648
+ from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
2649
+
2650
+ if rope:
2651
+ modeling_hunyuan_v1_dense.apply_rotary_pos_emb = liger_rotary_pos_emb
2652
+
2653
+ if rms_norm:
2654
+ modeling_hunyuan_v1_dense.HunYuanDenseV1RMSNorm = LigerRMSNorm
2655
+
2656
+ if cross_entropy:
2657
+ from transformers.loss.loss_utils import nn
2658
+
2659
+ nn.functional.cross_entropy = liger_cross_entropy
2660
+
2661
+ if fused_linear_cross_entropy:
2662
+ if model is not None:
2663
+ model.forward = MethodType(hunyuan_v1_lce_forward, model)
2664
+ else:
2665
+ modeling_hunyuan_v1_dense.HunYuanDenseV1ForCausalLM.forward = hunyuan_v1_lce_forward
2666
+
2667
+ if swiglu:
2668
+ modeling_hunyuan_v1_dense.HunYuanDenseV1MLP = LigerHunyuanV1SwiGLUMLP
2669
+
2670
+ if model is not None:
2671
+ # The model instance already exists, so we need to additionally patch the
2672
+ # instance variables that reference already-instantiated modules
2673
+
2674
+ # get the base model from the model instance
2675
+ base_model: HunYuanDenseV1Model = getattr(model, model.base_model_prefix, model)
2676
+
2677
+ if rms_norm:
2678
+ _patch_rms_norm_module(base_model.norm)
2679
+ for decoder_layer in base_model.layers:
2680
+ if swiglu:
2681
+ _patch_swiglu_module(decoder_layer.mlp, LigerHunyuanV1SwiGLUMLP)
2682
+ if rms_norm:
2683
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2684
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2685
+
2686
+
2687
+ def apply_liger_kernel_to_hunyuan_v1_moe(
2688
+ rope: bool = True,
2689
+ cross_entropy: bool = False,
2690
+ fused_linear_cross_entropy: bool = True,
2691
+ rms_norm: bool = True,
2692
+ swiglu: bool = True,
2693
+ model: PreTrainedModel = None,
2694
+ ) -> None:
2695
+ """
2696
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
2697
+ """
2698
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2699
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2700
+ )
2701
+
2702
+ from transformers.models.hunyuan_v1_moe import modeling_hunyuan_v1_moe
2703
+ from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1Model
2704
+
2705
+ from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_moe_lce_forward
2706
+ from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
2707
+
2708
+ if rope:
2709
+ modeling_hunyuan_v1_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
2710
+
2711
+ if rms_norm:
2712
+ modeling_hunyuan_v1_moe.HunYuanMoEV1RMSNorm = LigerRMSNorm
2713
+
2714
+ if cross_entropy:
2715
+ from transformers.loss.loss_utils import nn
2716
+
2717
+ nn.functional.cross_entropy = liger_cross_entropy
2718
+
2719
+ if fused_linear_cross_entropy:
2720
+ if model is not None:
2721
+ model.forward = MethodType(hunyuan_v1_moe_lce_forward, model)
2722
+ else:
2723
+ modeling_hunyuan_v1_moe.HunYuanMoEV1ForCausalLM.forward = hunyuan_v1_moe_lce_forward
2724
+
2725
+ if swiglu:
2726
+ modeling_hunyuan_v1_moe.HunYuanMoEV1MLP = LigerHunyuanV1SwiGLUMLP
2727
+
2728
+ if model is not None:
2729
+ # The model instance already exists, so we need to additionally patch the
2730
+ # instance variables that reference already-instantiated modules
2731
+
2732
+ # get the base model from the model instance
2733
+ base_model: HunYuanMoEV1Model = getattr(model, model.base_model_prefix, model)
2734
+
2735
+ if rms_norm:
2736
+ _patch_rms_norm_module(base_model.norm)
2737
+ for decoder_layer in base_model.layers:
2738
+ if swiglu:
2739
+ for mlp_expert in decoder_layer.mlp.experts:
2740
+ _patch_swiglu_module(mlp_expert, LigerHunyuanV1SwiGLUMLP)
2741
+ if rms_norm:
2742
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2743
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2744
+
2745
+
2032
2746
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
2033
2747
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
2034
2748
  "gemma": apply_liger_kernel_to_gemma,
@@ -2038,6 +2752,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
2038
2752
  "glm4": apply_liger_kernel_to_glm4,
2039
2753
  "glm4v": apply_liger_kernel_to_glm4v,
2040
2754
  "glm4v_moe": apply_liger_kernel_to_glm4v_moe,
2755
+ "internvl": apply_liger_kernel_to_internvl,
2041
2756
  "llama": apply_liger_kernel_to_llama,
2042
2757
  "llama4_text": apply_liger_kernel_to_llama4,
2043
2758
  "llama4": apply_liger_kernel_to_llama4,
@@ -2048,6 +2763,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
2048
2763
  "mistral": apply_liger_kernel_to_mistral,
2049
2764
  "mixtral": apply_liger_kernel_to_mixtral,
2050
2765
  "olmo2": apply_liger_kernel_to_olmo2,
2766
+ "olmo3": apply_liger_kernel_to_olmo3,
2051
2767
  "qwen2": apply_liger_kernel_to_qwen2,
2052
2768
  "qwen3": apply_liger_kernel_to_qwen3,
2053
2769
  "qwen3_moe": apply_liger_kernel_to_qwen3_moe,
@@ -2055,9 +2771,18 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
2055
2771
  "qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
2056
2772
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
2057
2773
  "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
2774
+ "qwen3_next": apply_liger_kernel_to_qwen3_next,
2775
+ "qwen3_vl": apply_liger_kernel_to_qwen3_vl,
2776
+ "qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
2777
+ "qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
2778
+ "qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
2058
2779
  "smollm3": apply_liger_kernel_to_smollm3,
2059
2780
  "phi3": apply_liger_kernel_to_phi3,
2060
2781
  "paligemma": apply_liger_kernel_to_paligemma,
2782
+ "falcon_h1": apply_liger_kernel_to_falcon_h1,
2783
+ "smolvlm": apply_liger_kernel_to_smolvlm,
2784
+ "hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
2785
+ "hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
2061
2786
  }
2062
2787
 
2063
2788