liger-kernel 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. liger_kernel/chunked_loss/README.md +25 -0
  2. liger_kernel/chunked_loss/__init__.py +3 -0
  3. liger_kernel/chunked_loss/cpo_loss.py +18 -8
  4. liger_kernel/chunked_loss/dpo_loss.py +20 -10
  5. liger_kernel/chunked_loss/functional.py +4 -0
  6. liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
  7. liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
  8. liger_kernel/chunked_loss/fused_linear_rlhf.py +213 -0
  9. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
  10. liger_kernel/chunked_loss/grpo_loss.py +160 -0
  11. liger_kernel/chunked_loss/jsd_loss.py +154 -0
  12. liger_kernel/chunked_loss/kto_loss.py +172 -0
  13. liger_kernel/chunked_loss/orpo_loss.py +8 -9
  14. liger_kernel/chunked_loss/simpo_loss.py +22 -8
  15. liger_kernel/env_report.py +5 -12
  16. liger_kernel/ops/cross_entropy.py +102 -51
  17. liger_kernel/ops/experimental/embedding.py +1 -3
  18. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  19. liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
  20. liger_kernel/ops/fused_linear_jsd.py +14 -32
  21. liger_kernel/ops/geglu.py +6 -17
  22. liger_kernel/ops/group_norm.py +11 -28
  23. liger_kernel/ops/jsd.py +5 -9
  24. liger_kernel/ops/kl_div.py +8 -11
  25. liger_kernel/ops/layer_norm.py +23 -12
  26. liger_kernel/ops/qwen2vl_mrope.py +8 -25
  27. liger_kernel/ops/rms_norm.py +14 -32
  28. liger_kernel/ops/rope.py +31 -33
  29. liger_kernel/ops/swiglu.py +4 -8
  30. liger_kernel/ops/tvd.py +207 -0
  31. liger_kernel/ops/utils.py +3 -2
  32. liger_kernel/transformers/__init__.py +19 -24
  33. liger_kernel/transformers/auto_model.py +6 -13
  34. liger_kernel/transformers/cross_entropy.py +7 -9
  35. liger_kernel/transformers/experimental/embedding.py +1 -3
  36. liger_kernel/transformers/functional.py +28 -7
  37. liger_kernel/transformers/fused_linear_cross_entropy.py +15 -10
  38. liger_kernel/transformers/geglu.py +1 -4
  39. liger_kernel/transformers/group_norm.py +9 -15
  40. liger_kernel/transformers/jsd.py +1 -3
  41. liger_kernel/transformers/kl_div.py +1 -3
  42. liger_kernel/transformers/layer_norm.py +3 -9
  43. liger_kernel/transformers/model/gemma.py +18 -40
  44. liger_kernel/transformers/model/gemma2.py +19 -41
  45. liger_kernel/transformers/model/llama.py +22 -48
  46. liger_kernel/transformers/model/mistral.py +14 -26
  47. liger_kernel/transformers/model/mixtral.py +24 -54
  48. liger_kernel/transformers/model/mllama.py +16 -36
  49. liger_kernel/transformers/model/olmo2.py +124 -0
  50. liger_kernel/transformers/model/phi3.py +18 -40
  51. liger_kernel/transformers/model/qwen2.py +18 -40
  52. liger_kernel/transformers/model/qwen2_vl.py +36 -32
  53. liger_kernel/transformers/monkey_patch.py +214 -144
  54. liger_kernel/transformers/rms_norm.py +4 -4
  55. liger_kernel/transformers/rope.py +2 -2
  56. liger_kernel/transformers/swiglu.py +2 -8
  57. liger_kernel/transformers/trainer/__init__.py +1 -3
  58. liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
  59. liger_kernel/transformers/tvd.py +13 -0
  60. liger_kernel/triton/__init__.py +1 -3
  61. liger_kernel/triton/monkey_patch.py +1 -3
  62. liger_kernel/utils.py +49 -0
  63. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/METADATA +53 -26
  64. liger_kernel-0.5.4.dist-info/RECORD +74 -0
  65. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/WHEEL +1 -1
  66. liger_kernel-0.5.2.dist-info/RECORD +0 -65
  67. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/LICENSE +0 -0
  68. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/NOTICE +0 -0
  69. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,172 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_unpaired_preference import LigerFusedLinearUnpairedPreferenceBase
