liger-kernel-nightly 0.6.2.dev20251011154427__py3-none-any.whl → 0.6.4.dev20251202054858__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (67) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +18 -5
  6. liger_kernel/ops/cross_entropy.py +65 -11
  7. liger_kernel/ops/dyt.py +5 -2
  8. liger_kernel/ops/fused_add_rms_norm.py +5 -1
  9. liger_kernel/ops/fused_linear_cross_entropy.py +43 -13
  10. liger_kernel/ops/geglu.py +2 -1
  11. liger_kernel/ops/group_norm.py +2 -1
  12. liger_kernel/ops/grpo_loss.py +3 -1
  13. liger_kernel/ops/layer_norm.py +86 -66
  14. liger_kernel/ops/poly_norm.py +390 -0
  15. liger_kernel/ops/rms_norm.py +7 -2
  16. liger_kernel/ops/tiled_mlp.py +136 -0
  17. liger_kernel/ops/utils.py +2 -0
  18. liger_kernel/transformers/__init__.py +27 -0
  19. liger_kernel/transformers/cross_entropy.py +8 -3
  20. liger_kernel/transformers/functional.py +29 -6
  21. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -3
  22. liger_kernel/transformers/grpo_loss.py +56 -1
  23. liger_kernel/transformers/model/falcon_h1.py +19 -5
  24. liger_kernel/transformers/model/gemma.py +17 -6
  25. liger_kernel/transformers/model/gemma2.py +14 -5
  26. liger_kernel/transformers/model/gemma3.py +25 -12
  27. liger_kernel/transformers/model/glm4.py +16 -4
  28. liger_kernel/transformers/model/glm4v.py +16 -4
  29. liger_kernel/transformers/model/glm4v_moe.py +23 -4
  30. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  31. liger_kernel/transformers/model/internvl.py +12 -5
  32. liger_kernel/transformers/model/llama.py +14 -5
  33. liger_kernel/transformers/model/llama4.py +16 -4
  34. liger_kernel/transformers/model/llava.py +12 -4
  35. liger_kernel/transformers/model/loss_utils.py +31 -3
  36. liger_kernel/transformers/model/mistral.py +15 -6
  37. liger_kernel/transformers/model/mixtral.py +16 -7
  38. liger_kernel/transformers/model/mllama.py +12 -4
  39. liger_kernel/transformers/model/olmo2.py +16 -4
  40. liger_kernel/transformers/model/olmo3.py +142 -0
  41. liger_kernel/transformers/model/output_classes.py +147 -0
  42. liger_kernel/transformers/model/paligemma.py +22 -5
  43. liger_kernel/transformers/model/phi3.py +14 -7
  44. liger_kernel/transformers/model/qwen2.py +16 -3
  45. liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
  46. liger_kernel/transformers/model/qwen2_vl.py +16 -4
  47. liger_kernel/transformers/model/qwen3.py +20 -5
  48. liger_kernel/transformers/model/qwen3_moe.py +19 -5
  49. liger_kernel/transformers/model/qwen3_next.py +146 -0
  50. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  51. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  52. liger_kernel/transformers/model/smollm3.py +15 -6
  53. liger_kernel/transformers/model/smolvlm.py +158 -0
  54. liger_kernel/transformers/monkey_patch.py +594 -19
  55. liger_kernel/transformers/poly_norm.py +42 -0
  56. liger_kernel/transformers/rms_norm.py +7 -0
  57. liger_kernel/transformers/rope.py +43 -0
  58. liger_kernel/transformers/swiglu.py +17 -0
  59. liger_kernel/transformers/tiled_mlp.py +133 -0
  60. liger_kernel/utils.py +25 -0
  61. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/METADATA +4 -1
  62. liger_kernel_nightly-0.6.4.dev20251202054858.dist-info/RECORD +118 -0
  63. liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/RECORD +0 -107
  64. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/LICENSE +0 -0
  65. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/NOTICE +0 -0
  66. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/WHEEL +0 -0
  67. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,6 @@
1
+ from typing import Tuple
2
+ from typing import Union
3
+
1
4
  import torch
2
5
  import torch.nn.functional as F
3
6
 
