liger-kernel-nightly 0.6.3.dev20251121010306__py3-none-any.whl → 0.6.3.dev20251121200119__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -42,6 +42,8 @@ if TYPE_CHECKING:
42
42
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
43
43
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
44
44
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
45
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_dense # noqa: F401
46
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_moe # noqa: F401
45
47
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_internvl # noqa: F401
46
48
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
47
49
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
@@ -50,6 +52,7 @@ if TYPE_CHECKING:
50
52
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
51
53
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
52
54
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
55
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo3 # noqa: F401
53
56
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_paligemma # noqa: F401
54
57
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
55
58
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
@@ -116,6 +119,7 @@ def __getattr__(name: str):
116
119
  "apply_liger_kernel_to_mixtral",
117
120
  "apply_liger_kernel_to_mllama",
118
121
  "apply_liger_kernel_to_olmo2",
122
+ "apply_liger_kernel_to_olmo3",
119
123
  "apply_liger_kernel_to_paligemma",
120
124
  "apply_liger_kernel_to_phi3",
121
125
  "apply_liger_kernel_to_qwen2",
@@ -128,6 +132,8 @@ def __getattr__(name: str):
128
132
  "apply_liger_kernel_to_qwen3_vl_moe",
129
133
  "apply_liger_kernel_to_smollm3",
130
134
  "apply_liger_kernel_to_smolvlm",
135
+ "apply_liger_kernel_to_hunyuan_v1_dense",
136
+ "apply_liger_kernel_to_hunyuan_v1_moe",
131
137
  }
132
138
 
133
139
  if name in monkey_patch_symbols:
@@ -190,6 +196,7 @@ if _TRANSFORMERS_AVAILABLE:
190
196
  "apply_liger_kernel_to_mixtral",
191
197
  "apply_liger_kernel_to_mllama",
192
198
  "apply_liger_kernel_to_olmo2",
199
+ "apply_liger_kernel_to_olmo3",
193
200
  "apply_liger_kernel_to_paligemma",
194
201
  "apply_liger_kernel_to_phi3",
195
202
  "apply_liger_kernel_to_qwen2",
@@ -202,5 +209,7 @@ if _TRANSFORMERS_AVAILABLE:
202
209
  "apply_liger_kernel_to_qwen3_vl_moe",
203
210
  "apply_liger_kernel_to_smollm3",
204
211
  "apply_liger_kernel_to_smolvlm",
212
+ "apply_liger_kernel_to_hunyuan_v1_dense",
213
+ "apply_liger_kernel_to_hunyuan_v1_moe",
205
214
  ]
206
215
  )
