liger-kernel-nightly 0.6.3.dev20251105224413__py3-none-any.whl → 0.6.3.dev20251105235313__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. liger_kernel/ops/cross_entropy.py +59 -9
  2. liger_kernel/ops/fused_linear_cross_entropy.py +27 -4
  3. liger_kernel/transformers/cross_entropy.py +8 -3
  4. liger_kernel/transformers/functional.py +24 -6
  5. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -3
  6. liger_kernel/transformers/model/falcon_h1.py +19 -5
  7. liger_kernel/transformers/model/gemma.py +17 -6
  8. liger_kernel/transformers/model/gemma2.py +14 -5
  9. liger_kernel/transformers/model/gemma3.py +25 -12
  10. liger_kernel/transformers/model/glm4.py +16 -4
  11. liger_kernel/transformers/model/glm4v.py +16 -4
  12. liger_kernel/transformers/model/glm4v_moe.py +19 -4
  13. liger_kernel/transformers/model/internvl.py +12 -5
  14. liger_kernel/transformers/model/llama.py +14 -5
  15. liger_kernel/transformers/model/llama4.py +16 -4
  16. liger_kernel/transformers/model/llava.py +12 -4
  17. liger_kernel/transformers/model/loss_utils.py +31 -3
  18. liger_kernel/transformers/model/mistral.py +15 -6
  19. liger_kernel/transformers/model/mixtral.py +16 -7
  20. liger_kernel/transformers/model/mllama.py +12 -4
  21. liger_kernel/transformers/model/olmo2.py +16 -4
  22. liger_kernel/transformers/model/output_classes.py +147 -0
  23. liger_kernel/transformers/model/paligemma.py +22 -5
  24. liger_kernel/transformers/model/phi3.py +14 -7
  25. liger_kernel/transformers/model/qwen2.py +16 -3
  26. liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
  27. liger_kernel/transformers/model/qwen2_vl.py +16 -4
  28. liger_kernel/transformers/model/qwen3.py +18 -5
  29. liger_kernel/transformers/model/qwen3_moe.py +19 -5
  30. liger_kernel/transformers/model/qwen3_next.py +17 -5
  31. liger_kernel/transformers/model/qwen3_vl.py +11 -5
  32. liger_kernel/transformers/model/qwen3_vl_moe.py +12 -5
  33. liger_kernel/transformers/model/smollm3.py +15 -6
  34. {liger_kernel_nightly-0.6.3.dev20251105224413.dist-info → liger_kernel_nightly-0.6.3.dev20251105235313.dist-info}/METADATA +1 -1
  35. {liger_kernel_nightly-0.6.3.dev20251105224413.dist-info → liger_kernel_nightly-0.6.3.dev20251105235313.dist-info}/RECORD +39 -38
  36. {liger_kernel_nightly-0.6.3.dev20251105224413.dist-info → liger_kernel_nightly-0.6.3.dev20251105235313.dist-info}/LICENSE +0 -0
  37. {liger_kernel_nightly-0.6.3.dev20251105224413.dist-info → liger_kernel_nightly-0.6.3.dev20251105235313.dist-info}/NOTICE +0 -0
  38. {liger_kernel_nightly-0.6.3.dev20251105224413.dist-info → liger_kernel_nightly-0.6.3.dev20251105235313.dist-info}/WHEEL +0 -0
  39. {liger_kernel_nightly-0.6.3.dev20251105224413.dist-info → liger_kernel_nightly-0.6.3.dev20251105235313.dist-info}/top_level.txt +0 -0
