liger-kernel-nightly 0.5.2.dev20250108073340__py3-none-any.whl → 0.5.2.dev20250109023714__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/fused_linear_preference.py +40 -12
- liger_kernel/chunked_loss/orpo_loss.py +5 -2
- liger_kernel/ops/cross_entropy.py +8 -24
- liger_kernel/ops/fused_linear_cross_entropy.py +4 -4
- liger_kernel/transformers/cross_entropy.py +0 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +16 -4
- {liger_kernel_nightly-0.5.2.dev20250108073340.dist-info → liger_kernel_nightly-0.5.2.dev20250109023714.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.2.dev20250108073340.dist-info → liger_kernel_nightly-0.5.2.dev20250109023714.dist-info}/RECORD +12 -12
- {liger_kernel_nightly-0.5.2.dev20250108073340.dist-info → liger_kernel_nightly-0.5.2.dev20250109023714.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108073340.dist-info → liger_kernel_nightly-0.5.2.dev20250109023714.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108073340.dist-info → liger_kernel_nightly-0.5.2.dev20250109023714.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108073340.dist-info → liger_kernel_nightly-0.5.2.dev20250109023714.dist-info}/top_level.txt +0 -0
|
@@ -27,6 +27,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
27
27
|
alpha=1.0,
|
|
28
28
|
beta=0.1,
|
|
29
29
|
compute_nll_loss=True,
|
|
30
|
+
nll_target=None,
|
|
30
31
|
compiled=True,
|
|
31
32
|
use_ref_model=False,
|
|
32
33
|
ref_input=None,
|
|
@@ -58,6 +59,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
58
59
|
alpha (float): Weight for the NLL loss.
|
|
59
60
|
beta (float): Weight for the preference loss.
|
|
60
61
|
compute_nll_loss (bool): Whether to compute NLL loss.
|
|
62
|
+
nll_target (torch.Tensor, optional): Target tensor for NLL loss. Shape: (batch_size, seq_len). If not provided the target is used.
|
|
61
63
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
|
62
64
|
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
|
63
65
|
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
|
@@ -96,11 +98,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
96
98
|
use_ref_model=use_ref_model,
|
|
97
99
|
ref_weight=ref_weight,
|
|
98
100
|
ref_bias=ref_bias,
|
|
101
|
+
full_nll_target=nll_target,
|
|
99
102
|
average_log_prob=average_log_prob,
|
|
100
103
|
**loss_kwargs,
|
|
101
104
|
)
|
|
102
105
|
|
|
103
|
-
def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk):
|
|
106
|
+
def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk):
|
|
104
107
|
"""
|
|
105
108
|
Fused forward and backward pass for a chunk of input and target.
|
|
106
109
|
"""
|
|
@@ -111,13 +114,18 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
111
114
|
target_chunk,
|
|
112
115
|
bias,
|
|
113
116
|
ref_input_chunk=ref_input_chunk,
|
|
117
|
+
chosen_nll_target_chunk=chosen_nll_target_chunk,
|
|
114
118
|
)
|
|
115
119
|
else:
|
|
116
120
|
return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
|
|
117
|
-
input_chunk,
|
|
121
|
+
input_chunk,
|
|
122
|
+
weight,
|
|
123
|
+
target_chunk,
|
|
124
|
+
ref_input_chunk=ref_input_chunk,
|
|
125
|
+
chosen_nll_target_chunk=chosen_nll_target_chunk,
|
|
118
126
|
)
|
|
119
127
|
|
|
120
|
-
def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
|
|
128
|
+
def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None, chosen_nll_target_chunk=None):
|
|
121
129
|
if bias is not None:
|
|
122
130
|
(
|
|
123
131
|
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
|
|
@@ -132,7 +140,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
132
140
|
*aux_outputs,
|
|
133
141
|
),
|
|
134
142
|
),
|
|
135
|
-
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
|
|
143
|
+
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
|
|
136
144
|
grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
|
|
137
145
|
else:
|
|
138
146
|
(
|
|
@@ -148,7 +156,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
148
156
|
*aux_outputs,
|
|
149
157
|
),
|
|
150
158
|
),
|
|
151
|
-
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
|
|
159
|
+
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
|
|
152
160
|
|
|
153
161
|
# Accumulate gradients
|
|
154
162
|
grad_weight.add_(chunk_grad_weight)
|
|
@@ -191,6 +199,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
191
199
|
_rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
|
|
192
200
|
_rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
|
|
193
201
|
|
|
202
|
+
if nll_target is not None:
|
|
203
|
+
_chosen_nll_target_chunks = torch.chunk(nll_target[:len_chosen], chunks=chunks, dim=0)
|
|
204
|
+
|
|
194
205
|
if use_ref_model:
|
|
195
206
|
_ref_chosen_input_chunks = torch.chunk(ref_input[:len_chosen], chunks=chunks, dim=0)
|
|
196
207
|
_ref_rejected_input_chunks = torch.chunk(ref_input[len_chosen:], chunks=chunks, dim=0)
|
|
@@ -202,6 +213,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
202
213
|
rejected_target_chunk,
|
|
203
214
|
ref_chosen_input_chunk,
|
|
204
215
|
ref_rejected_input_chunk,
|
|
216
|
+
chosen_nll_target_chunk,
|
|
205
217
|
) in zip(
|
|
206
218
|
_chosen_input_chunks,
|
|
207
219
|
_rejected_input_chunks,
|
|
@@ -209,6 +221,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
209
221
|
_rejected_target_chunks,
|
|
210
222
|
(_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
|
|
211
223
|
(_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
|
|
224
|
+
(_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)),
|
|
212
225
|
strict=False,
|
|
213
226
|
):
|
|
214
227
|
input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
|
|
@@ -222,9 +235,10 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
222
235
|
torch._dynamo.mark_dynamic(target_chunk, 1)
|
|
223
236
|
torch._dynamo.mark_dynamic(target, 1)
|
|
224
237
|
torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
|
|
238
|
+
torch._dynamo.mark_dynamic(chosen_nll_target_chunk, 1) if nll_target is not None else None
|
|
225
239
|
|
|
226
240
|
# accumulate loss, gradients, and metrics
|
|
227
|
-
accumulate_chunk(input_chunk, target_chunk, ref_input_chunk)
|
|
241
|
+
accumulate_chunk(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
|
|
228
242
|
|
|
229
243
|
# combine grad_chosen_inputs and grad_rejected_inputs
|
|
230
244
|
grad_inputs = grad_chosen_inputs + grad_rejected_inputs
|
|
@@ -258,7 +272,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
258
272
|
grad_weight = grad_weight * grad_output[0][0]
|
|
259
273
|
grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
|
|
260
274
|
|
|
261
|
-
return grad_input, grad_weight, None, grad_bias, None, None, None
|
|
275
|
+
return grad_input, grad_weight, None, grad_bias, None, None, None, None
|
|
262
276
|
|
|
263
277
|
@staticmethod
|
|
264
278
|
def chunk_forward(
|
|
@@ -268,6 +282,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
268
282
|
bias=None,
|
|
269
283
|
ignore_index=-100,
|
|
270
284
|
compute_nll_loss=True,
|
|
285
|
+
chosen_nll_target_chunk=None,
|
|
271
286
|
average_log_prob=True,
|
|
272
287
|
):
|
|
273
288
|
len_chosen_chunk = target_chunk.shape[0] // 2
|
|
@@ -278,9 +293,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
278
293
|
|
|
279
294
|
chosen_nll_loss = 0.0
|
|
280
295
|
if compute_nll_loss:
|
|
296
|
+
nll_labels = (
|
|
297
|
+
chosen_nll_target_chunk if chosen_nll_target_chunk is not None else target_chunk[:len_chosen_chunk]
|
|
298
|
+
)
|
|
281
299
|
chosen_nll_loss = F.nll_loss(
|
|
282
300
|
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
|
|
283
|
-
|
|
301
|
+
nll_labels.view(-1),
|
|
284
302
|
reduction="sum",
|
|
285
303
|
ignore_index=ignore_index,
|
|
286
304
|
)
|
|
@@ -324,6 +342,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
324
342
|
ref_input_chunk=None,
|
|
325
343
|
ref_weight=None,
|
|
326
344
|
ref_bias=None,
|
|
345
|
+
full_nll_target=None,
|
|
346
|
+
chosen_nll_target_chunk=None,
|
|
327
347
|
average_log_prob=True,
|
|
328
348
|
**loss_kwargs,
|
|
329
349
|
):
|
|
@@ -343,6 +363,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
343
363
|
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
|
344
364
|
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
|
345
365
|
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
|
366
|
+
full_nll_target (torch.Tensor, optional): Full target tensor for NLL loss. Shape: (batch_size, sequence_length).
|
|
367
|
+
chosen_nll_target_chunk (torch.Tensor, optional): Target tensor for NLL loss. Shape: (chunk_size, sequence_length) If not provided the target_chunk is used.
|
|
346
368
|
average_log_prob (bool): Whether to average log probabilities or the sum.
|
|
347
369
|
loss_kwargs (dict): Additional arguments for the loss function.
|
|
348
370
|
"""
|
|
@@ -359,9 +381,14 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
359
381
|
bias=bias,
|
|
360
382
|
ignore_index=ignore_index,
|
|
361
383
|
compute_nll_loss=compute_nll_loss,
|
|
384
|
+
chosen_nll_target_chunk=chosen_nll_target_chunk,
|
|
362
385
|
average_log_prob=average_log_prob,
|
|
363
386
|
)
|
|
364
|
-
|
|
387
|
+
if full_nll_target is not None:
|
|
388
|
+
chosen_nll_loss = chosen_nll_loss / (full_nll_target[: full_nll_target.shape[0] // 2] != ignore_index).sum()
|
|
389
|
+
else:
|
|
390
|
+
chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
|
|
391
|
+
|
|
365
392
|
chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
|
|
366
393
|
rejected_logits_mean = rejected_logits.sum() / (
|
|
367
394
|
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
|
@@ -372,9 +399,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
372
399
|
(
|
|
373
400
|
ref_chosen_logps,
|
|
374
401
|
ref_rejected_logps,
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
402
|
+
_,
|
|
403
|
+
_,
|
|
404
|
+
_,
|
|
378
405
|
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
|
379
406
|
ref_input_chunk,
|
|
380
407
|
ref_weight,
|
|
@@ -382,6 +409,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
382
409
|
ref_bias,
|
|
383
410
|
ignore_index=ignore_index,
|
|
384
411
|
compute_nll_loss=False, # We don't need NLL loss for the reference model
|
|
412
|
+
chosen_nll_target_chunk=None,
|
|
385
413
|
average_log_prob=average_log_prob,
|
|
386
414
|
)
|
|
387
415
|
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
|
|
@@ -52,6 +52,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
52
52
|
ignore_index=-100,
|
|
53
53
|
beta=0.1,
|
|
54
54
|
compute_nll_loss=True,
|
|
55
|
+
nll_target=None,
|
|
55
56
|
compiled=True,
|
|
56
57
|
):
|
|
57
58
|
return LigerFusedLinearPreferenceBase.forward(
|
|
@@ -64,13 +65,14 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
64
65
|
ignore_index=ignore_index,
|
|
65
66
|
beta=beta,
|
|
66
67
|
compute_nll_loss=compute_nll_loss,
|
|
68
|
+
nll_target=nll_target,
|
|
67
69
|
compiled=compiled,
|
|
68
70
|
)
|
|
69
71
|
|
|
70
72
|
@staticmethod
|
|
71
73
|
def backward(ctx, *grad_output):
|
|
72
74
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
73
|
-
return *grads, None, None, None, None
|
|
75
|
+
return *grads, None, None, None, None, None
|
|
74
76
|
|
|
75
77
|
|
|
76
78
|
class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
@@ -96,7 +98,7 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
|
96
98
|
self.compute_nll_loss = compute_nll_loss
|
|
97
99
|
self.compiled = compiled
|
|
98
100
|
|
|
99
|
-
def forward(self, lin_weight, _input, target, bias=None):
|
|
101
|
+
def forward(self, lin_weight, _input, target, bias=None, nll_target=None):
|
|
100
102
|
return LigerFusedLinearORPOFunction.apply(
|
|
101
103
|
_input,
|
|
102
104
|
lin_weight,
|
|
@@ -105,5 +107,6 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
|
105
107
|
self.ignore_index,
|
|
106
108
|
self.beta,
|
|
107
109
|
self.compute_nll_loss,
|
|
110
|
+
nll_target,
|
|
108
111
|
self.compiled,
|
|
109
112
|
)
|
|
@@ -20,9 +20,6 @@ if compare_version("triton", operator.ge, "3.0.0"):
|
|
|
20
20
|
else:
|
|
21
21
|
from triton.language.math import tanh
|
|
22
22
|
|
|
23
|
-
_TRUE: tl.constexpr = tl.constexpr(1)
|
|
24
|
-
_FALSE: tl.constexpr = tl.constexpr(0)
|
|
25
|
-
|
|
26
23
|
|
|
27
24
|
@triton.jit
|
|
28
25
|
def liger_cross_entropy_kernel(
|
|
@@ -95,7 +92,7 @@ def liger_cross_entropy_kernel(
|
|
|
95
92
|
return
|
|
96
93
|
|
|
97
94
|
loss_ptr += program_id * loss_stride
|
|
98
|
-
if RETURN_Z_LOSS
|
|
95
|
+
if RETURN_Z_LOSS:
|
|
99
96
|
z_loss_ptr += program_id * loss_stride
|
|
100
97
|
|
|
101
98
|
if HAS_WEIGHT:
|
|
@@ -254,7 +251,7 @@ def liger_cross_entropy_kernel(
|
|
|
254
251
|
loss += z_loss
|
|
255
252
|
|
|
256
253
|
tl.store(loss_ptr, loss)
|
|
257
|
-
if RETURN_Z_LOSS
|
|
254
|
+
if RETURN_Z_LOSS:
|
|
258
255
|
tl.store(z_loss_ptr, z_loss)
|
|
259
256
|
|
|
260
257
|
|
|
@@ -264,12 +261,6 @@ def liger_cross_entropy_kernel(
|
|
|
264
261
|
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
|
|
265
262
|
|
|
266
263
|
|
|
267
|
-
_bool_to_return_z_loss = {
|
|
268
|
-
True: _TRUE.value,
|
|
269
|
-
False: _FALSE.value,
|
|
270
|
-
}
|
|
271
|
-
|
|
272
|
-
|
|
273
264
|
def cross_entropy_forward(
|
|
274
265
|
_input,
|
|
275
266
|
target,
|
|
@@ -281,11 +272,7 @@ def cross_entropy_forward(
|
|
|
281
272
|
softcap,
|
|
282
273
|
return_z_loss,
|
|
283
274
|
):
|
|
284
|
-
|
|
285
|
-
assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
286
|
-
return_z_loss = _bool_to_return_z_loss[return_z_loss]
|
|
287
|
-
else:
|
|
288
|
-
assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
275
|
+
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
289
276
|
|
|
290
277
|
BT, V = _input.shape
|
|
291
278
|
n_rows = BT
|
|
@@ -294,10 +281,7 @@ def cross_entropy_forward(
|
|
|
294
281
|
|
|
295
282
|
# unreduced loss
|
|
296
283
|
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
297
|
-
if return_z_loss
|
|
298
|
-
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
299
|
-
else:
|
|
300
|
-
z_loss_1d = None # set None when return_z_loss == False
|
|
284
|
+
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
301
285
|
|
|
302
286
|
target_mask = target != ignore_index
|
|
303
287
|
n_non_ignore = target_mask.sum().item()
|
|
@@ -326,7 +310,7 @@ def cross_entropy_forward(
|
|
|
326
310
|
X_stride=_input.stride(-2),
|
|
327
311
|
Y_ptr=target,
|
|
328
312
|
Y_stride=target.stride(-1), # always 1
|
|
329
|
-
weight_ptr=weight
|
|
313
|
+
weight_ptr=weight, # dummy if None
|
|
330
314
|
loss_ptr=loss_1d,
|
|
331
315
|
z_loss_ptr=z_loss_1d,
|
|
332
316
|
loss_stride=loss_1d.stride(-1), # always 1
|
|
@@ -338,7 +322,7 @@ def cross_entropy_forward(
|
|
|
338
322
|
lse_square_scale=lse_square_scale,
|
|
339
323
|
label_smoothing=label_smoothing,
|
|
340
324
|
reduction=reduction,
|
|
341
|
-
softcap=softcap
|
|
325
|
+
softcap=softcap,
|
|
342
326
|
RETURN_Z_LOSS=return_z_loss,
|
|
343
327
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
344
328
|
HAS_WEIGHT=True if weight is not None else False,
|
|
@@ -350,10 +334,10 @@ def cross_entropy_forward(
|
|
|
350
334
|
|
|
351
335
|
if reduction == "none":
|
|
352
336
|
loss = loss_1d
|
|
353
|
-
z_loss = z_loss_1d if return_z_loss
|
|
337
|
+
z_loss = z_loss_1d if return_z_loss else None
|
|
354
338
|
else:
|
|
355
339
|
loss = torch.sum(loss_1d)
|
|
356
|
-
z_loss = torch.sum(z_loss_1d) if return_z_loss
|
|
340
|
+
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
357
341
|
|
|
358
342
|
return loss, z_loss, _input
|
|
359
343
|
|
|
@@ -92,9 +92,9 @@ def fused_linear_cross_entropy_forward(
|
|
|
92
92
|
X_stride=logits_chunk.stride(-2),
|
|
93
93
|
Y_ptr=target_chunk,
|
|
94
94
|
Y_stride=target_chunk.stride(-1), # always 1
|
|
95
|
-
weight_ptr=ce_weight
|
|
95
|
+
weight_ptr=ce_weight,
|
|
96
96
|
loss_ptr=loss_1d_slice,
|
|
97
|
-
z_loss_ptr=
|
|
97
|
+
z_loss_ptr=None,
|
|
98
98
|
loss_stride=loss_1d_slice.stride(-1), # always 1
|
|
99
99
|
n_cols=V,
|
|
100
100
|
n_non_ignore=total_n_non_ignore,
|
|
@@ -104,8 +104,8 @@ def fused_linear_cross_entropy_forward(
|
|
|
104
104
|
lse_square_scale=lse_square_scale,
|
|
105
105
|
label_smoothing=label_smoothing,
|
|
106
106
|
reduction=reduction,
|
|
107
|
-
softcap=softcap
|
|
108
|
-
RETURN_Z_LOSS=
|
|
107
|
+
softcap=softcap,
|
|
108
|
+
RETURN_Z_LOSS=False,
|
|
109
109
|
HAS_WEIGHT=True if ce_weight is not None else False,
|
|
110
110
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
111
111
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
@@ -20,9 +20,6 @@ class LigerCrossEntropyLoss(torch.nn.Module):
|
|
|
20
20
|
assert (label_smoothing >= 0) and (
|
|
21
21
|
label_smoothing <= 1
|
|
22
22
|
), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
|
|
23
|
-
assert (label_smoothing >= 0) and (
|
|
24
|
-
label_smoothing <= 1
|
|
25
|
-
), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
|
|
26
23
|
assert reduction in {
|
|
27
24
|
"mean",
|
|
28
25
|
"sum",
|
|
@@ -93,6 +93,13 @@ class LigerORPOTrainer(ORPOTrainer):
|
|
|
93
93
|
if self.aux_loss_enabled:
|
|
94
94
|
model_kwargs["output_router_logits"] = True
|
|
95
95
|
|
|
96
|
+
if self.is_encoder_decoder:
|
|
97
|
+
labels = concatenated_batch["concatenated_labels"].clone()
|
|
98
|
+
else:
|
|
99
|
+
labels = concatenated_batch["concatenated_input_ids"].clone()
|
|
100
|
+
attention_mask = concatenated_batch["concatenated_attention_mask"]
|
|
101
|
+
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
|
|
102
|
+
|
|
96
103
|
if isinstance(model, FullyShardedDataParallel):
|
|
97
104
|
outputs = _FSDPForwardRedirection()(
|
|
98
105
|
model,
|
|
@@ -114,15 +121,20 @@ class LigerORPOTrainer(ORPOTrainer):
|
|
|
114
121
|
|
|
115
122
|
orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta)
|
|
116
123
|
|
|
117
|
-
def orpo_partial(lm_head, last_hidden_state, concatenated_labels):
|
|
118
|
-
return orpo_loss_fn(
|
|
124
|
+
def orpo_partial(lm_head, last_hidden_state, concatenated_labels, nll_target):
|
|
125
|
+
return orpo_loss_fn(
|
|
126
|
+
lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias, nll_target=nll_target
|
|
127
|
+
)
|
|
119
128
|
|
|
120
129
|
orpo_loss, aux_outputs = _FSDPForwardRedirection()(
|
|
121
130
|
model,
|
|
122
131
|
orpo_partial,
|
|
123
132
|
model.lm_head,
|
|
124
|
-
outputs.last_hidden_state,
|
|
125
|
-
concatenated_batch["concatenated_labels"]
|
|
133
|
+
outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state,
|
|
134
|
+
concatenated_batch["concatenated_labels"][:, 1:]
|
|
135
|
+
if not self.is_encoder_decoder
|
|
136
|
+
else concatenated_batch["concatenated_labels"],
|
|
137
|
+
labels[:, 1:] if not self.is_encoder_decoder else labels,
|
|
126
138
|
)
|
|
127
139
|
# if aux_loss_enabled, add the aux_loss to the orpo_loss
|
|
128
140
|
if self.aux_loss_enabled:
|
|
@@ -7,12 +7,12 @@ liger_kernel/chunked_loss/cpo_loss.py,sha256=OdBR8WYdHTKpLI_c9DcuwqKSWPeAAeTyREz
|
|
|
7
7
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=VYZMOafdvE8xlhvTtwjrz81tIzxR1mHF4lXdsADnIQg,4373
|
|
8
8
|
liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
|
|
9
9
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=uQtwtu-kaUZJTjNhAnIr3O794oUlUZ98XR5shYtwP5k,10440
|
|
10
|
-
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=
|
|
11
|
-
liger_kernel/chunked_loss/orpo_loss.py,sha256=
|
|
10
|
+
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=idK9V9NivoVITqVpiG0fEGUHSvinYWkn9-EYXZjR-KQ,18356
|
|
11
|
+
liger_kernel/chunked_loss/orpo_loss.py,sha256=yjcrrbVeemLYodoSKT-FMSnaPtyKAZ3aOrvPD6tTY6Y,3617
|
|
12
12
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=3TTc7U79Orjgi-Wu81WZkWk5MgsdqKXIOBHgIvDazPw,3865
|
|
13
13
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
-
liger_kernel/ops/cross_entropy.py,sha256=
|
|
15
|
-
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=
|
|
14
|
+
liger_kernel/ops/cross_entropy.py,sha256=SRzAF9Ek84pBVFy3wqQZs7AhRoorKRIgQ-Td_rtl1Kk,18606
|
|
15
|
+
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=hezFRwbcPc-HNGZUFqUn5AYUqUpboPpFh4MNqEW4WgU,10108
|
|
16
16
|
liger_kernel/ops/fused_linear_jsd.py,sha256=eKqaADj7LgWfoYqyH03tjrmhNTfJOF1Dhx_bWzBTnTU,9600
|
|
17
17
|
liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
|
|
18
18
|
liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
|
|
@@ -28,7 +28,7 @@ liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectfl
|
|
|
28
28
|
liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
|
|
29
29
|
liger_kernel/transformers/__init__.py,sha256=QPmYkL6hosBPpPqCUGqvIvAtD9XzLgvZqZxUyYMZeVk,2008
|
|
30
30
|
liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
|
|
31
|
-
liger_kernel/transformers/cross_entropy.py,sha256=
|
|
31
|
+
liger_kernel/transformers/cross_entropy.py,sha256=LtiHlj_tK2YFpilwvbG_NEVzbf82zKRpWCZMjaFUd4M,1681
|
|
32
32
|
liger_kernel/transformers/functional.py,sha256=B1wkHWLx-YNhxvXBEXB4Ch1yEwF3mjwTPCeXA5aCV_c,4490
|
|
33
33
|
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=LAN8-pjUI2Erz_MnfMer-0ZmxJ0JlKxGzdZGJY-N65g,1569
|
|
34
34
|
liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
|
|
@@ -55,12 +55,12 @@ liger_kernel/transformers/model/phi3.py,sha256=biRa8fph9qdnQmkD9I21t5XIjpIt1i6UK
|
|
|
55
55
|
liger_kernel/transformers/model/qwen2.py,sha256=14UuPjxB-tjqWn85Tn4fqBFvVhVsth5iPEt8kJSMiew,9581
|
|
56
56
|
liger_kernel/transformers/model/qwen2_vl.py,sha256=rZg3nU3YgF6wkB1UJ0a9IACSIlVOSCyLltyqw951MQQ,8609
|
|
57
57
|
liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
|
|
58
|
-
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=
|
|
58
|
+
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
|
59
59
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
|
60
60
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
|
61
|
-
liger_kernel_nightly-0.5.2.
|
|
62
|
-
liger_kernel_nightly-0.5.2.
|
|
63
|
-
liger_kernel_nightly-0.5.2.
|
|
64
|
-
liger_kernel_nightly-0.5.2.
|
|
65
|
-
liger_kernel_nightly-0.5.2.
|
|
66
|
-
liger_kernel_nightly-0.5.2.
|
|
61
|
+
liger_kernel_nightly-0.5.2.dev20250109023714.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
|
62
|
+
liger_kernel_nightly-0.5.2.dev20250109023714.dist-info/METADATA,sha256=Qixggh8iZsja-nfhW4tPgwipEP_pCro3KxqZkBeyZ7s,21055
|
|
63
|
+
liger_kernel_nightly-0.5.2.dev20250109023714.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
|
64
|
+
liger_kernel_nightly-0.5.2.dev20250109023714.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
|
65
|
+
liger_kernel_nightly-0.5.2.dev20250109023714.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
|
66
|
+
liger_kernel_nightly-0.5.2.dev20250109023714.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|