liger-kernel-nightly 0.5.5.dev20250402185702__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (115) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +36 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +46 -15
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  15. liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
  16. liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
  17. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  18. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  19. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  20. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  21. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  22. liger_kernel/ops/backends/registry.py +61 -0
  23. liger_kernel/ops/cross_entropy.py +134 -65
  24. liger_kernel/ops/dyt.py +115 -180
  25. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  26. liger_kernel/ops/fused_linear_cross_entropy.py +117 -23
  27. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  28. liger_kernel/ops/geglu.py +6 -4
  29. liger_kernel/ops/group_norm.py +7 -7
  30. liger_kernel/ops/grpo_loss.py +312 -0
  31. liger_kernel/ops/jsd.py +2 -1
  32. liger_kernel/ops/kl_div.py +9 -5
  33. liger_kernel/ops/layer_norm.py +146 -78
  34. liger_kernel/ops/llama4_rope.py +225 -0
  35. liger_kernel/ops/multi_token_attention.py +207 -0
  36. liger_kernel/ops/poly_norm.py +390 -0
  37. liger_kernel/ops/rms_norm.py +398 -99
  38. liger_kernel/ops/rope.py +1 -1
  39. liger_kernel/ops/softmax.py +201 -0
  40. liger_kernel/ops/sparsemax.py +179 -0
  41. liger_kernel/ops/swiglu.py +1 -1
  42. liger_kernel/ops/tiled_mlp.py +136 -0
  43. liger_kernel/ops/utils.py +14 -0
  44. liger_kernel/transformers/__init__.py +208 -17
  45. liger_kernel/transformers/auto_model.py +21 -0
  46. liger_kernel/transformers/cross_entropy.py +9 -4
  47. liger_kernel/transformers/dyt.py +6 -4
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -1
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +122 -20
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -1
  57. liger_kernel/transformers/group_norm.py +1 -1
  58. liger_kernel/transformers/grpo_loss.py +153 -0
  59. liger_kernel/transformers/jsd.py +1 -1
  60. liger_kernel/transformers/kl_div.py +1 -1
  61. liger_kernel/transformers/layer_norm.py +1 -1
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/exaone4.py +136 -0
  64. liger_kernel/transformers/model/falcon_h1.py +122 -0
  65. liger_kernel/transformers/model/gemma.py +57 -27
  66. liger_kernel/transformers/model/gemma2.py +65 -28
  67. liger_kernel/transformers/model/gemma3.py +331 -0
  68. liger_kernel/transformers/model/glm4.py +141 -0
  69. liger_kernel/transformers/model/glm4v.py +163 -0
  70. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  71. liger_kernel/transformers/model/gpt_oss.py +211 -0
  72. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  73. liger_kernel/transformers/model/internvl.py +157 -0
  74. liger_kernel/transformers/model/llama.py +109 -27
  75. liger_kernel/transformers/model/llama4.py +121 -0
  76. liger_kernel/transformers/model/llava.py +111 -136
  77. liger_kernel/transformers/model/loss_utils.py +50 -12
  78. liger_kernel/transformers/model/mistral.py +51 -34
  79. liger_kernel/transformers/model/mixtral.py +50 -29
  80. liger_kernel/transformers/model/mllama.py +46 -24
  81. liger_kernel/transformers/model/olmo2.py +47 -22
  82. liger_kernel/transformers/model/olmo3.py +142 -0
  83. liger_kernel/transformers/model/output_classes.py +147 -0
  84. liger_kernel/transformers/model/paligemma.py +50 -14
  85. liger_kernel/transformers/model/phi3.py +47 -172
  86. liger_kernel/transformers/model/qwen2.py +55 -23
  87. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  88. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  89. liger_kernel/transformers/model/qwen3.py +136 -0
  90. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  91. liger_kernel/transformers/model/qwen3_next.py +146 -0
  92. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  93. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  94. liger_kernel/transformers/model/smollm3.py +199 -0
  95. liger_kernel/transformers/model/smolvlm.py +158 -0
  96. liger_kernel/transformers/monkey_patch.py +2018 -244
  97. liger_kernel/transformers/multi_token_attention.py +64 -0
  98. liger_kernel/transformers/poly_norm.py +42 -0
  99. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  100. liger_kernel/transformers/rms_norm.py +54 -6
  101. liger_kernel/transformers/rope.py +45 -1
  102. liger_kernel/transformers/softmax.py +12 -0
  103. liger_kernel/transformers/sparsemax.py +16 -0
  104. liger_kernel/transformers/swiglu.py +39 -1
  105. liger_kernel/transformers/tiled_mlp.py +125 -0
  106. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  107. liger_kernel/transformers/tvd.py +1 -1
  108. liger_kernel/utils.py +63 -0
  109. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +73 -39
  110. liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
  111. liger_kernel_nightly-0.5.5.dev20250402185702.dist-info/RECORD +0 -80
  112. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
  115. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityLoss # noqa:F401
