liger-kernel-nightly 0.5.6.dev20250403190551__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +35 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +25 -9
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  15. liger_kernel/ops/backends/registry.py +61 -0
  16. liger_kernel/ops/cross_entropy.py +124 -64
  17. liger_kernel/ops/dyt.py +115 -180
  18. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  19. liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
  20. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  21. liger_kernel/ops/geglu.py +3 -2
  22. liger_kernel/ops/group_norm.py +2 -1
  23. liger_kernel/ops/grpo_loss.py +312 -0
  24. liger_kernel/ops/jsd.py +2 -1
  25. liger_kernel/ops/kl_div.py +13 -6
  26. liger_kernel/ops/layer_norm.py +146 -78
  27. liger_kernel/ops/llama4_rope.py +225 -0
  28. liger_kernel/ops/multi_token_attention.py +207 -0
  29. liger_kernel/ops/poly_norm.py +390 -0
  30. liger_kernel/ops/rms_norm.py +283 -56
  31. liger_kernel/ops/rope.py +1 -1
  32. liger_kernel/ops/softmax.py +201 -0
  33. liger_kernel/ops/sparsemax.py +179 -0
  34. liger_kernel/ops/swiglu.py +1 -1
  35. liger_kernel/ops/tiled_mlp.py +136 -0
  36. liger_kernel/ops/utils.py +2 -0
  37. liger_kernel/transformers/__init__.py +205 -19
  38. liger_kernel/transformers/cross_entropy.py +9 -4
  39. liger_kernel/transformers/dyt.py +6 -4
  40. liger_kernel/transformers/experimental/__init__.py +5 -0
  41. liger_kernel/transformers/experimental/embedding.py +1 -1
  42. liger_kernel/transformers/fsdp.py +55 -0
  43. liger_kernel/transformers/functional.py +122 -20
  44. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  45. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  46. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  47. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  48. liger_kernel/transformers/geglu.py +1 -1
  49. liger_kernel/transformers/group_norm.py +1 -1
  50. liger_kernel/transformers/grpo_loss.py +153 -0
  51. liger_kernel/transformers/jsd.py +1 -1
  52. liger_kernel/transformers/kl_div.py +1 -1
  53. liger_kernel/transformers/layer_norm.py +1 -1
  54. liger_kernel/transformers/llama4_rope.py +93 -0
  55. liger_kernel/transformers/model/falcon_h1.py +122 -0
  56. liger_kernel/transformers/model/gemma.py +50 -25
  57. liger_kernel/transformers/model/gemma2.py +55 -23
  58. liger_kernel/transformers/model/gemma3.py +117 -120
  59. liger_kernel/transformers/model/glm4.py +141 -0
  60. liger_kernel/transformers/model/glm4v.py +163 -0
  61. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  62. liger_kernel/transformers/model/gpt_oss.py +211 -0
  63. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  64. liger_kernel/transformers/model/internvl.py +157 -0
  65. liger_kernel/transformers/model/llama.py +102 -25
  66. liger_kernel/transformers/model/llama4.py +121 -0
  67. liger_kernel/transformers/model/llava.py +111 -136
  68. liger_kernel/transformers/model/loss_utils.py +50 -12
  69. liger_kernel/transformers/model/mistral.py +36 -23
  70. liger_kernel/transformers/model/mixtral.py +45 -25
  71. liger_kernel/transformers/model/mllama.py +39 -22
  72. liger_kernel/transformers/model/olmo2.py +40 -20
  73. liger_kernel/transformers/model/olmo3.py +142 -0
  74. liger_kernel/transformers/model/output_classes.py +147 -0
  75. liger_kernel/transformers/model/paligemma.py +50 -14
  76. liger_kernel/transformers/model/phi3.py +47 -177
  77. liger_kernel/transformers/model/qwen2.py +48 -21
  78. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  79. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  80. liger_kernel/transformers/model/qwen3.py +136 -0
  81. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  82. liger_kernel/transformers/model/qwen3_next.py +146 -0
  83. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  84. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  85. liger_kernel/transformers/model/smollm3.py +199 -0
  86. liger_kernel/transformers/model/smolvlm.py +158 -0
  87. liger_kernel/transformers/monkey_patch.py +1678 -160
  88. liger_kernel/transformers/multi_token_attention.py +64 -0
  89. liger_kernel/transformers/poly_norm.py +42 -0
  90. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  91. liger_kernel/transformers/rms_norm.py +48 -5
  92. liger_kernel/transformers/rope.py +45 -1
  93. liger_kernel/transformers/softmax.py +12 -0
  94. liger_kernel/transformers/sparsemax.py +16 -0
  95. liger_kernel/transformers/swiglu.py +39 -1
  96. liger_kernel/transformers/tiled_mlp.py +133 -0
  97. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  98. liger_kernel/transformers/tvd.py +1 -1
  99. liger_kernel/utils.py +36 -0
  100. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/METADATA +68 -38
  101. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  102. liger_kernel/transformers/gema3_rms.py +0 -8
  103. liger_kernel_nightly-0.5.6.dev20250403190551.dist-info/RECORD +0 -82
  104. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  105. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/NOTICE +0 -0
  106. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +0 -0
  107. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -10,8 +10,9 @@ from liger_kernel.ops.utils import compare_version
