liger-kernel 0.5.9__py3-none-any.whl → 0.5.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. liger_kernel/chunked_loss/dpo_loss.py +1 -1
  2. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  3. liger_kernel/chunked_loss/jsd_loss.py +2 -2
  4. liger_kernel/ops/dyt.py +113 -179
  5. liger_kernel/ops/grpo_loss.py +310 -0
  6. liger_kernel/ops/sparsemax.py +167 -0
  7. liger_kernel/transformers/__init__.py +5 -0
  8. liger_kernel/transformers/dyt.py +5 -3
  9. liger_kernel/transformers/fsdp.py +55 -0
  10. liger_kernel/transformers/functional.py +8 -0
  11. liger_kernel/transformers/grpo_loss.py +98 -0
  12. liger_kernel/transformers/model/gemma.py +0 -8
  13. liger_kernel/transformers/model/gemma2.py +0 -6
  14. liger_kernel/transformers/model/gemma3.py +0 -8
  15. liger_kernel/transformers/model/glm4.py +0 -6
  16. liger_kernel/transformers/model/llama.py +56 -11
  17. liger_kernel/transformers/model/llava.py +0 -8
  18. liger_kernel/transformers/model/mistral.py +0 -6
  19. liger_kernel/transformers/model/mixtral.py +0 -8
  20. liger_kernel/transformers/model/mllama.py +0 -7
  21. liger_kernel/transformers/model/olmo2.py +0 -6
  22. liger_kernel/transformers/model/paligemma.py +0 -8
  23. liger_kernel/transformers/model/phi3.py +0 -8
  24. liger_kernel/transformers/model/qwen2.py +0 -8
  25. liger_kernel/transformers/model/qwen2_5_vl.py +0 -6
  26. liger_kernel/transformers/model/qwen2_vl.py +0 -6
  27. liger_kernel/transformers/model/qwen3.py +0 -6
  28. liger_kernel/transformers/model/qwen3_moe.py +128 -0
  29. liger_kernel/transformers/monkey_patch.py +122 -13
  30. liger_kernel/transformers/sparsemax.py +16 -0
  31. liger_kernel/transformers/swiglu.py +21 -0
  32. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  33. liger_kernel/utils.py +11 -0
  34. {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/METADATA +34 -20
  35. {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/RECORD +39 -33
  36. {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/WHEEL +1 -1
  37. {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/licenses/LICENSE +0 -0
  38. {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/licenses/NOTICE +0 -0
  39. {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/top_level.txt +0 -0
@@ -7,19 +7,13 @@ import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.modeling_outputs import MoeCausalLMOutputWithPast
10
- from transformers.models.mixtral.modeling_mixtral import _CONFIG_FOR_DOC
11
- from transformers.models.mixtral.modeling_mixtral import MIXTRAL_INPUTS_DOCSTRING
12
10
  from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
13
- from transformers.utils import add_start_docstrings_to_model_forward
14
- from transformers.utils import replace_return_docstrings
15
11
  from transformers.utils.deprecation import deprecate_kwarg
16
12
 
17
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
19
15
 
20
16
 
21
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
22
- @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
23
17
  def lce_forward_deprecated(
24
18
  self,
25
19
  input_ids: torch.LongTensor = None,
@@ -146,8 +140,6 @@ def lce_forward_deprecated(
146
140
 
147
141
 
148
142
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
149
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
150
- @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
151
143
  # Ignore copy
152
144
  def lce_forward(
153
145
  self,
@@ -8,17 +8,12 @@ import torch
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.cache_utils import Cache
10
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
11
- from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
11
  from transformers.utils.deprecation import deprecate_kwarg
15
12
 
16
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
15
 
19
16
 
20
- @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
22
17
  def lce_forward_deprecated(
23
18
  self,
24
19
  input_ids: torch.LongTensor = None,
@@ -135,8 +130,6 @@ def lce_forward_deprecated(
135
130
 
136
131
 
137
132
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
138
- @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
139
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
140
133
  def lce_forward(
141
134
  self,
142
135
  input_ids: torch.LongTensor = None,
@@ -6,18 +6,12 @@ from typing import Union
6
6
  import torch
7
7
 
8
8
  from transformers.modeling_outputs import CausalLMOutputWithPast
9
- from transformers.models.olmo2.modeling_olmo2 import _CONFIG_FOR_DOC
10
- from transformers.models.olmo2.modeling_olmo2 import OLMO2_INPUTS_DOCSTRING
11
- from transformers.utils import add_start_docstrings_to_model_forward
12
- from transformers.utils import replace_return_docstrings
13
9
  from transformers.utils.deprecation import deprecate_kwarg
14
10
 
15
11
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
12
 
17
13
 
18
14
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
19
- @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING)
20
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
21
15
  def lce_forward(
22
16
  self,
23
17
  input_ids: torch.LongTensor = None,
@@ -7,13 +7,9 @@ import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.cache_utils import Cache
10
- from transformers.models.paligemma.modeling_paligemma import _CONFIG_FOR_DOC
11
- from transformers.models.paligemma.modeling_paligemma import PALIGEMMA_INPUTS_DOCSTRING
12
10
  from transformers.models.paligemma.modeling_paligemma import PaliGemmaCausalLMOutputWithPast
13
- from transformers.utils import add_start_docstrings_to_model_forward
14
11
  from transformers.utils import is_torchdynamo_compiling
15
12
  from transformers.utils import logging
16
- from transformers.utils import replace_return_docstrings
17
13
  from transformers.utils.deprecation import deprecate_kwarg
18
14
 
19
15
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
@@ -21,8 +17,6 @@ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinea
21
17
  logger = logging.get_logger(__name__)
22
18
 
23
19
 
24
- @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
25
- @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
26
20
  def lce_forward_deprecated(
27
21
  self,
28
22
  input_ids: torch.LongTensor = None,
@@ -206,8 +200,6 @@ def lce_forward_deprecated(
206
200
 
207
201
 
208
202
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
209
- @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
210
- @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
211
203
  def lce_forward(
212
204
  self,
213
205
  input_ids: torch.LongTensor = None,
@@ -7,18 +7,12 @@ import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
10
- from transformers.models.phi3.modeling_phi3 import _CONFIG_FOR_DOC
11
- from transformers.models.phi3.modeling_phi3 import PHI3_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
10
  from transformers.utils.deprecation import deprecate_kwarg
15
11
 
16
12
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
13
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
14
 
19
15
 
20
- @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
22
16
  def lce_forward_deprecated(
23
17
  self,
24
18
  input_ids: torch.LongTensor = None,
@@ -128,8 +122,6 @@ def lce_forward_deprecated(
128
122
 
129
123
 
130
124
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
131
- @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
132
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
133
125
  def lce_forward(
134
126
  self,
135
127
  input_ids: torch.LongTensor = None,
@@ -7,18 +7,12 @@ import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
10
- from transformers.models.qwen2.modeling_qwen2 import _CONFIG_FOR_DOC
11
- from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
10
  from transformers.utils.deprecation import deprecate_kwarg
15
11
 
16
12
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
13
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
14
 
19
15
 
20
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
22
16
  def lce_forward_deprecated(
23
17
  self,
24
18
  input_ids: torch.LongTensor = None,
@@ -127,8 +121,6 @@ def lce_forward_deprecated(
127
121
 
128
122
 
129
123
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
130
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
131
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
132
124
  def lce_forward(
133
125
  self,
134
126
  input_ids: torch.LongTensor = None,
@@ -6,17 +6,11 @@ from typing import Union
6
6
  import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import _CONFIG_FOR_DOC
10
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import QWEN2_5_VL_INPUTS_DOCSTRING
11
9
  from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
10
 
15
11
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
12
 
17
13
 
18
- @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)
19
- @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
20
14
  def lce_forward(
21
15
  self,
22
16
  input_ids: torch.LongTensor = None,
@@ -8,17 +8,11 @@ import torch
8
8
  from packaging import version
9
9
  from torch.nn import CrossEntropyLoss
10
10
  from transformers import __version__ as transformers_version
11
- from transformers.models.qwen2_vl.modeling_qwen2_vl import _CONFIG_FOR_DOC
12
- from transformers.models.qwen2_vl.modeling_qwen2_vl import QWEN2_VL_INPUTS_DOCSTRING
13
11
  from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
14
- from transformers.utils import add_start_docstrings_to_model_forward
15
- from transformers.utils import replace_return_docstrings
16
12
 
17
13
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
14
 
19
15
 
20
- @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
22
16
  def lce_forward(
23
17
  self,
24
18
  input_ids: torch.LongTensor = None,
@@ -5,16 +5,10 @@ from typing import Union
5
5
  import torch
6
6
 
7
7
  from transformers.modeling_outputs import CausalLMOutputWithPast
8
- from transformers.models.qwen3.modeling_qwen3 import _CONFIG_FOR_DOC
9
- from transformers.models.qwen3.modeling_qwen3 import QWEN3_INPUTS_DOCSTRING
10
- from transformers.utils import add_start_docstrings_to_model_forward
11
- from transformers.utils import replace_return_docstrings
12
8
 
13
9
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
14
10
 
15
11
 
16
- @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)
17
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
18
12
  def lce_forward(
19
13
  self,
20
14
  input_ids: Optional[torch.LongTensor] = None,
@@ -0,0 +1,128 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Union
4
+
5
+ import torch
6
+
7
+ from transformers.modeling_outputs import MoeCausalLMOutputWithPast
8
+ from transformers.modeling_outputs import MoeModelOutputWithPast
9
+ from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
10
+
11
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
12
+
13
+
14
+ def lce_forward(
15
+ self,
16
+ input_ids: Optional[torch.LongTensor] = None,
17
+ attention_mask: Optional[torch.Tensor] = None,
18
+ position_ids: Optional[torch.LongTensor] = None,
19
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
20
+ inputs_embeds: Optional[torch.FloatTensor] = None,
21
+ labels: Optional[torch.LongTensor] = None,
22
+ use_cache: Optional[bool] = None,
23
+ output_attentions: Optional[bool] = None,
24
+ output_hidden_states: Optional[bool] = None,
25
+ output_router_logits: Optional[bool] = None,
26
+ cache_position: Optional[torch.LongTensor] = None,
27
+ logits_to_keep: Union[int, torch.Tensor] = 0,
28
+ **loss_kwargs,
29
+ ) -> MoeCausalLMOutputWithPast:
30
+ r"""
31
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
32
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
33
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
34
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
35
+
36
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
37
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
38
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
39
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
40
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
41
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
42
+
43
+ Returns:
44
+
45
+ Example:
46
+
47
+ ```python
48
+ >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM
49
+
50
+ >>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
51
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
52
+
53
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
54
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
55
+
56
+ >>> # Generate
57
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
58
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
59
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
60
+ ```"""
61
+
62
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
63
+ output_router_logits = (
64
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
65
+ )
66
+
67
+ output_hidden_states = (
68
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
69
+ )
70
+
71
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
72
+ outputs: MoeModelOutputWithPast = self.model(
73
+ input_ids=input_ids,
74
+ attention_mask=attention_mask,
75
+ position_ids=position_ids,
76
+ past_key_values=past_key_values,
77
+ inputs_embeds=inputs_embeds,
78
+ use_cache=use_cache,
79
+ output_attentions=output_attentions,
80
+ output_hidden_states=output_hidden_states,
81
+ output_router_logits=output_router_logits,
82
+ cache_position=cache_position,
83
+ )
84
+
85
+ hidden_states = outputs.last_hidden_state
86
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
87
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
88
+ kept_hidden_states = hidden_states[:, slice_indices, :]
89
+
90
+ shift_labels = loss_kwargs.pop("shift_labels", None)
91
+ logits = None
92
+ loss = None
93
+
94
+ # if in training mode, do not materialize logits
95
+ if self.training and (labels is not None or shift_labels is not None):
96
+ loss = LigerForCausalLMLoss(
97
+ hidden_states=kept_hidden_states,
98
+ lm_head_weight=self.lm_head.weight,
99
+ labels=labels,
100
+ shift_labels=shift_labels,
101
+ hidden_size=self.config.hidden_size,
102
+ **loss_kwargs,
103
+ )
104
+ else: # if in inference model materialize logits
105
+ logits = self.lm_head(kept_hidden_states)
106
+ if labels is not None:
107
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
108
+
109
+ aux_loss = None
110
+ if output_router_logits:
111
+ aux_loss = load_balancing_loss_func(
112
+ outputs.router_logits,
113
+ self.num_experts,
114
+ self.num_experts_per_tok,
115
+ attention_mask,
116
+ )
117
+ if labels is not None:
118
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
119
+
120
+ return MoeCausalLMOutputWithPast(
121
+ loss=loss,
122
+ aux_loss=aux_loss,
123
+ logits=logits,
124
+ past_key_values=outputs.past_key_values,
125
+ hidden_states=outputs.hidden_states,
126
+ attentions=outputs.attentions,
127
+ router_logits=outputs.router_logits,
128
+ )
@@ -35,6 +35,13 @@ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
35
35
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
36
36
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
37
37
 
38
+ try:
39
+ import peft
40
+
41
+ PEFT_AVAILABLE = True
42
+ except ImportError:
43
+ PEFT_AVAILABLE = False
44
+
38
45
  transformer_version = version.parse(transformers.__version__)
39
46
 
40
47
  logger = logging.getLogger(__name__)
@@ -48,22 +55,68 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
48
55
 
49
56
 
50
57
  def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
51
- module.offset = offset
52
- module.casting_mode = casting_mode
53
- module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
54
- module.in_place = in_place
55
- _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
56
- _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
57
- module.__class__.__name__ = LigerRMSNorm.__name__
58
+ # Check if the module is a PEFT ModulesToSaveWrapper
59
+ # If it is, we need to patch the modules_to_save.default and original_modules
60
+ if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
61
+ module.modules_to_save.default.offset = offset
62
+ module.modules_to_save.default.casting_mode = casting_mode
63
+ module.modules_to_save.default.variance_epsilon = (
64
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
65
+ )
66
+ module.modules_to_save.default.in_place = in_place
67
+ module.original_module.offset = offset
68
+ module.original_module.casting_mode = casting_mode
69
+ module.original_module.variance_epsilon = (
70
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
71
+ )
72
+ module.original_module.in_place = in_place
73
+ _bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
74
+ _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
75
+ _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
76
+ _bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
77
+ module.modules_to_save.default.__class__.__name__ = LigerRMSNorm.__name__
78
+ module.original_module.__class__.__name__ = LigerRMSNorm.__name__
79
+ else:
80
+ module.offset = offset
81
+ module.casting_mode = casting_mode
82
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
83
+ module.in_place = in_place
84
+ _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
85
+ _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
86
+ module.__class__.__name__ = LigerRMSNorm.__name__
58
87
 
59
88
 
60
89
  def _patch_layer_norm_module(module, eps=1e-6):
61
- module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
62
- module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
63
-
64
- _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
65
- _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
66
- module.__class__.__name__ = LigerLayerNorm.__name__
90
+ # Check if the module is a PEFT ModulesToSaveWrapper
91
+ # If it is, we need to patch the modules_to_save.default and original_modules
92
+ if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
93
+ module.hidden_size = module.normalized_shape
94
+ _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
95
+ _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
96
+ module.modules_to_save.default.variance_epsilon = (
97
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
98
+ )
99
+ module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
100
+ module, "normalized_shape", None
101
+ )
102
+ module.original_module.variance_epsilon = (
103
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
104
+ )
105
+ module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
106
+ module, "normalized_shape", None
107
+ )
108
+ _bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
109
+ _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
110
+ _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
111
+ _bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
112
+ module.modules_to_save.default.__class__.__name__ = LigerLayerNorm.__name__
113
+ module.original_module.__class__.__name__ = LigerLayerNorm.__name__
114
+ else:
115
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
116
+ module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
117
+ _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
118
+ _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
119
+ module.__class__.__name__ = LigerLayerNorm.__name__
67
120
 
68
121
 
69
122
  def _patch_swiglu_module(module, liger_module):
@@ -1102,6 +1155,61 @@ def apply_liger_kernel_to_qwen3(
1102
1155
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1103
1156
 
1104
1157
 
1158
+ def apply_liger_kernel_to_qwen3_moe(
1159
+ rope: bool = True,
1160
+ cross_entropy: bool = False,
1161
+ fused_linear_cross_entropy: bool = True,
1162
+ rms_norm: bool = True,
1163
+ swiglu: bool = True,
1164
+ model: PreTrainedModel = None,
1165
+ ) -> None:
1166
+ """
1167
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
1168
+ """
1169
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1170
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1171
+ )
1172
+
1173
+ from transformers.models.qwen3_moe import modeling_qwen3_moe
1174
+ from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeModel
1175
+
1176
+ from liger_kernel.transformers.model.qwen3_moe import lce_forward as qwen3_lce_forward
1177
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
1178
+
1179
+ if rope:
1180
+ modeling_qwen3_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
1181
+
1182
+ if rms_norm:
1183
+ modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
1184
+
1185
+ if cross_entropy:
1186
+ from transformers.loss.loss_utils import nn
1187
+
1188
+ nn.functional.cross_entropy = liger_cross_entropy
1189
+
1190
+ if fused_linear_cross_entropy:
1191
+ modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
1192
+
1193
+ if swiglu:
1194
+ modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
1195
+
1196
+ if model is not None:
1197
+ # The model instance already exists, so we need to additionally patch the
1198
+ # instance variables that reference already-instantiated modules
1199
+
1200
+ # get the base model from the model instance
1201
+ base_model: Qwen3MoeModel = getattr(model, model.base_model_prefix, model)
1202
+
1203
+ if rms_norm:
1204
+ _patch_rms_norm_module(base_model.norm)
1205
+ for decoder_layer in base_model.layers:
1206
+ if swiglu:
1207
+ _patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
1208
+ if rms_norm:
1209
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1210
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1211
+
1212
+
1105
1213
  def apply_liger_kernel_to_qwen2_vl(
1106
1214
  rope: bool = True,
1107
1215
  cross_entropy: bool = False,
@@ -1455,6 +1563,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1455
1563
  "olmo2": apply_liger_kernel_to_olmo2,
1456
1564
  "qwen2": apply_liger_kernel_to_qwen2,
1457
1565
  "qwen3": apply_liger_kernel_to_qwen3,
1566
+ "qwen3_moe": apply_liger_kernel_to_qwen3_moe,
1458
1567
  "qwen2_vl": apply_liger_kernel_to_qwen2_vl,
1459
1568
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
1460
1569
  "phi3": apply_liger_kernel_to_phi3,
@@ -0,0 +1,16 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
5
+
6
+
7
+ class LigerSparsemax(nn.Module):
8
+ def __init__(self, dim: int = -1):
9
+ super().__init__()
10
+ self.dim = dim
11
+
12
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
13
+ return LigerSparsemaxFunction.apply(x, self.dim)
14
+
15
+ def extra_repr(self) -> str:
16
+ return f"dim={self.dim}"
@@ -56,3 +56,24 @@ class LigerPhi3SwiGLUMLP(nn.Module):
56
56
  up_states = self.gate_up_proj(x)
57
57
  gate, up_states = up_states.chunk(2, dim=-1)
58
58
  return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states))
59
+
60
+
61
+ class LigerQwen3MoeSwiGLUMLP(nn.Module):
62
+ """
63
+ Patch Qwen3MoeMLP to use LigerSiLUMulFunction.
64
+ https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/qwen3_moe/modular_qwen3_moe.py#L57
65
+ """
66
+
67
+ def __init__(self, config, intermediate_size=None):
68
+ super().__init__()
69
+ self.config = config
70
+ self.hidden_size = config.hidden_size
71
+ self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
72
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
73
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
74
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
75
+ if config.hidden_act not in ["silu", "swish"]:
76
+ raise ValueError(f"Activation function {config.hidden_act} not supported.")
77
+
78
+ def forward(self, x):
79
+ return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
@@ -1,5 +1,3 @@
1
- from typing import Any
2
- from typing import Callable
3
1
  from typing import Dict
4
2
  from typing import List
5
3
  from typing import Literal
@@ -13,57 +11,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel
13
11
  from trl.trainer import ORPOTrainer
14
12
 
15
13
  from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
16
-
17
-
18
- class _FSDPForwardRedirection:
19
- """
20
- Modified based on
21
- https://github.com/Lightning-AI/pytorch-lightning/blob/d3f9c83d6efa4f1def36aa6c199600946cdb9117/src/lightning/pytorch/strategies/strategy.py#L601-L648
22
- Redirect a method call through FullyShardedDataParallel.forward so that the FSDP module's root pre-forward and
23
- post-forward can be properly executed around the method call.
24
- This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only
25
- the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving
26
- GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`)
27
- will not work because the first `nn.Embedding` layer is not independently wrapped as a FSDP module (because of
28
- the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather
29
- its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just
30
- the `lm_head` part of a model, we need this trick too to properly get its params all-gathered.
31
- """
32
-
33
- def __call__(
34
- self,
35
- wrapper_module: FullyShardedDataParallel,
36
- method: Callable,
37
- *args: Any,
38
- **kwargs: Any,
39
- ):
40
- """Reroutes a method call through the `wrapper_module`'s `forward` method.
41
- Args:
42
- wrapper_module: The module that has `original_module` wrapped.
43
- original_module: The module that was wrapped inside `wrapper_module`.
44
- method_name: The name of the method that should be called on the `original_module` after inputs get
45
- redirected through the `wrapper_module`'s `forward` method.
46
- *args: The positional arguments to the method `method_name`. They will get passed to a patched
47
- `forward` method instead.
48
- **kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
49
- `forward` method instead.
50
- """
51
- assert isinstance(wrapper_module, FullyShardedDataParallel)
52
- original_module = wrapper_module._fsdp_wrapped_module
53
- original_forward = original_module.forward
54
-
55
- def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
56
- # Unpatch ourselves immediately before calling the method `method_name`
57
- # because itself may want to call the real `forward`
58
- original_module.forward = original_forward # type: ignore[method-assign]
59
- # Call the actual method e.g. `.training_step(...)`
60
- out = method(*_args, **_kwargs)
61
- return out
62
-
63
- # Patch the original_module's forward so we can redirect the arguments back to the real method
64
- original_module.forward = wrapped_forward # type: ignore[method-assign]
65
- wrapper_output = wrapper_module(*args, **kwargs)
66
- return wrapper_output
14
+ from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
67
15
 
68
16
 
69
17
  class LigerORPOTrainer(ORPOTrainer):
liger_kernel/utils.py CHANGED
@@ -1,6 +1,17 @@
1
+ try:
2
+ import peft # noqa: F401
3
+
4
+ PEFT_AVAILABLE = True
5
+ except ImportError:
6
+ PEFT_AVAILABLE = False
7
+
1
8
  import torch
2
9
 
3
10
 
11
+ def is_peft_available():
12
+ return PEFT_AVAILABLE
13
+
14
+
4
15
  def infer_device():
5
16
  """
6
17
  Get current device name based on available devices