liger-kernel-nightly 0.5.10.dev20250624183504__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (68) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +54 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +4 -0
  7. liger_kernel/chunked_loss/grpo_loss.py +38 -4
  8. liger_kernel/chunked_loss/jsd_loss.py +23 -7
  9. liger_kernel/ops/cross_entropy.py +118 -62
  10. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  11. liger_kernel/ops/fused_linear_cross_entropy.py +113 -21
  12. liger_kernel/ops/geglu.py +1 -1
  13. liger_kernel/ops/layer_norm.py +124 -89
  14. liger_kernel/ops/llama4_rope.py +225 -0
  15. liger_kernel/ops/poly_norm.py +386 -0
  16. liger_kernel/ops/rms_norm.py +2 -2
  17. liger_kernel/ops/rope.py +1 -1
  18. liger_kernel/ops/swiglu.py +1 -1
  19. liger_kernel/ops/tiled_mlp.py +136 -0
  20. liger_kernel/transformers/__init__.py +50 -0
  21. liger_kernel/transformers/cross_entropy.py +8 -3
  22. liger_kernel/transformers/experimental/__init__.py +5 -0
  23. liger_kernel/transformers/functional.py +38 -6
  24. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  25. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -4
  26. liger_kernel/transformers/llama4_rope.py +93 -0
  27. liger_kernel/transformers/model/falcon_h1.py +122 -0
  28. liger_kernel/transformers/model/gemma.py +28 -8
  29. liger_kernel/transformers/model/gemma2.py +31 -8
  30. liger_kernel/transformers/model/gemma3.py +100 -110
  31. liger_kernel/transformers/model/glm4.py +18 -5
  32. liger_kernel/transformers/model/glm4v.py +163 -0
  33. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  34. liger_kernel/transformers/model/internvl.py +157 -0
  35. liger_kernel/transformers/model/llama.py +26 -7
  36. liger_kernel/transformers/model/llama4.py +121 -0
  37. liger_kernel/transformers/model/llava.py +18 -6
  38. liger_kernel/transformers/model/loss_utils.py +34 -3
  39. liger_kernel/transformers/model/mistral.py +17 -10
  40. liger_kernel/transformers/model/mixtral.py +24 -9
  41. liger_kernel/transformers/model/mllama.py +18 -7
  42. liger_kernel/transformers/model/olmo2.py +18 -5
  43. liger_kernel/transformers/model/output_classes.py +147 -0
  44. liger_kernel/transformers/model/paligemma.py +41 -5
  45. liger_kernel/transformers/model/phi3.py +24 -159
  46. liger_kernel/transformers/model/qwen2.py +26 -4
  47. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  48. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  49. liger_kernel/transformers/model/qwen3.py +22 -6
  50. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  51. liger_kernel/transformers/model/qwen3_next.py +146 -0
  52. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  53. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  54. liger_kernel/transformers/model/smollm3.py +199 -0
  55. liger_kernel/transformers/model/smolvlm.py +158 -0
  56. liger_kernel/transformers/monkey_patch.py +1090 -116
  57. liger_kernel/transformers/multi_token_attention.py +1 -1
  58. liger_kernel/transformers/poly_norm.py +42 -0
  59. liger_kernel/transformers/rms_norm.py +7 -0
  60. liger_kernel/transformers/rope.py +43 -0
  61. liger_kernel/transformers/tiled_mlp.py +133 -0
  62. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +26 -24
  63. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  64. liger_kernel_nightly-0.5.10.dev20250624183504.dist-info/RECORD +0 -95
  65. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  66. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  67. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +0 -0
  68. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityLoss # noqa:F401
1
2
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
2
3
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
3
4
  from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