1
2
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
2
3
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
3
4
  from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
@@ -0,0 +1,142 @@
1
+ from typing import Tuple
2
+ from typing import Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
8
+
9
+
10
+ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase):
11
+ @staticmethod
12
+ def distillation_loss_fn(
13
+ student_logits,
14
+ teacher_logits,
15
+ target=None,
16
+ ignore_index=None,
17
+ beta=1.0,
18
+ ):
19
+ """
20
+ Compute Cosine loss (Cosine Similarity Loss).
21
+ Args:
22
+ student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
23
+ teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
24
+ beta: Coefficient beta of generalized Cosine Similarity in the interval [0, 1]. Default: `1.0` (float): .
25
+ Returns:
26
+ torch.Tensor: cosine similarity loss
27
+ """
28
+ student_norm = F.normalize(student_logits, p=2, dim=-1)
29
+ teacher_norm = F.normalize(teacher_logits, p=2, dim=-1)
30
+
31
+ cosine_sim = F.cosine_similarity(student_norm, teacher_norm, dim=-1)
32
+ loss = beta * (1 - cosine_sim)
33
+ return loss.sum()
34
+
35
+ @classmethod
36
+ def forward(
37
+ cls,
38
+ ctx,
39
+ student_input: torch.Tensor,
40
+ student_weight: torch.Tensor,
41
+ teacher_input: torch.Tensor,
42
+ teacher_weight: torch.Tensor,
43
+ true_labels: torch.LongTensor,
44
+ student_bias: torch.Tensor,
45
+ teacher_bias: torch.Tensor,
46
+ weight_hard_loss: float = 0.5,
47
+ weight_soft_loss: float = 0.5,
48
+ beta: float = 0.5,
49
+ ignore_index: int = -100,
50
+ temperature: float = 1.0,
51
+ compiled: bool = True,
52
+ chunk_size: int = 1024,
53
+ return_soft_hard_loss: bool = False,
54
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
55
+ return super().forward(
56
+ cls=cls,
57
+ ctx=ctx,
58
+ student_input=student_input,
59
+ student_weight=student_weight,
60
+ teacher_input=teacher_input,
61
+ teacher_weight=teacher_weight,
62
+ target=true_labels,
63
+ student_bias=student_bias,
64
+ teacher_bias=teacher_bias,
65
+ chunk_size=chunk_size,
66
+ weight_hard_loss=weight_hard_loss,
67
+ weight_soft_loss=weight_soft_loss,
68
+ beta=beta,
69
+ ignore_index=ignore_index,
70
+ temperature=temperature,
71
+ compiled=compiled,
72
+ return_soft_hard_loss=return_soft_hard_loss,
73
+ )
74
+
75
+ @staticmethod
76
+ def backward(ctx, grad_output, *args):
77
+ grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
78
+
79
+ return (
80
+ *grads,
81
+ None, # teacher_bias
82
+ None, # weight_hard_loss
83
+ None, # weight_soft_loss
84
+ None, # beta
85
+ None, # ignore_index
86
+ None, # temperature
87
+ None, # compiled
88
+ None, # chunk_size
89
+ None, # return_soft_hard_loss
90
+ )
91
+
92
+
93
+ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
94
+ def __init__(
95
+ self,
96
+ weight_hard_loss: float = 0.5,
97
+ weight_soft_loss: float = 0.5,
98
+ beta: float = 0.5,
99
+ ignore_index: int = -100,
100
+ temperature: float = 1.0,
101
+ compiled: bool = True,
102
+ chunk_size: int = 1024,
103
+ return_soft_hard_loss: bool = False,
104
+ ):
105
+ super().__init__()
106
+ assert temperature != 0, "Temperature cannot be 0."
107
+ self.weight_hard_loss = weight_hard_loss
108
+ self.weight_soft_loss = weight_soft_loss
109
+ self.ignore_index = ignore_index
110
+ self.temperature = temperature
111
+ self.compiled = compiled
112
+ self.beta = beta
113
+ self.chunk_size = chunk_size
114
+ self.return_soft_hard_loss = return_soft_hard_loss
115
+
116
+ def forward(
117
+ self,
118
+ student_input: torch.Tensor,
119
+ student_weight: torch.Tensor,
120
+ teacher_input: torch.Tensor,
121
+ teacher_weight: torch.Tensor,
122
+ true_labels: torch.LongTensor,
123
+ student_bias: torch.Tensor = None,
124
+ teacher_bias: torch.Tensor = None,
125
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
126
+ return LigerFusedLinearCosineSimilarityFunction.apply(
127
+ student_input,
128
+ student_weight,
129
+ teacher_input,
130
+ teacher_weight,
131
+ true_labels,
132
+ student_bias,
133
+ teacher_bias,
134
+ self.weight_hard_loss,
135
+ self.weight_soft_loss,
136
+ self.beta,
137
+ self.ignore_index,
138
+ self.temperature,
139
+ self.compiled,
140
+ self.chunk_size,
141
+ self.return_soft_hard_loss,
142
+ )
@@ -13,6 +13,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
13
13
  ref_chosen_logps=None,
14
14
  ref_rejected_logps=None,
15
15
  beta=0.1,
16
+ loss_type="sigmoid",
16
17
  ):
17
18
  """
18
19
  Paper: https://arxiv.org/pdf/2305.18290
@@ -48,8 +49,50 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
48
49
  chosen_rewards = beta * chosen_logratios
49
50
  rejected_rewards = beta * rejected_logratios
50
51
 
51
- logits_diff = beta * (chosen_logratios - rejected_logratios)
52
- loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
52
+ if loss_type == "sigmoid":
53
+ logits_diff = beta * (chosen_logratios - rejected_logratios)
54
+ loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
55
+
56
+ elif loss_type == "apo_zero":
57
+ # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
58
+ # Use this loss when you believe the chosen outputs are better than your model's default output
59
+ losses_chosen = 1 - F.sigmoid(beta * chosen_logratios) # Increase chosen likelihood
60
+ losses_rejected = F.sigmoid(beta * rejected_logratios)
61
+ losses = losses_chosen + losses_rejected
62
+ loss = losses.sum() / (full_target.shape[0] // 2)
63
+
64
+ elif loss_type == "apo_down":
65
+ # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
66
+ # Use this loss when you believe the chosen outputs are worse than your model's default output.
67
+ # Decrease chosen likelihood and decrease rejected likelihood more
68
+ losses_chosen = F.sigmoid(beta * chosen_logratios)
69
+ losses_rejected = 1 - F.sigmoid(beta * (chosen_logratios - rejected_logratios))
70
+ losses = losses_chosen + losses_rejected
71
+ loss = losses.sum() / (full_target.shape[0] // 2)
72
+
73
+ elif loss_type == "sppo_hard":
74
+ # In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach,
75
+ # estimated using the PairRM score. The probability calculation is conducted outside of the trainer class.
76
+ # The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is
77
+ # set to 1 for the winner and 0 for the loser.
78
+ a = chosen_logps - ref_chosen_logps
79
+ b = rejected_logps - ref_rejected_logps
80
+ losses = (a - 0.5 / beta) ** 2 + (b + 0.5 / beta) ** 2
81
+ loss = losses.sum() / (full_target.shape[0] // 2)
82
+
83
+ elif loss_type == "nca_pair":
84
+ losses = (
85
+ -F.logsigmoid(chosen_rewards)
86
+ - 0.5 * F.logsigmoid(-chosen_rewards)
87
+ - 0.5 * F.logsigmoid(-rejected_rewards)
88
+ )
89
+ loss = losses.sum() / (full_target.shape[0] // 2)
90
+
91
+ else:
92
+ raise ValueError(
93
+ f"Unsupported loss_type: {loss_type}. Supported types are: sigmoid, apo_zero, apo_down, sppo_hard, nca_pair"
94
+ )
95
+
53
96
  return loss, chosen_rewards, rejected_rewards
54
97
 
55
98
  @classmethod
@@ -68,7 +111,9 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
68
111
  compute_nll_loss=False,
69
112
  compiled=True,
70
113
  use_ref_model=True,
114
+ average_log_prob=False,
71
115
  chunk_size=1,
116
+ loss_type="sigmoid",
72
117
  ):
73
118
  """
74
119
  Fused linear layer with DPO loss.
@@ -85,6 +130,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
85
130
  compute_nll_loss (bool): Whether to compute the NLL loss
86
131
  compiled (bool): Whether to use torch compile
87
132
  use_ref_model (bool): Whether to use a reference model
133
+ average_log_prob (bool): Whether to average the log probability per non-masked token
88
134
  chunk_size (int): Size of chunks for processing.
89
135
  Returns:
90
136
  torch.Tensor: Computed loss
@@ -104,13 +150,15 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
104
150
  ref_input=ref_input,
105
151
  ref_weight=ref_weight,
106
152
  ref_bias=ref_bias,
153
+ average_log_prob=average_log_prob,
107
154
  chunk_size=chunk_size,
155
+ loss_type=loss_type,
108
156
  )
