liger-kernel-nightly 0.5.10.dev20250624183504__py3-none-any.whl → 0.6.4.dev20251121224847__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (73) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +54 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
  7. liger_kernel/chunked_loss/grpo_loss.py +46 -9
  8. liger_kernel/chunked_loss/jsd_loss.py +23 -7
  9. liger_kernel/ops/cross_entropy.py +118 -62
  10. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  11. liger_kernel/ops/fused_linear_cross_entropy.py +113 -21
  12. liger_kernel/ops/geglu.py +1 -1
  13. liger_kernel/ops/grpo_loss.py +3 -1
  14. liger_kernel/ops/layer_norm.py +133 -79
  15. liger_kernel/ops/llama4_rope.py +225 -0
  16. liger_kernel/ops/poly_norm.py +386 -0
  17. liger_kernel/ops/rms_norm.py +2 -2
  18. liger_kernel/ops/rope.py +1 -1
  19. liger_kernel/ops/swiglu.py +1 -1
  20. liger_kernel/ops/tiled_mlp.py +136 -0
  21. liger_kernel/transformers/__init__.py +59 -0
  22. liger_kernel/transformers/cross_entropy.py +8 -3
  23. liger_kernel/transformers/experimental/__init__.py +5 -0
  24. liger_kernel/transformers/functional.py +38 -6
  25. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  26. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -4
  27. liger_kernel/transformers/grpo_loss.py +56 -1
  28. liger_kernel/transformers/llama4_rope.py +93 -0
  29. liger_kernel/transformers/model/falcon_h1.py +122 -0
  30. liger_kernel/transformers/model/gemma.py +28 -8
  31. liger_kernel/transformers/model/gemma2.py +31 -8
  32. liger_kernel/transformers/model/gemma3.py +100 -110
  33. liger_kernel/transformers/model/glm4.py +18 -5
  34. liger_kernel/transformers/model/glm4v.py +163 -0
  35. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  36. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  37. liger_kernel/transformers/model/internvl.py +157 -0
  38. liger_kernel/transformers/model/llama.py +26 -7
  39. liger_kernel/transformers/model/llama4.py +121 -0
  40. liger_kernel/transformers/model/llava.py +18 -6
  41. liger_kernel/transformers/model/loss_utils.py +34 -3
  42. liger_kernel/transformers/model/mistral.py +17 -10
  43. liger_kernel/transformers/model/mixtral.py +24 -9
  44. liger_kernel/transformers/model/mllama.py +18 -7
  45. liger_kernel/transformers/model/olmo2.py +18 -5
  46. liger_kernel/transformers/model/olmo3.py +142 -0
  47. liger_kernel/transformers/model/output_classes.py +147 -0
  48. liger_kernel/transformers/model/paligemma.py +41 -5
  49. liger_kernel/transformers/model/phi3.py +24 -159
  50. liger_kernel/transformers/model/qwen2.py +26 -4
  51. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  52. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  53. liger_kernel/transformers/model/qwen3.py +22 -6
  54. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  55. liger_kernel/transformers/model/qwen3_next.py +146 -0
  56. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  57. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  58. liger_kernel/transformers/model/smollm3.py +199 -0
  59. liger_kernel/transformers/model/smolvlm.py +158 -0
  60. liger_kernel/transformers/monkey_patch.py +1278 -116
  61. liger_kernel/transformers/multi_token_attention.py +1 -1
  62. liger_kernel/transformers/poly_norm.py +42 -0
  63. liger_kernel/transformers/rms_norm.py +7 -0
  64. liger_kernel/transformers/rope.py +43 -0
  65. liger_kernel/transformers/swiglu.py +17 -0
  66. liger_kernel/transformers/tiled_mlp.py +133 -0
  67. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/METADATA +29 -24
  68. liger_kernel_nightly-0.6.4.dev20251121224847.dist-info/RECORD +118 -0
  69. liger_kernel_nightly-0.5.10.dev20250624183504.dist-info/RECORD +0 -95
  70. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/LICENSE +0 -0
  71. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/NOTICE +0 -0
  72. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/WHEEL +0 -0
  73. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/top_level.txt +0 -0
