liger-kernel 0.6.2__py3-none-any.whl → 0.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. liger_kernel/chunked_loss/fused_linear_ppo.py +4 -0
  2. liger_kernel/chunked_loss/grpo_loss.py +38 -4
  3. liger_kernel/chunked_loss/jsd_loss.py +5 -2
  4. liger_kernel/ops/cross_entropy.py +59 -53
  5. liger_kernel/ops/fused_linear_cross_entropy.py +68 -10
  6. liger_kernel/ops/layer_norm.py +4 -6
  7. liger_kernel/ops/poly_norm.py +386 -0
  8. liger_kernel/transformers/__init__.py +17 -0
  9. liger_kernel/transformers/functional.py +7 -0
  10. liger_kernel/transformers/fused_linear_cross_entropy.py +5 -1
  11. liger_kernel/transformers/model/falcon_h1.py +108 -0
  12. liger_kernel/transformers/model/gemma.py +2 -1
  13. liger_kernel/transformers/model/gemma2.py +8 -2
  14. liger_kernel/transformers/model/gemma3.py +27 -2
  15. liger_kernel/transformers/model/glm4.py +2 -1
  16. liger_kernel/transformers/model/glm4v.py +3 -2
  17. liger_kernel/transformers/model/glm4v_moe.py +153 -0
  18. liger_kernel/transformers/model/internvl.py +150 -0
  19. liger_kernel/transformers/model/llama.py +2 -1
  20. liger_kernel/transformers/model/llama4.py +2 -1
  21. liger_kernel/transformers/model/llava.py +6 -2
  22. liger_kernel/transformers/model/loss_utils.py +1 -0
  23. liger_kernel/transformers/model/mistral.py +2 -1
  24. liger_kernel/transformers/model/mixtral.py +8 -2
  25. liger_kernel/transformers/model/mllama.py +2 -1
  26. liger_kernel/transformers/model/olmo2.py +2 -1
  27. liger_kernel/transformers/model/paligemma.py +19 -0
  28. liger_kernel/transformers/model/phi3.py +2 -1
  29. liger_kernel/transformers/model/qwen2.py +2 -1
  30. liger_kernel/transformers/model/qwen2_5_vl.py +7 -2
  31. liger_kernel/transformers/model/qwen2_vl.py +7 -2
  32. liger_kernel/transformers/model/qwen3.py +2 -1
  33. liger_kernel/transformers/model/qwen3_moe.py +8 -2
  34. liger_kernel/transformers/model/qwen3_next.py +134 -0
  35. liger_kernel/transformers/model/smollm3.py +2 -1
  36. liger_kernel/transformers/model/smolvlm.py +158 -0
  37. liger_kernel/transformers/monkey_patch.py +452 -3
  38. liger_kernel/transformers/multi_token_attention.py +1 -1
  39. liger_kernel/transformers/poly_norm.py +42 -0
  40. liger_kernel/transformers/rms_norm.py +7 -0
  41. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/METADATA +13 -10
  42. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/RECORD +46 -39
  43. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/WHEEL +0 -0
  44. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/licenses/LICENSE +0 -0
  45. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/licenses/NOTICE +0 -0
  46. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,7 @@ import logging
4
4
  from functools import partial
5
5
  from types import MethodType
6
6
  from typing import Callable
7
+ from typing import Optional
7
8
 
8
9
  import transformers
9
10
 
@@ -14,6 +15,7 @@ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
14
15
  from liger_kernel.transformers.functional import liger_cross_entropy
15
16
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
16
17
  from liger_kernel.transformers.layer_norm import LigerLayerNorm
18
+ from liger_kernel.transformers.model.falcon_h1 import lce_forward as falcon_h1_lce_forward
17
19
  from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
18
20
  from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
19
21
  from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
