liger-kernel 0.5.10__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/ops/dyt.py +0 -2
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +1 -1
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/rms_norm.py +265 -54
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +62 -50
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/transformers/__init__.py +3 -0
- liger_kernel/transformers/functional.py +62 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/model/gemma.py +25 -8
- liger_kernel/transformers/model/gemma2.py +27 -8
- liger_kernel/transformers/model/gemma3.py +62 -98
- liger_kernel/transformers/model/glm4.py +16 -7
- liger_kernel/transformers/model/llama.py +25 -7
- liger_kernel/transformers/model/llama4.py +108 -0
- liger_kernel/transformers/model/llava.py +95 -124
- liger_kernel/transformers/model/mistral.py +13 -8
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +16 -7
- liger_kernel/transformers/model/olmo2.py +16 -7
- liger_kernel/transformers/model/paligemma.py +8 -1
- liger_kernel/transformers/model/phi3.py +25 -8
- liger_kernel/transformers/model/qwen2.py +24 -7
- liger_kernel/transformers/model/qwen2_5_vl.py +41 -91
- liger_kernel/transformers/model/qwen2_vl.py +38 -100
- liger_kernel/transformers/model/qwen3.py +11 -3
- liger_kernel/transformers/model/qwen3_moe.py +10 -6
- liger_kernel/transformers/monkey_patch.py +304 -70
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/rms_norm.py +40 -4
- liger_kernel/transformers/softmax.py +12 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +8 -2
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/RECORD +42 -35
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
- liger_kernel/transformers/gema3_rms.py +0 -8
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
|
@@ -27,6 +27,7 @@ def lce_forward_deprecated(
|
|
|
27
27
|
output_hidden_states: Optional[bool] = None,
|
|
28
28
|
return_dict: Optional[bool] = None,
|
|
29
29
|
cache_position: Optional[torch.LongTensor] = None,
|
|
30
|
+
skip_logits: Optional[bool] = None,
|
|
30
31
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
31
32
|
r"""
|
|
32
33
|
|
|
@@ -81,7 +82,14 @@ def lce_forward_deprecated(
|
|
|
81
82
|
loss = None
|
|
82
83
|
logits = None
|
|
83
84
|
|
|
84
|
-
if
|
|
85
|
+
if skip_logits and labels is None:
|
|
86
|
+
raise ValueError("skip_logits is True, but labels is None")
|
|
87
|
+
|
|
88
|
+
if skip_logits is None:
|
|
89
|
+
# By default, if in training mode, don't materialize logits
|
|
90
|
+
skip_logits = self.training and labels is not None
|
|
91
|
+
|
|
92
|
+
if skip_logits:
|
|
85
93
|
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
86
94
|
shift_labels = labels[..., 1:].contiguous()
|
|
87
95
|
|
|
@@ -137,7 +145,8 @@ def lce_forward(
|
|
|
137
145
|
return_dict: Optional[bool] = None,
|
|
138
146
|
cache_position: Optional[torch.LongTensor] = None,
|
|
139
147
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
140
|
-
|
|
148
|
+
skip_logits: Optional[bool] = None,
|
|
149
|
+
**kwargs,
|
|
141
150
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
142
151
|
r"""
|
|
143
152
|
Args:
|
|
@@ -189,6 +198,7 @@ def lce_forward(
|
|
|
189
198
|
output_hidden_states=output_hidden_states,
|
|
190
199
|
return_dict=return_dict,
|
|
191
200
|
cache_position=cache_position,
|
|
201
|
+
**kwargs,
|
|
192
202
|
)
|
|
193
203
|
|
|
194
204
|
hidden_states = outputs[0]
|
|
@@ -196,27 +206,34 @@ def lce_forward(
|
|
|
196
206
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
197
207
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
198
208
|
|
|
199
|
-
shift_labels =
|
|
209
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
200
210
|
logits = None
|
|
201
211
|
loss = None
|
|
202
|
-
|
|
203
|
-
if
|
|
212
|
+
|
|
213
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
214
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
215
|
+
|
|
216
|
+
if skip_logits is None:
|
|
217
|
+
# By default, if in training mode, don't materialize logits
|
|
218
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
219
|
+
|
|
220
|
+
if skip_logits:
|
|
204
221
|
loss = LigerForCausalLMLoss(
|
|
205
222
|
hidden_states=kept_hidden_states,
|
|
206
223
|
lm_head_weight=self.lm_head.weight,
|
|
207
224
|
labels=labels,
|
|
208
225
|
shift_labels=shift_labels,
|
|
209
226
|
hidden_size=self.config.hidden_size,
|
|
210
|
-
**
|
|
227
|
+
**kwargs,
|
|
211
228
|
)
|
|
212
|
-
else:
|
|
229
|
+
else:
|
|
213
230
|
logits = self.lm_head(kept_hidden_states)
|
|
214
231
|
if labels is not None:
|
|
215
232
|
loss = self.loss_function(
|
|
216
233
|
logits=logits,
|
|
217
234
|
labels=labels,
|
|
218
235
|
vocab_size=self.config.vocab_size,
|
|
219
|
-
**
|
|
236
|
+
**kwargs,
|
|
220
237
|
)
|
|
221
238
|
|
|
222
239
|
if not return_dict:
|
|
@@ -30,6 +30,8 @@ def lce_forward_deprecated(
|
|
|
30
30
|
output_hidden_states: Optional[bool] = None,
|
|
31
31
|
return_dict: Optional[bool] = None,
|
|
32
32
|
cache_position: Optional[torch.LongTensor] = None,
|
|
33
|
+
skip_logits: Optional[bool] = None,
|
|
34
|
+
**kwargs,
|
|
33
35
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
34
36
|
r"""
|
|
35
37
|
Args:
|
|
@@ -76,6 +78,7 @@ def lce_forward_deprecated(
|
|
|
76
78
|
output_hidden_states=output_hidden_states,
|
|
77
79
|
return_dict=return_dict,
|
|
78
80
|
cache_position=cache_position,
|
|
81
|
+
**kwargs,
|
|
79
82
|
)
|
|
80
83
|
|
|
81
84
|
hidden_states = outputs[0]
|
|
@@ -83,7 +86,14 @@ def lce_forward_deprecated(
|
|
|
83
86
|
loss = None
|
|
84
87
|
logits = None
|
|
85
88
|
|
|
86
|
-
if
|
|
89
|
+
if skip_logits and labels is None:
|
|
90
|
+
raise ValueError("skip_logits is True, but labels is None")
|
|
91
|
+
|
|
92
|
+
if skip_logits is None:
|
|
93
|
+
# By default, if in training mode, don't materialize logits
|
|
94
|
+
skip_logits = self.training and labels is not None
|
|
95
|
+
|
|
96
|
+
if skip_logits:
|
|
87
97
|
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
88
98
|
shift_labels = labels[..., 1:].contiguous()
|
|
89
99
|
|
|
@@ -146,7 +156,8 @@ def lce_forward(
|
|
|
146
156
|
return_dict: Optional[bool] = None,
|
|
147
157
|
cache_position: Optional[torch.LongTensor] = None,
|
|
148
158
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
149
|
-
|
|
159
|
+
skip_logits: Optional[bool] = None,
|
|
160
|
+
**kwargs,
|
|
150
161
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
151
162
|
r"""
|
|
152
163
|
Args:
|
|
@@ -203,6 +214,7 @@ def lce_forward(
|
|
|
203
214
|
output_hidden_states=output_hidden_states,
|
|
204
215
|
return_dict=return_dict,
|
|
205
216
|
cache_position=cache_position,
|
|
217
|
+
**kwargs,
|
|
206
218
|
)
|
|
207
219
|
|
|
208
220
|
hidden_states = outputs[0]
|
|
@@ -210,11 +222,18 @@ def lce_forward(
|
|
|
210
222
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
211
223
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
212
224
|
|
|
213
|
-
shift_labels =
|
|
225
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
214
226
|
logits = None
|
|
215
227
|
loss = None
|
|
216
|
-
|
|
217
|
-
if
|
|
228
|
+
|
|
229
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
230
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
231
|
+
|
|
232
|
+
if skip_logits is None:
|
|
233
|
+
# By default, if in training mode, don't materialize logits
|
|
234
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
235
|
+
|
|
236
|
+
if skip_logits:
|
|
218
237
|
loss = LigerForCausalLMLoss(
|
|
219
238
|
hidden_states=kept_hidden_states,
|
|
220
239
|
lm_head_weight=self.lm_head.weight,
|
|
@@ -222,10 +241,10 @@ def lce_forward(
|
|
|
222
241
|
shift_labels=shift_labels,
|
|
223
242
|
hidden_size=self.config.hidden_size,
|
|
224
243
|
final_logit_softcapping=self.config.final_logit_softcapping,
|
|
225
|
-
**
|
|
244
|
+
**kwargs,
|
|
226
245
|
)
|
|
227
246
|
|
|
228
|
-
else:
|
|
247
|
+
else:
|
|
229
248
|
logits = self.lm_head(kept_hidden_states)
|
|
230
249
|
if self.config.final_logit_softcapping is not None:
|
|
231
250
|
logits = logits / self.config.final_logit_softcapping
|
|
@@ -234,7 +253,7 @@ def lce_forward(
|
|
|
234
253
|
|
|
235
254
|
loss = None
|
|
236
255
|
if labels is not None:
|
|
237
|
-
loss = self.loss_function(logits, labels, self.vocab_size, **
|
|
256
|
+
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
|
238
257
|
|
|
239
258
|
if not return_dict:
|
|
240
259
|
output = (logits,) + outputs[1:]
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
from typing import List
|
|
2
1
|
from typing import Optional
|
|
3
2
|
from typing import Tuple
|
|
4
3
|
from typing import Union
|
|
@@ -10,9 +9,7 @@ from transformers.cache_utils import Cache
|
|
|
10
9
|
from transformers.cache_utils import HybridCache
|
|
11
10
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
12
11
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast
|
|
13
|
-
from transformers.utils import is_torchdynamo_compiling
|
|
14
12
|
from transformers.utils import logging
|
|
15
|
-
from transformers.utils.deprecation import deprecate_kwarg
|
|
16
13
|
|
|
17
14
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
18
15
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
@@ -20,7 +17,6 @@ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
|
20
17
|
logger = logging.get_logger(__name__)
|
|
21
18
|
|
|
22
19
|
|
|
23
|
-
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
24
20
|
def causal_forward(
|
|
25
21
|
self,
|
|
26
22
|
input_ids: torch.LongTensor = None,
|
|
@@ -35,6 +31,7 @@ def causal_forward(
|
|
|
35
31
|
return_dict: Optional[bool] = None,
|
|
36
32
|
cache_position: Optional[torch.LongTensor] = None,
|
|
37
33
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
34
|
+
skip_logits: Optional[bool] = None,
|
|
38
35
|
**loss_kwargs,
|
|
39
36
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
40
37
|
r"""
|
|
@@ -101,7 +98,11 @@ def causal_forward(
|
|
|
101
98
|
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
102
99
|
loss = None
|
|
103
100
|
logits = None
|
|
104
|
-
|
|
101
|
+
|
|
102
|
+
if skip_logits is None:
|
|
103
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
104
|
+
|
|
105
|
+
if skip_logits:
|
|
105
106
|
loss = LigerForCausalLMLoss(
|
|
106
107
|
hidden_states=kept_hidden_states,
|
|
107
108
|
lm_head_weight=self.lm_head.weight,
|
|
@@ -134,14 +135,13 @@ def causal_forward(
|
|
|
134
135
|
)
|
|
135
136
|
|
|
136
137
|
|
|
137
|
-
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
138
138
|
def multimodal_forward(
|
|
139
139
|
self,
|
|
140
140
|
input_ids: torch.LongTensor = None,
|
|
141
141
|
pixel_values: torch.FloatTensor = None,
|
|
142
142
|
attention_mask: Optional[torch.Tensor] = None,
|
|
143
143
|
position_ids: Optional[torch.LongTensor] = None,
|
|
144
|
-
past_key_values: Optional[Union[
|
|
144
|
+
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
|
|
145
145
|
token_type_ids: Optional[torch.LongTensor] = None,
|
|
146
146
|
cache_position: Optional[torch.LongTensor] = None,
|
|
147
147
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
@@ -151,22 +151,14 @@ def multimodal_forward(
|
|
|
151
151
|
output_hidden_states: Optional[bool] = None,
|
|
152
152
|
return_dict: Optional[bool] = None,
|
|
153
153
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
154
|
+
skip_logits: Optional[bool] = None,
|
|
154
155
|
**lm_kwargs,
|
|
155
|
-
) -> Union[
|
|
156
|
+
) -> Union[tuple, Gemma3CausalLMOutputWithPast]:
|
|
156
157
|
r"""
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
163
|
-
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
164
|
-
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
165
|
-
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
166
|
-
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
167
|
-
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
168
|
-
|
|
169
|
-
Returns:
|
|
158
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
159
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
160
|
+
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
161
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
|
170
162
|
|
|
171
163
|
Example:
|
|
172
164
|
|
|
@@ -175,23 +167,37 @@ def multimodal_forward(
|
|
|
175
167
|
>>> import requests
|
|
176
168
|
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
|
177
169
|
|
|
178
|
-
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/
|
|
179
|
-
>>> processor = AutoProcessor.from_pretrained("google/
|
|
180
|
-
|
|
181
|
-
>>>
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
170
|
+
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
|
|
171
|
+
>>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
|
|
172
|
+
|
|
173
|
+
>>> messages = [
|
|
174
|
+
... {
|
|
175
|
+
... "role": "system",
|
|
176
|
+
... "content": [
|
|
177
|
+
... {"type": "text", "text": "You are a helpful assistant."}
|
|
178
|
+
... ]
|
|
179
|
+
... },
|
|
180
|
+
... {
|
|
181
|
+
... "role": "user", "content": [
|
|
182
|
+
... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
|
|
183
|
+
... {"type": "text", "text": "Where is the cat standing?"},
|
|
184
|
+
... ]
|
|
185
|
+
... },
|
|
186
|
+
... ]
|
|
187
|
+
|
|
188
|
+
>>> inputs = processor.apply_chat_template(
|
|
189
|
+
... messages,
|
|
190
|
+
... tokenize=True,
|
|
191
|
+
... return_dict=True,
|
|
192
|
+
... return_tensors="pt",
|
|
193
|
+
... add_generation_prompt=True
|
|
194
|
+
... )
|
|
187
195
|
>>> # Generate
|
|
188
|
-
>>> generate_ids = model.generate(**inputs
|
|
196
|
+
>>> generate_ids = model.generate(**inputs)
|
|
189
197
|
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
190
|
-
"
|
|
191
|
-
```
|
|
192
|
-
|
|
193
|
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
194
|
-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
198
|
+
"user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
|
|
199
|
+
```
|
|
200
|
+
"""
|
|
195
201
|
|
|
196
202
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
197
203
|
output_hidden_states = (
|
|
@@ -199,81 +205,38 @@ def multimodal_forward(
|
|
|
199
205
|
)
|
|
200
206
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
201
207
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
llm_input_ids = input_ids.clone()
|
|
208
|
-
llm_input_ids[special_image_mask] = 0
|
|
209
|
-
else:
|
|
210
|
-
llm_input_ids = input_ids
|
|
211
|
-
|
|
212
|
-
if inputs_embeds is None:
|
|
213
|
-
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
|
214
|
-
|
|
215
|
-
if cache_position is None:
|
|
216
|
-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
217
|
-
cache_position = torch.arange(
|
|
218
|
-
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
219
|
-
)
|
|
220
|
-
|
|
221
|
-
if position_ids is None:
|
|
222
|
-
position_ids = cache_position.unsqueeze(0) + 1 # Gemma3 positions are 1-indexed
|
|
223
|
-
|
|
224
|
-
# Merge text and images
|
|
225
|
-
if pixel_values is not None:
|
|
226
|
-
image_features = self.get_image_features(pixel_values)
|
|
227
|
-
|
|
228
|
-
if input_ids is None:
|
|
229
|
-
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
230
|
-
torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
|
|
231
|
-
)
|
|
232
|
-
else:
|
|
233
|
-
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
|
234
|
-
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
235
|
-
|
|
236
|
-
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
|
237
|
-
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
|
238
|
-
raise ValueError(
|
|
239
|
-
f"Number of images does not match number of special image tokens in the input text. "
|
|
240
|
-
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
|
241
|
-
"tokens from image embeddings."
|
|
242
|
-
)
|
|
243
|
-
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
244
|
-
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
|
245
|
-
|
|
246
|
-
# mask out pad-token-ids in labels for BC
|
|
247
|
-
if labels is not None and self.pad_token_id in labels:
|
|
248
|
-
logger.warning_once(
|
|
249
|
-
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
|
|
250
|
-
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
|
|
251
|
-
)
|
|
252
|
-
labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
|
|
253
|
-
|
|
254
|
-
causal_mask = self._update_causal_mask(
|
|
255
|
-
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
|
|
256
|
-
)
|
|
257
|
-
outputs = self.language_model.model(
|
|
258
|
-
attention_mask=causal_mask,
|
|
208
|
+
outputs = self.model(
|
|
209
|
+
input_ids=input_ids,
|
|
210
|
+
pixel_values=pixel_values,
|
|
211
|
+
token_type_ids=token_type_ids,
|
|
212
|
+
attention_mask=attention_mask,
|
|
259
213
|
position_ids=position_ids,
|
|
260
214
|
past_key_values=past_key_values,
|
|
261
215
|
inputs_embeds=inputs_embeds,
|
|
262
216
|
use_cache=use_cache,
|
|
217
|
+
labels=labels,
|
|
263
218
|
output_attentions=output_attentions,
|
|
264
219
|
output_hidden_states=output_hidden_states,
|
|
265
220
|
return_dict=return_dict,
|
|
266
221
|
cache_position=cache_position,
|
|
267
|
-
logits_to_keep=logits_to_keep,
|
|
268
222
|
**lm_kwargs,
|
|
269
223
|
)
|
|
270
224
|
|
|
271
225
|
hidden_states = outputs[0]
|
|
226
|
+
|
|
227
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
228
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
229
|
+
|
|
272
230
|
loss = None
|
|
273
231
|
logits = None
|
|
232
|
+
if skip_logits and labels is None:
|
|
233
|
+
raise ValueError("skip_logits is True, but labels is None")
|
|
234
|
+
|
|
235
|
+
if skip_logits is None:
|
|
236
|
+
skip_logits = self.training and (labels is not None)
|
|
274
237
|
|
|
275
|
-
if
|
|
276
|
-
shift_hidden_states =
|
|
238
|
+
if skip_logits:
|
|
239
|
+
shift_hidden_states = kept_hidden_states[..., :-1, :]
|
|
277
240
|
shift_labels = labels[..., 1:]
|
|
278
241
|
|
|
279
242
|
hidden_device = shift_hidden_states.device
|
|
@@ -294,7 +257,7 @@ def multimodal_forward(
|
|
|
294
257
|
lce = LigerFusedLinearCrossEntropyLoss()
|
|
295
258
|
loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
|
|
296
259
|
else:
|
|
297
|
-
logits = self.
|
|
260
|
+
logits = self.lm_head(kept_hidden_states)
|
|
298
261
|
if labels is not None:
|
|
299
262
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
300
263
|
logits = logits.float()
|
|
@@ -315,6 +278,7 @@ def multimodal_forward(
|
|
|
315
278
|
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
|
316
279
|
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
|
317
280
|
loss = loss_fct(flat_logits, flat_labels)
|
|
281
|
+
|
|
318
282
|
if not return_dict:
|
|
319
283
|
output = (logits,) + outputs[1:]
|
|
320
284
|
return (loss,) + output if loss is not None else output
|
|
@@ -325,5 +289,5 @@ def multimodal_forward(
|
|
|
325
289
|
past_key_values=outputs.past_key_values,
|
|
326
290
|
hidden_states=outputs.hidden_states,
|
|
327
291
|
attentions=outputs.attentions,
|
|
328
|
-
image_hidden_states=
|
|
292
|
+
image_hidden_states=outputs.image_hidden_states,
|
|
329
293
|
)
|
|
@@ -26,7 +26,8 @@ def lce_forward(
|
|
|
26
26
|
return_dict: Optional[bool] = None,
|
|
27
27
|
cache_position: Optional[torch.LongTensor] = None,
|
|
28
28
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
29
|
-
|
|
29
|
+
skip_logits: Optional[bool] = None,
|
|
30
|
+
**kwargs,
|
|
30
31
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
31
32
|
r"""
|
|
32
33
|
Args:
|
|
@@ -79,6 +80,7 @@ def lce_forward(
|
|
|
79
80
|
output_hidden_states=output_hidden_states,
|
|
80
81
|
return_dict=return_dict,
|
|
81
82
|
cache_position=cache_position,
|
|
83
|
+
**kwargs,
|
|
82
84
|
)
|
|
83
85
|
|
|
84
86
|
hidden_states = outputs[0]
|
|
@@ -86,28 +88,35 @@ def lce_forward(
|
|
|
86
88
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
87
89
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
88
90
|
|
|
89
|
-
shift_labels =
|
|
91
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
90
92
|
logits = None
|
|
91
93
|
loss = None
|
|
92
|
-
|
|
93
|
-
if
|
|
94
|
+
|
|
95
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
96
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
97
|
+
|
|
98
|
+
if skip_logits is None:
|
|
99
|
+
# By default, if in training mode, don't materialize logits
|
|
100
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
101
|
+
|
|
102
|
+
if skip_logits:
|
|
94
103
|
loss = LigerForCausalLMLoss(
|
|
95
104
|
hidden_states=kept_hidden_states,
|
|
96
105
|
lm_head_weight=self.lm_head.weight,
|
|
97
106
|
labels=labels,
|
|
98
107
|
shift_labels=shift_labels,
|
|
99
108
|
hidden_size=self.config.hidden_size,
|
|
100
|
-
**
|
|
109
|
+
**kwargs,
|
|
101
110
|
)
|
|
102
111
|
|
|
103
|
-
else:
|
|
112
|
+
else:
|
|
104
113
|
logits = self.lm_head(kept_hidden_states)
|
|
105
114
|
if labels is not None:
|
|
106
115
|
loss = self.loss_function(
|
|
107
116
|
logits=logits,
|
|
108
117
|
labels=labels,
|
|
109
118
|
vocab_size=self.config.vocab_size,
|
|
110
|
-
**
|
|
119
|
+
**kwargs,
|
|
111
120
|
)
|
|
112
121
|
|
|
113
122
|
return CausalLMOutputWithPast(
|
|
@@ -37,6 +37,7 @@ def lce_forward_deprecated(
|
|
|
37
37
|
output_hidden_states: Optional[bool] = None,
|
|
38
38
|
return_dict: Optional[bool] = None,
|
|
39
39
|
cache_position: Optional[torch.LongTensor] = None,
|
|
40
|
+
skip_logits: Optional[bool] = None,
|
|
40
41
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
41
42
|
r"""
|
|
42
43
|
Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy
|
|
@@ -91,7 +92,15 @@ def lce_forward_deprecated(
|
|
|
91
92
|
loss = None
|
|
92
93
|
logits = None
|
|
93
94
|
|
|
94
|
-
if
|
|
95
|
+
# if in training mode, don't materialize logits
|
|
96
|
+
if skip_logits and labels is None:
|
|
97
|
+
raise ValueError("skip_logits is True, but labels is None")
|
|
98
|
+
|
|
99
|
+
if skip_logits is None:
|
|
100
|
+
# By default, if in training mode, don't materialize logits
|
|
101
|
+
skip_logits = self.training and labels is not None
|
|
102
|
+
|
|
103
|
+
if skip_logits:
|
|
95
104
|
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
96
105
|
shift_labels = labels[..., 1:].contiguous()
|
|
97
106
|
|
|
@@ -151,7 +160,8 @@ def lce_forward(
|
|
|
151
160
|
return_dict: Optional[bool] = None,
|
|
152
161
|
cache_position: Optional[torch.LongTensor] = None,
|
|
153
162
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
154
|
-
|
|
163
|
+
skip_logits: Optional[bool] = None,
|
|
164
|
+
**kwargs,
|
|
155
165
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
156
166
|
r"""
|
|
157
167
|
Args:
|
|
@@ -204,6 +214,7 @@ def lce_forward(
|
|
|
204
214
|
output_hidden_states=output_hidden_states,
|
|
205
215
|
return_dict=return_dict,
|
|
206
216
|
cache_position=cache_position,
|
|
217
|
+
**kwargs,
|
|
207
218
|
)
|
|
208
219
|
|
|
209
220
|
hidden_states = outputs[0]
|
|
@@ -214,28 +225,35 @@ def lce_forward(
|
|
|
214
225
|
if self.config.pretraining_tp > 1:
|
|
215
226
|
raise Exception("Liger Kernel does not support pretraining_tp!!")
|
|
216
227
|
|
|
217
|
-
shift_labels =
|
|
228
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
218
229
|
logits = None
|
|
219
230
|
loss = None
|
|
220
231
|
# if in training mode, don't materialize logits
|
|
221
|
-
if
|
|
232
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
233
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
234
|
+
|
|
235
|
+
if skip_logits is None:
|
|
236
|
+
# By default, if in training mode, don't materialize logits
|
|
237
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
238
|
+
|
|
239
|
+
if skip_logits:
|
|
222
240
|
loss = lce_maybe_trainable_lm_head(
|
|
223
241
|
self,
|
|
224
242
|
hidden_states=kept_hidden_states,
|
|
225
243
|
hidden_size=self.config.hidden_size,
|
|
226
244
|
labels=labels,
|
|
227
245
|
shift_labels=shift_labels,
|
|
228
|
-
**
|
|
246
|
+
**kwargs,
|
|
229
247
|
)
|
|
230
248
|
|
|
231
|
-
else:
|
|
249
|
+
else:
|
|
232
250
|
logits = self.lm_head(kept_hidden_states)
|
|
233
251
|
if labels is not None:
|
|
234
252
|
loss = self.loss_function(
|
|
235
253
|
logits=logits,
|
|
236
254
|
labels=labels,
|
|
237
255
|
vocab_size=self.config.vocab_size,
|
|
238
|
-
**
|
|
256
|
+
**kwargs,
|
|
239
257
|
)
|
|
240
258
|
|
|
241
259
|
if not return_dict:
|