liger-kernel-nightly 0.5.10.dev20250611191801__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (107) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +54 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
  7. liger_kernel/chunked_loss/grpo_loss.py +46 -9
  8. liger_kernel/chunked_loss/jsd_loss.py +44 -13
  9. liger_kernel/ops/__init__.py +141 -0
  10. liger_kernel/ops/backends/README.md +151 -0
  11. liger_kernel/ops/backends/__init__.py +13 -0
  12. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  13. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  14. liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
  15. liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
  16. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  17. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  18. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  19. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  20. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  21. liger_kernel/ops/backends/registry.py +61 -0
  22. liger_kernel/ops/cross_entropy.py +130 -64
  23. liger_kernel/ops/dyt.py +5 -4
  24. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  25. liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
  26. liger_kernel/ops/geglu.py +6 -4
  27. liger_kernel/ops/group_norm.py +7 -7
  28. liger_kernel/ops/grpo_loss.py +3 -1
  29. liger_kernel/ops/kl_div.py +8 -11
  30. liger_kernel/ops/layer_norm.py +135 -80
  31. liger_kernel/ops/llama4_rope.py +225 -0
  32. liger_kernel/ops/poly_norm.py +390 -0
  33. liger_kernel/ops/rms_norm.py +148 -71
  34. liger_kernel/ops/rope.py +1 -1
  35. liger_kernel/ops/swiglu.py +1 -1
  36. liger_kernel/ops/tiled_mlp.py +136 -0
  37. liger_kernel/ops/utils.py +14 -0
  38. liger_kernel/transformers/__init__.py +65 -0
  39. liger_kernel/transformers/auto_model.py +21 -0
  40. liger_kernel/transformers/cross_entropy.py +9 -4
  41. liger_kernel/transformers/dyt.py +1 -1
  42. liger_kernel/transformers/experimental/__init__.py +5 -0
  43. liger_kernel/transformers/experimental/embedding.py +1 -1
  44. liger_kernel/transformers/functional.py +56 -24
  45. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  46. liger_kernel/transformers/fused_linear_cross_entropy.py +17 -5
  47. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  48. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  49. liger_kernel/transformers/geglu.py +1 -1
  50. liger_kernel/transformers/group_norm.py +1 -1
  51. liger_kernel/transformers/grpo_loss.py +57 -2
  52. liger_kernel/transformers/jsd.py +1 -1
  53. liger_kernel/transformers/kl_div.py +1 -1
  54. liger_kernel/transformers/layer_norm.py +1 -1
  55. liger_kernel/transformers/llama4_rope.py +93 -0
  56. liger_kernel/transformers/model/exaone4.py +136 -0
  57. liger_kernel/transformers/model/falcon_h1.py +122 -0
  58. liger_kernel/transformers/model/gemma.py +28 -8
  59. liger_kernel/transformers/model/gemma2.py +34 -11
  60. liger_kernel/transformers/model/gemma3.py +102 -112
  61. liger_kernel/transformers/model/glm4.py +18 -5
  62. liger_kernel/transformers/model/glm4v.py +163 -0
  63. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  64. liger_kernel/transformers/model/gpt_oss.py +211 -0
  65. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  66. liger_kernel/transformers/model/internvl.py +157 -0
  67. liger_kernel/transformers/model/llama.py +26 -7
  68. liger_kernel/transformers/model/llama4.py +121 -0
  69. liger_kernel/transformers/model/llava.py +18 -6
  70. liger_kernel/transformers/model/loss_utils.py +34 -3
  71. liger_kernel/transformers/model/mistral.py +17 -10
  72. liger_kernel/transformers/model/mixtral.py +24 -9
  73. liger_kernel/transformers/model/mllama.py +18 -7
  74. liger_kernel/transformers/model/olmo2.py +18 -5
  75. liger_kernel/transformers/model/olmo3.py +142 -0
  76. liger_kernel/transformers/model/output_classes.py +147 -0
  77. liger_kernel/transformers/model/paligemma.py +42 -5
  78. liger_kernel/transformers/model/phi3.py +24 -159
  79. liger_kernel/transformers/model/qwen2.py +26 -4
  80. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  81. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  82. liger_kernel/transformers/model/qwen3.py +22 -6
  83. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  84. liger_kernel/transformers/model/qwen3_next.py +146 -0
  85. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  86. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  87. liger_kernel/transformers/model/smollm3.py +199 -0
  88. liger_kernel/transformers/model/smolvlm.py +158 -0
  89. liger_kernel/transformers/monkey_patch.py +1423 -100
  90. liger_kernel/transformers/multi_token_attention.py +2 -2
  91. liger_kernel/transformers/poly_norm.py +42 -0
  92. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  93. liger_kernel/transformers/rms_norm.py +15 -5
  94. liger_kernel/transformers/rope.py +45 -1
  95. liger_kernel/transformers/softmax.py +1 -1
  96. liger_kernel/transformers/sparsemax.py +1 -1
  97. liger_kernel/transformers/swiglu.py +18 -1
  98. liger_kernel/transformers/tiled_mlp.py +125 -0
  99. liger_kernel/transformers/tvd.py +1 -1
  100. liger_kernel/utils.py +52 -0
  101. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +37 -25
  102. liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
  103. liger_kernel_nightly-0.5.10.dev20250611191801.dist-info/RECORD +0 -95
  104. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
  105. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
  106. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
  107. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityLoss # noqa:F401
