liger-kernel 0.5.8__py3-none-any.whl → 0.5.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. liger_kernel/chunked_loss/dpo_loss.py +8 -1
  2. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  3. liger_kernel/chunked_loss/jsd_loss.py +2 -2
  4. liger_kernel/ops/cross_entropy.py +4 -1
  5. liger_kernel/ops/dyt.py +113 -179
  6. liger_kernel/ops/fused_linear_cross_entropy.py +4 -3
  7. liger_kernel/ops/grpo_loss.py +310 -0
  8. liger_kernel/ops/sparsemax.py +167 -0
  9. liger_kernel/transformers/__init__.py +11 -0
  10. liger_kernel/transformers/dyt.py +5 -3
  11. liger_kernel/transformers/fsdp.py +55 -0
  12. liger_kernel/transformers/functional.py +8 -0
  13. liger_kernel/transformers/fused_linear_cross_entropy.py +1 -2
  14. liger_kernel/transformers/grpo_loss.py +98 -0
  15. liger_kernel/transformers/model/gemma.py +8 -12
  16. liger_kernel/transformers/model/gemma2.py +8 -10
  17. liger_kernel/transformers/model/gemma3.py +3 -9
  18. liger_kernel/transformers/model/glm4.py +119 -0
  19. liger_kernel/transformers/model/llama.py +64 -15
  20. liger_kernel/transformers/model/llava.py +0 -8
  21. liger_kernel/transformers/model/mistral.py +8 -10
  22. liger_kernel/transformers/model/mixtral.py +8 -12
  23. liger_kernel/transformers/model/mllama.py +8 -11
  24. liger_kernel/transformers/model/olmo2.py +8 -10
  25. liger_kernel/transformers/model/paligemma.py +0 -8
  26. liger_kernel/transformers/model/phi3.py +8 -12
  27. liger_kernel/transformers/model/qwen2.py +8 -12
  28. liger_kernel/transformers/model/qwen2_5_vl.py +3 -7
  29. liger_kernel/transformers/model/qwen2_vl.py +3 -7
  30. liger_kernel/transformers/model/qwen3.py +112 -0
  31. liger_kernel/transformers/model/qwen3_moe.py +128 -0
  32. liger_kernel/transformers/monkey_patch.py +243 -13
  33. liger_kernel/transformers/sparsemax.py +16 -0
  34. liger_kernel/transformers/swiglu.py +21 -0
  35. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  36. liger_kernel/utils.py +11 -0
  37. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/METADATA +36 -20
  38. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/RECORD +42 -34
  39. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/WHEEL +1 -1
  40. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/licenses/LICENSE +0 -0
  41. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/licenses/NOTICE +0 -0
  42. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,98 @@