@@ -0,0 +1,134 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Union
4
+
5
+ import torch
6
+
7
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
8
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
9
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
10
+
11
+
12
+ def lce_forward(
13
+ self,
14
+ input_ids: Optional[torch.LongTensor] = None,
15
+ attention_mask: Optional[torch.Tensor] = None,
16
+ position_ids: Optional[torch.LongTensor] = None,
17
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
18
+ inputs_embeds: Optional[torch.FloatTensor] = None,
19
+ labels: Optional[torch.LongTensor] = None,
20
+ use_cache: Optional[bool] = None,
21
+ output_attentions: Optional[bool] = None,
22
+ output_hidden_states: Optional[bool] = None,
23
+ cache_position: Optional[torch.LongTensor] = None,
24
+ logits_to_keep: Union[int, torch.Tensor] = 0,
25
+ skip_logits: Optional[bool] = None,
26
+ return_dict: Optional[bool] = None,
27
+ **kwargs,
28
+ ) -> LigerCausalLMOutputWithPast:
29
+ r"""
30
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
31
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
32
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
33
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
34
+
35
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
36
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
37
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
38
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
39
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
40
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
41
+
42
+ Returns:
43
+
44
+ Example:
45
+
46
+ ```python
47
+ >>> from transformers import AutoTokenizer, HunYuanDenseV1ForCausalLM
48
+
49
+ >>> model = HunYuanDenseV1ForCausalLM.from_pretrained("meta-hunyuan_v1_dense/HunYuanDenseV1-2-7b-hf")
50
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-hunyuan_v1_dense/HunYuanDenseV1-2-7b-hf")
51
+
52
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
53
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
54
+
55
+ >>> # Generate
56
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
57
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
58
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
59
+ ```"""
60
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
61
+ output_hidden_states = (
62
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
63
+ )
64
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
65
+
66
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
67
+ outputs = self.model(
68
+ input_ids=input_ids,
69
+ attention_mask=attention_mask,
70
+ position_ids=position_ids,
71
+ past_key_values=past_key_values,
72
+ inputs_embeds=inputs_embeds,
73
+ use_cache=use_cache,
74
+ output_attentions=output_attentions,
75
+ output_hidden_states=output_hidden_states,
76
+ cache_position=cache_position,
77
+ **kwargs,
78
+ )
79
+
80
+ hidden_states = outputs[0]
81
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
82
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
83
+ kept_hidden_states = hidden_states[:, slice_indices, :]
84
+
85
+ shift_labels = kwargs.pop("shift_labels", None)
86
+ logits = None
87
+ loss = None
88
+ token_accuracy = None
89
+
90
+ if skip_logits and labels is None and shift_labels is None:
91
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
92
+
93
+ if skip_logits is None:
94
+ # By default, if in training mode, don't materialize logits
95
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
96
+
97
+ # Compute loss
98
+ if skip_logits:
99
+ result = LigerForCausalLMLoss(
100
+ hidden_states=kept_hidden_states,
101
+ lm_head_weight=self.lm_head.weight,
102
+ labels=labels,
103
+ shift_labels=shift_labels,
104
+ hidden_size=self.config.hidden_size,
105
+ **kwargs,
106
+ )
107
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
108
+
109
+ else:
110
+ logits = self.lm_head(kept_hidden_states)
111
+ if labels is not None or shift_labels is not None:
112
+ loss = self.loss_function(
113
+ logits=logits,
114
+ labels=labels,
115
+ shift_labels=shift_labels,
116
+ vocab_size=self.config.vocab_size,
117
+ **kwargs,
118
+ )
119
+
120
+ if not return_dict:
121
+ output = (logits,) + outputs[1:]
122
+ output = ((loss,) + output) if loss is not None else output
123
+ output = output + (token_accuracy,) if token_accuracy is not None else output
124
+ return output
125
+
126
+ # Return custom output class with accuracy field
127
+ return LigerCausalLMOutputWithPast(
128
+ loss=loss,
129
+ logits=logits,
130
+ past_key_values=outputs.past_key_values,
131
+ hidden_states=outputs.hidden_states,
132
+ attentions=outputs.attentions,
133
+ token_accuracy=token_accuracy,
134
+ )
@@ -0,0 +1,142 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.modeling_outputs import BaseModelOutputWithPast
9
+ from transformers.utils.deprecation import deprecate_kwarg
10
+
11
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
12
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
13
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
14
+
15
+
16
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
17
+ def lce_forward(
18
+ self,
19
+ input_ids: torch.LongTensor = None,
20
+ attention_mask: Optional[torch.Tensor] = None,
21
+ position_ids: Optional[torch.LongTensor] = None,
22
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
23
+ inputs_embeds: Optional[torch.FloatTensor] = None,
24
+ labels: Optional[torch.LongTensor] = None,
25
+ use_cache: Optional[bool] = None,
26
+ output_attentions: Optional[bool] = None,
27
+ output_hidden_states: Optional[bool] = None,
28
+ return_dict: Optional[bool] = None,
29
+ cache_position: Optional[torch.LongTensor] = None,
30
+ logits_to_keep: Union[int, torch.Tensor] = 0,
31
+ skip_logits: Optional[bool] = None,
32
+ **kwargs,
33
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
34
+ r"""
35
+ Args:
36
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
37
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
38
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
39
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
40
+
41
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
42
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
43
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
44
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
45
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
46
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
47
+
48
+ Returns:
49
+
50
+ Example:
51
+
52
+ ```python
53
+ >>> from transformers import AutoTokenizer, Olmo3ForCausalLM
54
+
55
+ >>> model = Olmo3ForCausalLM.from_pretrained("allenai/Olmo-3-7B-Instruct")
56
+ >>> tokenizer = AutoTokenizer.from_pretrained("allenai/Olmo-3-7B-Instruct")
57
+
58
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
59
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
60
+
61
+ >>> # Generate
62
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
63
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
64
+ 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
65
+ ```
66
+ """
67
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
68
+ output_hidden_states = (
69
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
70
+ )
71
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
72
+
73
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
74
+ outputs: BaseModelOutputWithPast = self.model(
75
+ input_ids=input_ids,
76
+ attention_mask=attention_mask,
77
+ position_ids=position_ids,
78
+ past_key_values=past_key_values,
79
+ inputs_embeds=inputs_embeds,
80
+ use_cache=use_cache,
81
+ output_attentions=output_attentions,
82
+ output_hidden_states=output_hidden_states,
83
+ return_dict=return_dict,
84
+ cache_position=cache_position,
85
+ **kwargs,
86
+ )
87
+
88
+ hidden_states = outputs[0]
89
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
90
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
91
+ kept_hidden_states = hidden_states[:, slice_indices, :]
92
+
93
+ shift_labels = kwargs.pop("shift_labels", None)
94
+ logits = None
95
+ loss = None
96
+ token_accuracy = None
97
+
98
+ if skip_logits and labels is None and shift_labels is None:
99
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
100
+
101
+ if skip_logits is None:
102
+ # By default, if in training mode, don't materialize logits
103
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
104
+
105
+ # Compute loss
106
+ if skip_logits:
107
+ result = LigerForCausalLMLoss(
108
+ hidden_states=kept_hidden_states,
109
+ lm_head_weight=self.lm_head.weight,
110
+ labels=labels,
111
+ shift_labels=shift_labels,
112
+ hidden_size=self.config.hidden_size,
113
+ **kwargs,
114
+ )
115
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
116
+
117
+ else:
118
+ logits = self.lm_head(kept_hidden_states)
119
+ if labels is not None or shift_labels is not None:
120
+ loss = self.loss_function(
121
+ logits=logits,
122
+ labels=labels,
123
+ shift_labels=shift_labels,
124
+ vocab_size=self.config.vocab_size,
125
+ **kwargs,
126
+ )
127
+
128
+ if not return_dict:
129
+ output = (logits,) + outputs[1:]
130
+ output = ((loss,) + output) if loss is not None else output
131
+ output = output + (token_accuracy,) if token_accuracy is not None else output
132
+ return output
133
+
134
+ # Return custom output class with token_accuracy field
135
+ return LigerCausalLMOutputWithPast(
136
+ loss=loss,
137
+ logits=logits,
138
+ past_key_values=outputs.past_key_values,
139
+ hidden_states=outputs.hidden_states,
140
+ attentions=outputs.attentions,
141
+ token_accuracy=token_accuracy,
142
+ )
@@ -1928,6 +1928,74 @@ def apply_liger_kernel_to_olmo2(
1928
1928
  _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
1929
1929
 
1930
1930
 
1931
+ def apply_liger_kernel_to_olmo3(
1932
+ rope: bool = True,
1933
+ cross_entropy: bool = False,
1934
+ fused_linear_cross_entropy: bool = True,
1935
+ rms_norm: bool = True,
1936
+ swiglu: bool = True,
1937
+ model: PreTrainedModel = None,
1938
+ ) -> None:
1939
+ """
1940
+ Apply Liger kernels to replace original implementation in HuggingFace Olmo3 models.
1941
+
1942
+ Args:
1943
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1944
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1945
+ fused_linear_cross_entropy (bool):
1946
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1947
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1948
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1949
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1950
+ swiglu (bool): Whether to apply Liger's SwiGLU to Olmo3MLP. Default is True.
1951
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1952
+ loaded. Default is None.
1953
+ """
1954
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1955
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1956
+ )
1957
+
1958
+ from transformers.models.olmo3 import modeling_olmo3
1959
+ from transformers.models.olmo3.modeling_olmo3 import Olmo3Model
1960
+
1961
+ from liger_kernel.transformers.model.olmo3 import lce_forward as olmo3_lce_forward
1962
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
1963
+
1964
+ # Olmo3 arch is very similar to Olmo2, so we can reuse all these components in the same way.
1965
+ if rope:
1966
+ modeling_olmo3.apply_rotary_pos_emb = liger_rotary_pos_emb
1967
+ if rms_norm:
1968
+ modeling_olmo3.Olmo3RMSNorm = LigerRMSNormForOlmo2 # same as olmo2
1969
+ if swiglu:
1970
+ modeling_olmo3.Olmo3MLP = LigerSwiGLUMLP
1971
+ if cross_entropy:
1972
+ from transformers.loss.loss_utils import nn
1973
+
1974
+ nn.functional.cross_entropy = liger_cross_entropy
1975
+ if fused_linear_cross_entropy:
1976
+ if model is not None:
1977
+ model.forward = MethodType(olmo3_lce_forward, model)
1978
+ else:
1979
+ modeling_olmo3.Olmo3ForCausalLM.forward = olmo3_lce_forward
1980
+
1981
+ if model is not None:
1982
+ # The model instance already exists, so we need to additionally patch the
1983
+ # instance variables that reference already-instantiated modules
1984
+
1985
+ # get the base model from the model instance
1986
+ base_model: Olmo3Model = getattr(model, model.base_model_prefix, model)
1987
+
1988
+ if rms_norm:
1989
+ _patch_rms_norm_module(base_model.norm)
1990
+
1991
+ for decoder_layer in base_model.layers:
1992
+ if swiglu:
1993
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1994
+ if rms_norm:
1995
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
1996
+ _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
1997
+
1998
+
1931
1999
  def apply_liger_kernel_to_glm4(
1932
2000
  rope: bool = False,
1933
2001
  cross_entropy: bool = False,
@@ -2558,6 +2626,123 @@ def apply_liger_kernel_to_qwen3_next(
2558
2626
  _patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
2559
2627
 
2560
2628
 
2629
+ def apply_liger_kernel_to_hunyuan_v1_dense(
2630
+ rope: bool = True,
2631
+ cross_entropy: bool = False,
2632
+ fused_linear_cross_entropy: bool = True,
2633
+ rms_norm: bool = True,
2634
+ swiglu: bool = True,
2635
+ model: PreTrainedModel = None,
2636
+ ) -> None:
2637
+ """
2638
+ Apply Liger kernels to replace original implementation in HuggingFace Hunyuan v1 dense models.
2639
+ """
2640
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2641
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2642
+ )
2643
+
2644
+ from transformers.models.hunyuan_v1_dense import modeling_hunyuan_v1_dense
2645
+ from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1Model
2646
+
2647
+ from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_lce_forward
2648
+ from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
2649
+
2650
+ if rope:
2651
+ modeling_hunyuan_v1_dense.apply_rotary_pos_emb = liger_rotary_pos_emb
2652
+
2653
+ if rms_norm:
2654
+ modeling_hunyuan_v1_dense.HunYuanDenseV1RMSNorm = LigerRMSNorm
2655
+
2656
+ if cross_entropy:
2657
+ from transformers.loss.loss_utils import nn
2658
+
2659
+ nn.functional.cross_entropy = liger_cross_entropy
2660
+
2661
+ if fused_linear_cross_entropy:
2662
+ if model is not None:
2663
+ model.forward = MethodType(hunyuan_v1_lce_forward, model)
2664
+ else:
2665
+ modeling_hunyuan_v1_dense.HunYuanDenseV1ForCausalLM.forward = hunyuan_v1_lce_forward
2666
+
2667
+ if swiglu:
2668
+ modeling_hunyuan_v1_dense.HunYuanDenseV1MLP = LigerHunyuanV1SwiGLUMLP
2669
+
2670
+ if model is not None:
2671
+ # The model instance already exists, so we need to additionally patch the
2672
+ # instance variables that reference already-instantiated modules
2673
+
2674
+ # get the base model from the model instance
2675
+ base_model: HunYuanDenseV1Model = getattr(model, model.base_model_prefix, model)
2676
+
2677
+ if rms_norm:
2678
+ _patch_rms_norm_module(base_model.norm)
2679
+ for decoder_layer in base_model.layers:
2680
+ if swiglu:
2681
+ _patch_swiglu_module(decoder_layer.mlp, LigerHunyuanV1SwiGLUMLP)
2682
+ if rms_norm:
2683
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2684
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2685
+
2686
+
2687
+ def apply_liger_kernel_to_hunyuan_v1_moe(
2688
+ rope: bool = True,
2689
+ cross_entropy: bool = False,
2690
+ fused_linear_cross_entropy: bool = True,
2691
+ rms_norm: bool = True,
2692
+ swiglu: bool = True,
2693
+ model: PreTrainedModel = None,
2694
+ ) -> None:
2695
+ """
2696
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
2697
+ """
2698
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2699
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2700
+ )
2701
+
2702
+ from transformers.models.hunyuan_v1_moe import modeling_hunyuan_v1_moe
2703
+ from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1Model
2704
+
2705
+ from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_moe_lce_forward
2706
+ from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
2707
+
2708
+ if rope:
2709
+ modeling_hunyuan_v1_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
2710
+
2711
+ if rms_norm:
2712
+ modeling_hunyuan_v1_moe.HunYuanMoEV1RMSNorm = LigerRMSNorm
2713
+
2714
+ if cross_entropy:
2715
+ from transformers.loss.loss_utils import nn
2716
+
2717
+ nn.functional.cross_entropy = liger_cross_entropy
2718
+
2719
+ if fused_linear_cross_entropy:
2720
+ if model is not None:
2721
+ model.forward = MethodType(hunyuan_v1_moe_lce_forward, model)
2722
+ else:
2723
+ modeling_hunyuan_v1_moe.HunYuanMoEV1ForCausalLM.forward = hunyuan_v1_moe_lce_forward
2724
+
2725
+ if swiglu:
2726
+ modeling_hunyuan_v1_moe.HunYuanMoEV1MLP = LigerHunyuanV1SwiGLUMLP
2727
+
2728
+ if model is not None:
2729
+ # The model instance already exists, so we need to additionally patch the
2730
+ # instance variables that reference already-instantiated modules
2731
+
2732
+ # get the base model from the model instance
2733
+ base_model: HunYuanMoEV1Model = getattr(model, model.base_model_prefix, model)
2734
+
2735
+ if rms_norm:
2736
+ _patch_rms_norm_module(base_model.norm)
2737
+ for decoder_layer in base_model.layers:
2738
+ if swiglu:
2739
+ for mlp_expert in decoder_layer.mlp.experts:
2740
+ _patch_swiglu_module(mlp_expert, LigerHunyuanV1SwiGLUMLP)
2741
+ if rms_norm:
2742
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2743
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2744
+
2745
+
2561
2746
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
2562
2747
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
2563
2748
  "gemma": apply_liger_kernel_to_gemma,
@@ -2578,6 +2763,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
2578
2763
  "mistral": apply_liger_kernel_to_mistral,
2579
2764
  "mixtral": apply_liger_kernel_to_mixtral,
2580
2765
  "olmo2": apply_liger_kernel_to_olmo2,
2766
+ "olmo3": apply_liger_kernel_to_olmo3,
2581
2767
  "qwen2": apply_liger_kernel_to_qwen2,
2582
2768
  "qwen3": apply_liger_kernel_to_qwen3,
2583
2769
  "qwen3_moe": apply_liger_kernel_to_qwen3_moe,
@@ -2595,6 +2781,8 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
2595
2781
  "paligemma": apply_liger_kernel_to_paligemma,
2596
2782
  "falcon_h1": apply_liger_kernel_to_falcon_h1,
2597
2783
  "smolvlm": apply_liger_kernel_to_smolvlm,
2784
+ "hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
2785
+ "hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
2598
2786
  }