1
2
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
2
3
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
3
4
  from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
@@ -0,0 +1,142 @@
1
+ from typing import Tuple
2
+ from typing import Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
8
+
9
+
10
+ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase):
11
+ @staticmethod
12
+ def distillation_loss_fn(
13
+ student_logits,
14
+ teacher_logits,
15
+ target=None,
16
+ ignore_index=None,
17
+ beta=1.0,
18
+ ):
19
+ """
20
+ Compute Cosine loss (Cosine Similarity Loss).
21
+ Args:
22
+ student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
23
+ teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
24
+ beta: Coefficient beta of generalized Cosine Similarity in the interval [0, 1]. Default: `1.0` (float): .
25
+ Returns:
26
+ torch.Tensor: cosine similarity loss
27
+ """
28
+ student_norm = F.normalize(student_logits, p=2, dim=-1)
29
+ teacher_norm = F.normalize(teacher_logits, p=2, dim=-1)
30
+
31
+ cosine_sim = F.cosine_similarity(student_norm, teacher_norm, dim=-1)
32
+ loss = beta * (1 - cosine_sim)
33
+ return loss.sum()
34
+
35
+ @classmethod
36
+ def forward(
37
+ cls,
38
+ ctx,
39
+ student_input: torch.Tensor,
40
+ student_weight: torch.Tensor,
41
+ teacher_input: torch.Tensor,
42
+ teacher_weight: torch.Tensor,
43
+ true_labels: torch.LongTensor,
44
+ student_bias: torch.Tensor,
45
+ teacher_bias: torch.Tensor,
46
+ weight_hard_loss: float = 0.5,
47
+ weight_soft_loss: float = 0.5,
48
+ beta: float = 0.5,
49
+ ignore_index: int = -100,
50
+ temperature: float = 1.0,
51
+ compiled: bool = True,
52
+ chunk_size: int = 1024,
53
+ return_soft_hard_loss: bool = False,
54
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
55
+ return super().forward(
56
+ cls=cls,
57
+ ctx=ctx,
58
+ student_input=student_input,
59
+ student_weight=student_weight,
60
+ teacher_input=teacher_input,
61
+ teacher_weight=teacher_weight,
62
+ target=true_labels,
63
+ student_bias=student_bias,
64
+ teacher_bias=teacher_bias,
65
+ chunk_size=chunk_size,
66
+ weight_hard_loss=weight_hard_loss,
67
+ weight_soft_loss=weight_soft_loss,
68
+ beta=beta,
69
+ ignore_index=ignore_index,
70
+ temperature=temperature,
71
+ compiled=compiled,
72
+ return_soft_hard_loss=return_soft_hard_loss,
73
+ )
74
+
75
+ @staticmethod
76
+ def backward(ctx, grad_output, *args):
77
+ grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
78
+
79
+ return (
80
+ *grads,
81
+ None, # teacher_bias
82
+ None, # weight_hard_loss
83
+ None, # weight_soft_loss
84
+ None, # beta
85
+ None, # ignore_index
86
+ None, # temperature
87
+ None, # compiled
88
+ None, # chunk_size
89
+ None, # return_soft_hard_loss
90
+ )
91
+
92
+
93
+ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
94
+ def __init__(
95
+ self,
96
+ weight_hard_loss: float = 0.5,
97
+ weight_soft_loss: float = 0.5,
98
+ beta: float = 0.5,
99
+ ignore_index: int = -100,
100
+ temperature: float = 1.0,
101
+ compiled: bool = True,
102
+ chunk_size: int = 1024,
103
+ return_soft_hard_loss: bool = False,
104
+ ):
105
+ super().__init__()
106
+ assert temperature != 0, "Temperature cannot be 0."
107
+ self.weight_hard_loss = weight_hard_loss
108
+ self.weight_soft_loss = weight_soft_loss
109
+ self.ignore_index = ignore_index
110
+ self.temperature = temperature
111
+ self.compiled = compiled
112
+ self.beta = beta
113
+ self.chunk_size = chunk_size
114
+ self.return_soft_hard_loss = return_soft_hard_loss
115
+
116
+ def forward(
117
+ self,
118
+ student_input: torch.Tensor,
119
+ student_weight: torch.Tensor,
120
+ teacher_input: torch.Tensor,
121
+ teacher_weight: torch.Tensor,
122
+ true_labels: torch.LongTensor,
123
+ student_bias: torch.Tensor = None,
124
+ teacher_bias: torch.Tensor = None,
125
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
126
+ return LigerFusedLinearCosineSimilarityFunction.apply(
127
+ student_input,
128
+ student_weight,
129
+ teacher_input,
130
+ teacher_weight,
131
+ true_labels,
132
+ student_bias,
133
+ teacher_bias,
134
+ self.weight_hard_loss,
135
+ self.weight_soft_loss,
136
+ self.beta,
137
+ self.ignore_index,
138
+ self.temperature,
139
+ self.compiled,
140
+ self.chunk_size,
141
+ self.return_soft_hard_loss,
142
+ )
@@ -13,6 +13,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
13
13
  ref_chosen_logps=None,
14
14
  ref_rejected_logps=None,
15
15
  beta=0.1,
16
+ loss_type="sigmoid",
16
17
  ):
17
18
  """
18
19
  Paper: https://arxiv.org/pdf/2305.18290
@@ -48,8 +49,50 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
48
49
  chosen_rewards = beta * chosen_logratios
49
50
  rejected_rewards = beta * rejected_logratios
50
51
 
51
- logits_diff = beta * (chosen_logratios - rejected_logratios)
52
- loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
52
+ if loss_type == "sigmoid":
53
+ logits_diff = beta * (chosen_logratios - rejected_logratios)
54
+ loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
55
+
56
+ elif loss_type == "apo_zero":
57
+ # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
58
+ # Use this loss when you believe the chosen outputs are better than your model's default output
59
+ losses_chosen = 1 - F.sigmoid(beta * chosen_logratios) # Increase chosen likelihood
60
+ losses_rejected = F.sigmoid(beta * rejected_logratios)
61
+ losses = losses_chosen + losses_rejected
62
+ loss = losses.sum() / (full_target.shape[0] // 2)
63
+
64
+ elif loss_type == "apo_down":
65
+ # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
66
+ # Use this loss when you believe the chosen outputs are worse than your model's default output.
67
+ # Decrease chosen likelihood and decrease rejected likelihood more
68
+ losses_chosen = F.sigmoid(beta * chosen_logratios)
69
+ losses_rejected = 1 - F.sigmoid(beta * (chosen_logratios - rejected_logratios))
70
+ losses = losses_chosen + losses_rejected
71
+ loss = losses.sum() / (full_target.shape[0] // 2)
72
+
73
+ elif loss_type == "sppo_hard":
74
+ # In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach,
75
+ # estimated using the PairRM score. The probability calculation is conducted outside of the trainer class.
76
+ # The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is
77
+ # set to 1 for the winner and 0 for the loser.
78
+ a = chosen_logps - ref_chosen_logps
79
+ b = rejected_logps - ref_rejected_logps
80
+ losses = (a - 0.5 / beta) ** 2 + (b + 0.5 / beta) ** 2
81
+ loss = losses.sum() / (full_target.shape[0] // 2)
82
+
83
+ elif loss_type == "nca_pair":
84
+ losses = (
85
+ -F.logsigmoid(chosen_rewards)
86
+ - 0.5 * F.logsigmoid(-chosen_rewards)
87
+ - 0.5 * F.logsigmoid(-rejected_rewards)
88
+ )
89
+ loss = losses.sum() / (full_target.shape[0] // 2)
90
+
91
+ else:
92
+ raise ValueError(
93
+ f"Unsupported loss_type: {loss_type}. Supported types are: sigmoid, apo_zero, apo_down, sppo_hard, nca_pair"
94
+ )
95
+
53
96
  return loss, chosen_rewards, rejected_rewards
54
97
 
55
98
  @classmethod
@@ -70,6 +113,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
70
113
  use_ref_model=True,
71
114
  average_log_prob=False,
72
115
  chunk_size=1,
116
+ loss_type="sigmoid",
73
117
  ):
74
118
  """
75
119
  Fused linear layer with DPO loss.
@@ -108,12 +152,13 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
108
152
  ref_bias=ref_bias,
109
153
  average_log_prob=average_log_prob,
110
154
  chunk_size=chunk_size,
155
+ loss_type=loss_type,
111
156
  )
112
157
 
113
158
  @staticmethod
114
159
  def backward(ctx, *grad_output):
115
160
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
116
- return *grads, None, None, None, None, None, None, None, None, None, None
161
+ return *grads, None, None, None, None, None, None, None, None, None, None, None
117
162
 
118
163
 
119
164
  class LigerFusedLinearDPOLoss(torch.nn.Module):
@@ -130,6 +175,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
130
175
  use_ref_model: bool = True,
131
176
  average_log_prob: bool = False,
132
177
  chunk_size: int = 1,
178
+ loss_type: str = "sigmoid",
133
179
  ):
134
180
  """
135
181
  Args:
@@ -149,6 +195,10 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
149
195
  self.use_ref_model = use_ref_model
150
196
  self.average_log_prob = average_log_prob
151
197
  self.chunk_size = chunk_size
198
+ self.loss_type = loss_type
199
+ supported_loss_types = {"sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"}
200
+ if self.loss_type not in supported_loss_types:
201
+ raise ValueError(f"Unsupported loss_type: {self.loss_type}. Supported types are: {supported_loss_types}")
152
202
 
153
203
  def forward(
154
204
  self,
@@ -175,4 +225,5 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
175
225
  self.use_ref_model,
176
226
  self.average_log_prob,
177
227
  self.chunk_size,
228
+ self.loss_type,
178
229
  )
@@ -1,3 +1,4 @@
1
+ from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction
1
2
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
2
3
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3
4
  from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
@@ -9,6 +10,7 @@ from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
9
10
  liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
10
11
  liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
11
12
  liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
13
+ liger_fused_linear_cosine = LigerFusedLinearCosineSimilarityFunction.apply
12
14
  liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
13
15
  liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
14
16
  liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
@@ -1,5 +1,7 @@
1
1
  from abc import abstractmethod
2
2
  from functools import partial
3
+ from typing import Tuple
4
+ from typing import Union
3
5
 
4
6
  import torch
5
7
 
@@ -11,6 +13,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
11
13
  def distillation_loss_fn(
12
14
  student_logits,
13
15
  teacher_logits,
16
+ target=None,
17
+ ignore_index=None,
14
18
  ):
15
19
  """
16
20
  Compute distillation loss.
@@ -130,10 +134,15 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
130
134
  )
131
135
  student_logits_chunk = torch.cat([student_logits_chunk, pad_tensor], dim=-1)
132
136
 
133
- hard_loss /= full_target.shape[0]
137
+ num_valid_tokens = (full_target != ignore_index).sum()
138
+ num_valid_tokens = num_valid_tokens.clamp_min(1) # to avoid division by zero
134
139
 
135
- soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, **loss_kwargs)
136
- soft_loss /= full_target.shape[0]
140
+ hard_loss /= num_valid_tokens
141
+
142
+ soft_loss = distillation_loss_fn(
143
+ student_logits_chunk, teacher_logits_chunk, target=target_chunk, ignore_index=ignore_index, **loss_kwargs
144
+ )
145
+ soft_loss /= num_valid_tokens
137
146
 
138
147
  loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
139
148
  return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk)
@@ -157,8 +166,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
157
166
  compute_ce_loss=True,
158
167
  temperature=1.0,
159
168
  compiled=True,
169
+ return_soft_hard_loss=False,
160
170
  **loss_kwargs,
161
- ):
171
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
162
172
  """
163
173
  Base class for fused linear layer with distillation loss.
164
174
  Only need to compute gradients for student model.
@@ -180,6 +190,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
180
190
  compute_ce_loss (bool): Whether to compute CE loss.
181
191
  temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
182
192
  compiled (bool): Whether to use torch compile for chunk accumulation.
193
+ return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
183
194
  loss_kwargs (dict): Other possible arguments that a loss function might need
184
195
  """
