liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (114) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +304 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +21 -4
  18. liger_kernel/ops/cross_entropy.py +235 -84
  19. liger_kernel/ops/dyt.py +157 -0
  20. liger_kernel/ops/experimental/embedding.py +1 -3
  21. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  22. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  23. liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
  24. liger_kernel/ops/fused_linear_jsd.py +17 -34
  25. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  26. liger_kernel/ops/geglu.py +7 -18
  27. liger_kernel/ops/group_norm.py +305 -0
  28. liger_kernel/ops/grpo_loss.py +310 -0
  29. liger_kernel/ops/jsd.py +46 -21
  30. liger_kernel/ops/kl_div.py +23 -19
  31. liger_kernel/ops/layer_norm.py +150 -86
  32. liger_kernel/ops/llama4_rope.py +225 -0
  33. liger_kernel/ops/multi_token_attention.py +207 -0
  34. liger_kernel/ops/poly_norm.py +386 -0
  35. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  36. liger_kernel/ops/rms_norm.py +314 -84
  37. liger_kernel/ops/rope.py +32 -34
  38. liger_kernel/ops/softmax.py +201 -0
  39. liger_kernel/ops/sparsemax.py +179 -0
  40. liger_kernel/ops/swiglu.py +5 -9
  41. liger_kernel/ops/tiled_mlp.py +136 -0
  42. liger_kernel/ops/tvd.py +207 -0
  43. liger_kernel/ops/utils.py +8 -4
  44. liger_kernel/transformers/__init__.py +199 -24
  45. liger_kernel/transformers/auto_model.py +6 -13
  46. liger_kernel/transformers/cross_entropy.py +33 -20
  47. liger_kernel/transformers/dyt.py +22 -0
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -3
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +291 -13
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -4
  57. liger_kernel/transformers/group_norm.py +50 -0
  58. liger_kernel/transformers/grpo_loss.py +98 -0
  59. liger_kernel/transformers/jsd.py +2 -7
  60. liger_kernel/transformers/kl_div.py +1 -3
  61. liger_kernel/transformers/layer_norm.py +3 -9
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/falcon_h1.py +122 -0
  64. liger_kernel/transformers/model/gemma.py +77 -77
  65. liger_kernel/transformers/model/gemma2.py +283 -0
  66. liger_kernel/transformers/model/gemma3.py +331 -0
  67. liger_kernel/transformers/model/glm4.py +141 -0
  68. liger_kernel/transformers/model/glm4v.py +163 -0
  69. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  70. liger_kernel/transformers/model/internvl.py +157 -0
  71. liger_kernel/transformers/model/llama.py +128 -79
  72. liger_kernel/transformers/model/llama4.py +121 -0
  73. liger_kernel/transformers/model/llava.py +344 -0
  74. liger_kernel/transformers/model/loss_utils.py +95 -0
  75. liger_kernel/transformers/model/mistral.py +68 -64
  76. liger_kernel/transformers/model/mixtral.py +75 -91
  77. liger_kernel/transformers/model/mllama.py +63 -68
  78. liger_kernel/transformers/model/olmo2.py +141 -0
  79. liger_kernel/transformers/model/output_classes.py +147 -0
  80. liger_kernel/transformers/model/paligemma.py +432 -0
  81. liger_kernel/transformers/model/phi3.py +59 -213
  82. liger_kernel/transformers/model/qwen2.py +75 -72
  83. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  84. liger_kernel/transformers/model/qwen2_vl.py +78 -98
  85. liger_kernel/transformers/model/qwen3.py +136 -0
  86. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  87. liger_kernel/transformers/model/qwen3_next.py +146 -0
  88. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  89. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  90. liger_kernel/transformers/model/smollm3.py +199 -0
  91. liger_kernel/transformers/model/smolvlm.py +158 -0
  92. liger_kernel/transformers/monkey_patch.py +2106 -289
  93. liger_kernel/transformers/multi_token_attention.py +64 -0
  94. liger_kernel/transformers/poly_norm.py +42 -0
  95. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  96. liger_kernel/transformers/rms_norm.py +57 -6
  97. liger_kernel/transformers/rope.py +45 -2
  98. liger_kernel/transformers/softmax.py +12 -0
  99. liger_kernel/transformers/sparsemax.py +16 -0
  100. liger_kernel/transformers/swiglu.py +23 -8
  101. liger_kernel/transformers/tiled_mlp.py +133 -0
  102. liger_kernel/transformers/trainer/__init__.py +4 -0
  103. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  104. liger_kernel/transformers/tvd.py +13 -0
  105. liger_kernel/triton/__init__.py +1 -3
  106. liger_kernel/triton/monkey_patch.py +1 -3
  107. liger_kernel/utils.py +71 -0
  108. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
  109. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  110. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
  111. liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
  112. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,25 @@