@@ -0,0 +1,136 @@
1
+ from typing import Tuple
2
+ from typing import Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
8
+
9
+
10
+ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase):
11
+ @staticmethod
12
+ def distillation_loss_fn(student_logits, teacher_logits, beta=1.0):
13
+ """
14
+ Compute Cosine loss (Cosine Similarity Loss).
15
+ Args:
16
+ student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
17
+ teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
18
+ beta: Coefficient beta of generalized Cosine Similarity in the interval [0, 1]. Default: `1.0` (float): .
19
+ Returns:
20
+ torch.Tensor: cosine similarity loss
21
+ """
22
+ student_norm = F.normalize(student_logits, p=2, dim=-1)
23
+ teacher_norm = F.normalize(teacher_logits, p=2, dim=-1)
24
+
25
+ cosine_sim = F.cosine_similarity(student_norm, teacher_norm, dim=-1)
26
+ loss = beta * (1 - cosine_sim)
27
+ return loss.sum()
28
+
29
+ @classmethod
30
+ def forward(
31
+ cls,
32
+ ctx,
33
+ student_input: torch.Tensor,
34
+ student_weight: torch.Tensor,
35
+ teacher_input: torch.Tensor,
36
+ teacher_weight: torch.Tensor,
37
+ true_labels: torch.LongTensor,
38
+ student_bias: torch.Tensor,
39
+ teacher_bias: torch.Tensor,
40
+ weight_hard_loss: float = 0.5,
41
+ weight_soft_loss: float = 0.5,
42
+ beta: float = 0.5,
43
+ ignore_index: int = -100,
44
+ temperature: float = 1.0,
45
+ compiled: bool = True,
46
+ chunk_size: int = 1024,
47
+ return_soft_hard_loss: bool = False,
48
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
49
+ return super().forward(
50
+ cls=cls,
51
+ ctx=ctx,
52
+ student_input=student_input,
53
+ student_weight=student_weight,
54
+ teacher_input=teacher_input,
55
+ teacher_weight=teacher_weight,
56
+ target=true_labels,
57
+ student_bias=student_bias,
58
+ teacher_bias=teacher_bias,
59
+ chunk_size=chunk_size,
60
+ weight_hard_loss=weight_hard_loss,
61
+ weight_soft_loss=weight_soft_loss,
62
+ beta=beta,
63
+ ignore_index=ignore_index,
64
+ temperature=temperature,
65
+ compiled=compiled,
66
+ return_soft_hard_loss=return_soft_hard_loss,
67
+ )
68
+
69
+ @staticmethod
70
+ def backward(ctx, grad_output, *args):
71
+ grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
72
+
73
+ return (
74
+ *grads,
75
+ None, # teacher_bias
76
+ None, # weight_hard_loss
77
+ None, # weight_soft_loss
78
+ None, # beta
79
+ None, # ignore_index
80
+ None, # temperature
81
+ None, # compiled
82
+ None, # chunk_size
83
+ None, # return_soft_hard_loss
84
+ )
85
+
86
+
87
+ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
88
+ def __init__(
89
+ self,
90
+ weight_hard_loss: float = 0.5,
91
+ weight_soft_loss: float = 0.5,
92
+ beta: float = 0.5,
93
+ ignore_index: int = -100,
94
+ temperature: float = 1.0,
95
+ compiled: bool = True,
96
+ chunk_size: int = 1024,
97
+ return_soft_hard_loss: bool = False,
98
+ ):
99
+ super().__init__()
100
+ assert temperature != 0, "Temperature cannot be 0."
101
+ self.weight_hard_loss = weight_hard_loss
102
+ self.weight_soft_loss = weight_soft_loss
103
+ self.ignore_index = ignore_index
104
+ self.temperature = temperature
105
+ self.compiled = compiled
106
+ self.beta = beta
107
+ self.chunk_size = chunk_size
108
+ self.return_soft_hard_loss = return_soft_hard_loss
109
+
110
+ def forward(
111
+ self,
112
+ student_input: torch.Tensor,
113
+ student_weight: torch.Tensor,
114
+ teacher_input: torch.Tensor,
115
+ teacher_weight: torch.Tensor,
116
+ true_labels: torch.LongTensor,
117
+ student_bias: torch.Tensor = None,
118
+ teacher_bias: torch.Tensor = None,
119
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
120
+ return LigerFusedLinearCosineSimilarityFunction.apply(
121
+ student_input,
122
+ student_weight,
123
+ teacher_input,
124
+ teacher_weight,
125
+ true_labels,
126
+ student_bias,
127
+ teacher_bias,
128
+ self.weight_hard_loss,
129
+ self.weight_soft_loss,
130
+ self.beta,
131
+ self.ignore_index,
132
+ self.temperature,
133
+ self.compiled,
134
+ self.chunk_size,
135
+ self.return_soft_hard_loss,
136
+ )
@@ -13,6 +13,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
13
13
  ref_chosen_logps=None,
14
14
  ref_rejected_logps=None,
15
15
  beta=0.1,
16
+ loss_type="sigmoid",
16
17
  ):
17
18
  """
18
19
  Paper: https://arxiv.org/pdf/2305.18290
@@ -48,8 +49,50 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
48
49
  chosen_rewards = beta * chosen_logratios
49
50
  rejected_rewards = beta * rejected_logratios
50
51
 
51
- logits_diff = beta * (chosen_logratios - rejected_logratios)
52
- loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
52
+ if loss_type == "sigmoid":
53
+ logits_diff = beta * (chosen_logratios - rejected_logratios)
54
+ loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
55
+
56
+ elif loss_type == "apo_zero":
57
+ # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
58
+ # Use this loss when you believe the chosen outputs are better than your model's default output
59
+ losses_chosen = 1 - F.sigmoid(beta * chosen_logratios) # Increase chosen likelihood
60
+ losses_rejected = F.sigmoid(beta * rejected_logratios)
61
+ losses = losses_chosen + losses_rejected
62
+ loss = losses.sum() / (full_target.shape[0] // 2)
63
+
64
+ elif loss_type == "apo_down":
65
+ # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
66
+ # Use this loss when you believe the chosen outputs are worse than your model's default output.
67
+ # Decrease chosen likelihood and decrease rejected likelihood more
68
+ losses_chosen = F.sigmoid(beta * chosen_logratios)
69
+ losses_rejected = 1 - F.sigmoid(beta * (chosen_logratios - rejected_logratios))
70
+ losses = losses_chosen + losses_rejected
71
+ loss = losses.sum() / (full_target.shape[0] // 2)
72
+
73
+ elif loss_type == "sppo_hard":
74
+ # In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach,
75
+ # estimated using the PairRM score. The probability calculation is conducted outside of the trainer class.
76
+ # The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is
77
+ # set to 1 for the winner and 0 for the loser.
78
+ a = chosen_logps - ref_chosen_logps
79
+ b = rejected_logps - ref_rejected_logps
80
+ losses = (a - 0.5 / beta) ** 2 + (b + 0.5 / beta) ** 2
81
+ loss = losses.sum() / (full_target.shape[0] // 2)
82
+
83
+ elif loss_type == "nca_pair":
84
+ losses = (
85
+ -F.logsigmoid(chosen_rewards)
86
+ - 0.5 * F.logsigmoid(-chosen_rewards)
87
+ - 0.5 * F.logsigmoid(-rejected_rewards)
88
+ )
89
+ loss = losses.sum() / (full_target.shape[0] // 2)
90
+
91
+ else:
92
+ raise ValueError(
93
+ f"Unsupported loss_type: {loss_type}. Supported types are: sigmoid, apo_zero, apo_down, sppo_hard, nca_pair"
94
+ )
95
+
53
96
  return loss, chosen_rewards, rejected_rewards
54
97
 
55
98
  @classmethod
@@ -70,6 +113,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
70
113
  use_ref_model=True,
71
114
  average_log_prob=False,
72
115
  chunk_size=1,
116
+ loss_type="sigmoid",
73
117
  ):
74
118
  """
75
119
  Fused linear layer with DPO loss.
@@ -108,12 +152,13 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
108
152
  ref_bias=ref_bias,
109
153
  average_log_prob=average_log_prob,
110
154
  chunk_size=chunk_size,
155
+ loss_type=loss_type,
111
156
  )
