liger-kernel-nightly 0.5.6.dev20250403190551__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +35 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +25 -9
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  15. liger_kernel/ops/backends/registry.py +61 -0
  16. liger_kernel/ops/cross_entropy.py +124 -64
  17. liger_kernel/ops/dyt.py +115 -180
  18. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  19. liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
  20. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  21. liger_kernel/ops/geglu.py +3 -2
  22. liger_kernel/ops/group_norm.py +2 -1
  23. liger_kernel/ops/grpo_loss.py +312 -0
  24. liger_kernel/ops/jsd.py +2 -1
  25. liger_kernel/ops/kl_div.py +13 -6
  26. liger_kernel/ops/layer_norm.py +146 -78
  27. liger_kernel/ops/llama4_rope.py +225 -0
  28. liger_kernel/ops/multi_token_attention.py +207 -0
  29. liger_kernel/ops/poly_norm.py +390 -0
  30. liger_kernel/ops/rms_norm.py +283 -56
  31. liger_kernel/ops/rope.py +1 -1
  32. liger_kernel/ops/softmax.py +201 -0
  33. liger_kernel/ops/sparsemax.py +179 -0
  34. liger_kernel/ops/swiglu.py +1 -1
  35. liger_kernel/ops/tiled_mlp.py +136 -0
  36. liger_kernel/ops/utils.py +2 -0
  37. liger_kernel/transformers/__init__.py +205 -19
  38. liger_kernel/transformers/cross_entropy.py +9 -4
  39. liger_kernel/transformers/dyt.py +6 -4
  40. liger_kernel/transformers/experimental/__init__.py +5 -0
  41. liger_kernel/transformers/experimental/embedding.py +1 -1
  42. liger_kernel/transformers/fsdp.py +55 -0
  43. liger_kernel/transformers/functional.py +122 -20
  44. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  45. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  46. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  47. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  48. liger_kernel/transformers/geglu.py +1 -1
  49. liger_kernel/transformers/group_norm.py +1 -1
  50. liger_kernel/transformers/grpo_loss.py +153 -0
  51. liger_kernel/transformers/jsd.py +1 -1
  52. liger_kernel/transformers/kl_div.py +1 -1
  53. liger_kernel/transformers/layer_norm.py +1 -1
  54. liger_kernel/transformers/llama4_rope.py +93 -0
  55. liger_kernel/transformers/model/falcon_h1.py +122 -0
  56. liger_kernel/transformers/model/gemma.py +50 -25
  57. liger_kernel/transformers/model/gemma2.py +55 -23
  58. liger_kernel/transformers/model/gemma3.py +117 -120
  59. liger_kernel/transformers/model/glm4.py +141 -0
  60. liger_kernel/transformers/model/glm4v.py +163 -0
  61. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  62. liger_kernel/transformers/model/gpt_oss.py +211 -0
  63. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  64. liger_kernel/transformers/model/internvl.py +157 -0
  65. liger_kernel/transformers/model/llama.py +102 -25
  66. liger_kernel/transformers/model/llama4.py +121 -0
  67. liger_kernel/transformers/model/llava.py +111 -136
  68. liger_kernel/transformers/model/loss_utils.py +50 -12
  69. liger_kernel/transformers/model/mistral.py +36 -23
  70. liger_kernel/transformers/model/mixtral.py +45 -25
  71. liger_kernel/transformers/model/mllama.py +39 -22
  72. liger_kernel/transformers/model/olmo2.py +40 -20
  73. liger_kernel/transformers/model/olmo3.py +142 -0
  74. liger_kernel/transformers/model/output_classes.py +147 -0
  75. liger_kernel/transformers/model/paligemma.py +50 -14
  76. liger_kernel/transformers/model/phi3.py +47 -177
  77. liger_kernel/transformers/model/qwen2.py +48 -21
  78. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  79. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  80. liger_kernel/transformers/model/qwen3.py +136 -0
  81. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  82. liger_kernel/transformers/model/qwen3_next.py +146 -0
  83. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  84. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  85. liger_kernel/transformers/model/smollm3.py +199 -0
  86. liger_kernel/transformers/model/smolvlm.py +158 -0
  87. liger_kernel/transformers/monkey_patch.py +1678 -160
  88. liger_kernel/transformers/multi_token_attention.py +64 -0
  89. liger_kernel/transformers/poly_norm.py +42 -0
  90. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  91. liger_kernel/transformers/rms_norm.py +48 -5
  92. liger_kernel/transformers/rope.py +45 -1
  93. liger_kernel/transformers/softmax.py +12 -0
  94. liger_kernel/transformers/sparsemax.py +16 -0
  95. liger_kernel/transformers/swiglu.py +39 -1
  96. liger_kernel/transformers/tiled_mlp.py +133 -0
  97. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  98. liger_kernel/transformers/tvd.py +1 -1
  99. liger_kernel/utils.py +36 -0
  100. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/METADATA +68 -38
  101. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  102. liger_kernel/transformers/gema3_rms.py +0 -8
  103. liger_kernel_nightly-0.5.6.dev20250403190551.dist-info/RECORD +0 -82
  104. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  105. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/NOTICE +0 -0
  106. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +0 -0
  107. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityLoss # noqa:F401