1
+ from liger_kernel.ops.grpo_loss import GrpoLossFunction
2
+
3
+
4
+ def triton_grpo_loss(
5
+ logits,
6
+ old_logp,
7
+ ref_logp,
8
+ completion_ids,
9
+ advantages,
10
+ completion_mask=None,
11
+ temperature=0.9,
12
+ beta=0.04,
13
+ eps_low=0.2,
14
+ eps_high=0.4,
15
+ inplace=True,
16
+ ):
17
+ assert logits is not None and completion_ids is not None and advantages is not None, (
18
+ "must provide logits、completion_ids and advantages"
19
+ )
20
+
21
+ return GrpoLossFunction.apply(
22
+ logits,
23
+ old_logp,
24
+ ref_logp,
25
+ completion_ids,
26
+ advantages,
27
+ completion_mask,
28
+ temperature,
29
+ beta,
30
+ eps_low,
31
+ eps_high,
32
+ inplace,
33
+ )
34
+
35
+
36
+ # This is a demo how to use grpo_loss in GRPOTrainer. The Trl version must be 0.16
37
+ """
38
+ import torch
39
+ import trl
40
+ assert trl.__version__.startswith("0.16"), "please pip install trl==0.16"
41
+ from trl.extras.profiling import profiling_decorator
42
+
43
+ @profiling_decorator
44
+ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
45
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
46
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
47
+ return fused_selective_log_softmax(logits, input_ids, self.temperature, mask=attention_mask)
48
+
49
+ @profiling_decorator
50
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
51
+ if return_outputs:
52
+ raise ValueError("The GRPOTrainer does not support returning outputs")
53
+ # Compute the per-token log probabilities for the model
54
+
55
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
56
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
57
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
58
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
59
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
60
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
61
+
62
+ ref_per_token_logps = inputs["ref_per_token_logps"]
63
+ advantages = inputs["advantages"]
64
+ old_per_token_logps = inputs["old_per_token_logps"]
65
+
66
+
67
+ per_token_loss, per_token_kl, is_clipped = triton_grpo_loss(logits,
68
+ old_per_token_logps,
69
+ ref_per_token_logps,
70
+ completion_ids,
71
+ advantages,
72
+ completion_mask,
73
+ self.temperature,
74
+ self.beta,
75
+ self.epsilon_low,
76
+ self.epsilon_high,)
77
+ loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
78
+
79
+ # Log the metrics
80
+ mode = "eval" if self.control.should_evaluate else "train"
81
+
82
+ if self.beta != 0.0:
83
+ mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
84
+ self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
85
+
86
+ clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
87
+ self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
88
+ return loss
89
+
90
+ trl.GRPOTrainer._get_per_token_logps = _get_per_token_logps
91
+ trl.GRPOTrainer.compute_loss = compute_loss
92
+ trigger = None
93
+ """
94
+
95
+ # add this line at the first line of grpo.py in open-r1
96
+ """
97
+ from liger_kernel.transformers.grpo_loss import trigger
98
+ """
@@ -8,18 +8,12 @@ import torch
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.cache_utils import Cache
10
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
11
- from transformers.models.gemma.modeling_gemma import _CONFIG_FOR_DOC
12
- from transformers.models.gemma.modeling_gemma import GEMMA_INPUTS_DOCSTRING
13
- from transformers.utils import add_start_docstrings_to_model_forward
14
- from transformers.utils import replace_return_docstrings
15
11
  from transformers.utils.deprecation import deprecate_kwarg
16
12
 
17
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
19
15
 
20
16
 