1
+ import operator
2
+
3
+ from typing import Optional
4
+
1
5
  import torch
2
6
  import triton
3
7
  import triton.language as tl
4
8
 
5
- from liger_kernel.ops.utils import element_mul_kernel, is_hip
9
+ from liger_kernel.ops.utils import compare_version
10
+ from liger_kernel.ops.utils import element_mul_kernel
11
+ from liger_kernel.ops.utils import is_hip
12
+ from liger_kernel.utils import infer_device
6
13
 
7
- _TRUE = tl.constexpr(1)
8
- _FALSE = tl.constexpr(0)
14
+ if compare_version("triton", operator.ge, "3.0.0"):
15
+ try:
16
+ # typical import path with dispatch available
17
+ from triton.language.extra.libdevice import tanh
18
+ except ModuleNotFoundError:
19
+ # for working with NGC containers
20
+ from triton.language.extra.cuda.libdevice import tanh
21
+ else:
22
+ from triton.language.math import tanh
9
23
 
10
24
 
11
25
  @triton.jit
@@ -14,17 +28,27 @@ def liger_cross_entropy_kernel(
14
28
  X_stride,
15
29
  Y_ptr,
16
30
  Y_stride,
31
+ weight_ptr,
17
32
  loss_ptr,
18
33
  z_loss_ptr,
19
34
  loss_stride,
35
+ token_accuracy_ptr,
36
+ token_accuracy_stride,
20
37
  n_cols,
21
38
  n_non_ignore,
39
+ sum_non_ignore_weight,
40
+ weight_sum,
22
41
  ignore_index,
23
42
  lse_square_scale: tl.constexpr,
24
43
  label_smoothing: tl.constexpr,
25
44
  reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
45
+ softcap,
26
46
  RETURN_Z_LOSS: tl.constexpr,
47
+ RETURN_TOKEN_ACCURACY: tl.constexpr,
27
48
  BLOCK_SIZE: tl.constexpr,
49
+ HAS_WEIGHT: tl.constexpr,
50
+ HAS_SOFTCAPPING: tl.constexpr,
51
+ HAS_GRADIENTS: tl.constexpr,
28
52
  ):
29
53
  """
30
54
  This kernel computes both cross entropy loss and the gradient of the input.
@@ -35,17 +59,27 @@ def liger_cross_entropy_kernel(
35
59
  X_stride (int): The stride of the input tensor.
36
60
  Y_ptr: Pointer to target tensor.
37
61
  Y_stride (int): The stride of the target tensor.
62
+ weight_ptr: Pointer to weight tensor.
38
63
  loss_ptr: Pointer to tensor to store the loss.
39
64
  z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
40
65
  loss_stride (int): The stride of the loss tensor.
66
+ token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
67
+ token_accuracy_stride (int): The stride of the token accuracy tensor.
41
68
  n_cols (int): The number of columns in the input tensor.
42
- n_non_ignore (int): The number of non-ignored elements in the batch.
69
+ n_non_ignore (float): The number of non-ignored elements in the batch.
70
+ sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
71
+ weight_sum (float): The sum of weight tensor.
43
72
  ignore_index (int): The index to ignore in the target.
44
73
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
45
74
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
46
- RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
47
75
  reduction (str): The string for the reduction to apply
76
+ softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
77
+ RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
78
+ RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
48
79
  BLOCK_SIZE (int): The block size for Triton operations.
80
+ HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
81
+ HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
82
+ HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
49
83
  """
50
84
 
51
85
  # https://github.com/triton-lang/triton/issues/1058