10
10
  from liger_kernel.ops.utils import element_mul_kernel
11
11
  from liger_kernel.ops.utils import is_hip
12
12
  from liger_kernel.utils import infer_device
13
+ from liger_kernel.utils import is_npu_available
13
14
 
14
- if compare_version("triton", operator.ge, "3.0.0"):
15
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
15
16
  try:
16
17
  # typical import path with dispatch available
17
18
  from triton.language.extra.libdevice import tanh
@@ -32,6 +33,8 @@ def liger_cross_entropy_kernel(
32
33
  loss_ptr,
33
34
  z_loss_ptr,
34
35
  loss_stride,
36
+ token_accuracy_ptr,
37
+ token_accuracy_stride,
35
38
  n_cols,
36
39
  n_non_ignore,
37
40
  sum_non_ignore_weight,
@@ -42,9 +45,11 @@ def liger_cross_entropy_kernel(
42
45
  reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
43
46
  softcap,
44
47
  RETURN_Z_LOSS: tl.constexpr,
48
+ RETURN_TOKEN_ACCURACY: tl.constexpr,
45
49
  BLOCK_SIZE: tl.constexpr,
46
50
  HAS_WEIGHT: tl.constexpr,
47
51
  HAS_SOFTCAPPING: tl.constexpr,
52
+ HAS_GRADIENTS: tl.constexpr,
48
53
  ):
49
54
  """
50
55
  This kernel computes both cross entropy loss and the gradient of the input.
@@ -59,6 +64,8 @@ def liger_cross_entropy_kernel(
59
64
  loss_ptr: Pointer to tensor to store the loss.
60
65
  z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
61
66
  loss_stride (int): The stride of the loss tensor.
67
+ token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
68
+ token_accuracy_stride (int): The stride of the token accuracy tensor.
62
69
  n_cols (int): The number of columns in the input tensor.
63
70
  n_non_ignore (float): The number of non-ignored elements in the batch.
64
71
  sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
@@ -68,10 +75,12 @@ def liger_cross_entropy_kernel(
68
75
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
69
76
  reduction (str): The string for the reduction to apply
70
77
  softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
71
- RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
78
+ RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
79
+ RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
72
80
  BLOCK_SIZE (int): The block size for Triton operations.
73
81
  HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
74
82
  HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
83
+ HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
75
84
  """
76
85
 
77
86
  # https://github.com/triton-lang/triton/issues/1058
@@ -90,11 +99,17 @@ def liger_cross_entropy_kernel(
90
99
  for i in range(0, n_cols, BLOCK_SIZE):
91
100
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
92
101
  tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
102
+ # For ignored tokens, set token accuracy to 0
103
+ if RETURN_TOKEN_ACCURACY:
104
+ token_accuracy_ptr += program_id * token_accuracy_stride
105
+ tl.store(token_accuracy_ptr, 0.0)
93
106
  return
94
107
 
95
108
  loss_ptr += program_id * loss_stride
96
109
  if RETURN_Z_LOSS:
97
110
  z_loss_ptr += program_id * loss_stride
111
+ if RETURN_TOKEN_ACCURACY:
112
+ token_accuracy_ptr += program_id * token_accuracy_stride
98
113
 
99
114
  if HAS_WEIGHT:
100
115
  weight_y = tl.load(weight_ptr + y).cast(tl.float32)
@@ -105,6 +120,7 @@ def liger_cross_entropy_kernel(
105
120
  # 3. [Online softmax] first pass: find max + sum
106
121
  m = float("-inf") # m is the max value. use the notation from the paper
107
122
  d = 0.0 # d is the sum. use the notation from the paper
123
+ argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
108
124
  ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
109
125
  if HAS_SOFTCAPPING:
110
126
  ori_X_y = softcap * tanh(ori_X_y / softcap)
@@ -125,6 +141,16 @@ def liger_cross_entropy_kernel(
125
141
  if HAS_SOFTCAPPING:
126
142
  X_block = softcap * tanh(X_block / softcap)
127
143
  block_max = tl.max(X_block)
144
+
145
+ # Track argmax for accuracy computation
146
+ if RETURN_TOKEN_ACCURACY and block_max > m:
147
+ # Find the index of the maximum value in this block
148
+ is_max_mask = X_block == block_max
149
+ # Mask out invalid indices with a value larger than n_cols
150
+ masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
151
+ # Get the first (smallest) index where max occurs
152
+ argmax_idx = tl.min(masked_offsets)
153
+
128
154
  if label_smoothing > 0:
129
155
  # scale X beforehand to avoid overflow
130
156
  if HAS_WEIGHT:
@@ -155,58 +181,58 @@ def liger_cross_entropy_kernel(
155
181
  # For 'sum' reduction, no normalization is applied:
156
182
  # dx_y = softmax(x_y) - 1
157
183
  # dx_i = softmax(x_i), for i ≠ y
158
-
159
- for i in range(0, n_cols, BLOCK_SIZE):
160
- X_offsets = i + tl.arange(0, BLOCK_SIZE)
161
- X_block = tl.load(
162
- X_ptr + X_offsets,
163
- mask=X_offsets < n_cols,
164
- other=float("-inf"),
165
- # Ensure float32 precision for softmax calculation
166
- ).cast(tl.float32)
167
- if HAS_SOFTCAPPING:
168
- intermediate = tanh(X_block / softcap)
169
- X_block = softcap * intermediate
170
-
171
- if not HAS_WEIGHT:
172
- # softmax(x_i)
173
- X_block = tl.exp(X_block - m) / d
174
- # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
175
- X_block += 2 * lse_square_scale * lse * X_block
176
- # smoothing term
177
- X_block += -eps
178
- # special handle dx_y
179
- X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
180
- # reduction scale
181
- if reduction == "mean":
182
- X_block = X_block / n_non_ignore
183
- else:
184
- weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
185
- softmax_X = tl.exp(X_block - m) / d
186
- # derivative of original_loss
187
- dloss_ori = (1 - label_smoothing) * softmax_X
188
- # specially handle dx_y
189
- dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
190
- dloss_ori = dloss_ori * weight_y
191
- # derivative of smooth_loss
192
- dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
193
- # derivative of z-loss
194
- dz_loss = 2 * lse_square_scale * lse * softmax_X
195
- # reduction scale
196
- if reduction == "mean":
197
- dloss_ori = dloss_ori / sum_non_ignore_weight
198
- dloss_smooth = dloss_smooth / sum_non_ignore_weight
199
- # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
200
- dz_loss = dz_loss / n_non_ignore
201
- # derivative of total_loss
202
- X_block = dloss_ori + dloss_smooth + dz_loss
203
-
204
- # chain rule softcapping
205
- # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
206
- if HAS_SOFTCAPPING:
207
- X_block = X_block * (1 - intermediate * intermediate)
208
-
209
- tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
184
+ if HAS_GRADIENTS:
185
+ for i in range(0, n_cols, BLOCK_SIZE):
186
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
187
+ X_block = tl.load(
188
+ X_ptr + X_offsets,
189
+ mask=X_offsets < n_cols,
190
+ other=float("-inf"),
191
+ # Ensure float32 precision for softmax calculation
192
+ ).cast(tl.float32)
193
+ if HAS_SOFTCAPPING:
194
+ intermediate = tanh(X_block / softcap)
195
+ X_block = softcap * intermediate
196
+
197
+ if not HAS_WEIGHT:
198
+ # softmax(x_i)
199
+ X_block = tl.exp(X_block - m) / d
200
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
201
+ X_block += 2 * lse_square_scale * lse * X_block
202
+ # smoothing term
203
+ X_block += -eps
204
+ # special handle dx_y
205
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
206
+ # reduction scale
207
+ if reduction == "mean":
208
+ X_block = X_block / n_non_ignore
209
+ else:
210
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
211
+ softmax_X = tl.exp(X_block - m) / d
212
+ # derivative of original_loss
213
+ dloss_ori = (1 - label_smoothing) * softmax_X
214
+ # specially handle dx_y
215
+ dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
216
+ dloss_ori = dloss_ori * weight_y
217
+ # derivative of smooth_loss
218
+ dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
219
+ # derivative of z-loss
220
+ dz_loss = 2 * lse_square_scale * lse * softmax_X
221
+ # reduction scale
222
+ if reduction == "mean":
223
+ dloss_ori = dloss_ori / sum_non_ignore_weight
224
+ dloss_smooth = dloss_smooth / sum_non_ignore_weight
225
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
226
+ dz_loss = dz_loss / n_non_ignore
227
+ # derivative of total_loss
228
+ X_block = dloss_ori + dloss_smooth + dz_loss
229
+
230
+ # chain rule softcapping
231
+ # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
232
+ if HAS_SOFTCAPPING:
233
+ X_block = X_block * (1 - intermediate * intermediate)
234
+
235
+ tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
210
236
 
211
237
  # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
212
238
  # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
@@ -254,6 +280,10 @@ def liger_cross_entropy_kernel(
254
280
  tl.store(loss_ptr, loss)
255
281
  if RETURN_Z_LOSS:
256
282
  tl.store(z_loss_ptr, z_loss)
283
+ if RETURN_TOKEN_ACCURACY:
284
+ # Store 1.0 if prediction is correct, 0.0 otherwise
285
+ is_correct = 1.0 if argmax_idx == y else 0.0
286
+ tl.store(token_accuracy_ptr, is_correct)
257
287
 
258
288
 
259
289
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
@@ -272,8 +302,12 @@ def cross_entropy_forward(
272
302
  reduction,
273
303
  softcap,
274
304
  return_z_loss,
305
+ return_token_accuracy=False,
275
306
  ):
276
307
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
308
+ assert isinstance(return_token_accuracy, bool), (
309
+ f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
310
+ )
277
311
 
278
312
  BT, V = _input.shape
279
313
  n_rows = BT
@@ -283,6 +317,9 @@ def cross_entropy_forward(
283
317
  # unreduced loss
284
318
  loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
285
319
  z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
320
+ token_accuracy_1d = (
321
+ torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
322
+ )
286
323
 
287
324
  target_mask = target != ignore_index
288
325
  n_non_ignore = target_mask.sum().item()
@@ -319,6 +356,10 @@ def cross_entropy_forward(
319
356
  loss_ptr=loss_1d,
320
357
  z_loss_ptr=z_loss_1d,
321
358
  loss_stride=loss_1d.stride(-1), # always 1
359
+ token_accuracy_ptr=token_accuracy_1d,
360
+ token_accuracy_stride=token_accuracy_1d.stride(-1)
361
+ if return_token_accuracy
362
+ else 0, # always 1 if accuracy is enabled
322
363
  n_cols=V,
323
364
  n_non_ignore=n_non_ignore,
324
365
  sum_non_ignore_weight=sum_non_ignore_weight,
@@ -329,9 +370,11 @@ def cross_entropy_forward(
329
370
  reduction=reduction,
330
371
  softcap=softcap,
331
372
  RETURN_Z_LOSS=return_z_loss,
373
+ RETURN_TOKEN_ACCURACY=return_token_accuracy,
332
374
  BLOCK_SIZE=BLOCK_SIZE,
333
375
  HAS_WEIGHT=True if weight is not None else False,
334
376
  HAS_SOFTCAPPING=True if softcap is not None else False,
377
+ HAS_GRADIENTS=_input.requires_grad,
335
378
  # TODO: 32 seems to give the best performance
336
379
  # Performance is quite sensitive to num_warps
337
380
  num_warps=32 if not is_hip() else 16,
@@ -340,18 +383,24 @@ def cross_entropy_forward(
340
383
  if reduction == "none":
341
384
  loss = loss_1d
342
385
  z_loss = z_loss_1d if return_z_loss else None
386
+ token_accuracy = token_accuracy_1d if return_token_accuracy else None
343
387
  else:
344
388
  loss = torch.sum(loss_1d)
345
389
  z_loss = torch.sum(z_loss_1d) if return_z_loss else None
390
+ # For accuracy, we compute the mean across all non-ignored tokens
391
+ token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
346
392
 
347
- return loss, z_loss, _input
393
+ return loss, z_loss, token_accuracy, _input
348
394
 
349
395
 
350
396
  def cross_entropy_backward(_input, grad_output):
351
397
  # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
352
398
  if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
353
399
  pass
354
-
400
+ # If reduction is 'none'
401
+ elif grad_output.ndim > 0:
402
+ _input = _input * grad_output.unsqueeze(dim=1)
403
+ # If reduction is ['mean', 'sum'], grad_output is just a scalar
355
404
  # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
356
405
  # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
357
406
  else:
@@ -389,6 +438,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
389
438
  reduction: str = "mean",
390
439
  softcap: Optional[float] = None,
391
440
  return_z_loss: bool = False,
441
+ return_token_accuracy: bool = False,
392
442
  ):
393
443
  """
394
444
  The forward pass of the Liger Cross Entropy loss.
@@ -403,12 +453,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
403
453
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
404
454
  reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
405
455
  softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
406
- return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
456
+ return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
457
+ return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
407
458
 
408
459
  Returns:
409
- tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
460
+ tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
410
461
  """
411
- loss, z_loss, _input = cross_entropy_forward(
462
+ input_requires_grad = _input.requires_grad
463
+
464
+ loss, z_loss, token_accuracy, _input = cross_entropy_forward(
412
465
  _input,
413
466
  target,
414
467
  weight,
@@ -418,29 +471,35 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
418
471
  reduction,
419
472
  softcap,
420
473
  return_z_loss,
474
+ return_token_accuracy,
421
475
  )
422
476
  # TODO: investigation
423
477
  # If we don't detach the _input tensor, the memory will double
424
478
  # Not sure why but seems that there will be a time both grad and value exist but in different location
425
- ctx.save_for_backward(_input.detach())
479
+ if input_requires_grad:
480
+ ctx.save_for_backward(_input.detach())
426
481
  ctx.return_z_loss = return_z_loss
482
+ ctx.return_token_accuracy = return_token_accuracy
427
483
 
428
- return loss, z_loss
484
+ return loss, z_loss, token_accuracy
429
485
 
430
486
  @staticmethod
431
- def backward(ctx, grad_output, grad_ouput2):
487
+ def backward(ctx, grad_output, grad_output2, grad_output3):
432
488
  """
433
489
  The backward pass of the Liger Cross Entropy loss.
434
490
 
435
491
  Parameters:
436
492
  ctx : The context object with saved tensors.
437
493
  grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
438
- grad_output2 (tenosr): No use.
494
+ grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
495
+ grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
439
496
  Returns:
440
497
  tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
441
498
  """
442
499
  if ctx.return_z_loss:
443
- del grad_ouput2 # z_loss is only for logging
500
+ del grad_output2 # z_loss is only for logging
501
+ if ctx.return_token_accuracy:
502
+ del grad_output3 # token_accuracy is only for metrics
444
503
 
445
504
  (_input,) = ctx.saved_tensors
446
505
  _input = cross_entropy_backward(_input, grad_output)
@@ -454,4 +513,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
454
513
  None,
455
514
  None,
456
515
  None,
516
+ None,
457
517
  )