liger-kernel-nightly 0.6.2.dev20251011154427__py3-none-any.whl → 0.6.4.dev20260107111351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (97) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +39 -11
  6. liger_kernel/ops/__init__.py +141 -0
  7. liger_kernel/ops/backends/README.md +151 -0
  8. liger_kernel/ops/backends/__init__.py +13 -0
  9. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  10. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  11. liger_kernel/ops/backends/_ascend/ops/__init__.py +43 -0
  12. liger_kernel/ops/backends/_ascend/ops/geglu.py +244 -0
  13. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  14. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  15. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  16. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  17. liger_kernel/ops/backends/registry.py +61 -0
  18. liger_kernel/ops/cross_entropy.py +75 -12
  19. liger_kernel/ops/dyt.py +5 -2
  20. liger_kernel/ops/fused_add_rms_norm.py +5 -1
  21. liger_kernel/ops/fused_linear_cross_entropy.py +45 -14
  22. liger_kernel/ops/geglu.py +5 -3
  23. liger_kernel/ops/group_norm.py +2 -1
  24. liger_kernel/ops/grpo_loss.py +3 -1
  25. liger_kernel/ops/layer_norm.py +86 -66
  26. liger_kernel/ops/poly_norm.py +390 -0
  27. liger_kernel/ops/rms_norm.py +131 -49
  28. liger_kernel/ops/tiled_mlp.py +136 -0
  29. liger_kernel/ops/utils.py +14 -0
  30. liger_kernel/transformers/__init__.py +30 -0
  31. liger_kernel/transformers/auto_model.py +21 -0
  32. liger_kernel/transformers/cross_entropy.py +9 -4
  33. liger_kernel/transformers/dyt.py +1 -1
  34. liger_kernel/transformers/experimental/embedding.py +1 -1
  35. liger_kernel/transformers/functional.py +48 -25
  36. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  37. liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
  38. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  39. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  40. liger_kernel/transformers/geglu.py +1 -1
  41. liger_kernel/transformers/group_norm.py +1 -1
  42. liger_kernel/transformers/grpo_loss.py +57 -2
  43. liger_kernel/transformers/jsd.py +1 -1
  44. liger_kernel/transformers/kl_div.py +1 -1
  45. liger_kernel/transformers/layer_norm.py +1 -1
  46. liger_kernel/transformers/llama4_rope.py +1 -1
  47. liger_kernel/transformers/model/falcon_h1.py +19 -5
  48. liger_kernel/transformers/model/gemma.py +17 -6
  49. liger_kernel/transformers/model/gemma2.py +14 -5
  50. liger_kernel/transformers/model/gemma3.py +26 -12
  51. liger_kernel/transformers/model/glm4.py +16 -4
  52. liger_kernel/transformers/model/glm4v.py +16 -4
  53. liger_kernel/transformers/model/glm4v_moe.py +23 -4
  54. liger_kernel/transformers/model/gpt_oss.py +211 -0
  55. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  56. liger_kernel/transformers/model/internvl.py +12 -5
  57. liger_kernel/transformers/model/llama.py +14 -5
  58. liger_kernel/transformers/model/llama4.py +16 -4
  59. liger_kernel/transformers/model/llava.py +12 -4
  60. liger_kernel/transformers/model/loss_utils.py +31 -3
  61. liger_kernel/transformers/model/mistral.py +15 -6
  62. liger_kernel/transformers/model/mixtral.py +16 -7
  63. liger_kernel/transformers/model/mllama.py +12 -4
  64. liger_kernel/transformers/model/olmo2.py +16 -4
  65. liger_kernel/transformers/model/olmo3.py +142 -0
  66. liger_kernel/transformers/model/output_classes.py +147 -0
  67. liger_kernel/transformers/model/paligemma.py +23 -5
  68. liger_kernel/transformers/model/phi3.py +14 -7
  69. liger_kernel/transformers/model/qwen2.py +16 -3
  70. liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
  71. liger_kernel/transformers/model/qwen2_vl.py +16 -4
  72. liger_kernel/transformers/model/qwen3.py +20 -5
  73. liger_kernel/transformers/model/qwen3_moe.py +19 -5
  74. liger_kernel/transformers/model/qwen3_next.py +146 -0
  75. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  76. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  77. liger_kernel/transformers/model/smollm3.py +15 -6
  78. liger_kernel/transformers/model/smolvlm.py +158 -0
  79. liger_kernel/transformers/monkey_patch.py +702 -48
  80. liger_kernel/transformers/multi_token_attention.py +1 -1
  81. liger_kernel/transformers/poly_norm.py +42 -0
  82. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  83. liger_kernel/transformers/rms_norm.py +15 -3
  84. liger_kernel/transformers/rope.py +45 -1
  85. liger_kernel/transformers/softmax.py +1 -1
  86. liger_kernel/transformers/sparsemax.py +1 -1
  87. liger_kernel/transformers/swiglu.py +18 -1
  88. liger_kernel/transformers/tiled_mlp.py +133 -0
  89. liger_kernel/transformers/tvd.py +1 -1
  90. liger_kernel/utils.py +52 -0
  91. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/METADATA +12 -3
  92. liger_kernel_nightly-0.6.4.dev20260107111351.dist-info/RECORD +130 -0
  93. liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/RECORD +0 -107
  94. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/LICENSE +0 -0
  95. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/NOTICE +0 -0
  96. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/WHEEL +0 -0
  97. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,6 @@
