liger-kernel-nightly 0.4.2.dev20241209224333__tar.gz → 0.4.2.dev20241209234352__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. {liger_kernel_nightly-0.4.2.dev20241209224333/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.2.dev20241209234352}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/pyproject.toml +1 -1
  3. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/chunked_loss/cpo_loss.py +16 -10
  4. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/chunked_loss/dpo_loss.py +20 -12
  5. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/chunked_loss/orpo_loss.py +15 -9
  6. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/chunked_loss/simpo_loss.py +17 -11
  7. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352/src/liger_kernel_nightly.egg-info}/PKG-INFO +1 -1
  8. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/LICENSE +0 -0
  9. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/NOTICE +0 -0
  10. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/README.md +0 -0
  11. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/setup.cfg +0 -0
  12. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/__init__.py +0 -0
  13. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  14. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/chunked_loss/functional.py +0 -0
  15. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  16. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  17. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/env_report.py +0 -0
  18. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/__init__.py +0 -0
  19. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/cross_entropy.py +0 -0
  20. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  21. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  22. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  23. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  24. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/geglu.py +0 -0
  25. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/group_norm.py +0 -0
  26. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/jsd.py +0 -0
  27. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/kl_div.py +0 -0
  28. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/layer_norm.py +0 -0
  29. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  30. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/rms_norm.py +0 -0
  31. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/rope.py +0 -0
  32. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/swiglu.py +0 -0
  33. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/ops/utils.py +0 -0
  34. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/__init__.py +0 -0
  35. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/auto_model.py +0 -0
  36. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  37. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  38. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/functional.py +0 -0
  39. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  40. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  41. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/geglu.py +0 -0
  42. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/group_norm.py +0 -0
  43. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/jsd.py +0 -0
  44. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/kl_div.py +0 -0
  45. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/layer_norm.py +0 -0
  46. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/model/__init__.py +0 -0
  47. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/model/gemma.py +0 -0
  48. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  49. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/model/llama.py +0 -0
  50. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/model/mistral.py +0 -0
  51. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  52. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/model/mllama.py +0 -0
  53. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/model/phi3.py +0 -0
  54. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  55. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  56. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  57. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/orpo_trainer.py +0 -0
  58. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  59. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/rms_norm.py +0 -0
  60. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/rope.py +0 -0
  61. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/swiglu.py +0 -0
  62. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  63. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/triton/__init__.py +0 -0
  64. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/triton/monkey_patch.py +0 -0
  65. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel/utils.py +0 -0
  66. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  67. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  68. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  69. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241209234352}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241209224333
3
+ Version: 0.4.2.dev20241209234352
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.4.2.dev20241209224333"
7
+ version = "0.4.2.dev20241209234352"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -11,11 +11,25 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
11
11
  @staticmethod
12
12
  def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
13
13
  """
14
- Compute odds-ratio loss.
14
+ Paper: https://arxiv.org/pdf/2401.08417
15
+
16
+ Formula:
17
+ L(π_θ; U) = -E_(x,y_w,y_l)~D[log σ(β log π_θ(y_w|x) - β log π_θ(y_l|x))]
18
+
19
+ Where:
20
+ - π_θ(y|x): Policy (model) probability
21
+ - y_w: Chosen sequence
22
+ - y_l: Rejected sequence
23
+ - σ: Sigmoid function
24
+ - β: Temperature parameter
25
+ - E: Expected value over the dataset D
26
+ - D: Dataset of preferences
27
+
15
28
  Args:
16
29
  chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
17
30
  rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
18
- beta (float): Weight for the odds ratio loss.
31
+ full_target (torch.Tensor): Non chunked full target tensor
32
+ beta (float): Weight for the CPO loss
19
33
  """
20
34
  logits = beta * (chosen_logps - rejected_logps)
21
35
  loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
@@ -34,12 +48,6 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
34
48
  compute_nll_loss=True,
35
49
  compiled=True,
36
50
  ):
37
- """
38
- Fused linear layer with CPO (Odds-Ratio Preference Optimization) loss.
39
- Handles both the forward and backward pass of the final linear layer with CPO loss.
40
- Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
41
- """
42
-
43
51
  return LigerFusedLinearPreferenceBase.forward(
44
52
  ctx,
45
53
  _input,
@@ -56,9 +64,7 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
56
64
 
57
65
  @staticmethod
58
66
  def backward(ctx, *grad_output):
59
- # Get gradients for _input, weight, bias, and target from the base class
60
67
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
61
- # Return these gradients, followed by None for the remaining inputs
62
68
  return *grads, None, None, None, None, None
63
69
 
64
70
 
@@ -18,14 +18,28 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
18
18
  beta=0.1,
19
19
  ):
20
20
  """
21
- Compute DPO loss (Direct Preference Optimization).
21
+ Paper: https://arxiv.org/pdf/2305.18290
22
+
23
+ Formula:
24
+ L_DPO = -E[ log_sigmoid( β * (log(π(y_w|x)/π_ref(y_w|x)) - log(π(y_l|x)/π_ref(y_l|x))) ) ]
25
+
26
+ Where:
27
+ - π(y|x): Policy (model) probability
28
+ - π_ref(y|x): Reference model probability
29
+ - y_w: Chosen sequence
30
+ - y_l: Rejected sequence
31
+ - β: Weight for the direct preference loss
32
+ - E: Expected value over the dataset
33
+
22
34
  Args:
23
- chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
24
- rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
25
- ref_chosen_logps (torch.Tensor, optional): Reference log probabilities of chosen tokens. Shape: (batch_size,).
26
- ref_rejected_logps (torch.Tensor, optional): Reference log probabilities of rejected tokens. Shape: (batch_size,).
27
- beta (float): Weight for the direct preference loss.
35
+ chosen_logps: Log probabilities of chosen tokens (batch_size,)
36
+ rejected_logps: Log probabilities of rejected tokens (batch_size,)
37
+ full_target: Non chunked full target tensor
38
+ ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
39
+ ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
40
+ beta: Weight for the direct preference loss
28
41
  """
