liger-kernel 0.6.3__py3-none-any.whl → 0.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +18 -5
  6. liger_kernel/ops/cross_entropy.py +59 -9
  7. liger_kernel/ops/fused_linear_cross_entropy.py +30 -4
  8. liger_kernel/ops/grpo_loss.py +3 -1
  9. liger_kernel/ops/layer_norm.py +84 -65
  10. liger_kernel/ops/tiled_mlp.py +136 -0
  11. liger_kernel/transformers/__init__.py +19 -0
  12. liger_kernel/transformers/cross_entropy.py +8 -3
  13. liger_kernel/transformers/functional.py +24 -6
  14. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -3
  15. liger_kernel/transformers/grpo_loss.py +56 -1
  16. liger_kernel/transformers/model/falcon_h1.py +19 -5
  17. liger_kernel/transformers/model/gemma.py +17 -6
  18. liger_kernel/transformers/model/gemma2.py +14 -5
  19. liger_kernel/transformers/model/gemma3.py +25 -12
  20. liger_kernel/transformers/model/glm4.py +16 -4
  21. liger_kernel/transformers/model/glm4v.py +16 -4
  22. liger_kernel/transformers/model/glm4v_moe.py +23 -4
  23. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  24. liger_kernel/transformers/model/internvl.py +12 -5
  25. liger_kernel/transformers/model/llama.py +14 -5
  26. liger_kernel/transformers/model/llama4.py +16 -4
  27. liger_kernel/transformers/model/llava.py +12 -4
  28. liger_kernel/transformers/model/loss_utils.py +31 -3
  29. liger_kernel/transformers/model/mistral.py +15 -6
  30. liger_kernel/transformers/model/mixtral.py +16 -7
  31. liger_kernel/transformers/model/mllama.py +12 -4
  32. liger_kernel/transformers/model/olmo2.py +16 -4
  33. liger_kernel/transformers/model/olmo3.py +142 -0
  34. liger_kernel/transformers/model/output_classes.py +147 -0
  35. liger_kernel/transformers/model/paligemma.py +22 -5
  36. liger_kernel/transformers/model/phi3.py +14 -7
  37. liger_kernel/transformers/model/qwen2.py +16 -3
  38. liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
  39. liger_kernel/transformers/model/qwen2_vl.py +16 -4
  40. liger_kernel/transformers/model/qwen3.py +20 -5
  41. liger_kernel/transformers/model/qwen3_moe.py +19 -5
  42. liger_kernel/transformers/model/qwen3_next.py +17 -5
  43. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  44. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  45. liger_kernel/transformers/model/smollm3.py +15 -6
  46. liger_kernel/transformers/monkey_patch.py +398 -20
  47. liger_kernel/transformers/rope.py +43 -0
  48. liger_kernel/transformers/swiglu.py +17 -0
  49. liger_kernel/transformers/tiled_mlp.py +133 -0
  50. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/METADATA +4 -1
  51. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/RECORD +55 -48
  52. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/WHEEL +0 -0
  53. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/licenses/LICENSE +0 -0
  54. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/licenses/NOTICE +0 -0
  55. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/top_level.txt +0 -0
@@ -7,11 +7,12 @@ from typing import Union
7
7
  import torch
8
8
 
9
9
  from torch.distributed.fsdp import FullyShardedDataParallel
10
- from transformers.modeling_outputs import CausalLMOutputWithPast
11
10
  from transformers.utils.deprecation import deprecate_kwarg
12
11
 
13
12
  from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
14
13
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
14
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
15
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
15
16
  from liger_kernel.utils import PEFT_AVAILABLE
16
17
 
17
18
  if TYPE_CHECKING:
