liger-kernel-nightly 0.5.3.dev20250220195514__py3-none-any.whl → 0.5.3.dev20250220230230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -43,20 +43,20 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
43
43
  3. Maintain reasonable distance from the reference model
44
44
 
45
45
  Args:
46
- chosen_logps: Log probabilities of chosen tokens (batch_size,)
47
- rejected_logps: Log probabilities of rejected tokens (batch_size,)
46
+ average_log_prob_chunk: Log probabilities for the chunk (batch_size,)
47
+ preference_labels_chunk: Preference labels for the chunk (batch_size,)
48
48
  full_target: Non chunked full target tensor
49
- ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
50
- ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
51
- beta: Weight for the direct preference loss
49
+ ref_average_log_prob_chunk: Reference log probs for the chunk (batch_size,)
50
+ beta: Weight for the KTO loss
52
51
  kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
53
52
  Returns:
54
- Tuple of (loss, chosen_rewards, rejected_rewards):
55
53
  - loss: The KTO loss value
56
- - chosen_rewards: Reward signals for chosen responses (detached)
57
- - rejected_rewards: Reward signals for rejected responses (detached)
58
54
  """
59
- logratios_chunk = average_log_prob_chunk - ref_average_log_prob_chunk
55
+ if ref_average_log_prob_chunk is not None:
56
+ logratios_chunk = average_log_prob_chunk - ref_average_log_prob_chunk
57
+ else:
58
+ logratios_chunk = average_log_prob_chunk
59
+
60
60
  multiplier_chunk = torch.where(preference_labels_chunk, 1, -1)
61
61
  if kl is not None:
62
62
  losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.3.dev20250220195514
3
+ Version: 0.5.3.dev20250220230230
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -12,7 +12,7 @@ liger_kernel/chunked_loss/fused_linear_rlhf.py,sha256=sAApL4GQ3YL2F-ymIAF61GCpFf
12
12
  liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=ZqYlXXhIphkJPxOS7iI70avgrr6x0skEtgpckZTYau0,9819
13
13
  liger_kernel/chunked_loss/grpo_loss.py,sha256=M5qlQR-v5Rh8N3P3dPGNhOKygDFJ4516_rJaVPzU_-c,4980
14
14
  liger_kernel/chunked_loss/jsd_loss.py,sha256=yRCQdvd3ruTWP4A_BfU8VcZ6LepSUfO0Ob7stGnueQY,6052
15
- liger_kernel/chunked_loss/kto_loss.py,sha256=eVNW6HVCAm32shpfhbRlk92Flnjd7G32v0gK9DUUSOQ,5655
15
+ liger_kernel/chunked_loss/kto_loss.py,sha256=b3ffJyk97e-6XdXd4HFrYyx8wW4A-CU4gOaJSimKLtA,5476
16
16
  liger_kernel/chunked_loss/orpo_loss.py,sha256=yjcrrbVeemLYodoSKT-FMSnaPtyKAZ3aOrvPD6tTY6Y,3617
17
17
  liger_kernel/chunked_loss/simpo_loss.py,sha256=3TTc7U79Orjgi-Wu81WZkWk5MgsdqKXIOBHgIvDazPw,3865
18
18
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -63,9 +63,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
63
63
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
64
64
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
65
65
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
66
- liger_kernel_nightly-0.5.3.dev20250220195514.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
67
- liger_kernel_nightly-0.5.3.dev20250220195514.dist-info/METADATA,sha256=WkDgh3E1y7TYWCDttILZqikHz2S5b2kxLKoJ7JiWMd8,21766
68
- liger_kernel_nightly-0.5.3.dev20250220195514.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
69
- liger_kernel_nightly-0.5.3.dev20250220195514.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
70
- liger_kernel_nightly-0.5.3.dev20250220195514.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
71
- liger_kernel_nightly-0.5.3.dev20250220195514.dist-info/RECORD,,
66
+ liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
67
+ liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/METADATA,sha256=xtathj_pY7bV0Pkw0qNpzJ-cDVUXRy3AsSemRtaTRYY,21766
68
+ liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
69
+ liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
70
+ liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
71
+ liger_kernel_nightly-0.5.3.dev20250220230230.dist-info/RECORD,,