liger-kernel-nightly 0.5.3.dev20250220195514__py3-none-any.whl → 0.5.3.dev20250220230230__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -43,20 +43,20 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
43
43
  3. Maintain reasonable distance from the reference model
44
44
 
45
45
  Args:
46
- chosen_logps: Log probabilities of chosen tokens (batch_size,)
47
- rejected_logps: Log probabilities of rejected tokens (batch_size,)
46
+ average_log_prob_chunk: Log probabilities for the chunk (batch_size,)
47
+ preference_labels_chunk: Preference labels for the chunk (batch_size,)
48
48
  full_target: Non chunked full target tensor
49
- ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
50
- ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
51
- beta: Weight for the direct preference loss
49
+ ref_average_log_prob_chunk: Reference log probs for the chunk (batch_size,)
50
+ beta: Weight for the KTO loss
52
51
  kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
53
52
  Returns:
54
- Tuple of (loss, chosen_rewards, rejected_rewards):
55
53
  - loss: The KTO loss value
56
- - chosen_rewards: Reward signals for chosen responses (detached)
57
- - rejected_rewards: Reward signals for rejected responses (detached)
58
54
  """
59
- logratios_chunk = average_log_prob_chunk - ref_average_log_prob_chunk
55
+ if ref_average_log_prob_chunk is not None:
56
+ logratios_chunk = average_log_prob_chunk - ref_average_log_prob_chunk
57
+ else:
58
+ logratios_chunk = average_log_prob_chunk
59
+
60
60
  multiplier_chunk = torch.where(preference_labels_chunk, 1, -1)
61
61
  if kl is not None:
62
62
  losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.3.dev20250220195514
3
+ Version: 0.5.3.dev20250220230230
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -12,7 +12,7 @@ liger_kernel/chunked_loss/fused_linear_rlhf.py,sha256=sAApL4GQ3YL2F-ymIAF61GCpFf
12
12
  liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=ZqYlXXhIphkJPxOS7iI70avgrr6x0skEtgpckZTYau0,9819
13
13
  liger_kernel/chunked_loss/grpo_loss.py,sha256=M5qlQR-v5Rh8N3P3dPGNhOKygDFJ4516_rJaVPzU_-c,4980
14
14
  liger_kernel/chunked_loss/jsd_loss.py,sha256=yRCQdvd3ruTWP4A_BfU8VcZ6LepSUfO0Ob7stGnueQY,6052
15
- liger_kernel/chunked_loss/kto_loss.py,sha256=eVNW6HVCAm32shpfhbRlk92Flnjd7G32v0gK9DUUSOQ,5655
15
+ liger_kernel/chunked_loss/kto_loss.py,sha256=b3ffJyk97e-6XdXd4HFrYyx8wW4A-CU4gOaJSimKLtA,5476
16
16
  liger_kernel/chunked_loss/orpo_loss.py,sha256=yjcrrbVeemLYodoSKT-FMSnaPtyKAZ3aOrvPD6tTY6Y,3617
17
17
  liger_kernel/chunked_loss/simpo_loss.py,sha256=3TTc7U79Orjgi-Wu81WZkWk5MgsdqKXIOBHgIvDazPw,3865
18
18
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -63,9 +63,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
63
63
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
64
64
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
65
65
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
66
- liger_kernel_nightly-0.5.3.dev20250220195514.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
67
- liger_kernel_nightly-0.5.3.dev20250220195514.dist-info/METADATA,sha256=WkDgh3E1y7TYWCDttILZqikHz2S5b2kxLKoJ7JiWMd8,21766
68
- liger_kernel_nightly-0.5.3.dev20250220195514.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
69
- liger_kernel_nightly-0.5.3.dev20250220195514.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
70
- liger_kernel_nightly-0.5.3.dev20250220195514.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
71
- liger_kernel_nightly-0.5.3.dev20250220195514.dist-info/RECORD,,
66
+ liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
67
+ liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/METADATA,sha256=xtathj_pY7bV0Pkw0qNpzJ-cDVUXRy3AsSemRtaTRYY,21766
68
+ liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
69
+ liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
70
+ liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
71
+ liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/RECORD,,