@@ -38,7 +39,7 @@ def lce_forward(
38
39
  logits_to_keep: Union[int, torch.Tensor] = 0,
39
40
  skip_logits: Optional[bool] = None,
40
41
  **kwargs,
41
- ) -> Union[Tuple, CausalLMOutputWithPast]:
42
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
42
43
  r"""
43
44
  Args:
44
45
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -101,6 +102,8 @@ def lce_forward(
101
102
  shift_labels = kwargs.pop("shift_labels", None)
102
103
  logits = None
103
104
  loss = None
105
+ token_accuracy = None
106
+
104
107
  # if in training mode, don't materialize logits
105
108
  if skip_logits and labels is None and shift_labels is None:
106
109
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -109,8 +112,9 @@ def lce_forward(
109
112
  # By default, if in training mode, don't materialize logits
110
113
  skip_logits = self.training and (labels is not None or shift_labels is not None)
111
114
 
115
+ # Compute loss
112
116
  if skip_logits:
113
- loss = lce_maybe_trainable_lm_head(
117
+ result = lce_maybe_trainable_lm_head(
114
118
  self,
115
119
  hidden_states=kept_hidden_states,
116
120
  hidden_size=self.config.hidden_size,
@@ -118,6 +122,7 @@ def lce_forward(
118
122
  shift_labels=shift_labels,
119
123
  **kwargs,
120
124
  )
125
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
121
126
 
122
127
  else:
123
128
  logits = self.lm_head(kept_hidden_states)
@@ -131,15 +136,19 @@ def lce_forward(
131
136
  )
132
137
 
133
138
  if not return_dict:
134
- output = (logits,) + outputs[1:]
135
- return (loss,) + output if loss is not None else output
139
+ output_tuple = (logits,) + outputs[1:]
140
+ output = (loss,) + output_tuple if loss is not None else output_tuple
141
+ output = output + (token_accuracy,) if token_accuracy is not None else output
142
+ return output
136
143
 
137
- return CausalLMOutputWithPast(
144
+ # Return custom output class with token_accuracy field
145
+ return LigerCausalLMOutputWithPast(
138
146
  loss=loss,
139
147
  logits=logits,
140
148
  past_key_values=outputs.past_key_values,
141
149
  hidden_states=outputs.hidden_states,
142
150
  attentions=outputs.attentions,
151
+ token_accuracy=token_accuracy,
143
152
  )
144
153
 
145
154
 
@@ -34,6 +34,8 @@ from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_f
34
34
  from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
35
35
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
36
36
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
37
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb_with_cast
38
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb_with_cast_and_leading_batch
37
39
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
38
40
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
39
41
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
@@ -1643,6 +1645,158 @@ def apply_liger_kernel_to_qwen2_5_vl(
1643
1645
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1644
1646
 
1645
1647
 
1648
+ def apply_liger_kernel_to_qwen3_vl(
1649
+ rope: bool = True,
1650
+ cross_entropy: bool = False,
1651
+ fused_linear_cross_entropy: bool = True,
1652
+ rms_norm: bool = True,
1653
+ swiglu: bool = False,
1654
+ model: PreTrainedModel = None,
1655
+ ) -> None:
1656
+ """
1657
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
1658
+
1659
+ Args:
1660
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1661
+ fused_linear_cross_entropy (bool):
1662
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1663
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1664
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1665
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1666
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1667
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1668
+ loaded. Default is None.
1669
+ """
1670
+
1671
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1672
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1673
+ )
1674
+
1675
+ from transformers.models.qwen3_vl import modeling_qwen3_vl
1676
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
1677
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
1678
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
1679
+
1680
+ from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
1681
+
1682
+ if rope:
1683
+ modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
1684
+ modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
1685
+
1686
+ if rms_norm:
1687
+ modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
1688
+
1689
+ if cross_entropy:
1690
+ from transformers.loss.loss_utils import nn
1691
+
1692
+ nn.functional.cross_entropy = liger_cross_entropy
1693
+
1694
+ if fused_linear_cross_entropy:
1695
+ if model is not None:
1696
+ model.forward = MethodType(qwen3_vl_lce_forward, model)
1697
+ else:
1698
+ modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
1699
+
1700
+ if model is not None and rms_norm:
1701
+ if isinstance(model, (Qwen3VLForConditionalGeneration, Qwen3VLModel)):
1702
+ text_model: Qwen3VLTextModel = model.language_model
1703
+ elif isinstance(model, Qwen3VLTextModel):
1704
+ text_model = model
1705
+ else:
1706
+ raise TypeError(
1707
+ f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
1708
+ )
1709
+
1710
+ _patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1711
+
1712
+ if text_model is not None:
1713
+ _patch_qwen3_vl_rms_norm(text_model.norm)
1714
+ for decoder_layer in text_model.layers:
1715
+ _patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
1716
+ _patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
1717
+ self_attn = getattr(decoder_layer, "self_attn", None)
1718
+ if self_attn is not None:
1719
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1720
+ _patch_qwen3_vl_rms_norm(self_attn.q_norm)
1721
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1722
+ _patch_qwen3_vl_rms_norm(self_attn.k_norm)
1723
+
1724
+
1725
+ def apply_liger_kernel_to_qwen3_vl_moe(
1726
+ rope: bool = True,
1727
+ cross_entropy: bool = False,
1728
+ fused_linear_cross_entropy: bool = True,
1729
+ rms_norm: bool = True,
1730
+ swiglu: bool = False,
1731
+ model: PreTrainedModel = None,
1732
+ ) -> None:
1733
+ """
1734
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
1735
+
1736
+ Args:
1737
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1738
+ fused_linear_cross_entropy (bool):
1739
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
1740
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1741
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1742
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1743
+ loaded. Default is None.
1744
+ """
1745
+
1746
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1747
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1748
+ )
1749
+
1750
+ from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
1751
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
1752
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
1753
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
1754
+
1755
+ from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
1756
+
1757
+ if rope:
1758
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
1759
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_with_cast_and_leading_batch
1760
+
1761
+ if rms_norm:
1762
+ modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
1763
+
1764
+ if cross_entropy:
1765
+ from transformers.loss.loss_utils import nn
1766
+
1767
+ nn.functional.cross_entropy = liger_cross_entropy
1768
+
1769
+ if fused_linear_cross_entropy:
1770
+ if model is not None:
1771
+ model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
1772
+ else:
1773
+ modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
1774
+
1775
+ if model is not None and rms_norm:
1776
+ if isinstance(model, (Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel)):
1777
+ text_model: Qwen3VLMoeTextModel = model.language_model
1778
+ elif isinstance(model, Qwen3VLMoeTextModel):
1779
+ text_model = model
1780
+ else:
1781
+ raise TypeError(
1782
+ f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
1783
+ )
1784
+
1785
+ _patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1786
+
1787
+ if text_model is not None:
1788
+ _patch_qwen3_vl_moe_rms_norm(text_model.norm)
1789
+ for decoder_layer in text_model.layers:
1790
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
1791
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
1792
+ self_attn = getattr(decoder_layer, "self_attn", None)
1793
+ if self_attn is not None:
1794
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1795
+ _patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
1796
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1797
+ _patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
1798
+
1799
+
1646
1800
  def apply_liger_kernel_to_phi3(
1647
1801
  rope: bool = True,
1648
1802
  cross_entropy: bool = False,
@@ -1774,6 +1928,74 @@ def apply_liger_kernel_to_olmo2(
1774
1928
  _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
1775
1929
 
1776
1930
 
1931
+ def apply_liger_kernel_to_olmo3(
1932
+ rope: bool = True,
1933
+ cross_entropy: bool = False,
1934
+ fused_linear_cross_entropy: bool = True,
1935
+ rms_norm: bool = True,
1936
+ swiglu: bool = True,
1937
+ model: PreTrainedModel = None,
1938
+ ) -> None:
1939
+ """
1940
+ Apply Liger kernels to replace original implementation in HuggingFace Olmo3 models.
1941
+
1942
+ Args:
1943
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1944
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1945
+ fused_linear_cross_entropy (bool):
1946
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1947
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1948
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1949
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1950
+ swiglu (bool): Whether to apply Liger's SwiGLU to Olmo3MLP. Default is True.
1951
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1952
+ loaded. Default is None.
1953
+ """
1954
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1955
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1956
+ )
1957
+
1958
+ from transformers.models.olmo3 import modeling_olmo3
1959
+ from transformers.models.olmo3.modeling_olmo3 import Olmo3Model
1960
+
1961
+ from liger_kernel.transformers.model.olmo3 import lce_forward as olmo3_lce_forward
1962
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
1963
+
1964
+ # Olmo3 arch is very similar to Olmo2, so we can reuse all these components in the same way.
1965
+ if rope:
1966
+ modeling_olmo3.apply_rotary_pos_emb = liger_rotary_pos_emb
1967
+ if rms_norm:
1968
+ modeling_olmo3.Olmo3RMSNorm = LigerRMSNormForOlmo2 # same as olmo2
1969
+ if swiglu:
1970
+ modeling_olmo3.Olmo3MLP = LigerSwiGLUMLP
1971
+ if cross_entropy:
1972
+ from transformers.loss.loss_utils import nn
1973
+
1974
+ nn.functional.cross_entropy = liger_cross_entropy
1975
+ if fused_linear_cross_entropy:
1976
+ if model is not None:
1977
+ model.forward = MethodType(olmo3_lce_forward, model)
1978
+ else:
1979
+ modeling_olmo3.Olmo3ForCausalLM.forward = olmo3_lce_forward
1980
+
1981
+ if model is not None:
1982
+ # The model instance already exists, so we need to additionally patch the
1983
+ # instance variables that reference already-instantiated modules
1984
+
1985
+ # get the base model from the model instance
1986
+ base_model: Olmo3Model = getattr(model, model.base_model_prefix, model)
1987
+
1988
+ if rms_norm:
1989
+ _patch_rms_norm_module(base_model.norm)
1990
+
1991
+ for decoder_layer in base_model.layers:
1992
+ if swiglu:
1993
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1994
+ if rms_norm:
1995
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
1996
+ _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
1997
+
1998
+
1777
1999
  def apply_liger_kernel_to_glm4(
1778
2000
  rope: bool = False,
1779
2001
  cross_entropy: bool = False,
@@ -2038,6 +2260,7 @@ def apply_liger_kernel_to_internvl(
2038
2260
  cross_entropy: bool = False,
2039
2261
  fused_linear_cross_entropy: bool = True,
2040
2262
  rms_norm: bool = True,
2263
+ layer_norm: bool = True,
2041
2264
  model: Optional[PreTrainedModel] = None,
2042
2265
  **kwargs,
2043
2266
  ) -> None:
@@ -2048,37 +2271,60 @@ def apply_liger_kernel_to_internvl(
2048
2271
  NOTE: InternVL is not available in transformers<4.52.1
2049
2272
 
2050
2273
  Args:
2051
- rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2052
2274
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2053
2275
  fused_linear_cross_entropy (bool):
2054
2276
  Whether to apply Liger's fused linear cross entropy loss. Default is True.
2055
2277
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2056
2278
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2057
2279
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2058
- swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
2280
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
2059
2281
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2060
2282
  loaded. Default is None.
2061
2283
  """