2599
2787
 
2600
2788
 
@@ -77,3 +77,20 @@ class LigerQwen3MoeSwiGLUMLP(nn.Module):
77
77
 
78
78
  def forward(self, x):
79
79
  return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
80
+
81
+
82
+ class LigerHunyuanV1SwiGLUMLP(nn.Module):
83
+ def __init__(self, config, layer_idx=None, is_shared_mlp=False):
84
+ super().__init__()
85
+ self.config = config
86
+ self.hidden_size = config.hidden_size
87
+ self.intermediate_size = config.intermediate_size
88
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
89
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
90
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
91
+ self.layer_idx = layer_idx
92
+ if config.hidden_act not in ["silu", "swish"]:
93
+ raise ValueError(f"Activation function {config.hidden_act} not supported.")
94
+
95
+ def forward(self, x):
96
+ return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.6.3.dev20251121010306
3
+ Version: 0.6.3.dev20251121200119
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -310,8 +310,11 @@ loss.backward()
310
310
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
311
311
  | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
312
312
  | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
313
+ | Olmo3 | `liger_kernel.transformers.apply_liger_kernel_to_olmo3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
313
314
  | GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
314
315
  | InternVL3 | `liger_kernel.transformers.apply_liger_kernel_to_internvl` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
316
+ | HunyuanV1 | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_dense` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
317
+ | HunyuanV1 MoE | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
315
318
 
