liger-kernel 0.6.1__py3-none-any.whl → 0.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. liger_kernel/chunked_loss/dpo_loss.py +54 -3
  2. liger_kernel/chunked_loss/fused_linear_ppo.py +4 -0
  3. liger_kernel/chunked_loss/grpo_loss.py +38 -4
  4. liger_kernel/chunked_loss/jsd_loss.py +5 -2
  5. liger_kernel/ops/cross_entropy.py +59 -53
  6. liger_kernel/ops/fused_linear_cross_entropy.py +83 -17
  7. liger_kernel/ops/layer_norm.py +4 -6
  8. liger_kernel/ops/llama4_rope.py +225 -0
  9. liger_kernel/ops/poly_norm.py +386 -0
  10. liger_kernel/transformers/__init__.py +32 -0
  11. liger_kernel/transformers/experimental/__init__.py +5 -0
  12. liger_kernel/transformers/functional.py +9 -0
  13. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -1
  14. liger_kernel/transformers/llama4_rope.py +93 -0
  15. liger_kernel/transformers/model/falcon_h1.py +108 -0
  16. liger_kernel/transformers/model/gemma.py +2 -1
  17. liger_kernel/transformers/model/gemma2.py +8 -2
  18. liger_kernel/transformers/model/gemma3.py +27 -2
  19. liger_kernel/transformers/model/glm4.py +2 -1
  20. liger_kernel/transformers/model/glm4v.py +151 -0
  21. liger_kernel/transformers/model/glm4v_moe.py +153 -0
  22. liger_kernel/transformers/model/internvl.py +150 -0
  23. liger_kernel/transformers/model/llama.py +2 -1
  24. liger_kernel/transformers/model/llama4.py +2 -1
  25. liger_kernel/transformers/model/llava.py +6 -2
  26. liger_kernel/transformers/model/loss_utils.py +3 -0
  27. liger_kernel/transformers/model/mistral.py +2 -1
  28. liger_kernel/transformers/model/mixtral.py +8 -2
  29. liger_kernel/transformers/model/mllama.py +6 -3
  30. liger_kernel/transformers/model/olmo2.py +2 -1
  31. liger_kernel/transformers/model/paligemma.py +19 -0
  32. liger_kernel/transformers/model/phi3.py +10 -160
  33. liger_kernel/transformers/model/qwen2.py +2 -1
  34. liger_kernel/transformers/model/qwen2_5_vl.py +7 -2
  35. liger_kernel/transformers/model/qwen2_vl.py +7 -2
  36. liger_kernel/transformers/model/qwen3.py +2 -1
  37. liger_kernel/transformers/model/qwen3_moe.py +8 -2
  38. liger_kernel/transformers/model/qwen3_next.py +134 -0
  39. liger_kernel/transformers/model/smollm3.py +2 -1
  40. liger_kernel/transformers/model/smolvlm.py +158 -0
  41. liger_kernel/transformers/monkey_patch.py +552 -23
  42. liger_kernel/transformers/multi_token_attention.py +1 -1
  43. liger_kernel/transformers/poly_norm.py +42 -0
  44. liger_kernel/transformers/rms_norm.py +7 -0
  45. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/METADATA +14 -11
  46. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/RECORD +50 -39
  47. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/WHEEL +0 -0
  48. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/licenses/LICENSE +0 -0
  49. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/licenses/NOTICE +0 -0
  50. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,7 @@ import logging
4
4
  from functools import partial
5
5
  from types import MethodType
6
6
  from typing import Callable
7
+ from typing import Optional
7
8
 
8
9
  import transformers
9
10
 
@@ -14,6 +15,7 @@ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
14
15
  from liger_kernel.transformers.functional import liger_cross_entropy
15
16
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
16
17
  from liger_kernel.transformers.layer_norm import LigerLayerNorm
18
+ from liger_kernel.transformers.model.falcon_h1 import lce_forward as falcon_h1_lce_forward
17
19
  from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
18
20
  from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
19
21
  from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
@@ -26,7 +28,6 @@ from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_f
26
28
  from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
27
29
  from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
28
30
  from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
29
- from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
30
31
  from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
31
32
  from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
32
33
  from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
