liger-kernel-nightly 0.6.2.dev20250822031344__py3-none-any.whl → 0.6.2.dev20250826142826__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/jsd_loss.py +5 -2
- liger_kernel/ops/fused_linear_cross_entropy.py +15 -2
- liger_kernel/transformers/model/loss_utils.py +1 -0
- {liger_kernel_nightly-0.6.2.dev20250822031344.dist-info → liger_kernel_nightly-0.6.2.dev20250826142826.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.6.2.dev20250822031344.dist-info → liger_kernel_nightly-0.6.2.dev20250826142826.dist-info}/RECORD +9 -9
- {liger_kernel_nightly-0.6.2.dev20250822031344.dist-info → liger_kernel_nightly-0.6.2.dev20250826142826.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.2.dev20250822031344.dist-info → liger_kernel_nightly-0.6.2.dev20250826142826.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.2.dev20250822031344.dist-info → liger_kernel_nightly-0.6.2.dev20250826142826.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.2.dev20250822031344.dist-info → liger_kernel_nightly-0.6.2.dev20250826142826.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
|
|
1
|
+
import math
|
2
|
+
|
1
3
|
import torch
|
2
4
|
import torch.nn.functional as F
|
3
5
|
|
@@ -25,8 +27,9 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
25
27
|
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
|
26
28
|
else:
|
27
29
|
# Compute probabilities (only required for mean calculation)
|
28
|
-
|
29
|
-
|
30
|
+
log_mean_probs = torch.logsumexp(
|
31
|
+
torch.stack([student_log_probs + math.log(1 - beta), teacher_log_probs + math.log(beta)], dim=0), dim=0
|
32
|
+
)
|
30
33
|
|
31
34
|
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
|
32
35
|
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
|
@@ -101,8 +101,21 @@ def fused_linear_cross_entropy_forward(
|
|
101
101
|
# Compute softmax to get predicted probabilities
|
102
102
|
probs = torch.softmax(logits_for_softmax, dim=-1)
|
103
103
|
|
104
|
-
# Get
|
105
|
-
|
104
|
+
# Get predicted probabilities for token scaling, handling ignored targets
|
105
|
+
valid_target_mask = target_chunk != ignore_index
|
106
|
+
valid_targets = target_chunk[valid_target_mask]
|
107
|
+
|
108
|
+
if len(valid_targets) > 0:
|
109
|
+
# Gather probabilities only for valid targets
|
110
|
+
valid_probs = probs[valid_target_mask]
|
111
|
+
pred_probs_valid = torch.gather(valid_probs, -1, valid_targets.unsqueeze(-1)).squeeze(-1)
|
112
|
+
|
113
|
+
# Create full tensor with zeros for ignored targets
|
114
|
+
pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
|
115
|
+
pred_probs[valid_target_mask] = pred_probs_valid
|
116
|
+
else:
|
117
|
+
# All targets are ignored
|
118
|
+
pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
|
106
119
|
|
107
120
|
# Store the scaling factors
|
108
121
|
scaling_factors = pred_probs.detach() # Detach to ensure no gradient flow
|
@@ -12,7 +12,7 @@ liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=AA19cpv6D8mo5RbSK5GRCcZoOSn
|
|
12
12
|
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=FIH85uUXAOgYx5Ax8MjFhJHVu-2pKtY7wSegd0zSyyY,18336
|
13
13
|
liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=RiuK3UtRwH9T6jZ36sA8Urj-TVuOLOO2syLg_JOQapY,13437
|
14
14
|
liger_kernel/chunked_loss/grpo_loss.py,sha256=kuqHkYV383sUxqJN-DMsfADHi2hxHVyKx5S24TNc8bQ,10866
|
15
|
-
liger_kernel/chunked_loss/jsd_loss.py,sha256=
|
15
|
+
liger_kernel/chunked_loss/jsd_loss.py,sha256=gRhnmB8xwuz7FcMJi5v5eyBsq01owaCbcyyrF4rYtY0,7133
|
16
16
|
liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsmSbQyqwQY,7529
|
17
17
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
|
18
18
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=fy2w8KbhMrBv7b1jdIeH3bBFxY52bPQPZb3KwBvmurM,5385
|
@@ -20,7 +20,7 @@ liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
|
|
20
20
|
liger_kernel/ops/cross_entropy.py,sha256=e8THGnhOcy_0SbOLABx67HEM7-B8a8pG7nDKbCRpQKM,19123
|
21
21
|
liger_kernel/ops/dyt.py,sha256=gCLz4S8aul8SY9nvIGaoK67aGb7U9MJRQdo3ONqmQYs,5417
|
22
22
|
liger_kernel/ops/fused_add_rms_norm.py,sha256=UBqmlqFCmhSAIpkNKd8rrfXatX7Z4J9bp2dX9A0lrJQ,14017
|
23
|
-
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=
|
23
|
+
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=6rB3pdwU97Ivl2IHndPJjzhP28E9Fd0pUQcPHLiuCjc,14290
|
24
24
|
liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
|
25
25
|
liger_kernel/ops/fused_neighborhood_attention.py,sha256=vPi5xbnh6wxyZehaqo6Tuilqo2fN5SGDiONjnNmIKqs,35556
|
26
26
|
liger_kernel/ops/geglu.py,sha256=r0WSq9E93zzynL44Wh8femzOWK07_SseBM_pJUyxT3s,4144
|
@@ -79,7 +79,7 @@ liger_kernel/transformers/model/glm4v.py,sha256=zbV3agptEYpGAD0eeCRwIpJAhJUviTT5
|
|
79
79
|
liger_kernel/transformers/model/llama.py,sha256=i8jJgyZsMKWQ-zKloETLugtwFpUOdaWxLDceciFXKd4,12832
|
80
80
|
liger_kernel/transformers/model/llama4.py,sha256=IgbB8sTh3dlETQnaNNy1bZLuXy-Nt7qmeAjF27ydGpg,4210
|
81
81
|
liger_kernel/transformers/model/llava.py,sha256=bLCioday_SOm69ogMDBhy_4UsVkH2-BSl93-EXY6-7I,15076
|
82
|
-
liger_kernel/transformers/model/loss_utils.py,sha256=
|
82
|
+
liger_kernel/transformers/model/loss_utils.py,sha256=02RVkPI7Qs4ZP4yU_udCAvD_2hgIaHmxremRKe3N7EE,1885
|
83
83
|
liger_kernel/transformers/model/mistral.py,sha256=syYNL8dLThX2-4uC13Lu0krEZ5zw3InviDUR3AJmc-I,5500
|
84
84
|
liger_kernel/transformers/model/mixtral.py,sha256=VY-y73IyjcCyWyI7ahxXLw0fJrhgjYfr1xwRYtsHX0o,11396
|
85
85
|
liger_kernel/transformers/model/mllama.py,sha256=NhJtlXiuszJHo5YSJOvSGYH47ly7Hse8r-5BKznBg9s,11522
|
@@ -96,9 +96,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
96
96
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
97
97
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
98
98
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
99
|
-
liger_kernel_nightly-0.6.2.
|
100
|
-
liger_kernel_nightly-0.6.2.
|
101
|
-
liger_kernel_nightly-0.6.2.
|
102
|
-
liger_kernel_nightly-0.6.2.
|
103
|
-
liger_kernel_nightly-0.6.2.
|
104
|
-
liger_kernel_nightly-0.6.2.
|
99
|
+
liger_kernel_nightly-0.6.2.dev20250826142826.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
100
|
+
liger_kernel_nightly-0.6.2.dev20250826142826.dist-info/METADATA,sha256=6CBXUT-5ztpSjUlafCBwOBVsFxMpb8pkwBCyPFscKIE,24504
|
101
|
+
liger_kernel_nightly-0.6.2.dev20250826142826.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
102
|
+
liger_kernel_nightly-0.6.2.dev20250826142826.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
103
|
+
liger_kernel_nightly-0.6.2.dev20250826142826.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
104
|
+
liger_kernel_nightly-0.6.2.dev20250826142826.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|