liger-kernel 0.6.2__py3-none-any.whl → 0.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +46 -9
  5. liger_kernel/chunked_loss/jsd_loss.py +23 -7
  6. liger_kernel/ops/cross_entropy.py +118 -62
  7. liger_kernel/ops/fused_linear_cross_entropy.py +97 -13
  8. liger_kernel/ops/grpo_loss.py +3 -1
  9. liger_kernel/ops/layer_norm.py +86 -69
  10. liger_kernel/ops/poly_norm.py +386 -0
  11. liger_kernel/ops/tiled_mlp.py +136 -0
  12. liger_kernel/transformers/__init__.py +36 -0
  13. liger_kernel/transformers/cross_entropy.py +8 -3
  14. liger_kernel/transformers/functional.py +31 -6
  15. liger_kernel/transformers/fused_linear_cross_entropy.py +13 -4
  16. liger_kernel/transformers/grpo_loss.py +56 -1
  17. liger_kernel/transformers/model/falcon_h1.py +122 -0
  18. liger_kernel/transformers/model/gemma.py +19 -7
  19. liger_kernel/transformers/model/gemma2.py +22 -7
  20. liger_kernel/transformers/model/gemma3.py +52 -14
  21. liger_kernel/transformers/model/glm4.py +18 -5
  22. liger_kernel/transformers/model/glm4v.py +19 -6
  23. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  24. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  25. liger_kernel/transformers/model/internvl.py +157 -0
  26. liger_kernel/transformers/model/llama.py +16 -6
  27. liger_kernel/transformers/model/llama4.py +18 -5
  28. liger_kernel/transformers/model/llava.py +18 -6
  29. liger_kernel/transformers/model/loss_utils.py +32 -3
  30. liger_kernel/transformers/model/mistral.py +17 -7
  31. liger_kernel/transformers/model/mixtral.py +24 -9
  32. liger_kernel/transformers/model/mllama.py +14 -5
  33. liger_kernel/transformers/model/olmo2.py +18 -5
  34. liger_kernel/transformers/model/olmo3.py +142 -0
  35. liger_kernel/transformers/model/output_classes.py +147 -0
  36. liger_kernel/transformers/model/paligemma.py +41 -5
  37. liger_kernel/transformers/model/phi3.py +16 -8
  38. liger_kernel/transformers/model/qwen2.py +18 -4
  39. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  40. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  41. liger_kernel/transformers/model/qwen3.py +22 -6
  42. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  43. liger_kernel/transformers/model/qwen3_next.py +146 -0
  44. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  45. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  46. liger_kernel/transformers/model/smollm3.py +17 -7
  47. liger_kernel/transformers/model/smolvlm.py +158 -0
  48. liger_kernel/transformers/monkey_patch.py +830 -3
  49. liger_kernel/transformers/multi_token_attention.py +1 -1
  50. liger_kernel/transformers/poly_norm.py +42 -0
  51. liger_kernel/transformers/rms_norm.py +7 -0
  52. liger_kernel/transformers/rope.py +43 -0
  53. liger_kernel/transformers/swiglu.py +17 -0
  54. liger_kernel/transformers/tiled_mlp.py +133 -0
  55. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.4.dist-info}/METADATA +16 -10
  56. liger_kernel-0.6.4.dist-info/RECORD +118 -0
  57. liger_kernel-0.6.2.dist-info/RECORD +0 -104
  58. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.4.dist-info}/WHEEL +0 -0
  59. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.4.dist-info}/licenses/LICENSE +0 -0
  60. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.4.dist-info}/licenses/NOTICE +0 -0
  61. {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.4.dist-info}/top_level.txt +0 -0