5
+
6
+
7
+ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
8
+ @staticmethod
9
+ def preference_loss_fn(
10
+ average_log_prob_chunk,
11
+ preference_labels_chunk,
12
+ full_target,
13
+ ref_average_log_prob_chunk=None,
14
+ beta=0.1,
15
+ kl=None,
16
+ ):
17
+ """
18
+ Implements the Kahneman-Tversky Optimization (KTO) loss function.
19
+ Paper: "KTO: Model Alignment as Prospect Theory-Guided Optimization"
20
+ https://arxiv.org/abs/2402.01306
21
+
22
+ KTO loss is inspired by prospect theory (https://en.wikipedia.org/wiki/Prospect_theory)
23
+ from behavioral economics, which models how humans make decisions under uncertainty.
24
+ The loss function is asymmetric, treating gains and losses differently, similar to
25
+ human decision-making patterns.
26
+
27
+ Formula:
28
+ When y is chosen:
29
+ L_KTO = 1 - σ(β * (log[π(x)/π₀(x)] - KL(π||π₀)_y))
30
+ When y is rejected:
31
+ L_KTO = 1 - σ(β * (KL(π||π₀)_y - log[π(x)/π₀(x)]))
32
+
33
+ Where:
34
+ - σ: Sigmoid function
35
+ - β: Temperature parameter controlling the strength of the preference signal
36
+ - π(x): Policy (current model)
37
+ - π₀(x): Reference policy (reference model)
38
+ - KL(π||π₀)_y: KL divergence estimated using the rejected response y
39
+
40
+ The loss encourages the model to:
41
+ 1. Assign higher probability to chosen responses
42
+ 2. Assign lower probability to rejected responses
43
+ 3. Maintain reasonable distance from the reference model
44
+
45
+ Args:
46
+ average_log_prob_chunk: Log probabilities for the chunk (batch_size,)
47
+ preference_labels_chunk: Preference labels for the chunk (batch_size,)
48
+ full_target: Non chunked full target tensor
49
+ ref_average_log_prob_chunk: Reference log probs for the chunk (batch_size,)
50
+ beta: Weight for the KTO loss
51
+ kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
52
+ Returns:
53
+ - loss: The KTO loss value
54
+ """
55
+ if ref_average_log_prob_chunk is not None:
56
+ logratios_chunk = average_log_prob_chunk - ref_average_log_prob_chunk
57
+ else:
58
+ logratios_chunk = average_log_prob_chunk
59
+
60
+ multiplier_chunk = torch.where(preference_labels_chunk, 1, -1)
61
+ if kl is not None:
62
+ losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk)
63
+ else:
64
+ losses = 1 - F.sigmoid(beta * logratios_chunk * multiplier_chunk)
65
+
66
+ return losses.sum() / (full_target.shape[0])
67
+
68
+ @staticmethod
69
+ def forward(
70
+ ctx,
71
+ _input,
72
+ weight,
73
+ target,
74
+ preference_labels,
75
+ bias=None,
76
+ ref_input=None,
77
+ ref_weight=None,
78
+ ref_bias=None,
79
+ kl=None,
80
+ ignore_index=-100,
81
+ beta=0.1,
82
+ compiled=True,
83
+ use_ref_model=True,
84
+ ):
85
+ return LigerFusedLinearUnpairedPreferenceBase.forward(
86
+ ctx=ctx,
87
+ _input=_input,
88
+ weight=weight,
89
+ target=target,
90
+ preference_labels=preference_labels,
91
+ bias=bias,
92
+ loss_fn=LigerFusedLinearKTOFunction.preference_loss_fn,
93
+ ignore_index=ignore_index,
94
+ beta=beta,
95
+ compiled=compiled,
96
+ use_ref_model=use_ref_model,
97
+ ref_input=ref_input,
98
+ ref_weight=ref_weight,
99
+ ref_bias=ref_bias,
100
+ kl=kl,
101
+ )
102
+
103
+ @staticmethod
104
+ def backward(ctx, *grad_output):
105
+ grads = LigerFusedLinearUnpairedPreferenceBase.backward(ctx, grad_output)[:5]
106
+ return (
107
+ *grads,
108
+ None,
109
+ None,
110
+ None,
111
+ None,
112
+ None,
113
+ None,
114
+ None,
115
+ None,
116
+ None,
117
+ None,
118
+ )
119
+
120
+
121
+ class LigerFusedLinearKTOLoss(torch.nn.Module):
122
+ """
123
+ Fused linear layer with Kahneman-Tversky Optimization (KTO) loss.
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ ignore_index: int = -100,
129
+ beta: float = 0.1,
130
+ compiled: bool = True,
131
+ use_ref_model: bool = False,
132
+ ):
133
+ """
134
+ Args:
135
+ ignore_index (int): Index to ignore in the loss calculation
136
+ beta (float): Temperature parameter for the KTO loss
137
+ compiled (bool): Whether to use compiled operations
138
+ use_ref_model (bool): Whether to use a reference model for the DPO loss.
139
+ """
140
+ super().__init__()
141
+ self.ignore_index = ignore_index
142
+ self.beta = beta
143
+ self.compiled = compiled
144
+ self.use_ref_model = use_ref_model
145
+
146
+ def forward(
147
+ self,
148
+ _input,
149
+ lin_weight,
150
+ target,
151
+ bias=None,
152
+ preference_labels=None,
153
+ ref_input=None,
154
+ ref_weight=None,
155
+ ref_bias=None,
156
+ kl=None,
157
+ ):
158
+ return LigerFusedLinearKTOFunction.apply(
159
+ _input,
160
+ lin_weight,
161
+ target,
162
+ preference_labels,
163
+ bias,
164
+ ref_input,
165
+ ref_weight,
166
+ ref_bias,
167
+ kl,
168
+ self.ignore_index,
169
+ self.beta,
170
+ self.compiled,
171
+ self.use_ref_model,
172
+ )
@@ -1,13 +1,10 @@
1
1
  import torch
2
2
  import torch.nn.functional as F
3
3
 
4
- from liger_kernel.chunked_loss.fused_linear_preference import (
5
- LigerFusedLinearPreferenceBase,
6
- )
4
+ from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
7
5
 
8
6
 
9
7
  class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
10
-
11
8
  @staticmethod
12
9
  def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
13
10
  """