1
+ from typing import Tuple
2
+ from typing import Union
3
+
1
4
  import torch
2
5
  import torch.nn.functional as F
3
6
 
@@ -6,7 +9,13 @@ from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinear
6
9
 
7
10
  class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase):
8
11
  @staticmethod
9
- def distillation_loss_fn(student_logits, teacher_logits, beta=1.0):
12
+ def distillation_loss_fn(
13
+ student_logits,
14
+ teacher_logits,
15
+ target=None,
16
+ ignore_index=None,
17
+ beta=1.0,
18
+ ):
10
19
  """
11
20
  Compute Cosine loss (Cosine Similarity Loss).
12
21
  Args:
@@ -41,7 +50,8 @@ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase)
41
50
  temperature: float = 1.0,
42
51
  compiled: bool = True,
43
52
  chunk_size: int = 1024,
44
- ):
53
+ return_soft_hard_loss: bool = False,
54
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
45
55
  return super().forward(
46
56
  cls=cls,
47
57
  ctx=ctx,
@@ -59,11 +69,12 @@ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase)
59
69
  ignore_index=ignore_index,
60
70
  temperature=temperature,
61
71
  compiled=compiled,
72
+ return_soft_hard_loss=return_soft_hard_loss,
62
73
  )
63
74
 
64
75
  @staticmethod
65
- def backward(ctx, grad_output):
66
- grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
76
+ def backward(ctx, grad_output, *args):
77
+ grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
67
78
 
68
79
  return (
69
80
  *grads,
@@ -75,6 +86,7 @@ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase)
75
86
  None, # temperature
76
87
  None, # compiled
77
88
  None, # chunk_size
89
+ None, # return_soft_hard_loss
78
90
  )
79
91
 
80
92
 
@@ -88,6 +100,7 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
88
100
  temperature: float = 1.0,
89
101
  compiled: bool = True,
90
102
  chunk_size: int = 1024,
103
+ return_soft_hard_loss: bool = False,
91
104
  ):
92
105
  super().__init__()
93
106
  assert temperature != 0, "Temperature cannot be 0."
@@ -98,6 +111,7 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
98
111
  self.compiled = compiled
99
112
  self.beta = beta
100
113
  self.chunk_size = chunk_size
114
+ self.return_soft_hard_loss = return_soft_hard_loss
101
115
 
102
116
  def forward(
103
117
  self,
@@ -108,7 +122,7 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
108
122
  true_labels: torch.LongTensor,
109
123
  student_bias: torch.Tensor = None,
110
124
  teacher_bias: torch.Tensor = None,
111
- ) -> torch.Tensor:
125
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
112
126
  return LigerFusedLinearCosineSimilarityFunction.apply(
113
127
  student_input,
114
128
  student_weight,
@@ -124,4 +138,5 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
124
138
  self.temperature,
125
139
  self.compiled,
126
140
  self.chunk_size,
141
+ self.return_soft_hard_loss,
127
142
  )
@@ -1,5 +1,7 @@
1
1
  from abc import abstractmethod
2
2
  from functools import partial
3
+ from typing import Tuple
4
+ from typing import Union
3
5
 
4
6
  import torch
5
7
 
@@ -11,6 +13,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
11
13
  def distillation_loss_fn(
12
14
  student_logits,
13
15
  teacher_logits,
16
+ target=None,
17
+ ignore_index=None,
14
18
  ):
15
19
  """