@@ -64,10 +98,20 @@ def liger_cross_entropy_kernel(
64
98
  for i in range(0, n_cols, BLOCK_SIZE):
65
99
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
66
100
  tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
101
+ # For ignored tokens, set token accuracy to 0
102
+ if RETURN_TOKEN_ACCURACY:
103
+ token_accuracy_ptr += program_id * token_accuracy_stride
104
+ tl.store(token_accuracy_ptr, 0.0)
67
105
  return
68
106
 
69
107
  loss_ptr += program_id * loss_stride
70
- z_loss_ptr += program_id * loss_stride
108
+ if RETURN_Z_LOSS:
109
+ z_loss_ptr += program_id * loss_stride
110
+ if RETURN_TOKEN_ACCURACY:
111
+ token_accuracy_ptr += program_id * token_accuracy_stride
112
+
113
+ if HAS_WEIGHT:
114
+ weight_y = tl.load(weight_ptr + y).cast(tl.float32)
71
115
 
72
116
  # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
73
117
  # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
@@ -75,9 +119,10 @@ def liger_cross_entropy_kernel(
75
119
  # 3. [Online softmax] first pass: find max + sum
76
120
  m = float("-inf") # m is the max value. use the notation from the paper
77
121
  d = 0.0 # d is the sum. use the notation from the paper
78
- ori_X_y = tl.load(
79
- X_ptr + y
80
- ) # we need to store the original value of X_y for the loss calculation
122
+ argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
123
+ ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
124
+ if HAS_SOFTCAPPING:
125
+ ori_X_y = softcap * tanh(ori_X_y / softcap)
81
126
 
82
127
  # Label smoothing is a general case of normal cross entropy
83
128
  # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
@@ -87,12 +132,31 @@ def liger_cross_entropy_kernel(
87
132
  for i in range(0, n_cols, BLOCK_SIZE):
88
133
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
89
134
  X_block = tl.load(
90
- X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
91
- )
135
+ X_ptr + X_offsets,
136
+ mask=X_offsets < n_cols,
137
+ other=float("-inf"),
138
+ # Ensure float32 precision for softmax calculation
139
+ ).cast(tl.float32)
140
+ if HAS_SOFTCAPPING:
141
+ X_block = softcap * tanh(X_block / softcap)
92
142
  block_max = tl.max(X_block)
143
+
144
+ # Track argmax for accuracy computation
145
+ if RETURN_TOKEN_ACCURACY and block_max > m:
146
+ # Find the index of the maximum value in this block
147
+ is_max_mask = X_block == block_max
148
+ # Mask out invalid indices with a value larger than n_cols
149
+ masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
150
+ # Get the first (smallest) index where max occurs
151
+ argmax_idx = tl.min(masked_offsets)
152
+
93
153
  if label_smoothing > 0:
94
154
  # scale X beforehand to avoid overflow
95
- scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
155
+ if HAS_WEIGHT:
156
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
157
+ scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0))
158
+ else:
159
+ scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
96
160
  m_new = tl.maximum(m, block_max)
97
161
  d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
98
162
  m = m_new