@@ -41,7 +44,8 @@ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase)
41
44
  temperature: float = 1.0,
42
45
  compiled: bool = True,
43
46
  chunk_size: int = 1024,
44
- ):
47
+ return_soft_hard_loss: bool = False,
48
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
45
49
  return super().forward(
46
50
  cls=cls,
47
51
  ctx=ctx,
@@ -59,11 +63,12 @@ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase)
59
63
  ignore_index=ignore_index,
60
64
  temperature=temperature,
61
65
  compiled=compiled,
66
+ return_soft_hard_loss=return_soft_hard_loss,
62
67
  )
63
68
 
64
69
  @staticmethod
65
- def backward(ctx, grad_output):
66
- grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
70
+ def backward(ctx, grad_output, *args):
71
+ grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
67
72
 
68
73
  return (
69
74
  *grads,
@@ -75,6 +80,7 @@ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase)
75
80
  None, # temperature
76
81
  None, # compiled
77
82
  None, # chunk_size
83
+ None, # return_soft_hard_loss
78
84
  )
79
85
 
80
86
 
@@ -88,6 +94,7 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
88
94
  temperature: float = 1.0,
89
95
  compiled: bool = True,
90
96
  chunk_size: int = 1024,
97
+ return_soft_hard_loss: bool = False,
91
98
  ):
92
99
  super().__init__()
93
100
  assert temperature != 0, "Temperature cannot be 0."
@@ -98,6 +105,7 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
98
105
  self.compiled = compiled
99
106
  self.beta = beta
100
107
  self.chunk_size = chunk_size
108
+ self.return_soft_hard_loss = return_soft_hard_loss
101
109
 
102
110
  def forward(
103
111
  self,
@@ -108,7 +116,7 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
108
116
  true_labels: torch.LongTensor,
109
117
  student_bias: torch.Tensor = None,
110
118
  teacher_bias: torch.Tensor = None,
111
- ) -> torch.Tensor:
119
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
112
120
  return LigerFusedLinearCosineSimilarityFunction.apply(
113
121
  student_input,
114
122
  student_weight,
@@ -124,4 +132,5 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
124
132
  self.temperature,
125
133
  self.compiled,
126
134
  self.chunk_size,
135
+ self.return_soft_hard_loss,
127
136
  )
@@ -1,5 +1,7 @@
1
1
  from abc import abstractmethod
2
2
  from functools import partial
3
+ from typing import Tuple
4
+ from typing import Union
3
5
 
4
6
  import torch
5
7
 
@@ -157,8 +159,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
157
159
  compute_ce_loss=True,
158
160
  temperature=1.0,
159
161
  compiled=True,
162
+ return_soft_hard_loss=False,
160
163
  **loss_kwargs,
161
- ):
164
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
162
165
  """
163
166
  Base class for fused linear layer with distillation loss.
164
167
  Only need to compute gradients for student model.
@@ -180,6 +183,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
180
183
  compute_ce_loss (bool): Whether to compute CE loss.
181
184
  temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
182
185
  compiled (bool): Whether to use torch compile for chunk accumulation.
186
+ return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
183
187
  loss_kwargs (dict): Other possible arguments that a loss function might need
184
188
  """
185
189
  CHUNK_SIZE = chunk_size
@@ -187,6 +191,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
187
191
  grad_inputs = []
188
192
  grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
189
193
  loss_acc = torch.zeros((), device=student_input.device)
194
+ soft_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
195
+ hard_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
190
196
 
191
197
  loss_func_to_call = partial(
192
198
  LigerFusedLinearDistillationBase._compute_loss,
@@ -247,6 +253,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
247
253
  )
248
254
  grad_weight.add_(chunk_grad_weight)
249
255
  loss_acc.add_(chunk_loss)
256
+ if return_soft_hard_loss:
257
+ soft_loss_acc.add_(chunk_soft_loss)
258
+ hard_loss_acc.add_(chunk_hard_loss)
250
259
  return chunk_grad_input
251
260
 
252
261
  if compiled:
@@ -268,10 +277,12 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
268
277
  grad_weight,
269
278
  grad_bias,
270
279
  )
280
+ if return_soft_hard_loss:
281
+ return loss_acc, soft_loss_acc, hard_loss_acc
271
282
  return loss_acc
272
283
 
273
284
  @staticmethod
274
- def backward(ctx, grad_output):
285
+ def backward(ctx, grad_output, *args):
275
286
  grad_input, grad_weight, grad_bias = ctx.saved_tensors
276
287
  if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
277
288
  grad_input = grad_input * grad_output
@@ -32,7 +32,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
32
32
  epsilon_low=0.2,
33
33
  epsilon_high=0.2,
34
34
  beta=0.04,
35
- loss_type="bnpo",
35
+ loss_type="dapo",
36
36
  max_completion_length=None,
37
37
  importance_sampling_level="token",
38
38
  temperature=1.0,
@@ -60,7 +60,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
60
60
  epsilon_low: Lower bound for clipping the importance sampling ratio
61
61
  epsilon_high: Upper bound for clipping the importance sampling ratio
62
62
  beta: Weight for the KL penalty
63
- loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo")
63
+ loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo")
64
64
  max_completion_length: Maximum completion length required for "dr_grpo"
65
65
  temperature: Temperature for the logits
66
66
  compiled: Whether to use torch compile
@@ -244,6 +244,21 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
244
244
 
245
245
  return loss_acc, tuple(final_metrics)
246
246
 
247
+ @staticmethod
248
+ def _compute_dapo_normalizer(attention_mask):
249
+ """Global active tokens averaged per process."""
250
+ normalizer = attention_mask.to(torch.float32).sum()
251
+ world_size = 1
252
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
253
+ import torch.distributed as dist
254
+
255
+ normalizer = normalizer.clone()
256
+ dist.all_reduce(normalizer, op=dist.ReduceOp.SUM)
257
+ world_size = dist.get_world_size()
258
+
259
+ normalizer = normalizer / world_size
260
+ return torch.clamp(normalizer, min=1.0)
261
+
247
262
  @staticmethod
248
263
  def _compute_chunk_loss(
249
264
  input_chunk,
@@ -261,7 +276,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
261
276
  epsilon_low=0.2,
262
277
  epsilon_high=0.2,
263
278
  beta=0.04,
264
- loss_type="bnpo",
279
+ loss_type="dapo",
265
280
  max_completion_length=None,
266
281
  importance_sampling_level="token",
267
282
  temperature=1.0,
@@ -341,10 +356,11 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
341
356
  None, # grad_epsilon_low
342
357
  None, # grad_epsilon_high
343
358
  None, # grad_beta
359
+ None, # grad_loss_type
360
+ None, # grad_max_completion_length
361
+ None, # grad_importance_sampling_level
344
362
  None, # grad_temperature
345
363
  None, # grad_compiled
346
364
  None, # grad_use_ref_model
347
365
  None, # grad_chunk_size
348
- None, # grad_loss_type
349
- None, # grad_max_completion_length
350
366
  )
@@ -29,7 +29,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
29
29
  epsilon_low=0.2,
30
30
  epsilon_high=0.2,
31
31
  beta=0.04,
32
- loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
32
+ loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo"]
33
33
  max_completion_length=None, # Required for dr_grpo
34
34
  importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
35
35
  **kwargs,
@@ -94,6 +94,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
94
94
  if max_completion_length is None:
95
95
  raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
96
96
  loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
97
+ elif loss_type == "dapo":
98
+ loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask)
99
+ loss = (per_token_loss * attention_mask).sum() / loss_normalizer
97
100
  else:
98
101
  raise ValueError(f"Unknown loss type: {loss_type}")
99
102
 
@@ -135,7 +138,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
135
138
  beta=0.04,
136
139
  epsilon_low=0.2,
137
140
  epsilon_high=0.2,
138
- loss_type="bnpo",
141
+ loss_type="dapo",
139
142
  max_completion_length=None,
140
143
  importance_sampling_level="token",
141
144
  temperature=1.0,
@@ -157,7 +160,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
157
160
  ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
158
161
  ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
159
162
  beta (float): Weight for the KL penalty
160
- loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
163
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
161
164
  max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
162
165
  importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
163
166
  temperature (float): Temperature for the logits
