liger-kernel 0.5.9__py3-none-any.whl → 0.5.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/dpo_loss.py +1 -1
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/jsd_loss.py +2 -2
- liger_kernel/ops/dyt.py +113 -179
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/sparsemax.py +167 -0
- liger_kernel/transformers/__init__.py +5 -0
- liger_kernel/transformers/dyt.py +5 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +8 -0
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/model/gemma.py +0 -8
- liger_kernel/transformers/model/gemma2.py +0 -6
- liger_kernel/transformers/model/gemma3.py +0 -8
- liger_kernel/transformers/model/glm4.py +0 -6
- liger_kernel/transformers/model/llama.py +56 -11
- liger_kernel/transformers/model/llava.py +0 -8
- liger_kernel/transformers/model/mistral.py +0 -6
- liger_kernel/transformers/model/mixtral.py +0 -8
- liger_kernel/transformers/model/mllama.py +0 -7
- liger_kernel/transformers/model/olmo2.py +0 -6
- liger_kernel/transformers/model/paligemma.py +0 -8
- liger_kernel/transformers/model/phi3.py +0 -8
- liger_kernel/transformers/model/qwen2.py +0 -8
- liger_kernel/transformers/model/qwen2_5_vl.py +0 -6
- liger_kernel/transformers/model/qwen2_vl.py +0 -6
- liger_kernel/transformers/model/qwen3.py +0 -6
- liger_kernel/transformers/model/qwen3_moe.py +128 -0
- liger_kernel/transformers/monkey_patch.py +122 -13
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +21 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/utils.py +11 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/METADATA +34 -20
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/RECORD +39 -33
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/WHEEL +1 -1
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/top_level.txt +0 -0
|
@@ -7,19 +7,13 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
9
|
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
|
10
|
-
from transformers.models.mixtral.modeling_mixtral import _CONFIG_FOR_DOC
|
|
11
|
-
from transformers.models.mixtral.modeling_mixtral import MIXTRAL_INPUTS_DOCSTRING
|
|
12
10
|
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
|
|
13
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
14
|
-
from transformers.utils import replace_return_docstrings
|
|
15
11
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
16
12
|
|
|
17
13
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
18
14
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
19
15
|
|
|
20
16
|
|
|
21
|
-
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
|
|
22
|
-
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
23
17
|
def lce_forward_deprecated(
|
|
24
18
|
self,
|
|
25
19
|
input_ids: torch.LongTensor = None,
|
|
@@ -146,8 +140,6 @@ def lce_forward_deprecated(
|
|
|
146
140
|
|
|
147
141
|
|
|
148
142
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
149
|
-
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
|
|
150
|
-
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
151
143
|
# Ignore copy
|
|
152
144
|
def lce_forward(
|
|
153
145
|
self,
|
|
@@ -8,17 +8,12 @@ import torch
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
9
|
from transformers.cache_utils import Cache
|
|
10
10
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
11
|
-
from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING
|
|
12
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
|
-
from transformers.utils import replace_return_docstrings
|
|
14
11
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
15
12
|
|
|
16
13
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
17
14
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
18
15
|
|
|
19
16
|
|
|
20
|
-
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
|
21
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
|
|
22
17
|
def lce_forward_deprecated(
|
|
23
18
|
self,
|
|
24
19
|
input_ids: torch.LongTensor = None,
|
|
@@ -135,8 +130,6 @@ def lce_forward_deprecated(
|
|
|
135
130
|
|
|
136
131
|
|
|
137
132
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
138
|
-
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
|
139
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
|
|
140
133
|
def lce_forward(
|
|
141
134
|
self,
|
|
142
135
|
input_ids: torch.LongTensor = None,
|
|
@@ -6,18 +6,12 @@ from typing import Union
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
8
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
9
|
-
from transformers.models.olmo2.modeling_olmo2 import _CONFIG_FOR_DOC
|
|
10
|
-
from transformers.models.olmo2.modeling_olmo2 import OLMO2_INPUTS_DOCSTRING
|
|
11
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
12
|
-
from transformers.utils import replace_return_docstrings
|
|
13
9
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
14
10
|
|
|
15
11
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
16
12
|
|
|
17
13
|
|
|
18
14
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
19
|
-
@add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING)
|
|
20
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
21
15
|
def lce_forward(
|
|
22
16
|
self,
|
|
23
17
|
input_ids: torch.LongTensor = None,
|
|
@@ -7,13 +7,9 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
9
|
from transformers.cache_utils import Cache
|
|
10
|
-
from transformers.models.paligemma.modeling_paligemma import _CONFIG_FOR_DOC
|
|
11
|
-
from transformers.models.paligemma.modeling_paligemma import PALIGEMMA_INPUTS_DOCSTRING
|
|
12
10
|
from transformers.models.paligemma.modeling_paligemma import PaliGemmaCausalLMOutputWithPast
|
|
13
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
14
11
|
from transformers.utils import is_torchdynamo_compiling
|
|
15
12
|
from transformers.utils import logging
|
|
16
|
-
from transformers.utils import replace_return_docstrings
|
|
17
13
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
18
14
|
|
|
19
15
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
@@ -21,8 +17,6 @@ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinea
|
|
|
21
17
|
logger = logging.get_logger(__name__)
|
|
22
18
|
|
|
23
19
|
|
|
24
|
-
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
|
|
25
|
-
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
26
20
|
def lce_forward_deprecated(
|
|
27
21
|
self,
|
|
28
22
|
input_ids: torch.LongTensor = None,
|
|
@@ -206,8 +200,6 @@ def lce_forward_deprecated(
|
|
|
206
200
|
|
|
207
201
|
|
|
208
202
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
209
|
-
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
|
|
210
|
-
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
211
203
|
def lce_forward(
|
|
212
204
|
self,
|
|
213
205
|
input_ids: torch.LongTensor = None,
|
|
@@ -7,18 +7,12 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
9
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
10
|
-
from transformers.models.phi3.modeling_phi3 import _CONFIG_FOR_DOC
|
|
11
|
-
from transformers.models.phi3.modeling_phi3 import PHI3_INPUTS_DOCSTRING
|
|
12
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
|
-
from transformers.utils import replace_return_docstrings
|
|
14
10
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
15
11
|
|
|
16
12
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
17
13
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
18
14
|
|
|
19
15
|
|
|
20
|
-
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
|
|
21
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
22
16
|
def lce_forward_deprecated(
|
|
23
17
|
self,
|
|
24
18
|
input_ids: torch.LongTensor = None,
|
|
@@ -128,8 +122,6 @@ def lce_forward_deprecated(
|
|
|
128
122
|
|
|
129
123
|
|
|
130
124
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
131
|
-
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
|
|
132
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
133
125
|
def lce_forward(
|
|
134
126
|
self,
|
|
135
127
|
input_ids: torch.LongTensor = None,
|
|
@@ -7,18 +7,12 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
9
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
10
|
-
from transformers.models.qwen2.modeling_qwen2 import _CONFIG_FOR_DOC
|
|
11
|
-
from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING
|
|
12
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
|
-
from transformers.utils import replace_return_docstrings
|
|
14
10
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
15
11
|
|
|
16
12
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
17
13
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
18
14
|
|
|
19
15
|
|
|
20
|
-
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
21
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
22
16
|
def lce_forward_deprecated(
|
|
23
17
|
self,
|
|
24
18
|
input_ids: torch.LongTensor = None,
|
|
@@ -127,8 +121,6 @@ def lce_forward_deprecated(
|
|
|
127
121
|
|
|
128
122
|
|
|
129
123
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
130
|
-
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
131
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
132
124
|
def lce_forward(
|
|
133
125
|
self,
|
|
134
126
|
input_ids: torch.LongTensor = None,
|
|
@@ -6,17 +6,11 @@ from typing import Union
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
|
-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import _CONFIG_FOR_DOC
|
|
10
|
-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import QWEN2_5_VL_INPUTS_DOCSTRING
|
|
11
9
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast
|
|
12
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
|
-
from transformers.utils import replace_return_docstrings
|
|
14
10
|
|
|
15
11
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
16
12
|
|
|
17
13
|
|
|
18
|
-
@add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)
|
|
19
|
-
@replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
20
14
|
def lce_forward(
|
|
21
15
|
self,
|
|
22
16
|
input_ids: torch.LongTensor = None,
|
|
@@ -8,17 +8,11 @@ import torch
|
|
|
8
8
|
from packaging import version
|
|
9
9
|
from torch.nn import CrossEntropyLoss
|
|
10
10
|
from transformers import __version__ as transformers_version
|
|
11
|
-
from transformers.models.qwen2_vl.modeling_qwen2_vl import _CONFIG_FOR_DOC
|
|
12
|
-
from transformers.models.qwen2_vl.modeling_qwen2_vl import QWEN2_VL_INPUTS_DOCSTRING
|
|
13
11
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
|
|
14
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
15
|
-
from transformers.utils import replace_return_docstrings
|
|
16
12
|
|
|
17
13
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
18
14
|
|
|
19
15
|
|
|
20
|
-
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
|
|
21
|
-
@replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
22
16
|
def lce_forward(
|
|
23
17
|
self,
|
|
24
18
|
input_ids: torch.LongTensor = None,
|
|
@@ -5,16 +5,10 @@ from typing import Union
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
8
|
-
from transformers.models.qwen3.modeling_qwen3 import _CONFIG_FOR_DOC
|
|
9
|
-
from transformers.models.qwen3.modeling_qwen3 import QWEN3_INPUTS_DOCSTRING
|
|
10
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
11
|
-
from transformers.utils import replace_return_docstrings
|
|
12
8
|
|
|
13
9
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
14
10
|
|
|
15
11
|
|
|
16
|
-
@add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)
|
|
17
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
18
12
|
def lce_forward(
|
|
19
13
|
self,
|
|
20
14
|
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
|
8
|
+
from transformers.modeling_outputs import MoeModelOutputWithPast
|
|
9
|
+
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
|
|
10
|
+
|
|
11
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def lce_forward(
|
|
15
|
+
self,
|
|
16
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
17
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
18
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
19
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
20
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
21
|
+
labels: Optional[torch.LongTensor] = None,
|
|
22
|
+
use_cache: Optional[bool] = None,
|
|
23
|
+
output_attentions: Optional[bool] = None,
|
|
24
|
+
output_hidden_states: Optional[bool] = None,
|
|
25
|
+
output_router_logits: Optional[bool] = None,
|
|
26
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
27
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
28
|
+
**loss_kwargs,
|
|
29
|
+
) -> MoeCausalLMOutputWithPast:
|
|
30
|
+
r"""
|
|
31
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
32
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
33
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
34
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
35
|
+
|
|
36
|
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
37
|
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
38
|
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
39
|
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
40
|
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
41
|
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
|
|
45
|
+
Example:
|
|
46
|
+
|
|
47
|
+
```python
|
|
48
|
+
>>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM
|
|
49
|
+
|
|
50
|
+
>>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
|
51
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
|
52
|
+
|
|
53
|
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
54
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
55
|
+
|
|
56
|
+
>>> # Generate
|
|
57
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
58
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
59
|
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
60
|
+
```"""
|
|
61
|
+
|
|
62
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
63
|
+
output_router_logits = (
|
|
64
|
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
output_hidden_states = (
|
|
68
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
72
|
+
outputs: MoeModelOutputWithPast = self.model(
|
|
73
|
+
input_ids=input_ids,
|
|
74
|
+
attention_mask=attention_mask,
|
|
75
|
+
position_ids=position_ids,
|
|
76
|
+
past_key_values=past_key_values,
|
|
77
|
+
inputs_embeds=inputs_embeds,
|
|
78
|
+
use_cache=use_cache,
|
|
79
|
+
output_attentions=output_attentions,
|
|
80
|
+
output_hidden_states=output_hidden_states,
|
|
81
|
+
output_router_logits=output_router_logits,
|
|
82
|
+
cache_position=cache_position,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
hidden_states = outputs.last_hidden_state
|
|
86
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
87
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
88
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
89
|
+
|
|
90
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
91
|
+
logits = None
|
|
92
|
+
loss = None
|
|
93
|
+
|
|
94
|
+
# if in training mode, do not materialize logits
|
|
95
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
96
|
+
loss = LigerForCausalLMLoss(
|
|
97
|
+
hidden_states=kept_hidden_states,
|
|
98
|
+
lm_head_weight=self.lm_head.weight,
|
|
99
|
+
labels=labels,
|
|
100
|
+
shift_labels=shift_labels,
|
|
101
|
+
hidden_size=self.config.hidden_size,
|
|
102
|
+
**loss_kwargs,
|
|
103
|
+
)
|
|
104
|
+
else: # if in inference model materialize logits
|
|
105
|
+
logits = self.lm_head(kept_hidden_states)
|
|
106
|
+
if labels is not None:
|
|
107
|
+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
|
108
|
+
|
|
109
|
+
aux_loss = None
|
|
110
|
+
if output_router_logits:
|
|
111
|
+
aux_loss = load_balancing_loss_func(
|
|
112
|
+
outputs.router_logits,
|
|
113
|
+
self.num_experts,
|
|
114
|
+
self.num_experts_per_tok,
|
|
115
|
+
attention_mask,
|
|
116
|
+
)
|
|
117
|
+
if labels is not None:
|
|
118
|
+
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
|
119
|
+
|
|
120
|
+
return MoeCausalLMOutputWithPast(
|
|
121
|
+
loss=loss,
|
|
122
|
+
aux_loss=aux_loss,
|
|
123
|
+
logits=logits,
|
|
124
|
+
past_key_values=outputs.past_key_values,
|
|
125
|
+
hidden_states=outputs.hidden_states,
|
|
126
|
+
attentions=outputs.attentions,
|
|
127
|
+
router_logits=outputs.router_logits,
|
|
128
|
+
)
|
|
@@ -35,6 +35,13 @@ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
|
|
|
35
35
|
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
|
|
36
36
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
37
37
|
|
|
38
|
+
try:
|
|
39
|
+
import peft
|
|
40
|
+
|
|
41
|
+
PEFT_AVAILABLE = True
|
|
42
|
+
except ImportError:
|
|
43
|
+
PEFT_AVAILABLE = False
|
|
44
|
+
|
|
38
45
|
transformer_version = version.parse(transformers.__version__)
|
|
39
46
|
|
|
40
47
|
logger = logging.getLogger(__name__)
|
|
@@ -48,22 +55,68 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
|
|
|
48
55
|
|
|
49
56
|
|
|
50
57
|
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
|
|
51
|
-
module
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
+
# Check if the module is a PEFT ModulesToSaveWrapper
|
|
59
|
+
# If it is, we need to patch the modules_to_save.default and original_modules
|
|
60
|
+
if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
|
|
61
|
+
module.modules_to_save.default.offset = offset
|
|
62
|
+
module.modules_to_save.default.casting_mode = casting_mode
|
|
63
|
+
module.modules_to_save.default.variance_epsilon = (
|
|
64
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
65
|
+
)
|
|
66
|
+
module.modules_to_save.default.in_place = in_place
|
|
67
|
+
module.original_module.offset = offset
|
|
68
|
+
module.original_module.casting_mode = casting_mode
|
|
69
|
+
module.original_module.variance_epsilon = (
|
|
70
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
71
|
+
)
|
|
72
|
+
module.original_module.in_place = in_place
|
|
73
|
+
_bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
|
|
74
|
+
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
|
|
75
|
+
_bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
|
|
76
|
+
_bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
77
|
+
module.modules_to_save.default.__class__.__name__ = LigerRMSNorm.__name__
|
|
78
|
+
module.original_module.__class__.__name__ = LigerRMSNorm.__name__
|
|
79
|
+
else:
|
|
80
|
+
module.offset = offset
|
|
81
|
+
module.casting_mode = casting_mode
|
|
82
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
83
|
+
module.in_place = in_place
|
|
84
|
+
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
|
85
|
+
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
86
|
+
module.__class__.__name__ = LigerRMSNorm.__name__
|
|
58
87
|
|
|
59
88
|
|
|
60
89
|
def _patch_layer_norm_module(module, eps=1e-6):
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
90
|
+
# Check if the module is a PEFT ModulesToSaveWrapper
|
|
91
|
+
# If it is, we need to patch the modules_to_save.default and original_modules
|
|
92
|
+
if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
|
|
93
|
+
module.hidden_size = module.normalized_shape
|
|
94
|
+
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
95
|
+
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
96
|
+
module.modules_to_save.default.variance_epsilon = (
|
|
97
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
98
|
+
)
|
|
99
|
+
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
|
|
100
|
+
module, "normalized_shape", None
|
|
101
|
+
)
|
|
102
|
+
module.original_module.variance_epsilon = (
|
|
103
|
+
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
104
|
+
)
|
|
105
|
+
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
|
|
106
|
+
module, "normalized_shape", None
|
|
107
|
+
)
|
|
108
|
+
_bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
|
|
109
|
+
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
|
|
110
|
+
_bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
|
|
111
|
+
_bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
112
|
+
module.modules_to_save.default.__class__.__name__ = LigerLayerNorm.__name__
|
|
113
|
+
module.original_module.__class__.__name__ = LigerLayerNorm.__name__
|
|
114
|
+
else:
|
|
115
|
+
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
116
|
+
module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
|
|
117
|
+
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
118
|
+
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
119
|
+
module.__class__.__name__ = LigerLayerNorm.__name__
|
|
67
120
|
|
|
68
121
|
|
|
69
122
|
def _patch_swiglu_module(module, liger_module):
|
|
@@ -1102,6 +1155,61 @@ def apply_liger_kernel_to_qwen3(
|
|
|
1102
1155
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1103
1156
|
|
|
1104
1157
|
|
|
1158
|
+
def apply_liger_kernel_to_qwen3_moe(
|
|
1159
|
+
rope: bool = True,
|
|
1160
|
+
cross_entropy: bool = False,
|
|
1161
|
+
fused_linear_cross_entropy: bool = True,
|
|
1162
|
+
rms_norm: bool = True,
|
|
1163
|
+
swiglu: bool = True,
|
|
1164
|
+
model: PreTrainedModel = None,
|
|
1165
|
+
) -> None:
|
|
1166
|
+
"""
|
|
1167
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
|
|
1168
|
+
"""
|
|
1169
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
1170
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
1171
|
+
)
|
|
1172
|
+
|
|
1173
|
+
from transformers.models.qwen3_moe import modeling_qwen3_moe
|
|
1174
|
+
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeModel
|
|
1175
|
+
|
|
1176
|
+
from liger_kernel.transformers.model.qwen3_moe import lce_forward as qwen3_lce_forward
|
|
1177
|
+
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
|
|
1178
|
+
|
|
1179
|
+
if rope:
|
|
1180
|
+
modeling_qwen3_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
1181
|
+
|
|
1182
|
+
if rms_norm:
|
|
1183
|
+
modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
|
|
1184
|
+
|
|
1185
|
+
if cross_entropy:
|
|
1186
|
+
from transformers.loss.loss_utils import nn
|
|
1187
|
+
|
|
1188
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
1189
|
+
|
|
1190
|
+
if fused_linear_cross_entropy:
|
|
1191
|
+
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
|
|
1192
|
+
|
|
1193
|
+
if swiglu:
|
|
1194
|
+
modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
|
|
1195
|
+
|
|
1196
|
+
if model is not None:
|
|
1197
|
+
# The model instance already exists, so we need to additionally patch the
|
|
1198
|
+
# instance variables that reference already-instantiated modules
|
|
1199
|
+
|
|
1200
|
+
# get the base model from the model instance
|
|
1201
|
+
base_model: Qwen3MoeModel = getattr(model, model.base_model_prefix, model)
|
|
1202
|
+
|
|
1203
|
+
if rms_norm:
|
|
1204
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1205
|
+
for decoder_layer in base_model.layers:
|
|
1206
|
+
if swiglu:
|
|
1207
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
|
|
1208
|
+
if rms_norm:
|
|
1209
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
1210
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
1211
|
+
|
|
1212
|
+
|
|
1105
1213
|
def apply_liger_kernel_to_qwen2_vl(
|
|
1106
1214
|
rope: bool = True,
|
|
1107
1215
|
cross_entropy: bool = False,
|
|
@@ -1455,6 +1563,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
1455
1563
|
"olmo2": apply_liger_kernel_to_olmo2,
|
|
1456
1564
|
"qwen2": apply_liger_kernel_to_qwen2,
|
|
1457
1565
|
"qwen3": apply_liger_kernel_to_qwen3,
|
|
1566
|
+
"qwen3_moe": apply_liger_kernel_to_qwen3_moe,
|
|
1458
1567
|
"qwen2_vl": apply_liger_kernel_to_qwen2_vl,
|
|
1459
1568
|
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
|
|
1460
1569
|
"phi3": apply_liger_kernel_to_phi3,
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerSparsemax(nn.Module):
|
|
8
|
+
def __init__(self, dim: int = -1):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.dim = dim
|
|
11
|
+
|
|
12
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
13
|
+
return LigerSparsemaxFunction.apply(x, self.dim)
|
|
14
|
+
|
|
15
|
+
def extra_repr(self) -> str:
|
|
16
|
+
return f"dim={self.dim}"
|
|
@@ -56,3 +56,24 @@ class LigerPhi3SwiGLUMLP(nn.Module):
|
|
|
56
56
|
up_states = self.gate_up_proj(x)
|
|
57
57
|
gate, up_states = up_states.chunk(2, dim=-1)
|
|
58
58
|
return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class LigerQwen3MoeSwiGLUMLP(nn.Module):
|
|
62
|
+
"""
|
|
63
|
+
Patch Qwen3MoeMLP to use LigerSiLUMulFunction.
|
|
64
|
+
https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/qwen3_moe/modular_qwen3_moe.py#L57
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self, config, intermediate_size=None):
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.config = config
|
|
70
|
+
self.hidden_size = config.hidden_size
|
|
71
|
+
self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
|
|
72
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
73
|
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
74
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
75
|
+
if config.hidden_act not in ["silu", "swish"]:
|
|
76
|
+
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
|
77
|
+
|
|
78
|
+
def forward(self, x):
|
|
79
|
+
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Any
|
|
2
|
-
from typing import Callable
|
|
3
1
|
from typing import Dict
|
|
4
2
|
from typing import List
|
|
5
3
|
from typing import Literal
|
|
@@ -13,57 +11,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel
|
|
|
13
11
|
from trl.trainer import ORPOTrainer
|
|
14
12
|
|
|
15
13
|
from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class _FSDPForwardRedirection:
|
|
19
|
-
"""
|
|
20
|
-
Modified based on
|
|
21
|
-
https://github.com/Lightning-AI/pytorch-lightning/blob/d3f9c83d6efa4f1def36aa6c199600946cdb9117/src/lightning/pytorch/strategies/strategy.py#L601-L648
|
|
22
|
-
Redirect a method call through FullyShardedDataParallel.forward so that the FSDP module's root pre-forward and
|
|
23
|
-
post-forward can be properly executed around the method call.
|
|
24
|
-
This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only
|
|
25
|
-
the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving
|
|
26
|
-
GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`)
|
|
27
|
-
will not work because the first `nn.Embedding` layer is not independently wrapped as a FSDP module (because of
|
|
28
|
-
the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather
|
|
29
|
-
its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just
|
|
30
|
-
the `lm_head` part of a model, we need this trick too to properly get its params all-gathered.
|
|
31
|
-
"""
|
|
32
|
-
|
|
33
|
-
def __call__(
|
|
34
|
-
self,
|
|
35
|
-
wrapper_module: FullyShardedDataParallel,
|
|
36
|
-
method: Callable,
|
|
37
|
-
*args: Any,
|
|
38
|
-
**kwargs: Any,
|
|
39
|
-
):
|
|
40
|
-
"""Reroutes a method call through the `wrapper_module`'s `forward` method.
|
|
41
|
-
Args:
|
|
42
|
-
wrapper_module: The module that has `original_module` wrapped.
|
|
43
|
-
original_module: The module that was wrapped inside `wrapper_module`.
|
|
44
|
-
method_name: The name of the method that should be called on the `original_module` after inputs get
|
|
45
|
-
redirected through the `wrapper_module`'s `forward` method.
|
|
46
|
-
*args: The positional arguments to the method `method_name`. They will get passed to a patched
|
|
47
|
-
`forward` method instead.
|
|
48
|
-
**kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
|
|
49
|
-
`forward` method instead.
|
|
50
|
-
"""
|
|
51
|
-
assert isinstance(wrapper_module, FullyShardedDataParallel)
|
|
52
|
-
original_module = wrapper_module._fsdp_wrapped_module
|
|
53
|
-
original_forward = original_module.forward
|
|
54
|
-
|
|
55
|
-
def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
|
|
56
|
-
# Unpatch ourselves immediately before calling the method `method_name`
|
|
57
|
-
# because itself may want to call the real `forward`
|
|
58
|
-
original_module.forward = original_forward # type: ignore[method-assign]
|
|
59
|
-
# Call the actual method e.g. `.training_step(...)`
|
|
60
|
-
out = method(*_args, **_kwargs)
|
|
61
|
-
return out
|
|
62
|
-
|
|
63
|
-
# Patch the original_module's forward so we can redirect the arguments back to the real method
|
|
64
|
-
original_module.forward = wrapped_forward # type: ignore[method-assign]
|
|
65
|
-
wrapper_output = wrapper_module(*args, **kwargs)
|
|
66
|
-
return wrapper_output
|
|
14
|
+
from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
|
|
67
15
|
|
|
68
16
|
|
|
69
17
|
class LigerORPOTrainer(ORPOTrainer):
|
liger_kernel/utils.py
CHANGED
|
@@ -1,6 +1,17 @@
|
|
|
1
|
+
try:
|
|
2
|
+
import peft # noqa: F401
|
|
3
|
+
|
|
4
|
+
PEFT_AVAILABLE = True
|
|
5
|
+
except ImportError:
|
|
6
|
+
PEFT_AVAILABLE = False
|
|
7
|
+
|
|
1
8
|
import torch
|
|
2
9
|
|
|
3
10
|
|
|
11
|
+
def is_peft_available():
|
|
12
|
+
return PEFT_AVAILABLE
|
|
13
|
+
|
|
14
|
+
|
|
4
15
|
def infer_device():
|
|
5
16
|
"""
|
|
6
17
|
Get current device name based on available devices
|