liger-kernel-nightly 0.4.2.dev20241119054729__py3-none-any.whl → 0.4.2.dev20241119174706__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -32,6 +32,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
32
32
  alpha=1.0,
33
33
  beta=0.1,
34
34
  compiled=True,
35
+ **loss_kwargs,
35
36
  ):
36
37
  """
37
38
  Base class for fused linear layer with preference loss.
@@ -49,6 +50,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
49
50
  alpha (float): Weight for the NLL loss.
50
51
  beta (float): Weight for the odds ratio loss.
51
52
  compiled (bool): Whether to use torch compile for chunk accumulation.
53
+ loss_kwargs (dict): Other possible arguments that a loss function might need
52
54
  """
53
55
  # TODO: Tune CHUNK_SIZE to fully utilize the GPU
54
56
  CHUNK_SIZE = chunk_size
@@ -68,6 +70,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
68
70
  beta=beta,
69
71
  compute_nll_loss=compute_nll_loss,
70
72
  full_target=target,
73
+ **loss_kwargs,
71
74
  )
72
75
 
73
76
  def accumulate_chunk(input_chunk, target_chunk):
@@ -94,6 +97,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
94
97
  loss_acc.add_(chunk_loss)
95
98
  return chunk_grad_input
96
99
 
100
+ if compiled:
101
+ accumulate_chunk = torch.compile(accumulate_chunk)
102
+
97
103
  len_chosen = target.shape[0] // 2
98
104
  _chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
99
105
  _chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0)
@@ -116,8 +122,6 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
116
122
  [chosen_target_chunk, rejected_target_chunk], dim=0
117
123
  )
118
124
 
119
- if compiled:
120
- accumulate_chunk = torch.compile(accumulate_chunk)
121
125
  grad_input = accumulate_chunk(input_chunk, target_chunk)
122
126
 
123
127
  grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]])
@@ -34,7 +34,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
34
34
  ignore_index=-100,
35
35
  beta=0.1,
36
36
  compute_nll_loss=True,
37
- compiled=True,
37
+ compiled=False,
38
38
  ):