@@ -235,7 +238,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
235
238
  chunk_size: int = 1,
236
239
  epsilon_low: float = 0.2,
237
240
  epsilon_high: float = 0.2,
238
- loss_type: str = "bnpo",
241
+ loss_type: str = "dapo",
239
242
  max_completion_length: Optional[int] = None,
240
243
  importance_sampling_level: str = "token",
241
244
  temperature: float = 1.0,
@@ -248,7 +251,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
248
251
  chunk_size (int): Size of chunks for processing.
249
252
  epsilon_low (float): Lower bound for the importance sampling ratio.
250
253
  epsilon_high (float): Upper bound for the importance sampling ratio.
251
- loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
254
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
252
255
  max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
253
256
  importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
254
257
  temperature (float): Temperature for the logits.
@@ -1,5 +1,8 @@
1
1
  import math
2
2
 
3
+ from typing import Tuple
4
+ from typing import Union
5
+
3
6
  import torch
4
7
  import torch.nn.functional as F
5
8
 
@@ -56,6 +59,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
56
59
  temperature: float = 1.0,
57
60
  compiled: bool = True,
58
61
  chunk_size: int = 1024,
62
+ return_soft_hard_loss: bool = False,
59
63
  ):
60
64
  """
61
65
  Fused linear layer with JSD distillation loss.
@@ -72,8 +76,9 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
72
76
  temperature (float): Temperature for softening/sharpening distributions
73
77
  compiled (bool): Whether to use torch compile
74
78
  chunk_size (int): Size of chunks for processing.
79
+ return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
75
80
  Returns:
76
- torch.Tensor: Computed loss
81
+ torch.Tensor: Computed loss, or tuple (loss, soft_loss, hard_loss) if return_soft_hard_loss=True
77
82
  """
78
83
  return super().forward(
79
84
  cls=cls,
@@ -92,11 +97,12 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
92
97
  ignore_index=ignore_index,
93
98
  temperature=temperature,
94
99
  compiled=compiled,
100
+ return_soft_hard_loss=return_soft_hard_loss,
95
101
  )
96
102
 
97
103
  @staticmethod
98
- def backward(ctx, grad_output):
99
- grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
104
+ def backward(ctx, grad_output, *args):
105
+ grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
100
106
 
101
107
  return (
102
108
  *grads,
@@ -108,6 +114,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
108
114
  None, # temperature
109
115
  None, # compiled
110
116
  None, # chunk_size
117
+ None, # return_soft_hard_loss
111
118
  )
112
119
 
113
120
 
@@ -125,6 +132,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
125
132
  temperature: float = 1.0,
126
133
  compiled: bool = True,
127
134
  chunk_size: int = 1024,
135
+ return_soft_hard_loss: bool = False,
128
136
  ):
129
137
  """
130
138
  Args:
@@ -135,6 +143,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
135
143
  compiled (bool): Whether to use torch compile
136
144
  beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
137
145
  chunk_size (int): Size of chunks for processing.
146
+ return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
138
147
  """
139
148
  super().__init__()
140
149
  assert temperature != 0, "Temperature cannot be 0."
@@ -145,6 +154,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
145
154
  self.compiled = compiled
146
155
  self.beta = beta
147
156
  self.chunk_size = chunk_size
157
+ self.return_soft_hard_loss = return_soft_hard_loss
148
158
 
149
159
  def forward(
150
160
  self,
@@ -155,7 +165,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
155
165
  true_labels: torch.LongTensor,
156
166
  student_bias: torch.Tensor = None,
157
167
  teacher_bias: torch.Tensor = None,
158
- ) -> torch.Tensor:
168
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
159
169
  """
160
170
  Compute the JSD distillation loss.
161
171
 
@@ -167,7 +177,9 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
167
177
  true_labels (torch.LongTensor): Target labels tensor
168
178
 
169
179
  Returns:
170
- torch.Tensor: Computed loss
180
+ torch.Tensor or Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
181
+ If return_soft_hard_loss is False: Computed combined loss
182
+ If return_soft_hard_loss is True: Tuple of (combined_loss, soft_loss, hard_loss)
171
183
  """