185
196
  CHUNK_SIZE = chunk_size
@@ -187,6 +198,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
187
198
  grad_inputs = []
188
199
  grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
189
200
  loss_acc = torch.zeros((), device=student_input.device)
201
+ soft_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
202
+ hard_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
190
203
 
191
204
  loss_func_to_call = partial(
192
205
  LigerFusedLinearDistillationBase._compute_loss,
@@ -247,6 +260,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
247
260
  )
248
261
  grad_weight.add_(chunk_grad_weight)
249
262
  loss_acc.add_(chunk_loss)
263
+ if return_soft_hard_loss:
264
+ soft_loss_acc.add_(chunk_soft_loss)
265
+ hard_loss_acc.add_(chunk_hard_loss)
250
266
  return chunk_grad_input
251
267
 
252
268
  if compiled:
@@ -268,10 +284,12 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
268
284
  grad_weight,
269
285
  grad_bias,
270
286
  )
287
+ if return_soft_hard_loss:
288
+ return loss_acc, soft_loss_acc, hard_loss_acc
271
289
  return loss_acc
272
290
 
273
291
  @staticmethod
274
- def backward(ctx, grad_output):
292
+ def backward(ctx, grad_output, *args):
275
293
  grad_input, grad_weight, grad_bias = ctx.saved_tensors
276
294
  if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
277
295
  grad_input = grad_input * grad_output
@@ -32,8 +32,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
32
32
  epsilon_low=0.2,
33
33
  epsilon_high=0.2,
34
34
  beta=0.04,
35
- loss_type="bnpo",
35
+ loss_type="dapo",
36
36
  max_completion_length=None,
37
+ importance_sampling_level="token",
37
38
  temperature=1.0,
38
39
  compiled=True,
39
40
  use_ref_model=False,
@@ -59,7 +60,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
59
60
  epsilon_low: Lower bound for clipping the importance sampling ratio
60
61
  epsilon_high: Upper bound for clipping the importance sampling ratio
61
62
  beta: Weight for the KL penalty
62
- loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo")
63
+ loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo")
63
64
  max_completion_length: Maximum completion length required for "dr_grpo"
64
65
  temperature: Temperature for the logits
65
66
  compiled: Whether to use torch compile
@@ -92,6 +93,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
92
93
  beta=beta,