316
319
 
317
320
  ## Low-level APIs
@@ -43,7 +43,7 @@ liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
43
43
  liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
44
44
  liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
45
45
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
46
- liger_kernel/transformers/__init__.py,sha256=XX1ySRgZXeQe0or-6GNclAsNQG_VkABQlkwqpB1Wn8A,10090
46
+ liger_kernel/transformers/__init__.py,sha256=CgwhrY5cdx6OcRgR2ZZJbOIkLswQWPTr-BAaoxDNNOY,10687
47
47
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
48
48
  liger_kernel/transformers/cross_entropy.py,sha256=DMtHkKrVJDSsels7KgGQJqrXkEAd6Zopcdr-5oRmQgE,2010
49
49
  liger_kernel/transformers/dyt.py,sha256=i-4GPaMrl-jab9TVI5qN0-H9qycn_mCbV82ozU4nbmU,723
@@ -60,7 +60,7 @@ liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCc
60
60
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
61
61
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
62
62
  liger_kernel/transformers/llama4_rope.py,sha256=kS6PSHEwf3dS7hD7C7p8S0geugx2EMCiP0h0F7LsUoY,3639
63
- liger_kernel/transformers/monkey_patch.py,sha256=ZGnLygHuCiKGd6hT-C0pt1aY85f6GNFdV98oCDpxHHo,124742
63
+ liger_kernel/transformers/monkey_patch.py,sha256=4LV6LSz_AAop6HWk1spZm1QigPN9nUDPJu9tK21-jIo,132446
64
64
  liger_kernel/transformers/multi_token_attention.py,sha256=K3NIY9_5TPgZ4_Rahn0xnkMXxD_fmlJHK4CWGYvGQp0,1752
