liger-kernel 0.5.8__py3-none-any.whl → 0.5.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. liger_kernel/chunked_loss/dpo_loss.py +8 -1
  2. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  3. liger_kernel/chunked_loss/jsd_loss.py +2 -2
  4. liger_kernel/ops/cross_entropy.py +4 -1
  5. liger_kernel/ops/dyt.py +113 -179
  6. liger_kernel/ops/fused_linear_cross_entropy.py +4 -3
  7. liger_kernel/ops/grpo_loss.py +310 -0
  8. liger_kernel/ops/sparsemax.py +167 -0
  9. liger_kernel/transformers/__init__.py +11 -0
  10. liger_kernel/transformers/dyt.py +5 -3
  11. liger_kernel/transformers/fsdp.py +55 -0
  12. liger_kernel/transformers/functional.py +8 -0
  13. liger_kernel/transformers/fused_linear_cross_entropy.py +1 -2
  14. liger_kernel/transformers/grpo_loss.py +98 -0
  15. liger_kernel/transformers/model/gemma.py +8 -12
  16. liger_kernel/transformers/model/gemma2.py +8 -10
  17. liger_kernel/transformers/model/gemma3.py +3 -9
  18. liger_kernel/transformers/model/glm4.py +119 -0
  19. liger_kernel/transformers/model/llama.py +64 -15
  20. liger_kernel/transformers/model/llava.py +0 -8
  21. liger_kernel/transformers/model/mistral.py +8 -10
  22. liger_kernel/transformers/model/mixtral.py +8 -12
  23. liger_kernel/transformers/model/mllama.py +8 -11
  24. liger_kernel/transformers/model/olmo2.py +8 -10
  25. liger_kernel/transformers/model/paligemma.py +0 -8
  26. liger_kernel/transformers/model/phi3.py +8 -12
  27. liger_kernel/transformers/model/qwen2.py +8 -12
  28. liger_kernel/transformers/model/qwen2_5_vl.py +3 -7
  29. liger_kernel/transformers/model/qwen2_vl.py +3 -7
  30. liger_kernel/transformers/model/qwen3.py +112 -0
  31. liger_kernel/transformers/model/qwen3_moe.py +128 -0
  32. liger_kernel/transformers/monkey_patch.py +243 -13
  33. liger_kernel/transformers/sparsemax.py +16 -0
  34. liger_kernel/transformers/swiglu.py +21 -0
  35. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  36. liger_kernel/utils.py +11 -0
  37. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/METADATA +36 -20
  38. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/RECORD +42 -34
  39. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/WHEEL +1 -1
  40. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/licenses/LICENSE +0 -0
  41. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/licenses/NOTICE +0 -0
  42. {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/top_level.txt +0 -0
@@ -8,17 +8,12 @@ import torch
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.cache_utils import Cache
10
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
11
- from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
11
  from transformers.utils.deprecation import deprecate_kwarg
15
12
 
16
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
15
 
19
16
 
20
- @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
22
17
  def lce_forward_deprecated(
23
18
  self,
24
19
  input_ids: torch.LongTensor = None,
@@ -135,8 +130,6 @@ def lce_forward_deprecated(
135
130
 
136
131
 
137
132
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
138
- @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
139
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
140
133
  def lce_forward(
141
134
  self,
142
135
  input_ids: torch.LongTensor = None,
@@ -215,22 +208,26 @@ def lce_forward(
215
208
  )
216
209
 
217
210
  hidden_states = outputs[0]
211
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
212
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
213
+ kept_hidden_states = hidden_states[:, slice_indices, :]
218
214
 
215
+ shift_labels = loss_kwargs.pop("shift_labels", None)
219
216
  logits = None
220
217
  loss = None
221
218
  # if in training mode, don't materialize logits
222
- if self.training and (labels is not None):
219
+ if self.training and (labels is not None or shift_labels is not None):
223
220
  loss = LigerForCausalLMLoss(
224
- hidden_states=hidden_states,
221
+ hidden_states=kept_hidden_states,
225
222
  lm_head_weight=self.lm_head.weight,
226
223
  labels=labels,
224
+ shift_labels=shift_labels,
227
225
  hidden_size=self.config.hidden_size,
228
226
  **loss_kwargs,
229
227
  )
230
228
 
231
229
  else: # if in inference mode materialize logits
232
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
233
- logits = self.lm_head(hidden_states[:, slice_indices, :])
230
+ logits = self.lm_head(kept_hidden_states)
234
231
  if labels is not None:
235
232
  loss = self.loss_function(
236
233
  logits=logits,
@@ -6,18 +6,12 @@ from typing import Union
6
6
  import torch
7
7
 
8
8
  from transformers.modeling_outputs import CausalLMOutputWithPast
9
- from transformers.models.olmo2.modeling_olmo2 import _CONFIG_FOR_DOC
10
- from transformers.models.olmo2.modeling_olmo2 import OLMO2_INPUTS_DOCSTRING
11
- from transformers.utils import add_start_docstrings_to_model_forward
12
- from transformers.utils import replace_return_docstrings
13
9
  from transformers.utils.deprecation import deprecate_kwarg
14
10
 
15
11
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
12
 
17
13
 
18
14
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
19
- @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING)
20
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
21
15
  def lce_forward(
22
16
  self,
23
17
  input_ids: torch.LongTensor = None,
@@ -88,22 +82,26 @@ def lce_forward(
88
82
  )
89
83
 
90
84
  hidden_states = outputs[0]
85
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
86
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
87
+ kept_hidden_states = hidden_states[:, slice_indices, :]
91
88
 
89
+ shift_labels = loss_kwargs.pop("shift_labels", None)
92
90
  logits = None
93
91
  loss = None
94
92
  # if in training mode, don't materialize logits
95
- if self.training and (labels is not None):
93
+ if self.training and (labels is not None or shift_labels is not None):
96
94
  loss = LigerForCausalLMLoss(
97
- hidden_states=hidden_states,
95
+ hidden_states=kept_hidden_states,
98
96
  lm_head_weight=self.lm_head.weight,
99
97
  labels=labels,
98
+ shift_labels=shift_labels,
100
99
  hidden_size=self.config.hidden_size,
101
100
  **loss_kwargs,
102
101
  )
103
102
 
104
103
  else: # if in inference mode materialize logits
105
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
106
- logits = self.lm_head(hidden_states[:, slice_indices, :])
104
+ logits = self.lm_head(kept_hidden_states)
107
105
  if labels is not None:
108
106
  loss = self.loss_function(
109
107
  logits=logits,
@@ -7,13 +7,9 @@ import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.cache_utils import Cache
10
- from transformers.models.paligemma.modeling_paligemma import _CONFIG_FOR_DOC
11
- from transformers.models.paligemma.modeling_paligemma import PALIGEMMA_INPUTS_DOCSTRING
12
10
  from transformers.models.paligemma.modeling_paligemma import PaliGemmaCausalLMOutputWithPast
13
- from transformers.utils import add_start_docstrings_to_model_forward
14
11
  from transformers.utils import is_torchdynamo_compiling
15
12
  from transformers.utils import logging
16
- from transformers.utils import replace_return_docstrings
17
13
  from transformers.utils.deprecation import deprecate_kwarg
18
14
 
19
15
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
@@ -21,8 +17,6 @@ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinea
21
17
  logger = logging.get_logger(__name__)
22
18
 
23
19
 
24
- @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
25
- @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
26
20
  def lce_forward_deprecated(
27
21
  self,
28
22
  input_ids: torch.LongTensor = None,
@@ -206,8 +200,6 @@ def lce_forward_deprecated(
206
200
 
207
201
 
208
202
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
209
- @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
210
- @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
211
203
  def lce_forward(
212
204
  self,
213
205
  input_ids: torch.LongTensor = None,
@@ -7,18 +7,12 @@ import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
10
- from transformers.models.phi3.modeling_phi3 import _CONFIG_FOR_DOC
11
- from transformers.models.phi3.modeling_phi3 import PHI3_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
10
  from transformers.utils.deprecation import deprecate_kwarg
15
11
 
16
12
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
13
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
14
 
19
15
 
20
- @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
22
16
  def lce_forward_deprecated(
23
17
  self,
24
18
  input_ids: torch.LongTensor = None,
@@ -128,8 +122,6 @@ def lce_forward_deprecated(
128
122
 
129
123
 
130
124
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
131
- @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
132
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
133
125
  def lce_forward(
134
126
  self,
135
127
  input_ids: torch.LongTensor = None,
@@ -213,22 +205,26 @@ def lce_forward(
213
205
  )
214
206
 
215
207
  hidden_states = outputs[0]
208
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
209
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
210
+ kept_hidden_states = hidden_states[:, slice_indices, :]
216
211
 
212
+ shift_labels = loss_kwargs.pop("shift_labels", None)
217
213
  logits = None
218
214
  loss = None
219
215
  # if in training mode, don't materialize logits
220
- if self.training and (labels is not None):
216
+ if self.training and (labels is not None or shift_labels is not None):
221
217
  loss = LigerForCausalLMLoss(
222
- hidden_states=hidden_states,
218
+ hidden_states=kept_hidden_states,
223
219
  lm_head_weight=self.lm_head.weight,
224
220
  labels=labels,
221
+ shift_labels=shift_labels,
225
222
  hidden_size=self.config.hidden_size,
226
223
  **loss_kwargs,
227
224
  )
228
225
 
229
226
  else: # if in inference mode materialize logits
230
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
231
- logits = self.lm_head(hidden_states[:, slice_indices, :])
227
+ logits = self.lm_head(kept_hidden_states)
232
228
  if labels is not None:
233
229
  loss = self.loss_function(
234
230
  logits=logits,
@@ -7,18 +7,12 @@ import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
10
- from transformers.models.qwen2.modeling_qwen2 import _CONFIG_FOR_DOC
11
- from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
10
  from transformers.utils.deprecation import deprecate_kwarg
15
11
 
16
12
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
13
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
14
 
19
15
 
20
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
22
16
  def lce_forward_deprecated(
23
17
  self,
24
18
  input_ids: torch.LongTensor = None,
@@ -127,8 +121,6 @@ def lce_forward_deprecated(
127
121
 
128
122
 
129
123
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
130
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
131
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
132
124
  def lce_forward(
133
125
  self,
134
126
  input_ids: torch.LongTensor = None,
@@ -199,22 +191,26 @@ def lce_forward(
199
191
  )
200
192
 
201
193
  hidden_states = outputs[0]
194
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
195
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
196
+ kept_hidden_states = hidden_states[:, slice_indices, :]
202
197
 
198
+ shift_labels = loss_kwargs.pop("shift_labels", None)
203
199
  logits = None
204
200
  loss = None
205
201
  # if in training mode, don't materialize logits
206
- if self.training and (labels is not None):
202
+ if self.training and (labels is not None or shift_labels is not None):
207
203
  loss = LigerForCausalLMLoss(
208
- hidden_states=hidden_states,
204
+ hidden_states=kept_hidden_states,
209
205
  lm_head_weight=self.lm_head.weight,
210
206
  labels=labels,
207
+ shift_labels=shift_labels,
211
208
  hidden_size=self.config.hidden_size,
212
209
  **loss_kwargs,
213
210
  )
214
211
 
215
212
  else: # if in inference mode materialize logits
216
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
217
- logits = self.lm_head(hidden_states[:, slice_indices, :])
213
+ logits = self.lm_head(kept_hidden_states)
218
214
  if labels is not None:
219
215
  loss = self.loss_function(
220
216
  logits=logits,
@@ -6,17 +6,11 @@ from typing import Union
6
6
  import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import _CONFIG_FOR_DOC
10
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import QWEN2_5_VL_INPUTS_DOCSTRING
11
9
  from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
10
 
15
11
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
12
 
17
13
 
18
- @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)
19
- @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
20
14
  def lce_forward(
21
15
  self,
22
16
  input_ids: torch.LongTensor = None,
@@ -163,14 +157,16 @@ def lce_forward(
163
157
 
164
158
  hidden_states = outputs[0]
165
159
 
160
+ shift_labels = loss_kwargs.pop("shift_labels", None)
166
161
  loss = None
167
162
  logits = None
168
163
 
169
- if self.training and (labels is not None):
164
+ if self.training and (labels is not None or shift_labels is not None):
170
165
  loss = LigerForCausalLMLoss(
171
166
  hidden_states=hidden_states,
172
167
  lm_head_weight=self.lm_head.weight,
173
168
  labels=labels,
169
+ shift_labels=shift_labels,
174
170
  hidden_size=self.config.hidden_size,
175
171
  **loss_kwargs,
176
172
  )
@@ -8,17 +8,11 @@ import torch
8
8
  from packaging import version
9
9
  from torch.nn import CrossEntropyLoss
10
10
  from transformers import __version__ as transformers_version
11
- from transformers.models.qwen2_vl.modeling_qwen2_vl import _CONFIG_FOR_DOC
12
- from transformers.models.qwen2_vl.modeling_qwen2_vl import QWEN2_VL_INPUTS_DOCSTRING
13
11
  from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
14
- from transformers.utils import add_start_docstrings_to_model_forward
15
- from transformers.utils import replace_return_docstrings
16
12
 
17
13
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
14
 
19
15
 
20
- @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
22
16
  def lce_forward(
23
17
  self,
24
18
  input_ids: torch.LongTensor = None,
@@ -167,14 +161,16 @@ def lce_forward(
167
161
 
168
162
  hidden_states = outputs[0]
169
163
 
164
+ shift_labels = loss_kwargs.pop("shift_labels", None)
170
165
  loss = None
171
166
  logits = None
172
167
 
173
- if self.training and (labels is not None):
168
+ if self.training and (labels is not None or shift_labels is not None):
174
169
  loss = LigerForCausalLMLoss(
175
170
  hidden_states=hidden_states,
176
171
  lm_head_weight=self.lm_head.weight,
177
172
  labels=labels,
173
+ shift_labels=shift_labels,
178
174
  hidden_size=self.config.hidden_size,
179
175
  **loss_kwargs,
180
176
  )
@@ -0,0 +1,112 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Union
4
+
5
+ import torch
6
+
7
+ from transformers.modeling_outputs import CausalLMOutputWithPast
8
+
9
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
10
+
11
+
12
+ def lce_forward(
13
+ self,
14
+ input_ids: Optional[torch.LongTensor] = None,
15
+ attention_mask: Optional[torch.Tensor] = None,
16
+ position_ids: Optional[torch.LongTensor] = None,
17
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
18
+ inputs_embeds: Optional[torch.FloatTensor] = None,
19
+ labels: Optional[torch.LongTensor] = None,
20
+ use_cache: Optional[bool] = None,
21
+ output_attentions: Optional[bool] = None,
22
+ output_hidden_states: Optional[bool] = None,
23
+ cache_position: Optional[torch.LongTensor] = None,
24
+ logits_to_keep: Union[int, torch.Tensor] = 0,
25
+ **kwargs,
26
+ ) -> CausalLMOutputWithPast:
27
+ r"""
28
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
29
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
30
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
31
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
32
+
33
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
34
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
35
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
36
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
37
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
38
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
39
+
40
+ Returns:
41
+
42
+ Example:
43
+
44
+ ```python
45
+ >>> from transformers import AutoTokenizer, Qwen3ForCausalLM
46
+
47
+ >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
48
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
49
+
50
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
51
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
52
+
53
+ >>> # Generate
54
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
55
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
56
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
57
+ ```"""
58
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
59
+ output_hidden_states = (
60
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
61
+ )
62
+
63
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
64
+ outputs = self.model(
65
+ input_ids=input_ids,
66
+ attention_mask=attention_mask,
67
+ position_ids=position_ids,
68
+ past_key_values=past_key_values,
69
+ inputs_embeds=inputs_embeds,
70
+ use_cache=use_cache,
71
+ output_attentions=output_attentions,
72
+ output_hidden_states=output_hidden_states,
73
+ cache_position=cache_position,
74
+ **kwargs,
75
+ )
76
+
77
+ hidden_states = outputs[0]
78
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
79
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
80
+ kept_hidden_states = hidden_states[:, slice_indices, :]
81
+
82
+ shift_labels = kwargs.pop("shift_labels", None)
83
+ logits = None
84
+ loss = None
85
+ # if in training mode, don't materialize logits
86
+ if self.training and (labels is not None or shift_labels is not None):
87
+ loss = LigerForCausalLMLoss(
88
+ hidden_states=kept_hidden_states,
89
+ lm_head_weight=self.lm_head.weight,
90
+ labels=labels,
91
+ shift_labels=shift_labels,
92
+ hidden_size=self.config.hidden_size,
93
+ **kwargs,
94
+ )
95
+
96
+ else: # if in inference mode materialize logits
97
+ logits = self.lm_head(kept_hidden_states)
98
+ if labels is not None:
99
+ loss = self.loss_function(
100
+ logits=logits,
101
+ labels=labels,
102
+ vocab_size=self.config.vocab_size,
103
+ **kwargs,
104
+ )
105
+
106
+ return CausalLMOutputWithPast(
107
+ loss=loss,
108
+ logits=logits,
109
+ past_key_values=outputs.past_key_values,
110
+ hidden_states=outputs.hidden_states,
111
+ attentions=outputs.attentions,
112
+ )
@@ -0,0 +1,128 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Union
4
+
5
+ import torch
6
+
7
+ from transformers.modeling_outputs import MoeCausalLMOutputWithPast
8
+ from transformers.modeling_outputs import MoeModelOutputWithPast
9
+ from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
10
+
11
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
12
+
13
+
14
+ def lce_forward(
15
+ self,
16
+ input_ids: Optional[torch.LongTensor] = None,
17
+ attention_mask: Optional[torch.Tensor] = None,
18
+ position_ids: Optional[torch.LongTensor] = None,
19
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
20
+ inputs_embeds: Optional[torch.FloatTensor] = None,
21
+ labels: Optional[torch.LongTensor] = None,
22
+ use_cache: Optional[bool] = None,
23
+ output_attentions: Optional[bool] = None,
24
+ output_hidden_states: Optional[bool] = None,
25
+ output_router_logits: Optional[bool] = None,
26
+ cache_position: Optional[torch.LongTensor] = None,
27
+ logits_to_keep: Union[int, torch.Tensor] = 0,
28
+ **loss_kwargs,
29
+ ) -> MoeCausalLMOutputWithPast:
30
+ r"""
31
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
32
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
33
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
34
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
35
+
36
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
37
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
38
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
39
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
40
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
41
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
42
+
43
+ Returns:
44
+
45
+ Example:
46
+
47
+ ```python
48
+ >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM
49
+
50
+ >>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
51
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
52
+
53
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
54
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
55
+
56
+ >>> # Generate
57
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
58
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
59
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
60
+ ```"""
61
+
62
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
63
+ output_router_logits = (
64
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
65
+ )
66
+
67
+ output_hidden_states = (
68
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
69
+ )
70
+
71
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
72
+ outputs: MoeModelOutputWithPast = self.model(
73
+ input_ids=input_ids,
74
+ attention_mask=attention_mask,
75
+ position_ids=position_ids,
76
+ past_key_values=past_key_values,
77
+ inputs_embeds=inputs_embeds,
78
+ use_cache=use_cache,
79
+ output_attentions=output_attentions,
80
+ output_hidden_states=output_hidden_states,
81
+ output_router_logits=output_router_logits,
82
+ cache_position=cache_position,
83
+ )
84
+
85
+ hidden_states = outputs.last_hidden_state
86
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
87
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
88
+ kept_hidden_states = hidden_states[:, slice_indices, :]
89
+
90
+ shift_labels = loss_kwargs.pop("shift_labels", None)
91
+ logits = None
92
+ loss = None
93
+
94
+ # if in training mode, do not materialize logits
95
+ if self.training and (labels is not None or shift_labels is not None):
96
+ loss = LigerForCausalLMLoss(
97
+ hidden_states=kept_hidden_states,
98
+ lm_head_weight=self.lm_head.weight,
99
+ labels=labels,
100
+ shift_labels=shift_labels,
101
+ hidden_size=self.config.hidden_size,
102
+ **loss_kwargs,
103
+ )
104
+ else: # if in inference model materialize logits
105
+ logits = self.lm_head(kept_hidden_states)
106
+ if labels is not None:
107
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
108
+
109
+ aux_loss = None
110
+ if output_router_logits:
111
+ aux_loss = load_balancing_loss_func(
112
+ outputs.router_logits,
113
+ self.num_experts,
114
+ self.num_experts_per_tok,
115
+ attention_mask,
116
+ )
117
+ if labels is not None:
118
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
119
+
120
+ return MoeCausalLMOutputWithPast(
121
+ loss=loss,
122
+ aux_loss=aux_loss,
123
+ logits=logits,
124
+ past_key_values=outputs.past_key_values,
125
+ hidden_states=outputs.hidden_states,
126
+ attentions=outputs.attentions,
127
+ router_logits=outputs.router_logits,
128
+ )