42
+
29
43
  if ref_chosen_logps is None:
30
44
  ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device)
31
45
  if ref_rejected_logps is None:
@@ -53,10 +67,6 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
53
67
  compiled=True,
54
68
  use_ref_model=True,
55
69
  ):
56
- """
57
- Fused linear layer with DPO (Direct Preference Optimization) loss.
58
- Handles both the forward and backward pass of the final linear layer with DPO loss.
59
- """
60
70
  return LigerFusedLinearPreferenceBase.forward(
61
71
  ctx=ctx,
62
72
  _input=_input,
@@ -75,9 +85,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
75
85
 
76
86
  @staticmethod
77
87
  def backward(ctx, *grad_output):
78
- # Get gradients for _input, weight, bias, and target from the base class
79
88
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
80
- # Return these gradients, followed by None for the remaining inputs
81
89
  return *grads, None, None, None, None, None, None, None
82
90
 
83
91
 
@@ -11,10 +11,24 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
11
11
  @staticmethod
12
12
  def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
13
13
  """
14
- Compute odds-ratio loss.
14
+ Paper: https://arxiv.org/pdf/2403.07691
15
+
16
+ Formula:
17
+ Compute odds-ratio loss: L_OR = -log(σ(log(odds_θ(y_w|x) / odds_θ(y_l|x))))
18
+ where odds_θ(y|x) = P_θ(y|x) / (1 - P_θ(y|x))
19
+
20
+ Where:
21
+ - P_θ(y|x): Policy (model) probability
22
+ - y_w: Chosen sequence
23
+ - y_l: Rejected sequence
24
+ - σ: Sigmoid function
25
+ - β: Weight for the odds ratio loss
26
+ - odds_θ: Odds function for the policy
27
+
15
28
  Args:
16
29
  chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
17
30
  rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
31
+ full_target (torch.Tensor): Non chunked full target tensor
18
32
  beta (float): Weight for the odds ratio loss.
19
33
  """
20
34
  log_odds = (chosen_logps - rejected_logps) - (
@@ -44,12 +58,6 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
44
58
  compute_nll_loss=True,
45
59
  compiled=True,
46
60
  ):
47
- """
48
- Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
49
- Handles both the forward and backward pass of the final linear layer with ORPO loss.
50
- Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
51
- """
52
-
53
61
  return LigerFusedLinearPreferenceBase.forward(
54
62
  ctx=ctx,
55
63
  _input=_input,
@@ -65,9 +73,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
65
73
 
66
74
  @staticmethod
67
75
  def backward(ctx, *grad_output):
68
- # Get gradients for _input, weight, bias, and target from the base class
69
76
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
70
- # Return these gradients, followed by None for the remaining inputs
71
77
  return *grads, None, None, None, None
72
78
 
73
79
 
@@ -13,12 +13,26 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
13
13
  chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5
14
14
  ):
15
15
  """
16
- Compute odds-ratio loss.
16
+ Paper: https://arxiv.org/pdf/2405.14734
17
+
18
+ Formula:
19
+ L_SimPO(π_θ) = -E [log σ(β/|y_w| log π_θ(y_w|x) - β/|y_l| log π_θ(y_l|x) - γ)]
20
+
21
+ Where:
22
+ - π_θ(y|x): Policy (model) probability
23
+ - y_w: Chosen sequence
24
+ - y_l: Rejected sequence
25
+ - |y_w|, |y_l|: Sequence lengths
26
+ - σ: Sigmoid function
27
+ - β: beta weight
28
+ - γ: gemma margin term
29
+
17
30
  Args:
18
31
  chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
19
32
  rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
20
- beta (float): Weight for the odds ratio loss.
21
- gamma (float): The simpo gamma, margin term.
33
+ full_target: Non chunked full target tensor
34
+ beta (float): beta weight
35
+ gamma (float): gemma margin term
22
36
  """
23
37
  logits = beta * (chosen_logps - rejected_logps) - gamma
24
38
  loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
@@ -38,12 +52,6 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
38
52
  compiled=True,
39
53
  gamma=0.5,
40
54
  ):
41
- """
42
- Fused linear layer with SimPO (Simple Preference Optimization) loss. https://arxiv.org/pdf/2405.14734
43
- Handles both the forward and backward pass of the final linear layer with SimPO loss.
44
- Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
45
- """
46
-
47
55
  return LigerFusedLinearPreferenceBase.forward(
48
56
  ctx,
49
57
  _input,
@@ -61,9 +69,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
61
69
 
62
70
  @staticmethod
63
71
  def backward(ctx, *grad_output):
64
- # Get gradients for _input, weight, bias, and target from the base class
65
72
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
66
- # Return these gradients, followed by None for the remaining inputs
67
73
  return *grads, None, None, None, None, None, None
68
74
 
69
75
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241209224333
3
+ Version: 0.4.2.dev20241209234352
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation