liger-kernel 0.6.2__py3-none-any.whl → 0.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +46 -9
  5. liger_kernel/chunked_loss/jsd_loss.py +23 -7
  6. liger_kernel/ops/cross_entropy.py +118 -62
  7. liger_kernel/ops/fused_linear_cross_entropy.py +97 -13
  8. liger_kernel/ops/grpo_loss.py +3 -1
  9. liger_kernel/ops/layer_norm.py +86 -69
  10. liger_kernel/ops/poly_norm.py +386 -0
  11. liger_kernel/ops/tiled_mlp.py +136 -0
  12. liger_kernel/transformers/__init__.py +36 -0
  13. liger_kernel/transformers/cross_entropy.py +8 -3
  14. liger_kernel/transformers/functional.py +31 -6
  15. liger_kernel/transformers/fused_linear_cross_entropy.py +13 -4
  16. liger_kernel/transformers/grpo_loss.py +56 -1
  17. liger_kernel/transformers/model/falcon_h1.py +122 -0
  18. liger_kernel/transformers/model/gemma.py +19 -7
  19. liger_kernel/transformers/model/gemma2.py +22 -7
  20. liger_kernel/transformers/model/gemma3.py +52 -14
  21. liger_kernel/transformers/model/glm4.py +18 -5
  22. liger_kernel/transformers/model/glm4v.py +19 -6
  23. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  24. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  25. liger_kernel/transformers/model/internvl.py +157 -0
  26. liger_kernel/transformers/model/llama.py +16 -6
  27. liger_kernel/transformers/model/llama4.py +18 -5
  28. liger_kernel/transformers/model/llava.py +18 -6
  29. liger_kernel/transformers/model/loss_utils.py +32 -3
  30. liger_kernel/transformers/model/mistral.py +17 -7
  31. liger_kernel/transformers/model/mixtral.py +24 -9
  32. liger_kernel/transformers/model/mllama.py +14 -5
  33. liger_kernel/transformers/model/olmo2.py +18 -5
  34. liger_kernel/transformers/model/olmo3.py +142 -0
  35. liger_kernel/transformers/model/output_classes.py +147 -0
  36. liger_kernel/transformers/model/paligemma.py +41 -5
  37. liger_kernel/transformers/model/phi3.py +16 -8
  38. liger_kernel/transformers/model/qwen2.py +18 -4
  39. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  40. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  41. liger_kernel/transformers/model/qwen3.py +22 -6
  42. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  43. liger_kernel/transformers/model/qwen3_next.py +146 -0
  44. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  45. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  46. liger_kernel/transformers/model/smollm3.py +17 -7
  47. liger_kernel/transformers/model/smolvlm.py +158 -0
  48. liger_kernel/transformers/monkey_patch.py +830 -3
  49. liger_kernel/transformers/multi_token_attention.py +1 -1
  50. liger_kernel/transformers/poly_norm.py +42 -0
  51. liger_kernel/transformers/rms_norm.py +7 -0
  52. liger_kernel/transformers/rope.py +43 -0
  53. liger_kernel/transformers/swiglu.py +17 -0
  54. liger_kernel/transformers/tiled_mlp.py +133 -0
  55. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.4.dist-info}/METADATA +16 -10
  56. liger_kernel-0.6.4.dist-info/RECORD +118 -0
  57. liger_kernel-0.6.2.dist-info/RECORD +0 -104
  58. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.4.dist-info}/WHEEL +0 -0
  59. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.4.dist-info}/licenses/LICENSE +0 -0
  60. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.4.dist-info}/licenses/NOTICE +0 -0
  61. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.4.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,6 @@
1
+ from typing import Tuple
2
+ from typing import Union
3
+
1
4
  import torch
2
5
  import torch.nn.functional as F
3
6
 
@@ -41,7 +44,8 @@ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase)
41
44
  temperature: float = 1.0,
42
45
  compiled: bool = True,
43
46
  chunk_size: int = 1024,
44
- ):
47
+ return_soft_hard_loss: bool = False,
48
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
45
49
  return super().forward(
46
50
  cls=cls,
47
51
  ctx=ctx,
@@ -59,11 +63,12 @@ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase)
59
63
  ignore_index=ignore_index,
60
64
  temperature=temperature,
61
65
  compiled=compiled,
66
+ return_soft_hard_loss=return_soft_hard_loss,
62
67
  )
63
68
 
64
69
  @staticmethod
65
- def backward(ctx, grad_output):
66
- grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
70
+ def backward(ctx, grad_output, *args):
71
+ grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
67
72
 
68
73
  return (
69
74
  *grads,
@@ -75,6 +80,7 @@ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase)
75
80
  None, # temperature
76
81
  None, # compiled
77
82
  None, # chunk_size
83
+ None, # return_soft_hard_loss
78
84
  )
79
85
 
80
86
 
@@ -88,6 +94,7 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
88
94
  temperature: float = 1.0,
89
95
  compiled: bool = True,
90
96
  chunk_size: int = 1024,
97
+ return_soft_hard_loss: bool = False,
91
98
  ):
92
99
  super().__init__()
93
100
  assert temperature != 0, "Temperature cannot be 0."
@@ -98,6 +105,7 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
98
105
  self.compiled = compiled
99
106
  self.beta = beta
100
107
  self.chunk_size = chunk_size
108
+ self.return_soft_hard_loss = return_soft_hard_loss
101
109
 
102
110
  def forward(
103
111
  self,
@@ -108,7 +116,7 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
108
116
  true_labels: torch.LongTensor,
109
117
  student_bias: torch.Tensor = None,
110
118
  teacher_bias: torch.Tensor = None,
111
- ) -> torch.Tensor:
119
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
112
120
  return LigerFusedLinearCosineSimilarityFunction.apply(
113
121
  student_input,
114
122
  student_weight,
@@ -124,4 +132,5 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
124
132
  self.temperature,
125
133
  self.compiled,
126
134
  self.chunk_size,
135
+ self.return_soft_hard_loss,
127
136
  )
@@ -1,5 +1,7 @@
1
1
  from abc import abstractmethod
2
2
  from functools import partial
3
+ from typing import Tuple
4
+ from typing import Union
3
5
 
4
6
  import torch
5
7
 
@@ -157,8 +159,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
157
159
  compute_ce_loss=True,
158
160
  temperature=1.0,
159
161
  compiled=True,
162
+ return_soft_hard_loss=False,
160
163
  **loss_kwargs,
161
- ):
164
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
162
165
  """
163
166
  Base class for fused linear layer with distillation loss.
164
167
  Only need to compute gradients for student model.
@@ -180,6 +183,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
180
183
  compute_ce_loss (bool): Whether to compute CE loss.
181
184
  temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
182
185
  compiled (bool): Whether to use torch compile for chunk accumulation.
186
+ return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
183
187
  loss_kwargs (dict): Other possible arguments that a loss function might need
184
188
  """
185
189
  CHUNK_SIZE = chunk_size
@@ -187,6 +191,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
187
191
  grad_inputs = []
188
192
  grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
189
193
  loss_acc = torch.zeros((), device=student_input.device)
194
+ soft_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
195
+ hard_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
190
196
 
191
197
  loss_func_to_call = partial(
192
198
  LigerFusedLinearDistillationBase._compute_loss,
@@ -247,6 +253,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
247
253
  )
248
254
  grad_weight.add_(chunk_grad_weight)
249
255
  loss_acc.add_(chunk_loss)
256
+ if return_soft_hard_loss:
257
+ soft_loss_acc.add_(chunk_soft_loss)
258
+ hard_loss_acc.add_(chunk_hard_loss)
250
259
  return chunk_grad_input
251
260
 
252
261
  if compiled:
@@ -268,10 +277,12 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
268
277
  grad_weight,
269
278
  grad_bias,
270
279
  )
280
+ if return_soft_hard_loss:
281
+ return loss_acc, soft_loss_acc, hard_loss_acc
271
282
  return loss_acc
272
283
 
273
284
  @staticmethod
274
- def backward(ctx, grad_output):
285
+ def backward(ctx, grad_output, *args):
275
286
  grad_input, grad_weight, grad_bias = ctx.saved_tensors
276
287
  if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
277
288
  grad_input = grad_input * grad_output
@@ -32,8 +32,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
32
32
  epsilon_low=0.2,
33
33
  epsilon_high=0.2,
34
34
  beta=0.04,
35
- loss_type="bnpo",
35
+ loss_type="dapo",
36
36
  max_completion_length=None,