109
157
 
110
158
  @staticmethod
111
159
  def backward(ctx, *grad_output):
112
160
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
113
- return *grads, None, None, None, None, None, None, None, None, None
161
+ return *grads, None, None, None, None, None, None, None, None, None, None, None
114
162
 
115
163
 
116
164
  class LigerFusedLinearDPOLoss(torch.nn.Module):
@@ -125,7 +173,9 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
125
173
  compute_nll_loss: bool = False,
126
174
  compiled: bool = True,
127
175
  use_ref_model: bool = True,
176
+ average_log_prob: bool = False,
128
177
  chunk_size: int = 1,
178
+ loss_type: str = "sigmoid",
129
179
  ):
130
180
  """
131
181
  Args:
@@ -134,6 +184,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
134
184
  compute_nll_loss (bool): Whether to compute the NLL loss.
135
185
  compiled (bool): Whether to use the torch compiled kernel.
136
186
  use_ref_model (bool): Whether to use a reference model for the DPO loss.
187
+ average_log_prob (bool): Whether to average the log probability per non-masked token.
137
188
  chunk_size (int): Size of chunks for processing.
138
189
  """
139
190
  super().__init__()
@@ -142,7 +193,12 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
142
193
  self.compute_nll_loss = compute_nll_loss
143
194
  self.compiled = compiled