112
157
 
113
158
  @staticmethod
114
159
  def backward(ctx, *grad_output):
115
160
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
116
- return *grads, None, None, None, None, None, None, None, None, None, None
161
+ return *grads, None, None, None, None, None, None, None, None, None, None, None
117
162
 
118
163
 
119
164
  class LigerFusedLinearDPOLoss(torch.nn.Module):
@@ -130,6 +175,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
130
175
  use_ref_model: bool = True,
131
176
  average_log_prob: bool = False,
132
177
  chunk_size: int = 1,
178
+ loss_type: str = "sigmoid",
133
179
  ):
134
180
  """
135
181
  Args:
@@ -149,6 +195,10 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
149
195
  self.use_ref_model = use_ref_model
150
196
  self.average_log_prob = average_log_prob
151
197
  self.chunk_size = chunk_size
198
+ self.loss_type = loss_type
199
+ supported_loss_types = {"sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"}
200
+ if self.loss_type not in supported_loss_types:
201
+ raise ValueError(f"Unsupported loss_type: {self.loss_type}. Supported types are: {supported_loss_types}")
152
202
 
153
203
  def forward(
154
204
  self,
@@ -175,4 +225,5 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
175
225
  self.use_ref_model,
176
226
  self.average_log_prob,
177
227
  self.chunk_size,
228
+ self.loss_type,
178
229
  )
@@ -1,3 +1,4 @@
1
+ from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction
1
2
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
2
3
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3
4
  from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
@@ -9,6 +10,7 @@ from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
9
10
  liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
10
11
  liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
11
12
  liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
13
+ liger_fused_linear_cosine = LigerFusedLinearCosineSimilarityFunction.apply
12
14
  liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
13
15
  liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
14
16
  liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
@@ -1,5 +1,7 @@
1
1
  from abc import abstractmethod
2
2
  from functools import partial
3
+ from typing import Tuple
4
+ from typing import Union
3
5
 
4
6
  import torch
5
7
 
@@ -157,8 +159,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
157
159
  compute_ce_loss=True,
158
160
  temperature=1.0,
159
161
  compiled=True,
162
+ return_soft_hard_loss=False,
160
163
  **loss_kwargs,
161
- ):
164
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
162
165
  """