2062
2284
  assert not (cross_entropy and fused_linear_cross_entropy), (
2063
2285
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
2064
2286
  )
2287
+ import torch.nn as torch_nn
2065
2288
 
2066
2289
  from transformers.models.internvl import modeling_internvl
2290
+ from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration
2291
+ from transformers.models.internvl.modeling_internvl import InternVLModel
2292
+ from transformers.models.internvl.modeling_internvl import InternVLVisionLayer
2293
+ from transformers.models.internvl.modeling_internvl import InternVLVisionModel
2294
+ from transformers.models.internvl.modeling_internvl import InternVLVisionRMSNorm
2067
2295
 
2296
+ from liger_kernel.transformers.layer_norm import LigerLayerNorm
2068
2297
  from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
2298
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm
2299
+
2300
+ if layer_norm and model is None:
2301
+ modeling_internvl.nn.LayerNorm = LigerLayerNorm
2069
2302
 
2070
2303
  if cross_entropy:
2071
- logger.warning(TRANSFORMER_DEPRECATION_WARNING)
2072
- modeling_internvl.nn.CrossEntropyLoss = LigerCrossEntropyLoss
2304
+ logger.info("Apply liger cross entropy")
2305
+
2306
+ from transformers.loss.loss_utils import nn
2307
+
2308
+ nn.functional.cross_entropy = liger_cross_entropy
2073
2309
  if fused_linear_cross_entropy:
2074
2310
  modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
2075
2311
  if rms_norm:
2076
2312
  modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
2077
2313
 
2078
2314
  if model is not None:
2079
- text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
2315
+ # The model instance already exists, so we need to additionally patch the
2316
+ # instance variables that reference already-instantiated modules
2317
+ if isinstance(model, (InternVLForConditionalGeneration, InternVLModel)):
2318
+ # NOTE: language_model and visual properties can be accessed throught conditional class.
2319
+ text_model = model.language_model
2320
+ vision_model: InternVLVisionModel = model.vision_tower
2321
+ else:
2322
+ raise TypeError(
2323
+ f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}"
2324
+ )
2325
+
2326
+ text_model_name = model.config.text_config.model_type
2080
2327
  text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2081
- vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)
2082
2328
 
2083
2329
  kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2084
2330
  if text_liger_fn:
@@ -2091,25 +2337,33 @@ def apply_liger_kernel_to_internvl(
2091
2337
  f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2092
2338
  f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2093
2339
  )
2094
- text_kwargs["model"] = model.language_model
2340
+ text_kwargs["model"] = text_model
2095
2341
  text_liger_fn(**text_kwargs)
2096
2342
  elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2097
2343
  logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2098
2344
 
2099
- if vision_liger_fn:
2100
- accept_params = inspect.signature(vision_liger_fn).parameters
2101
- remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2102
- vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2345
+ # Patch vision model RMSNorm layers
2346
+ if rms_norm:
2347
+ for encoder_layer in vision_model.encoder.layer:
2348
+ encoder_layer: InternVLVisionLayer
2349
+ if isinstance(encoder_layer.attention.q_norm, InternVLVisionRMSNorm):
2350
+ _patch_rms_norm_module(encoder_layer.attention.q_norm)
2351
+ if isinstance(encoder_layer.attention.k_norm, InternVLVisionRMSNorm):
2352
+ _patch_rms_norm_module(encoder_layer.attention.k_norm)
2103
2353
 