144
195
  self.use_ref_model = use_ref_model
196
+ self.average_log_prob = average_log_prob
145
197
  self.chunk_size = chunk_size
198
+ self.loss_type = loss_type
199
+ supported_loss_types = {"sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"}
200
+ if self.loss_type not in supported_loss_types:
201
+ raise ValueError(f"Unsupported loss_type: {self.loss_type}. Supported types are: {supported_loss_types}")
146
202
 
147
203
  def forward(
148
204
  self,
@@ -167,5 +223,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
167
223
  self.compute_nll_loss,
168
224
  self.compiled,
169
225
  self.use_ref_model,
226
+ self.average_log_prob,
170
227
  self.chunk_size,
228
+ self.loss_type,
171
229
  )
@@ -1,3 +1,4 @@
1
+ from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction
1
2
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
2
3
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3
4
  from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
@@ -9,6 +10,7 @@ from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
9
10
  liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
10
11
  liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
11
12
  liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
13
+ liger_fused_linear_cosine = LigerFusedLinearCosineSimilarityFunction.apply
12
14
  liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
13
15
  liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
14
16
  liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
@@ -1,5 +1,7 @@
1
1
  from abc import abstractmethod
2
2
  from functools import partial
3
+ from typing import Tuple
4
+ from typing import Union
3
5
 
