liger-kernel 0.5.9__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
- liger_kernel/chunked_loss/dpo_loss.py +1 -1
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/jsd_loss.py +2 -2
- liger_kernel/ops/dyt.py +111 -179
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +1 -1
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/rms_norm.py +265 -54
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/transformers/__init__.py +8 -0
- liger_kernel/transformers/dyt.py +5 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +70 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/model/gemma.py +25 -16
- liger_kernel/transformers/model/gemma2.py +27 -14
- liger_kernel/transformers/model/gemma3.py +62 -106
- liger_kernel/transformers/model/glm4.py +16 -13
- liger_kernel/transformers/model/llama.py +81 -18
- liger_kernel/transformers/model/llama4.py +108 -0
- liger_kernel/transformers/model/llava.py +95 -132
- liger_kernel/transformers/model/mistral.py +13 -14
- liger_kernel/transformers/model/mixtral.py +16 -15
- liger_kernel/transformers/model/mllama.py +16 -14
- liger_kernel/transformers/model/olmo2.py +16 -13
- liger_kernel/transformers/model/paligemma.py +8 -9
- liger_kernel/transformers/model/phi3.py +25 -16
- liger_kernel/transformers/model/qwen2.py +24 -15
- liger_kernel/transformers/model/qwen2_5_vl.py +41 -97
- liger_kernel/transformers/model/qwen2_vl.py +38 -106
- liger_kernel/transformers/model/qwen3.py +11 -9
- liger_kernel/transformers/model/qwen3_moe.py +132 -0
- liger_kernel/transformers/monkey_patch.py +424 -81
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/rms_norm.py +40 -4
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +21 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/utils.py +11 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +41 -21
- liger_kernel-0.6.0.dist-info/RECORD +97 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
- liger_kernel/transformers/gema3_rms.py +0 -8
- liger_kernel-0.5.9.dist-info/RECORD +0 -84
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
|
@@ -7,18 +7,12 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from transformers.cache_utils import Cache
|
|
9
9
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
10
|
-
from transformers.models.mistral.modeling_mistral import _CONFIG_FOR_DOC
|
|
11
|
-
from transformers.models.mistral.modeling_mistral import MISTRAL_INPUTS_DOCSTRING
|
|
12
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
|
-
from transformers.utils import replace_return_docstrings
|
|
14
10
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
15
11
|
|
|
16
12
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
17
13
|
|
|
18
14
|
|
|
19
15
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
20
|
-
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
|
21
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
22
16
|
def lce_forward(
|
|
23
17
|
self,
|
|
24
18
|
input_ids: torch.LongTensor = None,
|
|
@@ -33,7 +27,8 @@ def lce_forward(
|
|
|
33
27
|
return_dict: Optional[bool] = None,
|
|
34
28
|
cache_position: Optional[torch.LongTensor] = None,
|
|
35
29
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
36
|
-
|
|
30
|
+
skip_logits: Optional[bool] = None,
|
|
31
|
+
**kwargs,
|
|
37
32
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
38
33
|
r"""
|
|
39
34
|
Copy paste Mistral's forward but replace torch cross entropy with liger fused linear cross entropy
|
|
@@ -88,6 +83,7 @@ def lce_forward(
|
|
|
88
83
|
output_hidden_states=output_hidden_states,
|
|
89
84
|
return_dict=return_dict,
|
|
90
85
|
cache_position=cache_position,
|
|
86
|
+
**kwargs,
|
|
91
87
|
)
|
|
92
88
|
|
|
93
89
|
hidden_states = outputs[0]
|
|
@@ -95,18 +91,24 @@ def lce_forward(
|
|
|
95
91
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
96
92
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
97
93
|
|
|
98
|
-
shift_labels =
|
|
94
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
99
95
|
loss = None
|
|
100
96
|
logits = None
|
|
101
97
|
|
|
102
|
-
if
|
|
98
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
99
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
100
|
+
|
|
101
|
+
if skip_logits is None:
|
|
102
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
103
|
+
|
|
104
|
+
if skip_logits:
|
|
103
105
|
loss = LigerForCausalLMLoss(
|
|
104
106
|
hidden_states=kept_hidden_states,
|
|
105
107
|
lm_head_weight=self.lm_head.weight,
|
|
106
108
|
labels=labels,
|
|
107
109
|
shift_labels=shift_labels,
|
|
108
110
|
hidden_size=self.config.hidden_size,
|
|
109
|
-
**
|
|
111
|
+
**kwargs,
|
|
110
112
|
)
|
|
111
113
|
|
|
112
114
|
else:
|
|
@@ -118,7 +120,7 @@ def lce_forward(
|
|
|
118
120
|
logits=logits,
|
|
119
121
|
labels=labels,
|
|
120
122
|
vocab_size=self.config.vocab_size,
|
|
121
|
-
**
|
|
123
|
+
**kwargs,
|
|
122
124
|
)
|
|
123
125
|
if not return_dict:
|
|
124
126
|
output = (logits,) + outputs[1:]
|
|
@@ -131,6 +133,3 @@ def lce_forward(
|
|
|
131
133
|
hidden_states=outputs.hidden_states,
|
|
132
134
|
attentions=outputs.attentions,
|
|
133
135
|
)
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
# Note: Grad Acc is not fixed in mistral at transformer 4.46.1
|
|
@@ -7,19 +7,13 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
9
|
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
|
10
|
-
from transformers.models.mixtral.modeling_mixtral import _CONFIG_FOR_DOC
|
|
11
|
-
from transformers.models.mixtral.modeling_mixtral import MIXTRAL_INPUTS_DOCSTRING
|
|
12
10
|
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
|
|
13
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
14
|
-
from transformers.utils import replace_return_docstrings
|
|
15
11
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
16
12
|
|
|
17
13
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
18
14
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
19
15
|
|
|
20
16
|
|
|
21
|
-
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
|
|
22
|
-
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
23
17
|
def lce_forward_deprecated(
|
|
24
18
|
self,
|
|
25
19
|
input_ids: torch.LongTensor = None,
|
|
@@ -146,8 +140,6 @@ def lce_forward_deprecated(
|
|
|
146
140
|
|
|
147
141
|
|
|
148
142
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
149
|
-
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
|
|
150
|
-
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
151
143
|
# Ignore copy
|
|
152
144
|
def lce_forward(
|
|
153
145
|
self,
|
|
@@ -164,7 +156,8 @@ def lce_forward(
|
|
|
164
156
|
return_dict: Optional[bool] = None,
|
|
165
157
|
cache_position: Optional[torch.LongTensor] = None,
|
|
166
158
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
167
|
-
|
|
159
|
+
skip_logits: Optional[bool] = None,
|
|
160
|
+
**kwargs,
|
|
168
161
|
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
|
169
162
|
r"""
|
|
170
163
|
Args:
|
|
@@ -222,6 +215,7 @@ def lce_forward(
|
|
|
222
215
|
output_router_logits=output_router_logits,
|
|
223
216
|
return_dict=return_dict,
|
|
224
217
|
cache_position=cache_position,
|
|
218
|
+
**kwargs,
|
|
225
219
|
)
|
|
226
220
|
|
|
227
221
|
hidden_states = outputs[0]
|
|
@@ -229,26 +223,33 @@ def lce_forward(
|
|
|
229
223
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
230
224
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
231
225
|
|
|
232
|
-
shift_labels =
|
|
226
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
233
227
|
logits = None
|
|
234
228
|
loss = None
|
|
235
|
-
|
|
236
|
-
if
|
|
229
|
+
|
|
230
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
231
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
232
|
+
|
|
233
|
+
if skip_logits is None:
|
|
234
|
+
# By default, if in training mode, don't materialize logits
|
|
235
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
236
|
+
|
|
237
|
+
if skip_logits:
|
|
237
238
|
loss = LigerForCausalLMLoss(
|
|
238
239
|
hidden_states=kept_hidden_states,
|
|
239
240
|
lm_head_weight=self.lm_head.weight,
|
|
240
241
|
labels=labels,
|
|
241
242
|
shift_labels=shift_labels,
|
|
242
243
|
hidden_size=self.config.hidden_size,
|
|
243
|
-
**
|
|
244
|
+
**kwargs,
|
|
244
245
|
)
|
|
245
246
|
|
|
246
|
-
else:
|
|
247
|
+
else:
|
|
247
248
|
logits = self.lm_head(kept_hidden_states)
|
|
248
249
|
|
|
249
250
|
loss = None
|
|
250
251
|
if labels is not None:
|
|
251
|
-
loss = self.loss_function(logits, labels, self.vocab_size, **
|
|
252
|
+
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
|
252
253
|
aux_loss = None
|
|
253
254
|
if output_router_logits:
|
|
254
255
|
aux_loss = load_balancing_loss_func(
|
|
@@ -8,17 +8,12 @@ import torch
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
9
|
from transformers.cache_utils import Cache
|
|
10
10
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
11
|
-
from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING
|
|
12
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
|
-
from transformers.utils import replace_return_docstrings
|
|
14
11
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
15
12
|
|
|
16
13
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
17
14
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
18
15
|
|
|
19
16
|
|
|
20
|
-
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
|
21
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
|
|
22
17
|
def lce_forward_deprecated(
|
|
23
18
|
self,
|
|
24
19
|
input_ids: torch.LongTensor = None,
|
|
@@ -135,8 +130,6 @@ def lce_forward_deprecated(
|
|
|
135
130
|
|
|
136
131
|
|
|
137
132
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
138
|
-
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
|
139
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig")
|
|
140
133
|
def lce_forward(
|
|
141
134
|
self,
|
|
142
135
|
input_ids: torch.LongTensor = None,
|
|
@@ -154,7 +147,8 @@ def lce_forward(
|
|
|
154
147
|
return_dict: Optional[bool] = None,
|
|
155
148
|
cache_position: Optional[torch.LongTensor] = None,
|
|
156
149
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
157
|
-
|
|
150
|
+
skip_logits: Optional[bool] = None,
|
|
151
|
+
**kwargs,
|
|
158
152
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
159
153
|
r"""
|
|
160
154
|
Args:
|
|
@@ -212,6 +206,7 @@ def lce_forward(
|
|
|
212
206
|
output_hidden_states=output_hidden_states,
|
|
213
207
|
return_dict=return_dict,
|
|
214
208
|
cache_position=cache_position,
|
|
209
|
+
**kwargs,
|
|
215
210
|
)
|
|
216
211
|
|
|
217
212
|
hidden_states = outputs[0]
|
|
@@ -219,28 +214,35 @@ def lce_forward(
|
|
|
219
214
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
220
215
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
221
216
|
|
|
222
|
-
shift_labels =
|
|
217
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
223
218
|
logits = None
|
|
224
219
|
loss = None
|
|
225
|
-
|
|
226
|
-
if
|
|
220
|
+
|
|
221
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
222
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
223
|
+
|
|
224
|
+
if skip_logits is None:
|
|
225
|
+
# By default, if in training mode, don't materialize logits
|
|
226
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
227
|
+
|
|
228
|
+
if skip_logits:
|
|
227
229
|
loss = LigerForCausalLMLoss(
|
|
228
230
|
hidden_states=kept_hidden_states,
|
|
229
231
|
lm_head_weight=self.lm_head.weight,
|
|
230
232
|
labels=labels,
|
|
231
233
|
shift_labels=shift_labels,
|
|
232
234
|
hidden_size=self.config.hidden_size,
|
|
233
|
-
**
|
|
235
|
+
**kwargs,
|
|
234
236
|
)
|
|
235
237
|
|
|
236
|
-
else:
|
|
238
|
+
else:
|
|
237
239
|
logits = self.lm_head(kept_hidden_states)
|
|
238
240
|
if labels is not None:
|
|
239
241
|
loss = self.loss_function(
|
|
240
242
|
logits=logits,
|
|
241
243
|
labels=labels,
|
|
242
244
|
vocab_size=self.config.vocab_size,
|
|
243
|
-
**
|
|
245
|
+
**kwargs,
|
|
244
246
|
)
|
|
245
247
|
|
|
246
248
|
if not return_dict:
|
|
@@ -6,18 +6,12 @@ from typing import Union
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
8
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
9
|
-
from transformers.models.olmo2.modeling_olmo2 import _CONFIG_FOR_DOC
|
|
10
|
-
from transformers.models.olmo2.modeling_olmo2 import OLMO2_INPUTS_DOCSTRING
|
|
11
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
12
|
-
from transformers.utils import replace_return_docstrings
|
|
13
9
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
14
10
|
|
|
15
11
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
16
12
|
|
|
17
13
|
|
|
18
14
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
19
|
-
@add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING)
|
|
20
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
21
15
|
def lce_forward(
|
|
22
16
|
self,
|
|
23
17
|
input_ids: torch.LongTensor = None,
|
|
@@ -32,7 +26,8 @@ def lce_forward(
|
|
|
32
26
|
return_dict: Optional[bool] = None,
|
|
33
27
|
cache_position: Optional[torch.LongTensor] = None,
|
|
34
28
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
35
|
-
|
|
29
|
+
skip_logits: Optional[bool] = None,
|
|
30
|
+
**kwargs,
|
|
36
31
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
37
32
|
r"""
|
|
38
33
|
Args:
|
|
@@ -85,6 +80,7 @@ def lce_forward(
|
|
|
85
80
|
output_hidden_states=output_hidden_states,
|
|
86
81
|
return_dict=return_dict,
|
|
87
82
|
cache_position=cache_position,
|
|
83
|
+
**kwargs,
|
|
88
84
|
)
|
|
89
85
|
|
|
90
86
|
hidden_states = outputs[0]
|
|
@@ -92,28 +88,35 @@ def lce_forward(
|
|
|
92
88
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
93
89
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
94
90
|
|
|
95
|
-
shift_labels =
|
|
91
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
96
92
|
logits = None
|
|
97
93
|
loss = None
|
|
98
|
-
|
|
99
|
-
if
|
|
94
|
+
|
|
95
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
96
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
97
|
+
|
|
98
|
+
if skip_logits is None:
|
|
99
|
+
# By default, if in training mode, don't materialize logits
|
|
100
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
101
|
+
|
|
102
|
+
if skip_logits:
|
|
100
103
|
loss = LigerForCausalLMLoss(
|
|
101
104
|
hidden_states=kept_hidden_states,
|
|
102
105
|
lm_head_weight=self.lm_head.weight,
|
|
103
106
|
labels=labels,
|
|
104
107
|
shift_labels=shift_labels,
|
|
105
108
|
hidden_size=self.config.hidden_size,
|
|
106
|
-
**
|
|
109
|
+
**kwargs,
|
|
107
110
|
)
|
|
108
111
|
|
|
109
|
-
else:
|
|
112
|
+
else:
|
|
110
113
|
logits = self.lm_head(kept_hidden_states)
|
|
111
114
|
if labels is not None:
|
|
112
115
|
loss = self.loss_function(
|
|
113
116
|
logits=logits,
|
|
114
117
|
labels=labels,
|
|
115
118
|
vocab_size=self.config.vocab_size,
|
|
116
|
-
**
|
|
119
|
+
**kwargs,
|
|
117
120
|
)
|
|
118
121
|
|
|
119
122
|
return CausalLMOutputWithPast(
|
|
@@ -7,13 +7,9 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
9
|
from transformers.cache_utils import Cache
|
|
10
|
-
from transformers.models.paligemma.modeling_paligemma import _CONFIG_FOR_DOC
|
|
11
|
-
from transformers.models.paligemma.modeling_paligemma import PALIGEMMA_INPUTS_DOCSTRING
|
|
12
10
|
from transformers.models.paligemma.modeling_paligemma import PaliGemmaCausalLMOutputWithPast
|
|
13
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
14
11
|
from transformers.utils import is_torchdynamo_compiling
|
|
15
12
|
from transformers.utils import logging
|
|
16
|
-
from transformers.utils import replace_return_docstrings
|
|
17
13
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
18
14
|
|
|
19
15
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
@@ -21,8 +17,6 @@ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinea
|
|
|
21
17
|
logger = logging.get_logger(__name__)
|
|
22
18
|
|
|
23
19
|
|
|
24
|
-
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
|
|
25
|
-
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
26
20
|
def lce_forward_deprecated(
|
|
27
21
|
self,
|
|
28
22
|
input_ids: torch.LongTensor = None,
|
|
@@ -206,8 +200,6 @@ def lce_forward_deprecated(
|
|
|
206
200
|
|
|
207
201
|
|
|
208
202
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
209
|
-
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
|
|
210
|
-
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
211
203
|
def lce_forward(
|
|
212
204
|
self,
|
|
213
205
|
input_ids: torch.LongTensor = None,
|
|
@@ -224,6 +216,7 @@ def lce_forward(
|
|
|
224
216
|
output_hidden_states: Optional[bool] = None,
|
|
225
217
|
return_dict: Optional[bool] = None,
|
|
226
218
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
219
|
+
skip_logits: Optional[bool] = None,
|
|
227
220
|
**lm_kwargs,
|
|
228
221
|
) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
|
|
229
222
|
r"""
|
|
@@ -339,7 +332,13 @@ def lce_forward(
|
|
|
339
332
|
loss = None
|
|
340
333
|
logits = None
|
|
341
334
|
|
|
342
|
-
if
|
|
335
|
+
if skip_logits and labels is None:
|
|
336
|
+
raise ValueError("skip_logits is True, but labels is None")
|
|
337
|
+
|
|
338
|
+
if skip_logits is None:
|
|
339
|
+
skip_logits = self.training and (labels is not None)
|
|
340
|
+
|
|
341
|
+
if skip_logits:
|
|
343
342
|
shift_hidden_states = hidden_states[..., :-1, :]
|
|
344
343
|
shift_labels = labels[..., 1:]
|
|
345
344
|
|
|
@@ -7,18 +7,12 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
9
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
10
|
-
from transformers.models.phi3.modeling_phi3 import _CONFIG_FOR_DOC
|
|
11
|
-
from transformers.models.phi3.modeling_phi3 import PHI3_INPUTS_DOCSTRING
|
|
12
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
|
-
from transformers.utils import replace_return_docstrings
|
|
14
10
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
15
11
|
|
|
16
12
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
17
13
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
18
14
|
|
|
19
15
|
|
|
20
|
-
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
|
|
21
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
22
16
|
def lce_forward_deprecated(
|
|
23
17
|
self,
|
|
24
18
|
input_ids: torch.LongTensor = None,
|
|
@@ -32,6 +26,7 @@ def lce_forward_deprecated(
|
|
|
32
26
|
output_hidden_states: Optional[bool] = None,
|
|
33
27
|
return_dict: Optional[bool] = None,
|
|
34
28
|
cache_position: Optional[torch.LongTensor] = None,
|
|
29
|
+
skip_logits: Optional[bool] = None,
|
|
35
30
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
36
31
|
r"""
|
|
37
32
|
Copy paste phi3 forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
|
|
@@ -86,7 +81,14 @@ def lce_forward_deprecated(
|
|
|
86
81
|
loss = None
|
|
87
82
|
logits = None
|
|
88
83
|
|
|
89
|
-
if
|
|
84
|
+
if skip_logits and labels is None:
|
|
85
|
+
raise ValueError("skip_logits is True, but labels is None")
|
|
86
|
+
|
|
87
|
+
if skip_logits is None:
|
|
88
|
+
# By default, if in training mode, don't materialize logits
|
|
89
|
+
skip_logits = self.training and labels is not None
|
|
90
|
+
|
|
91
|
+
if skip_logits:
|
|
90
92
|
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
91
93
|
shift_labels = labels[..., 1:].contiguous()
|
|
92
94
|
|
|
@@ -128,8 +130,6 @@ def lce_forward_deprecated(
|
|
|
128
130
|
|
|
129
131
|
|
|
130
132
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
131
|
-
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
|
|
132
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
133
133
|
def lce_forward(
|
|
134
134
|
self,
|
|
135
135
|
input_ids: torch.LongTensor = None,
|
|
@@ -144,7 +144,8 @@ def lce_forward(
|
|
|
144
144
|
return_dict: Optional[bool] = None,
|
|
145
145
|
cache_position: Optional[torch.LongTensor] = None,
|
|
146
146
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
147
|
-
|
|
147
|
+
skip_logits: Optional[bool] = None,
|
|
148
|
+
**kwargs,
|
|
148
149
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
149
150
|
r"""
|
|
150
151
|
Args:
|
|
@@ -210,6 +211,7 @@ def lce_forward(
|
|
|
210
211
|
output_attentions=output_attentions,
|
|
211
212
|
output_hidden_states=output_hidden_states,
|
|
212
213
|
return_dict=return_dict,
|
|
214
|
+
**kwargs,
|
|
213
215
|
)
|
|
214
216
|
|
|
215
217
|
hidden_states = outputs[0]
|
|
@@ -217,28 +219,35 @@ def lce_forward(
|
|
|
217
219
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
218
220
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
219
221
|
|
|
220
|
-
shift_labels =
|
|
222
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
221
223
|
logits = None
|
|
222
224
|
loss = None
|
|
223
|
-
|
|
224
|
-
if
|
|
225
|
+
|
|
226
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
227
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
228
|
+
|
|
229
|
+
if skip_logits is None:
|
|
230
|
+
# By default, if in training mode, don't materialize logits
|
|
231
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
232
|
+
|
|
233
|
+
if skip_logits:
|
|
225
234
|
loss = LigerForCausalLMLoss(
|
|
226
235
|
hidden_states=kept_hidden_states,
|
|
227
236
|
lm_head_weight=self.lm_head.weight,
|
|
228
237
|
labels=labels,
|
|
229
238
|
shift_labels=shift_labels,
|
|
230
239
|
hidden_size=self.config.hidden_size,
|
|
231
|
-
**
|
|
240
|
+
**kwargs,
|
|
232
241
|
)
|
|
233
242
|
|
|
234
|
-
else:
|
|
243
|
+
else:
|
|
235
244
|
logits = self.lm_head(kept_hidden_states)
|
|
236
245
|
if labels is not None:
|
|
237
246
|
loss = self.loss_function(
|
|
238
247
|
logits=logits,
|
|
239
248
|
labels=labels,
|
|
240
249
|
vocab_size=self.config.vocab_size,
|
|
241
|
-
**
|
|
250
|
+
**kwargs,
|
|
242
251
|
)
|
|
243
252
|
|
|
244
253
|
if not return_dict:
|
|
@@ -7,18 +7,12 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
9
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
10
|
-
from transformers.models.qwen2.modeling_qwen2 import _CONFIG_FOR_DOC
|
|
11
|
-
from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING
|
|
12
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
|
-
from transformers.utils import replace_return_docstrings
|
|
14
10
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
15
11
|
|
|
16
12
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
17
13
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
18
14
|
|
|
19
15
|
|
|
20
|
-
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
21
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
22
16
|
def lce_forward_deprecated(
|
|
23
17
|
self,
|
|
24
18
|
input_ids: torch.LongTensor = None,
|
|
@@ -32,6 +26,7 @@ def lce_forward_deprecated(
|
|
|
32
26
|
output_hidden_states: Optional[bool] = None,
|
|
33
27
|
return_dict: Optional[bool] = None,
|
|
34
28
|
cache_position: Optional[torch.LongTensor] = None,
|
|
29
|
+
skip_logits: Optional[bool] = None,
|
|
35
30
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
36
31
|
r"""
|
|
37
32
|
Copy paste Qwen2's forward but replace torch cross entropy with liger fused linear cross entropy
|
|
@@ -86,6 +81,13 @@ def lce_forward_deprecated(
|
|
|
86
81
|
loss = None
|
|
87
82
|
logits = None
|
|
88
83
|
|
|
84
|
+
if skip_logits and labels is None:
|
|
85
|
+
raise ValueError("skip_logits is True, but labels is None")
|
|
86
|
+
|
|
87
|
+
if skip_logits is None:
|
|
88
|
+
# By default, if in training mode, don't materialize logits
|
|
89
|
+
skip_logits = self.training and labels is not None
|
|
90
|
+
|
|
89
91
|
if self.training and (labels is not None):
|
|
90
92
|
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
91
93
|
shift_labels = labels[..., 1:].contiguous()
|
|
@@ -127,8 +129,6 @@ def lce_forward_deprecated(
|
|
|
127
129
|
|
|
128
130
|
|
|
129
131
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
130
|
-
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
131
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
132
132
|
def lce_forward(
|
|
133
133
|
self,
|
|
134
134
|
input_ids: torch.LongTensor = None,
|
|
@@ -143,7 +143,8 @@ def lce_forward(
|
|
|
143
143
|
return_dict: Optional[bool] = None,
|
|
144
144
|
cache_position: Optional[torch.LongTensor] = None,
|
|
145
145
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
146
|
-
|
|
146
|
+
skip_logits: Optional[bool] = None,
|
|
147
|
+
**kwargs,
|
|
147
148
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
148
149
|
r"""
|
|
149
150
|
Args:
|
|
@@ -196,6 +197,7 @@ def lce_forward(
|
|
|
196
197
|
output_hidden_states=output_hidden_states,
|
|
197
198
|
return_dict=return_dict,
|
|
198
199
|
cache_position=cache_position,
|
|
200
|
+
**kwargs,
|
|
199
201
|
)
|
|
200
202
|
|
|
201
203
|
hidden_states = outputs[0]
|
|
@@ -203,28 +205,35 @@ def lce_forward(
|
|
|
203
205
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
204
206
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
205
207
|
|
|
206
|
-
shift_labels =
|
|
208
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
207
209
|
logits = None
|
|
208
210
|
loss = None
|
|
209
|
-
|
|
210
|
-
if
|
|
211
|
+
|
|
212
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
213
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
214
|
+
|
|
215
|
+
if skip_logits is None:
|
|
216
|
+
# By default, if in training mode, don't materialize logits
|
|
217
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
218
|
+
|
|
219
|
+
if skip_logits:
|
|
211
220
|
loss = LigerForCausalLMLoss(
|
|
212
221
|
hidden_states=kept_hidden_states,
|
|
213
222
|
lm_head_weight=self.lm_head.weight,
|
|
214
223
|
labels=labels,
|
|
215
224
|
shift_labels=shift_labels,
|
|
216
225
|
hidden_size=self.config.hidden_size,
|
|
217
|
-
**
|
|
226
|
+
**kwargs,
|
|
218
227
|
)
|
|
219
228
|
|
|
220
|
-
else:
|
|
229
|
+
else:
|
|
221
230
|
logits = self.lm_head(kept_hidden_states)
|
|
222
231
|
if labels is not None:
|
|
223
232
|
loss = self.loss_function(
|
|
224
233
|
logits=logits,
|
|
225
234
|
labels=labels,
|
|
226
235
|
vocab_size=self.config.vocab_size,
|
|
227
|
-
**
|
|
236
|
+
**kwargs,
|
|
228
237
|
)
|
|
229
238
|
|
|
230
239
|
return CausalLMOutputWithPast(
|