@@ -32,6 +32,8 @@ def liger_cross_entropy_kernel(
32
32
  loss_ptr,
33
33
  z_loss_ptr,
34
34
  loss_stride,
35
+ token_accuracy_ptr,
36
+ token_accuracy_stride,
35
37
  n_cols,
36
38
  n_non_ignore,
37
39
  sum_non_ignore_weight,
@@ -42,9 +44,11 @@ def liger_cross_entropy_kernel(
42
44
  reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
43
45
  softcap,
44
46
  RETURN_Z_LOSS: tl.constexpr,
47
+ RETURN_TOKEN_ACCURACY: tl.constexpr,
45
48
  BLOCK_SIZE: tl.constexpr,
46
49
  HAS_WEIGHT: tl.constexpr,
47
50
  HAS_SOFTCAPPING: tl.constexpr,
51
+ HAS_GRADIENTS: tl.constexpr,
48
52
  ):
49
53
  """
50
54
  This kernel computes both cross entropy loss and the gradient of the input.
@@ -59,6 +63,8 @@ def liger_cross_entropy_kernel(
59
63
  loss_ptr: Pointer to tensor to store the loss.
60
64
  z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
61
65
  loss_stride (int): The stride of the loss tensor.
66
+ token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
67
+ token_accuracy_stride (int): The stride of the token accuracy tensor.
62
68
  n_cols (int): The number of columns in the input tensor.
63
69
  n_non_ignore (float): The number of non-ignored elements in the batch.
64
70
  sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
@@ -68,10 +74,12 @@ def liger_cross_entropy_kernel(
68
74
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
69
75
  reduction (str): The string for the reduction to apply
70
76
  softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
71
- RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
77
+ RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
78
+ RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
72
79
  BLOCK_SIZE (int): The block size for Triton operations.
73
80
  HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
74
81
  HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
82
+ HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
75
83
  """
76
84
 
77
85
  # https://github.com/triton-lang/triton/issues/1058
@@ -90,11 +98,17 @@ def liger_cross_entropy_kernel(
90
98
  for i in range(0, n_cols, BLOCK_SIZE):
91
99
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
92
100
  tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
101
+ # For ignored tokens, set token accuracy to 0
102
+ if RETURN_TOKEN_ACCURACY:
103
+ token_accuracy_ptr += program_id * token_accuracy_stride
104
+ tl.store(token_accuracy_ptr, 0.0)
93
105
  return
94
106
 
95
107
  loss_ptr += program_id * loss_stride
96
108
  if RETURN_Z_LOSS:
97
109
  z_loss_ptr += program_id * loss_stride
110
+ if RETURN_TOKEN_ACCURACY:
111
+ token_accuracy_ptr += program_id * token_accuracy_stride
98
112
 
99
113
  if HAS_WEIGHT:
100
114
  weight_y = tl.load(weight_ptr + y).cast(tl.float32)
@@ -105,6 +119,7 @@ def liger_cross_entropy_kernel(
105
119
  # 3. [Online softmax] first pass: find max + sum
106
120
  m = float("-inf") # m is the max value. use the notation from the paper
107
121
  d = 0.0 # d is the sum. use the notation from the paper
122
+ argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
108
123
  ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
109
124
  if HAS_SOFTCAPPING:
110
125
  ori_X_y = softcap * tanh(ori_X_y / softcap)
@@ -125,6 +140,16 @@ def liger_cross_entropy_kernel(
125
140
  if HAS_SOFTCAPPING:
126
141
  X_block = softcap * tanh(X_block / softcap)
127
142
  block_max = tl.max(X_block)
143
+
144
+ # Track argmax for accuracy computation
145
+ if RETURN_TOKEN_ACCURACY and block_max > m:
146
+ # Find the index of the maximum value in this block
147
+ is_max_mask = X_block == block_max
148
+ # Mask out invalid indices with a value larger than n_cols
149
+ masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
150
+ # Get the first (smallest) index where max occurs
151
+ argmax_idx = tl.min(masked_offsets)
152
+
128
153
  if label_smoothing > 0:
129
154
  # scale X beforehand to avoid overflow
130
155
  if HAS_WEIGHT:
@@ -155,58 +180,58 @@ def liger_cross_entropy_kernel(
155
180
  # For 'sum' reduction, no normalization is applied:
156
181
  # dx_y = softmax(x_y) - 1
157
182
  # dx_i = softmax(x_i), for i ≠ y
158
-
159
- for i in range(0, n_cols, BLOCK_SIZE):
160
- X_offsets = i + tl.arange(0, BLOCK_SIZE)
161
- X_block = tl.load(
162
- X_ptr + X_offsets,
163
- mask=X_offsets < n_cols,
164
- other=float("-inf"),
165
- # Ensure float32 precision for softmax calculation
166
- ).cast(tl.float32)
167
- if HAS_SOFTCAPPING:
168
- intermediate = tanh(X_block / softcap)
169
- X_block = softcap * intermediate
170
-
171
- if not HAS_WEIGHT:
172
- # softmax(x_i)
173
- X_block = tl.exp(X_block - m) / d
174
- # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
175
- X_block += 2 * lse_square_scale * lse * X_block
176
- # smoothing term
177
- X_block += -eps
178
- # special handle dx_y
179
- X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
180
- # reduction scale
181
- if reduction == "mean":
182
- X_block = X_block / n_non_ignore
183
- else:
184
- weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
185
- softmax_X = tl.exp(X_block - m) / d
186
- # derivative of original_loss
187
- dloss_ori = (1 - label_smoothing) * softmax_X
188
- # specially handle dx_y
189
- dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
190
- dloss_ori = dloss_ori * weight_y
191
- # derivative of smooth_loss
192
- dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
193
- # derivative of z-loss
194
- dz_loss = 2 * lse_square_scale * lse * softmax_X
195
- # reduction scale
196
- if reduction == "mean":
197
- dloss_ori = dloss_ori / sum_non_ignore_weight
198
- dloss_smooth = dloss_smooth / sum_non_ignore_weight
199
- # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
200
- dz_loss = dz_loss / n_non_ignore
201
- # derivative of total_loss
202
- X_block = dloss_ori + dloss_smooth + dz_loss
203
-
204
- # chain rule softcapping
205
- # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
206
- if HAS_SOFTCAPPING:
207
- X_block = X_block * (1 - intermediate * intermediate)
208
-
209
- tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
183
+ if HAS_GRADIENTS:
184
+ for i in range(0, n_cols, BLOCK_SIZE):
185
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
186
+ X_block = tl.load(
187
+ X_ptr + X_offsets,
188
+ mask=X_offsets < n_cols,
189
+ other=float("-inf"),
190
+ # Ensure float32 precision for softmax calculation
191
+ ).cast(tl.float32)
192
+ if HAS_SOFTCAPPING:
193
+ intermediate = tanh(X_block / softcap)
194
+ X_block = softcap * intermediate
195
+
196
+ if not HAS_WEIGHT:
197
+ # softmax(x_i)
198
+ X_block = tl.exp(X_block - m) / d
199
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
200
+ X_block += 2 * lse_square_scale * lse * X_block
201
+ # smoothing term
202
+ X_block += -eps
203
+ # special handle dx_y
204
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
205
+ # reduction scale
206
+ if reduction == "mean":
207
+ X_block = X_block / n_non_ignore
208
+ else:
209
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
210
+ softmax_X = tl.exp(X_block - m) / d
211
+ # derivative of original_loss
212
+ dloss_ori = (1 - label_smoothing) * softmax_X
213
+ # specially handle dx_y
214
+ dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
215
+ dloss_ori = dloss_ori * weight_y
216
+ # derivative of smooth_loss
217
+ dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
218
+ # derivative of z-loss
219
+ dz_loss = 2 * lse_square_scale * lse * softmax_X
220
+ # reduction scale
221
+ if reduction == "mean":
222
+ dloss_ori = dloss_ori / sum_non_ignore_weight
223
+ dloss_smooth = dloss_smooth / sum_non_ignore_weight
224
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
225
+ dz_loss = dz_loss / n_non_ignore
226
+ # derivative of total_loss
227
+ X_block = dloss_ori + dloss_smooth + dz_loss
228
+
229
+ # chain rule softcapping
230
+ # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
231
+ if HAS_SOFTCAPPING:
232
+ X_block = X_block * (1 - intermediate * intermediate)
233
+
234
+ tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
210
235
 
211
236
  # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
212
237
  # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
@@ -254,6 +279,10 @@ def liger_cross_entropy_kernel(
254
279
  tl.store(loss_ptr, loss)
255
280
  if RETURN_Z_LOSS:
256
281
  tl.store(z_loss_ptr, z_loss)
282
+ if RETURN_TOKEN_ACCURACY:
283
+ # Store 1.0 if prediction is correct, 0.0 otherwise
284
+ is_correct = 1.0 if argmax_idx == y else 0.0
285
+ tl.store(token_accuracy_ptr, is_correct)
257
286
 
258
287
 
259
288
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
@@ -272,8 +301,12 @@ def cross_entropy_forward(
272
301
  reduction,
273
302
  softcap,
274
303
  return_z_loss,
304
+ return_token_accuracy=False,
275
305
  ):
276
306
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
307
+ assert isinstance(return_token_accuracy, bool), (
308
+ f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
309
+ )
277
310
 
278
311
  BT, V = _input.shape
279
312
  n_rows = BT
@@ -283,6 +316,9 @@ def cross_entropy_forward(
283
316
  # unreduced loss
284
317
  loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
285
318
  z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
319
+ token_accuracy_1d = (
320
+ torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
321
+ )
286
322
 
287
323
  target_mask = target != ignore_index
288
324
  n_non_ignore = target_mask.sum().item()
@@ -319,6 +355,10 @@ def cross_entropy_forward(
319
355
  loss_ptr=loss_1d,
320
356
  z_loss_ptr=z_loss_1d,
321
357
  loss_stride=loss_1d.stride(-1), # always 1
358
+ token_accuracy_ptr=token_accuracy_1d,
359
+ token_accuracy_stride=token_accuracy_1d.stride(-1)
360
+ if return_token_accuracy
361
+ else 0, # always 1 if accuracy is enabled
322
362
  n_cols=V,
323
363
  n_non_ignore=n_non_ignore,
324
364
  sum_non_ignore_weight=sum_non_ignore_weight,
@@ -329,9 +369,11 @@ def cross_entropy_forward(
329
369
  reduction=reduction,
330
370
  softcap=softcap,
331
371
  RETURN_Z_LOSS=return_z_loss,
372
+ RETURN_TOKEN_ACCURACY=return_token_accuracy,
332
373
  BLOCK_SIZE=BLOCK_SIZE,
333
374
  HAS_WEIGHT=True if weight is not None else False,
334
375
  HAS_SOFTCAPPING=True if softcap is not None else False,
376
+ HAS_GRADIENTS=_input.requires_grad,
335
377
  # TODO: 32 seems to give the best performance
336
378
  # Performance is quite sensitive to num_warps
337
379
  num_warps=32 if not is_hip() else 16,
@@ -340,11 +382,14 @@ def cross_entropy_forward(
340
382
  if reduction == "none":
341
383
  loss = loss_1d
342
384
  z_loss = z_loss_1d if return_z_loss else None
385
+ token_accuracy = token_accuracy_1d if return_token_accuracy else None
343
386
  else:
344
387
  loss = torch.sum(loss_1d)
345
388
  z_loss = torch.sum(z_loss_1d) if return_z_loss else None
389
+ # For accuracy, we compute the mean across all non-ignored tokens
390
+ token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
346
391
 
347
- return loss, z_loss, _input
392
+ return loss, z_loss, token_accuracy, _input
348
393
 
349
394
 
350
395
  def cross_entropy_backward(_input, grad_output):
@@ -392,6 +437,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
392
437
  reduction: str = "mean",
393
438
  softcap: Optional[float] = None,
394
439
  return_z_loss: bool = False,
440
+ return_token_accuracy: bool = False,
395
441
  ):
396
442
  """
397
443
  The forward pass of the Liger Cross Entropy loss.
@@ -406,12 +452,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
406
452
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
407
453
  reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
408
454
  softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
409
- return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
455
+ return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
456
+ return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
410
457
 
411
458
  Returns:
412
- tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
459
+ tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
413
460
  """
414
- loss, z_loss, _input = cross_entropy_forward(
461
+ input_requires_grad = _input.requires_grad
462
+
463
+ loss, z_loss, token_accuracy, _input = cross_entropy_forward(
415
464
  _input,
416
465
  target,
417
466
  weight,
@@ -421,29 +470,35 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
421
470
  reduction,
422
471
  softcap,
423
472
  return_z_loss,
473
+ return_token_accuracy,
424
474
  )
425
475
  # TODO: investigation
426
476
  # If we don't detach the _input tensor, the memory will double
427
477
  # Not sure why but seems that there will be a time both grad and value exist but in different location
428
- ctx.save_for_backward(_input.detach())
478
+ if input_requires_grad:
479
+ ctx.save_for_backward(_input.detach())
429
480
  ctx.return_z_loss = return_z_loss
481
+ ctx.return_token_accuracy = return_token_accuracy
430
482
 
431
- return loss, z_loss
483
+ return loss, z_loss, token_accuracy
432
484
 
433
485
  @staticmethod
434
- def backward(ctx, grad_output, grad_ouput2):
486
+ def backward(ctx, grad_output, grad_output2, grad_output3):
435
487
  """
436
488
  The backward pass of the Liger Cross Entropy loss.
437
489
 
438
490
  Parameters:
439
491
  ctx : The context object with saved tensors.
440
492
  grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
441
- grad_output2 (tenosr): No use.
493
+ grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
494
+ grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
442
495
  Returns:
443
496
  tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
444
497
  """
445
498
  if ctx.return_z_loss:
446
- del grad_ouput2 # z_loss is only for logging
499
+ del grad_output2 # z_loss is only for logging
500
+ if ctx.return_token_accuracy:
501
+ del grad_output3 # token_accuracy is only for metrics
447
502
 
448
503
  (_input,) = ctx.saved_tensors
449
504
  _input = cross_entropy_backward(_input, grad_output)
@@ -457,4 +512,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
457
512
  None,
458
513
  None,
459
514
  None,
515
+ None,
460
516
  )
@@ -26,10 +26,17 @@ def fused_linear_cross_entropy_forward(
26
26
  softcap=None,
27
27
  return_z_loss=False,
28
28
  accum_dtype=None,
29
+ use_token_scaling=False,
30
+ return_token_accuracy=False,
29
31
  ):
30
32
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
33
+ assert isinstance(return_token_accuracy, bool), (
34
+ f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
35
+ )
31
36
  device = _input.device
32
37
 
38
+ input_requires_grad = _input.requires_grad
39
+
33
40
  # inputs have shape: BT x H
34
41
  # materialized activations will have shape: BT x V
35
42
  # the increase in memory = BT x V
@@ -48,15 +55,20 @@ def fused_linear_cross_entropy_forward(
48
55
  grad_input = torch.zeros_like(_input, device=device)
49
56
 
50
57
  # we use fp32 for loss and gradients accumulator
51
- if accum_dtype is None:
52
- grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
53
- grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
58
+ if input_requires_grad:
59
+ if accum_dtype is None:
60
+ grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
61
+ grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
62
+ else:
63
+ grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
64
+ grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
54
65
  else:
55
- grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
56
- grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
66
+ grad_weight = None
67
+ grad_bias = None
57
68
 
58
69
  loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
59
70
  z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
71
+ token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None
60
72
 
61
73
  # TODO: evaluate how CUDA synchronization caused by .item() affects the speed
62
74
  target_mask = target != ignore_index
@@ -89,9 +101,40 @@ def fused_linear_cross_entropy_forward(
89
101
 
90
102
  n_rows = logits_chunk.shape[0]
91
103
 
104
+ # Compute predicted probabilities for token scaling if needed
105
+ if use_token_scaling:
106
+ # Compute softmax probabilities for scaling
107
+ # We need to compute this before the cross entropy kernel modifies logits_chunk
108
+ logits_for_softmax = logits_chunk.detach().clone() # Detach to avoid gradient flow
109
+ if softcap is not None:
110
+ logits_for_softmax = softcap * torch.tanh(logits_for_softmax / softcap)
111
+
112
+ # Compute softmax to get predicted probabilities
113
+ probs = torch.softmax(logits_for_softmax, dim=-1)
114
+
115
+ # Get predicted probabilities for token scaling, handling ignored targets
116
+ valid_target_mask = target_chunk != ignore_index
117
+ valid_targets = target_chunk[valid_target_mask]
118
+
119
+ if len(valid_targets) > 0:
120
+ # Gather probabilities only for valid targets
121
+ valid_probs = probs[valid_target_mask]
122
+ pred_probs_valid = torch.gather(valid_probs, -1, valid_targets.unsqueeze(-1)).squeeze(-1)
123
+
124
+ # Create full tensor with zeros for ignored targets
125
+ pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
126
+ pred_probs[valid_target_mask] = pred_probs_valid
127
+ else:
128
+ # All targets are ignored
129
+ pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
130
+
131
+ # Store the scaling factors
132
+ scaling_factors = pred_probs.detach() # Detach to ensure no gradient flow
133
+
92
134
  # unreduced loss
93
135
  loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
94
136
  z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
137
+ token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None
95
138
 
96
139
  # ensure _input and target are contiguous
97
140
  logits_chunk = logits_chunk.contiguous()
@@ -107,6 +150,10 @@ def fused_linear_cross_entropy_forward(
107
150
  loss_ptr=loss_1d_slice,
108
151
  z_loss_ptr=z_loss_1d_slice,
109
152
  loss_stride=loss_1d_slice.stride(-1), # always 1
153
+ token_accuracy_ptr=token_accuracy_1d_slice,
154
+ token_accuracy_stride=token_accuracy_1d_slice.stride(-1)
155
+ if return_token_accuracy
156
+ else 0, # always 1 if accuracy is enabled
110
157
  n_cols=V,
111
158
  n_non_ignore=total_n_non_ignore,
112
159
  sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
@@ -117,26 +164,43 @@ def fused_linear_cross_entropy_forward(
117
164
  reduction=reduction,
118
165
  softcap=softcap,
119
166
  RETURN_Z_LOSS=return_z_loss,
167
+ RETURN_TOKEN_ACCURACY=return_token_accuracy,
120
168
  HAS_WEIGHT=True if ce_weight is not None else False,
121
169
  HAS_SOFTCAPPING=True if softcap is not None else False,
170
+ HAS_GRADIENTS=input_requires_grad,
122
171
  BLOCK_SIZE=BLOCK_SIZE,
123
172
  num_warps=32 if not is_hip() else 16,
124
173
  )
125
174
 
175
+ # Apply token scaling if requested
176
+ if use_token_scaling:
177
+ loss_1d_slice = loss_1d_slice * scaling_factors
178
+ if return_z_loss:
179
+ z_loss_1d_slice = z_loss_1d_slice * scaling_factors
180
+
126
181
  loss_1d[start_idx:end_idx] = loss_1d_slice
127
182
  if return_z_loss:
128
183
  z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
184
+ if return_token_accuracy:
185
+ token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice
129
186
  grad_logits_chunk = logits_chunk # chunk_size x V
130
187
 
131
- grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
188
+ # Apply token scaling to gradients if requested
189
+ if use_token_scaling:
190
+ # Expand scaling factors to match gradient dimensions
191
+ scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
192
+ grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
132
193
 
133
- if grad_weight is not None:
194
+ if input_requires_grad:
195
+ grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
196
+
197
+ if grad_weight is not None and input_requires_grad:
134
198
  grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
135
199
 
136
- if bias is not None:
200
+ if bias is not None and input_requires_grad:
137
201
  torch.add(
138
202
  input=grad_bias,
139
- other=logits_chunk.sum(dim=0),
203
+ other=grad_logits_chunk.sum(dim=0),
140
204
  out=grad_bias,
141
205
  alpha=1.0,
142
206
  )
@@ -146,15 +210,22 @@ def fused_linear_cross_entropy_forward(
146
210
  # loss = loss_1d
147
211
  # z_loss = z_loss_1d if return_z_loss else None
148
212
 
213
+ if reduction == "none":
214
+ # Return per-token losses
215
+ loss = loss_1d
216
+ z_loss = z_loss_1d if return_z_loss else None
217
+ token_accuracy = token_accuracy_1d if return_token_accuracy else None
149
218
  else:
150
219
  loss = torch.sum(loss_1d)
151
220
  z_loss = torch.sum(z_loss_1d) if return_z_loss else None
221
+ # For accuracy, we compute the mean across all non-ignored tokens
222
+ token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None
152
223
 
153
224
  # Cast back to original dtype
154
225
  grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
155
226
  grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None
156
227
 
157
- return loss, z_loss, grad_input, grad_weight, grad_bias
228
+ return loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias
158
229
 
159
230
 
160
231
  def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
@@ -221,6 +292,8 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
221
292
  softcap=None,
222
293
  return_z_loss: bool = False,
223
294
  accum_dtype=None,
295
+ use_token_scaling: bool = False,
296
+ return_token_accuracy: bool = False,
224
297
  ):
225
298
  """
226
299
  Fusing the last linear layer with cross-entropy loss
@@ -241,9 +314,13 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
241
314
  reduction: reduction to apply
242
315
  accum_dtype (torch.dtype): the dtype of intermediate result buffers for weight and bias gradient accumulations.
243
316
  Recommended to set `accum_dtype` to higher precision, e.g. `torch.float32`, if the training is unstable with original dtype. Default: `None`, performing accumulations in original dtype
317
+ use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
318
+ When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
319
+ Default: False.
320
+ return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
244
321
  """
245
322
 
246
- loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
323
+ loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
247
324
  _input=_input,
248
325
  weight=weight,
249
326
  target=target,
@@ -256,6 +333,8 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
256
333
  softcap=softcap,
257
334
  return_z_loss=return_z_loss,
258
335
  accum_dtype=accum_dtype,
336
+ use_token_scaling=use_token_scaling,
337
+ return_token_accuracy=return_token_accuracy,
259
338
  )
260
339
  # downcast to dtype and store for backward
261
340
  ctx.save_for_backward(
@@ -264,13 +343,16 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
264
343
  grad_bias.detach() if bias is not None else None,
265
344
  )
266
345
  ctx.return_z_loss = return_z_loss
267
- return loss, z_loss
346
+ ctx.return_token_accuracy = return_token_accuracy
347
+ return loss, z_loss, token_accuracy
268
348
 
269
349
  @staticmethod
270
350
  @amp_custom_bwd
271
- def backward(ctx, grad_output, grad_output2):
351
+ def backward(ctx, grad_output, grad_output2, grad_output3):
272
352
  if ctx.return_z_loss:
273
353
  del grad_output2 # z_loss is only for logging
354
+ if ctx.return_token_accuracy:
355
+ del grad_output3 # token_accuracy is only for metrics
274
356
  (grad_input, grad_weight, grad_bias) = ctx.saved_tensors
275
357
  grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
276
358
  grad_output, grad_input, grad_weight, grad_bias
@@ -288,4 +370,6 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
288
370
  None,
289
371
  None,
290
372
  None,
373
+ None, # use_token_scaling
374
+ None, # return_token_accuracy
291
375
  )
@@ -128,7 +128,9 @@ def _grpo_loss_fwd_kernel(
128
128
  per_token_loss1 = coef_1 * advantage
129
129
  per_token_loss2 = coef_2 * advantage
130
130
  per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
131
- is_clipped = per_token_loss1 < per_token_loss2
131
+ is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
132
+ is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
133
+ is_clipped = is_low_clipped | is_high_clipped
132
134
 
133
135
  if BETA != 0.0:
134
136
  REF_LOGP += off_b * L + off_l