4
6
  import torch
5
7
 
@@ -11,6 +13,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
11
13
  def distillation_loss_fn(
12
14
  student_logits,
13
15
  teacher_logits,
16
+ target=None,
17
+ ignore_index=None,
14
18
  ):
15
19
  """
16
20
  Compute distillation loss.
@@ -130,10 +134,15 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
130
134
  )
131
135
  student_logits_chunk = torch.cat([student_logits_chunk, pad_tensor], dim=-1)
132
136
 
133
- hard_loss /= full_target.shape[0]
137
+ num_valid_tokens = (full_target != ignore_index).sum()
138
+ num_valid_tokens = num_valid_tokens.clamp_min(1) # to avoid division by zero
134
139
 
135
- soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, **loss_kwargs)
136
- soft_loss /= full_target.shape[0]
140
+ hard_loss /= num_valid_tokens
141
+
142
+ soft_loss = distillation_loss_fn(
143
+ student_logits_chunk, teacher_logits_chunk, target=target_chunk, ignore_index=ignore_index, **loss_kwargs
144
+ )
145
+ soft_loss /= num_valid_tokens
137
146
 
138
147
  loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
139
148
  return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk)
@@ -157,8 +166,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
157
166
  compute_ce_loss=True,
158
167
  temperature=1.0,
159
168
  compiled=True,
169
+ return_soft_hard_loss=False,
160
170
  **loss_kwargs,
161
- ):
171
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
162
172
  """
163
173
  Base class for fused linear layer with distillation loss.
164
174
  Only need to compute gradients for student model.
@@ -180,6 +190,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
180
190
  compute_ce_loss (bool): Whether to compute CE loss.
181
191
  temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
182
192
  compiled (bool): Whether to use torch compile for chunk accumulation.
193
+ return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
183
194
  loss_kwargs (dict): Other possible arguments that a loss function might need