2104
- if remain_params:
2105
- logger.warning(
2106
- f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
2107
- f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
2108
- )
2109
- vision_kwargs["model"] = model.vision_tower
2110
- vision_liger_fn(**vision_kwargs)
2111
- elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2112
- logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
2354
+ # Patch vision model LayerNorm layers
2355
+ if layer_norm:
2356
+ # Patch layernorm
2357
+ if isinstance(vision_model.layernorm, torch_nn.LayerNorm):
2358
+ _patch_layer_norm_module(vision_model.layernorm)
2359
+
2360
+ # Patch encoder layers
2361
+ for encoder_layer in vision_model.encoder.layer:
2362
+ encoder_layer: InternVLVisionLayer
2363
+ if isinstance(encoder_layer.layernorm_before, torch_nn.LayerNorm):
2364
+ _patch_layer_norm_module(encoder_layer.layernorm_before)
2365
+ if isinstance(encoder_layer.layernorm_after, torch_nn.LayerNorm):
2366
+ _patch_layer_norm_module(encoder_layer.layernorm_after)
2113
2367
 
2114
2368
 
2115
2369
  def apply_liger_kernel_to_smolvlm(
@@ -2372,6 +2626,123 @@ def apply_liger_kernel_to_qwen3_next(
2372
2626
  _patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
2373
2627
 
2374
2628
 
2629
+ def apply_liger_kernel_to_hunyuan_v1_dense(
2630
+ rope: bool = True,
2631
+ cross_entropy: bool = False,
2632
+ fused_linear_cross_entropy: bool = True,
2633
+ rms_norm: bool = True,
2634
+ swiglu: bool = True,
2635
+ model: PreTrainedModel = None,
2636
+ ) -> None:
2637
+ """
2638
+ Apply Liger kernels to replace original implementation in HuggingFace Hunyuan v1 dense models.
2639
+ """
2640
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2641
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2642
+ )
2643
+
2644
+ from transformers.models.hunyuan_v1_dense import modeling_hunyuan_v1_dense
2645
+ from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1Model
2646
+
2647
+ from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_lce_forward
2648
+ from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
2649
+
2650
+ if rope:
2651
+ modeling_hunyuan_v1_dense.apply_rotary_pos_emb = liger_rotary_pos_emb
2652
+
2653
+ if rms_norm:
2654
+ modeling_hunyuan_v1_dense.HunYuanDenseV1RMSNorm = LigerRMSNorm
2655
+
2656
+ if cross_entropy:
2657
+ from transformers.loss.loss_utils import nn
2658
+
2659
+ nn.functional.cross_entropy = liger_cross_entropy
2660
+
2661
+ if fused_linear_cross_entropy:
2662
+ if model is not None:
2663
+ model.forward = MethodType(hunyuan_v1_lce_forward, model)
2664
+ else:
2665
+ modeling_hunyuan_v1_dense.HunYuanDenseV1ForCausalLM.forward = hunyuan_v1_lce_forward
2666
+
2667
+ if swiglu:
2668
+ modeling_hunyuan_v1_dense.HunYuanDenseV1MLP = LigerHunyuanV1SwiGLUMLP
2669
+
2670
+ if model is not None:
2671
+ # The model instance already exists, so we need to additionally patch the
2672
+ # instance variables that reference already-instantiated modules
2673
+
2674
+ # get the base model from the model instance
2675
+ base_model: HunYuanDenseV1Model = getattr(model, model.base_model_prefix, model)
2676
+
2677
+ if rms_norm:
2678
+ _patch_rms_norm_module(base_model.norm)
2679
+ for decoder_layer in base_model.layers:
2680
+ if swiglu:
2681
+ _patch_swiglu_module(decoder_layer.mlp, LigerHunyuanV1SwiGLUMLP)
2682
+ if rms_norm:
2683
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2684
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2685
+
2686
+
2687
+ def apply_liger_kernel_to_hunyuan_v1_moe(
2688
+ rope: bool = True,
2689
+ cross_entropy: bool = False,
2690
+ fused_linear_cross_entropy: bool = True,
2691
+ rms_norm: bool = True,
2692
+ swiglu: bool = True,
2693
+ model: PreTrainedModel = None,
2694
+ ) -> None:
2695
+ """
2696
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
2697
+ """
2698
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2699
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2700
+ )
2701
+
2702
+ from transformers.models.hunyuan_v1_moe import modeling_hunyuan_v1_moe
2703
+ from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1Model
2704
+
2705
+ from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_moe_lce_forward
2706
+ from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
2707
+
2708
+ if rope:
2709
+ modeling_hunyuan_v1_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
2710
+
2711
+ if rms_norm:
2712
+ modeling_hunyuan_v1_moe.HunYuanMoEV1RMSNorm = LigerRMSNorm
2713
+
2714
+ if cross_entropy:
2715
+ from transformers.loss.loss_utils import nn
2716
+
2717
+ nn.functional.cross_entropy = liger_cross_entropy
2718
+
2719
+ if fused_linear_cross_entropy:
2720
+ if model is not None:
2721
+ model.forward = MethodType(hunyuan_v1_moe_lce_forward, model)
2722
+ else:
2723
+ modeling_hunyuan_v1_moe.HunYuanMoEV1ForCausalLM.forward = hunyuan_v1_moe_lce_forward
2724
+
2725
+ if swiglu:
2726
+ modeling_hunyuan_v1_moe.HunYuanMoEV1MLP = LigerHunyuanV1SwiGLUMLP
2727
+
2728
+ if model is not None:
2729
+ # The model instance already exists, so we need to additionally patch the
2730
+ # instance variables that reference already-instantiated modules
2731
+
2732
+ # get the base model from the model instance
2733
+ base_model: HunYuanMoEV1Model = getattr(model, model.base_model_prefix, model)
2734
+
2735
+ if rms_norm:
2736
+ _patch_rms_norm_module(base_model.norm)
2737
+ for decoder_layer in base_model.layers:
2738
+ if swiglu:
2739
+ for mlp_expert in decoder_layer.mlp.experts:
2740
+ _patch_swiglu_module(mlp_expert, LigerHunyuanV1SwiGLUMLP)
2741
+ if rms_norm:
2742
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2743
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2744
+
2745
+
2375
2746
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
2376
2747
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
2377
2748
  "gemma": apply_liger_kernel_to_gemma,
