liger-kernel-nightly 0.5.2.dev20241211055800__py3-none-any.whl → 0.5.2.dev20241211213024__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,7 +29,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
29
29
  compute_nll_loss=True,
30
30
  compiled=True,
31
31
  use_ref_model=False,
32
- # TODO: ref input
32
+ ref_input=None,
33
33
  ref_weight=None,
34
34
  ref_bias=None,
35
35
  **loss_kwargs,
@@ -59,6 +59,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
59
59
  compute_nll_loss (bool): Whether to compute NLL loss.
60
60
  compiled (bool): Whether to use torch compile for chunk accumulation.
61
61
  use_ref_model (bool): Whether to use a reference model for the alignment loss.
62
+ ref_input (torch.Tensor): Reference input tensor. Shape: (batch_size, seq_len, hidden_size).
62
63
  ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
63
64
  ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
64
65
  loss_kwargs (dict): Other possible arguments that a loss function might need
@@ -92,6 +93,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
92
93
  compute_nll_loss=compute_nll_loss,
93
94
  full_target=target,
94
95
  use_ref_model=use_ref_model,
96
+ ref_input=ref_input,
95
97
  ref_weight=ref_weight,
96
98
  ref_bias=ref_bias,
97
99
  **loss_kwargs,
@@ -301,6 +303,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
301
303
  beta=0.1,
302
304
  compute_nll_loss=True,
303
305
  use_ref_model=False,
306
+ ref_input=None,
304
307
  ref_weight=None,
305
308
  ref_bias=None,
306
309
  **loss_kwargs,
@@ -319,6 +322,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
319
322
  beta (float): Weight for the preference loss.
320
323
  compute_nll_loss (bool): Whether to compute NLL loss.
321
324
  use_ref_model (bool): Whether to use a reference model for the alignment loss.
325
+ ref_input (torch.Tensor): Reference input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
322
326
  ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
323
327
  ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
324
328
  loss_kwargs (dict): Additional arguments for the loss function.
@@ -357,7 +361,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
357
361
  ref_rejected_logits,
358
362
  ref_chosen_nll_loss,
359
363
  ) = LigerFusedLinearPreferenceBase.chunk_forward(
360
- input_chunk,
364
+ ref_input,
361
365
  ref_weight,
362
366
  target_chunk,
363
367
  ref_bias,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241211055800
3
+ Version: 0.5.2.dev20241211213024
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -6,7 +6,7 @@ liger_kernel/chunked_loss/cpo_loss.py,sha256=Qu1Ul2A12sp6CqIT-atPbHWFb_LLtINEA9m
6
6
  liger_kernel/chunked_loss/dpo_loss.py,sha256=H9_RRhclckHYM2sd75tgbnf8IxC_PU2JCALbgtPQvwc,4222
7
7
  liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
8
8
  liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=2BH6DCPjsR2zS6zcwFPcIIZRhLF8SohjGdKsAJ_301o,10222
9
- liger_kernel/chunked_loss/fused_linear_preference.py,sha256=vlWfaaIECWvCQhY9PM7zRI0vKThIrydMf6P44bXn1EE,15114
9
+ liger_kernel/chunked_loss/fused_linear_preference.py,sha256=qeRod4MFVttj62uPFhgKAWNNjVrqiEvu5SjZfRnOGzI,15389
10
10
  liger_kernel/chunked_loss/orpo_loss.py,sha256=ZuKGjbkIYzV4UzvupNdq6vyxCp7-BztQkUt8ZnFvKos,3531
11
11
  liger_kernel/chunked_loss/simpo_loss.py,sha256=Wa4LOlDG9PbJkOOkKg8hbKvnKgg7OTBz6-qIkwPK1yw,3275
12
12
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -57,9 +57,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=c4OQVJmhNOloj0JYSEc0j_cQuBb
57
57
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=jko6oq_XQdBSmXubp05E-_YXOyhtB5Bj75dg5YNwOsE,7517
58
58
  liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
59
59
  liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
60
- liger_kernel_nightly-0.5.2.dev20241211055800.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
61
- liger_kernel_nightly-0.5.2.dev20241211055800.dist-info/METADATA,sha256=HSfJ-qbwGmNDotIXkz8RNaQ9h8k2kA6qxAXsPvlv494,20721
62
- liger_kernel_nightly-0.5.2.dev20241211055800.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
63
- liger_kernel_nightly-0.5.2.dev20241211055800.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
64
- liger_kernel_nightly-0.5.2.dev20241211055800.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
65
- liger_kernel_nightly-0.5.2.dev20241211055800.dist-info/RECORD,,
60
+ liger_kernel_nightly-0.5.2.dev20241211213024.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
61
+ liger_kernel_nightly-0.5.2.dev20241211213024.dist-info/METADATA,sha256=biU1_vrGRLdmGynYauvf4YfAVHrgN2RtGWe_CNuAD3c,20721
62
+ liger_kernel_nightly-0.5.2.dev20241211213024.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
63
+ liger_kernel_nightly-0.5.2.dev20241211213024.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
64
+ liger_kernel_nightly-0.5.2.dev20241211213024.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
65
+ liger_kernel_nightly-0.5.2.dev20241211213024.dist-info/RECORD,,