liger-kernel-nightly 0.5.2.dev20250108072837__py3-none-any.whl → 0.5.2.dev20250108073340__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cpo_loss.py +1 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +14 -3
- {liger_kernel_nightly-0.5.2.dev20250108072837.dist-info → liger_kernel_nightly-0.5.2.dev20250108073340.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.2.dev20250108072837.dist-info → liger_kernel_nightly-0.5.2.dev20250108073340.dist-info}/RECORD +8 -8
- {liger_kernel_nightly-0.5.2.dev20250108072837.dist-info → liger_kernel_nightly-0.5.2.dev20250108073340.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108072837.dist-info → liger_kernel_nightly-0.5.2.dev20250108073340.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108072837.dist-info → liger_kernel_nightly-0.5.2.dev20250108073340.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20250108072837.dist-info → liger_kernel_nightly-0.5.2.dev20250108073340.dist-info}/top_level.txt +0 -0
@@ -32,6 +32,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
32
32
|
ref_input=None,
|
33
33
|
ref_weight=None,
|
34
34
|
ref_bias=None,
|
35
|
+
average_log_prob=True,
|
35
36
|
**loss_kwargs,
|
36
37
|
):
|
37
38
|
"""
|
@@ -61,6 +62,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
61
62
|
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
62
63
|
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
63
64
|
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
65
|
+
average_log_prob (bool): Whether to average log probabilities or to sum them over the completion.
|
64
66
|
loss_kwargs (dict): Other possible arguments that a loss function might need
|
65
67
|
"""
|
66
68
|
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
|
@@ -94,6 +96,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
94
96
|
use_ref_model=use_ref_model,
|
95
97
|
ref_weight=ref_weight,
|
96
98
|
ref_bias=ref_bias,
|
99
|
+
average_log_prob=average_log_prob,
|
97
100
|
**loss_kwargs,
|
98
101
|
)
|
99
102
|
|
@@ -265,6 +268,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
265
268
|
bias=None,
|
266
269
|
ignore_index=-100,
|
267
270
|
compute_nll_loss=True,
|
271
|
+
average_log_prob=True,
|
268
272
|
):
|
269
273
|
len_chosen_chunk = target_chunk.shape[0] // 2
|
270
274
|
logits_chunk = input_chunk @ weight.t()
|
@@ -285,10 +289,13 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
285
289
|
label_chunk = torch.where(loss_mask, target_chunk, 0)
|
286
290
|
|
287
291
|
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
|
288
|
-
average_log_prob
|
292
|
+
if average_log_prob:
|
293
|
+
log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
294
|
+
else:
|
295
|
+
log_prob = (per_token_logps * loss_mask).sum(-1)
|
289
296
|
|
290
|
-
chosen_logps =
|
291
|
-
rejected_logps =
|
297
|
+
chosen_logps = log_prob[:len_chosen_chunk]
|
298
|
+
rejected_logps = log_prob[len_chosen_chunk:]
|
292
299
|
|
293
300
|
chosen_logits = logits_chunk[:len_chosen_chunk]
|
294
301
|
rejected_logits = logits_chunk[len_chosen_chunk:]
|
@@ -317,6 +324,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
317
324
|
ref_input_chunk=None,
|
318
325
|
ref_weight=None,
|
319
326
|
ref_bias=None,
|
327
|
+
average_log_prob=True,
|
320
328
|
**loss_kwargs,
|
321
329
|
):
|
322
330
|
"""
|
@@ -335,6 +343,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
335
343
|
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
336
344
|
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
337
345
|
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
346
|
+
average_log_prob (bool): Whether to average log probabilities or the sum.
|
338
347
|
loss_kwargs (dict): Additional arguments for the loss function.
|
339
348
|
"""
|
340
349
|
(
|
@@ -350,6 +359,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
350
359
|
bias=bias,
|
351
360
|
ignore_index=ignore_index,
|
352
361
|
compute_nll_loss=compute_nll_loss,
|
362
|
+
average_log_prob=average_log_prob,
|
353
363
|
)
|
354
364
|
chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
|
355
365
|
chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
|
@@ -372,6 +382,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
372
382
|
ref_bias,
|
373
383
|
ignore_index=ignore_index,
|
374
384
|
compute_nll_loss=False, # We don't need NLL loss for the reference model
|
385
|
+
average_log_prob=average_log_prob,
|
375
386
|
)
|
376
387
|
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
|
377
388
|
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
|
@@ -3,11 +3,11 @@ liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,17
|
|
3
3
|
liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
|
4
4
|
liger_kernel/chunked_loss/README.md,sha256=K6rucm6nqHpWCmxUOhBYcE3apwQxAy0TfRUippR7Icw,2243
|
5
5
|
liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
|
6
|
-
liger_kernel/chunked_loss/cpo_loss.py,sha256=
|
6
|
+
liger_kernel/chunked_loss/cpo_loss.py,sha256=OdBR8WYdHTKpLI_c9DcuwqKSWPeAAeTyREz46Vu_cAY,3682
|
7
7
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=VYZMOafdvE8xlhvTtwjrz81tIzxR1mHF4lXdsADnIQg,4373
|
8
8
|
liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
|
9
9
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=uQtwtu-kaUZJTjNhAnIr3O794oUlUZ98XR5shYtwP5k,10440
|
10
|
-
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=
|
10
|
+
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=eQCZmQ3xOL3jpZ7RhOfx_pqR9sNEX6RHx8DtIgyXEHc,16656
|
11
11
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=jbZxx-EjPK71A6CSyNzTOAIEQgAUjfvwSViw6R_pPXQ,3510
|
12
12
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=3TTc7U79Orjgi-Wu81WZkWk5MgsdqKXIOBHgIvDazPw,3865
|
13
13
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
58
58
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=MId1S_MfA3pPVQA1rkiKxp-jZDNz8VmvZzXC-Kugol4,7662
|
59
59
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
60
60
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
61
|
-
liger_kernel_nightly-0.5.2.
|
62
|
-
liger_kernel_nightly-0.5.2.
|
63
|
-
liger_kernel_nightly-0.5.2.
|
64
|
-
liger_kernel_nightly-0.5.2.
|
65
|
-
liger_kernel_nightly-0.5.2.
|
66
|
-
liger_kernel_nightly-0.5.2.
|
61
|
+
liger_kernel_nightly-0.5.2.dev20250108073340.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
62
|
+
liger_kernel_nightly-0.5.2.dev20250108073340.dist-info/METADATA,sha256=m2Zrd4xffCEa6qCxyFCCH6l1WJuk7V6eZ28Pt2_dtHc,21055
|
63
|
+
liger_kernel_nightly-0.5.2.dev20250108073340.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
64
|
+
liger_kernel_nightly-0.5.2.dev20250108073340.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
65
|
+
liger_kernel_nightly-0.5.2.dev20250108073340.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
66
|
+
liger_kernel_nightly-0.5.2.dev20250108073340.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|