21
- @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
22
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
23
17
  def lce_forward_deprecated(
24
18
  self,
25
19
  input_ids: torch.LongTensor = None,
@@ -129,8 +123,6 @@ def lce_forward_deprecated(
129
123
 
130
124
 
131
125
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
132
- @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
133
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
134
126
  def lce_forward(
135
127
  self,
136
128
  input_ids: torch.LongTensor = None,
@@ -200,21 +192,25 @@ def lce_forward(
200
192
  )
201
193
 
202
194
  hidden_states = outputs[0]
195
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
196
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
197
+ kept_hidden_states = hidden_states[:, slice_indices, :]
203
198
 
199
+ shift_labels = loss_kwargs.pop("shift_labels", None)
204
200
  logits = None
205
201
  loss = None
206
202
  # if in training mode, don't materialize logits
207
- if self.training and (labels is not None):
203
+ if self.training and (labels is not None or shift_labels is not None):
208
204
  loss = LigerForCausalLMLoss(
209
- hidden_states=hidden_states,
205
+ hidden_states=kept_hidden_states,
210
206
  lm_head_weight=self.lm_head.weight,
211
207
  labels=labels,
208
+ shift_labels=shift_labels,
212
209
  hidden_size=self.config.hidden_size,
213
210
  **loss_kwargs,
214
211
  )
215
212
  else: # if in inference mode materialize logits
216
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
217
- logits = self.lm_head(hidden_states[:, slice_indices, :])
213
+ logits = self.lm_head(kept_hidden_states)
218
214
  if labels is not None:
219
215
  loss = self.loss_function(
220
216
  logits=logits,
@@ -9,10 +9,6 @@ import torch
9
9
  from torch.nn import CrossEntropyLoss
10
10
  from transformers.cache_utils import HybridCache
11
11
  from transformers.modeling_outputs import CausalLMOutputWithPast
12
- from transformers.models.gemma2.modeling_gemma2 import _CONFIG_FOR_DOC
13
- from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING
14
- from transformers.utils import add_start_docstrings_to_model_forward
15
- from transformers.utils import replace_return_docstrings
16
12
  from transformers.utils.deprecation import deprecate_kwarg
17
13
 
18
14
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
@@ -136,8 +132,6 @@ def lce_forward_deprecated(
136
132
 
137
133
 
138
134
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
139
- @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
140
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
141
135
  def lce_forward(
142
136
  self,
143
137
  input_ids: torch.LongTensor = None,
@@ -212,23 +206,27 @@ def lce_forward(
212
206
  )
213
207
 
214
208
  hidden_states = outputs[0]
209
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
210
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
211
+ kept_hidden_states = hidden_states[:, slice_indices, :]
215
212
 
213
+ shift_labels = loss_kwargs.pop("shift_labels", None)
216
214
  logits = None
217
215
  loss = None
218
216
  # if in training mode, don't materialize logits
219
- if self.training and (labels is not None):
217
+ if self.training and (labels is not None or shift_labels is not None):
220
218
  loss = LigerForCausalLMLoss(
221
- hidden_states=hidden_states,
219
+ hidden_states=kept_hidden_states,
222
220
  lm_head_weight=self.lm_head.weight,
223
221
  labels=labels,
222
+ shift_labels=shift_labels,
224
223
  hidden_size=self.config.hidden_size,
225
224
  final_logit_softcapping=self.config.final_logit_softcapping,
226
225
  **loss_kwargs,
227
226
  )
228
227
 
229
228
  else: # if in inference mode materialize logits
230
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
231
- logits = self.lm_head(hidden_states[:, slice_indices, :])
229
+ logits = self.lm_head(kept_hidden_states)
232
230
  if self.config.final_logit_softcapping is not None:
233
231
  logits = logits / self.config.final_logit_softcapping
234
232
  logits = torch.tanh(logits)
@@ -9,13 +9,9 @@ import torch.nn as nn
9
9
  from transformers.cache_utils import Cache
10
10
  from transformers.cache_utils import HybridCache
11
11
  from transformers.modeling_outputs import CausalLMOutputWithPast
12
- from transformers.models.gemma3.modeling_gemma3 import _CONFIG_FOR_DOC
13
- from transformers.models.gemma3.modeling_gemma3 import GEMMA3_INPUTS_DOCSTRING
14
12
  from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast
15
- from transformers.utils import add_start_docstrings_to_model_forward
16
13
  from transformers.utils import is_torchdynamo_compiling
17
14
  from transformers.utils import logging
18
- from transformers.utils import replace_return_docstrings
19
15
  from transformers.utils.deprecation import deprecate_kwarg
20
16
 
21
17
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
@@ -25,8 +21,6 @@ logger = logging.get_logger(__name__)
25
21
 
26
22
 
27
23
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
28
- @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
29
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
30
24
  def causal_forward(
31
25
  self,
32
26
  input_ids: torch.LongTensor = None,
@@ -104,13 +98,15 @@ def causal_forward(
104
98
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
105
99
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
106
100
  kept_hidden_states = hidden_states[:, slice_indices, :]
101
+ shift_labels = loss_kwargs.pop("shift_labels", None)
107
102
  loss = None
108
103
  logits = None
109
- if self.training and (labels is not None):
104
+ if self.training and (labels is not None or shift_labels is not None):
110
105
  loss = LigerForCausalLMLoss(
111
106
  hidden_states=kept_hidden_states,
112
107
  lm_head_weight=self.lm_head.weight,
113
108
  labels=labels,
109
+ shift_labels=shift_labels,
114
110
  hidden_size=self.config.hidden_size,
115
111
  final_logit_softcapping=self.config.final_logit_softcapping,
116
112
  **loss_kwargs,
@@ -139,8 +135,6 @@ def causal_forward(
139
135
 
140
136
 
141
137
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
142
- @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
143
- @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
144
138
  def multimodal_forward(
145
139
  self,
146
140
  input_ids: torch.LongTensor = None,
@@ -0,0 +1,119 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast
9
+ from transformers.utils.deprecation import deprecate_kwarg
10
+
11
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
12
+
13
+
14
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
15
+ def lce_forward(
16
+ self,
17
+ input_ids: torch.LongTensor = None,
18
+ attention_mask: Optional[torch.Tensor] = None,
19
+ position_ids: Optional[torch.LongTensor] = None,
20
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
21
+ inputs_embeds: Optional[torch.FloatTensor] = None,
22
+ labels: Optional[torch.LongTensor] = None,
23
+ use_cache: Optional[bool] = None,
24
+ output_attentions: Optional[bool] = None,
25
+ output_hidden_states: Optional[bool] = None,
26
+ return_dict: Optional[bool] = None,
27
+ cache_position: Optional[torch.LongTensor] = None,
28
+ logits_to_keep: Union[int, torch.Tensor] = 0,
29
+ **loss_kwargs,
30
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
31
+ r"""
32
+ Args:
33
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
34
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
35
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
36
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
37
+
38
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
39
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
40
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
41
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
42
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
43
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
44
+
45
+ Returns:
46
+
47
+ Example:
48
+
49
+ ```python
50
+ >>> from transformers import AutoTokenizer, Glm4ForCausalLM
51
+
52
+ >>> model = Glm4ForCausalLM.from_pretrained("THUDM/GLM-4-9B-0414")
53
+ >>> tokenizer = AutoTokenizer.from_pretrained("THUDM/GLM-4-9B-0414")
54
+
55
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
56
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
57
+
58
+ >>> # Generate
59
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
60
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
61
+ 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
62
+ ```
63
+ """
64
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
65
+ output_hidden_states = (
66
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
67
+ )
68
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
69
+
70
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
71
+ outputs = self.model(
72
+ input_ids=input_ids,
73
+ attention_mask=attention_mask,
74
+ position_ids=position_ids,
75
+ past_key_values=past_key_values,
76
+ inputs_embeds=inputs_embeds,
77
+ use_cache=use_cache,
78
+ output_attentions=output_attentions,
79
+ output_hidden_states=output_hidden_states,
80
+ return_dict=return_dict,
81
+ cache_position=cache_position,
82
+ )
83
+
84
+ hidden_states = outputs[0]
85
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
86
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
87
+ kept_hidden_states = hidden_states[:, slice_indices, :]
88
+
89
+ shift_labels = loss_kwargs.pop("shift_labels", None)
90
+ logits = None
91
+ loss = None
92
+ # if in training mode, don't materialize logits
93
+ if self.training and (labels is not None or shift_labels is not None):
94
+ loss = LigerForCausalLMLoss(
95
+ hidden_states=kept_hidden_states,
96
+ lm_head_weight=self.lm_head.weight,
97
+ labels=labels,
98
+ shift_labels=shift_labels,
99
+ hidden_size=self.config.hidden_size,
100
+ **loss_kwargs,
101
+ )
102
+
103
+ else: # if in inference mode materialize logits
104
+ logits = self.lm_head(kept_hidden_states)
105
+ if labels is not None:
106
+ loss = self.loss_function(
107
+ logits=logits,
108
+ labels=labels,
109
+ vocab_size=self.config.vocab_size,
110
+ **loss_kwargs,
111
+ )
112
+
113
+ return CausalLMOutputWithPast(
114
+ loss=loss,
115
+ logits=logits,
116
+ past_key_values=outputs.past_key_values,
117
+ hidden_states=outputs.hidden_states,
118
+ attentions=outputs.attentions,
119
+ )
@@ -7,23 +7,23 @@ from typing import Union
7
7
  import torch
8
8
  import torch.nn.functional as F
9
9
 
10
+ from torch.distributed.fsdp import FullyShardedDataParallel
10
11
  from torch.nn import CrossEntropyLoss
11
12
  from transformers.modeling_outputs import CausalLMOutputWithPast
12
- from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC
13
- from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING
14
- from transformers.utils import add_start_docstrings_to_model_forward
15
- from transformers.utils import replace_return_docstrings
16
13
  from transformers.utils.deprecation import deprecate_kwarg
17
14
 
15
+ from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
18
16
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
19
17
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
+ from liger_kernel.utils import PEFT_AVAILABLE
20
19
 
21
20
  if TYPE_CHECKING:
22
21
  from transformers.cache_utils import Cache
23
22
 
23
+ if PEFT_AVAILABLE:
24
+ from peft.utils.other import ModulesToSaveWrapper
25
+
24
26
 
25
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
26
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
27
27
  def lce_forward_deprecated(
28
28
  self,
29
29
  input_ids: torch.LongTensor = None,
@@ -137,8 +137,6 @@ def lce_forward_deprecated(
137
137
 
138
138
 
139
139
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
140
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
141
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
142
140
  def lce_forward(
143
141
  self,
144
142
  input_ids: torch.LongTensor = None,
@@ -209,25 +207,29 @@ def lce_forward(
209
207
  )
210
208
 
211
209
  hidden_states = outputs[0]
210
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
211
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
212
+ kept_hidden_states = hidden_states[:, slice_indices, :]
212
213
 
213
214
  if self.config.pretraining_tp > 1:
214
215
  raise Exception("Liger Kernel does not support pretraining_tp!!")
215
216
 
217
+ shift_labels = loss_kwargs.pop("shift_labels", None)
216
218
  logits = None
217
219
  loss = None
218
220
  # if in training mode, don't materialize logits
219
- if self.training and (labels is not None):
220
- loss = LigerForCausalLMLoss(
221
- hidden_states=hidden_states,
222
- lm_head_weight=self.lm_head.weight,
223
- labels=labels,
221
+ if self.training and (labels is not None or shift_labels is not None):
222
+ loss = lce_maybe_trainable_lm_head(
223
+ self,
224
+ hidden_states=kept_hidden_states,
224
225
  hidden_size=self.config.hidden_size,
226
+ labels=labels,
227
+ shift_labels=shift_labels,
225
228
  **loss_kwargs,
226
229
  )
227
230
 
228
231
  else: # if in inference mode materialize logits
229
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
230
- logits = self.lm_head(hidden_states[:, slice_indices, :])
232
+ logits = self.lm_head(kept_hidden_states)
231
233
  if labels is not None:
232
234
  loss = self.loss_function(
233
235
  logits=logits,
@@ -247,3 +249,50 @@ def lce_forward(
247
249
  hidden_states=outputs.hidden_states,
248
250
  attentions=outputs.attentions,
249
251
  )
252
+
253
+
254
+ def lce_maybe_trainable_lm_head(self, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
255
+ lm_head = self.lm_head
256
+
257
+ # Unwrap the module if lm_head has been added as trainable module in PEFT LoRA configuration,
258
+ # i.e. listed in the modules_to_save field of LoraConfig, so the lm_head weights are read
259
+ # from the unwrapped module.
260
+ # See https://huggingface.co/docs/peft/package_reference/lora for reference.
261
+ if PEFT_AVAILABLE and isinstance(lm_head, ModulesToSaveWrapper):
262
+ lm_head = lm_head.modules_to_save.default
263
+
264
+ # If FSDP is used and lm_head is trainable, e.g., during full fine-tuning or with LoRA,
265
+ # reading the lm_head module weights and calling the kernel must be done within FSDP forward pass
266
+ # so the module entire parameters are summoned and kept in memory during the kernel execution.
267
+ if isinstance(lm_head, FullyShardedDataParallel):
268
+ return _FSDPForwardRedirection()(
269
+ lm_head,
270
+ _liger_for_causal_lm_loss,
271
+ lm_head.module,
272
+ hidden_states,
273
+ hidden_size,
274
+ labels,
275
+ shift_labels,
276
+ **loss_kwargs,
277
+ )
278
+
279
+ # FSDP is not used so we can read the lm_head weights and call the kernel directly
280
+ return _liger_for_causal_lm_loss(
281
+ lm_head=self.lm_head,
282
+ hidden_states=hidden_states,
283
+ hidden_size=hidden_size,
284
+ labels=labels,
285
+ shift_labels=shift_labels,
286
+ **loss_kwargs,
287
+ )
288
+
289
+
290
+ def _liger_for_causal_lm_loss(lm_head, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
291
+ return LigerForCausalLMLoss(
292
+ hidden_states=hidden_states,
293
+ lm_head_weight=lm_head.weight,
294
+ labels=labels,
295
+ hidden_size=hidden_size,
296
+ shift_labels=shift_labels,
297
+ **loss_kwargs,
298
+ )
@@ -5,19 +5,13 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from transformers.models.llava.modeling_llava import _CONFIG_FOR_DOC
9
- from transformers.models.llava.modeling_llava import LLAVA_INPUTS_DOCSTRING
10
8
  from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
11
- from transformers.utils import add_start_docstrings_to_model_forward
12
9
  from transformers.utils import is_torchdynamo_compiling
13
- from transformers.utils import replace_return_docstrings
14
10
  from transformers.utils.deprecation import deprecate_kwarg
15
11
 
16
12
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
13
 
18
14
 
19
- @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
20
- @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
21
15
  def lce_forward_deprecated(
22
16
  self,
23
17
  input_ids: torch.LongTensor = None,
@@ -210,9 +204,7 @@ def lce_forward_deprecated(
210
204
  )
211
205
 
212
206
 
213
- @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
214
207
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
215
- @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
216
208
  def lce_forward(
217
209
  self,
218
210
  input_ids: torch.LongTensor = None,
@@ -7,18 +7,12 @@ import torch
7
7
 
8
8
  from transformers.cache_utils import Cache
9
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
10
- from transformers.models.mistral.modeling_mistral import _CONFIG_FOR_DOC
11
- from transformers.models.mistral.modeling_mistral import MISTRAL_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
10
  from transformers.utils.deprecation import deprecate_kwarg
15
11
 
16
12
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
17
13
 
18
14
 
19
15
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
20
- @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
22
16
  def lce_forward(
23
17
  self,
24
18
  input_ids: torch.LongTensor = None,
@@ -91,22 +85,26 @@ def lce_forward(
91
85
  )
92
86
 
93
87
  hidden_states = outputs[0]
88
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
89
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
90
+ kept_hidden_states = hidden_states[:, slice_indices, :]
94
91
 
92
+ shift_labels = loss_kwargs.pop("shift_labels", None)
95
93
  loss = None
96
94
  logits = None
97
95
 
98
- if self.training and (labels is not None):
96
+ if self.training and (labels is not None or shift_labels is not None):
99
97
  loss = LigerForCausalLMLoss(
100
- hidden_states=hidden_states,
98
+ hidden_states=kept_hidden_states,
101
99
  lm_head_weight=self.lm_head.weight,
102
100
  labels=labels,
101
+ shift_labels=shift_labels,
103
102
  hidden_size=self.config.hidden_size,
104
103
  **loss_kwargs,
105
104
  )
106
105
 
107
106
  else:
108
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
109
- logits = self.lm_head(hidden_states[:, slice_indices, :])
107
+ logits = self.lm_head(kept_hidden_states)
110
108
 
111
109
  loss = None
112
110
  if labels is not None:
@@ -7,19 +7,13 @@ import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.modeling_outputs import MoeCausalLMOutputWithPast
10
- from transformers.models.mixtral.modeling_mixtral import _CONFIG_FOR_DOC
11
- from transformers.models.mixtral.modeling_mixtral import MIXTRAL_INPUTS_DOCSTRING
12
10
  from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
13
- from transformers.utils import add_start_docstrings_to_model_forward
14
- from transformers.utils import replace_return_docstrings
15
11
  from transformers.utils.deprecation import deprecate_kwarg
16
12
 
17
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
19
15
 
20
16
 
21
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
22
- @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
23
17
  def lce_forward_deprecated(
24
18
  self,
25
19
  input_ids: torch.LongTensor = None,
@@ -146,8 +140,6 @@ def lce_forward_deprecated(
146
140
 
147
141
 
148
142
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
149
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
150
- @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
151
143
  # Ignore copy
152
144
  def lce_forward(
153
145
  self,
@@ -225,22 +217,26 @@ def lce_forward(
225
217
  )
226
218
 
227
219
  hidden_states = outputs[0]
220
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
221
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
222
+ kept_hidden_states = hidden_states[:, slice_indices, :]
228
223
 
224
+ shift_labels = loss_kwargs.pop("shift_labels", None)
229
225
  logits = None
230
226
  loss = None
231
227
  # if in training mode, don't materialize logits
232
- if self.training and (labels is not None):
228
+ if self.training and (labels is not None or shift_labels is not None):
233
229
  loss = LigerForCausalLMLoss(
234
- hidden_states=hidden_states,
230
+ hidden_states=kept_hidden_states,
235
231
  lm_head_weight=self.lm_head.weight,
236
232
  labels=labels,
233
+ shift_labels=shift_labels,
237
234
  hidden_size=self.config.hidden_size,
238
235
  **loss_kwargs,
239
236
  )
240
237
 
241
238
  else: # if in inference mode materialize logits
242
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
243
- logits = self.lm_head(hidden_states[:, slice_indices, :])
239
+ logits = self.lm_head(kept_hidden_states)
244
240
 
245
241
  loss = None
246
242
  if labels is not None: