liger-kernel-nightly 0.5.8.dev20250422013410__py3-none-any.whl → 0.5.8.dev20250422210723__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/dpo_loss.py +8 -1
- {liger_kernel_nightly-0.5.8.dev20250422013410.dist-info → liger_kernel_nightly-0.5.8.dev20250422210723.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.8.dev20250422013410.dist-info → liger_kernel_nightly-0.5.8.dev20250422210723.dist-info}/RECORD +7 -7
- {liger_kernel_nightly-0.5.8.dev20250422013410.dist-info → liger_kernel_nightly-0.5.8.dev20250422210723.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.8.dev20250422013410.dist-info → liger_kernel_nightly-0.5.8.dev20250422210723.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.8.dev20250422013410.dist-info → liger_kernel_nightly-0.5.8.dev20250422210723.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.8.dev20250422013410.dist-info → liger_kernel_nightly-0.5.8.dev20250422210723.dist-info}/top_level.txt +0 -0
@@ -68,6 +68,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
68
68
|
compute_nll_loss=False,
|
69
69
|
compiled=True,
|
70
70
|
use_ref_model=True,
|
71
|
+
average_log_prob=False,
|
71
72
|
chunk_size=1,
|
72
73
|
):
|
73
74
|
"""
|
@@ -85,6 +86,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
85
86
|
compute_nll_loss (bool): Whether to compute the NLL loss
|
86
87
|
compiled (bool): Whether to use torch compile
|
87
88
|
use_ref_model (bool): Whether to use a reference model
|
89
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token
|
88
90
|
chunk_size (int): Size of chunks for processing.
|
89
91
|
Returns:
|
90
92
|
torch.Tensor: Computed loss
|
@@ -104,13 +106,14 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
104
106
|
ref_input=ref_input,
|
105
107
|
ref_weight=ref_weight,
|
106
108
|
ref_bias=ref_bias,
|
109
|
+
average_log_prob=average_log_prob,
|
107
110
|
chunk_size=chunk_size,
|
108
111
|
)
|
109
112
|
|
110
113
|
@staticmethod
|
111
114
|
def backward(ctx, *grad_output):
|
112
115
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
113
|
-
return *grads, None, None, None, None, None, None, None, None, None
|
116
|
+
return *grads, None, None, None, None, None, None, None, None, None, None
|
114
117
|
|
115
118
|
|
116
119
|
class LigerFusedLinearDPOLoss(torch.nn.Module):
|
@@ -125,6 +128,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
125
128
|
compute_nll_loss: bool = False,
|
126
129
|
compiled: bool = True,
|
127
130
|
use_ref_model: bool = True,
|
131
|
+
average_log_prob: bool = True,
|
128
132
|
chunk_size: int = 1,
|
129
133
|
):
|
130
134
|
"""
|
@@ -134,6 +138,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
134
138
|
compute_nll_loss (bool): Whether to compute the NLL loss.
|
135
139
|
compiled (bool): Whether to use the torch compiled kernel.
|
136
140
|
use_ref_model (bool): Whether to use a reference model for the DPO loss.
|
141
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token.
|
137
142
|
chunk_size (int): Size of chunks for processing.
|
138
143
|
"""
|
139
144
|
super().__init__()
|
@@ -142,6 +147,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
142
147
|
self.compute_nll_loss = compute_nll_loss
|
143
148
|
self.compiled = compiled
|
144
149
|
self.use_ref_model = use_ref_model
|
150
|
+
self.average_log_prob = average_log_prob
|
145
151
|
self.chunk_size = chunk_size
|
146
152
|
|
147
153
|
def forward(
|
@@ -167,5 +173,6 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
167
173
|
self.compute_nll_loss,
|
168
174
|
self.compiled,
|
169
175
|
self.use_ref_model,
|
176
|
+
self.average_log_prob,
|
170
177
|
self.chunk_size,
|
171
178
|
)
|
@@ -4,7 +4,7 @@ liger_kernel/utils.py,sha256=178Hn8uD-VauDT6FjqMyXLbKLod8ObIpaTtapHwfEK0,1861
|
|
4
4
|
liger_kernel/chunked_loss/README.md,sha256=0FmkFC3hKBqyoDT5uTlIYmrvRkF-EOCR1y-EBU1LpWU,2248
|
5
5
|
liger_kernel/chunked_loss/__init__.py,sha256=ATu-xX5Fc49Cr6yBOGBRNTo593ZrU5ZCsIuvoIbJWw4,603
|
6
6
|
liger_kernel/chunked_loss/cpo_loss.py,sha256=Gzz1eU4kgcbdubFVRy55e8A1Cr-r45UgNicXwZIjmBU,5454
|
7
|
-
liger_kernel/chunked_loss/dpo_loss.py,sha256=
|
7
|
+
liger_kernel/chunked_loss/dpo_loss.py,sha256=Xypt4FoTSmAnJE4SWtsCv4aNHK4ToR1LonUQtCTEuHQ,6258
|
8
8
|
liger_kernel/chunked_loss/functional.py,sha256=9G3nKm-Bi7uoZRFkL8wwGMl6juDl4bSzDvTa5GHZPzg,955
|
9
9
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=ooR-qnZCyWJN935oHCSWLaKKKyaYERyhNczRGi1VOiw,11935
|
10
10
|
liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=AA19cpv6D8mo5RbSK5GRCcZoOSnpxV_Z1eJlAsC5eic,13434
|
@@ -74,9 +74,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
74
74
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
75
75
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
76
76
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
77
|
-
liger_kernel_nightly-0.5.8.
|
78
|
-
liger_kernel_nightly-0.5.8.
|
79
|
-
liger_kernel_nightly-0.5.8.
|
80
|
-
liger_kernel_nightly-0.5.8.
|
81
|
-
liger_kernel_nightly-0.5.8.
|
82
|
-
liger_kernel_nightly-0.5.8.
|
77
|
+
liger_kernel_nightly-0.5.8.dev20250422210723.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
78
|
+
liger_kernel_nightly-0.5.8.dev20250422210723.dist-info/METADATA,sha256=aSh18zXYcQy1fb3OW8Q-Q9_DYczeWXULpNDET3PCbfg,23297
|
79
|
+
liger_kernel_nightly-0.5.8.dev20250422210723.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
80
|
+
liger_kernel_nightly-0.5.8.dev20250422210723.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
81
|
+
liger_kernel_nightly-0.5.8.dev20250422210723.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
82
|
+
liger_kernel_nightly-0.5.8.dev20250422210723.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|