liger-kernel 0.5.2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +2 -0
- liger_kernel/chunked_loss/cpo_loss.py +18 -8
- liger_kernel/chunked_loss/dpo_loss.py +20 -10
- liger_kernel/chunked_loss/functional.py +4 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
- liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
- liger_kernel/chunked_loss/jsd_loss.py +154 -0
- liger_kernel/chunked_loss/kto_loss.py +172 -0
- liger_kernel/chunked_loss/orpo_loss.py +8 -9
- liger_kernel/chunked_loss/simpo_loss.py +22 -8
- liger_kernel/env_report.py +5 -12
- liger_kernel/ops/cross_entropy.py +102 -51
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
- liger_kernel/ops/fused_linear_jsd.py +11 -29
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +2 -6
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +3 -5
- liger_kernel/ops/qwen2vl_mrope.py +8 -25
- liger_kernel/ops/rms_norm.py +14 -32
- liger_kernel/ops/rope.py +31 -33
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +16 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +4 -6
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +11 -7
- liger_kernel/transformers/fused_linear_cross_entropy.py +12 -7
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +3 -9
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +24 -54
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +36 -32
- liger_kernel/transformers/monkey_patch.py +43 -117
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/rope.py +2 -2
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/METADATA +38 -25
- liger_kernel-0.5.3.dist-info/RECORD +69 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/WHEEL +1 -1
- liger_kernel-0.5.2.dist-info/RECORD +0 -65
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/LICENSE +0 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/NOTICE +0 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.3.dist-info}/top_level.txt +0 -0
|
@@ -2,11 +2,11 @@ from abc import abstractmethod
|
|
|
2
2
|
from functools import partial
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
|
+
|
|
5
6
|
from torch.nn import functional as F
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
9
|
-
|
|
10
10
|
@abstractmethod
|
|
11
11
|
def preference_loss_fn(*args, **kwargs):
|
|
12
12
|
"""
|
|
@@ -27,11 +27,13 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
27
27
|
alpha=1.0,
|
|
28
28
|
beta=0.1,
|
|
29
29
|
compute_nll_loss=True,
|
|
30
|
+
nll_target=None,
|
|
30
31
|
compiled=True,
|
|
31
32
|
use_ref_model=False,
|
|
32
|
-
|
|
33
|
+
ref_input=None,
|
|
33
34
|
ref_weight=None,
|
|
34
35
|
ref_bias=None,
|
|
36
|
+
average_log_prob=True,
|
|
35
37
|
**loss_kwargs,
|
|
36
38
|
):
|
|
37
39
|
"""
|
|
@@ -57,10 +59,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
57
59
|
alpha (float): Weight for the NLL loss.
|
|
58
60
|
beta (float): Weight for the preference loss.
|
|
59
61
|
compute_nll_loss (bool): Whether to compute NLL loss.
|
|
62
|
+
nll_target (torch.Tensor, optional): Target tensor for NLL loss. Shape: (batch_size, seq_len). If not provided the target is used.
|
|
60
63
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
|
61
64
|
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
|
62
65
|
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
|
63
66
|
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
|
67
|
+
average_log_prob (bool): Whether to average log probabilities or to sum them over the completion.
|
|
64
68
|
loss_kwargs (dict): Other possible arguments that a loss function might need
|
|
65
69
|
"""
|
|
66
70
|
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
|
|
@@ -94,55 +98,70 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
94
98
|
use_ref_model=use_ref_model,
|
|
95
99
|
ref_weight=ref_weight,
|
|
96
100
|
ref_bias=ref_bias,
|
|
101
|
+
full_nll_target=nll_target,
|
|
102
|
+
average_log_prob=average_log_prob,
|
|
97
103
|
**loss_kwargs,
|
|
98
104
|
)
|
|
99
105
|
|
|
100
|
-
def fused_fwd_bwd(input_chunk, target_chunk):
|
|
106
|
+
def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk):
|
|
101
107
|
"""
|
|
102
108
|
Fused forward and backward pass for a chunk of input and target.
|
|
103
109
|
"""
|
|
104
110
|
if bias is not None:
|
|
105
|
-
return torch.func.grad_and_value(
|
|
106
|
-
|
|
107
|
-
|
|
111
|
+
return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 3), has_aux=True)(
|
|
112
|
+
input_chunk,
|
|
113
|
+
weight,
|
|
114
|
+
target_chunk,
|
|
115
|
+
bias,
|
|
116
|
+
ref_input_chunk=ref_input_chunk,
|
|
117
|
+
chosen_nll_target_chunk=chosen_nll_target_chunk,
|
|
118
|
+
)
|
|
108
119
|
else:
|
|
109
|
-
return torch.func.grad_and_value(
|
|
110
|
-
|
|
111
|
-
|
|
120
|
+
return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
|
|
121
|
+
input_chunk,
|
|
122
|
+
weight,
|
|
123
|
+
target_chunk,
|
|
124
|
+
ref_input_chunk=ref_input_chunk,
|
|
125
|
+
chosen_nll_target_chunk=chosen_nll_target_chunk,
|
|
126
|
+
)
|
|
112
127
|
|
|
113
|
-
def accumulate_chunk(input_chunk, target_chunk):
|
|
128
|
+
def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None, chosen_nll_target_chunk=None):
|
|
114
129
|
if bias is not None:
|
|
115
|
-
(
|
|
116
|
-
|
|
130
|
+
(
|
|
131
|
+
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
|
|
117
132
|
(
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
133
|
+
chunk_loss,
|
|
134
|
+
(
|
|
135
|
+
chunk_chosen_logps,
|
|
136
|
+
chunk_rejected_logps,
|
|
137
|
+
chunk_chosen_logits_mean,
|
|
138
|
+
chunk_rejected_logits_mean,
|
|
139
|
+
chunk_nll_loss,
|
|
140
|
+
*aux_outputs,
|
|
141
|
+
),
|
|
124
142
|
),
|
|
125
|
-
) = fused_fwd_bwd(input_chunk, target_chunk)
|
|
143
|
+
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
|
|
126
144
|
grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
|
|
127
145
|
else:
|
|
128
|
-
(
|
|
129
|
-
|
|
146
|
+
(
|
|
147
|
+
(chunk_grad_input, chunk_grad_weight),
|
|
130
148
|
(
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
149
|
+
chunk_loss,
|
|
150
|
+
(
|
|
151
|
+
chunk_chosen_logps,
|
|
152
|
+
chunk_rejected_logps,
|
|
153
|
+
chunk_chosen_logits_mean,
|
|
154
|
+
chunk_rejected_logits_mean,
|
|
155
|
+
chunk_nll_loss,
|
|
156
|
+
*aux_outputs,
|
|
157
|
+
),
|
|
137
158
|
),
|
|
138
|
-
) = fused_fwd_bwd(input_chunk, target_chunk)
|
|
159
|
+
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
|
|
139
160
|
|
|
140
161
|
# Accumulate gradients
|
|
141
162
|
grad_weight.add_(chunk_grad_weight)
|
|
142
163
|
grad_chosen_inputs.append(chunk_grad_input[: chosen_target_chunk.shape[0]])
|
|
143
|
-
grad_rejected_inputs.append(
|
|
144
|
-
chunk_grad_input[chosen_target_chunk.shape[0] :]
|
|
145
|
-
)
|
|
164
|
+
grad_rejected_inputs.append(chunk_grad_input[chosen_target_chunk.shape[0] :])
|
|
146
165
|
|
|
147
166
|
# Accumulate loss
|
|
148
167
|
loss_acc.add_(chunk_loss)
|
|
@@ -159,9 +178,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
159
178
|
if len(aggregated_aux_outputs) == 0:
|
|
160
179
|
for aux in aux_outputs:
|
|
161
180
|
if aux.ndim == 0:
|
|
162
|
-
aggregated_aux_outputs.append(
|
|
163
|
-
torch.zeros((), device=aux.device)
|
|
164
|
-
)
|
|
181
|
+
aggregated_aux_outputs.append(torch.zeros((), device=aux.device))
|
|
165
182
|
else:
|
|
166
183
|
aggregated_aux_outputs.append([])
|
|
167
184
|
|
|
@@ -182,29 +199,46 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
182
199
|
_rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
|
|
183
200
|
_rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
|
|
184
201
|
|
|
202
|
+
if nll_target is not None:
|
|
203
|
+
_chosen_nll_target_chunks = torch.chunk(nll_target[:len_chosen], chunks=chunks, dim=0)
|
|
204
|
+
|
|
205
|
+
if use_ref_model:
|
|
206
|
+
_ref_chosen_input_chunks = torch.chunk(ref_input[:len_chosen], chunks=chunks, dim=0)
|
|
207
|
+
_ref_rejected_input_chunks = torch.chunk(ref_input[len_chosen:], chunks=chunks, dim=0)
|
|
208
|
+
|
|
185
209
|
for (
|
|
186
210
|
chosen_input_chunk,
|
|
187
211
|
rejected_input_chunk,
|
|
188
212
|
chosen_target_chunk,
|
|
189
213
|
rejected_target_chunk,
|
|
214
|
+
ref_chosen_input_chunk,
|
|
215
|
+
ref_rejected_input_chunk,
|
|
216
|
+
chosen_nll_target_chunk,
|
|
190
217
|
) in zip(
|
|
191
218
|
_chosen_input_chunks,
|
|
192
219
|
_rejected_input_chunks,
|
|
193
220
|
_chosen_target_chunks,
|
|
194
221
|
_rejected_target_chunks,
|
|
222
|
+
(_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
|
|
223
|
+
(_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
|
|
224
|
+
(_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)),
|
|
225
|
+
strict=False,
|
|
195
226
|
):
|
|
196
227
|
input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
|
|
197
|
-
|
|
198
|
-
[
|
|
228
|
+
ref_input_chunk = (
|
|
229
|
+
torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0) if use_ref_model else None
|
|
199
230
|
)
|
|
231
|
+
target_chunk = torch.cat([chosen_target_chunk, rejected_target_chunk], dim=0)
|
|
200
232
|
|
|
201
233
|
# mark input_chunk, target_chunk, and target dimension 1 as dynamic to prevent torch.compile recompilation
|
|
202
234
|
torch._dynamo.mark_dynamic(input_chunk, 1)
|
|
203
235
|
torch._dynamo.mark_dynamic(target_chunk, 1)
|
|
204
236
|
torch._dynamo.mark_dynamic(target, 1)
|
|
237
|
+
torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
|
|
238
|
+
torch._dynamo.mark_dynamic(chosen_nll_target_chunk, 1) if nll_target is not None else None
|
|
205
239
|
|
|
206
240
|
# accumulate loss, gradients, and metrics
|
|
207
|
-
accumulate_chunk(input_chunk, target_chunk)
|
|
241
|
+
accumulate_chunk(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
|
|
208
242
|
|
|
209
243
|
# combine grad_chosen_inputs and grad_rejected_inputs
|
|
210
244
|
grad_inputs = grad_chosen_inputs + grad_rejected_inputs
|
|
@@ -233,14 +267,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
233
267
|
@staticmethod
|
|
234
268
|
def backward(ctx, *grad_output):
|
|
235
269
|
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
236
|
-
if torch.ne(
|
|
237
|
-
grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)
|
|
238
|
-
):
|
|
270
|
+
if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)):
|
|
239
271
|
grad_input = grad_input * grad_output[0][0]
|
|
240
272
|
grad_weight = grad_weight * grad_output[0][0]
|
|
241
273
|
grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
|
|
242
274
|
|
|
243
|
-
return grad_input, grad_weight, None, grad_bias, None, None, None
|
|
275
|
+
return grad_input, grad_weight, None, grad_bias, None, None, None, None
|
|
244
276
|
|
|
245
277
|
@staticmethod
|
|
246
278
|
def chunk_forward(
|
|
@@ -250,6 +282,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
250
282
|
bias=None,
|
|
251
283
|
ignore_index=-100,
|
|
252
284
|
compute_nll_loss=True,
|
|
285
|
+
chosen_nll_target_chunk=None,
|
|
286
|
+
average_log_prob=True,
|
|
253
287
|
):
|
|
254
288
|
len_chosen_chunk = target_chunk.shape[0] // 2
|
|
255
289
|
logits_chunk = input_chunk @ weight.t()
|
|
@@ -259,9 +293,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
259
293
|
|
|
260
294
|
chosen_nll_loss = 0.0
|
|
261
295
|
if compute_nll_loss:
|
|
296
|
+
nll_labels = (
|
|
297
|
+
chosen_nll_target_chunk if chosen_nll_target_chunk is not None else target_chunk[:len_chosen_chunk]
|
|
298
|
+
)
|
|
262
299
|
chosen_nll_loss = F.nll_loss(
|
|
263
300
|
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
|
|
264
|
-
|
|
301
|
+
nll_labels.view(-1),
|
|
265
302
|
reduction="sum",
|
|
266
303
|
ignore_index=ignore_index,
|
|
267
304
|
)
|
|
@@ -269,13 +306,14 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
269
306
|
loss_mask = target_chunk != ignore_index
|
|
270
307
|
label_chunk = torch.where(loss_mask, target_chunk, 0)
|
|
271
308
|
|
|
272
|
-
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
309
|
+
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
|
|
310
|
+
if average_log_prob:
|
|
311
|
+
log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
|
312
|
+
else:
|
|
313
|
+
log_prob = (per_token_logps * loss_mask).sum(-1)
|
|
276
314
|
|
|
277
|
-
chosen_logps =
|
|
278
|
-
rejected_logps =
|
|
315
|
+
chosen_logps = log_prob[:len_chosen_chunk]
|
|
316
|
+
rejected_logps = log_prob[len_chosen_chunk:]
|
|
279
317
|
|
|
280
318
|
chosen_logits = logits_chunk[:len_chosen_chunk]
|
|
281
319
|
rejected_logits = logits_chunk[len_chosen_chunk:]
|
|
@@ -301,8 +339,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
301
339
|
beta=0.1,
|
|
302
340
|
compute_nll_loss=True,
|
|
303
341
|
use_ref_model=False,
|
|
342
|
+
ref_input_chunk=None,
|
|
304
343
|
ref_weight=None,
|
|
305
344
|
ref_bias=None,
|
|
345
|
+
full_nll_target=None,
|
|
346
|
+
chosen_nll_target_chunk=None,
|
|
347
|
+
average_log_prob=True,
|
|
306
348
|
**loss_kwargs,
|
|
307
349
|
):
|
|
308
350
|
"""
|
|
@@ -321,6 +363,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
321
363
|
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
|
322
364
|
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
|
323
365
|
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
|
366
|
+
full_nll_target (torch.Tensor, optional): Full target tensor for NLL loss. Shape: (batch_size, sequence_length).
|
|
367
|
+
chosen_nll_target_chunk (torch.Tensor, optional): Target tensor for NLL loss. Shape: (chunk_size, sequence_length) If not provided the target_chunk is used.
|
|
368
|
+
average_log_prob (bool): Whether to average log probabilities or the sum.
|
|
324
369
|
loss_kwargs (dict): Additional arguments for the loss function.
|
|
325
370
|
"""
|
|
326
371
|
(
|
|
@@ -336,14 +381,15 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
336
381
|
bias=bias,
|
|
337
382
|
ignore_index=ignore_index,
|
|
338
383
|
compute_nll_loss=compute_nll_loss,
|
|
384
|
+
chosen_nll_target_chunk=chosen_nll_target_chunk,
|
|
385
|
+
average_log_prob=average_log_prob,
|
|
339
386
|
)
|
|
340
|
-
|
|
341
|
-
chosen_nll_loss
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
)
|
|
387
|
+
if full_nll_target is not None:
|
|
388
|
+
chosen_nll_loss = chosen_nll_loss / (full_nll_target[: full_nll_target.shape[0] // 2] != ignore_index).sum()
|
|
389
|
+
else:
|
|
390
|
+
chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
|
|
391
|
+
|
|
392
|
+
chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
|
|
347
393
|
rejected_logits_mean = rejected_logits.sum() / (
|
|
348
394
|
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
|
349
395
|
)
|
|
@@ -353,16 +399,18 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
353
399
|
(
|
|
354
400
|
ref_chosen_logps,
|
|
355
401
|
ref_rejected_logps,
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
402
|
+
_,
|
|
403
|
+
_,
|
|
404
|
+
_,
|
|
359
405
|
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
|
360
|
-
|
|
406
|
+
ref_input_chunk,
|
|
361
407
|
ref_weight,
|
|
362
408
|
target_chunk,
|
|
363
409
|
ref_bias,
|
|
364
410
|
ignore_index=ignore_index,
|
|
365
411
|
compute_nll_loss=False, # We don't need NLL loss for the reference model
|
|
412
|
+
chosen_nll_target_chunk=None,
|
|
413
|
+
average_log_prob=average_log_prob,
|
|
366
414
|
)
|
|
367
415
|
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
|
|
368
416
|
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
|
|
@@ -375,7 +423,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
375
423
|
else:
|
|
376
424
|
preference_loss, aux_outputs = preference_loss_outputs, []
|
|
377
425
|
|
|
378
|
-
loss = alpha * chosen_nll_loss
|
|
426
|
+
loss = alpha * chosen_nll_loss + preference_loss
|
|
379
427
|
return_vars = (
|
|
380
428
|
chosen_logps,
|
|
381
429
|
rejected_logps,
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from torch.nn import functional as F
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def preference_loss_fn(*args, **kwargs):
|
|
12
|
+
"""
|
|
13
|
+
To be extended by subclasses.
|
|
14
|
+
"""
|
|
15
|
+
raise NotImplementedError("Preference loss function must be implemented.")
|
|
16
|
+
|
|
17
|
+
@staticmethod
|
|
18
|
+
def forward(
|
|
19
|
+
ctx,
|
|
20
|
+
_input,
|
|
21
|
+
weight,
|
|
22
|
+
target,
|
|
23
|
+
preference_labels,
|
|
24
|
+
bias=None,
|
|
25
|
+
loss_fn=None,
|
|
26
|
+
chunk_size=1,
|
|
27
|
+
ignore_index=-100,
|
|
28
|
+
compiled=True,
|
|
29
|
+
use_ref_model=False,
|
|
30
|
+
ref_input=None,
|
|
31
|
+
ref_weight=None,
|
|
32
|
+
ref_bias=None,
|
|
33
|
+
**loss_kwargs,
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Base class for fused linear layer with unpaired preference loss like KTO
|
|
37
|
+
Expects _input to be stacked with chosen and rejected inputs on the batch dimension.
|
|
38
|
+
|
|
39
|
+
The mental model is:
|
|
40
|
+
|
|
41
|
+
forward()
|
|
42
|
+
├── Loop over chunks
|
|
43
|
+
└── compute_loss()
|
|
44
|
+
├── chunk_forward() # Compute logits and log probs
|
|
45
|
+
└── prefer_loss() # Calculate preference loss
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
|
|
49
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
|
|
50
|
+
target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len).
|
|
51
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
|
52
|
+
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
|
53
|
+
chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
|
|
54
|
+
ignore_index (int): Index to ignore for loss computation.
|
|
55
|
+
beta (float): Weight for the preference loss.
|
|
56
|
+
compiled (bool): Whether to use torch compile for chunk accumulation.
|
|
57
|
+
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
|
58
|
+
preference_labels (torch.Tensor): Boolean tensor indicating chosen (True) vs rejected (False) examples.
|
|
59
|
+
Shape: (batch_size,).
|
|
60
|
+
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
|
61
|
+
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
|
62
|
+
loss_kwargs (dict): Other possible arguments that a loss function might need
|
|
63
|
+
"""
|
|
64
|
+
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
|
|
65
|
+
CHUNK_SIZE = chunk_size
|
|
66
|
+
|
|
67
|
+
# Gradients to be accumulated
|
|
68
|
+
grad_inputs = []
|
|
69
|
+
grad_weight = torch.zeros_like(weight)
|
|
70
|
+
grad_bias = torch.zeros_like(bias) if bias is not None else None
|
|
71
|
+
|
|
72
|
+
# Loss to be accumulated
|
|
73
|
+
loss_acc = torch.zeros((), device=_input.device)
|
|
74
|
+
|
|
75
|
+
compute_loss = partial(
|
|
76
|
+
LigerFusedLinearUnpairedPreferenceBase._compute_loss,
|
|
77
|
+
preference_loss_fn=loss_fn,
|
|
78
|
+
full_target=target,
|
|
79
|
+
ignore_index=ignore_index,
|
|
80
|
+
use_ref_model=use_ref_model,
|
|
81
|
+
ref_weight=ref_weight,
|
|
82
|
+
ref_bias=ref_bias,
|
|
83
|
+
**loss_kwargs,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def fused_fwd_bwd(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk):
|
|
87
|
+
"""
|
|
88
|
+
Fused forward and backward pass for a chunk of input and target.
|
|
89
|
+
"""
|
|
90
|
+
argnums = (0, 1, 4) if bias is not None else (0, 1)
|
|
91
|
+
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=False)(
|
|
92
|
+
input_chunk,
|
|
93
|
+
weight,
|
|
94
|
+
target_chunk,
|
|
95
|
+
preference_labels_chunk,
|
|
96
|
+
bias,
|
|
97
|
+
ref_input_chunk=ref_input_chunk,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def accumulate_chunk(
|
|
101
|
+
input_chunk,
|
|
102
|
+
target_chunk,
|
|
103
|
+
preference_labels_chunk=None,
|
|
104
|
+
ref_input_chunk=None,
|
|
105
|
+
):
|
|
106
|
+
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss) = fused_fwd_bwd(
|
|
107
|
+
input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk
|
|
108
|
+
)
|
|
109
|
+
if bias is not None:
|
|
110
|
+
grad_bias.add_(chunk_grad_bias[0]) # accumulate bias gradient
|
|
111
|
+
|
|
112
|
+
# Accumulate gradients
|
|
113
|
+
grad_weight.add_(chunk_grad_weight)
|
|
114
|
+
grad_inputs.append(chunk_grad_input)
|
|
115
|
+
|
|
116
|
+
# Accumulate loss
|
|
117
|
+
loss_acc.add_(chunk_loss)
|
|
118
|
+
|
|
119
|
+
if compiled:
|
|
120
|
+
fused_fwd_bwd = torch.compile(fused_fwd_bwd)
|
|
121
|
+
|
|
122
|
+
# When not paired, use labels to separate chosen and rejected
|
|
123
|
+
assert preference_labels is not None, "preference_labels must be provided for unpaired preference loss"
|
|
124
|
+
|
|
125
|
+
chunks = max(1, _input.shape[0] // CHUNK_SIZE)
|
|
126
|
+
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
|
|
127
|
+
_target_chunks = torch.chunk(target, chunks=chunks, dim=0)
|
|
128
|
+
_preference_labels_chunks = torch.chunk(preference_labels, chunks=chunks, dim=0)
|
|
129
|
+
|
|
130
|
+
if use_ref_model:
|
|
131
|
+
_ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0)
|
|
132
|
+
|
|
133
|
+
for (
|
|
134
|
+
input_chunk,
|
|
135
|
+
target_chunk,
|
|
136
|
+
ref_input_chunk,
|
|
137
|
+
preference_labels_chunk,
|
|
138
|
+
) in zip(
|
|
139
|
+
_input_chunks,
|
|
140
|
+
_target_chunks,
|
|
141
|
+
(_ref_input_chunks if use_ref_model else [None] * len(_input_chunks)),
|
|
142
|
+
_preference_labels_chunks,
|
|
143
|
+
):
|
|
144
|
+
# mark input_chunk, target_chunk, and target dimension 1 (sequence length) as dynamic to prevent torch.compile recompilation
|
|
145
|
+
torch._dynamo.mark_dynamic(input_chunk, 1)
|
|
146
|
+
torch._dynamo.mark_dynamic(target_chunk, 1)
|
|
147
|
+
torch._dynamo.mark_dynamic(target, 1)
|
|
148
|
+
torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
|
|
149
|
+
torch._dynamo.mark_dynamic(preference_labels_chunk, 1)
|
|
150
|
+
|
|
151
|
+
# accumulate loss, gradients, and metrics
|
|
152
|
+
accumulate_chunk(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk)
|
|
153
|
+
|
|
154
|
+
ctx.save_for_backward(
|
|
155
|
+
torch.cat(grad_inputs, dim=0),
|
|
156
|
+
grad_weight,
|
|
157
|
+
grad_bias,
|
|
158
|
+
)
|
|
159
|
+
return loss_acc
|
|
160
|
+
|
|
161
|
+
@staticmethod
|
|
162
|
+
def backward(ctx, *grad_output):
|
|
163
|
+
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
164
|
+
if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)):
|
|
165
|
+
grad_input = grad_input * grad_output[0][0]
|
|
166
|
+
grad_weight = grad_weight * grad_output[0][0]
|
|
167
|
+
grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
|
|
168
|
+
|
|
169
|
+
return grad_input, grad_weight, None, None, grad_bias
|
|
170
|
+
|
|
171
|
+
@staticmethod
|
|
172
|
+
def chunk_forward(
|
|
173
|
+
input_chunk,
|
|
174
|
+
weight,
|
|
175
|
+
target_chunk,
|
|
176
|
+
bias=None,
|
|
177
|
+
ignore_index=-100,
|
|
178
|
+
):
|
|
179
|
+
logits_chunk = input_chunk @ weight.t()
|
|
180
|
+
if bias is not None:
|
|
181
|
+
logits_chunk = logits_chunk + bias
|
|
182
|
+
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
|
|
183
|
+
|
|
184
|
+
loss_mask_chunk = target_chunk != ignore_index
|
|
185
|
+
label_chunk = torch.where(loss_mask_chunk, target_chunk, 0)
|
|
186
|
+
|
|
187
|
+
per_token_logps_chunk = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
|
|
188
|
+
average_log_prob_chunk = (per_token_logps_chunk * loss_mask_chunk).sum(-1) / loss_mask_chunk.sum(-1)
|
|
189
|
+
|
|
190
|
+
return average_log_prob_chunk
|
|
191
|
+
|
|
192
|
+
@staticmethod
|
|
193
|
+
def _compute_loss(
|
|
194
|
+
input_chunk,
|
|
195
|
+
weight,
|
|
196
|
+
target_chunk,
|
|
197
|
+
preference_labels_chunk,
|
|
198
|
+
bias=None,
|
|
199
|
+
preference_loss_fn=None,
|
|
200
|
+
full_target=None,
|
|
201
|
+
ignore_index=-100,
|
|
202
|
+
use_ref_model=False,
|
|
203
|
+
ref_input_chunk=None,
|
|
204
|
+
ref_weight=None,
|
|
205
|
+
ref_bias=None,
|
|
206
|
+
**loss_kwargs,
|
|
207
|
+
):
|
|
208
|
+
"""
|
|
209
|
+
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
|
|
210
|
+
Args:
|
|
211
|
+
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
|
212
|
+
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
|
|
213
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
|
|
214
|
+
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
|
|
215
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
|
216
|
+
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
|
|
217
|
+
ignore_index (int): Index to ignore for loss computation.
|
|
218
|
+
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
|
219
|
+
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
|
220
|
+
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
|
221
|
+
loss_kwargs (dict): Additional arguments for the loss function.
|
|
222
|
+
"""
|
|
223
|
+
average_log_prob_chunk = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
|
|
224
|
+
input_chunk,
|
|
225
|
+
weight,
|
|
226
|
+
target_chunk,
|
|
227
|
+
bias=bias,
|
|
228
|
+
ignore_index=ignore_index,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
if use_ref_model:
|
|
232
|
+
with torch.no_grad():
|
|
233
|
+
ref_average_log_prob_chunk = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
|
|
234
|
+
ref_input_chunk,
|
|
235
|
+
ref_weight,
|
|
236
|
+
target_chunk,
|
|
237
|
+
ref_bias,
|
|
238
|
+
ignore_index=ignore_index,
|
|
239
|
+
)
|
|
240
|
+
loss_kwargs["ref_average_log_prob_chunk"] = ref_average_log_prob_chunk
|
|
241
|
+
|
|
242
|
+
preference_loss_chunk = preference_loss_fn(
|
|
243
|
+
average_log_prob_chunk, preference_labels_chunk, full_target, **loss_kwargs
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
return preference_loss_chunk
|