93
94
  loss_type=loss_type,
94
95
  max_completion_length=max_completion_length,
96
+ importance_sampling_level=importance_sampling_level,
95
97
  temperature=temperature,
96
98
  use_ref_model=use_ref_model,
97
99
  ppo_loss_fn=cls.ppo_loss_fn,
@@ -242,6 +244,21 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
242
244
 
243
245
  return loss_acc, tuple(final_metrics)
244
246
 
247
+ @staticmethod
248
+ def _compute_dapo_normalizer(attention_mask):
249
+ """Global active tokens averaged per process."""
250
+ normalizer = attention_mask.to(torch.float32).sum()
251
+ world_size = 1
252
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
253
+ import torch.distributed as dist
254
+
255
+ normalizer = normalizer.clone()
256
+ dist.all_reduce(normalizer, op=dist.ReduceOp.SUM)
257
+ world_size = dist.get_world_size()
258
+
259
+ normalizer = normalizer / world_size
260
+ return torch.clamp(normalizer, min=1.0)
261
+
245
262
  @staticmethod
246
263
  def _compute_chunk_loss(
247
264
  input_chunk,
@@ -259,8 +276,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
259
276
  epsilon_low=0.2,
260
277
  epsilon_high=0.2,
261
278
  beta=0.04,
262
- loss_type="bnpo",
279
+ loss_type="dapo",
263
280
  max_completion_length=None,
281
+ importance_sampling_level="token",
264
282
  temperature=1.0,
265
283
  use_ref_model=False,
266
284
  ppo_loss_fn=None,
@@ -292,6 +310,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
292
310
  beta=beta,
293
311
  loss_type=loss_type,
294
312
  max_completion_length=max_completion_length,
313
+ importance_sampling_level=importance_sampling_level,
295
314
  )
296
315
 
297
316
  return chunk_loss, chunk_metrics
@@ -337,10 +356,11 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
337
356
  None, # grad_epsilon_low
338
357
  None, # grad_epsilon_high
339
358
  None, # grad_beta
359
+ None, # grad_loss_type
360
+ None, # grad_max_completion_length
361
+ None, # grad_importance_sampling_level
340
362
  None, # grad_temperature
341
363
  None, # grad_compiled
342
364
  None, # grad_use_ref_model
343
365
  None, # grad_chunk_size
344
- None, # grad_loss_type
345
- None, # grad_max_completion_length
346
366
  )
@@ -29,8 +29,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
29
29
  epsilon_low=0.2,
30
30
  epsilon_high=0.2,
31
31
  beta=0.04,
32
- loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
32
+ loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo"]
33
33
  max_completion_length=None, # Required for dr_grpo
34
+ importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
34
35
  **kwargs,
35
36
  ):
36
37
  """GRPO Loss Function matching GRPOTrainer implementation."""
@@ -50,7 +51,22 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
50
51
 
51
52
  # Compute policy gradient loss with importance sampling ratio
52
53
  old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
53
- coef_1 = torch.exp(per_token_logps - old_per_token_logps)
54
+ log_ratio = per_token_logps - old_per_token_logps
55
+
56
+ if importance_sampling_level == "token":
57
+ log_importance_weights = log_ratio
58
+ elif importance_sampling_level == "sequence":
59
+ log_importance_weights = (log_ratio * attention_mask).sum(-1) / attention_mask.sum(-1).clamp(min=1.0)
60
+ log_importance_weights = log_importance_weights.unsqueeze(-1)
61
+ else:
62
+ raise ValueError(
63
+ f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
64
+ "and 'sequence'."
65
+ )
66
+
67
+ # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
68
+ # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
69
+ coef_1 = torch.exp(log_importance_weights)
54
70
  coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
55
71
  per_token_loss1 = coef_1 * advantages.unsqueeze(1)
56
72
  per_token_loss2 = coef_2 * advantages.unsqueeze(1)
@@ -78,6 +94,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
78
94
  if max_completion_length is None:
79
95
  raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
80
96
  loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