163
166
  Base class for fused linear layer with distillation loss.
164
167
  Only need to compute gradients for student model.
@@ -180,6 +183,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
180
183
  compute_ce_loss (bool): Whether to compute CE loss.
181
184
  temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
182
185
  compiled (bool): Whether to use torch compile for chunk accumulation.
186
+ return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
183
187
  loss_kwargs (dict): Other possible arguments that a loss function might need
184
188
  """
185
189
  CHUNK_SIZE = chunk_size
@@ -187,6 +191,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
187
191
  grad_inputs = []
188
192
  grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
189
193
  loss_acc = torch.zeros((), device=student_input.device)
194
+ soft_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
195
+ hard_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
190
196
 
191
197
  loss_func_to_call = partial(
192
198
  LigerFusedLinearDistillationBase._compute_loss,
@@ -247,6 +253,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
247
253
  )
248
254
  grad_weight.add_(chunk_grad_weight)
249
255
  loss_acc.add_(chunk_loss)
256
+ if return_soft_hard_loss:
257
+ soft_loss_acc.add_(chunk_soft_loss)
258
+ hard_loss_acc.add_(chunk_hard_loss)
250
259
  return chunk_grad_input
251
260
 
252
261
  if compiled:
@@ -268,10 +277,12 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
268
277
  grad_weight,
269
278
  grad_bias,
270
279
  )
280
+ if return_soft_hard_loss:
281
+ return loss_acc, soft_loss_acc, hard_loss_acc
271
282
  return loss_acc
272
283
 
273
284
  @staticmethod
274
- def backward(ctx, grad_output):
285
+ def backward(ctx, grad_output, *args):
275
286
  grad_input, grad_weight, grad_bias = ctx.saved_tensors
276
287
  if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
277
288
  grad_input = grad_input * grad_output
@@ -34,6 +34,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
34
34
  beta=0.04,
35
35
  loss_type="bnpo",
36
36
  max_completion_length=None,
37
+ importance_sampling_level="token",
37
38
  temperature=1.0,
38
39
  compiled=True,
39
40
  use_ref_model=False,
@@ -92,6 +93,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
92
93
  beta=beta,
93
94
  loss_type=loss_type,
94
95
  max_completion_length=max_completion_length,
96
+ importance_sampling_level=importance_sampling_level,
95
97
  temperature=temperature,
96
98
  use_ref_model=use_ref_model,
97
99
  ppo_loss_fn=cls.ppo_loss_fn,
@@ -261,6 +263,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
261
263
  beta=0.04,
262
264
  loss_type="bnpo",
263
265
  max_completion_length=None,
266
+ importance_sampling_level="token",
264
267
  temperature=1.0,
265
268
  use_ref_model=False,
266
269
  ppo_loss_fn=None,
@@ -292,6 +295,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
292
295
  beta=beta,
293
296
  loss_type=loss_type,
294
297
  max_completion_length=max_completion_length,
298
+ importance_sampling_level=importance_sampling_level,
295
299
  )
296
300
 
297
301
  return chunk_loss, chunk_metrics
@@ -31,6 +31,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
31
31
  beta=0.04,
32
32
  loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
33
33
  max_completion_length=None, # Required for dr_grpo
34
+ importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
34
35
  **kwargs,
35
36
  ):
36
37
  """GRPO Loss Function matching GRPOTrainer implementation."""
@@ -50,7 +51,22 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
50
51
 
51
52
  # Compute policy gradient loss with importance sampling ratio
52
53
  old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
53
- coef_1 = torch.exp(per_token_logps - old_per_token_logps)
54
+ log_ratio = per_token_logps - old_per_token_logps
55
+
56
+ if importance_sampling_level == "token":
57
+ log_importance_weights = log_ratio
58
+ elif importance_sampling_level == "sequence":
59
+ log_importance_weights = (log_ratio * attention_mask).sum(-1) / attention_mask.sum(-1).clamp(min=1.0)
60
+ log_importance_weights = log_importance_weights.unsqueeze(-1)
61
+ else:
62
+ raise ValueError(
63
+ f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
64
+ "and 'sequence'."
65
+ )
66
+
67
+ # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
68
+ # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
69
+ coef_1 = torch.exp(log_importance_weights)
54
70
  coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
55
71
  per_token_loss1 = coef_1 * advantages.unsqueeze(1)
56
72
  per_token_loss2 = coef_2 * advantages.unsqueeze(1)
@@ -85,9 +101,19 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
85
101
  metrics = []
86
102
  if beta != 0.0:
87
103
  metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
88
- is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
89
- (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
90
- )
104
+
105
+ # Adjust clipping metric calculation based on importance sampling level
106
+ if importance_sampling_level == "token":
107
+ is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
108
+ (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
109
+ )
110
+ else: # sequence level
111
+ # For sequence level, coef_1 is shape (B, 1), advantages is shape (B,)
112
+ is_clipped = ((coef_1.squeeze(-1) < 1 - epsilon_low) & (advantages < 0)) | (
113
+ (coef_1.squeeze(-1) > 1 + epsilon_high) & (advantages > 0)
114
+ )
115
+ is_clipped = is_clipped.unsqueeze(1).expand_as(attention_mask)
116
+
91
117
  metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
92
118
  return loss, metrics
93
119
 
@@ -111,6 +137,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
111
137
  epsilon_high=0.2,
112
138
  loss_type="bnpo",
113
139
  max_completion_length=None,
140
+ importance_sampling_level="token",
114
141
  temperature=1.0,
115
142
  compiled=True,
116
143
  use_ref_model=True,
@@ -132,6 +159,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
132
159
  beta (float): Weight for the KL penalty
133
160
  loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
134
161
  max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
162
+ importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
135
163
  temperature (float): Temperature for the logits
136
164
  compiled (bool): Whether to use torch compile
137
165
  use_ref_model (bool): Whether to use a reference model
@@ -162,6 +190,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
162
190
  compiled=compiled,
163
191
  use_ref_model=use_ref_model,
164
192
  chunk_size=chunk_size,
193
+ importance_sampling_level=importance_sampling_level,
165
194
  )
166
195
 
167
196
  @staticmethod
@@ -187,6 +216,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
187
216
  None, # grad_epsilon_high
188
217
  None, # grad_loss_type (string, not differentiable)
189
218
  None, # grad_max_completion_length (int, not differentiable)
219
+ None, # grad_importance_sampling_level (string, not differentiable)
190
220
  None, # grad_temperature
191
221
  None, # grad_compiled
192
222
  None, # grad_use_ref_model
@@ -207,6 +237,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
207
237
  epsilon_high: float = 0.2,
208
238
  loss_type: str = "bnpo",
209
239
  max_completion_length: Optional[int] = None,
240
+ importance_sampling_level: str = "token",
210
241
  temperature: float = 1.0,
211
242
  ):
212
243
  """