@@ -32,6 +32,8 @@ def liger_cross_entropy_kernel(
32
32
  loss_ptr,
33
33
  z_loss_ptr,
34
34
  loss_stride,
35
+ token_accuracy_ptr,
36
+ token_accuracy_stride,
35
37
  n_cols,
36
38
  n_non_ignore,
37
39
  sum_non_ignore_weight,
@@ -42,6 +44,7 @@ def liger_cross_entropy_kernel(
42
44
  reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
43
45
  softcap,
44
46
  RETURN_Z_LOSS: tl.constexpr,
47
+ RETURN_TOKEN_ACCURACY: tl.constexpr,
45
48
  BLOCK_SIZE: tl.constexpr,
46
49
  HAS_WEIGHT: tl.constexpr,
47
50
  HAS_SOFTCAPPING: tl.constexpr,
@@ -60,6 +63,8 @@ def liger_cross_entropy_kernel(
60
63
  loss_ptr: Pointer to tensor to store the loss.
61
64
  z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
62
65
  loss_stride (int): The stride of the loss tensor.
66
+ token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
67
+ token_accuracy_stride (int): The stride of the token accuracy tensor.
63
68
  n_cols (int): The number of columns in the input tensor.
64
69
  n_non_ignore (float): The number of non-ignored elements in the batch.
65
70
  sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
@@ -69,7 +74,8 @@ def liger_cross_entropy_kernel(
69
74
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
70
75
  reduction (str): The string for the reduction to apply
71
76
  softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
72
- RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
77
+ RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
78
+ RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
73
79
  BLOCK_SIZE (int): The block size for Triton operations.
74
80
  HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
75
81
  HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
@@ -92,11 +98,17 @@ def liger_cross_entropy_kernel(
92
98
  for i in range(0, n_cols, BLOCK_SIZE):
93
99
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
94
100
  tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
101
+ # For ignored tokens, set token accuracy to 0
102
+ if RETURN_TOKEN_ACCURACY:
103
+ token_accuracy_ptr += program_id * token_accuracy_stride
104
+ tl.store(token_accuracy_ptr, 0.0)
95
105
  return
96
106
 
97
107
  loss_ptr += program_id * loss_stride
98
108
  if RETURN_Z_LOSS:
99
109
  z_loss_ptr += program_id * loss_stride
110
+ if RETURN_TOKEN_ACCURACY:
111
+ token_accuracy_ptr += program_id * token_accuracy_stride
100
112
 
101
113
  if HAS_WEIGHT:
102
114
  weight_y = tl.load(weight_ptr + y).cast(tl.float32)
@@ -107,6 +119,7 @@ def liger_cross_entropy_kernel(
107
119
  # 3. [Online softmax] first pass: find max + sum
108
120
  m = float("-inf") # m is the max value. use the notation from the paper
109
121
  d = 0.0 # d is the sum. use the notation from the paper
122
+ argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
110
123
  ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
111
124
  if HAS_SOFTCAPPING:
112
125
  ori_X_y = softcap * tanh(ori_X_y / softcap)
@@ -127,6 +140,16 @@ def liger_cross_entropy_kernel(
127
140
  if HAS_SOFTCAPPING:
128
141
  X_block = softcap * tanh(X_block / softcap)
129
142
  block_max = tl.max(X_block)
143
+
144
+ # Track argmax for accuracy computation
145
+ if RETURN_TOKEN_ACCURACY and block_max > m:
146
+ # Find the index of the maximum value in this block
147
+ is_max_mask = X_block == block_max
148
+ # Mask out invalid indices with a value larger than n_cols
149
+ masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
150
+ # Get the first (smallest) index where max occurs
151
+ argmax_idx = tl.min(masked_offsets)
152
+
130
153
  if label_smoothing > 0:
131
154
  # scale X beforehand to avoid overflow
132
155
  if HAS_WEIGHT:
@@ -256,6 +279,10 @@ def liger_cross_entropy_kernel(
256
279
  tl.store(loss_ptr, loss)
257
280
  if RETURN_Z_LOSS:
258
281
  tl.store(z_loss_ptr, z_loss)
282
+ if RETURN_TOKEN_ACCURACY:
283
+ # Store 1.0 if prediction is correct, 0.0 otherwise
284
+ is_correct = 1.0 if argmax_idx == y else 0.0
285
+ tl.store(token_accuracy_ptr, is_correct)
259
286
 
260
287
 
261
288
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
@@ -274,8 +301,12 @@ def cross_entropy_forward(
274
301
  reduction,
275
302
  softcap,
276
303
  return_z_loss,
304
+ return_token_accuracy=False,
277
305
  ):
278
306
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
307
+ assert isinstance(return_token_accuracy, bool), (
308
+ f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
309
+ )
279
310
 
280
311
  BT, V = _input.shape
281
312
  n_rows = BT
@@ -285,6 +316,9 @@ def cross_entropy_forward(
285
316
  # unreduced loss
286
317
  loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
287
318
  z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
319
+ token_accuracy_1d = (
320
+ torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
321
+ )
288
322
 
289
323
  target_mask = target != ignore_index
290
324
  n_non_ignore = target_mask.sum().item()
@@ -321,6 +355,10 @@ def cross_entropy_forward(
321
355
  loss_ptr=loss_1d,
322
356
  z_loss_ptr=z_loss_1d,
323
357
  loss_stride=loss_1d.stride(-1), # always 1
358
+ token_accuracy_ptr=token_accuracy_1d,
359
+ token_accuracy_stride=token_accuracy_1d.stride(-1)
360
+ if return_token_accuracy
361
+ else 0, # always 1 if accuracy is enabled
324
362
  n_cols=V,
325
363
  n_non_ignore=n_non_ignore,
326
364
  sum_non_ignore_weight=sum_non_ignore_weight,
@@ -331,6 +369,7 @@ def cross_entropy_forward(
331
369
  reduction=reduction,
332
370
  softcap=softcap,
333
371
  RETURN_Z_LOSS=return_z_loss,
372
+ RETURN_TOKEN_ACCURACY=return_token_accuracy,
334
373
  BLOCK_SIZE=BLOCK_SIZE,
335
374
  HAS_WEIGHT=True if weight is not None else False,
336
375
  HAS_SOFTCAPPING=True if softcap is not None else False,
@@ -343,11 +382,14 @@ def cross_entropy_forward(
343
382
  if reduction == "none":
344
383
  loss = loss_1d
345
384
  z_loss = z_loss_1d if return_z_loss else None
385
+ token_accuracy = token_accuracy_1d if return_token_accuracy else None
346
386
  else:
347
387
  loss = torch.sum(loss_1d)
348
388
  z_loss = torch.sum(z_loss_1d) if return_z_loss else None
389
+ # For accuracy, we compute the mean across all non-ignored tokens
390
+ token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
349
391
 
350
- return loss, z_loss, _input
392
+ return loss, z_loss, token_accuracy, _input
351
393
 
352
394
 
353
395
  def cross_entropy_backward(_input, grad_output):
@@ -395,6 +437,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
395
437
  reduction: str = "mean",
396
438
  softcap: Optional[float] = None,
397
439
  return_z_loss: bool = False,
440
+ return_token_accuracy: bool = False,
398
441
  ):
399
442
  """
400
443
  The forward pass of the Liger Cross Entropy loss.
@@ -409,14 +452,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
409
452
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
410
453
  reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
411
454
  softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
412
- return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
455
+ return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
456
+ return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
413
457
 
414
458
  Returns:
415
- tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
459
+ tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
416
460
  """
417
461
  input_requires_grad = _input.requires_grad
418
462
 
419
- loss, z_loss, _input = cross_entropy_forward(
463
+ loss, z_loss, token_accuracy, _input = cross_entropy_forward(
420
464
  _input,
421
465
  target,
422
466
  weight,
@@ -426,6 +470,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
426
470
  reduction,
427
471
  softcap,
428
472
  return_z_loss,
473
+ return_token_accuracy,
429
474
  )
430
475
  # TODO: investigation
431
476
  # If we don't detach the _input tensor, the memory will double
@@ -433,23 +478,27 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
433
478
  if input_requires_grad:
434
479
  ctx.save_for_backward(_input.detach())
435
480
  ctx.return_z_loss = return_z_loss
481
+ ctx.return_token_accuracy = return_token_accuracy
436
482
 
437
- return loss, z_loss
483
+ return loss, z_loss, token_accuracy
438
484
 
439
485
  @staticmethod
440
- def backward(ctx, grad_output, grad_ouput2):
486
+ def backward(ctx, grad_output, grad_output2, grad_output3):
441
487
  """
442
488
  The backward pass of the Liger Cross Entropy loss.
443
489
 
444
490
  Parameters:
445
491
  ctx : The context object with saved tensors.
446
492
  grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
447
- grad_output2 (tenosr): No use.
493
+ grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
494
+ grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
448
495
  Returns:
449
496
  tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
450
497
  """
451
498
  if ctx.return_z_loss:
452
- del grad_ouput2 # z_loss is only for logging
499
+ del grad_output2 # z_loss is only for logging
500
+ if ctx.return_token_accuracy:
501
+ del grad_output3 # token_accuracy is only for metrics
453
502
 
454
503
  (_input,) = ctx.saved_tensors
455
504
  _input = cross_entropy_backward(_input, grad_output)
@@ -463,4 +512,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
463
512
  None,
464
513
  None,
465
514
  None,
515
+ None,
466
516
  )
@@ -27,8 +27,12 @@ def fused_linear_cross_entropy_forward(
27
27
  return_z_loss=False,
28
28
  accum_dtype=None,
29
29
  use_token_scaling=False,
30
+ return_token_accuracy=False,
30
31
  ):
31
32
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
33
+ assert isinstance(return_token_accuracy, bool), (
34
+ f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
35
+ )
32
36
  device = _input.device
33
37
 
34
38
  input_requires_grad = _input.requires_grad
@@ -64,6 +68,7 @@ def fused_linear_cross_entropy_forward(
64
68
 
65
69
  loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
66
70
  z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
71
+ token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None
67
72
 
68
73
  # TODO: evaluate how CUDA synchronization caused by .item() affects the speed
69
74
  target_mask = target != ignore_index
@@ -129,6 +134,7 @@ def fused_linear_cross_entropy_forward(
129
134
  # unreduced loss
130
135
  loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
131
136
  z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
137
+ token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None
132
138
 
133
139
  # ensure _input and target are contiguous
134
140
  logits_chunk = logits_chunk.contiguous()
@@ -144,6 +150,10 @@ def fused_linear_cross_entropy_forward(
144
150
  loss_ptr=loss_1d_slice,
145
151
  z_loss_ptr=z_loss_1d_slice,
146
152
  loss_stride=loss_1d_slice.stride(-1), # always 1
153
+ token_accuracy_ptr=token_accuracy_1d_slice,
154
+ token_accuracy_stride=token_accuracy_1d_slice.stride(-1)
155
+ if return_token_accuracy
156
+ else 0, # always 1 if accuracy is enabled
147
157
  n_cols=V,
148
158
  n_non_ignore=total_n_non_ignore,
149
159
  sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
@@ -154,6 +164,7 @@ def fused_linear_cross_entropy_forward(
154
164
  reduction=reduction,
155
165
  softcap=softcap,
156
166
  RETURN_Z_LOSS=return_z_loss,
167
+ RETURN_TOKEN_ACCURACY=return_token_accuracy,
157
168
  HAS_WEIGHT=True if ce_weight is not None else False,
158
169
  HAS_SOFTCAPPING=True if softcap is not None else False,
159
170
  HAS_GRADIENTS=input_requires_grad,
@@ -170,6 +181,8 @@ def fused_linear_cross_entropy_forward(
170
181
  loss_1d[start_idx:end_idx] = loss_1d_slice
171
182
  if return_z_loss:
172
183
  z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
184
+ if return_token_accuracy:
185
+ token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice
173
186
  grad_logits_chunk = logits_chunk # chunk_size x V
174
187
 
175
188
  # Apply token scaling to gradients if requested
@@ -201,15 +214,18 @@ def fused_linear_cross_entropy_forward(
201
214
  # Return per-token losses
202
215
  loss = loss_1d
203
216
  z_loss = z_loss_1d if return_z_loss else None
217
+ token_accuracy = token_accuracy_1d if return_token_accuracy else None
204
218
  else:
205
219
  loss = torch.sum(loss_1d)
206
220
  z_loss = torch.sum(z_loss_1d) if return_z_loss else None
221
+ # For accuracy, we compute the mean across all non-ignored tokens
222
+ token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None
207
223
 
208
224
  # Cast back to original dtype
209
225
  grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
210
226
  grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None
211
227
 
212
- return loss, z_loss, grad_input, grad_weight, grad_bias
228
+ return loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias
213
229
 
214
230
 
215
231
  def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
@@ -277,6 +293,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
277
293
  return_z_loss: bool = False,
278
294
  accum_dtype=None,
279
295
  use_token_scaling: bool = False,
296
+ return_token_accuracy: bool = False,
280
297
  ):
281
298
  """
282
299
  Fusing the last linear layer with cross-entropy loss
@@ -300,9 +317,10 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
300
317
  use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
301
318
  When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
302
319
  Default: False.
320
+ return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
303
321
  """
304
322
 
305
- loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
323
+ loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
306
324
  _input=_input,
307
325
  weight=weight,
308
326
  target=target,
@@ -316,6 +334,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
316
334
  return_z_loss=return_z_loss,
317
335
  accum_dtype=accum_dtype,
318
336
  use_token_scaling=use_token_scaling,
337
+ return_token_accuracy=return_token_accuracy,
319
338
  )
320
339
  # downcast to dtype and store for backward
321
340
  ctx.save_for_backward(
@@ -324,13 +343,16 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
324
343
  grad_bias.detach() if bias is not None else None,
325
344
  )
326
345
  ctx.return_z_loss = return_z_loss
327
- return loss, z_loss
346
+ ctx.return_token_accuracy = return_token_accuracy
347
+ return loss, z_loss, token_accuracy
328
348
 
329
349
  @staticmethod
330
350
  @amp_custom_bwd
331
- def backward(ctx, grad_output, grad_output2):
351
+ def backward(ctx, grad_output, grad_output2, grad_output3):
332
352
  if ctx.return_z_loss:
333
353
  del grad_output2 # z_loss is only for logging
354
+ if ctx.return_token_accuracy:
355
+ del grad_output3 # token_accuracy is only for metrics
334
356
  (grad_input, grad_weight, grad_bias) = ctx.saved_tensors
335
357
  grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
336
358
  grad_output, grad_input, grad_weight, grad_bias
@@ -349,4 +371,5 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
349
371
  None,
350
372
  None,
351
373
  None, # use_token_scaling
374
+ None, # return_token_accuracy
352
375
  )
@@ -3,6 +3,7 @@ from typing import Optional
3
3
  import torch
4
4
 
5
5
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
6
+ from liger_kernel.transformers.functional import CrossEntropyOutput
6
7
 
7
8
 
8
9
  class LigerCrossEntropyLoss(torch.nn.Module):
@@ -15,6 +16,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
15
16
  reduction: str = "mean",
16
17
  softcap: Optional[float] = None,
17
18
  return_z_loss: bool = False,
19
+ return_token_accuracy: bool = False,
18
20
  ):
19
21
  super().__init__()
20
22
  assert (label_smoothing >= 0) and (label_smoothing <= 1), (
@@ -33,9 +35,10 @@ class LigerCrossEntropyLoss(torch.nn.Module):
33
35
  self.reduction = reduction
34
36
  self.softcap = softcap
35
37
  self.return_z_loss = return_z_loss
38
+ self.return_token_accuracy = return_token_accuracy
36
39
 
37
40
  def forward(self, _input: torch.Tensor, target: torch.Tensor):
38
- loss, z_loss = LigerCrossEntropyFunction.apply(
41
+ loss, z_loss, token_accuracy = LigerCrossEntropyFunction.apply(
39
42
  _input,
40
43
  target,
41
44
  self.weight,
@@ -45,7 +48,9 @@ class LigerCrossEntropyLoss(torch.nn.Module):
45
48
  self.reduction,
46
49
  self.softcap,
47
50
  self.return_z_loss,
51
+ self.return_token_accuracy,
48
52
  )
49
- if not self.return_z_loss:
53
+ if not self.return_z_loss and not self.return_token_accuracy:
50
54
  return loss
51
- return loss, z_loss
55
+
56
+ return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
@@ -1,5 +1,8 @@
1
+ from dataclasses import dataclass
1
2
  from typing import Optional
2
3
 
4
+ import torch
5
+
3
6
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
7
  from liger_kernel.ops.dyt import LigerDyTFunction
5
8
  from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction
@@ -22,6 +25,13 @@ from liger_kernel.ops.swiglu import LigerSiLUMulFunction
22
25
  from liger_kernel.ops.tvd import LigerTVDLossFunction
23
26
 
24
27
 
28
+ @dataclass
29
+ class CrossEntropyOutput:
30
+ loss: torch.Tensor
31
+ z_loss: Optional[torch.Tensor] = None
32
+ token_accuracy: Optional[torch.Tensor] = None
33
+
34
+
25
35
  # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
26
36
  # `weight` and `size_average` are placeholders and not implemented yet
27
37
  def liger_cross_entropy(
@@ -36,8 +46,9 @@ def liger_cross_entropy(
36
46
  lse_square_scale: float = 0.0,
37
47
  softcap: Optional[float] = None,
38
48
  return_z_loss: bool = False,
49
+ return_token_accuracy: bool = False,
39
50
  ):
40
- loss, z_loss = LigerCrossEntropyFunction.apply(
51
+ loss, z_loss, token_accuracy = LigerCrossEntropyFunction.apply(
41
52
  input,
42
53
  target,
43
54
  weight,
@@ -47,10 +58,13 @@ def liger_cross_entropy(
47
58
  reduction,
48
59
  softcap,
49
60
  return_z_loss,
61
+ return_token_accuracy,
50
62
  )
51
- if not return_z_loss:
63
+
64
+ if not return_z_loss and not return_token_accuracy:
52
65
  return loss
53
- return loss, z_loss
66
+
67
+ return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
54
68
 
55
69
 
56
70
  def liger_fused_linear_cross_entropy(
@@ -67,8 +81,9 @@ def liger_fused_linear_cross_entropy(
67
81
  return_z_loss: bool = False,
68
82
  accum_dtype=None,
69
83
  use_token_scaling: bool = False,
84
+ return_token_accuracy: bool = False,
70
85
  ):
71
- loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
86
+ loss, z_loss, token_accuracy = LigerFusedLinearCrossEntropyFunction.apply(
72
87
  input,
73
88
  weight,
74
89
  target,
@@ -82,10 +97,13 @@ def liger_fused_linear_cross_entropy(
82
97
  return_z_loss,
83
98
  accum_dtype,
84
99
  use_token_scaling,
100
+ return_token_accuracy,
85
101
  )
86
- if not return_z_loss:
102
+
103
+ if not return_z_loss and not return_token_accuracy:
87
104
  return loss
88
- return loss, z_loss
105
+
106
+ return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
89
107
 
90
108
 
91
109
  def liger_fused_linear_jsd(
@@ -3,6 +3,7 @@ from typing import Optional
3
3
  import torch
4
4
 
5
5
  from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
6
+ from liger_kernel.transformers.functional import CrossEntropyOutput
6
7
 
7
8
 
8
9
  class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
@@ -17,6 +18,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
17
18
  return_z_loss: bool = False,
18
19
  accum_dtype: Optional[torch.dtype] = None,
19
20
  use_token_scaling: bool = False,
21
+ return_token_accuracy: bool = False,
20
22
  ):
21
23
  super().__init__()
22
24
  assert (label_smoothing >= 0) and (label_smoothing <= 1), (
@@ -37,9 +39,10 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
37
39
  self.return_z_loss = return_z_loss
38
40
  self.accum_dtype = accum_dtype
39
41
  self.use_token_scaling = use_token_scaling
42
+ self.return_token_accuracy = return_token_accuracy
40
43
 
41
44
  def forward(self, lin_weight, _input, target, bias=None):
42
- loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
45
+ loss, z_loss, token_accuracy = LigerFusedLinearCrossEntropyFunction.apply(
43
46
  _input,
44
47
  lin_weight,
45
48
  target,
@@ -53,7 +56,9 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
53
56
  self.return_z_loss,
54
57
  self.accum_dtype,
55
58
  self.use_token_scaling,
59
+ self.return_token_accuracy,
56
60
  )
57
- if not self.return_z_loss:
61
+ if not self.return_z_loss and not self.return_token_accuracy:
58
62
  return loss
59
- return loss, z_loss
63
+
64
+ return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
@@ -4,12 +4,12 @@ from typing import Union
4
4
 
5
5
  import torch
6
6
 
7
- from transformers.modeling_outputs import CausalLMOutputWithPast
8
-
9
7
  if TYPE_CHECKING:
10
8
  from transformers.models.falcon_h1.modeling_falcon_h1 import FalconHybridMambaAttentionDynamicCache
11
9
 
12
10
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
13
13
 
14
14
 
15
15
  def lce_forward(
@@ -26,8 +26,9 @@ def lce_forward(
26
26
  cache_position: Optional[torch.LongTensor] = None,
27
27
  logits_to_keep: Union[int, torch.Tensor] = 0,
28
28
  skip_logits: Optional[bool] = None,
29
+ return_dict: Optional[bool] = None,
29
30
  **kwargs,
30
- ) -> Union[tuple, CausalLMOutputWithPast]:
31
+ ) -> Union[tuple, LigerCausalLMOutputWithPast]:
31
32
  r"""
32
33
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
33
34
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -54,6 +55,7 @@ def lce_forward(
54
55
  output_hidden_states = (
55
56
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
56
57
  )
58
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
57
59
 
58
60
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
59
61
  outputs = self.model(
@@ -77,6 +79,8 @@ def lce_forward(
77
79
  shift_labels = kwargs.pop("shift_labels", None)
78
80
  logits = None
79
81
  loss = None
82
+ token_accuracy = None
83
+
80
84
  # if in training mode, don't materialize logits
81
85
  if skip_logits and labels is None:
82
86
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -85,8 +89,9 @@ def lce_forward(
85
89
  # By default, if in training mode, don't materialize logits
86
90
  skip_logits = self.training and labels is not None
87
91
 
92
+ # Compute loss
88
93
  if skip_logits:
89
- loss = LigerForCausalLMLoss(
94
+ result = LigerForCausalLMLoss(
90
95
  hidden_states=kept_hidden_states,
91
96
  lm_head_weight=self.lm_head.weight,
92
97
  labels=labels,
@@ -94,15 +99,24 @@ def lce_forward(
94
99
  hidden_size=self.config.hidden_size,
95
100
  **kwargs,
96
101
  )
102
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
97
103
  else:
98
104
  logits = self.lm_head(kept_hidden_states)
99
105
  if labels is not None or shift_labels is not None:
100
106
  loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
101
107
 
102
- return CausalLMOutputWithPast(
108
+ if not return_dict:
109
+ output = (logits,) + outputs[1:]
110
+ output = ((loss,) + output) if loss is not None else output
111
+ output = output + (token_accuracy,) if token_accuracy is not None else output
112
+ return output
113
+
114
+ # Return custom output class with token_accuracy field
115
+ return LigerCausalLMOutputWithPast(
103
116
  loss=loss,
104
117
  logits=logits,
105
118
  past_key_values=outputs.past_key_values,
106
119
  hidden_states=outputs.hidden_states,
107
120
  attentions=outputs.attentions,
121
+ token_accuracy=token_accuracy,
108
122
  )
@@ -12,6 +12,8 @@ from transformers.utils.deprecation import deprecate_kwarg
12
12
 
13
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
14
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
15
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
16
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
15
17
 
16
18
 
17
19
  def lce_forward_deprecated(
@@ -147,7 +149,7 @@ def lce_forward(
147
149
  logits_to_keep: Union[int, torch.Tensor] = 0,
148
150
  skip_logits: Optional[bool] = None,
149
151
  **kwargs,
150
- ) -> Union[Tuple, CausalLMOutputWithPast]:
152
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
151
153
  r"""
152
154
  Args:
153
155
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -209,6 +211,7 @@ def lce_forward(
209
211
  shift_labels = kwargs.pop("shift_labels", None)
210
212
  logits = None
211
213
  loss = None
214
+ token_accuracy = None
212
215
 
213
216
  if skip_logits and labels is None and shift_labels is None:
214
217
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -217,8 +220,9 @@ def lce_forward(
217
220
  # By default, if in training mode, don't materialize logits
218
221
  skip_logits = self.training and (labels is not None or shift_labels is not None)
219
222
 
223
+ # Compute loss
220
224
  if skip_logits:
221
- loss = LigerForCausalLMLoss(
225
+ result = LigerForCausalLMLoss(
222
226
  hidden_states=kept_hidden_states,
223
227
  lm_head_weight=self.lm_head.weight,
224
228
  labels=labels,
@@ -226,6 +230,7 @@ def lce_forward(
226
230
  hidden_size=self.config.hidden_size,
227
231
  **kwargs,
228
232
  )
233
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
229
234
  else:
230
235
  logits = self.lm_head(kept_hidden_states)
231
236
  if labels is not None or shift_labels is not None:
@@ -238,13 +243,19 @@ def lce_forward(
238
243
  )
239
244
 
240
245
  if not return_dict:
241
- output = (logits,) + outputs[1:]
242
- return (loss,) + output if loss is not None else output
243
-
244
- return CausalLMOutputWithPast(
246
+ output_tuple = (logits,) + outputs[1:]
247
+ if loss is not None:
248
+ output_tuple = (loss,) + output_tuple
249
+ if token_accuracy is not None:
250
+ output_tuple = output_tuple + (token_accuracy,)
251
+ return output_tuple
252
+
253
+ # Return custom output class with token_accuracy field
254
+ return LigerCausalLMOutputWithPast(
245
255
  loss=loss,
246
256
  logits=logits,
247
257
  past_key_values=outputs.past_key_values,
248
258
  hidden_states=outputs.hidden_states,
249
259
  attentions=outputs.attentions,
260
+ token_accuracy=token_accuracy,
250
261
  )