liger-kernel 0.5.8__py3-none-any.whl → 0.5.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/dpo_loss.py +8 -1
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/jsd_loss.py +2 -2
- liger_kernel/ops/cross_entropy.py +4 -1
- liger_kernel/ops/dyt.py +113 -179
- liger_kernel/ops/fused_linear_cross_entropy.py +4 -3
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/sparsemax.py +167 -0
- liger_kernel/transformers/__init__.py +11 -0
- liger_kernel/transformers/dyt.py +5 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +8 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +1 -2
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/model/gemma.py +8 -12
- liger_kernel/transformers/model/gemma2.py +8 -10
- liger_kernel/transformers/model/gemma3.py +3 -9
- liger_kernel/transformers/model/glm4.py +119 -0
- liger_kernel/transformers/model/llama.py +64 -15
- liger_kernel/transformers/model/llava.py +0 -8
- liger_kernel/transformers/model/mistral.py +8 -10
- liger_kernel/transformers/model/mixtral.py +8 -12
- liger_kernel/transformers/model/mllama.py +8 -11
- liger_kernel/transformers/model/olmo2.py +8 -10
- liger_kernel/transformers/model/paligemma.py +0 -8
- liger_kernel/transformers/model/phi3.py +8 -12
- liger_kernel/transformers/model/qwen2.py +8 -12
- liger_kernel/transformers/model/qwen2_5_vl.py +3 -7
- liger_kernel/transformers/model/qwen2_vl.py +3 -7
- liger_kernel/transformers/model/qwen3.py +112 -0
- liger_kernel/transformers/model/qwen3_moe.py +128 -0
- liger_kernel/transformers/monkey_patch.py +243 -13
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +21 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/utils.py +11 -0
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/METADATA +36 -20
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/RECORD +42 -34
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/WHEEL +1 -1
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/top_level.txt +0 -0
|
@@ -8,17 +8,12 @@ import torch
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
9
|
from transformers.cache_utils import Cache
|
|
10
10
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
11
|
-
from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING
|
|
12
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
|
-
from transformers.utils import replace_return_docstrings
|
|
14
11
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
15
12
|
|
|
16
13
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
17
14
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
18
15
|
|
|
19
16
|
|
|
20
|
-
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
|
21
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
|
|
22
17
|
def lce_forward_deprecated(
|
|
23
18
|
self,
|
|
24
19
|
input_ids: torch.LongTensor = None,
|
|
@@ -135,8 +130,6 @@ def lce_forward_deprecated(
|
|
|
135
130
|
|
|
136
131
|
|
|
137
132
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
138
|
-
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
|
139
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
|
|
140
133
|
def lce_forward(
|
|
141
134
|
self,
|
|
142
135
|
input_ids: torch.LongTensor = None,
|
|
@@ -215,22 +208,26 @@ def lce_forward(
|
|
|
215
208
|
)
|
|
216
209
|
|
|
217
210
|
hidden_states = outputs[0]
|
|
211
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
212
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
213
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
218
214
|
|
|
215
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
219
216
|
logits = None
|
|
220
217
|
loss = None
|
|
221
218
|
# if in training mode, don't materialize logits
|
|
222
|
-
if self.training and (labels is not None):
|
|
219
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
223
220
|
loss = LigerForCausalLMLoss(
|
|
224
|
-
hidden_states=
|
|
221
|
+
hidden_states=kept_hidden_states,
|
|
225
222
|
lm_head_weight=self.lm_head.weight,
|
|
226
223
|
labels=labels,
|
|
224
|
+
shift_labels=shift_labels,
|
|
227
225
|
hidden_size=self.config.hidden_size,
|
|
228
226
|
**loss_kwargs,
|
|
229
227
|
)
|
|
230
228
|
|
|
231
229
|
else: # if in inference mode materialize logits
|
|
232
|
-
|
|
233
|
-
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
230
|
+
logits = self.lm_head(kept_hidden_states)
|
|
234
231
|
if labels is not None:
|
|
235
232
|
loss = self.loss_function(
|
|
236
233
|
logits=logits,
|
|
@@ -6,18 +6,12 @@ from typing import Union
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
8
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
9
|
-
from transformers.models.olmo2.modeling_olmo2 import _CONFIG_FOR_DOC
|
|
10
|
-
from transformers.models.olmo2.modeling_olmo2 import OLMO2_INPUTS_DOCSTRING
|
|
11
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
12
|
-
from transformers.utils import replace_return_docstrings
|
|
13
9
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
14
10
|
|
|
15
11
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
16
12
|
|
|
17
13
|
|
|
18
14
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
19
|
-
@add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING)
|
|
20
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
21
15
|
def lce_forward(
|
|
22
16
|
self,
|
|
23
17
|
input_ids: torch.LongTensor = None,
|
|
@@ -88,22 +82,26 @@ def lce_forward(
|
|
|
88
82
|
)
|
|
89
83
|
|
|
90
84
|
hidden_states = outputs[0]
|
|
85
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
86
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
87
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
91
88
|
|
|
89
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
92
90
|
logits = None
|
|
93
91
|
loss = None
|
|
94
92
|
# if in training mode, don't materialize logits
|
|
95
|
-
if self.training and (labels is not None):
|
|
93
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
96
94
|
loss = LigerForCausalLMLoss(
|
|
97
|
-
hidden_states=
|
|
95
|
+
hidden_states=kept_hidden_states,
|
|
98
96
|
lm_head_weight=self.lm_head.weight,
|
|
99
97
|
labels=labels,
|
|
98
|
+
shift_labels=shift_labels,
|
|
100
99
|
hidden_size=self.config.hidden_size,
|
|
101
100
|
**loss_kwargs,
|
|
102
101
|
)
|
|
103
102
|
|
|
104
103
|
else: # if in inference mode materialize logits
|
|
105
|
-
|
|
106
|
-
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
104
|
+
logits = self.lm_head(kept_hidden_states)
|
|
107
105
|
if labels is not None:
|
|
108
106
|
loss = self.loss_function(
|
|
109
107
|
logits=logits,
|
|
@@ -7,13 +7,9 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
9
|
from transformers.cache_utils import Cache
|
|
10
|
-
from transformers.models.paligemma.modeling_paligemma import _CONFIG_FOR_DOC
|
|
11
|
-
from transformers.models.paligemma.modeling_paligemma import PALIGEMMA_INPUTS_DOCSTRING
|
|
12
10
|
from transformers.models.paligemma.modeling_paligemma import PaliGemmaCausalLMOutputWithPast
|
|
13
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
14
11
|
from transformers.utils import is_torchdynamo_compiling
|
|
15
12
|
from transformers.utils import logging
|
|
16
|
-
from transformers.utils import replace_return_docstrings
|
|
17
13
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
18
14
|
|
|
19
15
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
@@ -21,8 +17,6 @@ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinea
|
|
|
21
17
|
logger = logging.get_logger(__name__)
|
|
22
18
|
|
|
23
19
|
|
|
24
|
-
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
|
|
25
|
-
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
26
20
|
def lce_forward_deprecated(
|
|
27
21
|
self,
|
|
28
22
|
input_ids: torch.LongTensor = None,
|
|
@@ -206,8 +200,6 @@ def lce_forward_deprecated(
|
|
|
206
200
|
|
|
207
201
|
|
|
208
202
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
209
|
-
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
|
|
210
|
-
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
211
203
|
def lce_forward(
|
|
212
204
|
self,
|
|
213
205
|
input_ids: torch.LongTensor = None,
|
|
@@ -7,18 +7,12 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
9
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
10
|
-
from transformers.models.phi3.modeling_phi3 import _CONFIG_FOR_DOC
|
|
11
|
-
from transformers.models.phi3.modeling_phi3 import PHI3_INPUTS_DOCSTRING
|
|
12
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
|
-
from transformers.utils import replace_return_docstrings
|
|
14
10
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
15
11
|
|
|
16
12
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
17
13
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
18
14
|
|
|
19
15
|
|
|
20
|
-
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
|
|
21
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
22
16
|
def lce_forward_deprecated(
|
|
23
17
|
self,
|
|
24
18
|
input_ids: torch.LongTensor = None,
|
|
@@ -128,8 +122,6 @@ def lce_forward_deprecated(
|
|
|
128
122
|
|
|
129
123
|
|
|
130
124
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
131
|
-
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
|
|
132
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
133
125
|
def lce_forward(
|
|
134
126
|
self,
|
|
135
127
|
input_ids: torch.LongTensor = None,
|
|
@@ -213,22 +205,26 @@ def lce_forward(
|
|
|
213
205
|
)
|
|
214
206
|
|
|
215
207
|
hidden_states = outputs[0]
|
|
208
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
209
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
210
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
216
211
|
|
|
212
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
217
213
|
logits = None
|
|
218
214
|
loss = None
|
|
219
215
|
# if in training mode, don't materialize logits
|
|
220
|
-
if self.training and (labels is not None):
|
|
216
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
221
217
|
loss = LigerForCausalLMLoss(
|
|
222
|
-
hidden_states=
|
|
218
|
+
hidden_states=kept_hidden_states,
|
|
223
219
|
lm_head_weight=self.lm_head.weight,
|
|
224
220
|
labels=labels,
|
|
221
|
+
shift_labels=shift_labels,
|
|
225
222
|
hidden_size=self.config.hidden_size,
|
|
226
223
|
**loss_kwargs,
|
|
227
224
|
)
|
|
228
225
|
|
|
229
226
|
else: # if in inference mode materialize logits
|
|
230
|
-
|
|
231
|
-
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
227
|
+
logits = self.lm_head(kept_hidden_states)
|
|
232
228
|
if labels is not None:
|
|
233
229
|
loss = self.loss_function(
|
|
234
230
|
logits=logits,
|
|
@@ -7,18 +7,12 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
9
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
10
|
-
from transformers.models.qwen2.modeling_qwen2 import _CONFIG_FOR_DOC
|
|
11
|
-
from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING
|
|
12
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
|
-
from transformers.utils import replace_return_docstrings
|
|
14
10
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
15
11
|
|
|
16
12
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
17
13
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
18
14
|
|
|
19
15
|
|
|
20
|
-
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
21
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
22
16
|
def lce_forward_deprecated(
|
|
23
17
|
self,
|
|
24
18
|
input_ids: torch.LongTensor = None,
|
|
@@ -127,8 +121,6 @@ def lce_forward_deprecated(
|
|
|
127
121
|
|
|
128
122
|
|
|
129
123
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
130
|
-
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
131
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
132
124
|
def lce_forward(
|
|
133
125
|
self,
|
|
134
126
|
input_ids: torch.LongTensor = None,
|
|
@@ -199,22 +191,26 @@ def lce_forward(
|
|
|
199
191
|
)
|
|
200
192
|
|
|
201
193
|
hidden_states = outputs[0]
|
|
194
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
195
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
196
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
202
197
|
|
|
198
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
203
199
|
logits = None
|
|
204
200
|
loss = None
|
|
205
201
|
# if in training mode, don't materialize logits
|
|
206
|
-
if self.training and (labels is not None):
|
|
202
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
207
203
|
loss = LigerForCausalLMLoss(
|
|
208
|
-
hidden_states=
|
|
204
|
+
hidden_states=kept_hidden_states,
|
|
209
205
|
lm_head_weight=self.lm_head.weight,
|
|
210
206
|
labels=labels,
|
|
207
|
+
shift_labels=shift_labels,
|
|
211
208
|
hidden_size=self.config.hidden_size,
|
|
212
209
|
**loss_kwargs,
|
|
213
210
|
)
|
|
214
211
|
|
|
215
212
|
else: # if in inference mode materialize logits
|
|
216
|
-
|
|
217
|
-
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
213
|
+
logits = self.lm_head(kept_hidden_states)
|
|
218
214
|
if labels is not None:
|
|
219
215
|
loss = self.loss_function(
|
|
220
216
|
logits=logits,
|
|
@@ -6,17 +6,11 @@ from typing import Union
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
|
-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import _CONFIG_FOR_DOC
|
|
10
|
-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import QWEN2_5_VL_INPUTS_DOCSTRING
|
|
11
9
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast
|
|
12
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
|
-
from transformers.utils import replace_return_docstrings
|
|
14
10
|
|
|
15
11
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
16
12
|
|
|
17
13
|
|
|
18
|
-
@add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)
|
|
19
|
-
@replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
20
14
|
def lce_forward(
|
|
21
15
|
self,
|
|
22
16
|
input_ids: torch.LongTensor = None,
|
|
@@ -163,14 +157,16 @@ def lce_forward(
|
|
|
163
157
|
|
|
164
158
|
hidden_states = outputs[0]
|
|
165
159
|
|
|
160
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
166
161
|
loss = None
|
|
167
162
|
logits = None
|
|
168
163
|
|
|
169
|
-
if self.training and (labels is not None):
|
|
164
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
170
165
|
loss = LigerForCausalLMLoss(
|
|
171
166
|
hidden_states=hidden_states,
|
|
172
167
|
lm_head_weight=self.lm_head.weight,
|
|
173
168
|
labels=labels,
|
|
169
|
+
shift_labels=shift_labels,
|
|
174
170
|
hidden_size=self.config.hidden_size,
|
|
175
171
|
**loss_kwargs,
|
|
176
172
|
)
|
|
@@ -8,17 +8,11 @@ import torch
|
|
|
8
8
|
from packaging import version
|
|
9
9
|
from torch.nn import CrossEntropyLoss
|
|
10
10
|
from transformers import __version__ as transformers_version
|
|
11
|
-
from transformers.models.qwen2_vl.modeling_qwen2_vl import _CONFIG_FOR_DOC
|
|
12
|
-
from transformers.models.qwen2_vl.modeling_qwen2_vl import QWEN2_VL_INPUTS_DOCSTRING
|
|
13
11
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
|
|
14
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
15
|
-
from transformers.utils import replace_return_docstrings
|
|
16
12
|
|
|
17
13
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
18
14
|
|
|
19
15
|
|
|
20
|
-
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
|
|
21
|
-
@replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
22
16
|
def lce_forward(
|
|
23
17
|
self,
|
|
24
18
|
input_ids: torch.LongTensor = None,
|
|
@@ -167,14 +161,16 @@ def lce_forward(
|
|
|
167
161
|
|
|
168
162
|
hidden_states = outputs[0]
|
|
169
163
|
|
|
164
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
170
165
|
loss = None
|
|
171
166
|
logits = None
|
|
172
167
|
|
|
173
|
-
if self.training and (labels is not None):
|
|
168
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
174
169
|
loss = LigerForCausalLMLoss(
|
|
175
170
|
hidden_states=hidden_states,
|
|
176
171
|
lm_head_weight=self.lm_head.weight,
|
|
177
172
|
labels=labels,
|
|
173
|
+
shift_labels=shift_labels,
|
|
178
174
|
hidden_size=self.config.hidden_size,
|
|
179
175
|
**loss_kwargs,
|
|
180
176
|
)
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
8
|
+
|
|
9
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def lce_forward(
|
|
13
|
+
self,
|
|
14
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
15
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
16
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
17
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
18
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
19
|
+
labels: Optional[torch.LongTensor] = None,
|
|
20
|
+
use_cache: Optional[bool] = None,
|
|
21
|
+
output_attentions: Optional[bool] = None,
|
|
22
|
+
output_hidden_states: Optional[bool] = None,
|
|
23
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
24
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
25
|
+
**kwargs,
|
|
26
|
+
) -> CausalLMOutputWithPast:
|
|
27
|
+
r"""
|
|
28
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
29
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
30
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
31
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
32
|
+
|
|
33
|
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
34
|
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
35
|
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
36
|
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
37
|
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
38
|
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
|
|
42
|
+
Example:
|
|
43
|
+
|
|
44
|
+
```python
|
|
45
|
+
>>> from transformers import AutoTokenizer, Qwen3ForCausalLM
|
|
46
|
+
|
|
47
|
+
>>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
|
|
48
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
|
|
49
|
+
|
|
50
|
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
51
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
52
|
+
|
|
53
|
+
>>> # Generate
|
|
54
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
55
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
56
|
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
57
|
+
```"""
|
|
58
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
59
|
+
output_hidden_states = (
|
|
60
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
64
|
+
outputs = self.model(
|
|
65
|
+
input_ids=input_ids,
|
|
66
|
+
attention_mask=attention_mask,
|
|
67
|
+
position_ids=position_ids,
|
|
68
|
+
past_key_values=past_key_values,
|
|
69
|
+
inputs_embeds=inputs_embeds,
|
|
70
|
+
use_cache=use_cache,
|
|
71
|
+
output_attentions=output_attentions,
|
|
72
|
+
output_hidden_states=output_hidden_states,
|
|
73
|
+
cache_position=cache_position,
|
|
74
|
+
**kwargs,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
hidden_states = outputs[0]
|
|
78
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
79
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
80
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
81
|
+
|
|
82
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
83
|
+
logits = None
|
|
84
|
+
loss = None
|
|
85
|
+
# if in training mode, don't materialize logits
|
|
86
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
87
|
+
loss = LigerForCausalLMLoss(
|
|
88
|
+
hidden_states=kept_hidden_states,
|
|
89
|
+
lm_head_weight=self.lm_head.weight,
|
|
90
|
+
labels=labels,
|
|
91
|
+
shift_labels=shift_labels,
|
|
92
|
+
hidden_size=self.config.hidden_size,
|
|
93
|
+
**kwargs,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
else: # if in inference mode materialize logits
|
|
97
|
+
logits = self.lm_head(kept_hidden_states)
|
|
98
|
+
if labels is not None:
|
|
99
|
+
loss = self.loss_function(
|
|
100
|
+
logits=logits,
|
|
101
|
+
labels=labels,
|
|
102
|
+
vocab_size=self.config.vocab_size,
|
|
103
|
+
**kwargs,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
return CausalLMOutputWithPast(
|
|
107
|
+
loss=loss,
|
|
108
|
+
logits=logits,
|
|
109
|
+
past_key_values=outputs.past_key_values,
|
|
110
|
+
hidden_states=outputs.hidden_states,
|
|
111
|
+
attentions=outputs.attentions,
|
|
112
|
+
)
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
|
8
|
+
from transformers.modeling_outputs import MoeModelOutputWithPast
|
|
9
|
+
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
|
|
10
|
+
|
|
11
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def lce_forward(
|
|
15
|
+
self,
|
|
16
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
17
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
18
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
19
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
20
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
21
|
+
labels: Optional[torch.LongTensor] = None,
|
|
22
|
+
use_cache: Optional[bool] = None,
|
|
23
|
+
output_attentions: Optional[bool] = None,
|
|
24
|
+
output_hidden_states: Optional[bool] = None,
|
|
25
|
+
output_router_logits: Optional[bool] = None,
|
|
26
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
27
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
28
|
+
**loss_kwargs,
|
|
29
|
+
) -> MoeCausalLMOutputWithPast:
|
|
30
|
+
r"""
|
|
31
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
32
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
33
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
34
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
35
|
+
|
|
36
|
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
37
|
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
38
|
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
39
|
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
40
|
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
41
|
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
|
|
45
|
+
Example:
|
|
46
|
+
|
|
47
|
+
```python
|
|
48
|
+
>>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM
|
|
49
|
+
|
|
50
|
+
>>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
|
51
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
|
52
|
+
|
|
53
|
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
54
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
55
|
+
|
|
56
|
+
>>> # Generate
|
|
57
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
58
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
59
|
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
60
|
+
```"""
|
|
61
|
+
|
|
62
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
63
|
+
output_router_logits = (
|
|
64
|
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
output_hidden_states = (
|
|
68
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
72
|
+
outputs: MoeModelOutputWithPast = self.model(
|
|
73
|
+
input_ids=input_ids,
|
|
74
|
+
attention_mask=attention_mask,
|
|
75
|
+
position_ids=position_ids,
|
|
76
|
+
past_key_values=past_key_values,
|
|
77
|
+
inputs_embeds=inputs_embeds,
|
|
78
|
+
use_cache=use_cache,
|
|
79
|
+
output_attentions=output_attentions,
|
|
80
|
+
output_hidden_states=output_hidden_states,
|
|
81
|
+
output_router_logits=output_router_logits,
|
|
82
|
+
cache_position=cache_position,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
hidden_states = outputs.last_hidden_state
|
|
86
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
87
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
88
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
89
|
+
|
|
90
|
+
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
91
|
+
logits = None
|
|
92
|
+
loss = None
|
|
93
|
+
|
|
94
|
+
# if in training mode, do not materialize logits
|
|
95
|
+
if self.training and (labels is not None or shift_labels is not None):
|
|
96
|
+
loss = LigerForCausalLMLoss(
|
|
97
|
+
hidden_states=kept_hidden_states,
|
|
98
|
+
lm_head_weight=self.lm_head.weight,
|
|
99
|
+
labels=labels,
|
|
100
|
+
shift_labels=shift_labels,
|
|
101
|
+
hidden_size=self.config.hidden_size,
|
|
102
|
+
**loss_kwargs,
|
|
103
|
+
)
|
|
104
|
+
else: # if in inference model materialize logits
|
|
105
|
+
logits = self.lm_head(kept_hidden_states)
|
|
106
|
+
if labels is not None:
|
|
107
|
+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
|
108
|
+
|
|
109
|
+
aux_loss = None
|
|
110
|
+
if output_router_logits:
|
|
111
|
+
aux_loss = load_balancing_loss_func(
|
|
112
|
+
outputs.router_logits,
|
|
113
|
+
self.num_experts,
|
|
114
|
+
self.num_experts_per_tok,
|
|
115
|
+
attention_mask,
|
|
116
|
+
)
|
|
117
|
+
if labels is not None:
|
|
118
|
+
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
|
119
|
+
|
|
120
|
+
return MoeCausalLMOutputWithPast(
|
|
121
|
+
loss=loss,
|
|
122
|
+
aux_loss=aux_loss,
|
|
123
|
+
logits=logits,
|
|
124
|
+
past_key_values=outputs.past_key_values,
|
|
125
|
+
hidden_states=outputs.hidden_states,
|
|
126
|
+
attentions=outputs.attentions,
|
|
127
|
+
router_logits=outputs.router_logits,
|
|
128
|
+
)
|