@@ -219,6 +250,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
219
250
  epsilon_high (float): Upper bound for the importance sampling ratio.
220
251
  loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
221
252
  max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
253
+ importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
222
254
  temperature (float): Temperature for the logits.
223
255
  """
224
256
  super().__init__()
@@ -230,6 +262,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
230
262
  self.epsilon_high = epsilon_high
231
263
  self.loss_type = loss_type
232
264
  self.max_completion_length = max_completion_length
265
+ self.importance_sampling_level = importance_sampling_level
233
266
  self.temperature = temperature
234
267
 
235
268
  def forward(
@@ -263,6 +296,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
263
296
  self.epsilon_high,
264
297
  self.loss_type,
265
298
  self.max_completion_length,
299
+ self.importance_sampling_level,
266
300
  self.temperature,
267
301
  self.compiled,
268
302
  self.use_ref_model,
@@ -1,3 +1,8 @@
1
+ import math
2
+
3
+ from typing import Tuple
4
+ from typing import Union
5
+
1
6
  import torch
2
7
  import torch.nn.functional as F
3
8
 
@@ -25,8 +30,9 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
25
30
  jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
26
31
  else:
27
32
  # Compute probabilities (only required for mean calculation)
28
- mean_probs = (1 - beta) * student_log_probs.exp() + beta * teacher_log_probs.exp()
29
- log_mean_probs = mean_probs.log()
33
+ log_mean_probs = torch.logsumexp(
34
+ torch.stack([student_log_probs + math.log(1 - beta), teacher_log_probs + math.log(beta)], dim=0), dim=0
35
+ )
30
36
 
31
37
  student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
32
38
  teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
@@ -53,6 +59,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
53
59
  temperature: float = 1.0,
54
60
  compiled: bool = True,
55
61
  chunk_size: int = 1024,
62
+ return_soft_hard_loss: bool = False,
56
63
  ):
57
64
  """