16
20
  Compute distillation loss.
@@ -130,10 +134,15 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
130
134
  )
131
135
  student_logits_chunk = torch.cat([student_logits_chunk, pad_tensor], dim=-1)
132
136
 
133
- hard_loss /= full_target.shape[0]
137
+ num_valid_tokens = (full_target != ignore_index).sum()
138
+ num_valid_tokens = num_valid_tokens.clamp_min(1) # to avoid division by zero
134
139
 
135
- soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, **loss_kwargs)
136
- soft_loss /= full_target.shape[0]
140
+ hard_loss /= num_valid_tokens
141
+
142
+ soft_loss = distillation_loss_fn(
143
+ student_logits_chunk, teacher_logits_chunk, target=target_chunk, ignore_index=ignore_index, **loss_kwargs
144
+ )
145
+ soft_loss /= num_valid_tokens
137
146
 
138
147
  loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
139
148
  return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk)
@@ -157,8 +166,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
157
166
  compute_ce_loss=True,
158
167
  temperature=1.0,
159
168
  compiled=True,
169
+ return_soft_hard_loss=False,
160
170
  **loss_kwargs,
161
- ):
171
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
162
172
  """
163
173
  Base class for fused linear layer with distillation loss.
164
174
  Only need to compute gradients for student model.
@@ -180,6 +190,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
180
190
  compute_ce_loss (bool): Whether to compute CE loss.
181
191
  temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
182
192
  compiled (bool): Whether to use torch compile for chunk accumulation.
193
+ return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
183
194
  loss_kwargs (dict): Other possible arguments that a loss function might need
184
195
  """
185
196
  CHUNK_SIZE = chunk_size
@@ -187,6 +198,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
187
198
  grad_inputs = []
188
199
  grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
189
200
  loss_acc = torch.zeros((), device=student_input.device)
201
+ soft_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
202
+ hard_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
190
203
 
191
204
  loss_func_to_call = partial(
192
205
  LigerFusedLinearDistillationBase._compute_loss,
@@ -247,6 +260,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
247
260
  )
248
261
  grad_weight.add_(chunk_grad_weight)
249
262
  loss_acc.add_(chunk_loss)
263
+ if return_soft_hard_loss:
264
+ soft_loss_acc.add_(chunk_soft_loss)
265
+ hard_loss_acc.add_(chunk_hard_loss)
250
266
  return chunk_grad_input
251
267
 
252
268
  if compiled:
@@ -268,10 +284,12 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
268
284
  grad_weight,
269
285
  grad_bias,
270
286
  )
287
+ if return_soft_hard_loss:
288
+ return loss_acc, soft_loss_acc, hard_loss_acc
271
289
  return loss_acc
272
290
 
273
291
  @staticmethod
274
- def backward(ctx, grad_output):
292
+ def backward(ctx, grad_output, *args):
275
293
  grad_input, grad_weight, grad_bias = ctx.saved_tensors
276
294
  if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
277
295
  grad_input = grad_input * grad_output
@@ -32,7 +32,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
32
32
  epsilon_low=0.2,
33
33
  epsilon_high=0.2,
34
34
  beta=0.04,
35
- loss_type="bnpo",
35
+ loss_type="dapo",
36
36
  max_completion_length=None,
37
37
  importance_sampling_level="token",
38
38
  temperature=1.0,
@@ -60,7 +60,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
60
60
  epsilon_low: Lower bound for clipping the importance sampling ratio
61
61
  epsilon_high: Upper bound for clipping the importance sampling ratio
62
62
  beta: Weight for the KL penalty
63
- loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo")
63
+ loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo")
64
64
  max_completion_length: Maximum completion length required for "dr_grpo"
65
65
  temperature: Temperature for the logits
66
66
  compiled: Whether to use torch compile
@@ -244,6 +244,21 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
244
244
 
245
245
  return loss_acc, tuple(final_metrics)
246
246
 
