liger-kernel-nightly 0.5.2.dev20250108102127__py3-none-any.whl → 0.5.2.dev20250109023714__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- liger_kernel/chunked_loss/fused_linear_preference.py +40 -12
- liger_kernel/chunked_loss/orpo_loss.py +5 -2
- liger_kernel/transformers/trainer/orpo_trainer.py +16 -4
- {liger_kernel_nightly-0.5.2.dev20250108102127.dist-info → liger_kernel_nightly-0.5.2.dev20250109023714.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.2.dev20250108102127.dist-info → liger_kernel_nightly-0.5.2.dev20250109023714.dist-info}/RECORD +9 -9
- {liger_kernel_nightly-0.5.2.dev20250108102127.dist-info → liger_kernel_nightly-0.5.2.dev20250109023714.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127.dist-info → liger_kernel_nightly-0.5.2.dev20250109023714.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127.dist-info → liger_kernel_nightly-0.5.2.dev20250109023714.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108102127.dist-info → liger_kernel_nightly-0.5.2.dev20250109023714.dist-info}/top_level.txt +0 -0
@@ -27,6 +27,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
27
27
|
alpha=1.0,
|
28
28
|
beta=0.1,
|
29
29
|
compute_nll_loss=True,
|
30
|
+
nll_target=None,
|
30
31
|
compiled=True,
|
31
32
|
use_ref_model=False,
|
32
33
|
ref_input=None,
|
@@ -58,6 +59,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
58
59
|
alpha (float): Weight for the NLL loss.
|
59
60
|
beta (float): Weight for the preference loss.
|
60
61
|
compute_nll_loss (bool): Whether to compute NLL loss.
|
62
|
+
nll_target (torch.Tensor, optional): Target tensor for NLL loss. Shape: (batch_size, seq_len). If not provided the target is used.
|
61
63
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
62
64
|
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
63
65
|
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
@@ -96,11 +98,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
96
98
|
use_ref_model=use_ref_model,
|
97
99
|
ref_weight=ref_weight,
|
98
100
|
ref_bias=ref_bias,
|
101
|
+
full_nll_target=nll_target,
|
99
102
|
average_log_prob=average_log_prob,
|
100
103
|
**loss_kwargs,
|
101
104
|
)
|
102
105
|
|
103
|
-
def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk):
|
106
|
+
def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk):
|
104
107
|
"""
|
105
108
|
Fused forward and backward pass for a chunk of input and target.
|
106
109
|
"""
|
@@ -111,13 +114,18 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
111
114
|
target_chunk,
|
112
115
|
bias,
|
113
116
|
ref_input_chunk=ref_input_chunk,
|
117
|
+
chosen_nll_target_chunk=chosen_nll_target_chunk,
|
114
118
|
)
|
115
119
|
else:
|
116
120
|
return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
|
117
|
-
input_chunk,
|
121
|
+
input_chunk,
|
122
|
+
weight,
|
123
|
+
target_chunk,
|
124
|
+
ref_input_chunk=ref_input_chunk,
|
125
|
+
chosen_nll_target_chunk=chosen_nll_target_chunk,
|
118
126
|
)
|
119
127
|
|
120
|
-
def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
|
128
|
+
def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None, chosen_nll_target_chunk=None):
|
121
129
|
if bias is not None:
|
122
130
|
(
|
123
131
|
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
|
@@ -132,7 +140,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
132
140
|
*aux_outputs,
|
133
141
|
),
|
134
142
|
),
|
135
|
-
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
|
143
|
+
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
|
136
144
|
grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
|
137
145
|
else:
|
138
146
|
(
|
@@ -148,7 +156,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
148
156
|
*aux_outputs,
|
149
157
|
),
|
150
158
|
),
|
151
|
-
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
|
159
|
+
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
|
152
160
|
|
153
161
|
# Accumulate gradients
|
154
162
|
grad_weight.add_(chunk_grad_weight)
|
@@ -191,6 +199,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
191
199
|
_rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
|
192
200
|
_rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
|
193
201
|
|
202
|
+
if nll_target is not None:
|
203
|
+
_chosen_nll_target_chunks = torch.chunk(nll_target[:len_chosen], chunks=chunks, dim=0)
|
204
|
+
|
194
205
|
if use_ref_model:
|
195
206
|
_ref_chosen_input_chunks = torch.chunk(ref_input[:len_chosen], chunks=chunks, dim=0)
|
196
207
|
_ref_rejected_input_chunks = torch.chunk(ref_input[len_chosen:], chunks=chunks, dim=0)
|
@@ -202,6 +213,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
202
213
|
rejected_target_chunk,
|
203
214
|
ref_chosen_input_chunk,
|
204
215
|
ref_rejected_input_chunk,
|
216
|
+
chosen_nll_target_chunk,
|
205
217
|
) in zip(
|
206
218
|
_chosen_input_chunks,
|
207
219
|
_rejected_input_chunks,
|
@@ -209,6 +221,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
209
221
|
_rejected_target_chunks,
|
210
222
|
(_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
|
211
223
|
(_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
|
224
|
+
(_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)),
|
212
225
|
strict=False,
|
213
226
|
):
|
214
227
|
input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
|
@@ -222,9 +235,10 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
222
235
|
torch._dynamo.mark_dynamic(target_chunk, 1)
|
223
236
|
torch._dynamo.mark_dynamic(target, 1)
|
224
237
|
torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
|
238
|
+
torch._dynamo.mark_dynamic(chosen_nll_target_chunk, 1) if nll_target is not None else None
|
225
239
|
|
226
240
|
# accumulate loss, gradients, and metrics
|
227
|
-
accumulate_chunk(input_chunk, target_chunk, ref_input_chunk)
|
241
|
+
accumulate_chunk(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
|
228
242
|
|
229
243
|
# combine grad_chosen_inputs and grad_rejected_inputs
|
230
244
|
grad_inputs = grad_chosen_inputs + grad_rejected_inputs
|
@@ -258,7 +272,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
258
272
|
grad_weight = grad_weight * grad_output[0][0]
|
259
273
|
grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
|
260
274
|
|
261
|
-
return grad_input, grad_weight, None, grad_bias, None, None, None
|
275
|
+
return grad_input, grad_weight, None, grad_bias, None, None, None, None
|
262
276
|
|
263
277
|
@staticmethod
|
264
278
|
def chunk_forward(
|
@@ -268,6 +282,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
268
282
|
bias=None,
|
269
283
|
ignore_index=-100,
|
270
284
|
compute_nll_loss=True,
|
285
|
+
chosen_nll_target_chunk=None,
|
271
286
|
average_log_prob=True,
|
272
287
|
):
|
273
288
|
len_chosen_chunk = target_chunk.shape[0] // 2
|
@@ -278,9 +293,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
278
293
|
|
279
294
|
chosen_nll_loss = 0.0
|
280
295
|
if compute_nll_loss:
|
296
|
+
nll_labels = (
|
297
|
+
chosen_nll_target_chunk if chosen_nll_target_chunk is not None else target_chunk[:len_chosen_chunk]
|
298
|
+
)
|
281
299
|
chosen_nll_loss = F.nll_loss(
|
282
300
|
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
|
283
|
-
|
301
|
+
nll_labels.view(-1),
|
284
302
|
reduction="sum",
|
285
303
|
ignore_index=ignore_index,
|
286
304
|
)
|
@@ -324,6 +342,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
324
342
|
ref_input_chunk=None,
|
325
343
|
ref_weight=None,
|
326
344
|
ref_bias=None,
|
345
|
+
full_nll_target=None,
|
346
|
+
chosen_nll_target_chunk=None,
|
327
347
|
average_log_prob=True,
|
328
348
|
**loss_kwargs,
|
329
349
|
):
|
@@ -343,6 +363,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
343
363
|
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
344
364
|
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
345
365
|
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
366
|
+
full_nll_target (torch.Tensor, optional): Full target tensor for NLL loss. Shape: (batch_size, sequence_length).
|
367
|
+
chosen_nll_target_chunk (torch.Tensor, optional): Target tensor for NLL loss. Shape: (chunk_size, sequence_length) If not provided the target_chunk is used.
|
346
368
|
average_log_prob (bool): Whether to average log probabilities or the sum.
|
347
369
|
loss_kwargs (dict): Additional arguments for the loss function.
|
348
370
|
"""
|
@@ -359,9 +381,14 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
359
381
|
bias=bias,
|
360
382
|
ignore_index=ignore_index,
|
361
383
|
compute_nll_loss=compute_nll_loss,
|
384
|
+
chosen_nll_target_chunk=chosen_nll_target_chunk,
|
362
385
|
average_log_prob=average_log_prob,
|
363
386
|
)
|
364
|
-
|
387
|
+
if full_nll_target is not None:
|
388
|
+
chosen_nll_loss = chosen_nll_loss / (full_nll_target[: full_nll_target.shape[0] // 2] != ignore_index).sum()
|
389
|
+
else:
|
390
|
+
chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
|
391
|
+
|
365
392
|
chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
|
366
393
|
rejected_logits_mean = rejected_logits.sum() / (
|
367
394
|
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
@@ -372,9 +399,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
372
399
|
(
|
373
400
|
ref_chosen_logps,
|
374
401
|
ref_rejected_logps,
|
375
|
-
|
376
|
-
|
377
|
-
|
402
|
+
_,
|
403
|
+
_,
|
404
|
+
_,
|
378
405
|
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
379
406
|
ref_input_chunk,
|
380
407
|
ref_weight,
|
@@ -382,6 +409,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
382
409
|
ref_bias,
|
383
410
|
ignore_index=ignore_index,
|
384
411
|
compute_nll_loss=False, # We don't need NLL loss for the reference model
|
412
|
+
chosen_nll_target_chunk=None,
|
385
413
|
average_log_prob=average_log_prob,
|
386
414
|
)
|
387
415
|
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
|
@@ -52,6 +52,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
52
52
|
ignore_index=-100,
|
53
53
|
beta=0.1,
|
54
54
|
compute_nll_loss=True,
|
55
|
+
nll_target=None,
|
55
56
|
compiled=True,
|
56
57
|
):
|
57
58
|
return LigerFusedLinearPreferenceBase.forward(
|
@@ -64,13 +65,14 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
64
65
|
ignore_index=ignore_index,
|
65
66
|
beta=beta,
|
66
67
|
compute_nll_loss=compute_nll_loss,
|
68
|
+
nll_target=nll_target,
|
67
69
|
compiled=compiled,
|
68
70
|
)
|
69
71
|
|
70
72
|
@staticmethod
|
71
73
|
def backward(ctx, *grad_output):
|
72
74
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
73
|
-
return *grads, None, None, None, None
|
75
|
+
return *grads, None, None, None, None, None
|
74
76
|
|
75
77
|
|
76
78
|
class LigerFusedLinearORPOLoss(torch.nn.Module):
|
@@ -96,7 +98,7 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
96
98
|
self.compute_nll_loss = compute_nll_loss
|
97
99
|
self.compiled = compiled
|
98
100
|
|
99
|
-
def forward(self, lin_weight, _input, target, bias=None):
|
101
|
+
def forward(self, lin_weight, _input, target, bias=None, nll_target=None):
|
100
102
|
return LigerFusedLinearORPOFunction.apply(
|
101
103
|
_input,
|
102
104
|
lin_weight,
|
@@ -105,5 +107,6 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
105
107
|
self.ignore_index,
|
106
108
|
self.beta,
|
107
109
|
self.compute_nll_loss,
|
110
|
+
nll_target,
|
108
111
|
self.compiled,
|
109
112
|
)
|
@@ -93,6 +93,13 @@ class LigerORPOTrainer(ORPOTrainer):
|
|
93
93
|
if self.aux_loss_enabled:
|
94
94
|
model_kwargs["output_router_logits"] = True
|
95
95
|
|
96
|
+
if self.is_encoder_decoder:
|
97
|
+
labels = concatenated_batch["concatenated_labels"].clone()
|
98
|
+
else:
|
99
|
+
labels = concatenated_batch["concatenated_input_ids"].clone()
|
100
|
+
attention_mask = concatenated_batch["concatenated_attention_mask"]
|
101
|
+
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
|
102
|
+
|
96
103
|
if isinstance(model, FullyShardedDataParallel):
|
97
104
|
outputs = _FSDPForwardRedirection()(
|
98
105
|
model,
|
@@ -114,15 +121,20 @@ class LigerORPOTrainer(ORPOTrainer):
|
|
114
121
|
|
115
122
|
orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta)
|
116
123
|
|
117
|
-
def orpo_partial(lm_head, last_hidden_state, concatenated_labels):
|
118
|
-
return orpo_loss_fn(
|
124
|
+
def orpo_partial(lm_head, last_hidden_state, concatenated_labels, nll_target):
|
125
|
+
return orpo_loss_fn(
|
126
|
+
lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias, nll_target=nll_target
|
127
|
+
)
|
119
128
|
|
120
129
|
orpo_loss, aux_outputs = _FSDPForwardRedirection()(
|
121
130
|
model,
|
122
131
|
orpo_partial,
|
123
132
|
model.lm_head,
|
124
|
-
outputs.last_hidden_state,
|
125
|
-
concatenated_batch["concatenated_labels"]
|
133
|
+
outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state,
|
134
|
+
concatenated_batch["concatenated_labels"][:, 1:]
|
135
|
+
if not self.is_encoder_decoder
|
136
|
+
else concatenated_batch["concatenated_labels"],
|
137
|
+
labels[:, 1:] if not self.is_encoder_decoder else labels,
|
126
138
|
)
|
127
139
|
# if aux_loss_enabled, add the aux_loss to the orpo_loss
|
128
140
|
if self.aux_loss_enabled:
|
@@ -7,8 +7,8 @@ liger_kernel/chunked_loss/cpo_loss.py,sha256=OdBR8WYdHTKpLI_c9DcuwqKSWPeAAeTyREz
|
|
7
7
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=VYZMOafdvE8xlhvTtwjrz81tIzxR1mHF4lXdsADnIQg,4373
|
8
8
|
liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
|
9
9
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=uQtwtu-kaUZJTjNhAnIr3O794oUlUZ98XR5shYtwP5k,10440
|
10
|
-
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=
|
11
|
-
liger_kernel/chunked_loss/orpo_loss.py,sha256=
|
10
|
+
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=idK9V9NivoVITqVpiG0fEGUHSvinYWkn9-EYXZjR-KQ,18356
|
11
|
+
liger_kernel/chunked_loss/orpo_loss.py,sha256=yjcrrbVeemLYodoSKT-FMSnaPtyKAZ3aOrvPD6tTY6Y,3617
|
12
12
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=3TTc7U79Orjgi-Wu81WZkWk5MgsdqKXIOBHgIvDazPw,3865
|
13
13
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
liger_kernel/ops/cross_entropy.py,sha256=SRzAF9Ek84pBVFy3wqQZs7AhRoorKRIgQ-Td_rtl1Kk,18606
|
@@ -55,12 +55,12 @@ liger_kernel/transformers/model/phi3.py,sha256=biRa8fph9qdnQmkD9I21t5XIjpIt1i6UK
|
|
55
55
|
liger_kernel/transformers/model/qwen2.py,sha256=14UuPjxB-tjqWn85Tn4fqBFvVhVsth5iPEt8kJSMiew,9581
|
56
56
|
liger_kernel/transformers/model/qwen2_vl.py,sha256=rZg3nU3YgF6wkB1UJ0a9IACSIlVOSCyLltyqw951MQQ,8609
|
57
57
|
liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
|
58
|
-
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=
|
58
|
+
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
59
59
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
60
60
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
61
|
-
liger_kernel_nightly-0.5.2.
|
62
|
-
liger_kernel_nightly-0.5.2.
|
63
|
-
liger_kernel_nightly-0.5.2.
|
64
|
-
liger_kernel_nightly-0.5.2.
|
65
|
-
liger_kernel_nightly-0.5.2.
|
66
|
-
liger_kernel_nightly-0.5.2.
|
61
|
+
liger_kernel_nightly-0.5.2.dev20250109023714.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
62
|
+
liger_kernel_nightly-0.5.2.dev20250109023714.dist-info/METADATA,sha256=Qixggh8iZsja-nfhW4tPgwipEP_pCro3KxqZkBeyZ7s,21055
|
63
|
+
liger_kernel_nightly-0.5.2.dev20250109023714.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
64
|
+
liger_kernel_nightly-0.5.2.dev20250109023714.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
65
|
+
liger_kernel_nightly-0.5.2.dev20250109023714.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
66
|
+
liger_kernel_nightly-0.5.2.dev20250109023714.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|