liger-kernel-nightly 0.5.10.dev20250526154149__py3-none-any.whl → 0.5.10.dev20250527002824__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/transformers/model/llava.py +37 -1
- {liger_kernel_nightly-0.5.10.dev20250526154149.dist-info → liger_kernel_nightly-0.5.10.dev20250527002824.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.10.dev20250526154149.dist-info → liger_kernel_nightly-0.5.10.dev20250527002824.dist-info}/RECORD +7 -7
- {liger_kernel_nightly-0.5.10.dev20250526154149.dist-info → liger_kernel_nightly-0.5.10.dev20250527002824.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250526154149.dist-info → liger_kernel_nightly-0.5.10.dev20250527002824.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250526154149.dist-info → liger_kernel_nightly-0.5.10.dev20250527002824.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250526154149.dist-info → liger_kernel_nightly-0.5.10.dev20250527002824.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@ from typing import Union
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
8
|
+
from torch.nn import CrossEntropyLoss
|
8
9
|
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
|
9
10
|
from transformers.utils import is_torchdynamo_compiling
|
10
11
|
from transformers.utils.deprecation import deprecate_kwarg
|
@@ -189,7 +190,20 @@ def lce_forward_deprecated(
|
|
189
190
|
|
190
191
|
lce = LigerFusedLinearCrossEntropyLoss()
|
191
192
|
loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
|
192
|
-
|
193
|
+
else:
|
194
|
+
logits = self.language_model.lm_head(hidden_states)
|
195
|
+
if labels is not None:
|
196
|
+
# Shift so that tokens < n predict n
|
197
|
+
if attention_mask is not None:
|
198
|
+
shift_attention_mask = attention_mask[..., 1:]
|
199
|
+
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
|
200
|
+
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
|
201
|
+
else:
|
202
|
+
shift_logits = logits[..., :-1, :].contiguous()
|
203
|
+
shift_labels = labels[..., 1:].contiguous()
|
204
|
+
# Flatten the tokens
|
205
|
+
loss_fct = CrossEntropyLoss()
|
206
|
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device))
|
193
207
|
if not return_dict:
|
194
208
|
# NOTE: This part has not been tested.
|
195
209
|
output = outputs[1:]
|
@@ -349,6 +363,28 @@ def lce_forward(
|
|
349
363
|
shift_hidden_states.view(-1, shift_hidden_states.size(-1)),
|
350
364
|
shift_labels.view(-1).to(shift_hidden_states.device),
|
351
365
|
)
|
366
|
+
else:
|
367
|
+
logits = self.language_model.lm_head(hidden_states)
|
368
|
+
if labels is not None:
|
369
|
+
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
370
|
+
logits = logits.float()
|
371
|
+
shift_logits = logits[..., :-1, :]
|
372
|
+
shift_labels = labels[..., 1:]
|
373
|
+
if attention_mask is not None:
|
374
|
+
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
375
|
+
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
376
|
+
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
|
377
|
+
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
|
378
|
+
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
|
379
|
+
else:
|
380
|
+
shift_logits = shift_logits.contiguous()
|
381
|
+
shift_labels = shift_labels.contiguous()
|
382
|
+
# Flatten the tokens
|
383
|
+
loss_fct = CrossEntropyLoss()
|
384
|
+
|
385
|
+
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
386
|
+
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
387
|
+
loss = loss_fct(flat_logits, flat_labels)
|
352
388
|
|
353
389
|
if not return_dict:
|
354
390
|
# NOTE: This part has not been tested.
|
@@ -69,7 +69,7 @@ liger_kernel/transformers/model/gemma2.py,sha256=JxPTXVkuFtiaZYkaBM8bZF-ObyatHmA
|
|
69
69
|
liger_kernel/transformers/model/gemma3.py,sha256=JI4jj9K660HeRsofB6cpkCHBQ0OsazElArRtKUehUmw,15945
|
70
70
|
liger_kernel/transformers/model/glm4.py,sha256=3YJiGdZ0nNSdZidPFlXdUad8mlFwyfq44yd11OcdNns,5259
|
71
71
|
liger_kernel/transformers/model/llama.py,sha256=cAWTCY0bk67lFXNtAVEXIWl9WNgn4JyU25Q7nhpKjE0,12505
|
72
|
-
liger_kernel/transformers/model/llava.py,sha256=
|
72
|
+
liger_kernel/transformers/model/llava.py,sha256=ONdpx96AVbbL8QDQvHSm08jMJPz3tzkbeO92IRbAb1A,19270
|
73
73
|
liger_kernel/transformers/model/loss_utils.py,sha256=WWAMdiONPaXpIvxyOim_0igLrYh0yyOok5Q9_L9xvZw,1787
|
74
74
|
liger_kernel/transformers/model/mistral.py,sha256=vFFZD5VAwpx6Bs4gXoXDRmyU9-7Dp50w3jIcj0q0sIo,5567
|
75
75
|
liger_kernel/transformers/model/mixtral.py,sha256=vSmgBc91WMu9_iWkAHUJPzo0-WDkTJK5SEVYNaDRT_Y,11398
|
@@ -86,9 +86,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
86
86
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
87
87
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
88
88
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
89
|
-
liger_kernel_nightly-0.5.10.
|
90
|
-
liger_kernel_nightly-0.5.10.
|
91
|
-
liger_kernel_nightly-0.5.10.
|
92
|
-
liger_kernel_nightly-0.5.10.
|
93
|
-
liger_kernel_nightly-0.5.10.
|
94
|
-
liger_kernel_nightly-0.5.10.
|
89
|
+
liger_kernel_nightly-0.5.10.dev20250527002824.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
90
|
+
liger_kernel_nightly-0.5.10.dev20250527002824.dist-info/METADATA,sha256=HqGgAlbm4h44NFbWrGIOmYBHC04k4wADYVhjp1WnyFQ,24113
|
91
|
+
liger_kernel_nightly-0.5.10.dev20250527002824.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
92
|
+
liger_kernel_nightly-0.5.10.dev20250527002824.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
93
|
+
liger_kernel_nightly-0.5.10.dev20250527002824.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
94
|
+
liger_kernel_nightly-0.5.10.dev20250527002824.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|