65
65
  liger_kernel/transformers/poly_norm.py,sha256=g5tC75i3qy1_N26ZUP-jfpct7ivQAEdJfIfx8IXzeyE,1377
66
66
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
@@ -68,7 +68,7 @@ liger_kernel/transformers/rms_norm.py,sha256=HwddVqrqS58jE-M2_4NkFGARtCDBhGnkKyj
68
68
  liger_kernel/transformers/rope.py,sha256=VMlDZI6zss9mLaLcN5XCE_ktmYRwAi_Eh4TIgO6NrIQ,2361
69
69
  liger_kernel/transformers/softmax.py,sha256=yadlAgE4V2JByMwrDDa2s5SUBp8Jgd57xwnVvAWoBaI,264
70
70
  liger_kernel/transformers/sparsemax.py,sha256=0lQA0UEOs4mu8CMruZ3VLhImxQVXJWhPsAKUsYA7vj8,403
71
- liger_kernel/transformers/swiglu.py,sha256=LZ8YeLIdv2k46JleZMjzubGk98smt6t780kSgcVLsQk,3454
71
+ liger_kernel/transformers/swiglu.py,sha256=dRR69wDWSWfdjtnsTECyxQqWVo5QkdXdXm9SpSQ4Jvw,4291
72
72
  liger_kernel/transformers/tiled_mlp.py,sha256=J51-kpzwikDMMhT5bX-RZCKMaXBK6zZc1bhgRYTK5F0,4651