1
2
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
2
3
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
3
4
  from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
@@ -0,0 +1,136 @@
1
+ from typing import Tuple
2
+ from typing import Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
8
+
9
+
10
+ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase):
11
+ @staticmethod
12
+ def distillation_loss_fn(student_logits, teacher_logits, beta=1.0):
13
+ """
14
+ Compute Cosine loss (Cosine Similarity Loss).
15
+ Args:
16
+ student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
17
+ teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
18
+ beta: Coefficient beta of generalized Cosine Similarity in the interval [0, 1]. Default: `1.0` (float): .
19
+ Returns:
20
+ torch.Tensor: cosine similarity loss
21
+ """
22
+ student_norm = F.normalize(student_logits, p=2, dim=-1)
23
+ teacher_norm = F.normalize(teacher_logits, p=2, dim=-1)
24
+
25
+ cosine_sim = F.cosine_similarity(student_norm, teacher_norm, dim=-1)
26
+ loss = beta * (1 - cosine_sim)
27
+ return loss.sum()
28
+
29
+ @classmethod
30
+ def forward(
31
+ cls,
32
+ ctx,
33
+ student_input: torch.Tensor,
34
+ student_weight: torch.Tensor,
35
+ teacher_input: torch.Tensor,
36
+ teacher_weight: torch.Tensor,
37
+ true_labels: torch.LongTensor,
38
+ student_bias: torch.Tensor,
39
+ teacher_bias: torch.Tensor,
40
+ weight_hard_loss: float = 0.5,
41
+ weight_soft_loss: float = 0.5,
42
+ beta: float = 0.5,
43
+ ignore_index: int = -100,
44
+ temperature: float = 1.0,
45
+ compiled: bool = True,
46
+ chunk_size: int = 1024,
47
+ return_soft_hard_loss: bool = False,
48
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
49
+ return super().forward(
50
+ cls=cls,
51
+ ctx=ctx,
52
+ student_input=student_input,
53
+ student_weight=student_weight,
54
+ teacher_input=teacher_input,
55
+ teacher_weight=teacher_weight,
56
+ target=true_labels,
57
+ student_bias=student_bias,
58
+ teacher_bias=teacher_bias,
59
+ chunk_size=chunk_size,
60
+ weight_hard_loss=weight_hard_loss,
61
+ weight_soft_loss=weight_soft_loss,
62
+ beta=beta,
63
+ ignore_index=ignore_index,
64
+ temperature=temperature,
65
+ compiled=compiled,
66
+ return_soft_hard_loss=return_soft_hard_loss,
67
+ )
68
+
69
+ @staticmethod
70
+ def backward(ctx, grad_output, *args):
71
+ grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
72
+
73
+ return (
74
+ *grads,
75
+ None, # teacher_bias
76
+ None, # weight_hard_loss
77
+ None, # weight_soft_loss
78
+ None, # beta
79
+ None, # ignore_index
80
+ None, # temperature
81
+ None, # compiled
82
+ None, # chunk_size
83
+ None, # return_soft_hard_loss
84
+ )
85
+
86
+
87
+ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
88
+ def __init__(
89
+ self,
90
+ weight_hard_loss: float = 0.5,
91
+ weight_soft_loss: float = 0.5,
92
+ beta: float = 0.5,
93
+ ignore_index: int = -100,
94
+ temperature: float = 1.0,
95
+ compiled: bool = True,
96
+ chunk_size: int = 1024,
97
+ return_soft_hard_loss: bool = False,
98
+ ):
99
+ super().__init__()
100
+ assert temperature != 0, "Temperature cannot be 0."
101
+ self.weight_hard_loss = weight_hard_loss
102
+ self.weight_soft_loss = weight_soft_loss
103
+ self.ignore_index = ignore_index
104
+ self.temperature = temperature
105
+ self.compiled = compiled
106
+ self.beta = beta
107
+ self.chunk_size = chunk_size
108
+ self.return_soft_hard_loss = return_soft_hard_loss
109
+
110
+ def forward(
111
+ self,
112
+ student_input: torch.Tensor,
113
+ student_weight: torch.Tensor,
114
+ teacher_input: torch.Tensor,
115
+ teacher_weight: torch.Tensor,
116
+ true_labels: torch.LongTensor,
117
+ student_bias: torch.Tensor = None,
118
+ teacher_bias: torch.Tensor = None,
119
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
120
+ return LigerFusedLinearCosineSimilarityFunction.apply(
121
+ student_input,
122
+ student_weight,
123
+ teacher_input,
124
+ teacher_weight,
125
+ true_labels,
126
+ student_bias,
127
+ teacher_bias,
128
+ self.weight_hard_loss,
129
+ self.weight_soft_loss,
130
+ self.beta,
131
+ self.ignore_index,
132
+ self.temperature,
133
+ self.compiled,
134
+ self.chunk_size,
135
+ self.return_soft_hard_loss,
136
+ )
@@ -13,6 +13,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
13
13
  ref_chosen_logps=None,
14
14
  ref_rejected_logps=None,
15
15
  beta=0.1,
16
+ loss_type="sigmoid",
16
17
  ):
17
18
  """
18
19
  Paper: https://arxiv.org/pdf/2305.18290
@@ -48,8 +49,50 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
48
49
  chosen_rewards = beta * chosen_logratios
49
50
  rejected_rewards = beta * rejected_logratios
50
51
 
51
- logits_diff = beta * (chosen_logratios - rejected_logratios)
52
- loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
52
+ if loss_type == "sigmoid":
53
+ logits_diff = beta * (chosen_logratios - rejected_logratios)
54
+ loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
55
+
56
+ elif loss_type == "apo_zero":
57
+ # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
58
+ # Use this loss when you believe the chosen outputs are better than your model's default output
59
+ losses_chosen = 1 - F.sigmoid(beta * chosen_logratios) # Increase chosen likelihood
60
+ losses_rejected = F.sigmoid(beta * rejected_logratios)
61
+ losses = losses_chosen + losses_rejected
62
+ loss = losses.sum() / (full_target.shape[0] // 2)
63
+
64
+ elif loss_type == "apo_down":
65
+ # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
66
+ # Use this loss when you believe the chosen outputs are worse than your model's default output.
67
+ # Decrease chosen likelihood and decrease rejected likelihood more
68
+ losses_chosen = F.sigmoid(beta * chosen_logratios)
69
+ losses_rejected = 1 - F.sigmoid(beta * (chosen_logratios - rejected_logratios))
70
+ losses = losses_chosen + losses_rejected
71
+ loss = losses.sum() / (full_target.shape[0] // 2)
72
+
73
+ elif loss_type == "sppo_hard":
74
+ # In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach,
75
+ # estimated using the PairRM score. The probability calculation is conducted outside of the trainer class.
76
+ # The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is
77
+ # set to 1 for the winner and 0 for the loser.
78
+ a = chosen_logps - ref_chosen_logps
79
+ b = rejected_logps - ref_rejected_logps
80
+ losses = (a - 0.5 / beta) ** 2 + (b + 0.5 / beta) ** 2
81
+ loss = losses.sum() / (full_target.shape[0] // 2)
82
+
83
+ elif loss_type == "nca_pair":
84
+ losses = (
85
+ -F.logsigmoid(chosen_rewards)
86
+ - 0.5 * F.logsigmoid(-chosen_rewards)
87
+ - 0.5 * F.logsigmoid(-rejected_rewards)
88
+ )
89
+ loss = losses.sum() / (full_target.shape[0] // 2)
90
+
91
+ else:
92
+ raise ValueError(
93
+ f"Unsupported loss_type: {loss_type}. Supported types are: sigmoid, apo_zero, apo_down, sppo_hard, nca_pair"
94
+ )
95
+
53
96
  return loss, chosen_rewards, rejected_rewards
54
97
 
55
98
  @classmethod
@@ -68,7 +111,9 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
68
111
  compute_nll_loss=False,
69
112
  compiled=True,
70
113
  use_ref_model=True,
114
+ average_log_prob=False,
71
115
  chunk_size=1,
116
+ loss_type="sigmoid",
72
117
  ):
73
118
  """
74
119
  Fused linear layer with DPO loss.
@@ -85,6 +130,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
85
130
  compute_nll_loss (bool): Whether to compute the NLL loss
86
131
  compiled (bool): Whether to use torch compile
87
132
  use_ref_model (bool): Whether to use a reference model
133
+ average_log_prob (bool): Whether to average the log probability per non-masked token
88
134
  chunk_size (int): Size of chunks for processing.
89
135
  Returns:
90
136
  torch.Tensor: Computed loss
@@ -104,13 +150,15 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
104
150
  ref_input=ref_input,
105
151
  ref_weight=ref_weight,
106
152
  ref_bias=ref_bias,
153
+ average_log_prob=average_log_prob,
107
154
  chunk_size=chunk_size,
155
+ loss_type=loss_type,
108
156
  )
