liger-kernel-nightly 0.5.5.dev20250402185702__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
- liger_kernel/chunked_loss/dpo_loss.py +61 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +36 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/grpo_loss.py +76 -5
- liger_kernel/chunked_loss/jsd_loss.py +46 -15
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +134 -65
- liger_kernel/ops/dyt.py +115 -180
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +117 -23
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +6 -4
- liger_kernel/ops/group_norm.py +7 -7
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +2 -1
- liger_kernel/ops/kl_div.py +9 -5
- liger_kernel/ops/layer_norm.py +146 -78
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +398 -99
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +14 -0
- liger_kernel/transformers/__init__.py +208 -17
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +6 -4
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +122 -20
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +57 -27
- liger_kernel/transformers/model/gemma2.py +65 -28
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +109 -27
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +111 -136
- liger_kernel/transformers/model/loss_utils.py +50 -12
- liger_kernel/transformers/model/mistral.py +51 -34
- liger_kernel/transformers/model/mixtral.py +50 -29
- liger_kernel/transformers/model/mllama.py +46 -24
- liger_kernel/transformers/model/olmo2.py +47 -22
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +50 -14
- liger_kernel/transformers/model/phi3.py +47 -172
- liger_kernel/transformers/model/qwen2.py +55 -23
- liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
- liger_kernel/transformers/model/qwen2_vl.py +59 -108
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2018 -244
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +54 -6
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +39 -1
- liger_kernel/transformers/tiled_mlp.py +125 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +63 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +73 -39
- liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
- liger_kernel_nightly-0.5.5.dev20250402185702.dist-info/RECORD +0 -80
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
|
@@ -10,8 +10,9 @@ from liger_kernel.ops.utils import compare_version
|
|
|
10
10
|
from liger_kernel.ops.utils import element_mul_kernel
|
|
11
11
|
from liger_kernel.ops.utils import is_hip
|
|
12
12
|
from liger_kernel.utils import infer_device
|
|
13
|
+
from liger_kernel.utils import is_npu_available
|
|
13
14
|
|
|
14
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
15
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
15
16
|
try:
|
|
16
17
|
# typical import path with dispatch available
|
|
17
18
|
from triton.language.extra.libdevice import tanh
|
|
@@ -32,6 +33,8 @@ def liger_cross_entropy_kernel(
|
|
|
32
33
|
loss_ptr,
|
|
33
34
|
z_loss_ptr,
|
|
34
35
|
loss_stride,
|
|
36
|
+
token_accuracy_ptr,
|
|
37
|
+
token_accuracy_stride,
|
|
35
38
|
n_cols,
|
|
36
39
|
n_non_ignore,
|
|
37
40
|
sum_non_ignore_weight,
|
|
@@ -42,9 +45,11 @@ def liger_cross_entropy_kernel(
|
|
|
42
45
|
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
|
|
43
46
|
softcap,
|
|
44
47
|
RETURN_Z_LOSS: tl.constexpr,
|
|
48
|
+
RETURN_TOKEN_ACCURACY: tl.constexpr,
|
|
45
49
|
BLOCK_SIZE: tl.constexpr,
|
|
46
50
|
HAS_WEIGHT: tl.constexpr,
|
|
47
51
|
HAS_SOFTCAPPING: tl.constexpr,
|
|
52
|
+
HAS_GRADIENTS: tl.constexpr,
|
|
48
53
|
):
|
|
49
54
|
"""
|
|
50
55
|
This kernel computes both cross entropy loss and the gradient of the input.
|
|
@@ -59,6 +64,8 @@ def liger_cross_entropy_kernel(
|
|
|
59
64
|
loss_ptr: Pointer to tensor to store the loss.
|
|
60
65
|
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
|
61
66
|
loss_stride (int): The stride of the loss tensor.
|
|
67
|
+
token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
|
|
68
|
+
token_accuracy_stride (int): The stride of the token accuracy tensor.
|
|
62
69
|
n_cols (int): The number of columns in the input tensor.
|
|
63
70
|
n_non_ignore (float): The number of non-ignored elements in the batch.
|
|
64
71
|
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
|
@@ -68,10 +75,12 @@ def liger_cross_entropy_kernel(
|
|
|
68
75
|
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
69
76
|
reduction (str): The string for the reduction to apply
|
|
70
77
|
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
71
|
-
RETURN_Z_LOSS (int): The boolean value to decide whether
|
|
78
|
+
RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
|
|
79
|
+
RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
|
|
72
80
|
BLOCK_SIZE (int): The block size for Triton operations.
|
|
73
81
|
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
|
74
82
|
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
|
83
|
+
HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
|
|
75
84
|
"""
|
|
76
85
|
|
|
77
86
|
# https://github.com/triton-lang/triton/issues/1058
|
|
@@ -90,11 +99,17 @@ def liger_cross_entropy_kernel(
|
|
|
90
99
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
91
100
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
92
101
|
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
|
|
102
|
+
# For ignored tokens, set token accuracy to 0
|
|
103
|
+
if RETURN_TOKEN_ACCURACY:
|
|
104
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
105
|
+
tl.store(token_accuracy_ptr, 0.0)
|
|
93
106
|
return
|
|
94
107
|
|
|
95
108
|
loss_ptr += program_id * loss_stride
|
|
96
109
|
if RETURN_Z_LOSS:
|
|
97
110
|
z_loss_ptr += program_id * loss_stride
|
|
111
|
+
if RETURN_TOKEN_ACCURACY:
|
|
112
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
98
113
|
|
|
99
114
|
if HAS_WEIGHT:
|
|
100
115
|
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
|
|
@@ -105,6 +120,7 @@ def liger_cross_entropy_kernel(
|
|
|
105
120
|
# 3. [Online softmax] first pass: find max + sum
|
|
106
121
|
m = float("-inf") # m is the max value. use the notation from the paper
|
|
107
122
|
d = 0.0 # d is the sum. use the notation from the paper
|
|
123
|
+
argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
|
|
108
124
|
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
|
|
109
125
|
if HAS_SOFTCAPPING:
|
|
110
126
|
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
|
@@ -125,6 +141,19 @@ def liger_cross_entropy_kernel(
|
|
|
125
141
|
if HAS_SOFTCAPPING:
|
|
126
142
|
X_block = softcap * tanh(X_block / softcap)
|
|
127
143
|
block_max = tl.max(X_block)
|
|
144
|
+
|
|
145
|
+
# Track argmax for accuracy computation
|
|
146
|
+
if RETURN_TOKEN_ACCURACY:
|
|
147
|
+
# Find the index of the maximum value in this block
|
|
148
|
+
is_max_mask = X_block == block_max
|
|
149
|
+
# Mask out invalid indices with a value larger than n_cols
|
|
150
|
+
masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
|
|
151
|
+
# Get the first (smallest) index where max occurs
|
|
152
|
+
current_block_argmax_idx = tl.min(masked_offsets)
|
|
153
|
+
|
|
154
|
+
is_new_max = block_max > m
|
|
155
|
+
argmax_idx = tl.where(is_new_max, current_block_argmax_idx, argmax_idx)
|
|
156
|
+
|
|
128
157
|
if label_smoothing > 0:
|
|
129
158
|
# scale X beforehand to avoid overflow
|
|
130
159
|
if HAS_WEIGHT:
|
|
@@ -155,58 +184,58 @@ def liger_cross_entropy_kernel(
|
|
|
155
184
|
# For 'sum' reduction, no normalization is applied:
|
|
156
185
|
# dx_y = softmax(x_y) - 1
|
|
157
186
|
# dx_i = softmax(x_i), for i ≠ y
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
187
|
+
if HAS_GRADIENTS:
|
|
188
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
189
|
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
190
|
+
X_block = tl.load(
|
|
191
|
+
X_ptr + X_offsets,
|
|
192
|
+
mask=X_offsets < n_cols,
|
|
193
|
+
other=float("-inf"),
|
|
194
|
+
# Ensure float32 precision for softmax calculation
|
|
195
|
+
).cast(tl.float32)
|
|
196
|
+
if HAS_SOFTCAPPING:
|
|
197
|
+
intermediate = tanh(X_block / softcap)
|
|
198
|
+
X_block = softcap * intermediate
|
|
199
|
+
|
|
200
|
+
if not HAS_WEIGHT:
|
|
201
|
+
# softmax(x_i)
|
|
202
|
+
X_block = tl.exp(X_block - m) / d
|
|
203
|
+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
|
204
|
+
X_block += 2 * lse_square_scale * lse * X_block
|
|
205
|
+
# smoothing term
|
|
206
|
+
X_block += -eps
|
|
207
|
+
# special handle dx_y
|
|
208
|
+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
|
|
209
|
+
# reduction scale
|
|
210
|
+
if reduction == "mean":
|
|
211
|
+
X_block = X_block / n_non_ignore
|
|
212
|
+
else:
|
|
213
|
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
214
|
+
softmax_X = tl.exp(X_block - m) / d
|
|
215
|
+
# derivative of original_loss
|
|
216
|
+
dloss_ori = (1 - label_smoothing) * softmax_X
|
|
217
|
+
# specially handle dx_y
|
|
218
|
+
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
|
|
219
|
+
dloss_ori = dloss_ori * weight_y
|
|
220
|
+
# derivative of smooth_loss
|
|
221
|
+
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
|
|
222
|
+
# derivative of z-loss
|
|
223
|
+
dz_loss = 2 * lse_square_scale * lse * softmax_X
|
|
224
|
+
# reduction scale
|
|
225
|
+
if reduction == "mean":
|
|
226
|
+
dloss_ori = dloss_ori / sum_non_ignore_weight
|
|
227
|
+
dloss_smooth = dloss_smooth / sum_non_ignore_weight
|
|
228
|
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
|
229
|
+
dz_loss = dz_loss / n_non_ignore
|
|
230
|
+
# derivative of total_loss
|
|
231
|
+
X_block = dloss_ori + dloss_smooth + dz_loss
|
|
232
|
+
|
|
233
|
+
# chain rule softcapping
|
|
234
|
+
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
|
|
235
|
+
if HAS_SOFTCAPPING:
|
|
236
|
+
X_block = X_block * (1 - intermediate * intermediate)
|
|
237
|
+
|
|
238
|
+
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
|
210
239
|
|
|
211
240
|
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
|
|
212
241
|
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
|
|
@@ -254,12 +283,22 @@ def liger_cross_entropy_kernel(
|
|
|
254
283
|
tl.store(loss_ptr, loss)
|
|
255
284
|
if RETURN_Z_LOSS:
|
|
256
285
|
tl.store(z_loss_ptr, z_loss)
|
|
286
|
+
if RETURN_TOKEN_ACCURACY:
|
|
287
|
+
# Store 1.0 if prediction is correct, 0.0 otherwise
|
|
288
|
+
is_correct = 1.0 if argmax_idx == y else 0.0
|
|
289
|
+
tl.store(token_accuracy_ptr, is_correct)
|
|
257
290
|
|
|
258
291
|
|
|
259
292
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
260
293
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
261
294
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
262
|
-
|
|
295
|
+
# the best size we found by manually tuning on xpu and npu.
|
|
296
|
+
if infer_device() == "xpu":
|
|
297
|
+
MAX_FUSED_SIZE = 4096
|
|
298
|
+
elif infer_device() == "npu":
|
|
299
|
+
MAX_FUSED_SIZE = 2048
|
|
300
|
+
else:
|
|
301
|
+
MAX_FUSED_SIZE = 65536 // 2
|
|
263
302
|
|
|
264
303
|
|
|
265
304
|
def cross_entropy_forward(
|
|
@@ -272,8 +311,12 @@ def cross_entropy_forward(
|
|
|
272
311
|
reduction,
|
|
273
312
|
softcap,
|
|
274
313
|
return_z_loss,
|
|
314
|
+
return_token_accuracy=False,
|
|
275
315
|
):
|
|
276
316
|
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
317
|
+
assert isinstance(return_token_accuracy, bool), (
|
|
318
|
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
|
319
|
+
)
|
|
277
320
|
|
|
278
321
|
BT, V = _input.shape
|
|
279
322
|
n_rows = BT
|
|
@@ -283,6 +326,9 @@ def cross_entropy_forward(
|
|
|
283
326
|
# unreduced loss
|
|
284
327
|
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
285
328
|
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
329
|
+
token_accuracy_1d = (
|
|
330
|
+
torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
|
|
331
|
+
)
|
|
286
332
|
|
|
287
333
|
target_mask = target != ignore_index
|
|
288
334
|
n_non_ignore = target_mask.sum().item()
|
|
@@ -319,6 +365,10 @@ def cross_entropy_forward(
|
|
|
319
365
|
loss_ptr=loss_1d,
|
|
320
366
|
z_loss_ptr=z_loss_1d,
|
|
321
367
|
loss_stride=loss_1d.stride(-1), # always 1
|
|
368
|
+
token_accuracy_ptr=token_accuracy_1d,
|
|
369
|
+
token_accuracy_stride=token_accuracy_1d.stride(-1)
|
|
370
|
+
if return_token_accuracy
|
|
371
|
+
else 0, # always 1 if accuracy is enabled
|
|
322
372
|
n_cols=V,
|
|
323
373
|
n_non_ignore=n_non_ignore,
|
|
324
374
|
sum_non_ignore_weight=sum_non_ignore_weight,
|
|
@@ -329,9 +379,11 @@ def cross_entropy_forward(
|
|
|
329
379
|
reduction=reduction,
|
|
330
380
|
softcap=softcap,
|
|
331
381
|
RETURN_Z_LOSS=return_z_loss,
|
|
382
|
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
|
332
383
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
333
384
|
HAS_WEIGHT=True if weight is not None else False,
|
|
334
385
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
386
|
+
HAS_GRADIENTS=_input.requires_grad,
|
|
335
387
|
# TODO: 32 seems to give the best performance
|
|
336
388
|
# Performance is quite sensitive to num_warps
|
|
337
389
|
num_warps=32 if not is_hip() else 16,
|
|
@@ -340,18 +392,24 @@ def cross_entropy_forward(
|
|
|
340
392
|
if reduction == "none":
|
|
341
393
|
loss = loss_1d
|
|
342
394
|
z_loss = z_loss_1d if return_z_loss else None
|
|
395
|
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
|
343
396
|
else:
|
|
344
397
|
loss = torch.sum(loss_1d)
|
|
345
398
|
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
399
|
+
# For accuracy, we compute the mean across all non-ignored tokens
|
|
400
|
+
token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
|
|
346
401
|
|
|
347
|
-
return loss, z_loss, _input
|
|
402
|
+
return loss, z_loss, token_accuracy, _input
|
|
348
403
|
|
|
349
404
|
|
|
350
405
|
def cross_entropy_backward(_input, grad_output):
|
|
351
406
|
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
|
352
407
|
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
353
408
|
pass
|
|
354
|
-
|
|
409
|
+
# If reduction is 'none'
|
|
410
|
+
elif grad_output.ndim > 0:
|
|
411
|
+
_input = _input * grad_output.unsqueeze(dim=1)
|
|
412
|
+
# If reduction is ['mean', 'sum'], grad_output is just a scalar
|
|
355
413
|
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
|
356
414
|
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
|
357
415
|
else:
|
|
@@ -389,6 +447,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
389
447
|
reduction: str = "mean",
|
|
390
448
|
softcap: Optional[float] = None,
|
|
391
449
|
return_z_loss: bool = False,
|
|
450
|
+
return_token_accuracy: bool = False,
|
|
392
451
|
):
|
|
393
452
|
"""
|
|
394
453
|
The forward pass of the Liger Cross Entropy loss.
|
|
@@ -403,12 +462,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
403
462
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
404
463
|
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
|
|
405
464
|
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
406
|
-
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
|
|
465
|
+
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
|
|
466
|
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
|
407
467
|
|
|
408
468
|
Returns:
|
|
409
|
-
tuple: A tuple with the
|
|
469
|
+
tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
|
|
410
470
|
"""
|
|
411
|
-
|
|
471
|
+
input_requires_grad = _input.requires_grad
|
|
472
|
+
|
|
473
|
+
loss, z_loss, token_accuracy, _input = cross_entropy_forward(
|
|
412
474
|
_input,
|
|
413
475
|
target,
|
|
414
476
|
weight,
|
|
@@ -418,29 +480,35 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
418
480
|
reduction,
|
|
419
481
|
softcap,
|
|
420
482
|
return_z_loss,
|
|
483
|
+
return_token_accuracy,
|
|
421
484
|
)
|
|
422
485
|
# TODO: investigation
|
|
423
486
|
# If we don't detach the _input tensor, the memory will double
|
|
424
487
|
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
|
425
|
-
|
|
488
|
+
if input_requires_grad:
|
|
489
|
+
ctx.save_for_backward(_input.detach())
|
|
426
490
|
ctx.return_z_loss = return_z_loss
|
|
491
|
+
ctx.return_token_accuracy = return_token_accuracy
|
|
427
492
|
|
|
428
|
-
return loss, z_loss
|
|
493
|
+
return loss, z_loss, token_accuracy
|
|
429
494
|
|
|
430
495
|
@staticmethod
|
|
431
|
-
def backward(ctx, grad_output,
|
|
496
|
+
def backward(ctx, grad_output, grad_output2, grad_output3):
|
|
432
497
|
"""
|
|
433
498
|
The backward pass of the Liger Cross Entropy loss.
|
|
434
499
|
|
|
435
500
|
Parameters:
|
|
436
501
|
ctx : The context object with saved tensors.
|
|
437
502
|
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
|
|
438
|
-
grad_output2 (
|
|
503
|
+
grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
|
|
504
|
+
grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
|
|
439
505
|
Returns:
|
|
440
506
|
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
|
|
441
507
|
"""
|
|
442
508
|
if ctx.return_z_loss:
|
|
443
|
-
del
|
|
509
|
+
del grad_output2 # z_loss is only for logging
|
|
510
|
+
if ctx.return_token_accuracy:
|
|
511
|
+
del grad_output3 # token_accuracy is only for metrics
|
|
444
512
|
|
|
445
513
|
(_input,) = ctx.saved_tensors
|
|
446
514
|
_input = cross_entropy_backward(_input, grad_output)
|
|
@@ -454,4 +522,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
454
522
|
None,
|
|
455
523
|
None,
|
|
456
524
|
None,
|
|
525
|
+
None,
|
|
457
526
|
)
|