liger-kernel-nightly 0.5.3.dev20250220195514__py3-none-any.whl → 0.5.3.dev20250220230230__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- liger_kernel/chunked_loss/kto_loss.py +9 -9
- {liger_kernel_nightly-0.5.3.dev20250220195514.dist-info → liger_kernel_nightly-0.5.3.dev20250220230230.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.3.dev20250220195514.dist-info → liger_kernel_nightly-0.5.3.dev20250220230230.dist-info}/RECORD +7 -7
- {liger_kernel_nightly-0.5.3.dev20250220195514.dist-info → liger_kernel_nightly-0.5.3.dev20250220230230.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220195514.dist-info → liger_kernel_nightly-0.5.3.dev20250220230230.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220195514.dist-info → liger_kernel_nightly-0.5.3.dev20250220230230.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220195514.dist-info → liger_kernel_nightly-0.5.3.dev20250220230230.dist-info}/top_level.txt +0 -0
@@ -43,20 +43,20 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
43
43
|
3. Maintain reasonable distance from the reference model
|
44
44
|
|
45
45
|
Args:
|
46
|
-
|
47
|
-
|
46
|
+
average_log_prob_chunk: Log probabilities for the chunk (batch_size,)
|
47
|
+
preference_labels_chunk: Preference labels for the chunk (batch_size,)
|
48
48
|
full_target: Non chunked full target tensor
|
49
|
-
|
50
|
-
|
51
|
-
beta: Weight for the direct preference loss
|
49
|
+
ref_average_log_prob_chunk: Reference log probs for the chunk (batch_size,)
|
50
|
+
beta: Weight for the KTO loss
|
52
51
|
kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
|
53
52
|
Returns:
|
54
|
-
Tuple of (loss, chosen_rewards, rejected_rewards):
|
55
53
|
- loss: The KTO loss value
|
56
|
-
- chosen_rewards: Reward signals for chosen responses (detached)
|
57
|
-
- rejected_rewards: Reward signals for rejected responses (detached)
|
58
54
|
"""
|
59
|
-
|
55
|
+
if ref_average_log_prob_chunk is not None:
|
56
|
+
logratios_chunk = average_log_prob_chunk - ref_average_log_prob_chunk
|
57
|
+
else:
|
58
|
+
logratios_chunk = average_log_prob_chunk
|
59
|
+
|
60
60
|
multiplier_chunk = torch.where(preference_labels_chunk, 1, -1)
|
61
61
|
if kl is not None:
|
62
62
|
losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk)
|
@@ -12,7 +12,7 @@ liger_kernel/chunked_loss/fused_linear_rlhf.py,sha256=sAApL4GQ3YL2F-ymIAF61GCpFf
|
|
12
12
|
liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=ZqYlXXhIphkJPxOS7iI70avgrr6x0skEtgpckZTYau0,9819
|
13
13
|
liger_kernel/chunked_loss/grpo_loss.py,sha256=M5qlQR-v5Rh8N3P3dPGNhOKygDFJ4516_rJaVPzU_-c,4980
|
14
14
|
liger_kernel/chunked_loss/jsd_loss.py,sha256=yRCQdvd3ruTWP4A_BfU8VcZ6LepSUfO0Ob7stGnueQY,6052
|
15
|
-
liger_kernel/chunked_loss/kto_loss.py,sha256=
|
15
|
+
liger_kernel/chunked_loss/kto_loss.py,sha256=b3ffJyk97e-6XdXd4HFrYyx8wW4A-CU4gOaJSimKLtA,5476
|
16
16
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=yjcrrbVeemLYodoSKT-FMSnaPtyKAZ3aOrvPD6tTY6Y,3617
|
17
17
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=3TTc7U79Orjgi-Wu81WZkWk5MgsdqKXIOBHgIvDazPw,3865
|
18
18
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -63,9 +63,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
63
63
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
64
64
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
65
65
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
66
|
-
liger_kernel_nightly-0.5.3.
|
67
|
-
liger_kernel_nightly-0.5.3.
|
68
|
-
liger_kernel_nightly-0.5.3.
|
69
|
-
liger_kernel_nightly-0.5.3.
|
70
|
-
liger_kernel_nightly-0.5.3.
|
71
|
-
liger_kernel_nightly-0.5.3.
|
66
|
+
liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
67
|
+
liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/METADATA,sha256=xtathj_pY7bV0Pkw0qNpzJ-cDVUXRy3AsSemRtaTRYY,21766
|
68
|
+
liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
69
|
+
liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
70
|
+
liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
71
|
+
liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|