liger-kernel-nightly 0.5.2.dev20241212030605__py3-none-any.whl → 0.5.2.dev20241212033924__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/dpo_loss.py +12 -2
- liger_kernel/chunked_loss/fused_linear_preference.py +42 -9
- {liger_kernel_nightly-0.5.2.dev20241212030605.dist-info → liger_kernel_nightly-0.5.2.dev20241212033924.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.2.dev20241212030605.dist-info → liger_kernel_nightly-0.5.2.dev20241212033924.dist-info}/RECORD +8 -8
- {liger_kernel_nightly-0.5.2.dev20241212030605.dist-info → liger_kernel_nightly-0.5.2.dev20241212033924.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241212030605.dist-info → liger_kernel_nightly-0.5.2.dev20241212033924.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241212030605.dist-info → liger_kernel_nightly-0.5.2.dev20241212033924.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20241212030605.dist-info → liger_kernel_nightly-0.5.2.dev20241212033924.dist-info}/top_level.txt +0 -0
@@ -59,6 +59,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
59
59
|
weight,
|
60
60
|
target,
|
61
61
|
bias=None,
|
62
|
+
ref_input=None,
|
62
63
|
ref_weight=None,
|
63
64
|
ref_bias=None,
|
64
65
|
ignore_index=-100,
|
@@ -79,6 +80,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
79
80
|
compute_nll_loss=compute_nll_loss,
|
80
81
|
compiled=compiled,
|
81
82
|
use_ref_model=use_ref_model,
|
83
|
+
ref_input=ref_input,
|
82
84
|
ref_weight=ref_weight,
|
83
85
|
ref_bias=ref_bias,
|
84
86
|
)
|
@@ -86,7 +88,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
86
88
|
@staticmethod
|
87
89
|
def backward(ctx, *grad_output):
|
88
90
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
89
|
-
return *grads, None, None, None, None, None, None, None
|
91
|
+
return *grads, None, None, None, None, None, None, None, None
|
90
92
|
|
91
93
|
|
92
94
|
class LigerFusedLinearDPOLoss(torch.nn.Module):
|
@@ -118,13 +120,21 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
118
120
|
self.use_ref_model = use_ref_model
|
119
121
|
|
120
122
|
def forward(
|
121
|
-
self,
|
123
|
+
self,
|
124
|
+
lin_weight,
|
125
|
+
_input,
|
126
|
+
target,
|
127
|
+
bias=None,
|
128
|
+
ref_input=None,
|
129
|
+
ref_weight=None,
|
130
|
+
ref_bias=None,
|
122
131
|
):
|
123
132
|
return LigerFusedLinearDPOFunction.apply(
|
124
133
|
_input,
|
125
134
|
lin_weight,
|
126
135
|
target,
|
127
136
|
bias,
|
137
|
+
ref_input,
|
128
138
|
ref_weight,
|
129
139
|
ref_bias,
|
130
140
|
self.ignore_index,
|
@@ -29,7 +29,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
29
29
|
compute_nll_loss=True,
|
30
30
|
compiled=True,
|
31
31
|
use_ref_model=False,
|
32
|
-
|
32
|
+
ref_input=None,
|
33
33
|
ref_weight=None,
|
34
34
|
ref_bias=None,
|
35
35
|
**loss_kwargs,
|
@@ -97,20 +97,26 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
97
97
|
**loss_kwargs,
|
98
98
|
)
|
99
99
|
|
100
|
-
def fused_fwd_bwd(input_chunk, target_chunk):
|
100
|
+
def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk):
|
101
101
|
"""
|
102
102
|
Fused forward and backward pass for a chunk of input and target.
|
103
103
|
"""
|
104
104
|
if bias is not None:
|
105
105
|
return torch.func.grad_and_value(
|
106
106
|
compute_loss, argnums=(0, 1, 3), has_aux=True
|
107
|
-
)(
|
107
|
+
)(
|
108
|
+
input_chunk,
|
109
|
+
weight,
|
110
|
+
target_chunk,
|
111
|
+
bias,
|
112
|
+
ref_input_chunk=ref_input_chunk,
|
113
|
+
)
|
108
114
|
else:
|
109
115
|
return torch.func.grad_and_value(
|
110
116
|
compute_loss, argnums=(0, 1), has_aux=True
|
111
|
-
)(input_chunk, weight, target_chunk)
|
117
|
+
)(input_chunk, weight, target_chunk, ref_input_chunk=ref_input_chunk)
|
112
118
|
|
113
|
-
def accumulate_chunk(input_chunk, target_chunk):
|
119
|
+
def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
|
114
120
|
if bias is not None:
|
115
121
|
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
|
116
122
|
chunk_loss,
|
@@ -122,7 +128,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
122
128
|
chunk_nll_loss,
|
123
129
|
*aux_outputs,
|
124
130
|
),
|
125
|
-
) = fused_fwd_bwd(input_chunk, target_chunk)
|
131
|
+
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
|
126
132
|
grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
|
127
133
|
else:
|
128
134
|
(chunk_grad_input, chunk_grad_weight), (
|
@@ -135,7 +141,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
135
141
|
chunk_nll_loss,
|
136
142
|
*aux_outputs,
|
137
143
|
),
|
138
|
-
) = fused_fwd_bwd(input_chunk, target_chunk)
|
144
|
+
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
|
139
145
|
|
140
146
|
# Accumulate gradients
|
141
147
|
grad_weight.add_(chunk_grad_weight)
|
@@ -182,18 +188,43 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
182
188
|
_rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
|
183
189
|
_rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
|
184
190
|
|
191
|
+
if use_ref_model:
|
192
|
+
_ref_chosen_input_chunks = torch.chunk(
|
193
|
+
ref_input[:len_chosen], chunks=chunks, dim=0
|
194
|
+
)
|
195
|
+
_ref_rejected_input_chunks = torch.chunk(
|
196
|
+
ref_input[len_chosen:], chunks=chunks, dim=0
|
197
|
+
)
|
198
|
+
|
185
199
|
for (
|
186
200
|
chosen_input_chunk,
|
187
201
|
rejected_input_chunk,
|
188
202
|
chosen_target_chunk,
|
189
203
|
rejected_target_chunk,
|
204
|
+
ref_chosen_input_chunk,
|
205
|
+
ref_rejected_input_chunk,
|
190
206
|
) in zip(
|
191
207
|
_chosen_input_chunks,
|
192
208
|
_rejected_input_chunks,
|
193
209
|
_chosen_target_chunks,
|
194
210
|
_rejected_target_chunks,
|
211
|
+
(
|
212
|
+
_ref_chosen_input_chunks
|
213
|
+
if use_ref_model
|
214
|
+
else [None] * len(_chosen_input_chunks)
|
215
|
+
),
|
216
|
+
(
|
217
|
+
_ref_rejected_input_chunks
|
218
|
+
if use_ref_model
|
219
|
+
else [None] * len(_rejected_input_chunks)
|
220
|
+
),
|
195
221
|
):
|
196
222
|
input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
|
223
|
+
ref_input_chunk = (
|
224
|
+
torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0)
|
225
|
+
if use_ref_model
|
226
|
+
else None
|
227
|
+
)
|
197
228
|
target_chunk = torch.cat(
|
198
229
|
[chosen_target_chunk, rejected_target_chunk], dim=0
|
199
230
|
)
|
@@ -202,9 +233,10 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
202
233
|
torch._dynamo.mark_dynamic(input_chunk, 1)
|
203
234
|
torch._dynamo.mark_dynamic(target_chunk, 1)
|
204
235
|
torch._dynamo.mark_dynamic(target, 1)
|
236
|
+
torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
|
205
237
|
|
206
238
|
# accumulate loss, gradients, and metrics
|
207
|
-
accumulate_chunk(input_chunk, target_chunk)
|
239
|
+
accumulate_chunk(input_chunk, target_chunk, ref_input_chunk)
|
208
240
|
|
209
241
|
# combine grad_chosen_inputs and grad_rejected_inputs
|
210
242
|
grad_inputs = grad_chosen_inputs + grad_rejected_inputs
|
@@ -301,6 +333,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
301
333
|
beta=0.1,
|
302
334
|
compute_nll_loss=True,
|
303
335
|
use_ref_model=False,
|
336
|
+
ref_input_chunk=None,
|
304
337
|
ref_weight=None,
|
305
338
|
ref_bias=None,
|
306
339
|
**loss_kwargs,
|
@@ -357,7 +390,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
357
390
|
ref_rejected_logits,
|
358
391
|
ref_chosen_nll_loss,
|
359
392
|
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
360
|
-
|
393
|
+
ref_input_chunk,
|
361
394
|
ref_weight,
|
362
395
|
target_chunk,
|
363
396
|
ref_bias,
|
@@ -4,10 +4,10 @@ liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
|
|
4
4
|
liger_kernel/chunked_loss/README.md,sha256=K6rucm6nqHpWCmxUOhBYcE3apwQxAy0TfRUippR7Icw,2243
|
5
5
|
liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
|
6
6
|
liger_kernel/chunked_loss/cpo_loss.py,sha256=Qu1Ul2A12sp6CqIT-atPbHWFb_LLtINEA9mOpIRx_0g,3097
|
7
|
-
liger_kernel/chunked_loss/dpo_loss.py,sha256=
|
7
|
+
liger_kernel/chunked_loss/dpo_loss.py,sha256=9S67SzKkLyoBmHGx8bkmthSNHlCT2ikBy9CCdb7wGj0,4381
|
8
8
|
liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
|
9
9
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=2BH6DCPjsR2zS6zcwFPcIIZRhLF8SohjGdKsAJ_301o,10222
|
10
|
-
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=
|
10
|
+
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=AsovMdfsOjgWVxtDhZ_rXqpahMsKTg8YueXnZcHt1XQ,16376
|
11
11
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=ZuKGjbkIYzV4UzvupNdq6vyxCp7-BztQkUt8ZnFvKos,3531
|
12
12
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=Wa4LOlDG9PbJkOOkKg8hbKvnKgg7OTBz6-qIkwPK1yw,3275
|
13
13
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=c4OQVJmhNOloj0JYSEc0j_cQuBb
|
|
58
58
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=jko6oq_XQdBSmXubp05E-_YXOyhtB5Bj75dg5YNwOsE,7517
|
59
59
|
liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
|
60
60
|
liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
|
61
|
-
liger_kernel_nightly-0.5.2.
|
62
|
-
liger_kernel_nightly-0.5.2.
|
63
|
-
liger_kernel_nightly-0.5.2.
|
64
|
-
liger_kernel_nightly-0.5.2.
|
65
|
-
liger_kernel_nightly-0.5.2.
|
66
|
-
liger_kernel_nightly-0.5.2.
|
61
|
+
liger_kernel_nightly-0.5.2.dev20241212033924.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
62
|
+
liger_kernel_nightly-0.5.2.dev20241212033924.dist-info/METADATA,sha256=ayx2_ON0TY-xC2ba0fpG3x5Vgx5b_SQCIRx-qw455u8,20260
|
63
|
+
liger_kernel_nightly-0.5.2.dev20241212033924.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
64
|
+
liger_kernel_nightly-0.5.2.dev20241212033924.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
65
|
+
liger_kernel_nightly-0.5.2.dev20241212033924.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
66
|
+
liger_kernel_nightly-0.5.2.dev20241212033924.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|