109
157
 
110
158
  @staticmethod
111
159
  def backward(ctx, *grad_output):
112
160
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
113
- return *grads, None, None, None, None, None, None, None, None, None
161
+ return *grads, None, None, None, None, None, None, None, None, None, None, None
114
162
 
115
163
 
116
164
  class LigerFusedLinearDPOLoss(torch.nn.Module):
@@ -125,7 +173,9 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
125
173
  compute_nll_loss: bool = False,
126
174
  compiled: bool = True,
127
175
  use_ref_model: bool = True,
176
+ average_log_prob: bool = False,
128
177
  chunk_size: int = 1,
178
+ loss_type: str = "sigmoid",
129
179
  ):
130
180
  """
131
181
  Args:
@@ -134,6 +184,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
134
184
  compute_nll_loss (bool): Whether to compute the NLL loss.
135
185
  compiled (bool): Whether to use the torch compiled kernel.
136
186
  use_ref_model (bool): Whether to use a reference model for the DPO loss.
187
+ average_log_prob (bool): Whether to average the log probability per non-masked token.
137
188
  chunk_size (int): Size of chunks for processing.
138
189
  """
139
190
  super().__init__()
@@ -142,7 +193,12 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
142
193
  self.compute_nll_loss = compute_nll_loss
143
194
  self.compiled = compiled
