liger-kernel-nightly 0.6.2.dev20250919191028__py3-none-any.whl → 0.6.4.dev20251202054858__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (67) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +18 -5
  6. liger_kernel/ops/cross_entropy.py +120 -63
  7. liger_kernel/ops/dyt.py +5 -2
  8. liger_kernel/ops/fused_add_rms_norm.py +5 -1
  9. liger_kernel/ops/fused_linear_cross_entropy.py +43 -12
  10. liger_kernel/ops/geglu.py +2 -1
  11. liger_kernel/ops/group_norm.py +2 -1
  12. liger_kernel/ops/grpo_loss.py +3 -1
  13. liger_kernel/ops/layer_norm.py +88 -70
  14. liger_kernel/ops/poly_norm.py +390 -0
  15. liger_kernel/ops/rms_norm.py +7 -2
  16. liger_kernel/ops/tiled_mlp.py +136 -0
  17. liger_kernel/ops/utils.py +2 -0
  18. liger_kernel/transformers/__init__.py +33 -0
  19. liger_kernel/transformers/cross_entropy.py +8 -3
  20. liger_kernel/transformers/functional.py +29 -6
  21. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -3
  22. liger_kernel/transformers/grpo_loss.py +56 -1
  23. liger_kernel/transformers/model/falcon_h1.py +122 -0
  24. liger_kernel/transformers/model/gemma.py +19 -7
  25. liger_kernel/transformers/model/gemma2.py +22 -7
  26. liger_kernel/transformers/model/gemma3.py +52 -14
  27. liger_kernel/transformers/model/glm4.py +18 -5
  28. liger_kernel/transformers/model/glm4v.py +18 -5
  29. liger_kernel/transformers/model/glm4v_moe.py +25 -5
  30. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  31. liger_kernel/transformers/model/internvl.py +157 -0
  32. liger_kernel/transformers/model/llama.py +16 -6
  33. liger_kernel/transformers/model/llama4.py +18 -5
  34. liger_kernel/transformers/model/llava.py +18 -6
  35. liger_kernel/transformers/model/loss_utils.py +31 -3
  36. liger_kernel/transformers/model/mistral.py +17 -7
  37. liger_kernel/transformers/model/mixtral.py +24 -9
  38. liger_kernel/transformers/model/mllama.py +14 -5
  39. liger_kernel/transformers/model/olmo2.py +18 -5
  40. liger_kernel/transformers/model/olmo3.py +142 -0
  41. liger_kernel/transformers/model/output_classes.py +147 -0
  42. liger_kernel/transformers/model/paligemma.py +41 -5
  43. liger_kernel/transformers/model/phi3.py +16 -8
  44. liger_kernel/transformers/model/qwen2.py +18 -4
  45. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  46. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  47. liger_kernel/transformers/model/qwen3.py +22 -6
  48. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  49. liger_kernel/transformers/model/qwen3_next.py +146 -0
  50. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  51. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  52. liger_kernel/transformers/model/smollm3.py +17 -7
  53. liger_kernel/transformers/model/smolvlm.py +158 -0
  54. liger_kernel/transformers/monkey_patch.py +729 -4
  55. liger_kernel/transformers/poly_norm.py +42 -0
  56. liger_kernel/transformers/rms_norm.py +7 -0
  57. liger_kernel/transformers/rope.py +43 -0
  58. liger_kernel/transformers/swiglu.py +17 -0
  59. liger_kernel/transformers/tiled_mlp.py +133 -0
  60. liger_kernel/utils.py +25 -0
  61. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/METADATA +13 -6
  62. liger_kernel_nightly-0.6.4.dev20251202054858.dist-info/RECORD +118 -0
  63. liger_kernel_nightly-0.6.2.dev20250919191028.dist-info/RECORD +0 -105
  64. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/LICENSE +0 -0
  65. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/NOTICE +0 -0
  66. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/WHEEL +0 -0
  67. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/top_level.txt +0 -0
@@ -10,8 +10,9 @@ from liger_kernel.ops.utils import compare_version
10
10
  from liger_kernel.ops.utils import element_mul_kernel
11
11
  from liger_kernel.ops.utils import is_hip
12
12
  from liger_kernel.utils import infer_device