37
+ importance_sampling_level="token",
37
38
  temperature=1.0,
38
39
  compiled=True,
39
40
  use_ref_model=False,
@@ -59,7 +60,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
59
60
  epsilon_low: Lower bound for clipping the importance sampling ratio
60
61
  epsilon_high: Upper bound for clipping the importance sampling ratio
61
62
  beta: Weight for the KL penalty
62
- loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo")
63
+ loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo")
63
64
  max_completion_length: Maximum completion length required for "dr_grpo"
64
65
  temperature: Temperature for the logits
65
66
  compiled: Whether to use torch compile
@@ -92,6 +93,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
92
93
  beta=beta,
93
94
  loss_type=loss_type,
94
95
  max_completion_length=max_completion_length,
96
+ importance_sampling_level=importance_sampling_level,
95
97
  temperature=temperature,
96
98
  use_ref_model=use_ref_model,
97
99
  ppo_loss_fn=cls.ppo_loss_fn,
@@ -242,6 +244,21 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
242
244
 
243
245
  return loss_acc, tuple(final_metrics)
244
246
 
247
+ @staticmethod
248
+ def _compute_dapo_normalizer(attention_mask):
249
+ """Global active tokens averaged per process."""
250
+ normalizer = attention_mask.to(torch.float32).sum()
251
+ world_size = 1
252
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
253
+ import torch.distributed as dist
254
+
255
+ normalizer = normalizer.clone()
256
+ dist.all_reduce(normalizer, op=dist.ReduceOp.SUM)
257
+ world_size = dist.get_world_size()
258
+
259
+ normalizer = normalizer / world_size
260
+ return torch.clamp(normalizer, min=1.0)
261
+
245
262
  @staticmethod
246
263
  def _compute_chunk_loss(
247
264
  input_chunk,
@@ -259,8 +276,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
259
276
  epsilon_low=0.2,
260
277
  epsilon_high=0.2,
261
278
  beta=0.04,
262
- loss_type="bnpo",
279
+ loss_type="dapo",
263
280
  max_completion_length=None,
281
+ importance_sampling_level="token",
264
282
  temperature=1.0,
265
283
  use_ref_model=False,
266
284
  ppo_loss_fn=None,
@@ -292,6 +310,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
292
310
  beta=beta,
293
311
  loss_type=loss_type,
294
312
  max_completion_length=max_completion_length,
313
+ importance_sampling_level=importance_sampling_level,
295
314
  )
296
315
 
297
316
  return chunk_loss, chunk_metrics
@@ -337,10 +356,11 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
337
356
  None, # grad_epsilon_low
338
357
  None, # grad_epsilon_high
339
358
  None, # grad_beta
359
+ None, # grad_loss_type
360
+ None, # grad_max_completion_length
361
+ None, # grad_importance_sampling_level
340
362
  None, # grad_temperature
341
363
  None, # grad_compiled
342
364
  None, # grad_use_ref_model
343
365
  None, # grad_chunk_size
344
- None, # grad_loss_type
345
- None, # grad_max_completion_length
346
366
  )
@@ -29,8 +29,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
29
29
  epsilon_low=0.2,
30
30
  epsilon_high=0.2,
31
31
  beta=0.04,
32
- loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
32
+ loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo"]
33
33
  max_completion_length=None, # Required for dr_grpo
34
+ importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
34
35
  **kwargs,
35
36
  ):
36
37
  """GRPO Loss Function matching GRPOTrainer implementation."""
@@ -50,7 +51,22 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
50
51
 
51
52
  # Compute policy gradient loss with importance sampling ratio
52
53
  old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
53
- coef_1 = torch.exp(per_token_logps - old_per_token_logps)
54
+ log_ratio = per_token_logps - old_per_token_logps
55
+
56
+ if importance_sampling_level == "token":
57
+ log_importance_weights = log_ratio
58
+ elif importance_sampling_level == "sequence":
59
+ log_importance_weights = (log_ratio * attention_mask).sum(-1) / attention_mask.sum(-1).clamp(min=1.0)
60
+ log_importance_weights = log_importance_weights.unsqueeze(-1)
61
+ else:
62
+ raise ValueError(
63
+ f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
64
+ "and 'sequence'."
65
+ )
66
+
67
+ # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
68
+ # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
69
+ coef_1 = torch.exp(log_importance_weights)
54
70
  coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