39
39
  """
40
40
  Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
@@ -0,0 +1,64 @@
1
+ import torch.nn.functional as F
2
+
3
+ from liger_kernel.chunked_loss.fused_linear_preference import (
4
+ LigerFusedLinearPreferenceBase,
5
+ )
6
+
7
+
8
+ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
9
+
10
+ @staticmethod
11
+ def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1, gamma=0.5):
12
+ """
13
+ Compute odds-ratio loss.
14
+ Args:
15
+ chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
16
+ rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
17
+ beta (float): Weight for the odds ratio loss.
18
+ gamma (float): The simpo gamma, margin term.
19
+ """
20
+ logits = beta * (chosen_logps - rejected_logps) - gamma
21
+ loss = F.logsigmoid(logits).mean()
22
+ return loss
23
+
24
+ @staticmethod
25
+ def forward(
26
+ ctx,
27
+ _input,
28
+ weight,
29
+ target,
30
+ bias=None,
31
+ ignore_index=-100,
32
+ beta=0.1,
33
+ alpha=1.0,
34
+ compute_nll_loss=False,
35
+ compiled=True,
36
+ gamma=0.5,
37
+ ):
38
+ """
39
+ Fused linear layer with SimPO (Simple Preference Optimization) loss. https://arxiv.org/pdf/2405.14734
40
+ Handles both the forward and backward pass of the final linear layer with SimPO loss.
41
+ Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
42
+ """
43
+
44
+ return LigerFusedLinearPreferenceBase.forward(
45
+ ctx,
46
+ _input,
47
+ weight,
48
+ target,
49
+ bias,
50
+ loss_fn=LigerFusedLinearSimPOFunction.preference_loss_fn,
51
+ compute_nll_loss=compute_nll_loss,
52
+ ignore_index=ignore_index,
53
+ alpha=alpha,
54
+ beta=beta,
55
+ compiled=compiled,
56
+ gamma=gamma,
57
+ )
58
+
59
+ @staticmethod
60
+ def backward(ctx, grad_output):
61
+ # Get gradients for _input, weight, bias, and target from the base class
62
+ grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
63
+ # Return these gradients, followed by None for the remaining inputs
64
+ return *grads, None, None, None, None, None, None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241119054729
3
+ Version: 0.4.2.dev20241119174706
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -2,8 +2,9 @@ liger_kernel/env_report.py,sha256=jye8RvUkmhqaIshdeIpoUABoAu7FPKJUib4FnAfvkpw,11
2
2
  liger_kernel/chunked_loss/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  liger_kernel/chunked_loss/cpo_loss.py,sha256=ty3nAlpxGqH6HMvTDzNOVulwvs-j6k26FIgEK0nl9Rc,2059
4
4
  liger_kernel/chunked_loss/dpo_loss.py,sha256=_sftycUsxypLiQaCIoqMEwtc425Kxiq97YI6DvFvscc,1943
5
- liger_kernel/chunked_loss/fused_linear_preference.py,sha256=9bj4u328TeTBGr5DquYU5sHANagXa3ti-6rjTIa-OXQ,8595
6
- liger_kernel/chunked_loss/orpo_loss.py,sha256=DNifPpzGV_t3dfOPlPy2XKDM6M1Qne0kCbIPztvFY9U,2179
5
+ liger_kernel/chunked_loss/fused_linear_preference.py,sha256=gtsWG3rpTlWpiiom_oMPeS-w-lofBVrguN0KglAXTGk,8727
6
+ liger_kernel/chunked_loss/orpo_loss.py,sha256=QtHPDQwZdU7QFgu9tPg81vQfF_Dm3zQcsmhp9SdKKvA,2180
7
+ liger_kernel/chunked_loss/simpo_loss.py,sha256=lmPopkHcqfglEnXv28FcQQjIkpNg8CEhn0Wt19xcoE4,2223
7
8
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
9
  liger_kernel/ops/cross_entropy.py,sha256=sfUb7-jIZp0EKXjg1DYy2Wdzw_Mg-mHmGoR5bpdm4tw,15526
9
10
  liger_kernel/ops/fused_linear_cross_entropy.py,sha256=ib7M3AjJE164yMfuS9R39k-5qnDgYOXptIT146lqYbg,9964
@@ -50,9 +51,9 @@ liger_kernel/transformers/model/qwen2.py,sha256=EyhSSzQOskGjSnCsKMZpd1s5IAIlHd5P
50
51
  liger_kernel/transformers/model/qwen2_vl.py,sha256=bIQe2bWiY--G84FhCD29Gdi64_qHP6vbcGsK6vKysQE,8547
51
52
  liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
52
53
  liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
53
- liger_kernel_nightly-0.4.2.dev20241119054729.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
54
- liger_kernel_nightly-0.4.2.dev20241119054729.dist-info/METADATA,sha256=z7a8F_CpxQybJzgLVPeWKTiLuv4nWL7gvQe4vIVzn7Q,21556
55
- liger_kernel_nightly-0.4.2.dev20241119054729.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
56
- liger_kernel_nightly-0.4.2.dev20241119054729.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
57
- liger_kernel_nightly-0.4.2.dev20241119054729.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
58
- liger_kernel_nightly-0.4.2.dev20241119054729.dist-info/RECORD,,
54
+ liger_kernel_nightly-0.4.2.dev20241119174706.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
55
+ liger_kernel_nightly-0.4.2.dev20241119174706.dist-info/METADATA,sha256=T8u7FqqGYIi-w-4gDqYqgRq5y5vgV7Nk9FdcHYNlQxQ,21556
56
+ liger_kernel_nightly-0.4.2.dev20241119174706.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
57
+ liger_kernel_nightly-0.4.2.dev20241119174706.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
58
+ liger_kernel_nightly-0.4.2.dev20241119174706.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
59
+ liger_kernel_nightly-0.4.2.dev20241119174706.dist-info/RECORD,,