58
65
  Fused linear layer with JSD distillation loss.
@@ -69,8 +76,9 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
69
76
  temperature (float): Temperature for softening/sharpening distributions
70
77
  compiled (bool): Whether to use torch compile
71
78
  chunk_size (int): Size of chunks for processing.
79
+ return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
72
80
  Returns:
73
- torch.Tensor: Computed loss
81
+ torch.Tensor: Computed loss, or tuple (loss, soft_loss, hard_loss) if return_soft_hard_loss=True
74
82
  """
75
83
  return super().forward(
76
84
  cls=cls,
@@ -89,11 +97,12 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
89
97
  ignore_index=ignore_index,
90
98
  temperature=temperature,
91
99
  compiled=compiled,
100
+ return_soft_hard_loss=return_soft_hard_loss,
92
101
  )
93
102
 
94
103
  @staticmethod
95
- def backward(ctx, grad_output):
96
- grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
104
+ def backward(ctx, grad_output, *args):
105
+ grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
97
106
 
98
107
  return (
99
108
  *grads,
@@ -105,6 +114,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
105
114
  None, # temperature
106
115
  None, # compiled
107
116
  None, # chunk_size
117
+ None, # return_soft_hard_loss
108
118
  )
109
119
 
110
120
 
@@ -122,6 +132,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
122
132
  temperature: float = 1.0,
123
133
  compiled: bool = True,
124
134
  chunk_size: int = 1024,
135
+ return_soft_hard_loss: bool = False,
125
136
  ):
126
137
  """
127
138
  Args:
@@ -132,6 +143,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
132
143
  compiled (bool): Whether to use torch compile
133
144
  beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
134
145
  chunk_size (int): Size of chunks for processing.
146
+ return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
135
147
  """
136
148
  super().__init__()
137
149
  assert temperature != 0, "Temperature cannot be 0."
@@ -142,6 +154,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
142
154
  self.compiled = compiled
143
155
  self.beta = beta
144
156
  self.chunk_size = chunk_size
157
+ self.return_soft_hard_loss = return_soft_hard_loss
145
158
 
146
159
  def forward(
147
160
  self,
@@ -152,7 +165,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
152
165
  true_labels: torch.LongTensor,
153
166
  student_bias: torch.Tensor = None,
154
167
  teacher_bias: torch.Tensor = None,
155
- ) -> torch.Tensor:
168
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
156
169
  """
157
170
  Compute the JSD distillation loss.
158
171
 
@@ -164,7 +177,9 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
164
177
  true_labels (torch.LongTensor): Target labels tensor
165
178
 
166
179
  Returns:
167
- torch.Tensor: Computed loss
180
+ torch.Tensor or Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
181
+ If return_soft_hard_loss is False: Computed combined loss
182
+ If return_soft_hard_loss is True: Tuple of (combined_loss, soft_loss, hard_loss)
168
183
  """
169
184
  return LigerFusedLinearJSDFunction.apply(
170
185
  student_input,
@@ -181,4 +196,5 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
181
196
  self.temperature,
182
197
  self.compiled,
183
198
  self.chunk_size,
199
+ self.return_soft_hard_loss,
184
200
  )