55
71
  per_token_loss1 = coef_1 * advantages.unsqueeze(1)
56
72
  per_token_loss2 = coef_2 * advantages.unsqueeze(1)
@@ -78,6 +94,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
78
94
  if max_completion_length is None:
79
95
  raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
80
96
  loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
97
+ elif loss_type == "dapo":
98
+ loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask)
99
+ loss = (per_token_loss * attention_mask).sum() / loss_normalizer
81
100
  else:
82
101
  raise ValueError(f"Unknown loss type: {loss_type}")
83
102
 
@@ -85,9 +104,19 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
85
104
  metrics = []
86
105
  if beta != 0.0:
87
106
  metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
88
- is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
89
- (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
90
- )
107
+
108
+ # Adjust clipping metric calculation based on importance sampling level
109
+ if importance_sampling_level == "token":
110
+ is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
111
+ (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
112
+ )
113
+ else: # sequence level
114
+ # For sequence level, coef_1 is shape (B, 1), advantages is shape (B,)
115
+ is_clipped = ((coef_1.squeeze(-1) < 1 - epsilon_low) & (advantages < 0)) | (
116
+ (coef_1.squeeze(-1) > 1 + epsilon_high) & (advantages > 0)
117
+ )
118
+ is_clipped = is_clipped.unsqueeze(1).expand_as(attention_mask)
119
+
91
120
  metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
92
121
  return loss, metrics
93
122
 
@@ -109,8 +138,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
109
138
  beta=0.04,
110
139
  epsilon_low=0.2,
111
140
  epsilon_high=0.2,
112
- loss_type="bnpo",
141
+ loss_type="dapo",
113
142
  max_completion_length=None,
143
+ importance_sampling_level="token",
114
144
  temperature=1.0,
115
145
  compiled=True,
116
146
  use_ref_model=True,
@@ -130,8 +160,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
130
160
  ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
131
161
  ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
132
162
  beta (float): Weight for the KL penalty
133
- loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
163
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
134
164
  max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
165
+ importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
135
166
  temperature (float): Temperature for the logits
136
167
  compiled (bool): Whether to use torch compile
137
168
  use_ref_model (bool): Whether to use a reference model
@@ -162,6 +193,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
162
193
  compiled=compiled,
163
194
  use_ref_model=use_ref_model,
164
195
  chunk_size=chunk_size,
196
+ importance_sampling_level=importance_sampling_level,
165
197
  )
166
198
 
167
199
  @staticmethod
@@ -187,6 +219,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
187
219
  None, # grad_epsilon_high
188
220
  None, # grad_loss_type (string, not differentiable)
189
221
  None, # grad_max_completion_length (int, not differentiable)
222
+ None, # grad_importance_sampling_level (string, not differentiable)
190
223
  None, # grad_temperature
191
224
  None, # grad_compiled
192
225
  None, # grad_use_ref_model
@@ -205,8 +238,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
205
238
  chunk_size: int = 1,
206
239
  epsilon_low: float = 0.2,
207
240
  epsilon_high: float = 0.2,
208
- loss_type: str = "bnpo",
241
+ loss_type: str = "dapo",
209
242
  max_completion_length: Optional[int] = None,
243
+ importance_sampling_level: str = "token",
210
244
  temperature: float = 1.0,
211
245
  ):
212
246
  """
@@ -217,8 +251,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
217
251
  chunk_size (int): Size of chunks for processing.
218
252
  epsilon_low (float): Lower bound for the importance sampling ratio.
219
253
  epsilon_high (float): Upper bound for the importance sampling ratio.
220
- loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
254
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
221
255
  max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
256
+ importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
222
257
  temperature (float): Temperature for the logits.
223
258
  """
224
259
  super().__init__()
@@ -230,6 +265,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
230
265
  self.epsilon_high = epsilon_high
231
266
  self.loss_type = loss_type
232
267
  self.max_completion_length = max_completion_length
268
+ self.importance_sampling_level = importance_sampling_level
233
269
  self.temperature = temperature
234
270
 
