liger-kernel-nightly 0.5.10.dev20250611191801__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (107) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +54 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
  7. liger_kernel/chunked_loss/grpo_loss.py +46 -9
  8. liger_kernel/chunked_loss/jsd_loss.py +44 -13
  9. liger_kernel/ops/__init__.py +141 -0
  10. liger_kernel/ops/backends/README.md +151 -0
  11. liger_kernel/ops/backends/__init__.py +13 -0
  12. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  13. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  14. liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
  15. liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
  16. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  17. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  18. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  19. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  20. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  21. liger_kernel/ops/backends/registry.py +61 -0
  22. liger_kernel/ops/cross_entropy.py +130 -64
  23. liger_kernel/ops/dyt.py +5 -4
  24. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  25. liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
  26. liger_kernel/ops/geglu.py +6 -4
  27. liger_kernel/ops/group_norm.py +7 -7
  28. liger_kernel/ops/grpo_loss.py +3 -1
  29. liger_kernel/ops/kl_div.py +8 -11
  30. liger_kernel/ops/layer_norm.py +135 -80
  31. liger_kernel/ops/llama4_rope.py +225 -0
  32. liger_kernel/ops/poly_norm.py +390 -0
  33. liger_kernel/ops/rms_norm.py +148 -71
  34. liger_kernel/ops/rope.py +1 -1
  35. liger_kernel/ops/swiglu.py +1 -1
  36. liger_kernel/ops/tiled_mlp.py +136 -0
  37. liger_kernel/ops/utils.py +14 -0
  38. liger_kernel/transformers/__init__.py +65 -0
  39. liger_kernel/transformers/auto_model.py +21 -0
  40. liger_kernel/transformers/cross_entropy.py +9 -4
  41. liger_kernel/transformers/dyt.py +1 -1
  42. liger_kernel/transformers/experimental/__init__.py +5 -0
  43. liger_kernel/transformers/experimental/embedding.py +1 -1
  44. liger_kernel/transformers/functional.py +56 -24
  45. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  46. liger_kernel/transformers/fused_linear_cross_entropy.py +17 -5
  47. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  48. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  49. liger_kernel/transformers/geglu.py +1 -1
  50. liger_kernel/transformers/group_norm.py +1 -1
  51. liger_kernel/transformers/grpo_loss.py +57 -2
  52. liger_kernel/transformers/jsd.py +1 -1
  53. liger_kernel/transformers/kl_div.py +1 -1
  54. liger_kernel/transformers/layer_norm.py +1 -1
  55. liger_kernel/transformers/llama4_rope.py +93 -0
  56. liger_kernel/transformers/model/exaone4.py +136 -0
  57. liger_kernel/transformers/model/falcon_h1.py +122 -0
  58. liger_kernel/transformers/model/gemma.py +28 -8
  59. liger_kernel/transformers/model/gemma2.py +34 -11
  60. liger_kernel/transformers/model/gemma3.py +102 -112
  61. liger_kernel/transformers/model/glm4.py +18 -5
  62. liger_kernel/transformers/model/glm4v.py +163 -0
  63. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  64. liger_kernel/transformers/model/gpt_oss.py +211 -0
  65. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  66. liger_kernel/transformers/model/internvl.py +157 -0
  67. liger_kernel/transformers/model/llama.py +26 -7
  68. liger_kernel/transformers/model/llama4.py +121 -0
  69. liger_kernel/transformers/model/llava.py +18 -6
  70. liger_kernel/transformers/model/loss_utils.py +34 -3
  71. liger_kernel/transformers/model/mistral.py +17 -10
  72. liger_kernel/transformers/model/mixtral.py +24 -9
  73. liger_kernel/transformers/model/mllama.py +18 -7
  74. liger_kernel/transformers/model/olmo2.py +18 -5
  75. liger_kernel/transformers/model/olmo3.py +142 -0
  76. liger_kernel/transformers/model/output_classes.py +147 -0
  77. liger_kernel/transformers/model/paligemma.py +42 -5
  78. liger_kernel/transformers/model/phi3.py +24 -159
  79. liger_kernel/transformers/model/qwen2.py +26 -4
  80. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  81. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  82. liger_kernel/transformers/model/qwen3.py +22 -6
  83. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  84. liger_kernel/transformers/model/qwen3_next.py +146 -0
  85. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  86. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  87. liger_kernel/transformers/model/smollm3.py +199 -0
  88. liger_kernel/transformers/model/smolvlm.py +158 -0
  89. liger_kernel/transformers/monkey_patch.py +1423 -100
  90. liger_kernel/transformers/multi_token_attention.py +2 -2
  91. liger_kernel/transformers/poly_norm.py +42 -0
  92. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  93. liger_kernel/transformers/rms_norm.py +15 -5
  94. liger_kernel/transformers/rope.py +45 -1
  95. liger_kernel/transformers/softmax.py +1 -1
  96. liger_kernel/transformers/sparsemax.py +1 -1
  97. liger_kernel/transformers/swiglu.py +18 -1
  98. liger_kernel/transformers/tiled_mlp.py +125 -0
  99. liger_kernel/transformers/tvd.py +1 -1
  100. liger_kernel/utils.py +52 -0
  101. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +37 -25
  102. liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
  103. liger_kernel_nightly-0.5.10.dev20250611191801.dist-info/RECORD +0 -95
  104. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
  105. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
  106. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
  107. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