184
195
  """
185
196
  CHUNK_SIZE = chunk_size
@@ -187,6 +198,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
187
198
  grad_inputs = []
188
199
  grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
189
200
  loss_acc = torch.zeros((), device=student_input.device)
201
+ soft_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
202
+ hard_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
190
203
 
191
204
  loss_func_to_call = partial(
192
205
  LigerFusedLinearDistillationBase._compute_loss,
@@ -247,6 +260,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
247
260
  )
248
261
  grad_weight.add_(chunk_grad_weight)
249
262
  loss_acc.add_(chunk_loss)
263
+ if return_soft_hard_loss:
264
+ soft_loss_acc.add_(chunk_soft_loss)
265
+ hard_loss_acc.add_(chunk_hard_loss)
250
266
  return chunk_grad_input
251
267
 
252
268
  if compiled:
@@ -268,10 +284,12 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
268
284
  grad_weight,
269
285
  grad_bias,
270
286
  )
287
+ if return_soft_hard_loss:
288
+ return loss_acc, soft_loss_acc, hard_loss_acc
271
289
  return loss_acc
272
290
 
273
291
  @staticmethod
274
- def backward(ctx, grad_output):
292
+ def backward(ctx, grad_output, *args):
275
293
  grad_input, grad_weight, grad_bias = ctx.saved_tensors
276
294
  if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
277
295
  grad_input = grad_input * grad_output
@@ -32,11 +32,15 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
32
32
  epsilon_low=0.2,
33
33
  epsilon_high=0.2,
34
34
  beta=0.04,
35
+ loss_type="dapo",
36
+ max_completion_length=None,
37
+ importance_sampling_level="token",
35
38
  temperature=1.0,
36
39
  compiled=True,
37
40
  use_ref_model=False,
38
41
  chunk_size=1,
39
42
  ):
43
+ # TODO: check torch compile matmul
40
44
  """Chunked forward pass for PPO loss computation.
41
45
 
42
46
  Args:
@@ -56,6 +60,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
56
60
  epsilon_low: Lower bound for clipping the importance sampling ratio
57
61
  epsilon_high: Upper bound for clipping the importance sampling ratio
58
62
  beta: Weight for the KL penalty
63
+ loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo")
64
+ max_completion_length: Maximum completion length required for "dr_grpo"
59
65
  temperature: Temperature for the logits
60
66
  compiled: Whether to use torch compile
61
67
  use_ref_model: Whether to use a reference model
@@ -67,6 +73,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
67
73
  )
68
74
  if ref_per_token_logps is not None and ref_input is not None:
69
75
  raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
76
+ if loss_type == "dr_grpo":
77
+ assert max_completion_length is not None, "max_completion_length must be provided for loss_type 'dr_grpo'"
70
78
  # Initialize accumulators
71
79
  loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
72
80
  grad_weight = torch.zeros_like(weight) # [V, H]
@@ -83,6 +91,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
83
91
  epsilon_low=epsilon_low,
84
92
  epsilon_high=epsilon_high,
85
93
  beta=beta,
94
+ loss_type=loss_type,
95
+ max_completion_length=max_completion_length,
96
+ importance_sampling_level=importance_sampling_level,
86
97
  temperature=temperature,
87
98
  use_ref_model=use_ref_model,
88
99
  ppo_loss_fn=cls.ppo_loss_fn,
@@ -233,6 +244,21 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
233
244
 
234
245
  return loss_acc, tuple(final_metrics)
235
246
 
247
+ @staticmethod
248
+ def _compute_dapo_normalizer(attention_mask):
249
+ """Global active tokens averaged per process."""
250
+ normalizer = attention_mask.to(torch.float32).sum()
251
+ world_size = 1
252
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
253
+ import torch.distributed as dist
254
+
255
+ normalizer = normalizer.clone()
256
+ dist.all_reduce(normalizer, op=dist.ReduceOp.SUM)
257
+ world_size = dist.get_world_size()
258
+
259
+ normalizer = normalizer / world_size
260
+ return torch.clamp(normalizer, min=1.0)
261
+
236
262
  @staticmethod
237
263
  def _compute_chunk_loss(
238
264
  input_chunk,
@@ -250,6 +276,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
250
276
  epsilon_low=0.2,
251
277
  epsilon_high=0.2,
252
278
  beta=0.04,
279
+ loss_type="dapo",
280
+ max_completion_length=None,
281
+ importance_sampling_level="token",
253
282
  temperature=1.0,
254
283
  use_ref_model=False,
255
284
  ppo_loss_fn=None,
@@ -279,6 +308,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
279
308
  epsilon_low=epsilon_low,
280
309
  epsilon_high=epsilon_high,
281
310
  beta=beta,
311
+ loss_type=loss_type,
312
+ max_completion_length=max_completion_length,
313
+ importance_sampling_level=importance_sampling_level,
282
314
  )