144
195
  self.use_ref_model = use_ref_model
196
+ self.average_log_prob = average_log_prob
145
197
  self.chunk_size = chunk_size
198
+ self.loss_type = loss_type
199
+ supported_loss_types = {"sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"}
200
+ if self.loss_type not in supported_loss_types:
201
+ raise ValueError(f"Unsupported loss_type: {self.loss_type}. Supported types are: {supported_loss_types}")
146
202
 
147
203
  def forward(
148
204
  self,
@@ -167,5 +223,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
167
223
  self.compute_nll_loss,
168
224
  self.compiled,
169
225
  self.use_ref_model,
226
+ self.average_log_prob,
170
227
  self.chunk_size,
228
+ self.loss_type,
171
229
  )
@@ -1,3 +1,4 @@
1
+ from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction
1
2
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
2
3
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3
4
  from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
@@ -9,6 +10,7 @@ from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
9
10
  liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
10
11
  liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
11
12
  liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
13
+ liger_fused_linear_cosine = LigerFusedLinearCosineSimilarityFunction.apply
12
14
  liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
13
15
  liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
14
16
  liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
@@ -1,5 +1,7 @@
1
1
  from abc import abstractmethod
2
2
  from functools import partial
3
+ from typing import Tuple
4
+ from typing import Union
3
5
 
4
6
  import torch
5
7
 
@@ -157,8 +159,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
157
159
  compute_ce_loss=True,
158
160
  temperature=1.0,
159
161
  compiled=True,
162
+ return_soft_hard_loss=False,
160
163
  **loss_kwargs,