73
73
  liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
74
74
  liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_JerQ,450
@@ -82,6 +82,7 @@ liger_kernel/transformers/model/gemma3.py,sha256=mEV3Kuy-dqfTk_b899Vb-InuD4_DvwH
82
82
  liger_kernel/transformers/model/glm4.py,sha256=bSp22iPIjsli4-c_usUOsyh1Bs2gIK8X6ynS0azseUs,5900
83
83
  liger_kernel/transformers/model/glm4v.py,sha256=dd-BQpccDCp1SbIxcJ5rG8xcwYQK3KOv1Tgm9TGnZc4,6594
84
84
  liger_kernel/transformers/model/glm4v_moe.py,sha256=zKhMdOOrRhlrvCSFaeVYfddL1ubpY8edEO91TN81n98,7135
85
+ liger_kernel/transformers/model/hunyuan_v1.py,sha256=MJvP9xkUFePIV0HLETJM4YPbVCEPkAE1ZI5Jxyiebh0,5731
85
86
  liger_kernel/transformers/model/internvl.py,sha256=OOutracs9qrPHSU7FVYar08yinvGrHQVPvo39JEws6w,6473
86
87
  liger_kernel/transformers/model/llama.py,sha256=kqZeONzwTBzudoChlKMzq1w23BtYGbxWZC1l1V__JTw,13410