97
+ elif loss_type == "dapo":
98
+ loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask)
99
+ loss = (per_token_loss * attention_mask).sum() / loss_normalizer
81
100
  else:
82
101
  raise ValueError(f"Unknown loss type: {loss_type}")
83
102
 
@@ -85,9 +104,19 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
85
104
  metrics = []
86
105
  if beta != 0.0:
87
106
  metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
88
- is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
89
- (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
90
- )
107
+
108
+ # Adjust clipping metric calculation based on importance sampling level
109
+ if importance_sampling_level == "token":
110
+ is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
111
+ (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
112
+ )
113
+ else: # sequence level
114
+ # For sequence level, coef_1 is shape (B, 1), advantages is shape (B,)
115
+ is_clipped = ((coef_1.squeeze(-1) < 1 - epsilon_low) & (advantages < 0)) | (
116
+ (coef_1.squeeze(-1) > 1 + epsilon_high) & (advantages > 0)
117
+ )
118
+ is_clipped = is_clipped.unsqueeze(1).expand_as(attention_mask)
119
+
91
120
  metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
92
121
  return loss, metrics
93
122
 
@@ -109,8 +138,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
109
138
  beta=0.04,
110
139
  epsilon_low=0.2,
111
140
  epsilon_high=0.2,
112
- loss_type="bnpo",
141
+ loss_type="dapo",
113
142
  max_completion_length=None,
143
+ importance_sampling_level="token",
114
144
  temperature=1.0,
115
145
  compiled=True,
116
146
  use_ref_model=True,
@@ -130,8 +160,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
130
160
  ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
131
161
  ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
132
162
  beta (float): Weight for the KL penalty
133
- loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
163
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
134
164
  max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
165
+ importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
135
166
  temperature (float): Temperature for the logits
136
167
  compiled (bool): Whether to use torch compile
137
168
  use_ref_model (bool): Whether to use a reference model
@@ -162,6 +193,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
162
193
  compiled=compiled,
163
194
  use_ref_model=use_ref_model,
164
195
  chunk_size=chunk_size,
196
+ importance_sampling_level=importance_sampling_level,
165
197
  )
166
198
 
167
199
  @staticmethod
@@ -187,6 +219,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
187
219
  None, # grad_epsilon_high
188
220
  None, # grad_loss_type (string, not differentiable)
189
221
  None, # grad_max_completion_length (int, not differentiable)
222
+ None, # grad_importance_sampling_level (string, not differentiable)
190
223
  None, # grad_temperature
191
224
  None, # grad_compiled
192
225
  None, # grad_use_ref_model
@@ -205,8 +238,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
205
238
  chunk_size: int = 1,
206
239
  epsilon_low: float = 0.2,
207
240
  epsilon_high: float = 0.2,
208
- loss_type: str = "bnpo",
241
+ loss_type: str = "dapo",
209
242
  max_completion_length: Optional[int] = None,
243
+ importance_sampling_level: str = "token",
210
244
  temperature: float = 1.0,
211
245
  ):
212
246
  """
@@ -217,8 +251,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
217
251
  chunk_size (int): Size of chunks for processing.
218
252
  epsilon_low (float): Lower bound for the importance sampling ratio.
219
253
  epsilon_high (float): Upper bound for the importance sampling ratio.
220
- loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
254
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
221
255
  max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
256
+ importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
222
257
  temperature (float): Temperature for the logits.
223
258
  """
224
259
  super().__init__()
@@ -230,6 +265,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
230
265
  self.epsilon_high = epsilon_high
231
266
  self.loss_type = loss_type
232
267
  self.max_completion_length = max_completion_length
268
+ self.importance_sampling_level = importance_sampling_level
233
269
  self.temperature = temperature
234
270
 
235
271
  def forward(
@@ -263,6 +299,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
263
299
  self.epsilon_high,
264
300
  self.loss_type,
265
301
  self.max_completion_length,
302
+ self.importance_sampling_level,
266
303
  self.temperature,
267
304
  self.compiled,
268
305
  self.use_ref_model,