247
+ @staticmethod
248
+ def _compute_dapo_normalizer(attention_mask):
249
+ """Global active tokens averaged per process."""
250
+ normalizer = attention_mask.to(torch.float32).sum()
251
+ world_size = 1
252
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
253
+ import torch.distributed as dist
254
+
255
+ normalizer = normalizer.clone()
256
+ dist.all_reduce(normalizer, op=dist.ReduceOp.SUM)
257
+ world_size = dist.get_world_size()
258
+
259
+ normalizer = normalizer / world_size
260
+ return torch.clamp(normalizer, min=1.0)
261
+
247
262
  @staticmethod
248
263
  def _compute_chunk_loss(
249
264
  input_chunk,
@@ -261,7 +276,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
261
276
  epsilon_low=0.2,
262
277
  epsilon_high=0.2,
263
278
  beta=0.04,
264
- loss_type="bnpo",
279
+ loss_type="dapo",
265
280
  max_completion_length=None,
266
281
  importance_sampling_level="token",
267
282
  temperature=1.0,
@@ -341,10 +356,11 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
341
356
  None, # grad_epsilon_low
342
357
  None, # grad_epsilon_high
343
358
  None, # grad_beta
359
+ None, # grad_loss_type
360
+ None, # grad_max_completion_length
361
+ None, # grad_importance_sampling_level
344
362
  None, # grad_temperature
345
363
  None, # grad_compiled
346
364
  None, # grad_use_ref_model
347
365
  None, # grad_chunk_size
348
- None, # grad_loss_type
349
- None, # grad_max_completion_length
350
366
  )
@@ -29,7 +29,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
29
29
  epsilon_low=0.2,
30
30
  epsilon_high=0.2,
31
31
  beta=0.04,
32
- loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
32
+ loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo"]
33
33
  max_completion_length=None, # Required for dr_grpo
34
34
  importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
35
35
  **kwargs,
@@ -94,6 +94,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
94
94
  if max_completion_length is None:
95
95
  raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
96
96
  loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
97
+ elif loss_type == "dapo":
98
+ loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask)
99
+ loss = (per_token_loss * attention_mask).sum() / loss_normalizer
97
100
  else:
98
101
  raise ValueError(f"Unknown loss type: {loss_type}")
99
102
 
@@ -135,7 +138,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
135
138
  beta=0.04,
136
139
  epsilon_low=0.2,
137
140
  epsilon_high=0.2,
138
- loss_type="bnpo",
141
+ loss_type="dapo",
139
142
  max_completion_length=None,
140
143
  importance_sampling_level="token",
141
144
  temperature=1.0,
@@ -157,7 +160,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
157
160
  ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
158
161
  ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
159
162
  beta (float): Weight for the KL penalty
160
- loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
163
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
161
164
  max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
162
165
  importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
163
166
  temperature (float): Temperature for the logits
@@ -235,7 +238,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
235
238
  chunk_size: int = 1,
236
239
  epsilon_low: float = 0.2,
237
240
  epsilon_high: float = 0.2,
238
- loss_type: str = "bnpo",
241
+ loss_type: str = "dapo",
239
242
  max_completion_length: Optional[int] = None,
240
243
  importance_sampling_level: str = "token",
241
244
  temperature: float = 1.0,
@@ -248,7 +251,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
248
251
  chunk_size (int): Size of chunks for processing.
249
252
  epsilon_low (float): Lower bound for the importance sampling ratio.
250
253
  epsilon_high (float): Upper bound for the importance sampling ratio.
251
- loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
254
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
252
255
  max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
253
256
  importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
254
257
  temperature (float): Temperature for the logits.
@@ -1,5 +1,8 @@
1
1
  import math
2
2
 
3
+ from typing import Tuple
4
+ from typing import Union
5
+
3
6
  import torch
4
7
  import torch.nn.functional as F
5
8
 
@@ -8,35 +11,50 @@ from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinear
8
11
 
9
12
  class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
10
13
  @staticmethod
11
- def distillation_loss_fn(student_logits, teacher_logits, beta=0.5):
14
+ def distillation_loss_fn(student_logits, teacher_logits, beta=0.5, target=None, ignore_index=-100):
12
15
  """
13
16
  Compute JSD loss (Jensen-Shannon Divergence Loss).
14
17
  Args:
15
18
  student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
16
19
  teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
17
20
  beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
21
+ target (torch.Tensor): Target labels for masking. Shape: (chunk_size,).
22
+ ignore_index (int): Index to ignore in loss computation.
18
23
  Returns:
19
24
  torch.Tensor: Jensen-Shannon Divergence loss
25
+ Note:
26
+ - Uses reduction="none" to preserve per-token losses for masking
27
+ - KL divergence requires summing over vocab dimension (not mean)
28
+ - Masking excludes padding/prompt tokens from loss computation
20
29
  """