87
88
  liger_kernel/transformers/model/llama4.py,sha256=PfkynGVI0xxMs3EtyYpCgaALI6stu25OIrTIymE-pvg,4853
@@ -91,6 +92,7 @@ liger_kernel/transformers/model/mistral.py,sha256=OcwOzVDMwwDbVccVPv-AaocznzWwzL
91
92
  liger_kernel/transformers/model/mixtral.py,sha256=YcBDoTEJDgLFJ_RTo180DYGxR8D5Ad9-idumif7kCPE,12130
92
93
  liger_kernel/transformers/model/mllama.py,sha256=vAHwCm63sn4kpAY0rDGf_N0HR7KRTBVpBYDVTPOaZTg,12079
93
94
  liger_kernel/transformers/model/olmo2.py,sha256=-h2bUOeuPfY1MdShdRvq5_wFDHKP4PEimgIl0fL-BT4,5902
95
+ liger_kernel/transformers/model/olmo3.py,sha256=k2zYOlS8U_b5MwjdToB3tDRQ0bH_mWapVQqJcH8-qAo,6007
94
96
  liger_kernel/transformers/model/output_classes.py,sha256=0BGXVR4dYQpSHLkSqpRoXuHMryrceGSlTYRu6pvd8ZY,4542
95
97
  liger_kernel/transformers/model/paligemma.py,sha256=r0smHLADkEwfLS6d6ArWoSWEeLt2d_8pmgOO5F04b1o,20793
96
98
  liger_kernel/transformers/model/phi3.py,sha256=PT7Kw6yySg-7TsssWfi82eVMN3SWujCqzCqHigAdfeQ,4574
@@ -108,9 +110,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
108
110
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
109
111
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
110
112
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
111
- liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
112
- liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/METADATA,sha256=HgCaZORIkj1lSLvj1vsjOJ0r9ouWZ-lqPCQ3JrJJMFU,24777
113
- liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
114
- liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
115
- liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
116
- liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD,,
113
+ liger_kernel_nightly-0.6.3.dev20251121200119.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
114
+ liger_kernel_nightly-0.6.3.dev20251121200119.dist-info/METADATA,sha256=dTCc8yabO75aXtlWdPFHw23yAhHuEr5K06YDaMH4OHU,25238
115
+ liger_kernel_nightly-0.6.3.dev20251121200119.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
116
+ liger_kernel_nightly-0.6.3.dev20251121200119.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
117
+ liger_kernel_nightly-0.6.3.dev20251121200119.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
118
+ liger_kernel_nightly-0.6.3.dev20251121200119.dist-info/RECORD,,