283
315
 
284
316
  return chunk_loss, chunk_metrics
@@ -302,6 +334,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
302
334
  def backward(ctx, grad_output, *grad_metrics):
303
335
  """Backward pass for PPO loss."""
304
336
  grad_input, grad_weight, grad_bias = ctx.saved_tensors
337
+
305
338
  if grad_output != 1.0:
306
339
  grad_input = grad_input * grad_output
307
340
  grad_weight = grad_weight * grad_output
@@ -323,6 +356,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
323
356
  None, # grad_epsilon_low
324
357
  None, # grad_epsilon_high
325
358
  None, # grad_beta
359
+ None, # grad_loss_type
360
+ None, # grad_max_completion_length
361
+ None, # grad_importance_sampling_level
326
362
  None, # grad_temperature
327
363
  None, # grad_compiled
328
364
  None, # grad_use_ref_model
@@ -222,7 +222,6 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
222
222
  (_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
223
223
  (_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
224
224
  (_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)),
225
- strict=False,
226
225
  ):
227
226
  input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
228
227
  ref_input_chunk = (
@@ -1,3 +1,5 @@
1
+ from typing import Optional
2
+
1
3
  import torch
2
4
 
3
5
  from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
@@ -27,6 +29,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
27
29
  epsilon_low=0.2,
28
30
  epsilon_high=0.2,
29
31
  beta=0.04,
32
+ loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo"]
33
+ max_completion_length=None, # Required for dr_grpo
34
+ importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
30
35
  **kwargs,
31
36
  ):
32
37
  """GRPO Loss Function matching GRPOTrainer implementation."""
@@ -46,7 +51,22 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
46
51
 
47
52
  # Compute policy gradient loss with importance sampling ratio
48
53
  old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
49
- coef_1 = torch.exp(per_token_logps - old_per_token_logps)
54
+ log_ratio = per_token_logps - old_per_token_logps
55
+
56
+ if importance_sampling_level == "token":
57
+ log_importance_weights = log_ratio
58
+ elif importance_sampling_level == "sequence":
59
+ log_importance_weights = (log_ratio * attention_mask).sum(-1) / attention_mask.sum(-1).clamp(min=1.0)
60
+ log_importance_weights = log_importance_weights.unsqueeze(-1)
61
+ else:
62
+ raise ValueError(
63
+ f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
64
+ "and 'sequence'."
65
+ )
66
+
67
+ # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
68
+ # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
69
+ coef_1 = torch.exp(log_importance_weights)
50
70
  coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
51
71
  per_token_loss1 = coef_1 * advantages.unsqueeze(1)
52
72
  per_token_loss2 = coef_2 * advantages.unsqueeze(1)
@@ -61,15 +81,42 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
61
81
  # which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
62
82
  # and TRL GRPO implementation
63
83
  # (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
64
- loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
84
+ if loss_type == "grpo":
85
+ # Average per-sequence loss
86
+ loss = (
87
+ (per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)
88
+ ).sum() / full_attention_mask.shape[0]
89
+ elif loss_type == "bnpo":
90
+ # Batch Normalized Per-token loss (original implementation)
91
+ loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
92
+ elif loss_type == "dr_grpo":
93
+ # Dimension-Reduced GRPO (normalize by batch_size * max_completion_length)
94
+ if max_completion_length is None:
95
+ raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
96
+ loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
97
+ elif loss_type == "dapo":
98
+ loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask)
99
+ loss = (per_token_loss * attention_mask).sum() / loss_normalizer
100
+ else:
101
+ raise ValueError(f"Unknown loss type: {loss_type}")
65
102
 
66
103
  # Calculate metrics
67
104
  metrics = []
68
105
  if beta != 0.0:
69
106
  metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