@@ -32,6 +32,8 @@ def liger_cross_entropy_kernel(
32
32
  loss_ptr,
33
33
  z_loss_ptr,
34
34
  loss_stride,
35
+ token_accuracy_ptr,
36
+ token_accuracy_stride,
35
37
  n_cols,
36
38
  n_non_ignore,
37
39
  sum_non_ignore_weight,
@@ -42,9 +44,11 @@ def liger_cross_entropy_kernel(
42
44
  reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
43
45
  softcap,
44
46
  RETURN_Z_LOSS: tl.constexpr,
47
+ RETURN_TOKEN_ACCURACY: tl.constexpr,
45
48
  BLOCK_SIZE: tl.constexpr,
46
49
  HAS_WEIGHT: tl.constexpr,
47
50
  HAS_SOFTCAPPING: tl.constexpr,
51
+ HAS_GRADIENTS: tl.constexpr,
48
52
  ):
49
53
  """
50
54
  This kernel computes both cross entropy loss and the gradient of the input.
@@ -59,6 +63,8 @@ def liger_cross_entropy_kernel(
59
63
  loss_ptr: Pointer to tensor to store the loss.
60
64
  z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
61
65
  loss_stride (int): The stride of the loss tensor.
66
+ token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
67
+ token_accuracy_stride (int): The stride of the token accuracy tensor.
62
68
  n_cols (int): The number of columns in the input tensor.
63
69
  n_non_ignore (float): The number of non-ignored elements in the batch.
64
70
  sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
@@ -68,10 +74,12 @@ def liger_cross_entropy_kernel(
68
74
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
69
75
  reduction (str): The string for the reduction to apply
70
76
  softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
71
- RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
77
+ RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
78
+ RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
72
79
  BLOCK_SIZE (int): The block size for Triton operations.
73
80
  HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
74
81
  HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
82
+ HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
75
83
  """
76
84
 
77
85
  # https://github.com/triton-lang/triton/issues/1058
@@ -90,11 +98,17 @@ def liger_cross_entropy_kernel(
90
98
  for i in range(0, n_cols, BLOCK_SIZE):
91
99
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
92
100
  tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
101
+ # For ignored tokens, set token accuracy to 0
102
+ if RETURN_TOKEN_ACCURACY:
103
+ token_accuracy_ptr += program_id * token_accuracy_stride
104
+ tl.store(token_accuracy_ptr, 0.0)
93
105
  return
94
106
 
95
107
  loss_ptr += program_id * loss_stride
96
108
  if RETURN_Z_LOSS:
97
109
  z_loss_ptr += program_id * loss_stride
110
+ if RETURN_TOKEN_ACCURACY:
111
+ token_accuracy_ptr += program_id * token_accuracy_stride
98
112
 
99
113
  if HAS_WEIGHT:
100
114
  weight_y = tl.load(weight_ptr + y).cast(tl.float32)
@@ -105,6 +119,7 @@ def liger_cross_entropy_kernel(
105
119
  # 3. [Online softmax] first pass: find max + sum
106
120
  m = float("-inf") # m is the max value. use the notation from the paper
107
121
  d = 0.0 # d is the sum. use the notation from the paper
122
+ argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
108
123
  ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
109
124
  if HAS_SOFTCAPPING:
110
125
  ori_X_y = softcap * tanh(ori_X_y / softcap)
@@ -125,6 +140,16 @@ def liger_cross_entropy_kernel(
125
140
  if HAS_SOFTCAPPING:
126
141
  X_block = softcap * tanh(X_block / softcap)
127
142
  block_max = tl.max(X_block)
143
+
144
+ # Track argmax for accuracy computation
145
+ if RETURN_TOKEN_ACCURACY and block_max > m:
146
+ # Find the index of the maximum value in this block
147
+ is_max_mask = X_block == block_max
148
+ # Mask out invalid indices with a value larger than n_cols
149
+ masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
150
+ # Get the first (smallest) index where max occurs
151
+ argmax_idx = tl.min(masked_offsets)
152
+
128
153
  if label_smoothing > 0:
129
154
  # scale X beforehand to avoid overflow
130
155
  if HAS_WEIGHT:
@@ -155,58 +180,58 @@ def liger_cross_entropy_kernel(
155
180
  # For 'sum' reduction, no normalization is applied:
156
181
  # dx_y = softmax(x_y) - 1
157
182
  # dx_i = softmax(x_i), for i ≠ y
158
-
159
- for i in range(0, n_cols, BLOCK_SIZE):
160
- X_offsets = i + tl.arange(0, BLOCK_SIZE)
161
- X_block = tl.load(
162
- X_ptr + X_offsets,
163
- mask=X_offsets < n_cols,
164
- other=float("-inf"),
165
- # Ensure float32 precision for softmax calculation
166
- ).cast(tl.float32)
167
- if HAS_SOFTCAPPING:
168
- intermediate = tanh(X_block / softcap)
169
- X_block = softcap * intermediate
170
-
171
- if not HAS_WEIGHT:
172
- # softmax(x_i)
173
- X_block = tl.exp(X_block - m) / d
174
- # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
175
- X_block += 2 * lse_square_scale * lse * X_block
176
- # smoothing term
177
- X_block += -eps
178
- # special handle dx_y
179
- X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
180
- # reduction scale
181
- if reduction == "mean":
182
- X_block = X_block / n_non_ignore
183
- else:
184
- weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
185
- softmax_X = tl.exp(X_block - m) / d
186
- # derivative of original_loss
187
- dloss_ori = (1 - label_smoothing) * softmax_X
188
- # specially handle dx_y
189
- dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
190
- dloss_ori = dloss_ori * weight_y
191
- # derivative of smooth_loss
192
- dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
193
- # derivative of z-loss
194
- dz_loss = 2 * lse_square_scale * lse * softmax_X
195
- # reduction scale
196
- if reduction == "mean":
197
- dloss_ori = dloss_ori / sum_non_ignore_weight
198
- dloss_smooth = dloss_smooth / sum_non_ignore_weight
199
- # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
200
- dz_loss = dz_loss / n_non_ignore
201
- # derivative of total_loss
202
- X_block = dloss_ori + dloss_smooth + dz_loss
203
-
204
- # chain rule softcapping
205
- # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
206
- if HAS_SOFTCAPPING:
207
- X_block = X_block * (1 - intermediate * intermediate)
208
-
209
- tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
183
+ if HAS_GRADIENTS:
184
+ for i in range(0, n_cols, BLOCK_SIZE):
185
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
186
+ X_block = tl.load(
187
+ X_ptr + X_offsets,
188
+ mask=X_offsets < n_cols,
189
+ other=float("-inf"),
190
+ # Ensure float32 precision for softmax calculation
191
+ ).cast(tl.float32)
192
+ if HAS_SOFTCAPPING:
193
+ intermediate = tanh(X_block / softcap)
194
+ X_block = softcap * intermediate
195
+
196
+ if not HAS_WEIGHT:
197
+ # softmax(x_i)
198
+ X_block = tl.exp(X_block - m) / d
199
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
200
+ X_block += 2 * lse_square_scale * lse * X_block
201
+ # smoothing term
202
+ X_block += -eps
203
+ # special handle dx_y
204
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
205
+ # reduction scale
206
+ if reduction == "mean":
207
+ X_block = X_block / n_non_ignore
208
+ else:
209
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
210
+ softmax_X = tl.exp(X_block - m) / d
211
+ # derivative of original_loss
212
+ dloss_ori = (1 - label_smoothing) * softmax_X
213
+ # specially handle dx_y
214
+ dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
215
+ dloss_ori = dloss_ori * weight_y
216
+ # derivative of smooth_loss
217
+ dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
218
+ # derivative of z-loss
219
+ dz_loss = 2 * lse_square_scale * lse * softmax_X
220
+ # reduction scale
221
+ if reduction == "mean":
222
+ dloss_ori = dloss_ori / sum_non_ignore_weight
223
+ dloss_smooth = dloss_smooth / sum_non_ignore_weight
224
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
225
+ dz_loss = dz_loss / n_non_ignore
226
+ # derivative of total_loss
227
+ X_block = dloss_ori + dloss_smooth + dz_loss
228
+
229
+ # chain rule softcapping
230
+ # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
231
+ if HAS_SOFTCAPPING:
232
+ X_block = X_block * (1 - intermediate * intermediate)
233
+
234
+ tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
210
235
 
211
236
  # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
212
237
  # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
@@ -254,6 +279,10 @@ def liger_cross_entropy_kernel(
254
279
  tl.store(loss_ptr, loss)
255
280
  if RETURN_Z_LOSS:
256
281
  tl.store(z_loss_ptr, z_loss)
282
+ if RETURN_TOKEN_ACCURACY:
283
+ # Store 1.0 if prediction is correct, 0.0 otherwise
284
+ is_correct = 1.0 if argmax_idx == y else 0.0
285
+ tl.store(token_accuracy_ptr, is_correct)
257
286
 
258
287
 
259
288
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
@@ -272,8 +301,12 @@ def cross_entropy_forward(
272
301
  reduction,
273
302
  softcap,
274
303
  return_z_loss,
304
+ return_token_accuracy=False,
275
305
  ):
276
306
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
307
+ assert isinstance(return_token_accuracy, bool), (
308
+ f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
309
+ )
277
310
 
278
311
  BT, V = _input.shape
279
312
  n_rows = BT
@@ -283,6 +316,9 @@ def cross_entropy_forward(
283
316
  # unreduced loss
284
317
  loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
285
318
  z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
319
+ token_accuracy_1d = (
320
+ torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
321
+ )
286
322
 
287
323
  target_mask = target != ignore_index
288
324
  n_non_ignore = target_mask.sum().item()
@@ -319,6 +355,10 @@ def cross_entropy_forward(
319
355
  loss_ptr=loss_1d,
320
356
  z_loss_ptr=z_loss_1d,
321
357
  loss_stride=loss_1d.stride(-1), # always 1
358
+ token_accuracy_ptr=token_accuracy_1d,
359
+ token_accuracy_stride=token_accuracy_1d.stride(-1)
360
+ if return_token_accuracy
361
+ else 0, # always 1 if accuracy is enabled
322
362
  n_cols=V,
323
363
  n_non_ignore=n_non_ignore,
324
364
  sum_non_ignore_weight=sum_non_ignore_weight,
@@ -329,9 +369,11 @@ def cross_entropy_forward(
329
369
  reduction=reduction,
330
370
  softcap=softcap,
331
371
  RETURN_Z_LOSS=return_z_loss,
372
+ RETURN_TOKEN_ACCURACY=return_token_accuracy,
332
373
  BLOCK_SIZE=BLOCK_SIZE,
333
374
  HAS_WEIGHT=True if weight is not None else False,
334
375
  HAS_SOFTCAPPING=True if softcap is not None else False,
376
+ HAS_GRADIENTS=_input.requires_grad,
335
377
  # TODO: 32 seems to give the best performance
336
378
  # Performance is quite sensitive to num_warps
337
379
  num_warps=32 if not is_hip() else 16,
@@ -340,11 +382,14 @@ def cross_entropy_forward(
340
382
  if reduction == "none":
341
383
  loss = loss_1d
342
384
  z_loss = z_loss_1d if return_z_loss else None
385
+ token_accuracy = token_accuracy_1d if return_token_accuracy else None
343
386
  else:
344
387
  loss = torch.sum(loss_1d)
345
388
  z_loss = torch.sum(z_loss_1d) if return_z_loss else None
389
+ # For accuracy, we compute the mean across all non-ignored tokens
390
+ token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
346
391
 
347
- return loss, z_loss, _input
392
+ return loss, z_loss, token_accuracy, _input
348
393
 
349
394
 
350
395
  def cross_entropy_backward(_input, grad_output):
@@ -392,6 +437,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
392
437
  reduction: str = "mean",
393
438
  softcap: Optional[float] = None,
394
439
  return_z_loss: bool = False,
440
+ return_token_accuracy: bool = False,
395
441
  ):
396
442
  """
397
443
  The forward pass of the Liger Cross Entropy loss.
@@ -406,12 +452,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
406
452
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
407
453
  reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
408
454
  softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
409
- return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
455
+ return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
456
+ return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
410
457
 
411
458
  Returns:
412
- tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
459
+ tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
413
460
  """
414
- loss, z_loss, _input = cross_entropy_forward(
461
+ input_requires_grad = _input.requires_grad
462
+
463
+ loss, z_loss, token_accuracy, _input = cross_entropy_forward(
415
464
  _input,
416
465
  target,
417
466
  weight,
@@ -421,29 +470,35 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
421
470
  reduction,
422
471
  softcap,
423
472
  return_z_loss,
473
+ return_token_accuracy,
424
474
  )
425
475
  # TODO: investigation
426
476
  # If we don't detach the _input tensor, the memory will double
427
477
  # Not sure why but seems that there will be a time both grad and value exist but in different location
428
- ctx.save_for_backward(_input.detach())
478
+ if input_requires_grad:
479
+ ctx.save_for_backward(_input.detach())
429
480
  ctx.return_z_loss = return_z_loss
481
+ ctx.return_token_accuracy = return_token_accuracy
430
482
 
431
- return loss, z_loss
483
+ return loss, z_loss, token_accuracy
432
484
 
433
485
  @staticmethod
434
- def backward(ctx, grad_output, grad_ouput2):
486
+ def backward(ctx, grad_output, grad_output2, grad_output3):
435
487
  """
436
488
  The backward pass of the Liger Cross Entropy loss.
437
489
 
438
490
  Parameters:
439
491
  ctx : The context object with saved tensors.
440
492
  grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
441
- grad_output2 (tenosr): No use.
493
+ grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
494
+ grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
442
495
  Returns:
443
496
  tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
444
497
  """
445
498
  if ctx.return_z_loss:
446
- del grad_ouput2 # z_loss is only for logging
499
+ del grad_output2 # z_loss is only for logging
500
+ if ctx.return_token_accuracy:
501
+ del grad_output3 # token_accuracy is only for metrics
447
502
 
448
503
  (_input,) = ctx.saved_tensors
449
504
  _input = cross_entropy_backward(_input, grad_output)
@@ -457,4 +512,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
457
512
  None,
458
513
  None,
459
514
  None,
515
+ None,
460
516
  )