liger-kernel-nightly 0.5.3.dev20250220195514__py3-none-any.whl → 0.5.3.dev20250220230230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/kto_loss.py +9 -9
- {liger_kernel_nightly-0.5.3.dev20250220195514.dist-info → liger_kernel_nightly-0.5.3.dev20250220230230.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.3.dev20250220195514.dist-info → liger_kernel_nightly-0.5.3.dev20250220230230.dist-info}/RECORD +7 -7
- {liger_kernel_nightly-0.5.3.dev20250220195514.dist-info → liger_kernel_nightly-0.5.3.dev20250220230230.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220195514.dist-info → liger_kernel_nightly-0.5.3.dev20250220230230.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220195514.dist-info → liger_kernel_nightly-0.5.3.dev20250220230230.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.3.dev20250220195514.dist-info → liger_kernel_nightly-0.5.3.dev20250220230230.dist-info}/top_level.txt +0 -0
@@ -43,20 +43,20 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
43
43
|
3. Maintain reasonable distance from the reference model
|
44
44
|
|
45
45
|
Args:
|
46
|
-
|
47
|
-
|
46
|
+
average_log_prob_chunk: Log probabilities for the chunk (batch_size,)
|
47
|
+
preference_labels_chunk: Preference labels for the chunk (batch_size,)
|
48
48
|
full_target: Non chunked full target tensor
|
49
|
-
|
50
|
-
|
51
|
-
beta: Weight for the direct preference loss
|
49
|
+
ref_average_log_prob_chunk: Reference log probs for the chunk (batch_size,)
|
50
|
+
beta: Weight for the KTO loss
|
52
51
|
kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
|
53
52
|
Returns:
|
54
|
-
Tuple of (loss, chosen_rewards, rejected_rewards):
|
55
53
|
- loss: The KTO loss value
|
56
|
-
- chosen_rewards: Reward signals for chosen responses (detached)
|
57
|
-
- rejected_rewards: Reward signals for rejected responses (detached)
|
58
54
|
"""
|
59
|
-
|
55
|
+
if ref_average_log_prob_chunk is not None:
|
56
|
+
logratios_chunk = average_log_prob_chunk - ref_average_log_prob_chunk
|
57
|
+
else:
|
58
|
+
logratios_chunk = average_log_prob_chunk
|
59
|
+
|
60
60
|
multiplier_chunk = torch.where(preference_labels_chunk, 1, -1)
|
61
61
|
if kl is not None:
|
62
62
|
losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk)
|
@@ -12,7 +12,7 @@ liger_kernel/chunked_loss/fused_linear_rlhf.py,sha256=sAApL4GQ3YL2F-ymIAF61GCpFf
|
|
12
12
|
liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=ZqYlXXhIphkJPxOS7iI70avgrr6x0skEtgpckZTYau0,9819
|
13
13
|
liger_kernel/chunked_loss/grpo_loss.py,sha256=M5qlQR-v5Rh8N3P3dPGNhOKygDFJ4516_rJaVPzU_-c,4980
|
14
14
|
liger_kernel/chunked_loss/jsd_loss.py,sha256=yRCQdvd3ruTWP4A_BfU8VcZ6LepSUfO0Ob7stGnueQY,6052
|
15
|
-
liger_kernel/chunked_loss/kto_loss.py,sha256=
|
15
|
+
liger_kernel/chunked_loss/kto_loss.py,sha256=b3ffJyk97e-6XdXd4HFrYyx8wW4A-CU4gOaJSimKLtA,5476
|
16
16
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=yjcrrbVeemLYodoSKT-FMSnaPtyKAZ3aOrvPD6tTY6Y,3617
|
17
17
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=3TTc7U79Orjgi-Wu81WZkWk5MgsdqKXIOBHgIvDazPw,3865
|
18
18
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -63,9 +63,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
63
63
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
64
64
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
65
65
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
66
|
-
liger_kernel_nightly-0.5.3.
|
67
|
-
liger_kernel_nightly-0.5.3.
|
68
|
-
liger_kernel_nightly-0.5.3.
|
69
|
-
liger_kernel_nightly-0.5.3.
|
70
|
-
liger_kernel_nightly-0.5.3.
|
71
|
-
liger_kernel_nightly-0.5.3.
|
66
|
+
liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
67
|
+
liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/METADATA,sha256=xtathj_pY7bV0Pkw0qNpzJ-cDVUXRy3AsSemRtaTRYY,21766
|
68
|
+
liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
69
|
+
liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
70
|
+
liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
71
|
+
liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|