@@ -2392,6 +2763,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
2392
2763
  "mistral": apply_liger_kernel_to_mistral,
2393
2764
  "mixtral": apply_liger_kernel_to_mixtral,
2394
2765
  "olmo2": apply_liger_kernel_to_olmo2,
2766
+ "olmo3": apply_liger_kernel_to_olmo3,
2395
2767
  "qwen2": apply_liger_kernel_to_qwen2,
2396
2768
  "qwen3": apply_liger_kernel_to_qwen3,
2397
2769
  "qwen3_moe": apply_liger_kernel_to_qwen3_moe,
@@ -2400,11 +2772,17 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
2400
2772
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
2401
2773
  "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
2402
2774
  "qwen3_next": apply_liger_kernel_to_qwen3_next,
2775
+ "qwen3_vl": apply_liger_kernel_to_qwen3_vl,
2776
+ "qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
2777
+ "qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
2778
+ "qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
2403
2779
  "smollm3": apply_liger_kernel_to_smollm3,
2404
2780
  "phi3": apply_liger_kernel_to_phi3,
2405
2781
  "paligemma": apply_liger_kernel_to_paligemma,
2406
2782
  "falcon_h1": apply_liger_kernel_to_falcon_h1,
2407
2783
  "smolvlm": apply_liger_kernel_to_smolvlm,
