liger-kernel 0.5.8__py3-none-any.whl → 0.5.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/dpo_loss.py +8 -1
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/jsd_loss.py +2 -2
- liger_kernel/ops/cross_entropy.py +4 -1
- liger_kernel/ops/dyt.py +113 -179
- liger_kernel/ops/fused_linear_cross_entropy.py +4 -3
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/sparsemax.py +167 -0
- liger_kernel/transformers/__init__.py +11 -0
- liger_kernel/transformers/dyt.py +5 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +8 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +1 -2
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/model/gemma.py +8 -12
- liger_kernel/transformers/model/gemma2.py +8 -10
- liger_kernel/transformers/model/gemma3.py +3 -9
- liger_kernel/transformers/model/glm4.py +119 -0
- liger_kernel/transformers/model/llama.py +64 -15
- liger_kernel/transformers/model/llava.py +0 -8
- liger_kernel/transformers/model/mistral.py +8 -10
- liger_kernel/transformers/model/mixtral.py +8 -12
- liger_kernel/transformers/model/mllama.py +8 -11
- liger_kernel/transformers/model/olmo2.py +8 -10
- liger_kernel/transformers/model/paligemma.py +0 -8
- liger_kernel/transformers/model/phi3.py +8 -12
- liger_kernel/transformers/model/qwen2.py +8 -12
- liger_kernel/transformers/model/qwen2_5_vl.py +3 -7
- liger_kernel/transformers/model/qwen2_vl.py +3 -7
- liger_kernel/transformers/model/qwen3.py +112 -0
- liger_kernel/transformers/model/qwen3_moe.py +128 -0
- liger_kernel/transformers/monkey_patch.py +243 -13
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +21 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/utils.py +11 -0
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/METADATA +36 -20
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/RECORD +42 -34
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/WHEEL +1 -1
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from liger_kernel.ops.grpo_loss import GrpoLossFunction
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def triton_grpo_loss(
|
|
5
|
+
logits,
|
|
6
|
+
old_logp,
|
|
7
|
+
ref_logp,
|
|
8
|
+
completion_ids,
|
|
9
|
+
advantages,
|
|
10
|
+
completion_mask=None,
|
|
11
|
+
temperature=0.9,
|
|
12
|
+
beta=0.04,
|
|
13
|
+
eps_low=0.2,
|
|
14
|
+
eps_high=0.4,
|
|
15
|
+
inplace=True,
|
|
16
|
+
):
|
|
17
|
+
assert logits is not None and completion_ids is not None and advantages is not None, (
|
|
18
|
+
"must provide logits、completion_ids and advantages"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
return GrpoLossFunction.apply(
|
|
22
|
+
logits,
|
|
23
|
+
old_logp,
|
|
24
|
+
ref_logp,
|
|
25
|
+
completion_ids,
|
|
26
|
+
advantages,
|
|
27
|
+
completion_mask,
|
|
28
|
+
temperature,
|
|
29
|
+
beta,
|
|
30
|
+
eps_low,
|
|
31
|
+
eps_high,
|
|
32
|
+
inplace,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# This is a demo how to use grpo_loss in GRPOTrainer. The Trl version must be 0.16
|
|
37
|
+
"""
|
|
38
|
+
import torch
|
|
39
|
+
import trl
|
|
40
|
+
assert trl.__version__.startswith("0.16"), "please pip install trl==0.16"
|
|
41
|
+
from trl.extras.profiling import profiling_decorator
|
|
42
|
+
|
|
43
|
+
@profiling_decorator
|
|
44
|
+
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
|
45
|
+
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
|
46
|
+
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
|
|
47
|
+
return fused_selective_log_softmax(logits, input_ids, self.temperature, mask=attention_mask)
|
|
48
|
+
|
|
49
|
+
@profiling_decorator
|
|
50
|
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
|
51
|
+
if return_outputs:
|
|
52
|
+
raise ValueError("The GRPOTrainer does not support returning outputs")
|
|
53
|
+
# Compute the per-token log probabilities for the model
|
|
54
|
+
|
|
55
|
+
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
|
56
|
+
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
|
|
57
|
+
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
|
58
|
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
|
59
|
+
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
|
60
|
+
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
|
|
61
|
+
|
|
62
|
+
ref_per_token_logps = inputs["ref_per_token_logps"]
|
|
63
|
+
advantages = inputs["advantages"]
|
|
64
|
+
old_per_token_logps = inputs["old_per_token_logps"]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
per_token_loss, per_token_kl, is_clipped = triton_grpo_loss(logits,
|
|
68
|
+
old_per_token_logps,
|
|
69
|
+
ref_per_token_logps,
|
|
70
|
+
completion_ids,
|
|
71
|
+
advantages,
|
|
72
|
+
completion_mask,
|
|
73
|
+
self.temperature,
|
|
74
|
+
self.beta,
|
|
75
|
+
self.epsilon_low,
|
|
76
|
+
self.epsilon_high,)
|
|
77
|
+
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
|
|
78
|
+
|
|
79
|
+
# Log the metrics
|
|
80
|
+
mode = "eval" if self.control.should_evaluate else "train"
|
|
81
|
+
|
|
82
|
+
if self.beta != 0.0:
|
|
83
|
+
mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
|
|
84
|
+
self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
|
85
|
+
|
|
86
|
+
clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
|
|
87
|
+
self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
|
|
88
|
+
return loss
|
|
89
|
+
|
|
90
|
+
trl.GRPOTrainer._get_per_token_logps = _get_per_token_logps
|
|
91
|
+
trl.GRPOTrainer.compute_loss = compute_loss
|
|
92
|
+
trigger = None
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
# add this line at the first line of grpo.py in open-r1
|
|
96
|
+
"""
|
|
97
|
+
from liger_kernel.transformers.grpo_loss import trigger
|
|
98
|
+
"""
|
|
@@ -8,18 +8,12 @@ import torch
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
9
|
from transformers.cache_utils import Cache
|
|
10
10
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
11
|
-
from transformers.models.gemma.modeling_gemma import _CONFIG_FOR_DOC
|
|
12
|
-
from transformers.models.gemma.modeling_gemma import GEMMA_INPUTS_DOCSTRING
|
|
13
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
14
|
-
from transformers.utils import replace_return_docstrings
|
|
15
11
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
16
12
|
|
|
17
13
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
18
14
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
19
15
|
|
|
20
16
|
|
|
21
|
-
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
|
22
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
23
17
|
def lce_forward_deprecated(
|
|
24
18
|
self,
|
|
25
19
|
input_ids: torch.LongTensor = None,
|
|
@@ -129,8 +123,6 @@ def lce_forward_deprecated(
|
|
|
129
123
|
|
|
130
124
|
|
|
131
125
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
132
|
-
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
|
133
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
134
126
|
def lce_forward(
|
|
135
127
|
self,
|
|
136
128
|
input_ids: torch.LongTensor = None,
|
|
@@ -200,21 +192,25 @@ def lce_forward(
|
|
|
200
192
|
)
|
|
201
193
|
|
|
202
194
|
hidden_states = outputs[0]
|
|
195
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
196
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
197
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
203
198
|
|
|
199
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
204
200
|
logits = None
|
|
205
201
|
loss = None
|
|
206
202
|
# if in training mode, don't materialize logits
|
|
207
|
-
if self.training and (labels is not None):
|
|
203
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
208
204
|
loss = LigerForCausalLMLoss(
|
|
209
|
-
hidden_states=
|
|
205
|
+
hidden_states=kept_hidden_states,
|
|
210
206
|
lm_head_weight=self.lm_head.weight,
|
|
211
207
|
labels=labels,
|
|
208
|
+
shift_labels=shift_labels,
|
|
212
209
|
hidden_size=self.config.hidden_size,
|
|
213
210
|
**loss_kwargs,
|
|
214
211
|
)
|
|
215
212
|
else: # if in inference mode materialize logits
|
|
216
|
-
|
|
217
|
-
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
213
|
+
logits = self.lm_head(kept_hidden_states)
|
|
218
214
|
if labels is not None:
|
|
219
215
|
loss = self.loss_function(
|
|
220
216
|
logits=logits,
|
|
@@ -9,10 +9,6 @@ import torch
|
|
|
9
9
|
from torch.nn import CrossEntropyLoss
|
|
10
10
|
from transformers.cache_utils import HybridCache
|
|
11
11
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
12
|
-
from transformers.models.gemma2.modeling_gemma2 import _CONFIG_FOR_DOC
|
|
13
|
-
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING
|
|
14
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
15
|
-
from transformers.utils import replace_return_docstrings
|
|
16
12
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
17
13
|
|
|
18
14
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
@@ -136,8 +132,6 @@ def lce_forward_deprecated(
|
|
|
136
132
|
|
|
137
133
|
|
|
138
134
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
139
|
-
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
|
|
140
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
141
135
|
def lce_forward(
|
|
142
136
|
self,
|
|
143
137
|
input_ids: torch.LongTensor = None,
|
|
@@ -212,23 +206,27 @@ def lce_forward(
|
|
|
212
206
|
)
|
|
213
207
|
|
|
214
208
|
hidden_states = outputs[0]
|
|
209
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
210
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
211
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
215
212
|
|
|
213
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
216
214
|
logits = None
|
|
217
215
|
loss = None
|
|
218
216
|
# if in training mode, don't materialize logits
|
|
219
|
-
if self.training and (labels is not None):
|
|
217
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
220
218
|
loss = LigerForCausalLMLoss(
|
|
221
|
-
hidden_states=
|
|
219
|
+
hidden_states=kept_hidden_states,
|
|
222
220
|
lm_head_weight=self.lm_head.weight,
|
|
223
221
|
labels=labels,
|
|
222
|
+
shift_labels=shift_labels,
|
|
224
223
|
hidden_size=self.config.hidden_size,
|
|
225
224
|
final_logit_softcapping=self.config.final_logit_softcapping,
|
|
226
225
|
**loss_kwargs,
|
|
227
226
|
)
|
|
228
227
|
|
|
229
228
|
else: # if in inference mode materialize logits
|
|
230
|
-
|
|
231
|
-
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
229
|
+
logits = self.lm_head(kept_hidden_states)
|
|
232
230
|
if self.config.final_logit_softcapping is not None:
|
|
233
231
|
logits = logits / self.config.final_logit_softcapping
|
|
234
232
|
logits = torch.tanh(logits)
|
|
@@ -9,13 +9,9 @@ import torch.nn as nn
|
|
|
9
9
|
from transformers.cache_utils import Cache
|
|
10
10
|
from transformers.cache_utils import HybridCache
|
|
11
11
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
12
|
-
from transformers.models.gemma3.modeling_gemma3 import _CONFIG_FOR_DOC
|
|
13
|
-
from transformers.models.gemma3.modeling_gemma3 import GEMMA3_INPUTS_DOCSTRING
|
|
14
12
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast
|
|
15
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
16
13
|
from transformers.utils import is_torchdynamo_compiling
|
|
17
14
|
from transformers.utils import logging
|
|
18
|
-
from transformers.utils import replace_return_docstrings
|
|
19
15
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
20
16
|
|
|
21
17
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
@@ -25,8 +21,6 @@ logger = logging.get_logger(__name__)
|
|
|
25
21
|
|
|
26
22
|
|
|
27
23
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
28
|
-
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
|
29
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
30
24
|
def causal_forward(
|
|
31
25
|
self,
|
|
32
26
|
input_ids: torch.LongTensor = None,
|
|
@@ -104,13 +98,15 @@ def causal_forward(
|
|
|
104
98
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
105
99
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
106
100
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
101
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
107
102
|
loss = None
|
|
108
103
|
logits = None
|
|
109
|
-
if self.training and (labels is not None):
|
|
104
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
110
105
|
loss = LigerForCausalLMLoss(
|
|
111
106
|
hidden_states=kept_hidden_states,
|
|
112
107
|
lm_head_weight=self.lm_head.weight,
|
|
113
108
|
labels=labels,
|
|
109
|
+
shift_labels=shift_labels,
|
|
114
110
|
hidden_size=self.config.hidden_size,
|
|
115
111
|
final_logit_softcapping=self.config.final_logit_softcapping,
|
|
116
112
|
**loss_kwargs,
|
|
@@ -139,8 +135,6 @@ def causal_forward(
|
|
|
139
135
|
|
|
140
136
|
|
|
141
137
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
142
|
-
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
|
143
|
-
@replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
144
138
|
def multimodal_forward(
|
|
145
139
|
self,
|
|
146
140
|
input_ids: torch.LongTensor = None,
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
9
|
+
from transformers.utils.deprecation import deprecate_kwarg
|
|
10
|
+
|
|
11
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
15
|
+
def lce_forward(
|
|
16
|
+
self,
|
|
17
|
+
input_ids: torch.LongTensor = None,
|
|
18
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
19
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
20
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
21
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
22
|
+
labels: Optional[torch.LongTensor] = None,
|
|
23
|
+
use_cache: Optional[bool] = None,
|
|
24
|
+
output_attentions: Optional[bool] = None,
|
|
25
|
+
output_hidden_states: Optional[bool] = None,
|
|
26
|
+
return_dict: Optional[bool] = None,
|
|
27
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
28
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
29
|
+
**loss_kwargs,
|
|
30
|
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
31
|
+
r"""
|
|
32
|
+
Args:
|
|
33
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
34
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
35
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
36
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
37
|
+
|
|
38
|
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
39
|
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
40
|
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
41
|
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
42
|
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
43
|
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
|
|
47
|
+
Example:
|
|
48
|
+
|
|
49
|
+
```python
|
|
50
|
+
>>> from transformers import AutoTokenizer, Glm4ForCausalLM
|
|
51
|
+
|
|
52
|
+
>>> model = Glm4ForCausalLM.from_pretrained("THUDM/GLM-4-9B-0414")
|
|
53
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("THUDM/GLM-4-9B-0414")
|
|
54
|
+
|
|
55
|
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
56
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
57
|
+
|
|
58
|
+
>>> # Generate
|
|
59
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
60
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
61
|
+
'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
|
|
62
|
+
```
|
|
63
|
+
"""
|
|
64
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
65
|
+
output_hidden_states = (
|
|
66
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
67
|
+
)
|
|
68
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
69
|
+
|
|
70
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
71
|
+
outputs = self.model(
|
|
72
|
+
input_ids=input_ids,
|
|
73
|
+
attention_mask=attention_mask,
|
|
74
|
+
position_ids=position_ids,
|
|
75
|
+
past_key_values=past_key_values,
|
|
76
|
+
inputs_embeds=inputs_embeds,
|
|
77
|
+
use_cache=use_cache,
|
|
78
|
+
output_attentions=output_attentions,
|
|
79
|
+
output_hidden_states=output_hidden_states,
|
|
80
|
+
return_dict=return_dict,
|
|
81
|
+
cache_position=cache_position,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
hidden_states = outputs[0]
|
|
85
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
86
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
87
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
88
|
+
|
|
89
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
90
|
+
logits = None
|
|
91
|
+
loss = None
|
|
92
|
+
# if in training mode, don't materialize logits
|
|
93
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
94
|
+
loss = LigerForCausalLMLoss(
|
|
95
|
+
hidden_states=kept_hidden_states,
|
|
96
|
+
lm_head_weight=self.lm_head.weight,
|
|
97
|
+
labels=labels,
|
|
98
|
+
shift_labels=shift_labels,
|
|
99
|
+
hidden_size=self.config.hidden_size,
|
|
100
|
+
**loss_kwargs,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
else: # if in inference mode materialize logits
|
|
104
|
+
logits = self.lm_head(kept_hidden_states)
|
|
105
|
+
if labels is not None:
|
|
106
|
+
loss = self.loss_function(
|
|
107
|
+
logits=logits,
|
|
108
|
+
labels=labels,
|
|
109
|
+
vocab_size=self.config.vocab_size,
|
|
110
|
+
**loss_kwargs,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return CausalLMOutputWithPast(
|
|
114
|
+
loss=loss,
|
|
115
|
+
logits=logits,
|
|
116
|
+
past_key_values=outputs.past_key_values,
|
|
117
|
+
hidden_states=outputs.hidden_states,
|
|
118
|
+
attentions=outputs.attentions,
|
|
119
|
+
)
|
|
@@ -7,23 +7,23 @@ from typing import Union
|
|
|
7
7
|
import torch
|
|
8
8
|
import torch.nn.functional as F
|
|
9
9
|
|
|
10
|
+
from torch.distributed.fsdp import FullyShardedDataParallel
|
|
10
11
|
from torch.nn import CrossEntropyLoss
|
|
11
12
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
12
|
-
from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC
|
|
13
|
-
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING
|
|
14
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
15
|
-
from transformers.utils import replace_return_docstrings
|
|
16
13
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
17
14
|
|
|
15
|
+
from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
|
|
18
16
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
19
17
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
18
|
+
from liger_kernel.utils import PEFT_AVAILABLE
|
|
20
19
|
|
|
21
20
|
if TYPE_CHECKING:
|
|
22
21
|
from transformers.cache_utils import Cache
|
|
23
22
|
|
|
23
|
+
if PEFT_AVAILABLE:
|
|
24
|
+
from peft.utils.other import ModulesToSaveWrapper
|
|
25
|
+
|
|
24
26
|
|
|
25
|
-
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
|
26
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
27
27
|
def lce_forward_deprecated(
|
|
28
28
|
self,
|
|
29
29
|
input_ids: torch.LongTensor = None,
|
|
@@ -137,8 +137,6 @@ def lce_forward_deprecated(
|
|
|
137
137
|
|
|
138
138
|
|
|
139
139
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
140
|
-
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
|
141
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
142
140
|
def lce_forward(
|
|
143
141
|
self,
|
|
144
142
|
input_ids: torch.LongTensor = None,
|
|
@@ -209,25 +207,29 @@ def lce_forward(
|
|
|
209
207
|
)
|
|
210
208
|
|
|
211
209
|
hidden_states = outputs[0]
|
|
210
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
211
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
212
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
212
213
|
|
|
213
214
|
if self.config.pretraining_tp > 1:
|
|
214
215
|
raise Exception("Liger Kernel does not support pretraining_tp!!")
|
|
215
216
|
|
|
217
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
216
218
|
logits = None
|
|
217
219
|
loss = None
|
|
218
220
|
# if in training mode, don't materialize logits
|
|
219
|
-
if self.training and (labels is not None):
|
|
220
|
-
loss =
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
labels=labels,
|
|
221
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
222
|
+
loss = lce_maybe_trainable_lm_head(
|
|
223
|
+
self,
|
|
224
|
+
hidden_states=kept_hidden_states,
|
|
224
225
|
hidden_size=self.config.hidden_size,
|
|
226
|
+
labels=labels,
|
|
227
|
+
shift_labels=shift_labels,
|
|
225
228
|
**loss_kwargs,
|
|
226
229
|
)
|
|
227
230
|
|
|
228
231
|
else: # if in inference mode materialize logits
|
|
229
|
-
|
|
230
|
-
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
232
|
+
logits = self.lm_head(kept_hidden_states)
|
|
231
233
|
if labels is not None:
|
|
232
234
|
loss = self.loss_function(
|
|
233
235
|
logits=logits,
|
|
@@ -247,3 +249,50 @@ def lce_forward(
|
|
|
247
249
|
hidden_states=outputs.hidden_states,
|
|
248
250
|
attentions=outputs.attentions,
|
|
249
251
|
)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def lce_maybe_trainable_lm_head(self, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
|
|
255
|
+
lm_head = self.lm_head
|
|
256
|
+
|
|
257
|
+
# Unwrap the module if lm_head has been added as trainable module in PEFT LoRA configuration,
|
|
258
|
+
# i.e. listed in the modules_to_save field of LoraConfig, so the lm_head weights are read
|
|
259
|
+
# from the unwrapped module.
|
|
260
|
+
# See https://huggingface.co/docs/peft/package_reference/lora for reference.
|
|
261
|
+
if PEFT_AVAILABLE and isinstance(lm_head, ModulesToSaveWrapper):
|
|
262
|
+
lm_head = lm_head.modules_to_save.default
|
|
263
|
+
|
|
264
|
+
# If FSDP is used and lm_head is trainable, e.g., during full fine-tuning or with LoRA,
|
|
265
|
+
# reading the lm_head module weights and calling the kernel must be done within FSDP forward pass
|
|
266
|
+
# so the module entire parameters are summoned and kept in memory during the kernel execution.
|
|
267
|
+
if isinstance(lm_head, FullyShardedDataParallel):
|
|
268
|
+
return _FSDPForwardRedirection()(
|
|
269
|
+
lm_head,
|
|
270
|
+
_liger_for_causal_lm_loss,
|
|
271
|
+
lm_head.module,
|
|
272
|
+
hidden_states,
|
|
273
|
+
hidden_size,
|
|
274
|
+
labels,
|
|
275
|
+
shift_labels,
|
|
276
|
+
**loss_kwargs,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# FSDP is not used so we can read the lm_head weights and call the kernel directly
|
|
280
|
+
return _liger_for_causal_lm_loss(
|
|
281
|
+
lm_head=self.lm_head,
|
|
282
|
+
hidden_states=hidden_states,
|
|
283
|
+
hidden_size=hidden_size,
|
|
284
|
+
labels=labels,
|
|
285
|
+
shift_labels=shift_labels,
|
|
286
|
+
**loss_kwargs,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def _liger_for_causal_lm_loss(lm_head, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
|
|
291
|
+
return LigerForCausalLMLoss(
|
|
292
|
+
hidden_states=hidden_states,
|
|
293
|
+
lm_head_weight=lm_head.weight,
|
|
294
|
+
labels=labels,
|
|
295
|
+
hidden_size=hidden_size,
|
|
296
|
+
shift_labels=shift_labels,
|
|
297
|
+
**loss_kwargs,
|
|
298
|
+
)
|
|
@@ -5,19 +5,13 @@ from typing import Union
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from transformers.models.llava.modeling_llava import _CONFIG_FOR_DOC
|
|
9
|
-
from transformers.models.llava.modeling_llava import LLAVA_INPUTS_DOCSTRING
|
|
10
8
|
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
|
|
11
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
12
9
|
from transformers.utils import is_torchdynamo_compiling
|
|
13
|
-
from transformers.utils import replace_return_docstrings
|
|
14
10
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
15
11
|
|
|
16
12
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
17
13
|
|
|
18
14
|
|
|
19
|
-
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
|
|
20
|
-
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
21
15
|
def lce_forward_deprecated(
|
|
22
16
|
self,
|
|
23
17
|
input_ids: torch.LongTensor = None,
|
|
@@ -210,9 +204,7 @@ def lce_forward_deprecated(
|
|
|
210
204
|
)
|
|
211
205
|
|
|
212
206
|
|
|
213
|
-
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
|
|
214
207
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
215
|
-
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
216
208
|
def lce_forward(
|
|
217
209
|
self,
|
|
218
210
|
input_ids: torch.LongTensor = None,
|
|
@@ -7,18 +7,12 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from transformers.cache_utils import Cache
|
|
9
9
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
10
|
-
from transformers.models.mistral.modeling_mistral import _CONFIG_FOR_DOC
|
|
11
|
-
from transformers.models.mistral.modeling_mistral import MISTRAL_INPUTS_DOCSTRING
|
|
12
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
|
-
from transformers.utils import replace_return_docstrings
|
|
14
10
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
15
11
|
|
|
16
12
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
17
13
|
|
|
18
14
|
|
|
19
15
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
20
|
-
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
|
21
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
22
16
|
def lce_forward(
|
|
23
17
|
self,
|
|
24
18
|
input_ids: torch.LongTensor = None,
|
|
@@ -91,22 +85,26 @@ def lce_forward(
|
|
|
91
85
|
)
|
|
92
86
|
|
|
93
87
|
hidden_states = outputs[0]
|
|
88
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
89
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
90
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
94
91
|
|
|
92
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
95
93
|
loss = None
|
|
96
94
|
logits = None
|
|
97
95
|
|
|
98
|
-
if self.training and (labels is not None):
|
|
96
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
99
97
|
loss = LigerForCausalLMLoss(
|
|
100
|
-
hidden_states=
|
|
98
|
+
hidden_states=kept_hidden_states,
|
|
101
99
|
lm_head_weight=self.lm_head.weight,
|
|
102
100
|
labels=labels,
|
|
101
|
+
shift_labels=shift_labels,
|
|
103
102
|
hidden_size=self.config.hidden_size,
|
|
104
103
|
**loss_kwargs,
|
|
105
104
|
)
|
|
106
105
|
|
|
107
106
|
else:
|
|
108
|
-
|
|
109
|
-
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
107
|
+
logits = self.lm_head(kept_hidden_states)
|
|
110
108
|
|
|
111
109
|
loss = None
|
|
112
110
|
if labels is not None:
|
|
@@ -7,19 +7,13 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
9
|
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
|
10
|
-
from transformers.models.mixtral.modeling_mixtral import _CONFIG_FOR_DOC
|
|
11
|
-
from transformers.models.mixtral.modeling_mixtral import MIXTRAL_INPUTS_DOCSTRING
|
|
12
10
|
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
|
|
13
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
14
|
-
from transformers.utils import replace_return_docstrings
|
|
15
11
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
16
12
|
|
|
17
13
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
18
14
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
19
15
|
|
|
20
16
|
|
|
21
|
-
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
|
|
22
|
-
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
23
17
|
def lce_forward_deprecated(
|
|
24
18
|
self,
|
|
25
19
|
input_ids: torch.LongTensor = None,
|
|
@@ -146,8 +140,6 @@ def lce_forward_deprecated(
|
|
|
146
140
|
|
|
147
141
|
|
|
148
142
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
149
|
-
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
|
|
150
|
-
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
151
143
|
# Ignore copy
|
|
152
144
|
def lce_forward(
|
|
153
145
|
self,
|
|
@@ -225,22 +217,26 @@ def lce_forward(
|
|
|
225
217
|
)
|
|
226
218
|
|
|
227
219
|
hidden_states = outputs[0]
|
|
220
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
221
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
222
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
228
223
|
|
|
224
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
229
225
|
logits = None
|
|
230
226
|
loss = None
|
|
231
227
|
# if in training mode, don't materialize logits
|
|
232
|
-
if self.training and (labels is not None):
|
|
228
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
233
229
|
loss = LigerForCausalLMLoss(
|
|
234
|
-
hidden_states=
|
|
230
|
+
hidden_states=kept_hidden_states,
|
|
235
231
|
lm_head_weight=self.lm_head.weight,
|
|
236
232
|
labels=labels,
|
|
233
|
+
shift_labels=shift_labels,
|
|
237
234
|
hidden_size=self.config.hidden_size,
|
|
238
235
|
**loss_kwargs,
|
|
239
236
|
)
|
|
240
237
|
|
|
241
238
|
else: # if in inference mode materialize logits
|
|
242
|
-
|
|
243
|
-
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
239
|
+
logits = self.lm_head(kept_hidden_states)
|
|
244
240
|
|
|
245
241
|
loss = None
|
|
246
242
|
if labels is not None:
|