liger-kernel 0.5.10__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/ops/dyt.py +0 -2
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +1 -1
- liger_kernel/ops/layer_norm.py +126 -89
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/rms_norm.py +267 -56
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +62 -50
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/transformers/__init__.py +8 -0
- liger_kernel/transformers/functional.py +67 -0
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/model/gemma.py +25 -8
- liger_kernel/transformers/model/gemma2.py +27 -8
- liger_kernel/transformers/model/gemma3.py +63 -99
- liger_kernel/transformers/model/glm4.py +16 -7
- liger_kernel/transformers/model/llama.py +25 -7
- liger_kernel/transformers/model/llama4.py +108 -0
- liger_kernel/transformers/model/llava.py +95 -124
- liger_kernel/transformers/model/mistral.py +13 -8
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +16 -7
- liger_kernel/transformers/model/olmo2.py +16 -7
- liger_kernel/transformers/model/paligemma.py +8 -1
- liger_kernel/transformers/model/phi3.py +25 -8
- liger_kernel/transformers/model/qwen2.py +24 -7
- liger_kernel/transformers/model/qwen2_5_vl.py +41 -91
- liger_kernel/transformers/model/qwen2_vl.py +38 -100
- liger_kernel/transformers/model/qwen3.py +11 -3
- liger_kernel/transformers/model/qwen3_moe.py +10 -6
- liger_kernel/transformers/model/smollm3.py +189 -0
- liger_kernel/transformers/monkey_patch.py +389 -82
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/rms_norm.py +40 -4
- liger_kernel/transformers/softmax.py +12 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/METADATA +18 -14
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/RECORD +47 -37
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/WHEEL +1 -1
- liger_kernel/transformers/gema3_rms.py +0 -8
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/top_level.txt +0 -0
|
@@ -5,11 +5,12 @@ from typing import Union
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
+
from torch.nn import CrossEntropyLoss
|
|
8
9
|
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
|
|
9
10
|
from transformers.utils import is_torchdynamo_compiling
|
|
10
|
-
from transformers.utils.deprecation import deprecate_kwarg
|
|
11
11
|
|
|
12
12
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
13
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
def lce_forward_deprecated(
|
|
@@ -27,6 +28,11 @@ def lce_forward_deprecated(
|
|
|
27
28
|
output_attentions: Optional[bool] = None,
|
|
28
29
|
output_hidden_states: Optional[bool] = None,
|
|
29
30
|
return_dict: Optional[bool] = None,
|
|
31
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
32
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
33
|
+
image_sizes: torch.Tensor = None,
|
|
34
|
+
skip_logits: Optional[bool] = None,
|
|
35
|
+
**lm_kwargs,
|
|
30
36
|
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
|
|
31
37
|
r"""
|
|
32
38
|
Args:
|
|
@@ -35,10 +41,12 @@ def lce_forward_deprecated(
|
|
|
35
41
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
36
42
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
37
43
|
|
|
38
|
-
|
|
39
|
-
|
|
44
|
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
45
|
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
40
46
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
41
47
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
48
|
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
49
|
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
42
50
|
|
|
43
51
|
|
|
44
52
|
Returns:
|
|
@@ -64,7 +72,6 @@ def lce_forward_deprecated(
|
|
|
64
72
|
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
65
73
|
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
|
|
66
74
|
```"""
|
|
67
|
-
|
|
68
75
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
69
76
|
output_hidden_states = (
|
|
70
77
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
@@ -88,73 +95,24 @@ def lce_forward_deprecated(
|
|
|
88
95
|
)
|
|
89
96
|
|
|
90
97
|
if inputs_embeds is None:
|
|
91
|
-
# 1. Extra the input embeddings
|
|
92
98
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
93
99
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
selected_image_feature = selected_image_feature[:, 1:]
|
|
102
|
-
elif vision_feature_select_strategy == "full":
|
|
103
|
-
selected_image_feature = selected_image_feature
|
|
104
|
-
else:
|
|
105
|
-
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
|
106
|
-
|
|
107
|
-
image_features = self.multi_modal_projector(selected_image_feature)
|
|
108
|
-
inputs_embeds = inputs_embeds.to(image_features.dtype)
|
|
109
|
-
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
|
110
|
-
image_features, inputs_embeds, input_ids, attention_mask, labels
|
|
111
|
-
)
|
|
112
|
-
|
|
113
|
-
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
|
114
|
-
# generation with cache
|
|
115
|
-
elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
|
|
116
|
-
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
|
117
|
-
# that are set to 0
|
|
118
|
-
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
|
119
|
-
|
|
120
|
-
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
|
121
|
-
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
|
122
|
-
|
|
123
|
-
# Get the target length
|
|
124
|
-
target_length = input_ids.shape[1]
|
|
125
|
-
past_length = first_layer_past_key_value.shape[-1]
|
|
126
|
-
|
|
127
|
-
extended_attention_mask = torch.ones(
|
|
128
|
-
(attention_mask.shape[0], past_length),
|
|
129
|
-
dtype=attention_mask.dtype,
|
|
130
|
-
device=attention_mask.device,
|
|
131
|
-
)
|
|
132
|
-
|
|
133
|
-
# Filter out only the tokens that can be un-attended, this can happen
|
|
134
|
-
# if one uses Llava + Fused modules where the cache on the
|
|
135
|
-
# first iteration is already big enough, or if one passes custom cache
|
|
136
|
-
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
|
137
|
-
new_batch_index = batch_index[valid_indices]
|
|
138
|
-
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
|
139
|
-
|
|
140
|
-
# Zero-out the places where we don't need to attend
|
|
141
|
-
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
|
142
|
-
|
|
143
|
-
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
|
144
|
-
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
|
145
|
-
|
|
146
|
-
# TODO: @raushan retain only the new behavior after v4.47
|
|
147
|
-
elif image_features is not None:
|
|
148
|
-
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
|
149
|
-
n_image_features = image_features.shape[0] * image_features.shape[1]
|
|
100
|
+
if pixel_values is not None:
|
|
101
|
+
image_features = self.get_image_features(
|
|
102
|
+
pixel_values=pixel_values,
|
|
103
|
+
vision_feature_layer=vision_feature_layer,
|
|
104
|
+
vision_feature_select_strategy=vision_feature_select_strategy,
|
|
105
|
+
image_sizes=image_sizes,
|
|
106
|
+
)
|
|
150
107
|
|
|
151
|
-
|
|
108
|
+
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
|
109
|
+
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
110
|
+
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
|
111
|
+
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
|
112
|
+
n_image_features = image_features.shape[0] * image_features.shape[1]
|
|
152
113
|
raise ValueError(
|
|
153
114
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
154
115
|
)
|
|
155
|
-
special_image_mask = (
|
|
156
|
-
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
157
|
-
)
|
|
158
116
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
159
117
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
|
160
118
|
|
|
@@ -167,13 +125,19 @@ def lce_forward_deprecated(
|
|
|
167
125
|
output_attentions=output_attentions,
|
|
168
126
|
output_hidden_states=output_hidden_states,
|
|
169
127
|
return_dict=return_dict,
|
|
128
|
+
cache_position=cache_position,
|
|
129
|
+
logits_to_keep=logits_to_keep,
|
|
130
|
+
**lm_kwargs,
|
|
170
131
|
)
|
|
171
132
|
hidden_states = outputs[0]
|
|
172
133
|
|
|
173
134
|
loss = None
|
|
174
135
|
logits = None
|
|
175
136
|
|
|
176
|
-
|
|
137
|
+
# Overwrite skip_logits, since llava never materializes logits
|
|
138
|
+
skip_logits = labels is not None
|
|
139
|
+
|
|
140
|
+
if skip_logits:
|
|
177
141
|
# Shift so that tokens < n predict n
|
|
178
142
|
if attention_mask is not None:
|
|
179
143
|
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
|
@@ -188,7 +152,33 @@ def lce_forward_deprecated(
|
|
|
188
152
|
shift_labels = labels[..., 1:].contiguous()
|
|
189
153
|
|
|
190
154
|
lce = LigerFusedLinearCrossEntropyLoss()
|
|
191
|
-
loss = lce(
|
|
155
|
+
loss = lce(
|
|
156
|
+
self.language_model.lm_head.weight,
|
|
157
|
+
shift_hidden_states.view(-1, shift_hidden_states.size(-1)),
|
|
158
|
+
shift_labels.view(-1).to(shift_hidden_states.device),
|
|
159
|
+
)
|
|
160
|
+
else:
|
|
161
|
+
logits = self.language_model.lm_head(hidden_states)
|
|
162
|
+
if labels is not None:
|
|
163
|
+
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
164
|
+
logits = logits.float()
|
|
165
|
+
shift_logits = logits[..., :-1, :]
|
|
166
|
+
shift_labels = labels[..., 1:]
|
|
167
|
+
if attention_mask is not None:
|
|
168
|
+
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
|
169
|
+
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
|
170
|
+
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
|
|
171
|
+
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
|
|
172
|
+
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
|
|
173
|
+
else:
|
|
174
|
+
shift_logits = shift_logits.contiguous()
|
|
175
|
+
shift_labels = shift_labels.contiguous()
|
|
176
|
+
# Flatten the tokens
|
|
177
|
+
loss_fct = CrossEntropyLoss()
|
|
178
|
+
|
|
179
|
+
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
|
180
|
+
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
|
181
|
+
loss = loss_fct(flat_logits, flat_labels)
|
|
192
182
|
|
|
193
183
|
if not return_dict:
|
|
194
184
|
# NOTE: This part has not been tested.
|
|
@@ -201,10 +191,10 @@ def lce_forward_deprecated(
|
|
|
201
191
|
past_key_values=outputs.past_key_values,
|
|
202
192
|
hidden_states=outputs.hidden_states,
|
|
203
193
|
attentions=outputs.attentions,
|
|
194
|
+
image_hidden_states=image_features if pixel_values is not None else None,
|
|
204
195
|
)
|
|
205
196
|
|
|
206
197
|
|
|
207
|
-
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
208
198
|
def lce_forward(
|
|
209
199
|
self,
|
|
210
200
|
input_ids: torch.LongTensor = None,
|
|
@@ -223,6 +213,7 @@ def lce_forward(
|
|
|
223
213
|
cache_position: Optional[torch.LongTensor] = None,
|
|
224
214
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
225
215
|
image_sizes: torch.Tensor = None,
|
|
216
|
+
skip_logits: Optional[bool] = None,
|
|
226
217
|
**lm_kwargs,
|
|
227
218
|
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
|
|
228
219
|
r"""
|
|
@@ -277,78 +268,58 @@ def lce_forward(
|
|
|
277
268
|
else self.config.vision_feature_select_strategy
|
|
278
269
|
)
|
|
279
270
|
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
if pixel_values is not None and inputs_embeds is not None:
|
|
284
|
-
raise ValueError(
|
|
285
|
-
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
|
286
|
-
)
|
|
287
|
-
|
|
288
|
-
if inputs_embeds is None:
|
|
289
|
-
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
290
|
-
|
|
291
|
-
if pixel_values is not None:
|
|
292
|
-
image_features = self.get_image_features(
|
|
293
|
-
pixel_values=pixel_values,
|
|
294
|
-
vision_feature_layer=vision_feature_layer,
|
|
295
|
-
vision_feature_select_strategy=vision_feature_select_strategy,
|
|
296
|
-
image_sizes=image_sizes,
|
|
297
|
-
)
|
|
298
|
-
|
|
299
|
-
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
|
300
|
-
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
301
|
-
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
|
302
|
-
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
|
303
|
-
n_image_features = image_features.shape[0] * image_features.shape[1]
|
|
304
|
-
raise ValueError(
|
|
305
|
-
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
306
|
-
)
|
|
307
|
-
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
308
|
-
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
|
309
|
-
|
|
310
|
-
outputs = self.language_model.model(
|
|
271
|
+
outputs = self.model(
|
|
272
|
+
input_ids=input_ids,
|
|
273
|
+
pixel_values=pixel_values,
|
|
311
274
|
attention_mask=attention_mask,
|
|
312
275
|
position_ids=position_ids,
|
|
313
276
|
past_key_values=past_key_values,
|
|
314
277
|
inputs_embeds=inputs_embeds,
|
|
278
|
+
vision_feature_layer=vision_feature_layer,
|
|
279
|
+
vision_feature_select_strategy=vision_feature_select_strategy,
|
|
315
280
|
use_cache=use_cache,
|
|
316
281
|
output_attentions=output_attentions,
|
|
317
282
|
output_hidden_states=output_hidden_states,
|
|
318
|
-
return_dict=
|
|
283
|
+
return_dict=True,
|
|
319
284
|
cache_position=cache_position,
|
|
320
|
-
|
|
285
|
+
image_sizes=image_sizes,
|
|
321
286
|
**lm_kwargs,
|
|
322
287
|
)
|
|
323
288
|
hidden_states = outputs[0]
|
|
289
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
290
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
291
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
324
292
|
|
|
325
|
-
|
|
293
|
+
shift_labels = lm_kwargs.pop("shift_labels", None)
|
|
326
294
|
logits = None
|
|
295
|
+
loss = None
|
|
327
296
|
|
|
328
|
-
if
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
shift_labels
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
loss = lce(
|
|
344
|
-
self.language_model.lm_head.weight,
|
|
345
|
-
shift_hidden_states.view(-1, shift_hidden_states.size(-1)),
|
|
346
|
-
shift_labels.view(-1).to(shift_hidden_states.device),
|
|
297
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
298
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
299
|
+
|
|
300
|
+
if skip_logits is None:
|
|
301
|
+
# By default, if in training mode, don't materialize logits
|
|
302
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
303
|
+
|
|
304
|
+
if skip_logits:
|
|
305
|
+
loss = LigerForCausalLMLoss(
|
|
306
|
+
hidden_states=kept_hidden_states,
|
|
307
|
+
lm_head_weight=self.lm_head.weight,
|
|
308
|
+
labels=labels,
|
|
309
|
+
shift_labels=shift_labels,
|
|
310
|
+
hidden_size=self.config.text_config.hidden_size,
|
|
311
|
+
**lm_kwargs,
|
|
347
312
|
)
|
|
348
313
|
|
|
314
|
+
else:
|
|
315
|
+
logits = self.lm_head(kept_hidden_states)
|
|
316
|
+
if labels is not None:
|
|
317
|
+
loss = self.loss_function(
|
|
318
|
+
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **lm_kwargs
|
|
319
|
+
)
|
|
320
|
+
|
|
349
321
|
if not return_dict:
|
|
350
|
-
|
|
351
|
-
output = outputs[1:]
|
|
322
|
+
output = (logits,) + outputs[1:]
|
|
352
323
|
return (loss,) + output if loss is not None else output
|
|
353
324
|
|
|
354
325
|
return LlavaCausalLMOutputWithPast(
|
|
@@ -357,5 +328,5 @@ def lce_forward(
|
|
|
357
328
|
past_key_values=outputs.past_key_values,
|
|
358
329
|
hidden_states=outputs.hidden_states,
|
|
359
330
|
attentions=outputs.attentions,
|
|
360
|
-
image_hidden_states=
|
|
331
|
+
image_hidden_states=outputs.image_hidden_states,
|
|
361
332
|
)
|
|
@@ -27,7 +27,8 @@ def lce_forward(
|
|
|
27
27
|
return_dict: Optional[bool] = None,
|
|
28
28
|
cache_position: Optional[torch.LongTensor] = None,
|
|
29
29
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
30
|
-
|
|
30
|
+
skip_logits: Optional[bool] = None,
|
|
31
|
+
**kwargs,
|
|
31
32
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
32
33
|
r"""
|
|
33
34
|
Copy paste Mistral's forward but replace torch cross entropy with liger fused linear cross entropy
|
|
@@ -82,6 +83,7 @@ def lce_forward(
|
|
|
82
83
|
output_hidden_states=output_hidden_states,
|
|
83
84
|
return_dict=return_dict,
|
|
84
85
|
cache_position=cache_position,
|
|
86
|
+
**kwargs,
|
|
85
87
|
)
|
|
86
88
|
|
|
87
89
|
hidden_states = outputs[0]
|
|
@@ -89,18 +91,24 @@ def lce_forward(
|
|
|
89
91
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
90
92
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
91
93
|
|
|
92
|
-
shift_labels =
|
|
94
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
93
95
|
loss = None
|
|
94
96
|
logits = None
|
|
95
97
|
|
|
96
|
-
if
|
|
98
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
99
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
100
|
+
|
|
101
|
+
if skip_logits is None:
|
|
102
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
103
|
+
|
|
104
|
+
if skip_logits:
|
|
97
105
|
loss = LigerForCausalLMLoss(
|
|
98
106
|
hidden_states=kept_hidden_states,
|
|
99
107
|
lm_head_weight=self.lm_head.weight,
|
|
100
108
|
labels=labels,
|
|
101
109
|
shift_labels=shift_labels,
|
|
102
110
|
hidden_size=self.config.hidden_size,
|
|
103
|
-
**
|
|
111
|
+
**kwargs,
|
|
104
112
|
)
|
|
105
113
|
|
|
106
114
|
else:
|
|
@@ -112,7 +120,7 @@ def lce_forward(
|
|
|
112
120
|
logits=logits,
|
|
113
121
|
labels=labels,
|
|
114
122
|
vocab_size=self.config.vocab_size,
|
|
115
|
-
**
|
|
123
|
+
**kwargs,
|
|
116
124
|
)
|
|
117
125
|
if not return_dict:
|
|
118
126
|
output = (logits,) + outputs[1:]
|
|
@@ -125,6 +133,3 @@ def lce_forward(
|
|
|
125
133
|
hidden_states=outputs.hidden_states,
|
|
126
134
|
attentions=outputs.attentions,
|
|
127
135
|
)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
# Note: Grad Acc is not fixed in mistral at transformer 4.46.1
|
|
@@ -156,7 +156,8 @@ def lce_forward(
|
|
|
156
156
|
return_dict: Optional[bool] = None,
|
|
157
157
|
cache_position: Optional[torch.LongTensor] = None,
|
|
158
158
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
159
|
-
|
|
159
|
+
skip_logits: Optional[bool] = None,
|
|
160
|
+
**kwargs,
|
|
160
161
|
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
|
161
162
|
r"""
|
|
162
163
|
Args:
|
|
@@ -214,6 +215,7 @@ def lce_forward(
|
|
|
214
215
|
output_router_logits=output_router_logits,
|
|
215
216
|
return_dict=return_dict,
|
|
216
217
|
cache_position=cache_position,
|
|
218
|
+
**kwargs,
|
|
217
219
|
)
|
|
218
220
|
|
|
219
221
|
hidden_states = outputs[0]
|
|
@@ -221,26 +223,33 @@ def lce_forward(
|
|
|
221
223
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
222
224
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
223
225
|
|
|
224
|
-
shift_labels =
|
|
226
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
225
227
|
logits = None
|
|
226
228
|
loss = None
|
|
227
|
-
|
|
228
|
-
if
|
|
229
|
+
|
|
230
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
231
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
232
|
+
|
|
233
|
+
if skip_logits is None:
|
|
234
|
+
# By default, if in training mode, don't materialize logits
|
|
235
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
236
|
+
|
|
237
|
+
if skip_logits:
|
|
229
238
|
loss = LigerForCausalLMLoss(
|
|
230
239
|
hidden_states=kept_hidden_states,
|
|
231
240
|
lm_head_weight=self.lm_head.weight,
|
|
232
241
|
labels=labels,
|
|
233
242
|
shift_labels=shift_labels,
|
|
234
243
|
hidden_size=self.config.hidden_size,
|
|
235
|
-
**
|
|
244
|
+
**kwargs,
|
|
236
245
|
)
|
|
237
246
|
|
|
238
|
-
else:
|
|
247
|
+
else:
|
|
239
248
|
logits = self.lm_head(kept_hidden_states)
|
|
240
249
|
|
|
241
250
|
loss = None
|
|
242
251
|
if labels is not None:
|
|
243
|
-
loss = self.loss_function(logits, labels, self.vocab_size, **
|
|
252
|
+
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
|
244
253
|
aux_loss = None
|
|
245
254
|
if output_router_logits:
|
|
246
255
|
aux_loss = load_balancing_loss_func(
|
|
@@ -147,7 +147,8 @@ def lce_forward(
|
|
|
147
147
|
return_dict: Optional[bool] = None,
|
|
148
148
|
cache_position: Optional[torch.LongTensor] = None,
|
|
149
149
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
150
|
-
|
|
150
|
+
skip_logits: Optional[bool] = None,
|
|
151
|
+
**kwargs,
|
|
151
152
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
152
153
|
r"""
|
|
153
154
|
Args:
|
|
@@ -205,6 +206,7 @@ def lce_forward(
|
|
|
205
206
|
output_hidden_states=output_hidden_states,
|
|
206
207
|
return_dict=return_dict,
|
|
207
208
|
cache_position=cache_position,
|
|
209
|
+
**kwargs,
|
|
208
210
|
)
|
|
209
211
|
|
|
210
212
|
hidden_states = outputs[0]
|
|
@@ -212,28 +214,35 @@ def lce_forward(
|
|
|
212
214
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
213
215
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
214
216
|
|
|
215
|
-
shift_labels =
|
|
217
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
216
218
|
logits = None
|
|
217
219
|
loss = None
|
|
218
|
-
|
|
219
|
-
if
|
|
220
|
+
|
|
221
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
222
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
223
|
+
|
|
224
|
+
if skip_logits is None:
|
|
225
|
+
# By default, if in training mode, don't materialize logits
|
|
226
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
227
|
+
|
|
228
|
+
if skip_logits:
|
|
220
229
|
loss = LigerForCausalLMLoss(
|
|
221
230
|
hidden_states=kept_hidden_states,
|
|
222
231
|
lm_head_weight=self.lm_head.weight,
|
|
223
232
|
labels=labels,
|
|
224
233
|
shift_labels=shift_labels,
|
|
225
234
|
hidden_size=self.config.hidden_size,
|
|
226
|
-
**
|
|
235
|
+
**kwargs,
|
|
227
236
|
)
|
|
228
237
|
|
|
229
|
-
else:
|
|
238
|
+
else:
|
|
230
239
|
logits = self.lm_head(kept_hidden_states)
|
|
231
240
|
if labels is not None:
|
|
232
241
|
loss = self.loss_function(
|
|
233
242
|
logits=logits,
|
|
234
243
|
labels=labels,
|
|
235
244
|
vocab_size=self.config.vocab_size,
|
|
236
|
-
**
|
|
245
|
+
**kwargs,
|
|
237
246
|
)
|
|
238
247
|
|
|
239
248
|
if not return_dict:
|
|
@@ -26,7 +26,8 @@ def lce_forward(
|
|
|
26
26
|
return_dict: Optional[bool] = None,
|
|
27
27
|
cache_position: Optional[torch.LongTensor] = None,
|
|
28
28
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
29
|
-
|
|
29
|
+
skip_logits: Optional[bool] = None,
|
|
30
|
+
**kwargs,
|
|
30
31
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
31
32
|
r"""
|
|
32
33
|
Args:
|
|
@@ -79,6 +80,7 @@ def lce_forward(
|
|
|
79
80
|
output_hidden_states=output_hidden_states,
|
|
80
81
|
return_dict=return_dict,
|
|
81
82
|
cache_position=cache_position,
|
|
83
|
+
**kwargs,
|
|
82
84
|
)
|
|
83
85
|
|
|
84
86
|
hidden_states = outputs[0]
|
|
@@ -86,28 +88,35 @@ def lce_forward(
|
|
|
86
88
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
87
89
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
88
90
|
|
|
89
|
-
shift_labels =
|
|
91
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
90
92
|
logits = None
|
|
91
93
|
loss = None
|
|
92
|
-
|
|
93
|
-
if
|
|
94
|
+
|
|
95
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
96
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
97
|
+
|
|
98
|
+
if skip_logits is None:
|
|
99
|
+
# By default, if in training mode, don't materialize logits
|
|
100
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
101
|
+
|
|
102
|
+
if skip_logits:
|
|
94
103
|
loss = LigerForCausalLMLoss(
|
|
95
104
|
hidden_states=kept_hidden_states,
|
|
96
105
|
lm_head_weight=self.lm_head.weight,
|
|
97
106
|
labels=labels,
|
|
98
107
|
shift_labels=shift_labels,
|
|
99
108
|
hidden_size=self.config.hidden_size,
|
|
100
|
-
**
|
|
109
|
+
**kwargs,
|
|
101
110
|
)
|
|
102
111
|
|
|
103
|
-
else:
|
|
112
|
+
else:
|
|
104
113
|
logits = self.lm_head(kept_hidden_states)
|
|
105
114
|
if labels is not None:
|
|
106
115
|
loss = self.loss_function(
|
|
107
116
|
logits=logits,
|
|
108
117
|
labels=labels,
|
|
109
118
|
vocab_size=self.config.vocab_size,
|
|
110
|
-
**
|
|
119
|
+
**kwargs,
|
|
111
120
|
)
|
|
112
121
|
|
|
113
122
|
return CausalLMOutputWithPast(
|
|
@@ -216,6 +216,7 @@ def lce_forward(
|
|
|
216
216
|
output_hidden_states: Optional[bool] = None,
|
|
217
217
|
return_dict: Optional[bool] = None,
|
|
218
218
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
219
|
+
skip_logits: Optional[bool] = None,
|
|
219
220
|
**lm_kwargs,
|
|
220
221
|
) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
|
|
221
222
|
r"""
|
|
@@ -331,7 +332,13 @@ def lce_forward(
|
|
|
331
332
|
loss = None
|
|
332
333
|
logits = None
|
|
333
334
|
|
|
334
|
-
if
|
|
335
|
+
if skip_logits and labels is None:
|
|
336
|
+
raise ValueError("skip_logits is True, but labels is None")
|
|
337
|
+
|
|
338
|
+
if skip_logits is None:
|
|
339
|
+
skip_logits = self.training and (labels is not None)
|
|
340
|
+
|
|
341
|
+
if skip_logits:
|
|
335
342
|
shift_hidden_states = hidden_states[..., :-1, :]
|
|
336
343
|
shift_labels = labels[..., 1:]
|
|
337
344
|
|
|
@@ -26,6 +26,7 @@ def lce_forward_deprecated(
|
|
|
26
26
|
output_hidden_states: Optional[bool] = None,
|
|
27
27
|
return_dict: Optional[bool] = None,
|
|
28
28
|
cache_position: Optional[torch.LongTensor] = None,
|
|
29
|
+
skip_logits: Optional[bool] = None,
|
|
29
30
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
30
31
|
r"""
|
|
31
32
|
Copy paste phi3 forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
|
|
@@ -80,7 +81,14 @@ def lce_forward_deprecated(
|
|
|
80
81
|
loss = None
|
|
81
82
|
logits = None
|
|
82
83
|
|
|
83
|
-
if
|
|
84
|
+
if skip_logits and labels is None:
|
|
85
|
+
raise ValueError("skip_logits is True, but labels is None")
|
|
86
|
+
|
|
87
|
+
if skip_logits is None:
|
|
88
|
+
# By default, if in training mode, don't materialize logits
|
|
89
|
+
skip_logits = self.training and labels is not None
|
|
90
|
+
|
|
91
|
+
if skip_logits:
|
|
84
92
|
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
85
93
|
shift_labels = labels[..., 1:].contiguous()
|
|
86
94
|
|
|
@@ -136,7 +144,8 @@ def lce_forward(
|
|
|
136
144
|
return_dict: Optional[bool] = None,
|
|
137
145
|
cache_position: Optional[torch.LongTensor] = None,
|
|
138
146
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
139
|
-
|
|
147
|
+
skip_logits: Optional[bool] = None,
|
|
148
|
+
**kwargs,
|
|
140
149
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
141
150
|
r"""
|
|
142
151
|
Args:
|
|
@@ -202,6 +211,7 @@ def lce_forward(
|
|
|
202
211
|
output_attentions=output_attentions,
|
|
203
212
|
output_hidden_states=output_hidden_states,
|
|
204
213
|
return_dict=return_dict,
|
|
214
|
+
**kwargs,
|
|
205
215
|
)
|
|
206
216
|
|
|
207
217
|
hidden_states = outputs[0]
|
|
@@ -209,28 +219,35 @@ def lce_forward(
|
|
|
209
219
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
210
220
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
211
221
|
|
|
212
|
-
shift_labels =
|
|
222
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
213
223
|
logits = None
|
|
214
224
|
loss = None
|
|
215
|
-
|
|
216
|
-
if
|
|
225
|
+
|
|
226
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
227
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
228
|
+
|
|
229
|
+
if skip_logits is None:
|
|
230
|
+
# By default, if in training mode, don't materialize logits
|
|
231
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
232
|
+
|
|
233
|
+
if skip_logits:
|
|
217
234
|
loss = LigerForCausalLMLoss(
|
|
218
235
|
hidden_states=kept_hidden_states,
|
|
219
236
|
lm_head_weight=self.lm_head.weight,
|
|
220
237
|
labels=labels,
|
|
221
238
|
shift_labels=shift_labels,
|
|
222
239
|
hidden_size=self.config.hidden_size,
|
|
223
|
-
**
|
|
240
|
+
**kwargs,
|
|
224
241
|
)
|
|
225
242
|
|
|
226
|
-
else:
|
|
243
|
+
else:
|
|
227
244
|
logits = self.lm_head(kept_hidden_states)
|
|
228
245
|
if labels is not None:
|
|
229
246
|
loss = self.loss_function(
|
|
230
247
|
logits=logits,
|
|
231
248
|
labels=labels,
|
|
232
249
|
vocab_size=self.config.vocab_size,
|
|
233
|
-
**
|
|
250
|
+
**kwargs,
|
|
234
251
|
)
|
|
235
252
|
|
|
236
253
|
if not return_dict:
|