13
+ from liger_kernel.utils import is_npu_available
13
14
 
14
- if compare_version("triton", operator.ge, "3.0.0"):
15
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
15
16
  try:
16
17
  # typical import path with dispatch available
17
18
  from triton.language.extra.libdevice import tanh
@@ -32,6 +33,8 @@ def liger_cross_entropy_kernel(
32
33
  loss_ptr,
33
34
  z_loss_ptr,
34
35
  loss_stride,
36
+ token_accuracy_ptr,
37
+ token_accuracy_stride,
35
38
  n_cols,
36
39
  n_non_ignore,
37
40
  sum_non_ignore_weight,
@@ -42,9 +45,11 @@ def liger_cross_entropy_kernel(
42
45
  reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
43
46
  softcap,
44
47
  RETURN_Z_LOSS: tl.constexpr,
48
+ RETURN_TOKEN_ACCURACY: tl.constexpr,
45
49
  BLOCK_SIZE: tl.constexpr,
46
50
  HAS_WEIGHT: tl.constexpr,
47
51
  HAS_SOFTCAPPING: tl.constexpr,
52
+ HAS_GRADIENTS: tl.constexpr,
48
53
  ):
49
54
  """
50
55
  This kernel computes both cross entropy loss and the gradient of the input.
@@ -59,6 +64,8 @@ def liger_cross_entropy_kernel(
59
64
  loss_ptr: Pointer to tensor to store the loss.
60
65
  z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
61
66
  loss_stride (int): The stride of the loss tensor.
67
+ token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
68
+ token_accuracy_stride (int): The stride of the token accuracy tensor.
62
69
  n_cols (int): The number of columns in the input tensor.
63
70
  n_non_ignore (float): The number of non-ignored elements in the batch.
64
71
  sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
@@ -68,10 +75,12 @@ def liger_cross_entropy_kernel(
68
75
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
69
76
  reduction (str): The string for the reduction to apply
70
77
  softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
71
- RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
78
+ RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
79
+ RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
72
80
  BLOCK_SIZE (int): The block size for Triton operations.
73
81
  HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
74
82
  HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
83
+ HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
75
84
  """
76
85
 
77
86
  # https://github.com/triton-lang/triton/issues/1058
@@ -90,11 +99,17 @@ def liger_cross_entropy_kernel(
90
99
  for i in range(0, n_cols, BLOCK_SIZE):
91
100
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
92
101
  tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
102
+ # For ignored tokens, set token accuracy to 0
103
+ if RETURN_TOKEN_ACCURACY:
104
+ token_accuracy_ptr += program_id * token_accuracy_stride
105
+ tl.store(token_accuracy_ptr, 0.0)
93
106
  return
94
107
 
95
108
  loss_ptr += program_id * loss_stride
96
109
  if RETURN_Z_LOSS:
97
110
  z_loss_ptr += program_id * loss_stride
111
+ if RETURN_TOKEN_ACCURACY:
112
+ token_accuracy_ptr += program_id * token_accuracy_stride
98
113
 
99
114
  if HAS_WEIGHT:
100
115
  weight_y = tl.load(weight_ptr + y).cast(tl.float32)
@@ -105,6 +120,7 @@ def liger_cross_entropy_kernel(
105
120
  # 3. [Online softmax] first pass: find max + sum
106
121
  m = float("-inf") # m is the max value. use the notation from the paper
107
122
  d = 0.0 # d is the sum. use the notation from the paper
123
+ argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
108
124
  ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
109
125
  if HAS_SOFTCAPPING:
110
126
  ori_X_y = softcap * tanh(ori_X_y / softcap)
@@ -125,6 +141,16 @@ def liger_cross_entropy_kernel(
125
141
  if HAS_SOFTCAPPING:
126
142
  X_block = softcap * tanh(X_block / softcap)
127
143
  block_max = tl.max(X_block)
144
+
145
+ # Track argmax for accuracy computation
146
+ if RETURN_TOKEN_ACCURACY and block_max > m:
147
+ # Find the index of the maximum value in this block
148
+ is_max_mask = X_block == block_max
149
+ # Mask out invalid indices with a value larger than n_cols
150
+ masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
151
+ # Get the first (smallest) index where max occurs
152
+ argmax_idx = tl.min(masked_offsets)
153
+
128
154
  if label_smoothing > 0:
129
155
  # scale X beforehand to avoid overflow
130
156
  if HAS_WEIGHT:
@@ -155,58 +181,58 @@ def liger_cross_entropy_kernel(
155
181
  # For 'sum' reduction, no normalization is applied:
156
182
  # dx_y = softmax(x_y) - 1
157
183
  # dx_i = softmax(x_i), for i ≠ y
158
-
159
- for i in range(0, n_cols, BLOCK_SIZE):
160
- X_offsets = i + tl.arange(0, BLOCK_SIZE)
161
- X_block = tl.load(
162
- X_ptr + X_offsets,
163
- mask=X_offsets < n_cols,
164
- other=float("-inf"),
165
- # Ensure float32 precision for softmax calculation
166
- ).cast(tl.float32)
167
- if HAS_SOFTCAPPING:
168
- intermediate = tanh(X_block / softcap)
169
- X_block = softcap * intermediate
170
-
171
- if not HAS_WEIGHT:
172
- # softmax(x_i)
173
- X_block = tl.exp(X_block - m) / d
174
- # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
175
- X_block += 2 * lse_square_scale * lse * X_block
176
- # smoothing term
177
- X_block += -eps
178
- # special handle dx_y
179
- X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
180
- # reduction scale
181
- if reduction == "mean":
182
- X_block = X_block / n_non_ignore
183
- else:
184
- weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
185
- softmax_X = tl.exp(X_block - m) / d
186
- # derivative of original_loss
187
- dloss_ori = (1 - label_smoothing) * softmax_X
188
- # specially handle dx_y
189
- dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
190
- dloss_ori = dloss_ori * weight_y
191
- # derivative of smooth_loss
192
- dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
193
- # derivative of z-loss
194
- dz_loss = 2 * lse_square_scale * lse * softmax_X
195
- # reduction scale
196
- if reduction == "mean":
197
- dloss_ori = dloss_ori / sum_non_ignore_weight
198
- dloss_smooth = dloss_smooth / sum_non_ignore_weight
199
- # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
200
- dz_loss = dz_loss / n_non_ignore
201
- # derivative of total_loss
202
- X_block = dloss_ori + dloss_smooth + dz_loss
203
-
204
- # chain rule softcapping
205
- # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
206
- if HAS_SOFTCAPPING:
207
- X_block = X_block * (1 - intermediate * intermediate)
208
-
209
- tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
184
+ if HAS_GRADIENTS:
185
+ for i in range(0, n_cols, BLOCK_SIZE):
186
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
187
+ X_block = tl.load(
188
+ X_ptr + X_offsets,
189
+ mask=X_offsets < n_cols,
190
+ other=float("-inf"),
191
+ # Ensure float32 precision for softmax calculation
192
+ ).cast(tl.float32)
193
+ if HAS_SOFTCAPPING:
194
+ intermediate = tanh(X_block / softcap)
195
+ X_block = softcap * intermediate
196
+
197
+ if not HAS_WEIGHT:
198
+ # softmax(x_i)
199
+ X_block = tl.exp(X_block - m) / d
200
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
201
+ X_block += 2 * lse_square_scale * lse * X_block
202
+ # smoothing term
203
+ X_block += -eps
204
+ # special handle dx_y
205
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
206
+ # reduction scale
207
+ if reduction == "mean":
208
+ X_block = X_block / n_non_ignore
209
+ else:
210
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
211
+ softmax_X = tl.exp(X_block - m) / d
212
+ # derivative of original_loss
213
+ dloss_ori = (1 - label_smoothing) * softmax_X
214
+ # specially handle dx_y
215
+ dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
216
+ dloss_ori = dloss_ori * weight_y
217
+ # derivative of smooth_loss
218
+ dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
219
+ # derivative of z-loss
220
+ dz_loss = 2 * lse_square_scale * lse * softmax_X
221
+ # reduction scale
222
+ if reduction == "mean":
223
+ dloss_ori = dloss_ori / sum_non_ignore_weight
224
+ dloss_smooth = dloss_smooth / sum_non_ignore_weight
225
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
226
+ dz_loss = dz_loss / n_non_ignore
227
+ # derivative of total_loss
228
+ X_block = dloss_ori + dloss_smooth + dz_loss
229
+
230
+ # chain rule softcapping
231
+ # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
232
+ if HAS_SOFTCAPPING:
233
+ X_block = X_block * (1 - intermediate * intermediate)
234
+
235
+ tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
210
236
 
211
237
  # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
212
238
  # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
@@ -254,6 +280,10 @@ def liger_cross_entropy_kernel(
254
280
  tl.store(loss_ptr, loss)
255
281
  if RETURN_Z_LOSS:
256
282
  tl.store(z_loss_ptr, z_loss)
283
+ if RETURN_TOKEN_ACCURACY:
284
+ # Store 1.0 if prediction is correct, 0.0 otherwise
285
+ is_correct = 1.0 if argmax_idx == y else 0.0
286
+ tl.store(token_accuracy_ptr, is_correct)
257
287
 
258
288
 
259
289
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
@@ -272,8 +302,12 @@ def cross_entropy_forward(
272
302
  reduction,
273
303
  softcap,
274
304
  return_z_loss,
305
+ return_token_accuracy=False,
275
306
  ):
276
307
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
308
+ assert isinstance(return_token_accuracy, bool), (
309
+ f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
310
+ )
277
311
 
278
312
  BT, V = _input.shape
279
313
  n_rows = BT
@@ -283,6 +317,9 @@ def cross_entropy_forward(
283
317
  # unreduced loss
284
318
  loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
285
319
  z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
320
+ token_accuracy_1d = (
321
+ torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
322
+ )
286
323
 
287
324
  target_mask = target != ignore_index
288
325
  n_non_ignore = target_mask.sum().item()
@@ -319,6 +356,10 @@ def cross_entropy_forward(
319
356
  loss_ptr=loss_1d,
320
357
  z_loss_ptr=z_loss_1d,
321
358
  loss_stride=loss_1d.stride(-1), # always 1
359
+ token_accuracy_ptr=token_accuracy_1d,
360
+ token_accuracy_stride=token_accuracy_1d.stride(-1)
361
+ if return_token_accuracy
362
+ else 0, # always 1 if accuracy is enabled
322
363
  n_cols=V,
323
364
  n_non_ignore=n_non_ignore,
324
365
  sum_non_ignore_weight=sum_non_ignore_weight,
@@ -329,9 +370,11 @@ def cross_entropy_forward(
329
370
  reduction=reduction,
330
371
  softcap=softcap,
331
372
  RETURN_Z_LOSS=return_z_loss,
373
+ RETURN_TOKEN_ACCURACY=return_token_accuracy,
332
374
  BLOCK_SIZE=BLOCK_SIZE,
333
375
  HAS_WEIGHT=True if weight is not None else False,
334
376
  HAS_SOFTCAPPING=True if softcap is not None else False,
377
+ HAS_GRADIENTS=_input.requires_grad,
335
378
  # TODO: 32 seems to give the best performance
336
379
  # Performance is quite sensitive to num_warps
337
380
  num_warps=32 if not is_hip() else 16,
@@ -340,11 +383,14 @@ def cross_entropy_forward(
340
383
  if reduction == "none":
341
384
  loss = loss_1d
342
385
  z_loss = z_loss_1d if return_z_loss else None
386
+ token_accuracy = token_accuracy_1d if return_token_accuracy else None
343
387
  else:
344
388
  loss = torch.sum(loss_1d)
345
389
  z_loss = torch.sum(z_loss_1d) if return_z_loss else None
390
+ # For accuracy, we compute the mean across all non-ignored tokens
391
+ token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
346
392
 
347
- return loss, z_loss, _input
393
+ return loss, z_loss, token_accuracy, _input
348
394
 
349
395
 
350
396
  def cross_entropy_backward(_input, grad_output):
@@ -392,6 +438,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
392
438
  reduction: str = "mean",
393
439
  softcap: Optional[float] = None,
394
440
  return_z_loss: bool = False,
441
+ return_token_accuracy: bool = False,
395
442
  ):
396
443
  """
397
444
  The forward pass of the Liger Cross Entropy loss.
@@ -406,12 +453,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
406
453
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
407
454
  reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
408
455
  softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
409
- return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
456
+ return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
457
+ return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
410
458
 
411
459
  Returns:
412
- tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
460
+ tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
413
461
  """
414
- loss, z_loss, _input = cross_entropy_forward(
462
+ input_requires_grad = _input.requires_grad
463
+
464
+ loss, z_loss, token_accuracy, _input = cross_entropy_forward(
415
465
  _input,
416
466
  target,
417
467
  weight,
@@ -421,29 +471,35 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
421
471
  reduction,
422
472
  softcap,
423
473
  return_z_loss,
474
+ return_token_accuracy,
424
475
  )
425
476
  # TODO: investigation
426
477
  # If we don't detach the _input tensor, the memory will double
427
478
  # Not sure why but seems that there will be a time both grad and value exist but in different location
428
- ctx.save_for_backward(_input.detach())
479
+ if input_requires_grad:
480
+ ctx.save_for_backward(_input.detach())
429
481
  ctx.return_z_loss = return_z_loss
482
+ ctx.return_token_accuracy = return_token_accuracy
430
483
 
431
- return loss, z_loss
484
+ return loss, z_loss, token_accuracy
432
485
 
433
486
  @staticmethod
434
- def backward(ctx, grad_output, grad_ouput2):
487
+ def backward(ctx, grad_output, grad_output2, grad_output3):
435
488
  """
436
489
  The backward pass of the Liger Cross Entropy loss.
437
490
 
438
491
  Parameters:
439
492
  ctx : The context object with saved tensors.
440
493
  grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
441
- grad_output2 (tenosr): No use.
494
+ grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
495
+ grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
442
496
  Returns:
443
497
  tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
444
498
  """
445
499
  if ctx.return_z_loss:
446
- del grad_ouput2 # z_loss is only for logging
500
+ del grad_output2 # z_loss is only for logging
501
+ if ctx.return_token_accuracy:
502
+ del grad_output3 # token_accuracy is only for metrics
447
503
 
448
504
  (_input,) = ctx.saved_tensors
449
505
  _input = cross_entropy_backward(_input, grad_output)
@@ -457,4 +513,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
457
513
  None,
458
514
  None,
459
515
  None,
516
+ None,
460
517
  )
liger_kernel/ops/dyt.py CHANGED
@@ -7,8 +7,10 @@ import triton.language as tl
7
7
  from liger_kernel.ops.utils import compare_version
8
8
  from liger_kernel.ops.utils import ensure_contiguous
9
9
  from liger_kernel.ops.utils import infer_device
10
+ from liger_kernel.utils import get_npu_multi_processor_count
11
+ from liger_kernel.utils import is_npu_available
10
12
 
11
- if compare_version("triton", operator.ge, "3.0.0"):
13
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
12
14
  try:
13
15
  # typical import path with dispatch available
14
16
  from triton.language.extra.libdevice import tanh
@@ -125,7 +127,8 @@ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
125
127
  NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
126
128
  elif device == "xpu":
127
129
  NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
128
-
130
+ elif device == "npu":
131
+ NUM_SMS = get_npu_multi_processor_count()
129
132
  da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
130
133
  dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
131
134
  db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
@@ -9,8 +9,10 @@ from liger_kernel.ops.utils import calculate_settings
9
9
  from liger_kernel.ops.utils import compare_version
10
10
  from liger_kernel.ops.utils import ensure_contiguous
11
11
  from liger_kernel.ops.utils import torch_to_triton_dtype
12
+ from liger_kernel.utils import get_npu_multi_processor_count
13
+ from liger_kernel.utils import is_npu_available
12
14
 
13
- if compare_version("triton", operator.ge, "3.0.0"):
15
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
14
16
  try:
15
17
  # typical import path with dispatch available
16
18
  from triton.language.extra.libdevice import rsqrt
@@ -293,6 +295,8 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL
293
295
  sm_count = torch.cuda.get_device_properties(S.device).multi_processor_count
294
296
  elif S.device.type == "xpu":
295
297
  sm_count = torch.xpu.get_device_properties(S.device).gpu_eu_count
298
+ elif S.device.type == "npu":
299
+ sm_count = get_npu_multi_processor_count()
296
300
 
297
301
  # fp32 for numerical stability especially.
298
302
  _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
@@ -27,10 +27,16 @@ def fused_linear_cross_entropy_forward(
27
27
  return_z_loss=False,
28
28
  accum_dtype=None,
29
29
  use_token_scaling=False,
30
+ return_token_accuracy=False,
30
31
  ):
31
32
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
33
+ assert isinstance(return_token_accuracy, bool), (
34
+ f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
35
+ )
32
36
  device = _input.device
33
37
 
38
+ input_requires_grad = _input.requires_grad
39
+
34
40
  # inputs have shape: BT x H
35
41
  # materialized activations will have shape: BT x V
36
42
  # the increase in memory = BT x V
@@ -49,15 +55,20 @@ def fused_linear_cross_entropy_forward(
49
55
  grad_input = torch.zeros_like(_input, device=device)
50
56
 
51
57
  # we use fp32 for loss and gradients accumulator
52
- if accum_dtype is None:
53
- grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
54
- grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
58
+ if input_requires_grad:
59
+ if accum_dtype is None:
60
+ grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
61
+ grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
62
+ else:
63
+ grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
64
+ grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
55
65
  else:
56
- grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
57
- grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
66
+ grad_weight = None
67
+ grad_bias = None
58
68
 
59
69
  loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
60
70
  z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
71
+ token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None
61
72
 
62
73
  # TODO: evaluate how CUDA synchronization caused by .item() affects the speed
63
74
  target_mask = target != ignore_index
@@ -123,6 +134,7 @@ def fused_linear_cross_entropy_forward(
123
134
  # unreduced loss
124
135
  loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
125
136
  z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
137
+ token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None
126
138
 
127
139
  # ensure _input and target are contiguous
128
140
  logits_chunk = logits_chunk.contiguous()
@@ -138,6 +150,10 @@ def fused_linear_cross_entropy_forward(
138
150
  loss_ptr=loss_1d_slice,
139
151
  z_loss_ptr=z_loss_1d_slice,
140
152
  loss_stride=loss_1d_slice.stride(-1), # always 1
153
+ token_accuracy_ptr=token_accuracy_1d_slice,
154
+ token_accuracy_stride=token_accuracy_1d_slice.stride(-1)
155
+ if return_token_accuracy
156
+ else 0, # always 1 if accuracy is enabled
141
157
  n_cols=V,
142
158
  n_non_ignore=total_n_non_ignore,
143
159
  sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
@@ -148,8 +164,10 @@ def fused_linear_cross_entropy_forward(
148
164
  reduction=reduction,
149
165
  softcap=softcap,
150
166
  RETURN_Z_LOSS=return_z_loss,
167
+ RETURN_TOKEN_ACCURACY=return_token_accuracy,
151
168
  HAS_WEIGHT=True if ce_weight is not None else False,
152
169
  HAS_SOFTCAPPING=True if softcap is not None else False,
170
+ HAS_GRADIENTS=input_requires_grad,
153
171
  BLOCK_SIZE=BLOCK_SIZE,
154
172
  num_warps=32 if not is_hip() else 16,
155
173
  )
@@ -163,6 +181,8 @@ def fused_linear_cross_entropy_forward(
163
181
  loss_1d[start_idx:end_idx] = loss_1d_slice
164
182
  if return_z_loss:
165
183
  z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
184
+ if return_token_accuracy:
185
+ token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice
166
186
  grad_logits_chunk = logits_chunk # chunk_size x V
167
187
 
168
188
  # Apply token scaling to gradients if requested
@@ -171,12 +191,13 @@ def fused_linear_cross_entropy_forward(
171
191
  scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
172
192
  grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
173
193
 
174
- grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
194
+ if input_requires_grad:
195
+ grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
175
196
 
176
- if grad_weight is not None:
197
+ if grad_weight is not None and input_requires_grad:
177
198
  grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
178
199
 
179
- if bias is not None:
200
+ if bias is not None and input_requires_grad:
180
201
  torch.add(
181
202
  input=grad_bias,
182
203
  other=grad_logits_chunk.sum(dim=0),
@@ -193,15 +214,18 @@ def fused_linear_cross_entropy_forward(
193
214
  # Return per-token losses
194
215
  loss = loss_1d
195
216
  z_loss = z_loss_1d if return_z_loss else None
217
+ token_accuracy = token_accuracy_1d if return_token_accuracy else None
196
218
  else:
197
219
  loss = torch.sum(loss_1d)
198
220
  z_loss = torch.sum(z_loss_1d) if return_z_loss else None
221
+ # For accuracy, we compute the mean across all non-ignored tokens
222
+ token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None
199
223
 
200
224
  # Cast back to original dtype
201
225
  grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
202
226
  grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None
203
227
 
204
- return loss, z_loss, grad_input, grad_weight, grad_bias
228
+ return loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias
205
229
 
206
230
 
207
231
  def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
@@ -269,6 +293,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
269
293
  return_z_loss: bool = False,
270
294
  accum_dtype=None,
271
295
  use_token_scaling: bool = False,
296
+ return_token_accuracy: bool = False,
272
297
  ):
273
298
  """
274
299
  Fusing the last linear layer with cross-entropy loss
@@ -292,9 +317,10 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
292
317
  use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
293
318
  When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
294
319
  Default: False.
320
+ return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
295
321
  """
296
322
 
297
- loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
323
+ loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
298
324
  _input=_input,
299
325
  weight=weight,
300
326
  target=target,
@@ -308,6 +334,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
308
334
  return_z_loss=return_z_loss,
309
335
  accum_dtype=accum_dtype,
310
336
  use_token_scaling=use_token_scaling,
337
+ return_token_accuracy=return_token_accuracy,
311
338
  )
312
339
  # downcast to dtype and store for backward
313
340
  ctx.save_for_backward(
@@ -316,13 +343,16 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
316
343
  grad_bias.detach() if bias is not None else None,
317
344
  )
318
345
  ctx.return_z_loss = return_z_loss
319
- return loss, z_loss
346
+ ctx.return_token_accuracy = return_token_accuracy
347
+ return loss, z_loss, token_accuracy
320
348
 
321
349
  @staticmethod
322
350
  @amp_custom_bwd
323
- def backward(ctx, grad_output, grad_output2):
351
+ def backward(ctx, grad_output, grad_output2, grad_output3):
324
352
  if ctx.return_z_loss:
325
353
  del grad_output2 # z_loss is only for logging
354
+ if ctx.return_token_accuracy:
355
+ del grad_output3 # token_accuracy is only for metrics
326
356
  (grad_input, grad_weight, grad_bias) = ctx.saved_tensors
327
357
  grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
328
358
  grad_output, grad_input, grad_weight, grad_bias
@@ -341,4 +371,5 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
341
371
  None,
342
372
  None,
343
373
  None, # use_token_scaling
374
+ None, # return_token_accuracy
344
375
  )
liger_kernel/ops/geglu.py CHANGED
@@ -7,8 +7,9 @@ import triton.language as tl
7
7
  from liger_kernel.ops.utils import calculate_settings
8
8
  from liger_kernel.ops.utils import compare_version
9
9
  from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.utils import is_npu_available
10
11
 
11
- if compare_version("triton", operator.ge, "3.0.0"):
12
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
12
13
  try:
13
14
  # typical import path with dispatch available
14
15
  from triton.language.extra.libdevice import tanh
@@ -6,8 +6,9 @@ import triton.language as tl
6
6
 
7
7
  from liger_kernel.ops.utils import compare_version
8
8
  from liger_kernel.ops.utils import ensure_contiguous
9
+ from liger_kernel.utils import is_npu_available
9
10
 
10
- if compare_version("triton", operator.ge, "3.0.0"):
11
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
11
12
  try:
12
13
  # typical import path with dispatch available
13
14
  from triton.language.extra.libdevice import rsqrt
@@ -128,7 +128,9 @@ def _grpo_loss_fwd_kernel(
128
128
  per_token_loss1 = coef_1 * advantage
129
129
  per_token_loss2 = coef_2 * advantage
130
130
  per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
131
- is_clipped = per_token_loss1 < per_token_loss2
131
+ is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
132
+ is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
133
+ is_clipped = is_low_clipped | is_high_clipped
132
134
 
133
135
  if BETA != 0.0:
134
136
  REF_LOGP += off_b * L + off_l