161
- ):
164
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
162
165
  """
163
166
  Base class for fused linear layer with distillation loss.
164
167
  Only need to compute gradients for student model.
@@ -180,6 +183,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
180
183
  compute_ce_loss (bool): Whether to compute CE loss.
181
184
  temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
182
185
  compiled (bool): Whether to use torch compile for chunk accumulation.
186
+ return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
183
187
  loss_kwargs (dict): Other possible arguments that a loss function might need
184
188
  """
185
189
  CHUNK_SIZE = chunk_size
@@ -187,6 +191,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
187
191
  grad_inputs = []
188
192
  grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
189
193
  loss_acc = torch.zeros((), device=student_input.device)
194
+ soft_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
195
+ hard_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
190
196
 
191
197
  loss_func_to_call = partial(
192
198
  LigerFusedLinearDistillationBase._compute_loss,
@@ -247,6 +253,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
247
253
  )
248
254
  grad_weight.add_(chunk_grad_weight)
249
255
  loss_acc.add_(chunk_loss)
256
+ if return_soft_hard_loss:
257
+ soft_loss_acc.add_(chunk_soft_loss)
258
+ hard_loss_acc.add_(chunk_hard_loss)
250
259
  return chunk_grad_input
251
260
 
252
261
  if compiled:
@@ -268,10 +277,12 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
268
277
  grad_weight,
269
278
  grad_bias,
270
279
  )
280
+ if return_soft_hard_loss:
281
+ return loss_acc, soft_loss_acc, hard_loss_acc
271
282
  return loss_acc
272
283
 
273
284
  @staticmethod
274
- def backward(ctx, grad_output):
285
+ def backward(ctx, grad_output, *args):
275
286
  grad_input, grad_weight, grad_bias = ctx.saved_tensors
276
287
  if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
277
288
  grad_input = grad_input * grad_output
@@ -32,6 +32,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
32
32
  epsilon_low=0.2,
33
33
  epsilon_high=0.2,
34
34
  beta=0.04,
35
+ loss_type="dapo",
36
+ max_completion_length=None,
37
+ importance_sampling_level="token",
35
38
  temperature=1.0,
36
39
  compiled=True,
37
40
  use_ref_model=False,
@@ -57,6 +60,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
57
60
  epsilon_low: Lower bound for clipping the importance sampling ratio
58
61
  epsilon_high: Upper bound for clipping the importance sampling ratio
59
62
  beta: Weight for the KL penalty
63
+ loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo")
64
+ max_completion_length: Maximum completion length required for "dr_grpo"
60
65
  temperature: Temperature for the logits
61
66
  compiled: Whether to use torch compile
62
67
  use_ref_model: Whether to use a reference model
@@ -68,6 +73,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
68
73
  )
69
74
  if ref_per_token_logps is not None and ref_input is not None:
70
75
  raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
76
+ if loss_type == "dr_grpo":
77
+ assert max_completion_length is not None, "max_completion_length must be provided for loss_type 'dr_grpo'"
71
78
  # Initialize accumulators
72
79
  loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
73
80
  grad_weight = torch.zeros_like(weight) # [V, H]
@@ -84,6 +91,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
84
91
  epsilon_low=epsilon_low,
85
92
  epsilon_high=epsilon_high,
86
93
  beta=beta,
94
+ loss_type=loss_type,
95
+ max_completion_length=max_completion_length,
96
+ importance_sampling_level=importance_sampling_level,
87
97
  temperature=temperature,
88
98
  use_ref_model=use_ref_model,
89
99
  ppo_loss_fn=cls.ppo_loss_fn,
@@ -234,6 +244,21 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
234
244
 
235
245
  return loss_acc, tuple(final_metrics)
236
246
 
247
+ @staticmethod
248
+ def _compute_dapo_normalizer(attention_mask):
249
+ """Global active tokens averaged per process."""
250
+ normalizer = attention_mask.to(torch.float32).sum()
251
+ world_size = 1
252
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
253
+ import torch.distributed as dist
254
+
255
+ normalizer = normalizer.clone()
256
+ dist.all_reduce(normalizer, op=dist.ReduceOp.SUM)
257
+ world_size = dist.get_world_size()
258
+
259
+ normalizer = normalizer / world_size
260
+ return torch.clamp(normalizer, min=1.0)
261
+
237
262
  @staticmethod