235
271
  def forward(
@@ -263,6 +299,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
263
299
  self.epsilon_high,
264
300
  self.loss_type,
265
301
  self.max_completion_length,
302
+ self.importance_sampling_level,
266
303
  self.temperature,
267
304
  self.compiled,
268
305
  self.use_ref_model,
@@ -1,3 +1,8 @@
1
+ import math
2
+
3
+ from typing import Tuple
4
+ from typing import Union
5
+
1
6
  import torch
2
7
  import torch.nn.functional as F
3
8
 
@@ -25,8 +30,9 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
25
30
  jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
26
31
  else:
27
32
  # Compute probabilities (only required for mean calculation)
28
- mean_probs = (1 - beta) * student_log_probs.exp() + beta * teacher_log_probs.exp()
29
- log_mean_probs = mean_probs.log()
33
+ log_mean_probs = torch.logsumexp(
34
+ torch.stack([student_log_probs + math.log(1 - beta), teacher_log_probs + math.log(beta)], dim=0), dim=0
35
+ )
30
36
 
31
37
  student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
32
38
  teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
@@ -53,6 +59,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
53
59
  temperature: float = 1.0,
54
60
  compiled: bool = True,
55
61
  chunk_size: int = 1024,
62
+ return_soft_hard_loss: bool = False,
56
63
  ):
57
64
  """
58
65
  Fused linear layer with JSD distillation loss.
@@ -69,8 +76,9 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
69
76
  temperature (float): Temperature for softening/sharpening distributions
70
77
  compiled (bool): Whether to use torch compile
71
78
  chunk_size (int): Size of chunks for processing.
79
+ return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
72
80
  Returns:
73
- torch.Tensor: Computed loss
81
+ torch.Tensor: Computed loss, or tuple (loss, soft_loss, hard_loss) if return_soft_hard_loss=True
74
82
  """
75
83
  return super().forward(
76
84
  cls=cls,
@@ -89,11 +97,12 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
89
97
  ignore_index=ignore_index,
90
98
  temperature=temperature,
91
99
  compiled=compiled,
100
+ return_soft_hard_loss=return_soft_hard_loss,
92
101
  )
93
102
 
94
103
  @staticmethod
95
- def backward(ctx, grad_output):
96
- grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
104
+ def backward(ctx, grad_output, *args):
105
+ grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
97
106
 
98
107
  return (
99
108
  *grads,
@@ -105,6 +114,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
105
114
  None, # temperature
106
115
  None, # compiled
107
116
  None, # chunk_size
117
+ None, # return_soft_hard_loss
108
118
  )
109
119
 
110
120
 
@@ -122,6 +132,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
122
132
  temperature: float = 1.0,
123
133
  compiled: bool = True,
124
134
  chunk_size: int = 1024,
135
+ return_soft_hard_loss: bool = False,
125
136
  ):
126
137
  """
127
138
  Args:
@@ -132,6 +143,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
132
143
  compiled (bool): Whether to use torch compile
133
144
  beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
134
145
  chunk_size (int): Size of chunks for processing.
146
+ return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
135
147
  """
136
148
  super().__init__()
137
149
  assert temperature != 0, "Temperature cannot be 0."
@@ -142,6 +154,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
142
154
  self.compiled = compiled
143
155
  self.beta = beta
144
156
  self.chunk_size = chunk_size
157
+ self.return_soft_hard_loss = return_soft_hard_loss
145
158
 
146
159
  def forward(
147
160
  self,
@@ -152,7 +165,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
152
165
  true_labels: torch.LongTensor,
153
166
  student_bias: torch.Tensor = None,
154
167
  teacher_bias: torch.Tensor = None,
155
- ) -> torch.Tensor:
168
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
156
169
  """
157
170
  Compute the JSD distillation loss.
158
171
 
@@ -164,7 +177,9 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
164
177
  true_labels (torch.LongTensor): Target labels tensor
165
178
 
166
179
  Returns:
167
- torch.Tensor: Computed loss
180
+ torch.Tensor or Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
181
+ If return_soft_hard_loss is False: Computed combined loss
182
+ If return_soft_hard_loss is True: Tuple of (combined_loss, soft_loss, hard_loss)
168
183
  """
169
184
  return LigerFusedLinearJSDFunction.apply(
170
185
  student_input,
@@ -181,4 +196,5 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
181
196
  self.temperature,
182
197
  self.compiled,
183
198
  self.chunk_size,
199
+ self.return_soft_hard_loss,
184
200
  )