@@ -32,11 +29,10 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
32
29
  beta (float): Weight for the odds ratio loss.
33
30
  """
34
31
  log_odds = (chosen_logps - rejected_logps) - (
35
- torch.log1p(-torch.exp(chosen_logps))
36
- - torch.log1p(-torch.exp(rejected_logps))
32
+ torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
37
33
  )
38
34
  ratio = F.logsigmoid(log_odds)
39
- loss = beta * ratio.sum() / (full_target.shape[0] // 2)
35
+ loss = -beta * ratio.sum() / (full_target.shape[0] // 2)
40
36
 
41
37
  chosen_rewards = beta * chosen_logps
42
38
  rejected_rewards = beta * rejected_logps
@@ -56,6 +52,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
56
52
  ignore_index=-100,
57
53
  beta=0.1,
58
54
  compute_nll_loss=True,
55
+ nll_target=None,
59
56
  compiled=True,
60
57
  ):
61
58
  return LigerFusedLinearPreferenceBase.forward(
@@ -68,13 +65,14 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
68
65
  ignore_index=ignore_index,
69
66
  beta=beta,
70
67
  compute_nll_loss=compute_nll_loss,
68
+ nll_target=nll_target,
71
69
  compiled=compiled,
72
70
  )
73
71
 
74
72
  @staticmethod
75
73
  def backward(ctx, *grad_output):
76
74
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
77
- return *grads, None, None, None, None
75
+ return *grads, None, None, None, None, None
78
76
 
79
77
 
80
78
  class LigerFusedLinearORPOLoss(torch.nn.Module):
@@ -100,7 +98,7 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
100
98
  self.compute_nll_loss = compute_nll_loss
101
99
  self.compiled = compiled
102
100
 
103
- def forward(self, lin_weight, _input, target, bias=None):
101
+ def forward(self, lin_weight, _input, target, bias=None, nll_target=None):
104
102
  return LigerFusedLinearORPOFunction.apply(
105
103
  _input,
106
104
  lin_weight,
@@ -109,5 +107,6 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
109
107
  self.ignore_index,
110
108
  self.beta,
111
109
  self.compute_nll_loss,
110
+ nll_target,
112
111
  self.compiled,
113
112
  )
@@ -1,16 +1,18 @@
1
1
  import torch
2
2
  import torch.nn.functional as F
3
3
 
4
- from liger_kernel.chunked_loss.fused_linear_preference import (
5
- LigerFusedLinearPreferenceBase,
6
- )
4
+ from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
7
5
 
8
6
 
9
7
  class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
10
-
11
8
  @staticmethod
12
9
  def preference_loss_fn(
13
- chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5
10
+ chosen_logps,
11
+ rejected_logps,
12
+ full_target,
13
+ beta=0.1,
14
+ gamma=0.5,
15
+ label_smoothing=0.0,
14
16
  ):
15
17
  """
