liger-kernel-nightly 0.5.10.dev20250526154218__py3-none-any.whl → 0.5.10.dev20250528223223__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/transformers/model/llava.py +37 -1
- liger_kernel/transformers/monkey_patch.py +3 -4
- {liger_kernel_nightly-0.5.10.dev20250526154218.dist-info → liger_kernel_nightly-0.5.10.dev20250528223223.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.10.dev20250526154218.dist-info → liger_kernel_nightly-0.5.10.dev20250528223223.dist-info}/RECORD +8 -8
- {liger_kernel_nightly-0.5.10.dev20250526154218.dist-info → liger_kernel_nightly-0.5.10.dev20250528223223.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250526154218.dist-info → liger_kernel_nightly-0.5.10.dev20250528223223.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250526154218.dist-info → liger_kernel_nightly-0.5.10.dev20250528223223.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250526154218.dist-info → liger_kernel_nightly-0.5.10.dev20250528223223.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@ from typing import Union
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
8
|
+
from torch.nn import CrossEntropyLoss
|
8
9
|
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
|
9
10
|
from transformers.utils import is_torchdynamo_compiling
|
10
11
|
from transformers.utils.deprecation import deprecate_kwarg
|
@@ -189,7 +190,20 @@ def lce_forward_deprecated(
|
|
189
190
|
|
190
191
|
lce = LigerFusedLinearCrossEntropyLoss()
|
191
192
|
loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
|
192
|
-
|
193
|
+
else:
|
194
|
+
logits = self.language_model.lm_head(hidden_states)
|
195
|
+
if labels is not None:
|
196
|
+
# Shift so that tokens < n predict n
|
197
|
+
if attention_mask is not None:
|
198
|
+
shift_attention_mask = attention_mask[..., 1:]
|
199
|
+
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
|
200
|
+
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
|
201
|
+
else:
|
202
|
+
shift_logits = logits[..., :-1, :].contiguous()
|
203
|
+
shift_labels = labels[..., 1:].contiguous()
|
204
|
+
# Flatten the tokens
|
205
|
+
loss_fct = CrossEntropyLoss()
|
206
|
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device))
|
193
207
|
if not return_dict:
|
194
208
|
# NOTE: This part has not been tested.
|
195
209
|
output = outputs[1:]
|
@@ -349,6 +363,28 @@ def lce_forward(
|
|
349
363
|
shift_hidden_states.view(-1, shift_hidden_states.size(-1)),
|
350
364
|
shift_labels.view(-1).to(shift_hidden_states.device),
|
351
365
|
)
|
366
|
+
else:
|
367
|
+
logits = self.language_model.lm_head(hidden_states)
|
368
|
+
if labels is not None:
|
369
|
+
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
370
|
+
logits = logits.float()
|
371
|
+
shift_logits = logits[..., :-1, :]
|
372
|
+
shift_labels = labels[..., 1:]
|
373
|
+
if attention_mask is not None:
|
374
|
+
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
375
|
+
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
376
|
+
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
|
377
|
+
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
|
378
|
+
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
|
379
|
+
else:
|
380
|
+
shift_logits = shift_logits.contiguous()
|
381
|
+
shift_labels = shift_labels.contiguous()
|
382
|
+
# Flatten the tokens
|
383
|
+
loss_fct = CrossEntropyLoss()
|
384
|
+
|
385
|
+
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
386
|
+
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
387
|
+
loss = loss_fct(flat_logits, flat_labels)
|
352
388
|
|
353
389
|
if not return_dict:
|
354
390
|
# NOTE: This part has not been tested.
|
@@ -776,7 +776,7 @@ def apply_liger_kernel_to_gemma3_text(
|
|
776
776
|
|
777
777
|
from transformers.models.gemma3 import modeling_gemma3
|
778
778
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
|
779
|
-
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
|
779
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM, Gemma3TextModel
|
780
780
|
|
781
781
|
from liger_kernel.transformers.gema3_rms import LigerRMSNormForGemma3
|
782
782
|
from liger_kernel.transformers.model.gemma3 import causal_forward
|
@@ -807,9 +807,9 @@ def apply_liger_kernel_to_gemma3_text(
|
|
807
807
|
# The model instance already exists, so we need to additionally patch the
|
808
808
|
# instance variables that reference already-instantiated modules
|
809
809
|
|
810
|
-
if isinstance(model, Gemma3ForCausalLM):
|
810
|
+
if isinstance(model, Gemma3ForCausalLM) or isinstance(model, Gemma3TextModel):
|
811
811
|
# get the base model from the model instance
|
812
|
-
base_model = model.model
|
812
|
+
base_model = model.model if isinstance(model, Gemma3ForCausalLM) else model
|
813
813
|
|
814
814
|
if rms_norm:
|
815
815
|
_patch_rms_norm_module_for_gemma3(base_model.norm)
|
@@ -1625,7 +1625,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
|
1625
1625
|
return
|
1626
1626
|
|
1627
1627
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
1628
|
-
|
1629
1628
|
apply_fn_signature = inspect.signature(apply_fn)
|
1630
1629
|
|
1631
1630
|
# Filter out the keyword arguments that are not supported by the apply function
|
@@ -52,7 +52,7 @@ liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-
|
|
52
52
|
liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
|
53
53
|
liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
|
54
54
|
liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
|
55
|
-
liger_kernel/transformers/monkey_patch.py,sha256=
|
55
|
+
liger_kernel/transformers/monkey_patch.py,sha256=a0CXSC8BwZg3vok-ns0udZLUOBkegGQgPDod3H8ilP4,74610
|
56
56
|
liger_kernel/transformers/multi_token_attention.py,sha256=l9VDICK0dfmifUDW668hGscP8AHq2rYcM2oGUa3baRQ,1751
|
57
57
|
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
58
58
|
liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUSyH2uFW2A,1191
|
@@ -69,7 +69,7 @@ liger_kernel/transformers/model/gemma2.py,sha256=JxPTXVkuFtiaZYkaBM8bZF-ObyatHmA
|
|
69
69
|
liger_kernel/transformers/model/gemma3.py,sha256=JI4jj9K660HeRsofB6cpkCHBQ0OsazElArRtKUehUmw,15945
|
70
70
|
liger_kernel/transformers/model/glm4.py,sha256=3YJiGdZ0nNSdZidPFlXdUad8mlFwyfq44yd11OcdNns,5259
|
71
71
|
liger_kernel/transformers/model/llama.py,sha256=cAWTCY0bk67lFXNtAVEXIWl9WNgn4JyU25Q7nhpKjE0,12505
|
72
|
-
liger_kernel/transformers/model/llava.py,sha256=
|
72
|
+
liger_kernel/transformers/model/llava.py,sha256=ONdpx96AVbbL8QDQvHSm08jMJPz3tzkbeO92IRbAb1A,19270
|
73
73
|
liger_kernel/transformers/model/loss_utils.py,sha256=WWAMdiONPaXpIvxyOim_0igLrYh0yyOok5Q9_L9xvZw,1787
|
74
74
|
liger_kernel/transformers/model/mistral.py,sha256=vFFZD5VAwpx6Bs4gXoXDRmyU9-7Dp50w3jIcj0q0sIo,5567
|
75
75
|
liger_kernel/transformers/model/mixtral.py,sha256=vSmgBc91WMu9_iWkAHUJPzo0-WDkTJK5SEVYNaDRT_Y,11398
|
@@ -86,9 +86,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
86
86
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
87
87
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
88
88
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
89
|
-
liger_kernel_nightly-0.5.10.
|
90
|
-
liger_kernel_nightly-0.5.10.
|
91
|
-
liger_kernel_nightly-0.5.10.
|
92
|
-
liger_kernel_nightly-0.5.10.
|
93
|
-
liger_kernel_nightly-0.5.10.
|
94
|
-
liger_kernel_nightly-0.5.10.
|
89
|
+
liger_kernel_nightly-0.5.10.dev20250528223223.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
90
|
+
liger_kernel_nightly-0.5.10.dev20250528223223.dist-info/METADATA,sha256=peyDncCLhsNKI0sXe4Fg-cjTiGK_5NFaM7vdiRwjaZY,24113
|
91
|
+
liger_kernel_nightly-0.5.10.dev20250528223223.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
92
|
+
liger_kernel_nightly-0.5.10.dev20250528223223.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
93
|
+
liger_kernel_nightly-0.5.10.dev20250528223223.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
94
|
+
liger_kernel_nightly-0.5.10.dev20250528223223.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|