@@ -10,8 +10,9 @@ from liger_kernel.ops.utils import compare_version
10
10
  from liger_kernel.ops.utils import element_mul_kernel
11
11
  from liger_kernel.ops.utils import is_hip
12
12
  from liger_kernel.utils import infer_device
13
+ from liger_kernel.utils import is_npu_available
13
14
 
14
- if compare_version("triton", operator.ge, "3.0.0"):
15
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
15
16
  try:
16
17
  # typical import path with dispatch available
17
18
  from triton.language.extra.libdevice import tanh
@@ -32,6 +33,8 @@ def liger_cross_entropy_kernel(
32
33
  loss_ptr,
33
34
  z_loss_ptr,
34
35
  loss_stride,
36
+ token_accuracy_ptr,
37
+ token_accuracy_stride,
35
38
  n_cols,
36
39
  n_non_ignore,
37
40
  sum_non_ignore_weight,
@@ -42,9 +45,11 @@ def liger_cross_entropy_kernel(
42
45
  reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
43
46
  softcap,
44
47
  RETURN_Z_LOSS: tl.constexpr,
48
+ RETURN_TOKEN_ACCURACY: tl.constexpr,
45
49
  BLOCK_SIZE: tl.constexpr,
46
50
  HAS_WEIGHT: tl.constexpr,
47
51
  HAS_SOFTCAPPING: tl.constexpr,
52
+ HAS_GRADIENTS: tl.constexpr,
48
53
  ):
49
54
  """
50
55
  This kernel computes both cross entropy loss and the gradient of the input.
@@ -59,6 +64,8 @@ def liger_cross_entropy_kernel(
59
64
  loss_ptr: Pointer to tensor to store the loss.
60
65
  z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
61
66
  loss_stride (int): The stride of the loss tensor.
67
+ token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
68
+ token_accuracy_stride (int): The stride of the token accuracy tensor.
62
69
  n_cols (int): The number of columns in the input tensor.
63
70
  n_non_ignore (float): The number of non-ignored elements in the batch.
64
71
  sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
@@ -68,10 +75,12 @@ def liger_cross_entropy_kernel(
68
75
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
69
76
  reduction (str): The string for the reduction to apply
70
77
  softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
71
- RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
78
+ RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
79
+ RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
72
80
  BLOCK_SIZE (int): The block size for Triton operations.
73
81
  HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
74
82
  HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
83
+ HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
75
84
  """
76
85
 
77
86
  # https://github.com/triton-lang/triton/issues/1058
@@ -90,11 +99,17 @@ def liger_cross_entropy_kernel(
90
99
  for i in range(0, n_cols, BLOCK_SIZE):
91
100
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
92
101
  tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
102
+ # For ignored tokens, set token accuracy to 0
103
+ if RETURN_TOKEN_ACCURACY:
104
+ token_accuracy_ptr += program_id * token_accuracy_stride
105
+ tl.store(token_accuracy_ptr, 0.0)
93
106
  return
94
107
 
95
108
  loss_ptr += program_id * loss_stride
96
109
  if RETURN_Z_LOSS:
97
110
  z_loss_ptr += program_id * loss_stride
111
+ if RETURN_TOKEN_ACCURACY:
112
+ token_accuracy_ptr += program_id * token_accuracy_stride
98
113
 
99
114
  if HAS_WEIGHT:
100
115
  weight_y = tl.load(weight_ptr + y).cast(tl.float32)
@@ -105,6 +120,7 @@ def liger_cross_entropy_kernel(
105
120
  # 3. [Online softmax] first pass: find max + sum
106
121
  m = float("-inf") # m is the max value. use the notation from the paper
107
122
  d = 0.0 # d is the sum. use the notation from the paper
123
+ argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
108
124
  ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
109
125
  if HAS_SOFTCAPPING:
110
126
  ori_X_y = softcap * tanh(ori_X_y / softcap)
@@ -125,6 +141,19 @@ def liger_cross_entropy_kernel(
125
141
  if HAS_SOFTCAPPING:
126
142
  X_block = softcap * tanh(X_block / softcap)
127
143
  block_max = tl.max(X_block)
144
+
145
+ # Track argmax for accuracy computation
146
+ if RETURN_TOKEN_ACCURACY:
147
+ # Find the index of the maximum value in this block
148
+ is_max_mask = X_block == block_max
149
+ # Mask out invalid indices with a value larger than n_cols
150
+ masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
151
+ # Get the first (smallest) index where max occurs
152
+ current_block_argmax_idx = tl.min(masked_offsets)
153
+
154
+ is_new_max = block_max > m
155
+ argmax_idx = tl.where(is_new_max, current_block_argmax_idx, argmax_idx)
156
+
128
157
  if label_smoothing > 0:
129
158
  # scale X beforehand to avoid overflow
130
159
  if HAS_WEIGHT:
@@ -155,58 +184,58 @@ def liger_cross_entropy_kernel(
155
184
  # For 'sum' reduction, no normalization is applied:
156
185
  # dx_y = softmax(x_y) - 1
157
186
  # dx_i = softmax(x_i), for i ≠ y
158
-
159
- for i in range(0, n_cols, BLOCK_SIZE):
160
- X_offsets = i + tl.arange(0, BLOCK_SIZE)
161
- X_block = tl.load(
162
- X_ptr + X_offsets,
163
- mask=X_offsets < n_cols,
164
- other=float("-inf"),
165
- # Ensure float32 precision for softmax calculation
166
- ).cast(tl.float32)
167
- if HAS_SOFTCAPPING:
168
- intermediate = tanh(X_block / softcap)
169
- X_block = softcap * intermediate
170
-
171
- if not HAS_WEIGHT:
172
- # softmax(x_i)
173
- X_block = tl.exp(X_block - m) / d
174
- # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
175
- X_block += 2 * lse_square_scale * lse * X_block
176
- # smoothing term
177
- X_block += -eps
178
- # special handle dx_y
179
- X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
180
- # reduction scale
181
- if reduction == "mean":
182
- X_block = X_block / n_non_ignore
183
- else:
184
- weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
185
- softmax_X = tl.exp(X_block - m) / d
186
- # derivative of original_loss
187
- dloss_ori = (1 - label_smoothing) * softmax_X
188
- # specially handle dx_y
189
- dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
190
- dloss_ori = dloss_ori * weight_y
191
- # derivative of smooth_loss
192
- dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
193
- # derivative of z-loss
194
- dz_loss = 2 * lse_square_scale * lse * softmax_X
195
- # reduction scale
196
- if reduction == "mean":
197
- dloss_ori = dloss_ori / sum_non_ignore_weight
198
- dloss_smooth = dloss_smooth / sum_non_ignore_weight
199
- # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
200
- dz_loss = dz_loss / n_non_ignore
201
- # derivative of total_loss
202
- X_block = dloss_ori + dloss_smooth + dz_loss
203
-
204
- # chain rule softcapping
205
- # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
206
- if HAS_SOFTCAPPING:
207
- X_block = X_block * (1 - intermediate * intermediate)
208
-
209
- tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
187
+ if HAS_GRADIENTS:
188
+ for i in range(0, n_cols, BLOCK_SIZE):
189
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
190
+ X_block = tl.load(
191
+ X_ptr + X_offsets,
192
+ mask=X_offsets < n_cols,
193
+ other=float("-inf"),
194
+ # Ensure float32 precision for softmax calculation
195
+ ).cast(tl.float32)
196
+ if HAS_SOFTCAPPING:
197
+ intermediate = tanh(X_block / softcap)
198
+ X_block = softcap * intermediate
199
+
200
+ if not HAS_WEIGHT:
201
+ # softmax(x_i)
202
+ X_block = tl.exp(X_block - m) / d
203
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
204
+ X_block += 2 * lse_square_scale * lse * X_block
205
+ # smoothing term
206
+ X_block += -eps
207
+ # special handle dx_y
208
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
209
+ # reduction scale
210
+ if reduction == "mean":
211
+ X_block = X_block / n_non_ignore
212
+ else:
213
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
214
+ softmax_X = tl.exp(X_block - m) / d
215
+ # derivative of original_loss
216
+ dloss_ori = (1 - label_smoothing) * softmax_X
217
+ # specially handle dx_y
218
+ dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
219
+ dloss_ori = dloss_ori * weight_y
220
+ # derivative of smooth_loss
221
+ dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
222
+ # derivative of z-loss
223
+ dz_loss = 2 * lse_square_scale * lse * softmax_X
224
+ # reduction scale
225
+ if reduction == "mean":
226
+ dloss_ori = dloss_ori / sum_non_ignore_weight
227
+ dloss_smooth = dloss_smooth / sum_non_ignore_weight
228
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
229
+ dz_loss = dz_loss / n_non_ignore
230
+ # derivative of total_loss
231
+ X_block = dloss_ori + dloss_smooth + dz_loss
232
+
233
+ # chain rule softcapping
234
+ # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
235
+ if HAS_SOFTCAPPING:
236
+ X_block = X_block * (1 - intermediate * intermediate)
237
+
238
+ tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
210
239
 
211
240
  # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
212
241
  # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
@@ -254,12 +283,22 @@ def liger_cross_entropy_kernel(
254
283
  tl.store(loss_ptr, loss)
255
284
  if RETURN_Z_LOSS:
256
285
  tl.store(z_loss_ptr, z_loss)
286
+ if RETURN_TOKEN_ACCURACY:
287
+ # Store 1.0 if prediction is correct, 0.0 otherwise
288
+ is_correct = 1.0 if argmax_idx == y else 0.0
289
+ tl.store(token_accuracy_ptr, is_correct)
257
290
 
258
291
 
259
292
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
260
293
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
261
294
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
262
- MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
295
+ # the best size we found by manually tuning on xpu and npu.
296
+ if infer_device() == "xpu":
297
+ MAX_FUSED_SIZE = 4096
298
+ elif infer_device() == "npu":
299
+ MAX_FUSED_SIZE = 2048
300
+ else:
301
+ MAX_FUSED_SIZE = 65536 // 2
263
302
 
264
303
 
265
304
  def cross_entropy_forward(
@@ -272,8 +311,12 @@ def cross_entropy_forward(
272
311
  reduction,
273
312
  softcap,
274
313
  return_z_loss,
314
+ return_token_accuracy=False,
275
315
  ):
276
316
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
317
+ assert isinstance(return_token_accuracy, bool), (
318
+ f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
319
+ )
277
320
 
278
321
  BT, V = _input.shape
279
322
  n_rows = BT
@@ -283,6 +326,9 @@ def cross_entropy_forward(
283
326
  # unreduced loss
284
327
  loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
285
328
  z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
329
+ token_accuracy_1d = (
330
+ torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
331
+ )
286
332
 
287
333
  target_mask = target != ignore_index
288
334
  n_non_ignore = target_mask.sum().item()
@@ -319,6 +365,10 @@ def cross_entropy_forward(
319
365
  loss_ptr=loss_1d,
320
366
  z_loss_ptr=z_loss_1d,
321
367
  loss_stride=loss_1d.stride(-1), # always 1
368
+ token_accuracy_ptr=token_accuracy_1d,
369
+ token_accuracy_stride=token_accuracy_1d.stride(-1)
370
+ if return_token_accuracy
371
+ else 0, # always 1 if accuracy is enabled
322
372
  n_cols=V,
323
373
  n_non_ignore=n_non_ignore,
324
374
  sum_non_ignore_weight=sum_non_ignore_weight,
@@ -329,9 +379,11 @@ def cross_entropy_forward(
329
379
  reduction=reduction,
330
380
  softcap=softcap,
331
381
  RETURN_Z_LOSS=return_z_loss,
382
+ RETURN_TOKEN_ACCURACY=return_token_accuracy,
332
383
  BLOCK_SIZE=BLOCK_SIZE,
333
384
  HAS_WEIGHT=True if weight is not None else False,
334
385
  HAS_SOFTCAPPING=True if softcap is not None else False,
386
+ HAS_GRADIENTS=_input.requires_grad,
335
387
  # TODO: 32 seems to give the best performance
336
388
  # Performance is quite sensitive to num_warps
337
389
  num_warps=32 if not is_hip() else 16,
@@ -340,11 +392,14 @@ def cross_entropy_forward(
340
392
  if reduction == "none":
341
393
  loss = loss_1d
342
394
  z_loss = z_loss_1d if return_z_loss else None
395
+ token_accuracy = token_accuracy_1d if return_token_accuracy else None
343
396
  else:
344
397
  loss = torch.sum(loss_1d)
345
398
  z_loss = torch.sum(z_loss_1d) if return_z_loss else None
399
+ # For accuracy, we compute the mean across all non-ignored tokens
400
+ token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
346
401
 
347
- return loss, z_loss, _input
402
+ return loss, z_loss, token_accuracy, _input
348
403
 
349
404
 
350
405
  def cross_entropy_backward(_input, grad_output):
@@ -392,6 +447,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
392
447
  reduction: str = "mean",
393
448
  softcap: Optional[float] = None,
394
449
  return_z_loss: bool = False,
450
+ return_token_accuracy: bool = False,
395
451
  ):
396
452
  """
397
453
  The forward pass of the Liger Cross Entropy loss.
@@ -406,12 +462,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
406
462
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
407
463
  reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
408
464
  softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
409
- return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
465
+ return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
466
+ return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
410
467
 
411
468
  Returns:
412
- tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
469
+ tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
413
470
  """
414
- loss, z_loss, _input = cross_entropy_forward(
471
+ input_requires_grad = _input.requires_grad
472
+
473
+ loss, z_loss, token_accuracy, _input = cross_entropy_forward(
415
474
  _input,
416
475
  target,
417
476
  weight,
@@ -421,29 +480,35 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
421
480
  reduction,
422
481
  softcap,
423
482
  return_z_loss,
483
+ return_token_accuracy,
424
484
  )
425
485
  # TODO: investigation
426
486
  # If we don't detach the _input tensor, the memory will double
427
487
  # Not sure why but seems that there will be a time both grad and value exist but in different location
428
- ctx.save_for_backward(_input.detach())
488
+ if input_requires_grad:
489
+ ctx.save_for_backward(_input.detach())
429
490
  ctx.return_z_loss = return_z_loss
491
+ ctx.return_token_accuracy = return_token_accuracy
430
492
 
431
- return loss, z_loss
493
+ return loss, z_loss, token_accuracy
432
494
 
433
495
  @staticmethod
434
- def backward(ctx, grad_output, grad_ouput2):
496
+ def backward(ctx, grad_output, grad_output2, grad_output3):
435
497
  """
436
498
  The backward pass of the Liger Cross Entropy loss.
437
499
 
438
500
  Parameters:
439
501
  ctx : The context object with saved tensors.
440
502
  grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
441
- grad_output2 (tenosr): No use.
503
+ grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
504
+ grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
442
505
  Returns:
443
506
  tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
444
507
  """
445
508
  if ctx.return_z_loss:
446
- del grad_ouput2 # z_loss is only for logging
509
+ del grad_output2 # z_loss is only for logging
510
+ if ctx.return_token_accuracy:
511
+ del grad_output3 # token_accuracy is only for metrics
447
512
 
448
513
  (_input,) = ctx.saved_tensors
449
514
  _input = cross_entropy_backward(_input, grad_output)
@@ -457,4 +522,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
457
522
  None,
458
523
  None,
459
524
  None,
525
+ None,
460
526
  )
liger_kernel/ops/dyt.py CHANGED
@@ -4,13 +4,13 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from triton.language.extra.libdevice import tanh
8
-
9
7
  from liger_kernel.ops.utils import compare_version
10
8
  from liger_kernel.ops.utils import ensure_contiguous
11
9
  from liger_kernel.ops.utils import infer_device
10
+ from liger_kernel.utils import get_npu_multi_processor_count
11
+ from liger_kernel.utils import is_npu_available
12
12
 
13
- if compare_version("triton", operator.ge, "3.0.0"):
13
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
14
14
  try:
15
15
  # typical import path with dispatch available
16
16
  from triton.language.extra.libdevice import tanh
@@ -127,7 +127,8 @@ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
127
127
  NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
128
128
  elif device == "xpu":
129
129
  NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
130
-
130
+ elif device == "npu":
131
+ NUM_SMS = get_npu_multi_processor_count()
131
132
  da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
132
133
  dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
133
134
  db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None