@@ -449,7 +450,7 @@ def apply_liger_kernel_to_llava(
449
450
 
450
451
 
451
452
  def apply_liger_kernel_to_llama4(
452
- rope: bool = False,
453
+ rope: bool = True,
453
454
  cross_entropy: bool = False,
454
455
  fused_linear_cross_entropy: bool = True,
455
456
  rms_norm: bool = True,
@@ -468,7 +469,7 @@ def apply_liger_kernel_to_llama4(
468
469
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
469
470
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
470
471
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
471
- swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
472
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
472
473
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
473
474
  loaded. Default is None.
474
475
  """
@@ -485,7 +486,9 @@ def apply_liger_kernel_to_llama4(
485
486
  from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
486
487
 
487
488
  if rope:
488
- raise NotImplementedError("liger_rotary_pos_emb is not available for Llama4 models.")
489
+ from liger_kernel.transformers.llama4_rope import apply_liger_llama4_rope_full
490
+
491
+ apply_liger_llama4_rope_full(modeling_llama4)
489
492
  if rms_norm:
490
493
  modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
491
494
  if swiglu:
@@ -519,7 +522,10 @@ def apply_liger_kernel_to_llama4(
519
522
  _patch_rms_norm_module(text_model.norm)
520
523
  for decoder_layer in text_model.layers:
521
524
  if swiglu:
522
- _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
525
+ if decoder_layer.is_moe_layer:
526
+ _patch_swiglu_module(decoder_layer.feed_forward.shared_expert, LigerSwiGLUMLP)
527
+ else:
528
+ _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
523
529
  if rms_norm:
524
530
  _patch_rms_norm_module(decoder_layer.input_layernorm)
525
531
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -1333,7 +1339,6 @@ def apply_liger_kernel_to_qwen2(
1333
1339
  if rms_norm:
1334
1340
  _patch_rms_norm_module(decoder_layer.input_layernorm)
1335
1341
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1336
- print("Applied Liger kernels to Qwen2")
1337
1342
 
1338
1343
 
1339
1344
  def apply_liger_kernel_to_qwen3(
@@ -1675,25 +1680,14 @@ def apply_liger_kernel_to_phi3(
1675
1680
  if swiglu:
1676
1681
  modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
1677
1682
  if cross_entropy:
1678
- if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1679
- from transformers.loss.loss_utils import nn
1683
+ from transformers.loss.loss_utils import nn
1680
1684
 
1681
- nn.functional.cross_entropy = liger_cross_entropy
1682
- else:
1683
- logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1684
- modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
1685
+ nn.functional.cross_entropy = liger_cross_entropy
1685
1686
  if fused_linear_cross_entropy:
1686
- if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1687
- if model is not None:
1688
- model.forward = MethodType(phi3_lce_forward, model)
1689
- else:
1690
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1691
- else: # if version < 4.46.1
1692
- logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1693
- if model is not None:
1694
- model.forward = MethodType(phi3_lce_forward_deprecated, model)
1695
- else:
1696
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
1687
+ if model is not None:
1688
+ model.forward = MethodType(phi3_lce_forward, model)
1689
+ else:
1690
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1697
1691
 
1698
1692
  if model is not None:
1699
1693
  # The model instance already exists, so we need to additionally patch the
@@ -1849,6 +1843,535 @@ def apply_liger_kernel_to_glm4(
1849
1843
  _patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
1850
1844
 
1851
1845
 
1846
+ def apply_liger_kernel_to_glm4v(
1847
+ rope: bool = False,
1848
+ cross_entropy: bool = False,
1849
+ fused_linear_cross_entropy: bool = True,
1850
+ rms_norm: bool = True,
1851
+ swiglu: bool = True,
1852
+ model: PreTrainedModel = None,
1853
+ ) -> None:
1854
+ """
1855
+ Apply Liger kernels to replace original implementation in HuggingFace GLM-4v models.
1856
+
1857
+ Args:
1858
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
1859
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1860
+ fused_linear_cross_entropy (bool):
1861
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1862
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1863
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1864
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1865
+ swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
1866
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1867
+ loaded. Default is None.
1868
+ """
1869
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1870
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1871
+ )
1872
+
1873
+ from transformers.models.glm4v import modeling_glm4v
1874
+ from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration
1875
+ from transformers.models.glm4v.modeling_glm4v import Glm4vModel
1876
+ from transformers.models.glm4v.modeling_glm4v import Glm4vTextModel
1877
+ from transformers.models.glm4v.modeling_glm4v import Glm4vVisionModel
1878
+
1879
+ from liger_kernel.transformers.model.glm4v import lce_forward as glm4v_lce_forward
1880
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
1881
+
1882
+ if rope:
1883
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
1884
+ if rms_norm:
1885
+ modeling_glm4v.Glm4vRMSNorm = LigerRMSNormForGlm4
1886
+ if cross_entropy:
1887
+ from transformers.loss.loss_utils import nn
1888
+
1889
+ nn.functional.cross_entropy = liger_cross_entropy
1890
+ if fused_linear_cross_entropy:
1891
+ if model is not None:
1892
+ model.forward = MethodType(glm4v_lce_forward, model)
1893
+ else:
1894
+ modeling_glm4v.Glm4vForConditionalGeneration.forward = glm4v_lce_forward
1895
+
1896
+ if model is not None:
1897
+ # The model instance already exists, so we need to additionally patch the
1898
+ # instance variables that reference already-instantiated modules
1899
+ if isinstance(model, (Glm4vForConditionalGeneration, Glm4vModel)):
1900
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
1901
+ # Not sure if it is subject to changes in the future.
1902
+ # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py#L1305
1903
+ text_model: Glm4vTextModel = model.language_model
1904
+ vision_model: Glm4vVisionModel = model.visual
1905
+ elif isinstance(model, Glm4vTextModel):
1906
+ text_model: Glm4vTextModel = model
1907
+ vision_model = None
1908
+ else:
1909
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
1910
+ raise TypeError(
1911
+ f"Unsupported glm4.1v model type. `model` must be `Glm4VLForConditionalGeneration`, `Glm4vVisionModel` or `Glm4vTextModel`. Got: {type(model)}"
1912
+ )
1913
+
1914
+ if vision_model is not None:
1915
+ for vision_block in vision_model.blocks:
1916
+ if rms_norm:
1917
+ _patch_rms_norm_module(vision_block.norm1)
1918
+ _patch_rms_norm_module(vision_block.norm2)
1919
+ if swiglu:
1920
+ _patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
1921
+
1922
+ if text_model is not None:
1923
+ if rms_norm:
1924
+ _patch_rms_norm_module(text_model.norm)
1925
+ for decoder_layer in text_model.layers:
1926
+ if swiglu:
1927
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
1928
+ if rms_norm:
1929
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1930
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1931
+ _patch_rms_norm_module(decoder_layer.post_self_attn_layernorm)
1932
+ _patch_rms_norm_module(decoder_layer.post_mlp_layernorm)
1933
+
1934
+
1935
+ def apply_liger_kernel_to_glm4v_moe(
1936
+ rope: bool = False,
1937
+ cross_entropy: bool = False,
1938
+ fused_linear_cross_entropy: bool = True,
1939
+ rms_norm: bool = True,
1940
+ swiglu: bool = True,
1941
+ model: PreTrainedModel = None,
1942
+ ) -> None:
1943
+ """
1944
+ Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
1945
+
1946
+ Args:
1947
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
1948
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1949
+ fused_linear_cross_entropy (bool):
1950
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1951
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1952
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1953
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1954
+ swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
1955
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1956
+ loaded. Default is None.
1957
+ """
1958
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1959
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1960
+ )
1961
+
1962
+ from transformers.models.glm4v_moe import modeling_glm4v_moe
1963
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeForConditionalGeneration
1964
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeModel
1965
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextModel
1966
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeVisionModel
1967
+
1968
+ from liger_kernel.transformers.model.glm4v_moe import lce_forward as glm4v_moe_lce_forward
1969
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
1970
+
1971
+ if rope:
1972
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
1973
+ if rms_norm:
1974
+ modeling_glm4v_moe.Glm4vMoeRMSNorm = LigerRMSNormForGlm4
1975
+ modeling_glm4v_moe.Glm4vMoeTextRMSNorm = LigerRMSNormForGlm4
1976
+ if cross_entropy:
1977
+ from transformers.loss.loss_utils import nn
1978
+
1979
+ nn.functional.cross_entropy = liger_cross_entropy
1980
+ if fused_linear_cross_entropy:
1981
+ if model is not None:
1982
+ model.forward = MethodType(glm4v_moe_lce_forward, model)
1983
+ else:
1984
+ modeling_glm4v_moe.Glm4vMoeForConditionalGeneration.forward = glm4v_moe_lce_forward
1985
+
1986
+ if model is not None:
1987
+ # The model instance already exists, so we need to additionally patch the
1988
+ # instance variables that reference already-instantiated modules
1989
+ if isinstance(model, (Glm4vMoeForConditionalGeneration, Glm4vMoeModel)):
1990
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
1991
+ # Not sure if it is subject to changes in the future.
1992
+ # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py#L337
1993
+ text_model: Glm4vMoeTextModel = model.language_model
1994
+ vision_model: Glm4vMoeVisionModel = model.visual
1995
+ Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
1996
+ elif isinstance(model, Glm4vMoeTextModel):
1997
+ text_model: Glm4vMoeTextModel = model
1998
+ vision_model = None
1999
+ else:
2000
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
2001
+ raise TypeError(
2002
+ f"Unsupported glm4v_moe model type. `model` must be `Glm4vMoeForConditionalGeneration`, `Glm4vMoeVisionModel` or `Glm4vMoeTextModel`. Got: {type(model)}"
2003
+ )
2004
+
2005
+ if vision_model is not None:
2006
+ _patch_rms_norm_module(vision_model.post_conv_layernorm)
2007
+ _patch_rms_norm_module(vision_model.post_layernorm)
2008
+ for vision_block in vision_model.blocks:
2009
+ if rms_norm:
2010
+ _patch_rms_norm_module(vision_block.norm1)
2011
+ _patch_rms_norm_module(vision_block.norm2)
2012
+ if swiglu:
2013
+ _patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
2014
+
2015
+ if text_model is not None:
2016
+ if rms_norm:
2017
+ _patch_rms_norm_module(text_model.norm)
2018
+ for decoder_layer in text_model.layers:
2019
+ if swiglu:
2020
+ decoder_layer.mlp = _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2021
+ if rms_norm:
2022
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2023
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2024
+ if isinstance(Glm4vMoeTextMoE, type) and isinstance(decoder_layer.mlp, Glm4vMoeTextMoE):
2025
+ experts = getattr(decoder_layer.mlp, "experts", None)
2026
+ if experts is not None:
2027
+ for expert in experts:
2028
+ _patch_swiglu_module(expert, LigerSwiGLUMLP)
2029
+ if decoder_layer.mlp.shared_experts is not None:
2030
+ _patch_swiglu_module(decoder_layer.mlp.shared_experts, LigerSwiGLUMLP)
2031
+ for decoder_layer in text_model.layers:
2032
+ if rms_norm:
2033
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2034
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2035
+
2036
+
2037
+ def apply_liger_kernel_to_internvl(
2038
+ cross_entropy: bool = False,
2039
+ fused_linear_cross_entropy: bool = True,
2040
+ rms_norm: bool = True,
2041
+ model: Optional[PreTrainedModel] = None,
2042
+ **kwargs,
2043
+ ) -> None:
2044
+ """
2045
+ Apply Liger kernels to replace original implementation in HuggingFace InternVL models.
2046
+ Due to the characteristics of InternVL, the model must be passed to apply Liger-Kernel's patch to other models connected to InternVL.
2047
+ However, if an LM not supported by Liger-Kernel is connected to InternVL, unexpected side effects may occur.
2048
+ NOTE: InternVL is not available in transformers<4.52.1
2049
+
2050
+ Args:
2051
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2052
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2053
+ fused_linear_cross_entropy (bool):
2054
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2055
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2056
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2057
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2058
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
2059
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2060
+ loaded. Default is None.
2061
+ """
2062
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2063
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2064
+ )
2065
+
2066
+ from transformers.models.internvl import modeling_internvl
2067
+
2068
+ from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
2069
+
2070
+ if cross_entropy:
2071
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
2072
+ modeling_internvl.nn.CrossEntropyLoss = LigerCrossEntropyLoss
2073
+ if fused_linear_cross_entropy:
2074
+ modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
2075
+ if rms_norm:
2076
+ modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
2077
+
2078
+ if model is not None:
2079
+ text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
2080
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2081
+ vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)
2082
+
2083
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2084
+ if text_liger_fn:
2085
+ accept_params = inspect.signature(text_liger_fn).parameters
2086
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2087
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2088
+
2089
+ if remain_params:
2090
+ logger.warning(
2091
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2092
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2093
+ )
2094
+ text_kwargs["model"] = model.language_model
2095
+ text_liger_fn(**text_kwargs)
2096
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2097
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2098
+
2099
+ if vision_liger_fn:
2100
+ accept_params = inspect.signature(vision_liger_fn).parameters
2101
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2102
+ vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2103
+
2104
+ if remain_params:
2105
+ logger.warning(
2106
+ f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
2107
+ f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
2108
+ )
2109
+ vision_kwargs["model"] = model.vision_tower
2110
+ vision_liger_fn(**vision_kwargs)
2111
+ elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2112
+ logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
2113
+
2114
+
2115
+ def apply_liger_kernel_to_smolvlm(
2116
+ cross_entropy: bool = False,
2117
+ fused_linear_cross_entropy: bool = True,
2118
+ rms_norm: bool = True,
2119
+ layer_norm: bool = True,
2120
+ model: Optional[PreTrainedModel] = None,
2121
+ **kwargs,
2122
+ ) -> None:
2123
+ """
2124
+ Apply Liger kernels to replace original implementation in HuggingFace SmolVLM models.
2125
+ Due to the characteristics of SmolVLM, the model must be passed to apply Liger-Kernel's patch to other models connected to SmolVLM.
2126
+ However, if an LM not supported by Liger-Kernel is connected to SmolVLM, unexpected side effects may occur.
2127
+ NOTE: SmolVLM is not available in transformers<4.50.0
2128
+
2129
+ Args:
2130
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2131
+ fused_linear_cross_entropy (bool):
2132
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2133
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2134
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2135
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2136
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
2137
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2138
+ loaded. Default is None.
2139
+ """
2140
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2141
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2142
+ )
2143
+
2144
+ from transformers.models.smolvlm import modeling_smolvlm
2145
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMEncoderLayer
2146
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration
2147
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMModel
2148
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer
2149
+
2150
+ from liger_kernel.transformers.model.smolvlm import lce_forward as smolvlm_lce_forward
2151
+
2152
+ # Patch LayerNorm for vision model if model is not provided (pre-initialization)
2153
+ if layer_norm and model is None:
2154
+ modeling_smolvlm.nn.LayerNorm = LigerLayerNorm
2155
+
2156
+ if cross_entropy:
2157
+ logger.info("Apply liger cross entropy")
2158
+
2159
+ from transformers.loss.loss_utils import nn
2160
+
2161
+ nn.functional.cross_entropy = liger_cross_entropy
2162
+ if fused_linear_cross_entropy:
2163
+ if model is not None:
2164
+ model.forward = MethodType(smolvlm_lce_forward, model)
2165
+ else:
2166
+ modeling_smolvlm.SmolVLMForConditionalGeneration.forward = smolvlm_lce_forward
2167
+ if rms_norm:
2168
+ modeling_smolvlm.SmolVLMRMSNorm = LigerRMSNorm
2169
+
2170
+ if model is not None:
2171
+ # The model instance already exists, so we need to additionally patch the
2172
+ # instance variables that reference already-instantiated modules
2173
+ if isinstance(model, SmolVLMForConditionalGeneration):
2174
+ text_model = model.model.text_model
2175
+ vision_model: SmolVLMVisionTransformer = model.model.vision_model
2176
+ elif isinstance(model, SmolVLMModel):
2177
+ text_model = model.text_model
2178
+ vision_model: SmolVLMVisionTransformer = model.vision_model
2179
+ else:
2180
+ raise TypeError(
2181
+ f"Unsupported smolvlm model type. `model` must be `SmolVLMForConditionalGeneration`, `SmolVLMModel`. Got: {type(model)}"
2182
+ )
2183
+
2184
+ text_model_name = model.config.text_config.model_type
2185
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2186
+
2187
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2188
+ if text_liger_fn:
2189
+ accept_params = inspect.signature(text_liger_fn).parameters
2190
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2191
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2192
+
2193
+ if remain_params:
2194
+ logger.warning(
2195
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2196
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2197
+ )
2198
+ text_kwargs["model"] = text_model
2199
+ text_liger_fn(**text_kwargs)
2200
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2201
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2202
+
2203
+ # Patch vision model LayerNorm layers
2204
+ if layer_norm:
2205
+ # Patch post_layernorm
2206
+ _patch_layer_norm_module(vision_model.post_layernorm)
2207
+
2208
+ # Patch encoder layers
2209
+ for encoder_layer in vision_model.encoder.layers:
2210
+ encoder_layer: SmolVLMEncoderLayer
2211
+ _patch_layer_norm_module(encoder_layer.layer_norm1)
2212
+ _patch_layer_norm_module(encoder_layer.layer_norm2)
2213
+
2214
+
2215
+ def apply_liger_kernel_to_falcon_h1(
2216
+ rope: bool = True,
2217
+ cross_entropy: bool = False,
2218
+ fused_linear_cross_entropy: bool = True,
2219
+ rms_norm: bool = True,
2220
+ swiglu: bool = False,
2221
+ model: PreTrainedModel = None,
2222
+ ) -> None:
2223
+ """
2224
+ Apply Liger kernels to replace original implementation in HuggingFace Falcon-H1 models
2225
+ Args:
2226
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2227
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
2228
+ fused_linear_cross_entropy (bool):
2229
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
2230
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2231
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2232
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
2233
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
2234
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2235
+ loaded. Default is None.
2236
+ """
2237
+
2238
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2239
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2240
+ )
2241
+
2242
+ from transformers.models.falcon_h1 import modeling_falcon_h1
2243
+ from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1Model
2244
+
2245
+ if rope:
2246
+ logger.info("Apply liger rotary pos emb.")
2247
+ modeling_falcon_h1.apply_rotary_pos_emb = liger_rotary_pos_emb
2248
+ if rms_norm:
2249
+ logger.info("Apply liger RMSNorm")
2250
+ modeling_falcon_h1.FalconH1RMSNorm = LigerRMSNorm
2251
+ if swiglu:
2252
+ logger.warning("LigerSwiGLUMLP is not available for Falcon-H1 models. There will be no effect.")
2253
+
2254
+ if cross_entropy:
2255
+ logger.info("Apply liger cross entropy")
2256
+ from transformers.loss.loss_utils import nn
2257
+
2258
+ nn.functional.cross_entropy = liger_cross_entropy
2259
+
2260
+ if fused_linear_cross_entropy:
2261
+ if model is not None:
2262
+ model.forward = MethodType(falcon_h1_lce_forward, model)
2263
+ else:
2264
+ modeling_falcon_h1.FalconH1ForCausalLM.forward = falcon_h1_lce_forward
2265
+
2266
+ if model is not None:
2267
+ # The model instance already exists, so we need to additionally patch the
2268
+ # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
2269
+
2270
+ # get the base model from the model instance
2271
+ base_model: FalconH1Model = getattr(model, model.base_model_prefix, model)
2272
+
2273
+ if rms_norm:
2274
+ _patch_rms_norm_module(base_model.final_layernorm)
2275
+
2276
+ for decoder_layer in base_model.layers:
2277
+ if swiglu:
2278
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2279
+ if rms_norm:
2280
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2281
+ _patch_rms_norm_module(decoder_layer.pre_ff_layernorm)
2282
+
2283
+
2284
+ def apply_liger_kernel_to_qwen3_next(
2285
+ rope: bool = False,
2286
+ cross_entropy: bool = False,
2287
+ fused_linear_cross_entropy: bool = True,
2288
+ rms_norm: bool = True,
2289
+ swiglu: bool = True,
2290
+ model: PreTrainedModel = None,
2291
+ ) -> None:
2292
+ """
2293
+ Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
2294
+
2295
+ Args:
2296
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2297
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2298
+ fused_linear_cross_entropy (bool):
2299
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2300
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2301
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2302
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2303
+ swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
2304
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2305
+ loaded. Default is None.
2306
+ """
2307
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2308
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2309
+ )
2310
+
2311
+ from transformers.models.qwen3_next import modeling_qwen3_next
2312
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM
2313
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP
2314
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextModel
2315
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
2316
+
2317
+ from liger_kernel.transformers.model.qwen3_next import lce_forward as qwen3_next_lce_forward
2318
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
2319
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
2320
+
2321
+ if rope:
2322
+ # It might enocunter nan issue
2323
+ # modeling_qwen3_next.apply_rotary_pos_emb = liger_rotary_pos_emb
2324
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3Next models.")
2325
+ if rms_norm:
2326
+ modeling_qwen3_next.Qwen3NextRMSNorm = LigerRMSNormForQwen3Next
2327
+ if cross_entropy:
2328
+ from transformers.loss.loss_utils import nn
2329
+
2330
+ nn.functional.cross_entropy = liger_cross_entropy
2331
+ if fused_linear_cross_entropy:
2332
+ if model is not None:
2333
+ if isinstance(model, Qwen3NextForCausalLM):
2334
+ model.forward = MethodType(qwen3_next_lce_forward, model)
2335
+ else:
2336
+ raise TypeError(
2337
+ f" fused_linear_cross_entropy is only applicable on Qwen3NextForCausalLM. Got: {type(model)}"
2338
+ )
2339
+ else:
2340
+ modeling_qwen3_next.Qwen3NextForCausalLM.forward = qwen3_next_lce_forward
2341
+ if swiglu:
2342
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2343
+ modeling_qwen3_next.Qwen3NextMLP = LigerQwen3MoeSwiGLUMLP
2344
+
2345
+ if model is not None:
2346
+ # The model instance already exists, so we need to additionally patch the
2347
+ # instance variables that reference already-instantiated modules
2348
+ if isinstance(model, (Qwen3NextForCausalLM, Qwen3NextModel)):
2349
+ base_model: Qwen3NextForCausalLM = getattr(model, model.base_model_prefix, model)
2350
+ else:
2351
+ raise TypeError(
2352
+ f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}"
2353
+ )
2354
+
2355
+ if rms_norm:
2356
+ _patch_rms_norm_module(base_model.norm)
2357
+
2358
+ for decoder_layer in base_model.layers:
2359
+ if rms_norm:
2360
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2361
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2362
+
2363
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2364
+ if swiglu:
2365
+ if isinstance(decoder_layer.mlp, Qwen3NextMLP):
2366
+ _patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
2367
+ if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock):
2368
+ _patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP)
2369
+ experts = getattr(decoder_layer.mlp, "experts", None)
2370
+ if experts is not None:
2371
+ for expert in experts:
2372
+ _patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
2373
+
2374
+
1852
2375
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
1853
2376
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
1854
2377
  "gemma": apply_liger_kernel_to_gemma,
@@ -1856,6 +2379,9 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1856
2379
  "gemma3_text": apply_liger_kernel_to_gemma3_text,
1857
2380
  "gemma3": apply_liger_kernel_to_gemma3,
1858
2381
  "glm4": apply_liger_kernel_to_glm4,
2382
+ "glm4v": apply_liger_kernel_to_glm4v,
2383
+ "glm4v_moe": apply_liger_kernel_to_glm4v_moe,
2384
+ "internvl": apply_liger_kernel_to_internvl,
1859
2385
  "llama": apply_liger_kernel_to_llama,
1860
2386
  "llama4_text": apply_liger_kernel_to_llama4,
1861
2387
  "llama4": apply_liger_kernel_to_llama4,
@@ -1873,9 +2399,12 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1873
2399
  "qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
1874
2400
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
1875
2401
  "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
2402
+ "qwen3_next": apply_liger_kernel_to_qwen3_next,
1876
2403
  "smollm3": apply_liger_kernel_to_smollm3,
1877
2404
  "phi3": apply_liger_kernel_to_phi3,
1878
2405
  "paligemma": apply_liger_kernel_to_paligemma,
2406
+ "falcon_h1": apply_liger_kernel_to_falcon_h1,
2407
+ "smolvlm": apply_liger_kernel_to_smolvlm,
1879
2408
  }
1880
2409
 
1881
2410
 
@@ -9,7 +9,7 @@ from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunct
9
9
 
10
10
 
11
11
  class LigerMultiTokenAttention(nn.Module):
12
- """
12
+ r"""
13
13
  Multi-Token Attention:
14
14
  out = mask_{0}(conv2d(softmax(mask_{-\inf}(scores))))
15
15