21
30
  student_log_probs = F.log_softmax(student_logits, dim=-1)
22
31
  teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
23
32
 
24
33
  if beta == 0:
25
- jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="sum", log_target=True)
34
+ jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
26
35
  elif beta == 1:
27
- jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
36
+ jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
28
37
  else:
29
38
  # Compute probabilities (only required for mean calculation)
30
39
  log_mean_probs = torch.logsumexp(
31
40
  torch.stack([student_log_probs + math.log(1 - beta), teacher_log_probs + math.log(beta)], dim=0), dim=0
32
41
  )
33
42
 
34
- student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
35
- teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
43
+ student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="none", log_target=True)
44
+ teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="none", log_target=True)
36
45
 
37
46
  # JSD is the weighted average of the KL divergences
38
47
  jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
39
- return jsd_loss
48
+
49
+ # Sum over vocab dimension (KL divergence definition)
50
+ jsd_loss = jsd_loss.sum(dim=-1) # (chunk_size,)
51
+
52
+ # Apply ignore_index mask
53
+ if target is not None:
54
+ mask = target != ignore_index
55
+ jsd_loss = jsd_loss.masked_fill(~mask, 0.0)
56
+
57
+ return jsd_loss.sum()
40
58
 
41
59
  @classmethod
42
60
  def forward(
@@ -56,6 +74,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
56
74
  temperature: float = 1.0,
57
75
  compiled: bool = True,
58
76
  chunk_size: int = 1024,
77
+ return_soft_hard_loss: bool = False,
59
78
  ):
60
79
  """
61
80
  Fused linear layer with JSD distillation loss.
@@ -72,8 +91,9 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
72
91
  temperature (float): Temperature for softening/sharpening distributions
73
92
  compiled (bool): Whether to use torch compile
74
93
  chunk_size (int): Size of chunks for processing.
94
+ return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
75
95
  Returns:
76
- torch.Tensor: Computed loss
96
+ torch.Tensor: Computed loss, or tuple (loss, soft_loss, hard_loss) if return_soft_hard_loss=True
77
97
  """
78
98
  return super().forward(
79
99
  cls=cls,
@@ -92,11 +112,12 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
92
112
  ignore_index=ignore_index,
93
113
  temperature=temperature,
94
114
  compiled=compiled,
115
+ return_soft_hard_loss=return_soft_hard_loss,
95
116
  )
96
117
 
97
118
  @staticmethod
98
- def backward(ctx, grad_output):
99
- grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
119
+ def backward(ctx, grad_output, *args):
120
+ grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
100
121
 
101
122
  return (
102
123
  *grads,
@@ -108,6 +129,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
108
129
  None, # temperature
109
130
  None, # compiled
110
131
  None, # chunk_size
132
+ None, # return_soft_hard_loss
111
133
  )
112
134
 
113
135
 
@@ -125,6 +147,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
125
147
  temperature: float = 1.0,
126
148
  compiled: bool = True,
127
149
  chunk_size: int = 1024,
150
+ return_soft_hard_loss: bool = False,
128
151
  ):
129
152
  """
130
153
  Args:
@@ -135,6 +158,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
135
158
  compiled (bool): Whether to use torch compile
136
159
  beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
137
160
  chunk_size (int): Size of chunks for processing.
161
+ return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
138
162
  """
139
163
  super().__init__()
140
164
  assert temperature != 0, "Temperature cannot be 0."
@@ -145,6 +169,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
145
169
  self.compiled = compiled
146
170
  self.beta = beta
147
171
  self.chunk_size = chunk_size
172
+ self.return_soft_hard_loss = return_soft_hard_loss
148
173
 
149
174
  def forward(
150
175
  self,
@@ -155,7 +180,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
155
180
  true_labels: torch.LongTensor,
156
181
  student_bias: torch.Tensor = None,
157
182
  teacher_bias: torch.Tensor = None,
158
- ) -> torch.Tensor:
183
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
159
184
  """
160
185
  Compute the JSD distillation loss.
161
186
 
@@ -167,7 +192,9 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
167
192
  true_labels (torch.LongTensor): Target labels tensor
168
193
 
169
194
  Returns:
170
- torch.Tensor: Computed loss
195
+ torch.Tensor or Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
196
+ If return_soft_hard_loss is False: Computed combined loss
197
+ If return_soft_hard_loss is True: Tuple of (combined_loss, soft_loss, hard_loss)
171
198
  """