2784
+ "hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
2785
+ "hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
2408
2786
  }
2409
2787
 
2410
2788
 
@@ -1,3 +1,8 @@
1
+ from typing import Optional
2
+ from typing import Tuple
3
+
4
+ import torch
5
+
1
6
  from liger_kernel.ops.rope import LigerRopeFunction
2
7
 
3
8
 
@@ -18,3 +23,41 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
18
23
  """
19
24
 
20
25
  return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
26
+
27
+
28
+ def liger_rotary_pos_emb_with_cast(
29
+ q: torch.Tensor,
30
+ k: torch.Tensor,
31
+ cos: torch.Tensor,
32
+ sin: torch.Tensor,
33
+ position_ids: Optional[torch.Tensor] = None,
34
+ unsqueeze_dim: int = 1,
35
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
36
+ orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
37
+
38
+ q32 = q.to(torch.float32)
39
+ k32 = k.to(torch.float32)
40
+ cos32 = cos.to(torch.float32)
41
+ sin32 = sin.to(torch.float32)
42
+
43
+ q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim)
44
+ return q_out.to(orig_q_dtype), k_out.to(orig_k_dtype)
45
+
46
+
47
+ def liger_rotary_pos_emb_with_cast_and_leading_batch(
48
+ q: torch.Tensor,
49
+ k: torch.Tensor,
50
+ cos: torch.Tensor,
51
+ sin: torch.Tensor,
52
+ position_ids: Optional[torch.Tensor] = None,
53
+ unsqueeze_dim: int = 1,
54
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
55
+ orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
56
+
57
+ q32 = q.to(torch.float32).unsqueeze(0)
58
+ k32 = k.to(torch.float32).unsqueeze(0)
59
+ cos32 = cos.to(torch.float32).unsqueeze(0)
60
+ sin32 = sin.to(torch.float32).unsqueeze(0)
61
+
62
+ q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim)
63
+ return q_out.to(orig_q_dtype).squeeze(0), k_out.to(orig_k_dtype).squeeze(0)
@@ -77,3 +77,20 @@ class LigerQwen3MoeSwiGLUMLP(nn.Module):
77
77
 
78
78
  def forward(self, x):
79
79
  return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
80
+
81
+
82
+ class LigerHunyuanV1SwiGLUMLP(nn.Module):
83
+ def __init__(self, config, layer_idx=None, is_shared_mlp=False):
84
+ super().__init__()
85
+ self.config = config
86
+ self.hidden_size = config.hidden_size
87
+ self.intermediate_size = config.intermediate_size
88
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
89
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
90
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
91
+ self.layer_idx = layer_idx
92
+ if config.hidden_act not in ["silu", "swish"]:
93
+ raise ValueError(f"Activation function {config.hidden_act} not supported.")
94
+
95
+ def forward(self, x):
96
+ return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))