liger-kernel-nightly 0.4.2.dev20241117192137__py3-none-any.whl → 0.4.2.dev20241119054456__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,61 @@
1
+ import torch.nn.functional as F
2
+
3
+ from liger_kernel.chunked_loss.fused_linear_preference import (
4
+ LigerFusedLinearPreferenceBase,
5
+ )
6
+
7
+
8
+ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
9
+
10
+ @staticmethod
11
+ def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
12
+ """
13
+ Compute odds-ratio loss.
14
+ Args:
15
+ chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
16
+ rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
17
+ beta (float): Weight for the odds ratio loss.
18
+ """
19
+ logits = beta * (chosen_logps - rejected_logps)
20
+ loss = F.logsigmoid(logits).mean()
21
+ return loss
22
+
23
+ @staticmethod
24
+ def forward(
25
+ ctx,
26
+ _input,
27
+ weight,
28
+ target,
29
+ bias=None,
30
+ ignore_index=-100,
31
+ beta=0.1,
32
+ alpha=1.0,
33
+ compute_nll_loss=True,
34
+ compiled=True,
35
+ ):
36
+ """
37
+ Fused linear layer with CPO (Odds-Ratio Preference Optimization) loss.
38
+ Handles both the forward and backward pass of the final linear layer with CPO loss.
39
+ Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
40
+ """
41
+
42
+ return LigerFusedLinearPreferenceBase.forward(
43
+ ctx,
44
+ _input,
45
+ weight,
46
+ target,
47
+ bias,
48
+ loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn,
49
+ compute_nll_loss=compute_nll_loss,
50
+ ignore_index=ignore_index,
51
+ alpha=alpha,
52
+ beta=beta,
53
+ compiled=compiled,
54
+ )
55
+
56
+ @staticmethod
57
+ def backward(ctx, grad_output):
58
+ # Get gradients for _input, weight, bias, and target from the base class
59
+ grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
60
+ # Return these gradients, followed by None for the remaining inputs
61
+ return *grads, None, None, None, None, None
@@ -29,6 +29,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
29
29
  chunk_size=1,
30
30
  compute_nll_loss=True,
31
31
  ignore_index=-100,
32
+ alpha=1.0,
32
33
  beta=0.1,
33
34
  compiled=True,
34
35
  ):
@@ -45,6 +46,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
45
46
  chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
46
47
  compute_nll_loss (bool): Whether to compute NLL loss.
47
48
  ignore_index (int): Index to ignore for loss computation.
49
+ alpha (float): Weight for the NLL loss.
48
50
  beta (float): Weight for the odds ratio loss.
49
51
  compiled (bool): Whether to use torch compile for chunk accumulation.
50
52
  """
@@ -62,6 +64,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
62
64
  LigerFusedLinearPreferenceBase._compute_loss,
63
65
  preference_loss_fn=loss_fn,
64
66
  ignore_index=ignore_index,
67
+ alpha=alpha,
65
68
  beta=beta,
66
69
  compute_nll_loss=compute_nll_loss,
67
70
  full_target=target,
@@ -149,6 +152,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
149
152
  preference_loss_fn=None,
150
153
  full_target=None,
151
154
  ignore_index=-100,
155
+ alpha=1.0,
152
156
  beta=0.1,
153
157
  compute_nll_loss=True,
154
158
  **loss_kwargs,
@@ -163,6 +167,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
163
167
  bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
164
168
  full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
165
169
  ignore_index (int): Index to ignore for loss computation.
170
+ alpha (float): Weight for the NLL loss.
166
171
  beta (float): Weight for the odds ratio loss.
167
172
  loss_kwargs (dict): Additional arguments for the loss function.
168
173
  """
@@ -202,5 +207,5 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
202
207
  )