172
199
  return LigerFusedLinearJSDFunction.apply(
173
200
  student_input,
@@ -184,4 +211,5 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
184
211
  self.temperature,
185
212
  self.compiled,
186
213
  self.chunk_size,
214
+ self.return_soft_hard_loss,
187
215
  )
@@ -0,0 +1,141 @@
1
+ """
2
+ Liger-Kernel operators with automatic vendor-specific replacement.
3
+
4
+ This module provides two ways to import operators:
5
+
6
+ 1. Import from this package (recommended for Function classes):
7
+ from liger_kernel.ops import LigerGELUMulFunction
8
+
9
+ This automatically uses vendor-specific implementation if available.
10
+
11
+ 2. Import from submodules (for kernel functions or specific access):
12
+ from liger_kernel.ops.geglu import geglu_forward, geglu_backward
13
+
14
+ This always uses the default implementation (no auto-replacement).
15
+
16
+ The replacement mechanism:
17
+ 1. Default implementations are imported from individual modules (e.g., geglu.py)
18
+ 2. On module load, device is detected via infer_device()
19
+ 3. If running on a supported vendor device (npu, xpu, etc.), the default
20
+ implementations are replaced with vendor-specific ones
21
+ 4. All subsequent imports from this package get the replaced versions
22
+
23
+ Note: Direct imports from submodules (e.g., from liger_kernel.ops.geglu import ...)
24
+ are NOT affected by the replacement mechanism.
25
+ """
26
+
27
+ # =============================================================================
28
+ # Import default implementations
29
+ # Both Function classes and kernel functions are imported here.
30
+ # All of these can be replaced by vendor-specific implementations.
31
+ # =============================================================================
32
+
33
+ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction # noqa: F401
34
+ from liger_kernel.ops.cross_entropy import cross_entropy_backward # noqa: F401
35
+ from liger_kernel.ops.cross_entropy import cross_entropy_forward # noqa: F401
36
+ from liger_kernel.ops.dyt import LigerDyTFunction # noqa: F401
37
+ from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction # noqa: F401
38
+ from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction # noqa: F401
39
+ from liger_kernel.ops.fused_add_rms_norm import fused_add_rms_norm_backward # noqa: F401
40
+ from liger_kernel.ops.fused_add_rms_norm import fused_add_rms_norm_forward # noqa: F401
41
+ from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction # noqa: F401
42
+ from liger_kernel.ops.fused_linear_cross_entropy import fused_linear_cross_entropy_backward # noqa: F401
43
+ from liger_kernel.ops.fused_linear_cross_entropy import fused_linear_cross_entropy_forward # noqa: F401
44
+ from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction # noqa: F401
45
+ from liger_kernel.ops.fused_linear_jsd import fused_linear_jsd_backward # noqa: F401
46
+ from liger_kernel.ops.fused_linear_jsd import fused_linear_jsd_forward # noqa: F401
47
+ from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction # noqa: F401
48
+ from liger_kernel.ops.geglu import LigerGELUMulFunction # noqa: F401
49
+ from liger_kernel.ops.geglu import geglu_backward # noqa: F401
50
+ from liger_kernel.ops.geglu import geglu_forward # noqa: F401
51
+ from liger_kernel.ops.group_norm import LigerGroupNormFunction # noqa: F401
52
+ from liger_kernel.ops.group_norm import group_norm_backward # noqa: F401
53
+ from liger_kernel.ops.group_norm import group_norm_forward # noqa: F401
54
+ from liger_kernel.ops.grpo_loss import GrpoLossFunction # noqa: F401
55
+ from liger_kernel.ops.jsd import LigerJSDFunction # noqa: F401
56
+ from liger_kernel.ops.jsd import jsd_backward # noqa: F401
57
+ from liger_kernel.ops.jsd import jsd_forward # noqa: F401
58
+ from liger_kernel.ops.kl_div import LigerKLDivLossFunction # noqa: F401
59
+ from liger_kernel.ops.layer_norm import LigerLayerNormFunction # noqa: F401
60
+ from liger_kernel.ops.layer_norm import layer_norm_backward # noqa: F401
61
+ from liger_kernel.ops.layer_norm import layer_norm_forward # noqa: F401
62
+ from liger_kernel.ops.llama4_rope import LigerLlama4RopeFunction # noqa: F401
63
+ from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction # noqa: F401
64
+ from liger_kernel.ops.poly_norm import LigerPolyNormFunction # noqa: F401
65
+ from liger_kernel.ops.poly_norm import poly_norm_backward # noqa: F401
66
+ from liger_kernel.ops.poly_norm import poly_norm_forward # noqa: F401
67
+ from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction # noqa: F401
68
+ from liger_kernel.ops.rms_norm import LigerRMSNormFunction # noqa: F401
69
+ from liger_kernel.ops.rms_norm import rms_norm_backward # noqa: F401
70
+ from liger_kernel.ops.rms_norm import rms_norm_forward # noqa: F401
71
+ from liger_kernel.ops.rope import LigerRopeFunction # noqa: F401
72
+ from liger_kernel.ops.rope import rope_backward # noqa: F401
73
+ from liger_kernel.ops.rope import rope_forward # noqa: F401
74
+ from liger_kernel.ops.softmax import LigerSoftmaxFunction # noqa: F401
75
+ from liger_kernel.ops.sparsemax import LigerSparsemaxFunction # noqa: F401
76
+ from liger_kernel.ops.swiglu import LigerSiLUMulFunction # noqa: F401
77
+ from liger_kernel.ops.swiglu import swiglu_backward # noqa: F401
78
+ from liger_kernel.ops.swiglu import swiglu_forward # noqa: F401
79
+ from liger_kernel.ops.tiled_mlp import LigerTiledMLPFunction # noqa: F401
80
+ from liger_kernel.ops.tiled_mlp import apply_tiled_mlp # noqa: F401
81
+ from liger_kernel.ops.tvd import LigerTVDLossFunction # noqa: F401
82
+
83
+ # NOTE: __all__ is intentionally NOT defined.
84
+ # - Import from this package (liger_kernel.ops) -> subject to vendor replacement
85
+ # - Import from submodules (liger_kernel.ops.geglu) -> always use default implementation
86
+
87
+
88
+ # =============================================================================
89
+ # Vendor-specific replacement logic
90
+ # =============================================================================
91
+
92
+
93
+ def _replace_with_vendor_ops():
94
+ """
95
+ Replace/add vendor-specific operator implementations.
96
+
97
+ This function is called automatically on module load. It:
98
+ 1. Detects the current device (cuda, npu, xpu, etc.)
99
+ 2. Looks up the vendor for that device via VENDOR_REGISTRY
100
+ 3. Loads and applies vendor-specific implementations
101
+
102
+ Vendor implementations should be placed in:
103
+ liger_kernel/ops/backends/_<vendor>/ops/
104
+
105
+ If the vendor module defines __all__, only those symbols are exported.
106
+ Otherwise, all public symbols (not starting with _) are auto-discovered.
107
+
108
+ Note: Vendor can both override existing ops AND add new vendor-specific ops.
109
+ """
110
+ from liger_kernel.ops.backends import get_vendor_for_device
111
+ from liger_kernel.utils import infer_device
112
+
113
+ device = infer_device()
114
+
115
+ # Look up vendor info for this device
116
+ vendor_info = get_vendor_for_device(device)
117
+ if vendor_info is None:
118
+ return
119
+
120
+ try:
121
+ import importlib
122
+
123
+ vendor_ops = importlib.import_module(vendor_info.module_path)
124
+
125
+ # Get names to export: use __all__ if defined, otherwise auto-discover
126
+ names_to_export = getattr(vendor_ops, "__all__", None)
127
+
128
+ if names_to_export is None:
129
+ # Auto-discover: find all public symbols (classes and functions)
130
+ names_to_export = [name for name in dir(vendor_ops) if not name.startswith("_")]
131
+
132
+ # Replace or add to this module's globals
133
+ for name in names_to_export:
134
+ globals()[name] = getattr(vendor_ops, name)
135
+
136
+ except ImportError:
137
+ # Vendor module not available, use default implementations
138
+ pass
139
+
140
+
141
+ _replace_with_vendor_ops()