172
184
  return LigerFusedLinearJSDFunction.apply(
173
185
  student_input,
@@ -184,4 +196,5 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
184
196
  self.temperature,
185
197
  self.compiled,
186
198
  self.chunk_size,
199
+ self.return_soft_hard_loss,
187
200
  )
@@ -10,8 +10,9 @@ from liger_kernel.ops.utils import compare_version
10
10
  from liger_kernel.ops.utils import element_mul_kernel
11
11
  from liger_kernel.ops.utils import is_hip
12
12
  from liger_kernel.utils import infer_device
13
+ from liger_kernel.utils import is_npu_available
13
14
 
14
- if compare_version("triton", operator.ge, "3.0.0"):
15
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
15
16
  try:
16
17
  # typical import path with dispatch available
17
18
  from triton.language.extra.libdevice import tanh
@@ -32,6 +33,8 @@ def liger_cross_entropy_kernel(
32
33
  loss_ptr,
33
34
  z_loss_ptr,
34
35
  loss_stride,
36
+ token_accuracy_ptr,
37
+ token_accuracy_stride,
35
38
  n_cols,
36
39
  n_non_ignore,
37
40
  sum_non_ignore_weight,
@@ -42,6 +45,7 @@ def liger_cross_entropy_kernel(
42
45
  reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
43
46
  softcap,
44
47
  RETURN_Z_LOSS: tl.constexpr,
48
+ RETURN_TOKEN_ACCURACY: tl.constexpr,
45
49
  BLOCK_SIZE: tl.constexpr,
46
50
  HAS_WEIGHT: tl.constexpr,
47
51
  HAS_SOFTCAPPING: tl.constexpr,
@@ -60,6 +64,8 @@ def liger_cross_entropy_kernel(
60
64
  loss_ptr: Pointer to tensor to store the loss.
61
65
  z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
62
66
  loss_stride (int): The stride of the loss tensor.
67
+ token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
68
+ token_accuracy_stride (int): The stride of the token accuracy tensor.
63
69
  n_cols (int): The number of columns in the input tensor.
64
70
  n_non_ignore (float): The number of non-ignored elements in the batch.
65
71
  sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
@@ -69,7 +75,8 @@ def liger_cross_entropy_kernel(
69
75
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
70
76
  reduction (str): The string for the reduction to apply
71
77
  softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
72
- RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
78
+ RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
79
+ RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
73
80
  BLOCK_SIZE (int): The block size for Triton operations.
74
81
  HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
75
82
  HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
@@ -92,11 +99,17 @@ def liger_cross_entropy_kernel(
92
99
  for i in range(0, n_cols, BLOCK_SIZE):
93
100
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
94
101
  tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
102
+ # For ignored tokens, set token accuracy to 0
103
+ if RETURN_TOKEN_ACCURACY:
104
+ token_accuracy_ptr += program_id * token_accuracy_stride
105
+ tl.store(token_accuracy_ptr, 0.0)
95
106
  return
96
107
 
97
108
  loss_ptr += program_id * loss_stride
98
109
  if RETURN_Z_LOSS:
99
110
  z_loss_ptr += program_id * loss_stride
111
+ if RETURN_TOKEN_ACCURACY:
112
+ token_accuracy_ptr += program_id * token_accuracy_stride
100
113
 
101
114
  if HAS_WEIGHT:
102
115
  weight_y = tl.load(weight_ptr + y).cast(tl.float32)
@@ -107,6 +120,7 @@ def liger_cross_entropy_kernel(
107
120
  # 3. [Online softmax] first pass: find max + sum
108
121
  m = float("-inf") # m is the max value. use the notation from the paper
109
122
  d = 0.0 # d is the sum. use the notation from the paper
123
+ argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
110
124
  ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
111
125
  if HAS_SOFTCAPPING:
112
126
  ori_X_y = softcap * tanh(ori_X_y / softcap)
@@ -127,6 +141,16 @@ def liger_cross_entropy_kernel(
127
141
  if HAS_SOFTCAPPING:
128
142
  X_block = softcap * tanh(X_block / softcap)
129
143
  block_max = tl.max(X_block)
144
+
145
+ # Track argmax for accuracy computation
146
+ if RETURN_TOKEN_ACCURACY and block_max > m:
147
+ # Find the index of the maximum value in this block
148
+ is_max_mask = X_block == block_max
149
+ # Mask out invalid indices with a value larger than n_cols
150
+ masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
151
+ # Get the first (smallest) index where max occurs
152
+ argmax_idx = tl.min(masked_offsets)
153
+
130
154
  if label_smoothing > 0:
131
155
  # scale X beforehand to avoid overflow
132
156
  if HAS_WEIGHT:
@@ -256,6 +280,10 @@ def liger_cross_entropy_kernel(
256
280
  tl.store(loss_ptr, loss)
257
281
  if RETURN_Z_LOSS:
258
282
  tl.store(z_loss_ptr, z_loss)
283
+ if RETURN_TOKEN_ACCURACY:
284
+ # Store 1.0 if prediction is correct, 0.0 otherwise
285
+ is_correct = 1.0 if argmax_idx == y else 0.0
286
+ tl.store(token_accuracy_ptr, is_correct)
259
287
 
260
288
 
261
289
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
@@ -274,8 +302,12 @@ def cross_entropy_forward(
274
302
  reduction,
275
303
  softcap,
276
304
  return_z_loss,
305
+ return_token_accuracy=False,
277
306
  ):
278
307
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
308
+ assert isinstance(return_token_accuracy, bool), (
309
+ f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
310
+ )
279
311
 
280
312
  BT, V = _input.shape
281
313
  n_rows = BT
@@ -285,6 +317,9 @@ def cross_entropy_forward(
285
317
  # unreduced loss
286
318
  loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
287
319
  z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
320
+ token_accuracy_1d = (
321
+ torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
322
+ )
288
323
 
289
324
  target_mask = target != ignore_index
290
325
  n_non_ignore = target_mask.sum().item()
@@ -321,6 +356,10 @@ def cross_entropy_forward(
321
356
  loss_ptr=loss_1d,
322
357
  z_loss_ptr=z_loss_1d,
323
358
  loss_stride=loss_1d.stride(-1), # always 1
359
+ token_accuracy_ptr=token_accuracy_1d,
360
+ token_accuracy_stride=token_accuracy_1d.stride(-1)
361
+ if return_token_accuracy
362
+ else 0, # always 1 if accuracy is enabled
324
363
  n_cols=V,
325
364
  n_non_ignore=n_non_ignore,
326
365
  sum_non_ignore_weight=sum_non_ignore_weight,
@@ -331,6 +370,7 @@ def cross_entropy_forward(
331
370
  reduction=reduction,
332
371
  softcap=softcap,
333
372
  RETURN_Z_LOSS=return_z_loss,
373
+ RETURN_TOKEN_ACCURACY=return_token_accuracy,
334
374
  BLOCK_SIZE=BLOCK_SIZE,
335
375
  HAS_WEIGHT=True if weight is not None else False,
336
376
  HAS_SOFTCAPPING=True if softcap is not None else False,
@@ -343,11 +383,14 @@ def cross_entropy_forward(
343
383
  if reduction == "none":
344
384
  loss = loss_1d
345
385
  z_loss = z_loss_1d if return_z_loss else None
386
+ token_accuracy = token_accuracy_1d if return_token_accuracy else None
346
387
  else:
347
388
  loss = torch.sum(loss_1d)
348
389
  z_loss = torch.sum(z_loss_1d) if return_z_loss else None
390
+ # For accuracy, we compute the mean across all non-ignored tokens
391
+ token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
349
392
 
350
- return loss, z_loss, _input
393
+ return loss, z_loss, token_accuracy, _input
351
394
 
352
395
 
353
396
  def cross_entropy_backward(_input, grad_output):
@@ -395,6 +438,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
395
438
  reduction: str = "mean",
396
439
  softcap: Optional[float] = None,
397
440
  return_z_loss: bool = False,
441
+ return_token_accuracy: bool = False,
398
442
  ):
399
443
  """
400
444
  The forward pass of the Liger Cross Entropy loss.
@@ -409,12 +453,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
409
453
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
410
454
  reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
411
455
  softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
412
- return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
456
+ return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
457
+ return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
413
458
 
414
459
  Returns:
415
- tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
460
+ tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
416
461
  """
417
- loss, z_loss, _input = cross_entropy_forward(
462
+ input_requires_grad = _input.requires_grad
463
+
464
+ loss, z_loss, token_accuracy, _input = cross_entropy_forward(
418
465
  _input,
419
466
  target,
420
467
  weight,
@@ -424,29 +471,35 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
424
471
  reduction,
425
472
  softcap,
426
473
  return_z_loss,
474
+ return_token_accuracy,
427
475
  )
428
476
  # TODO: investigation
429
477
  # If we don't detach the _input tensor, the memory will double
430
478
  # Not sure why but seems that there will be a time both grad and value exist but in different location
431
- ctx.save_for_backward(_input.detach())
479
+ if input_requires_grad:
480
+ ctx.save_for_backward(_input.detach())
432
481
  ctx.return_z_loss = return_z_loss
482
+ ctx.return_token_accuracy = return_token_accuracy
433
483
 
434
- return loss, z_loss
484
+ return loss, z_loss, token_accuracy
435
485
 
436
486
  @staticmethod
437
- def backward(ctx, grad_output, grad_ouput2):
487
+ def backward(ctx, grad_output, grad_output2, grad_output3):
438
488
  """
439
489
  The backward pass of the Liger Cross Entropy loss.
440
490
 
441
491
  Parameters:
442
492
  ctx : The context object with saved tensors.
443
493
  grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
444
- grad_output2 (tenosr): No use.
494
+ grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
495
+ grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
445
496
  Returns:
446
497
  tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
447
498
  """
448
499
  if ctx.return_z_loss:
449
- del grad_ouput2 # z_loss is only for logging
500
+ del grad_output2 # z_loss is only for logging
501
+ if ctx.return_token_accuracy:
502
+ del grad_output3 # token_accuracy is only for metrics
450
503
 
451
504
  (_input,) = ctx.saved_tensors
452
505
  _input = cross_entropy_backward(_input, grad_output)
@@ -460,4 +513,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
460
513
  None,
461
514
  None,
462
515
  None,
516
+ None,
463
517
  )
liger_kernel/ops/dyt.py CHANGED
@@ -7,8 +7,10 @@ import triton.language as tl
7
7
  from liger_kernel.ops.utils import compare_version
8
8
  from liger_kernel.ops.utils import ensure_contiguous
9
9
  from liger_kernel.ops.utils import infer_device
10
+ from liger_kernel.utils import get_npu_multi_processor_count
11
+ from liger_kernel.utils import is_npu_available
10
12
 
11
- if compare_version("triton", operator.ge, "3.0.0"):
13
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
12
14
  try:
13
15
  # typical import path with dispatch available
14
16
  from triton.language.extra.libdevice import tanh
@@ -125,7 +127,8 @@ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
125
127
  NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
126
128
  elif device == "xpu":
127
129
  NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
128
-
130
+ elif device == "npu":
131
+ NUM_SMS = get_npu_multi_processor_count()
129
132
  da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
130
133
  dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
131
134
  db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
@@ -9,8 +9,10 @@ from liger_kernel.ops.utils import calculate_settings
9
9
  from liger_kernel.ops.utils import compare_version
10
10
  from liger_kernel.ops.utils import ensure_contiguous
11
11
  from liger_kernel.ops.utils import torch_to_triton_dtype
12
+ from liger_kernel.utils import get_npu_multi_processor_count
13
+ from liger_kernel.utils import is_npu_available
12
14
 
13
- if compare_version("triton", operator.ge, "3.0.0"):
15
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
14
16
  try:
15
17
  # typical import path with dispatch available
16
18
  from triton.language.extra.libdevice import rsqrt
@@ -293,6 +295,8 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL
293
295
  sm_count = torch.cuda.get_device_properties(S.device).multi_processor_count
294
296
  elif S.device.type == "xpu":
295
297
  sm_count = torch.xpu.get_device_properties(S.device).gpu_eu_count
298
+ elif S.device.type == "npu":
299
+ sm_count = get_npu_multi_processor_count()
296
300
 
297
301
  # fp32 for numerical stability especially.
298
302
  _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)