70
- is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
71
- (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
72
- )
107
+
108
+ # Adjust clipping metric calculation based on importance sampling level
109
+ if importance_sampling_level == "token":
110
+ is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
111
+ (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
112
+ )
113
+ else: # sequence level
114
+ # For sequence level, coef_1 is shape (B, 1), advantages is shape (B,)
115
+ is_clipped = ((coef_1.squeeze(-1) < 1 - epsilon_low) & (advantages < 0)) | (
116
+ (coef_1.squeeze(-1) > 1 + epsilon_high) & (advantages > 0)
117
+ )
118
+ is_clipped = is_clipped.unsqueeze(1).expand_as(attention_mask)
119
+
73
120
  metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
74
121
  return loss, metrics
75
122
 
@@ -91,6 +138,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
91
138
  beta=0.04,
92
139
  epsilon_low=0.2,
93
140
  epsilon_high=0.2,
141
+ loss_type="dapo",
142
+ max_completion_length=None,
143
+ importance_sampling_level="token",
94
144
  temperature=1.0,
95
145
  compiled=True,
96
146
  use_ref_model=True,
@@ -110,6 +160,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
110
160
  ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
111
161
  ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
112
162
  beta (float): Weight for the KL penalty
163
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
164
+ max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
165
+ importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
113
166
  temperature (float): Temperature for the logits
114
167
  compiled (bool): Whether to use torch compile
115
168
  use_ref_model (bool): Whether to use a reference model
@@ -134,10 +187,13 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
134
187
  beta=beta,
135
188
  epsilon_low=epsilon_low,
136
189
  epsilon_high=epsilon_high,
190
+ loss_type=loss_type,
191
+ max_completion_length=max_completion_length,
137
192
  temperature=temperature,
138
193
  compiled=compiled,
139
194
  use_ref_model=use_ref_model,
140
195
  chunk_size=chunk_size,
196
+ importance_sampling_level=importance_sampling_level,
141
197
  )
142
198
 
143
199
  @staticmethod
@@ -161,6 +217,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
161
217
  None, # grad_beta
162
218
  None, # grad_epsilon_low
163
219
  None, # grad_epsilon_high
220
+ None, # grad_loss_type (string, not differentiable)
221
+ None, # grad_max_completion_length (int, not differentiable)
222
+ None, # grad_importance_sampling_level (string, not differentiable)
164
223
  None, # grad_temperature
165
224
  None, # grad_compiled
166
225
  None, # grad_use_ref_model
@@ -179,6 +238,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
179
238
  chunk_size: int = 1,
180
239
  epsilon_low: float = 0.2,
181
240
  epsilon_high: float = 0.2,
241
+ loss_type: str = "dapo",
242
+ max_completion_length: Optional[int] = None,
243
+ importance_sampling_level: str = "token",
182
244
  temperature: float = 1.0,
183
245
  ):
184
246
  """
@@ -189,6 +251,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
189
251
  chunk_size (int): Size of chunks for processing.
190
252
  epsilon_low (float): Lower bound for the importance sampling ratio.
191
253
  epsilon_high (float): Upper bound for the importance sampling ratio.
254
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
255
+ max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
256
+ importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
192
257
  temperature (float): Temperature for the logits.
193
258
  """
194
259
  super().__init__()
@@ -198,6 +263,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
198
263
  self.chunk_size = chunk_size
199
264
  self.epsilon_low = epsilon_low
200
265
  self.epsilon_high = epsilon_high
266
+ self.loss_type = loss_type
267
+ self.max_completion_length = max_completion_length
268
+ self.importance_sampling_level = importance_sampling_level
201
269
  self.temperature = temperature
202
270
 
203
271
  def forward(
@@ -229,6 +297,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
229
297
  self.beta,
230
298
  self.epsilon_low,
231
299
  self.epsilon_high,
300
+ self.loss_type,
301
+ self.max_completion_length,
302
+ self.importance_sampling_level,
232
303
  self.temperature,
233
304
  self.compiled,
234
305
  self.use_ref_model,