liger-kernel-nightly 0.5.1.dev20241211055647__py3-none-any.whl → 0.5.2.dev20241211213024__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/fused_linear_preference.py +6 -2
- {liger_kernel_nightly-0.5.1.dev20241211055647.dist-info → liger_kernel_nightly-0.5.2.dev20241211213024.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.1.dev20241211055647.dist-info → liger_kernel_nightly-0.5.2.dev20241211213024.dist-info}/RECORD +7 -7
- {liger_kernel_nightly-0.5.1.dev20241211055647.dist-info → liger_kernel_nightly-0.5.2.dev20241211213024.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.1.dev20241211055647.dist-info → liger_kernel_nightly-0.5.2.dev20241211213024.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.1.dev20241211055647.dist-info → liger_kernel_nightly-0.5.2.dev20241211213024.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.1.dev20241211055647.dist-info → liger_kernel_nightly-0.5.2.dev20241211213024.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
29
29
|
compute_nll_loss=True,
|
30
30
|
compiled=True,
|
31
31
|
use_ref_model=False,
|
32
|
-
|
32
|
+
ref_input=None,
|
33
33
|
ref_weight=None,
|
34
34
|
ref_bias=None,
|
35
35
|
**loss_kwargs,
|
@@ -59,6 +59,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
59
59
|
compute_nll_loss (bool): Whether to compute NLL loss.
|
60
60
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
61
61
|
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
62
|
+
ref_input (torch.Tensor): Reference input tensor. Shape: (batch_size, seq_len, hidden_size).
|
62
63
|
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
63
64
|
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
64
65
|
loss_kwargs (dict): Other possible arguments that a loss function might need
|
@@ -92,6 +93,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
92
93
|
compute_nll_loss=compute_nll_loss,
|
93
94
|
full_target=target,
|
94
95
|
use_ref_model=use_ref_model,
|
96
|
+
ref_input=ref_input,
|
95
97
|
ref_weight=ref_weight,
|
96
98
|
ref_bias=ref_bias,
|
97
99
|
**loss_kwargs,
|
@@ -301,6 +303,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
301
303
|
beta=0.1,
|
302
304
|
compute_nll_loss=True,
|
303
305
|
use_ref_model=False,
|
306
|
+
ref_input=None,
|
304
307
|
ref_weight=None,
|
305
308
|
ref_bias=None,
|
306
309
|
**loss_kwargs,
|
@@ -319,6 +322,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
319
322
|
beta (float): Weight for the preference loss.
|
320
323
|
compute_nll_loss (bool): Whether to compute NLL loss.
|
321
324
|
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
325
|
+
ref_input (torch.Tensor): Reference input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
|
322
326
|
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
323
327
|
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
324
328
|
loss_kwargs (dict): Additional arguments for the loss function.
|
@@ -357,7 +361,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
357
361
|
ref_rejected_logits,
|
358
362
|
ref_chosen_nll_loss,
|
359
363
|
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
360
|
-
|
364
|
+
ref_input,
|
361
365
|
ref_weight,
|
362
366
|
target_chunk,
|
363
367
|
ref_bias,
|
@@ -6,7 +6,7 @@ liger_kernel/chunked_loss/cpo_loss.py,sha256=Qu1Ul2A12sp6CqIT-atPbHWFb_LLtINEA9m
|
|
6
6
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=H9_RRhclckHYM2sd75tgbnf8IxC_PU2JCALbgtPQvwc,4222
|
7
7
|
liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
|
8
8
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=2BH6DCPjsR2zS6zcwFPcIIZRhLF8SohjGdKsAJ_301o,10222
|
9
|
-
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=
|
9
|
+
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=qeRod4MFVttj62uPFhgKAWNNjVrqiEvu5SjZfRnOGzI,15389
|
10
10
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=ZuKGjbkIYzV4UzvupNdq6vyxCp7-BztQkUt8ZnFvKos,3531
|
11
11
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=Wa4LOlDG9PbJkOOkKg8hbKvnKgg7OTBz6-qIkwPK1yw,3275
|
12
12
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -57,9 +57,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=c4OQVJmhNOloj0JYSEc0j_cQuBb
|
|
57
57
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=jko6oq_XQdBSmXubp05E-_YXOyhtB5Bj75dg5YNwOsE,7517
|
58
58
|
liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
|
59
59
|
liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
|
60
|
-
liger_kernel_nightly-0.5.
|
61
|
-
liger_kernel_nightly-0.5.
|
62
|
-
liger_kernel_nightly-0.5.
|
63
|
-
liger_kernel_nightly-0.5.
|
64
|
-
liger_kernel_nightly-0.5.
|
65
|
-
liger_kernel_nightly-0.5.
|
60
|
+
liger_kernel_nightly-0.5.2.dev20241211213024.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
61
|
+
liger_kernel_nightly-0.5.2.dev20241211213024.dist-info/METADATA,sha256=biU1_vrGRLdmGynYauvf4YfAVHrgN2RtGWe_CNuAD3c,20721
|
62
|
+
liger_kernel_nightly-0.5.2.dev20241211213024.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
63
|
+
liger_kernel_nightly-0.5.2.dev20241211213024.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
64
|
+
liger_kernel_nightly-0.5.2.dev20241211213024.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
65
|
+
liger_kernel_nightly-0.5.2.dev20241211213024.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|