238
263
  def _compute_chunk_loss(
239
264
  input_chunk,
@@ -251,6 +276,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
251
276
  epsilon_low=0.2,
252
277
  epsilon_high=0.2,
253
278
  beta=0.04,
279
+ loss_type="dapo",
280
+ max_completion_length=None,
281
+ importance_sampling_level="token",
254
282
  temperature=1.0,
255
283
  use_ref_model=False,
256
284
  ppo_loss_fn=None,
@@ -280,6 +308,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
280
308
  epsilon_low=epsilon_low,
281
309
  epsilon_high=epsilon_high,
282
310
  beta=beta,
311
+ loss_type=loss_type,
312
+ max_completion_length=max_completion_length,
313
+ importance_sampling_level=importance_sampling_level,
283
314
  )
284
315
 
285
316
  return chunk_loss, chunk_metrics
@@ -303,6 +334,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
303
334
  def backward(ctx, grad_output, *grad_metrics):
304
335
  """Backward pass for PPO loss."""
305
336
  grad_input, grad_weight, grad_bias = ctx.saved_tensors
337
+
306
338
  if grad_output != 1.0:
307
339
  grad_input = grad_input * grad_output
308
340
  grad_weight = grad_weight * grad_output
@@ -324,6 +356,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
324
356
  None, # grad_epsilon_low
325
357
  None, # grad_epsilon_high
326
358
  None, # grad_beta
359
+ None, # grad_loss_type
360
+ None, # grad_max_completion_length
361
+ None, # grad_importance_sampling_level
327
362
  None, # grad_temperature
328
363
  None, # grad_compiled
329
364
  None, # grad_use_ref_model
@@ -222,7 +222,6 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
222
222
  (_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
223
223
  (_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
224
224
  (_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)),
225
- strict=False,
226
225
  ):
227
226
  input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
228
227
  ref_input_chunk = (
@@ -1,3 +1,5 @@
1
+ from typing import Optional
2
+
1
3
  import torch
2
4
 
3
5
  from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
@@ -27,6 +29,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
27
29
  epsilon_low=0.2,
28
30
  epsilon_high=0.2,
29
31
  beta=0.04,
32
+ loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo"]
33
+ max_completion_length=None, # Required for dr_grpo
34
+ importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
30
35
  **kwargs,
31
36
  ):
32
37
  """GRPO Loss Function matching GRPOTrainer implementation."""
@@ -46,7 +51,22 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
46
51
 
47
52
  # Compute policy gradient loss with importance sampling ratio
48
53
  old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
49
- coef_1 = torch.exp(per_token_logps - old_per_token_logps)
54
+ log_ratio = per_token_logps - old_per_token_logps
55
+
56
+ if importance_sampling_level == "token":
57
+ log_importance_weights = log_ratio
58
+ elif importance_sampling_level == "sequence":
59
+ log_importance_weights = (log_ratio * attention_mask).sum(-1) / attention_mask.sum(-1).clamp(min=1.0)
60
+ log_importance_weights = log_importance_weights.unsqueeze(-1)
61
+ else:
62
+ raise ValueError(
63
+ f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
64
+ "and 'sequence'."
65
+ )
66
+
67
+ # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
68
+ # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
69
+ coef_1 = torch.exp(log_importance_weights)
50
70
  coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
51
71
  per_token_loss1 = coef_1 * advantages.unsqueeze(1)
52
72
  per_token_loss2 = coef_2 * advantages.unsqueeze(1)
@@ -61,15 +81,42 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
61
81
  # which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
62
82
  # and TRL GRPO implementation
63
83
  # (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
64
- loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
84
+ if loss_type == "grpo":
85
+ # Average per-sequence loss
86
+ loss = (
87
+ (per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)
88
+ ).sum() / full_attention_mask.shape[0]
89
+ elif loss_type == "bnpo":
90
+ # Batch Normalized Per-token loss (original implementation)
91
+ loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
92
+ elif loss_type == "dr_grpo":
93
+ # Dimension-Reduced GRPO (normalize by batch_size * max_completion_length)
94
+ if max_completion_length is None:
95
+ raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
96
+ loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
97
+ elif loss_type == "dapo":
98
+ loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask)
99
+ loss = (per_token_loss * attention_mask).sum() / loss_normalizer
100
+ else:
101
+ raise ValueError(f"Unknown loss type: {loss_type}")
65
102
 