@@ -116,23 +180,58 @@ def liger_cross_entropy_kernel(
116
180
  # For 'sum' reduction, no normalization is applied:
117
181
  # dx_y = softmax(x_y) - 1
118
182
  # dx_i = softmax(x_i), for i ≠ y
119
-
120
- for i in range(0, n_cols, BLOCK_SIZE):
121
- X_offsets = i + tl.arange(0, BLOCK_SIZE)
122
- X_block = tl.load(
123
- X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
124
- )
125
- # softmax(x_i)
126
- X_block = tl.exp(X_block - m) / d
127
- # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
128
- X_block += 2 * lse_square_scale * lse * X_block
129
- # smoothing term
130
- X_block += -eps
131
- # reduction scale
132
- if reduction == "mean":
133
- X_block = X_block / (n_non_ignore)
134
-
135
- tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
183
+ if HAS_GRADIENTS:
184
+ for i in range(0, n_cols, BLOCK_SIZE):
185
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
186
+ X_block = tl.load(
187
+ X_ptr + X_offsets,
188
+ mask=X_offsets < n_cols,
189
+ other=float("-inf"),
190
+ # Ensure float32 precision for softmax calculation
191
+ ).cast(tl.float32)
192
+ if HAS_SOFTCAPPING:
193
+ intermediate = tanh(X_block / softcap)
194
+ X_block = softcap * intermediate
195
+
196
+ if not HAS_WEIGHT:
197
+ # softmax(x_i)
198
+ X_block = tl.exp(X_block - m) / d
199
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
200
+ X_block += 2 * lse_square_scale * lse * X_block
201
+ # smoothing term
202
+ X_block += -eps
203
+ # special handle dx_y
204
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
205
+ # reduction scale
206
+ if reduction == "mean":
207
+ X_block = X_block / n_non_ignore
208
+ else:
209
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
210
+ softmax_X = tl.exp(X_block - m) / d
211
+ # derivative of original_loss
212
+ dloss_ori = (1 - label_smoothing) * softmax_X
213
+ # specially handle dx_y
214
+ dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
215
+ dloss_ori = dloss_ori * weight_y
216
+ # derivative of smooth_loss
217
+ dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
218
+ # derivative of z-loss
219
+ dz_loss = 2 * lse_square_scale * lse * softmax_X
220
+ # reduction scale
221
+ if reduction == "mean":
222
+ dloss_ori = dloss_ori / sum_non_ignore_weight
223
+ dloss_smooth = dloss_smooth / sum_non_ignore_weight
224
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
225
+ dz_loss = dz_loss / n_non_ignore
226
+ # derivative of total_loss
227
+ X_block = dloss_ori + dloss_smooth + dz_loss
228
+
229
+ # chain rule softcapping
230
+ # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
231
+ if HAS_SOFTCAPPING:
232
+ X_block = X_block * (1 - intermediate * intermediate)
233
+
234
+ tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
136
235
 
137
236
  # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
138
237
  # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
@@ -146,71 +245,68 @@ def liger_cross_entropy_kernel(
146
245
  # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
147
246
  # So we can safely calculate log (softmax(X_y)) without overflow
148
247
  loss = lse - ori_X_y
248
+ if HAS_WEIGHT:
249
+ loss = weight_y * loss
149
250
 
150
251
  # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
151
252
  # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
152
253
  # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
153
254
  # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
154
- # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
255
+ # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd))
155
256
  # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
156
257
  # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
157
258
  # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
158
259
  if label_smoothing > 0:
159
- smooth_loss = scaled_x_sum + label_smoothing * lse
260
+ if HAS_WEIGHT:
261
+ smooth_loss = scaled_x_sum + eps * lse * weight_sum
262
+ else:
263
+ smooth_loss = scaled_x_sum + label_smoothing * lse
160
264
  loss = loss * (1 - label_smoothing) + smooth_loss
161
265
 
162
266
  # An auxiliary loss, z_loss
163
267
  # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
164
268
  z_loss = lse_square_scale * lse * lse
165
- loss += z_loss
166
269
  # Normalize the loss by the number of non-ignored elements if reduction is "mean"
167
270
  if reduction == "mean":
271
+ if HAS_WEIGHT:
272
+ loss = loss / sum_non_ignore_weight
273
+ else:
274
+ loss = loss / n_non_ignore
275
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
168
276
  z_loss = z_loss / n_non_ignore
169
- loss = loss / n_non_ignore
170
-
171
- # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
172
- X_y = tl.load(X_ptr + y)
173
- if reduction == "mean":
174
- X_y += -(1 - label_smoothing) / (n_non_ignore)
175
- else:
176
- X_y += -(1 - label_smoothing)
277
+ loss += z_loss
177
278
 
178
279
  tl.store(loss_ptr, loss)
179
- if RETURN_Z_LOSS == _TRUE:
280
+ if RETURN_Z_LOSS:
180
281
  tl.store(z_loss_ptr, z_loss)
181
- tl.store(X_ptr + y, X_y)
282
+ if RETURN_TOKEN_ACCURACY:
283
+ # Store 1.0 if prediction is correct, 0.0 otherwise
284
+ is_correct = 1.0 if argmax_idx == y else 0.0
285
+ tl.store(token_accuracy_ptr, is_correct)
182
286
 
183
287
 
184
288
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
185
289
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
186
290
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
187
- MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
188
-
189
-
190
- _bool_to_return_z_loss = {
191
- True: _TRUE.value,
192
- False: _FALSE.value,
193
- }
291
+ MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
194
292
 
195
293
 
196
294
  def cross_entropy_forward(
197
295
  _input,
198
296
  target,
297
+ weight,
199
298
  ignore_index,
200
299
  lse_square_scale,
201
300
  label_smoothing,
202
301
  reduction,
302
+ softcap,
203
303
  return_z_loss,
304
+ return_token_accuracy=False,
204
305
  ):
205
- if not isinstance(return_z_loss, int):
206
- assert (
207
- return_z_loss in _bool_to_return_z_loss
208
- ), f"return_z_loss must be True or False. Got: {return_z_loss}"
209
- return_z_loss = _bool_to_return_z_loss[return_z_loss]
210
- else:
211
- assert (
212
- return_z_loss in _bool_to_return_z_loss
213
- ), f"return_z_loss must be True or False. Got: {return_z_loss}"
306
+ assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
307
+ assert isinstance(return_token_accuracy, bool), (
308
+ f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
309
+ )
214
310
 
215
311
  BT, V = _input.shape
216
312
  n_rows = BT
@@ -219,12 +315,29 @@ def cross_entropy_forward(
219
315
 
220
316
  # unreduced loss
221
317
  loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
222
- if return_z_loss == _TRUE.value:
223
- z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
224
- else:
225
- z_loss_1d = loss_1d # dummy ptr when return_z_loss == False
318
+ z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
319
+ token_accuracy_1d = (
320
+ torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
321
+ )
226
322
 
227
- n_non_ignore = (target != ignore_index).sum().item()
323
+ target_mask = target != ignore_index
324
+ n_non_ignore = target_mask.sum().item()
325
+ assert (target * target_mask).max() < _input.shape[-1], (
326
+ f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}"
327
+ )
328
+ assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0"
329
+ sum_non_ignore_weight = n_non_ignore
330
+ weight_sum = 0.0
331
+ if weight is not None:
332
+ assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
333
+ assert torch.is_floating_point(weight), (
334
+ f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
335
+ )
336
+ sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
337
+ weight_sum = weight.sum().item()
338
+ # ensure weight is contiguous
339
+ if weight.stride(-1) != 1:
340
+ weight = weight.contiguous()
228
341
 
229
342
  # ensure _input and target are contiguous in the last dimension
230
343
  if _input.stride(-1) != 1:
@@ -238,36 +351,55 @@ def cross_entropy_forward(
238
351
  X_stride=_input.stride(-2),
239
352
  Y_ptr=target,
240
353
  Y_stride=target.stride(-1), # always 1
354
+ weight_ptr=weight, # dummy if None
241
355
  loss_ptr=loss_1d,
242
356
  z_loss_ptr=z_loss_1d,
243
357
  loss_stride=loss_1d.stride(-1), # always 1
358
+ token_accuracy_ptr=token_accuracy_1d,
359
+ token_accuracy_stride=token_accuracy_1d.stride(-1)
360
+ if return_token_accuracy
361
+ else 0, # always 1 if accuracy is enabled
244
362
  n_cols=V,
245
363
  n_non_ignore=n_non_ignore,
364
+ sum_non_ignore_weight=sum_non_ignore_weight,
246
365
  ignore_index=ignore_index,
366
+ weight_sum=weight_sum,
247
367
  lse_square_scale=lse_square_scale,
248
368
  label_smoothing=label_smoothing,
249
369
  reduction=reduction,
370
+ softcap=softcap,
250
371
  RETURN_Z_LOSS=return_z_loss,
372
+ RETURN_TOKEN_ACCURACY=return_token_accuracy,
251
373
  BLOCK_SIZE=BLOCK_SIZE,
374
+ HAS_WEIGHT=True if weight is not None else False,
375
+ HAS_SOFTCAPPING=True if softcap is not None else False,
376
+ HAS_GRADIENTS=_input.requires_grad,
252
377
  # TODO: 32 seems to give the best performance
253
378
  # Performance is quite sensitive to num_warps
254
379
  num_warps=32 if not is_hip() else 16,
255
380
  )
256
381
 
257
- loss = torch.sum(loss_1d)
258
- if return_z_loss == _TRUE.value:
259
- z_loss = torch.sum(z_loss_1d)
382
+ if reduction == "none":
383
+ loss = loss_1d
384
+ z_loss = z_loss_1d if return_z_loss else None
385
+ token_accuracy = token_accuracy_1d if return_token_accuracy else None
260
386
  else:
261
- z_loss = None
387
+ loss = torch.sum(loss_1d)
388
+ z_loss = torch.sum(z_loss_1d) if return_z_loss else None
389
+ # For accuracy, we compute the mean across all non-ignored tokens
390
+ token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
262
391
 
263
- return loss, z_loss, _input
392
+ return loss, z_loss, token_accuracy, _input
264
393
 
265
394
 
266
395
  def cross_entropy_backward(_input, grad_output):
267
396
  # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
268
397
  if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
269
398
  pass
270
-
399
+ # If reduction is 'none'
400
+ elif grad_output.ndim > 0:
401
+ _input = _input * grad_output.unsqueeze(dim=1)
402
+ # If reduction is ['mean', 'sum'], grad_output is just a scalar
271
403
  # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
272
404
  # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
273
405
  else:
@@ -296,13 +428,16 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
296
428
  @staticmethod
297
429
  def forward(
298
430
  ctx,
299
- _input,
300
- target,
301
- ignore_index=-100,
302
- lse_square_scale=0.0,
303
- label_smoothing=0.0,
304
- reduction="mean",
305
- return_z_loss=False,
431
+ _input: torch.Tensor,
432
+ target: torch.Tensor,
433
+ weight: Optional[torch.FloatTensor],
434
+ ignore_index: int = -100,
435
+ lse_square_scale: float = 0.0,
436
+ label_smoothing: float = 0.0,
437
+ reduction: str = "mean",
438
+ softcap: Optional[float] = None,
439
+ return_z_loss: bool = False,
440
+ return_token_accuracy: bool = False,
306
441
  ):
307
442
  """
308
443
  The forward pass of the Liger Cross Entropy loss.
@@ -311,46 +446,59 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
311
446
  ctx : The context object.
312
447
  _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
313
448
  target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
449
+ weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
314
450
  ignore_index (int): The index to ignore in the target.
315
451
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
316
452
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
317
453
  reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
318
- return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
454
+ softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
455
+ return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
456
+ return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
319
457
 
320
458
  Returns:
321
- tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
459
+ tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
322
460
  """
323
- loss, z_loss, _input = cross_entropy_forward(
461
+ input_requires_grad = _input.requires_grad
462
+
463
+ loss, z_loss, token_accuracy, _input = cross_entropy_forward(
324
464
  _input,
325
465
  target,
466
+ weight,
326
467
  ignore_index,
327
468
  lse_square_scale,
328
469
  label_smoothing,
329
470
  reduction,
471
+ softcap,
330
472
  return_z_loss,
473
+ return_token_accuracy,
331
474
  )
332
475
  # TODO: investigation
333
476
  # If we don't detach the _input tensor, the memory will double
334
477
  # Not sure why but seems that there will be a time both grad and value exist but in different location
335
- ctx.save_for_backward(_input.detach())
478
+ if input_requires_grad:
479
+ ctx.save_for_backward(_input.detach())
336
480
  ctx.return_z_loss = return_z_loss
481
+ ctx.return_token_accuracy = return_token_accuracy
337
482
 
338
- return loss, z_loss
483
+ return loss, z_loss, token_accuracy
339
484
 
340
485
  @staticmethod
341
- def backward(ctx, grad_output, grad_ouput2):
486
+ def backward(ctx, grad_output, grad_output2, grad_output3):
342
487
  """
343
488
  The backward pass of the Liger Cross Entropy loss.
344
489
 
345
490
  Parameters:
346
491
  ctx : The context object with saved tensors.
347
492
  grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
348
- grad_output2 (tenosr): No use.
493
+ grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
494
+ grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
349
495
  Returns:
350
496
  tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
351
497
  """
352
498
  if ctx.return_z_loss:
353
- del grad_ouput2 # z_loss is only for logging
499
+ del grad_output2 # z_loss is only for logging
500
+ if ctx.return_token_accuracy:
501
+ del grad_output3 # token_accuracy is only for metrics
354
502
 
355
503
  (_input,) = ctx.saved_tensors
356
504
  _input = cross_entropy_backward(_input, grad_output)
@@ -362,4 +510,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
362
510
  None,
363
511
  None,
364
512
  None,
513
+ None,
514
+ None,
515
+ None,
365
516
  )