liger-kernel-nightly 0.5.8.dev20250416185644__py3-none-any.whl → 0.5.8.dev20250422210723__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -68,6 +68,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
68
68
  compute_nll_loss=False,
69
69
  compiled=True,
70
70
  use_ref_model=True,
71
+ average_log_prob=False,
71
72
  chunk_size=1,
72
73
  ):
73
74
  """
@@ -85,6 +86,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
85
86
  compute_nll_loss (bool): Whether to compute the NLL loss
86
87
  compiled (bool): Whether to use torch compile
87
88
  use_ref_model (bool): Whether to use a reference model
89
+ average_log_prob (bool): Whether to average the log probability per non-masked token
88
90
  chunk_size (int): Size of chunks for processing.
89
91
  Returns:
90
92
  torch.Tensor: Computed loss
@@ -104,13 +106,14 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
104
106
  ref_input=ref_input,
105
107
  ref_weight=ref_weight,
106
108
  ref_bias=ref_bias,
109
+ average_log_prob=average_log_prob,
107
110
  chunk_size=chunk_size,
108
111
  )
109
112
 
110
113
  @staticmethod
111
114
  def backward(ctx, *grad_output):
112
115
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
113
- return *grads, None, None, None, None, None, None, None, None, None
116
+ return *grads, None, None, None, None, None, None, None, None, None, None
114
117
 
115
118
 
116
119
  class LigerFusedLinearDPOLoss(torch.nn.Module):
@@ -125,6 +128,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
125
128
  compute_nll_loss: bool = False,
126
129
  compiled: bool = True,
127
130
  use_ref_model: bool = True,
131
+ average_log_prob: bool = True,
128
132
  chunk_size: int = 1,
129
133
  ):
130
134
  """
@@ -134,6 +138,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
134
138
  compute_nll_loss (bool): Whether to compute the NLL loss.
135
139
  compiled (bool): Whether to use the torch compiled kernel.
136
140
  use_ref_model (bool): Whether to use a reference model for the DPO loss.
141
+ average_log_prob (bool): Whether to average the log probability per non-masked token.
137
142
  chunk_size (int): Size of chunks for processing.
138
143
  """
139
144
  super().__init__()
@@ -142,6 +147,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
142
147
  self.compute_nll_loss = compute_nll_loss
143
148
  self.compiled = compiled
144
149
  self.use_ref_model = use_ref_model
150
+ self.average_log_prob = average_log_prob
145
151
  self.chunk_size = chunk_size
146
152
 
147
153
  def forward(
@@ -167,5 +173,6 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
167
173
  self.compute_nll_loss,
168
174
  self.compiled,
169
175
  self.use_ref_model,
176
+ self.average_log_prob,
170
177
  self.chunk_size,
171
178
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.8.dev20250416185644
3
+ Version: 0.5.8.dev20250422210723
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ liger_kernel/utils.py,sha256=178Hn8uD-VauDT6FjqMyXLbKLod8ObIpaTtapHwfEK0,1861
4
4
  liger_kernel/chunked_loss/README.md,sha256=0FmkFC3hKBqyoDT5uTlIYmrvRkF-EOCR1y-EBU1LpWU,2248
5
5
  liger_kernel/chunked_loss/__init__.py,sha256=ATu-xX5Fc49Cr6yBOGBRNTo593ZrU5ZCsIuvoIbJWw4,603
6
6
  liger_kernel/chunked_loss/cpo_loss.py,sha256=Gzz1eU4kgcbdubFVRy55e8A1Cr-r45UgNicXwZIjmBU,5454
7
- liger_kernel/chunked_loss/dpo_loss.py,sha256=xZwGqS04si9zXyob95SAdalC-hajZg8fWINqiqffN8k,5855
7
+ liger_kernel/chunked_loss/dpo_loss.py,sha256=Xypt4FoTSmAnJE4SWtsCv4aNHK4ToR1LonUQtCTEuHQ,6258
8
8
  liger_kernel/chunked_loss/functional.py,sha256=9G3nKm-Bi7uoZRFkL8wwGMl6juDl4bSzDvTa5GHZPzg,955
9
9
  liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=ooR-qnZCyWJN935oHCSWLaKKKyaYERyhNczRGi1VOiw,11935
10
10
  liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=AA19cpv6D8mo5RbSK5GRCcZoOSnpxV_Z1eJlAsC5eic,13434
@@ -74,9 +74,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
74
74
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
75
75
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
76
76
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
77
- liger_kernel_nightly-0.5.8.dev20250416185644.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
78
- liger_kernel_nightly-0.5.8.dev20250416185644.dist-info/METADATA,sha256=DMyDK7rTzTSE8a03KwKq6MmT6aHmPX3XIuhShff4Qgs,23297
79
- liger_kernel_nightly-0.5.8.dev20250416185644.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
80
- liger_kernel_nightly-0.5.8.dev20250416185644.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
81
- liger_kernel_nightly-0.5.8.dev20250416185644.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
82
- liger_kernel_nightly-0.5.8.dev20250416185644.dist-info/RECORD,,
77
+ liger_kernel_nightly-0.5.8.dev20250422210723.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
78
+ liger_kernel_nightly-0.5.8.dev20250422210723.dist-info/METADATA,sha256=aSh18zXYcQy1fb3OW8Q-Q9_DYczeWXULpNDET3PCbfg,23297
79
+ liger_kernel_nightly-0.5.8.dev20250422210723.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
80
+ liger_kernel_nightly-0.5.8.dev20250422210723.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
81
+ liger_kernel_nightly-0.5.8.dev20250422210723.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
82
+ liger_kernel_nightly-0.5.8.dev20250422210723.dist-info/RECORD,,