66
103
  # Calculate metrics
67
104
  metrics = []
68
105
  if beta != 0.0:
69
106
  metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
70
- is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
71
- (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
72
- )
107
+
108
+ # Adjust clipping metric calculation based on importance sampling level
109
+ if importance_sampling_level == "token":
110
+ is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
111
+ (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
112
+ )
113
+ else: # sequence level
114
+ # For sequence level, coef_1 is shape (B, 1), advantages is shape (B,)
115
+ is_clipped = ((coef_1.squeeze(-1) < 1 - epsilon_low) & (advantages < 0)) | (
116
+ (coef_1.squeeze(-1) > 1 + epsilon_high) & (advantages > 0)
117
+ )
118
+ is_clipped = is_clipped.unsqueeze(1).expand_as(attention_mask)
119
+
73
120
  metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
74
121
  return loss, metrics
75
122
 
@@ -91,6 +138,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
91
138
  beta=0.04,
92
139
  epsilon_low=0.2,
93
140
  epsilon_high=0.2,
141
+ loss_type="dapo",
142
+ max_completion_length=None,
143
+ importance_sampling_level="token",
94
144
  temperature=1.0,
95
145
  compiled=True,
96
146
  use_ref_model=True,
@@ -110,6 +160,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
110
160
  ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
111
161
  ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
112
162
  beta (float): Weight for the KL penalty
163
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
164
+ max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
165
+ importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
113
166
  temperature (float): Temperature for the logits
114
167
  compiled (bool): Whether to use torch compile
115
168
  use_ref_model (bool): Whether to use a reference model
@@ -134,10 +187,13 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
134
187
  beta=beta,
135
188
  epsilon_low=epsilon_low,
136
189
  epsilon_high=epsilon_high,
190
+ loss_type=loss_type,
191
+ max_completion_length=max_completion_length,
137
192
  temperature=temperature,
138
193
  compiled=compiled,
139
194
  use_ref_model=use_ref_model,
140
195
  chunk_size=chunk_size,
196
+ importance_sampling_level=importance_sampling_level,
141
197
  )
142
198
 
143
199
  @staticmethod
@@ -161,6 +217,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
161
217
  None, # grad_beta
162
218
  None, # grad_epsilon_low
163
219
  None, # grad_epsilon_high
220
+ None, # grad_loss_type (string, not differentiable)
221
+ None, # grad_max_completion_length (int, not differentiable)
222
+ None, # grad_importance_sampling_level (string, not differentiable)
164
223
  None, # grad_temperature
165
224
  None, # grad_compiled
166
225
  None, # grad_use_ref_model
@@ -179,6 +238,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
179
238
  chunk_size: int = 1,
180
239
  epsilon_low: float = 0.2,
181
240
  epsilon_high: float = 0.2,
241
+ loss_type: str = "dapo",
242
+ max_completion_length: Optional[int] = None,
243
+ importance_sampling_level: str = "token",
182
244
  temperature: float = 1.0,
183
245
  ):
184
246
  """
@@ -189,6 +251,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
189
251
  chunk_size (int): Size of chunks for processing.
190
252
  epsilon_low (float): Lower bound for the importance sampling ratio.
191
253
  epsilon_high (float): Upper bound for the importance sampling ratio.
254
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
255
+ max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
256
+ importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
192
257
  temperature (float): Temperature for the logits.
193
258
  """
194
259
  super().__init__()
@@ -198,6 +263,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
198
263
  self.chunk_size = chunk_size
199
264
  self.epsilon_low = epsilon_low
200
265
  self.epsilon_high = epsilon_high
266
+ self.loss_type = loss_type
267
+ self.max_completion_length = max_completion_length
268
+ self.importance_sampling_level = importance_sampling_level
201
269
  self.temperature = temperature
202
270
 
203
271
  def forward(
@@ -229,6 +297,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
229
297
  self.beta,
230
298
  self.epsilon_low,
231
299
  self.epsilon_high,
300
+ self.loss_type,
301
+ self.max_completion_length,
302
+ self.importance_sampling_level,
232
303
  self.temperature,
233
304
  self.compiled,
234
305
  self.use_ref_model,