203
208
  alignment_loss = alignment_loss / (full_target.shape[0] // 2)
204
209
 
205
- loss = chosen_nll_loss - alignment_loss
210
+ loss = alpha * chosen_nll_loss - alignment_loss
206
211
  return loss, (alignment_loss, chosen_logps, rejected_logps)
@@ -610,9 +610,7 @@ def apply_liger_kernel_to_qwen2(
610
610
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
611
611
  modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
612
612
 
613
- # import pdb; pdb.set_trace()
614
613
  if fused_linear_cross_entropy:
615
-
616
614
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
617
615
  modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
618
616
  else: # if version < 4.46.1
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241117192137
3
+ Version: 0.4.2.dev20241119054456
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -1,7 +1,8 @@
1
1
  liger_kernel/env_report.py,sha256=jye8RvUkmhqaIshdeIpoUABoAu7FPKJUib4FnAfvkpw,1132
2
2
  liger_kernel/chunked_loss/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ liger_kernel/chunked_loss/cpo_loss.py,sha256=ty3nAlpxGqH6HMvTDzNOVulwvs-j6k26FIgEK0nl9Rc,2059
3
4
  liger_kernel/chunked_loss/dpo_loss.py,sha256=_sftycUsxypLiQaCIoqMEwtc425Kxiq97YI6DvFvscc,1943
4
- liger_kernel/chunked_loss/fused_linear_preference.py,sha256=ayx-dmAx1TW9sThHJ_wUU1MqpZeJ4-SooGh0ZgVFlOA,8420
5
+ liger_kernel/chunked_loss/fused_linear_preference.py,sha256=9bj4u328TeTBGr5DquYU5sHANagXa3ti-6rjTIa-OXQ,8595
5
6
  liger_kernel/chunked_loss/orpo_loss.py,sha256=DNifPpzGV_t3dfOPlPy2XKDM6M1Qne0kCbIPztvFY9U,2179
6
7
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
8
  liger_kernel/ops/cross_entropy.py,sha256=sfUb7-jIZp0EKXjg1DYy2Wdzw_Mg-mHmGoR5bpdm4tw,15526
@@ -29,7 +30,7 @@ liger_kernel/transformers/group_norm.py,sha256=FJ9R7mS9G1wO-GRIQ6QKSmIhnZ6nQ6GIk
29
30
  liger_kernel/transformers/jsd.py,sha256=W-5CypO2mx4-bUWOxq1KScfCdoXlLoYbtt5xBnRzMs4,3056
30
31
  liger_kernel/transformers/kl_div.py,sha256=qVhjBg6tjRyue5iZ3NFxo8uySY4JuIFJyv0IM_50F24,431
31
32
  liger_kernel/transformers/layer_norm.py,sha256=fd6o4kSHJWolQMWxh-l1qObfgL08ruNbUoBiANKX1ow,972
32
- liger_kernel/transformers/monkey_patch.py,sha256=L1IuGmFMWYgf-u3OXCg43BUxbZKTpd7ATjjDjYoFkEM,38268
33
+ liger_kernel/transformers/monkey_patch.py,sha256=Qk8jTO1AO6-knod7w8LtZKVIvm5gapsHInBMCjy6zR8,38233
33
34
  liger_kernel/transformers/rms_norm.py,sha256=AHstklNIO1PLHjjCBU-TPuUD-Fl_pycJUTLlJNojbV8,1189
34
35
  liger_kernel/transformers/rope.py,sha256=m-ah8vZBYW8tfplTXCiAPMHJWlB1tdp_JPXJeWE-Boo,943
35
36
  liger_kernel/transformers/swiglu.py,sha256=0-tVJ8xEYfhxnduc16PflXFj8sZPxdx9sHUn3hfwCI4,2468
@@ -47,9 +48,9 @@ liger_kernel/transformers/model/qwen2.py,sha256=EyhSSzQOskGjSnCsKMZpd1s5IAIlHd5P
47
48
  liger_kernel/transformers/model/qwen2_vl.py,sha256=bIQe2bWiY--G84FhCD29Gdi64_qHP6vbcGsK6vKysQE,8547
48
49
  liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
49
50
  liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
50
- liger_kernel_nightly-0.4.2.dev20241117192137.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
51
- liger_kernel_nightly-0.4.2.dev20241117192137.dist-info/METADATA,sha256=YPECo7OOvylo0MxXQ_usD84R8GX8J7Sy7rQPdHhVuqc,21556
52
- liger_kernel_nightly-0.4.2.dev20241117192137.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
53
- liger_kernel_nightly-0.4.2.dev20241117192137.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
54
- liger_kernel_nightly-0.4.2.dev20241117192137.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
55
- liger_kernel_nightly-0.4.2.dev20241117192137.dist-info/RECORD,,
51
+ liger_kernel_nightly-0.4.2.dev20241119054456.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
52
+ liger_kernel_nightly-0.4.2.dev20241119054456.dist-info/METADATA,sha256=FDmZnvTxvl1UbpHLw6hwcuMTHMGHdTi_1GS9N7OhZoQ,21556
53
+ liger_kernel_nightly-0.4.2.dev20241119054456.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
54
+ liger_kernel_nightly-0.4.2.dev20241119054456.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
55
+ liger_kernel_nightly-0.4.2.dev20241119054456.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
56
+ liger_kernel_nightly-0.4.2.dev20241119054456.dist-info/RECORD,,