@@ -467,7 +469,7 @@ def apply_liger_kernel_to_llama4(
467
469
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
468
470
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
469
471
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
470
- swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
472
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
471
473
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
472
474
  loaded. Default is None.
473
475
  """
@@ -520,7 +522,10 @@ def apply_liger_kernel_to_llama4(
520
522
  _patch_rms_norm_module(text_model.norm)
521
523
  for decoder_layer in text_model.layers:
522
524
  if swiglu:
523
- _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
525
+ if decoder_layer.is_moe_layer:
526
+ _patch_swiglu_module(decoder_layer.feed_forward.shared_expert, LigerSwiGLUMLP)
527
+ else:
528
+ _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
524
529
  if rms_norm:
525
530
  _patch_rms_norm_module(decoder_layer.input_layernorm)
526
531
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -1334,7 +1339,6 @@ def apply_liger_kernel_to_qwen2(
1334
1339
  if rms_norm:
1335
1340
  _patch_rms_norm_module(decoder_layer.input_layernorm)
1336
1341
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1337
- print("Applied Liger kernels to Qwen2")
1338
1342
 
1339
1343
 
1340
1344
  def apply_liger_kernel_to_qwen3(
@@ -1928,6 +1932,446 @@ def apply_liger_kernel_to_glm4v(
1928
1932
  _patch_rms_norm_module(decoder_layer.post_mlp_layernorm)
1929
1933
 
1930
1934
 
1935
+ def apply_liger_kernel_to_glm4v_moe(
1936
+ rope: bool = False,
1937
+ cross_entropy: bool = False,
1938
+ fused_linear_cross_entropy: bool = True,
1939
+ rms_norm: bool = True,
1940
+ swiglu: bool = True,
1941
+ model: PreTrainedModel = None,
1942
+ ) -> None:
1943
+ """
1944
+ Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
1945
+
1946
+ Args:
1947
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
1948
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1949
+ fused_linear_cross_entropy (bool):
1950
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1951
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1952
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1953
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1954
+ swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
1955
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1956
+ loaded. Default is None.
1957
+ """
1958
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1959
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1960
+ )
1961
+
1962
+ from transformers.models.glm4v_moe import modeling_glm4v_moe
1963
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeForConditionalGeneration
1964
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeModel
1965
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextModel
1966
+ from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeVisionModel
1967
+
1968
+ from liger_kernel.transformers.model.glm4v_moe import lce_forward as glm4v_moe_lce_forward
1969
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
1970
+
1971
+ if rope:
1972
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
1973
+ if rms_norm:
1974
+ modeling_glm4v_moe.Glm4vMoeRMSNorm = LigerRMSNormForGlm4
1975
+ modeling_glm4v_moe.Glm4vMoeTextRMSNorm = LigerRMSNormForGlm4
1976
+ if cross_entropy:
1977
+ from transformers.loss.loss_utils import nn
1978
+
1979
+ nn.functional.cross_entropy = liger_cross_entropy
1980
+ if fused_linear_cross_entropy:
1981
+ if model is not None:
1982
+ model.forward = MethodType(glm4v_moe_lce_forward, model)
1983
+ else:
1984
+ modeling_glm4v_moe.Glm4vMoeForConditionalGeneration.forward = glm4v_moe_lce_forward
1985
+
1986
+ if model is not None:
1987
+ # The model instance already exists, so we need to additionally patch the
1988
+ # instance variables that reference already-instantiated modules
1989
+ if isinstance(model, (Glm4vMoeForConditionalGeneration, Glm4vMoeModel)):
1990
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
1991
+ # Not sure if it is subject to changes in the future.
1992
+ # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py#L337
1993
+ text_model: Glm4vMoeTextModel = model.language_model
1994
+ vision_model: Glm4vMoeVisionModel = model.visual
1995
+ Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
1996
+ elif isinstance(model, Glm4vMoeTextModel):
1997
+ text_model: Glm4vMoeTextModel = model
1998
+ vision_model = None
1999
+ else:
2000
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
2001
+ raise TypeError(
2002
+ f"Unsupported glm4v_moe model type. `model` must be `Glm4vMoeForConditionalGeneration`, `Glm4vMoeVisionModel` or `Glm4vMoeTextModel`. Got: {type(model)}"
2003
+ )
2004
+
2005
+ if vision_model is not None:
2006
+ _patch_rms_norm_module(vision_model.post_conv_layernorm)
2007
+ _patch_rms_norm_module(vision_model.post_layernorm)
2008
+ for vision_block in vision_model.blocks:
2009
+ if rms_norm:
2010
+ _patch_rms_norm_module(vision_block.norm1)
2011
+ _patch_rms_norm_module(vision_block.norm2)
2012
+ if swiglu:
2013
+ _patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
2014
+
2015
+ if text_model is not None:
2016
+ if rms_norm:
2017
+ _patch_rms_norm_module(text_model.norm)
2018
+ for decoder_layer in text_model.layers:
2019
+ if swiglu:
2020
+ decoder_layer.mlp = _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2021
+ if rms_norm:
2022
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2023
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2024
+ if isinstance(Glm4vMoeTextMoE, type) and isinstance(decoder_layer.mlp, Glm4vMoeTextMoE):
2025
+ experts = getattr(decoder_layer.mlp, "experts", None)
2026
+ if experts is not None:
2027
+ for expert in experts:
2028
+ _patch_swiglu_module(expert, LigerSwiGLUMLP)
2029
+ if decoder_layer.mlp.shared_experts is not None:
2030
+ _patch_swiglu_module(decoder_layer.mlp.shared_experts, LigerSwiGLUMLP)
2031
+ for decoder_layer in text_model.layers:
2032
+ if rms_norm:
2033
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2034
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2035
+
2036
+
2037
+ def apply_liger_kernel_to_internvl(
2038
+ cross_entropy: bool = False,
2039
+ fused_linear_cross_entropy: bool = True,
2040
+ rms_norm: bool = True,
2041
+ model: Optional[PreTrainedModel] = None,
2042
+ **kwargs,
2043
+ ) -> None:
2044
+ """
2045
+ Apply Liger kernels to replace original implementation in HuggingFace InternVL models.
2046
+ Due to the characteristics of InternVL, the model must be passed to apply Liger-Kernel's patch to other models connected to InternVL.
2047
+ However, if an LM not supported by Liger-Kernel is connected to InternVL, unexpected side effects may occur.
2048
+ NOTE: InternVL is not available in transformers<4.52.1
2049
+
2050
+ Args:
2051
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2052
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2053
+ fused_linear_cross_entropy (bool):
2054
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2055
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2056
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2057
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2058
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
2059
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2060
+ loaded. Default is None.
2061
+ """
2062
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2063
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2064
+ )
2065
+
2066
+ from transformers.models.internvl import modeling_internvl
2067
+
2068
+ from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
2069
+
2070
+ if cross_entropy:
2071
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
2072
+ modeling_internvl.nn.CrossEntropyLoss = LigerCrossEntropyLoss
2073
+ if fused_linear_cross_entropy:
2074
+ modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
2075
+ if rms_norm:
2076
+ modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
2077
+
2078
+ if model is not None:
2079
+ text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
2080
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2081
+ vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)
2082
+
2083
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2084
+ if text_liger_fn:
2085
+ accept_params = inspect.signature(text_liger_fn).parameters
2086
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2087
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2088
+
2089
+ if remain_params:
2090
+ logger.warning(
2091
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2092
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2093
+ )
2094
+ text_kwargs["model"] = model.language_model
2095
+ text_liger_fn(**text_kwargs)
2096
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2097
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2098
+
2099
+ if vision_liger_fn:
2100
+ accept_params = inspect.signature(vision_liger_fn).parameters
2101
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2102
+ vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2103
+
2104
+ if remain_params:
2105
+ logger.warning(
2106
+ f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
2107
+ f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
2108
+ )
2109
+ vision_kwargs["model"] = model.vision_tower
2110
+ vision_liger_fn(**vision_kwargs)
2111
+ elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2112
+ logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
2113
+
2114
+
2115
+ def apply_liger_kernel_to_smolvlm(
2116
+ cross_entropy: bool = False,
2117
+ fused_linear_cross_entropy: bool = True,
2118
+ rms_norm: bool = True,
2119
+ layer_norm: bool = True,
2120
+ model: Optional[PreTrainedModel] = None,
2121
+ **kwargs,
2122
+ ) -> None:
2123
+ """
2124
+ Apply Liger kernels to replace original implementation in HuggingFace SmolVLM models.
2125
+ Due to the characteristics of SmolVLM, the model must be passed to apply Liger-Kernel's patch to other models connected to SmolVLM.
2126
+ However, if an LM not supported by Liger-Kernel is connected to SmolVLM, unexpected side effects may occur.
2127
+ NOTE: SmolVLM is not available in transformers<4.50.0
2128
+
2129
+ Args:
2130
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2131
+ fused_linear_cross_entropy (bool):
2132
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2133
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2134
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2135
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2136
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
2137
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2138
+ loaded. Default is None.
2139
+ """
2140
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2141
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2142
+ )
2143
+
2144
+ from transformers.models.smolvlm import modeling_smolvlm
2145
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMEncoderLayer
2146
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration
2147
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMModel
2148
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer
2149
+
2150
+ from liger_kernel.transformers.model.smolvlm import lce_forward as smolvlm_lce_forward
2151
+
2152
+ # Patch LayerNorm for vision model if model is not provided (pre-initialization)
2153
+ if layer_norm and model is None:
2154
+ modeling_smolvlm.nn.LayerNorm = LigerLayerNorm
2155
+
2156
+ if cross_entropy:
2157
+ logger.info("Apply liger cross entropy")
2158
+
2159
+ from transformers.loss.loss_utils import nn
2160
+
2161
+ nn.functional.cross_entropy = liger_cross_entropy
2162
+ if fused_linear_cross_entropy:
2163
+ if model is not None:
2164
+ model.forward = MethodType(smolvlm_lce_forward, model)
2165
+ else:
2166
+ modeling_smolvlm.SmolVLMForConditionalGeneration.forward = smolvlm_lce_forward
2167
+ if rms_norm:
2168
+ modeling_smolvlm.SmolVLMRMSNorm = LigerRMSNorm
2169
+
2170
+ if model is not None:
2171
+ # The model instance already exists, so we need to additionally patch the
2172
+ # instance variables that reference already-instantiated modules
2173
+ if isinstance(model, SmolVLMForConditionalGeneration):
2174
+ text_model = model.model.text_model
2175
+ vision_model: SmolVLMVisionTransformer = model.model.vision_model
2176
+ elif isinstance(model, SmolVLMModel):
2177
+ text_model = model.text_model
2178
+ vision_model: SmolVLMVisionTransformer = model.vision_model
2179
+ else:
2180
+ raise TypeError(
2181
+ f"Unsupported smolvlm model type. `model` must be `SmolVLMForConditionalGeneration`, `SmolVLMModel`. Got: {type(model)}"
2182
+ )
2183
+
2184
+ text_model_name = model.config.text_config.model_type
2185
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2186
+
2187
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2188
+ if text_liger_fn:
2189
+ accept_params = inspect.signature(text_liger_fn).parameters
2190
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2191
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2192
+
2193
+ if remain_params:
2194
+ logger.warning(
2195
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2196
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2197
+ )
2198
+ text_kwargs["model"] = text_model
2199
+ text_liger_fn(**text_kwargs)
2200
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2201
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2202
+
2203
+ # Patch vision model LayerNorm layers
2204
+ if layer_norm:
2205
+ # Patch post_layernorm
2206
+ _patch_layer_norm_module(vision_model.post_layernorm)
2207
+
2208
+ # Patch encoder layers
2209
+ for encoder_layer in vision_model.encoder.layers:
2210
+ encoder_layer: SmolVLMEncoderLayer
2211
+ _patch_layer_norm_module(encoder_layer.layer_norm1)
2212
+ _patch_layer_norm_module(encoder_layer.layer_norm2)
2213
+
2214
+
2215
+ def apply_liger_kernel_to_falcon_h1(
2216
+ rope: bool = True,
2217
+ cross_entropy: bool = False,
2218
+ fused_linear_cross_entropy: bool = True,
2219
+ rms_norm: bool = True,
2220
+ swiglu: bool = False,
2221
+ model: PreTrainedModel = None,
2222
+ ) -> None:
2223
+ """
2224
+ Apply Liger kernels to replace original implementation in HuggingFace Falcon-H1 models
2225
+ Args:
2226
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2227
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
2228
+ fused_linear_cross_entropy (bool):
2229
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
2230
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2231
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2232
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
2233
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
2234
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2235
+ loaded. Default is None.
2236
+ """
2237
+
2238
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2239
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2240
+ )
2241
+
2242
+ from transformers.models.falcon_h1 import modeling_falcon_h1
2243
+ from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1Model
2244
+
2245
+ if rope:
2246
+ logger.info("Apply liger rotary pos emb.")
2247
+ modeling_falcon_h1.apply_rotary_pos_emb = liger_rotary_pos_emb
2248
+ if rms_norm:
2249
+ logger.info("Apply liger RMSNorm")
2250
+ modeling_falcon_h1.FalconH1RMSNorm = LigerRMSNorm
2251
+ if swiglu:
2252
+ logger.warning("LigerSwiGLUMLP is not available for Falcon-H1 models. There will be no effect.")
2253
+
2254
+ if cross_entropy:
2255
+ logger.info("Apply liger cross entropy")
2256
+ from transformers.loss.loss_utils import nn
2257
+
2258
+ nn.functional.cross_entropy = liger_cross_entropy
2259
+
2260
+ if fused_linear_cross_entropy:
2261
+ if model is not None:
2262
+ model.forward = MethodType(falcon_h1_lce_forward, model)
2263
+ else:
2264
+ modeling_falcon_h1.FalconH1ForCausalLM.forward = falcon_h1_lce_forward
2265
+
2266
+ if model is not None:
2267
+ # The model instance already exists, so we need to additionally patch the
2268
+ # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
2269
+
2270
+ # get the base model from the model instance
2271
+ base_model: FalconH1Model = getattr(model, model.base_model_prefix, model)
2272
+
2273
+ if rms_norm:
2274
+ _patch_rms_norm_module(base_model.final_layernorm)
2275
+
2276
+ for decoder_layer in base_model.layers:
2277
+ if swiglu:
2278
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2279
+ if rms_norm:
2280
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2281
+ _patch_rms_norm_module(decoder_layer.pre_ff_layernorm)
2282
+
2283
+
2284
+ def apply_liger_kernel_to_qwen3_next(
2285
+ rope: bool = False,
2286
+ cross_entropy: bool = False,
2287
+ fused_linear_cross_entropy: bool = True,
2288
+ rms_norm: bool = True,
2289
+ swiglu: bool = True,
2290
+ model: PreTrainedModel = None,
2291
+ ) -> None:
2292
+ """
2293
+ Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
2294
+
2295
+ Args:
2296
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2297
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2298
+ fused_linear_cross_entropy (bool):
2299
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2300
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2301
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2302
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2303
+ swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
2304
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2305
+ loaded. Default is None.
2306
+ """
2307
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2308
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2309
+ )
2310
+
2311
+ from transformers.models.qwen3_next import modeling_qwen3_next
2312
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM
2313
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP
2314
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextModel
2315
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
2316
+
2317
+ from liger_kernel.transformers.model.qwen3_next import lce_forward as qwen3_next_lce_forward
2318
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
2319
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
2320
+
2321
+ if rope:
2322
+ # It might enocunter nan issue
2323
+ # modeling_qwen3_next.apply_rotary_pos_emb = liger_rotary_pos_emb
2324
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3Next models.")
2325
+ if rms_norm:
2326
+ modeling_qwen3_next.Qwen3NextRMSNorm = LigerRMSNormForQwen3Next
2327
+ if cross_entropy:
2328
+ from transformers.loss.loss_utils import nn
2329
+
2330
+ nn.functional.cross_entropy = liger_cross_entropy
2331
+ if fused_linear_cross_entropy:
2332
+ if model is not None:
2333
+ if isinstance(model, Qwen3NextForCausalLM):
2334
+ model.forward = MethodType(qwen3_next_lce_forward, model)
2335
+ else:
2336
+ raise TypeError(
2337
+ f" fused_linear_cross_entropy is only applicable on Qwen3NextForCausalLM. Got: {type(model)}"
2338
+ )
2339
+ else:
2340
+ modeling_qwen3_next.Qwen3NextForCausalLM.forward = qwen3_next_lce_forward
2341
+ if swiglu:
2342
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2343
+ modeling_qwen3_next.Qwen3NextMLP = LigerQwen3MoeSwiGLUMLP
2344
+
2345
+ if model is not None:
2346
+ # The model instance already exists, so we need to additionally patch the
2347
+ # instance variables that reference already-instantiated modules
2348
+ if isinstance(model, (Qwen3NextForCausalLM, Qwen3NextModel)):
2349
+ base_model: Qwen3NextForCausalLM = getattr(model, model.base_model_prefix, model)
2350
+ else:
2351
+ raise TypeError(
2352
+ f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}"
2353
+ )
2354
+
2355
+ if rms_norm:
2356
+ _patch_rms_norm_module(base_model.norm)
2357
+
2358
+ for decoder_layer in base_model.layers:
2359
+ if rms_norm:
2360
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2361
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2362
+
2363
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2364
+ if swiglu:
2365
+ if isinstance(decoder_layer.mlp, Qwen3NextMLP):
2366
+ _patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
2367
+ if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock):
2368
+ _patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP)
2369
+ experts = getattr(decoder_layer.mlp, "experts", None)
2370
+ if experts is not None:
2371
+ for expert in experts:
2372
+ _patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
2373
+
2374
+
1931
2375
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
1932
2376
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
1933
2377
  "gemma": apply_liger_kernel_to_gemma,
@@ -1936,6 +2380,8 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1936
2380
  "gemma3": apply_liger_kernel_to_gemma3,
1937
2381
  "glm4": apply_liger_kernel_to_glm4,
1938
2382
  "glm4v": apply_liger_kernel_to_glm4v,
2383
+ "glm4v_moe": apply_liger_kernel_to_glm4v_moe,
2384
+ "internvl": apply_liger_kernel_to_internvl,
1939
2385
  "llama": apply_liger_kernel_to_llama,
1940
2386
  "llama4_text": apply_liger_kernel_to_llama4,
1941
2387
  "llama4": apply_liger_kernel_to_llama4,
@@ -1953,9 +2399,12 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1953
2399
  "qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
1954
2400
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
1955
2401
  "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
2402
+ "qwen3_next": apply_liger_kernel_to_qwen3_next,
1956
2403
  "smollm3": apply_liger_kernel_to_smollm3,
1957
2404
  "phi3": apply_liger_kernel_to_phi3,
1958
2405
  "paligemma": apply_liger_kernel_to_paligemma,
2406
+ "falcon_h1": apply_liger_kernel_to_falcon_h1,
2407
+ "smolvlm": apply_liger_kernel_to_smolvlm,
1959
2408
  }
1960
2409
 
1961
2410
 
@@ -9,7 +9,7 @@ from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunct
9
9
 
10
10
 
11
11
  class LigerMultiTokenAttention(nn.Module):
12
- """
12
+ r"""
13
13
  Multi-Token Attention:
14
14
  out = mask_{0}(conv2d(softmax(mask_{-\inf}(scores))))
15
15
 
@@ -0,0 +1,42 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.poly_norm import LigerPolyNormFunction
5
+
6
+
7
+ class LigerPolyNorm(nn.Module):
8
+ """
9
+ PolyNorm layer wrapper for Liger kernel.
10
+
11
+ PolyNorm formula:
12
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
13
+ where norm(u) = u / sqrt(mean(u²) + ε)
14
+
15
+ Reference:
16
+ https://github.com/BryceZhuo/PolyCom/
17
+
18
+ Args:
19
+ eps: epsilon for numerical stability (default: 1e-6)
20
+ in_place: whether to in-place modify grad_output in backward to save memory (default: False).
21
+ Set to True to save memory if grad_output is not needed elsewhere.
22
+ """
23
+
24
+ def __init__(self, eps=1e-6, in_place=True):
25
+ super().__init__()
26
+ # Align with PolyCom reference: initialize weights to (1/3, 1/3, 1/3) and bias to 1.0
27
+ self.weight = nn.Parameter(torch.full((3,), 1.0 / 3.0))
28
+ self.bias = nn.Parameter(torch.tensor(1.0))
29
+ self.variance_epsilon = eps
30
+ self.in_place = in_place
31
+
32
+ def forward(self, hidden_states):
33
+ return LigerPolyNormFunction.apply(
34
+ hidden_states,
35
+ self.weight,
36
+ self.bias,
37
+ self.variance_epsilon,
38
+ self.in_place,
39
+ )
40
+
41
+ def extra_repr(self):
42
+ return f"weight_shape={tuple(self.weight.shape)}, eps={self.variance_epsilon}, in_place={self.in_place}"
@@ -77,3 +77,10 @@ class LigerRMSNormForGlm4(LigerRMSNorm):
77
77
  self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
78
78
  ):
79
79
  super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
80
+
81
+
82
+ class LigerRMSNormForQwen3Next(LigerRMSNorm):
83
+ def __init__(
84
+ self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None
85
+ ):
86
+ super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: liger_kernel
3
- Version: 0.6.2
3
+ Version: 0.6.3
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -35,15 +35,14 @@ Requires-Dist: triton>=2.3.1
35
35
  Provides-Extra: dev
36
36
  Requires-Dist: transformers>=4.49.0; extra == "dev"
37
37
  Requires-Dist: matplotlib>=3.7.2; extra == "dev"
38
- Requires-Dist: flake8>=4.0.1.1; extra == "dev"
39
- Requires-Dist: black>=24.4.2; extra == "dev"
40
- Requires-Dist: isort>=5.13.2; extra == "dev"
38
+ Requires-Dist: ruff>=0.12.0; extra == "dev"
41
39
  Requires-Dist: pytest>=7.1.2; extra == "dev"
42
40
  Requires-Dist: pytest-xdist; extra == "dev"
41
+ Requires-Dist: pytest-cov; extra == "dev"
42
+ Requires-Dist: pytest-asyncio; extra == "dev"
43
43
  Requires-Dist: pytest-rerunfailures; extra == "dev"
44
44
  Requires-Dist: datasets>=2.19.2; extra == "dev"
45
45
  Requires-Dist: seaborn; extra == "dev"
46
- Requires-Dist: mkdocs; extra == "dev"
47
46
  Requires-Dist: mkdocs-material; extra == "dev"
48
47
  Requires-Dist: torchvision>=0.20; extra == "dev"
49
48
  Dynamic: license-file
@@ -181,8 +180,8 @@ y = orpo_loss(lm_head.weight, x, target)
181
180
  - `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`)
182
181
 
183
182
  ```bash
184
- # Need to pass the url when installing
185
- pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2
183
+ pip install -e .[dev]
184
+ pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3/
186
185
  ```
187
186
 
188
187
  ### Optional Dependencies
@@ -216,6 +215,9 @@ pip install -e .
216
215
 
217
216
  # Setup Development Dependencies
218
217
  pip install -e ".[dev]"
218
+
219
+ # NOTE -> For AMD users only
220
+ pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3/
219
221
  ```
220
222
 
221
223
 
@@ -312,6 +314,7 @@ loss.backward()
312
314
  | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
313
315
  | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
314
316
  | GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
317
+ | InternVL3 | `liger_kernel.transformers.apply_liger_kernel_to_internvl` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
315
318
 
316
319
 
317
320
  ## Low-level APIs
@@ -391,17 +394,17 @@ loss.backward()
391
394
  <td style="padding: 10px;">
392
395
  <div style="display: block;">
393
396
  <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
394
- <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
397
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?branch=main&event=push" alt="Build">
395
398
  </a>
396
399
  </div>
397
400
  <div style="display: block;">
398
401
  <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
399
- <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
402
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?branch=main&event=push" alt="Build">
400
403
  </a>
401
404
  </div>
402
405
  <div style="display: block;">
403
406
  <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml">
404
- <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
407
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?branch=main&event=push" alt="Build">
405
408
  </a>
406
409
  </div>
407
410
  </td>