16
18
  Paper: https://arxiv.org/pdf/2405.14734
@@ -33,10 +35,17 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
33
35
  full_target: Non chunked full target tensor
34
36
  beta (float): beta weight
35
37
  gamma (float): gemma margin term
38
+ label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
36
39
  """
37
40
  logits = beta * (chosen_logps - rejected_logps) - gamma
38
- loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
39
- return loss
41
+ loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
42
+ full_target.shape[0] // 2
43
+ )
44
+
45
+ chosen_rewards = beta * chosen_logps
46
+ rejected_rewards = beta * rejected_logps
47
+
48
+ return loss, chosen_rewards, rejected_rewards
40
49
 
41
50
  @staticmethod
42
51
  def forward(
@@ -48,6 +57,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
48
57
  ignore_index=-100,
49
58
  beta=0.1,
50
59
  alpha=1.0,
60
+ label_smoothing=0.0,
51
61
  compute_nll_loss=False,
52
62
  compiled=True,
53
63
  gamma=0.5,
@@ -63,6 +73,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
63
73
  ignore_index=ignore_index,
64
74
  alpha=alpha,
65
75
  beta=beta,
76
+ label_smoothing=label_smoothing,
66
77
  compiled=compiled,
67
78
  gamma=gamma,
68
79
  )
@@ -70,7 +81,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
70
81
  @staticmethod
71
82
  def backward(ctx, *grad_output):
72
83
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
73
- return *grads, None, None, None, None, None, None
84
+ return *grads, None, None, None, None, None, None, None
74
85
 
75
86
 
76
87
  class LigerFusedLinearSimPOLoss(torch.nn.Module):
@@ -83,6 +94,7 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
83
94
  ignore_index: int = -100,
84
95
  beta: float = 0.1,
85
96
  alpha: float = 1.0,
97
+ label_smoothing: float = 0.0,
86
98
  compute_nll_loss: bool = True,
87
99
  compiled: bool = True,
88
100
  gamma: float = 0.5,
@@ -96,6 +108,7 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
96
108
  self.ignore_index = ignore_index
97
109
  self.beta = beta
98
110
  self.alpha = alpha
111
+ self.label_smoothing = label_smoothing
99
112
  self.compute_nll_loss = compute_nll_loss
100
113
  self.compiled = compiled
101
114
  self.gamma = gamma
@@ -109,6 +122,7 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
109
122
  self.ignore_index,
110
123
  self.beta,
111
124
  self.alpha,
125
+ self.label_smoothing,
112
126
  self.compute_nll_loss,
113
127
  self.compiled,
114
128
  self.gamma,
@@ -1,12 +1,13 @@
1
1
  import platform
2
2
  import sys
3
+
3
4
  from importlib.metadata import version
4
5
 
5
6
 
6
7
  def print_env_report():
7
8
  """
8
9
 
9
- Prints a report of the environment. Useful for debugging and reproducibility.
10
+ Prints a report of the environment. Useful for debugging and reproducibility.
10
11
  Usage:
11
12
  ```
12
13
  python -m liger_kernel.env_report
@@ -27,15 +28,9 @@ def print_env_report():
27
28
  import torch
28
29
 
29
30
  print(f"PyTorch version: {torch.__version__}")
30
- cuda_version = (
31
- torch.version.cuda if torch.cuda.is_available() else "Not available"
32
- )
31
+ cuda_version = torch.version.cuda if torch.cuda.is_available() else "Not available"
33
32
  print(f"CUDA version: {cuda_version}")
34
- hip_version = (
35
- torch.version.hip
36
- if torch.cuda.is_available() and torch.version.hip
37
- else "Not available"
38
- )
33
+ hip_version = torch.version.hip if torch.cuda.is_available() and torch.version.hip else "Not available"
39
34
  print(f"HIP(ROCm) version: {hip_version}")
40
35
 
41
36
  except ImportError:
@@ -58,9 +53,7 @@ def print_env_report():
58
53
  print("Transformers: Not installed")
59
54
 
60
55
  try:
61
- xpu_version = (
62
- torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
63
- )
56
+ xpu_version = torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
64
57
  print(f"XPU version: {xpu_version}")
65
58
  except ImportError:
66
59
  print("XPU version: Unable to query")