liger-kernel-nightly 0.5.6.dev20250403190551__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/dpo_loss.py +61 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +35 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/grpo_loss.py +76 -5
- liger_kernel/chunked_loss/jsd_loss.py +25 -9
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +124 -64
- liger_kernel/ops/dyt.py +115 -180
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +3 -2
- liger_kernel/ops/group_norm.py +2 -1
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +2 -1
- liger_kernel/ops/kl_div.py +13 -6
- liger_kernel/ops/layer_norm.py +146 -78
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +283 -56
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +205 -19
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +6 -4
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +122 -20
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +50 -25
- liger_kernel/transformers/model/gemma2.py +55 -23
- liger_kernel/transformers/model/gemma3.py +117 -120
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +102 -25
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +111 -136
- liger_kernel/transformers/model/loss_utils.py +50 -12
- liger_kernel/transformers/model/mistral.py +36 -23
- liger_kernel/transformers/model/mixtral.py +45 -25
- liger_kernel/transformers/model/mllama.py +39 -22
- liger_kernel/transformers/model/olmo2.py +40 -20
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +50 -14
- liger_kernel/transformers/model/phi3.py +47 -177
- liger_kernel/transformers/model/qwen2.py +48 -21
- liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
- liger_kernel/transformers/model/qwen2_vl.py +59 -108
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +1678 -160
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +48 -5
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +39 -1
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +36 -0
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/METADATA +68 -38
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
- liger_kernel/transformers/gema3_rms.py +0 -8
- liger_kernel_nightly-0.5.6.dev20250403190551.dist-info/RECORD +0 -82
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
|
@@ -10,8 +10,9 @@ from liger_kernel.ops.utils import compare_version
|
|
|
10
10
|
from liger_kernel.ops.utils import element_mul_kernel
|
|
11
11
|
from liger_kernel.ops.utils import is_hip
|
|
12
12
|
from liger_kernel.utils import infer_device
|
|
13
|
+
from liger_kernel.utils import is_npu_available
|
|
13
14
|
|
|
14
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
15
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
15
16
|
try:
|
|
16
17
|
# typical import path with dispatch available
|
|
17
18
|
from triton.language.extra.libdevice import tanh
|
|
@@ -32,6 +33,8 @@ def liger_cross_entropy_kernel(
|
|
|
32
33
|
loss_ptr,
|
|
33
34
|
z_loss_ptr,
|
|
34
35
|
loss_stride,
|
|
36
|
+
token_accuracy_ptr,
|
|
37
|
+
token_accuracy_stride,
|
|
35
38
|
n_cols,
|
|
36
39
|
n_non_ignore,
|
|
37
40
|
sum_non_ignore_weight,
|
|
@@ -42,9 +45,11 @@ def liger_cross_entropy_kernel(
|
|
|
42
45
|
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
|
|
43
46
|
softcap,
|
|
44
47
|
RETURN_Z_LOSS: tl.constexpr,
|
|
48
|
+
RETURN_TOKEN_ACCURACY: tl.constexpr,
|
|
45
49
|
BLOCK_SIZE: tl.constexpr,
|
|
46
50
|
HAS_WEIGHT: tl.constexpr,
|
|
47
51
|
HAS_SOFTCAPPING: tl.constexpr,
|
|
52
|
+
HAS_GRADIENTS: tl.constexpr,
|
|
48
53
|
):
|
|
49
54
|
"""
|
|
50
55
|
This kernel computes both cross entropy loss and the gradient of the input.
|
|
@@ -59,6 +64,8 @@ def liger_cross_entropy_kernel(
|
|
|
59
64
|
loss_ptr: Pointer to tensor to store the loss.
|
|
60
65
|
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
|
61
66
|
loss_stride (int): The stride of the loss tensor.
|
|
67
|
+
token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
|
|
68
|
+
token_accuracy_stride (int): The stride of the token accuracy tensor.
|
|
62
69
|
n_cols (int): The number of columns in the input tensor.
|
|
63
70
|
n_non_ignore (float): The number of non-ignored elements in the batch.
|
|
64
71
|
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
|
@@ -68,10 +75,12 @@ def liger_cross_entropy_kernel(
|
|
|
68
75
|
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
69
76
|
reduction (str): The string for the reduction to apply
|
|
70
77
|
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
71
|
-
RETURN_Z_LOSS (int): The boolean value to decide whether
|
|
78
|
+
RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
|
|
79
|
+
RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
|
|
72
80
|
BLOCK_SIZE (int): The block size for Triton operations.
|
|
73
81
|
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
|
74
82
|
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
|
83
|
+
HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
|
|
75
84
|
"""
|
|
76
85
|
|
|
77
86
|
# https://github.com/triton-lang/triton/issues/1058
|
|
@@ -90,11 +99,17 @@ def liger_cross_entropy_kernel(
|
|
|
90
99
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
91
100
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
92
101
|
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
|
|
102
|
+
# For ignored tokens, set token accuracy to 0
|
|
103
|
+
if RETURN_TOKEN_ACCURACY:
|
|
104
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
105
|
+
tl.store(token_accuracy_ptr, 0.0)
|
|
93
106
|
return
|
|
94
107
|
|
|
95
108
|
loss_ptr += program_id * loss_stride
|
|
96
109
|
if RETURN_Z_LOSS:
|
|
97
110
|
z_loss_ptr += program_id * loss_stride
|
|
111
|
+
if RETURN_TOKEN_ACCURACY:
|
|
112
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
98
113
|
|
|
99
114
|
if HAS_WEIGHT:
|
|
100
115
|
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
|
|
@@ -105,6 +120,7 @@ def liger_cross_entropy_kernel(
|
|
|
105
120
|
# 3. [Online softmax] first pass: find max + sum
|
|
106
121
|
m = float("-inf") # m is the max value. use the notation from the paper
|
|
107
122
|
d = 0.0 # d is the sum. use the notation from the paper
|
|
123
|
+
argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
|
|
108
124
|
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
|
|
109
125
|
if HAS_SOFTCAPPING:
|
|
110
126
|
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
|
@@ -125,6 +141,16 @@ def liger_cross_entropy_kernel(
|
|
|
125
141
|
if HAS_SOFTCAPPING:
|
|
126
142
|
X_block = softcap * tanh(X_block / softcap)
|
|
127
143
|
block_max = tl.max(X_block)
|
|
144
|
+
|
|
145
|
+
# Track argmax for accuracy computation
|
|
146
|
+
if RETURN_TOKEN_ACCURACY and block_max > m:
|
|
147
|
+
# Find the index of the maximum value in this block
|
|
148
|
+
is_max_mask = X_block == block_max
|
|
149
|
+
# Mask out invalid indices with a value larger than n_cols
|
|
150
|
+
masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
|
|
151
|
+
# Get the first (smallest) index where max occurs
|
|
152
|
+
argmax_idx = tl.min(masked_offsets)
|
|
153
|
+
|
|
128
154
|
if label_smoothing > 0:
|
|
129
155
|
# scale X beforehand to avoid overflow
|
|
130
156
|
if HAS_WEIGHT:
|
|
@@ -155,58 +181,58 @@ def liger_cross_entropy_kernel(
|
|
|
155
181
|
# For 'sum' reduction, no normalization is applied:
|
|
156
182
|
# dx_y = softmax(x_y) - 1
|
|
157
183
|
# dx_i = softmax(x_i), for i ≠ y
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
184
|
+
if HAS_GRADIENTS:
|
|
185
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
186
|
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
187
|
+
X_block = tl.load(
|
|
188
|
+
X_ptr + X_offsets,
|
|
189
|
+
mask=X_offsets < n_cols,
|
|
190
|
+
other=float("-inf"),
|
|
191
|
+
# Ensure float32 precision for softmax calculation
|
|
192
|
+
).cast(tl.float32)
|
|
193
|
+
if HAS_SOFTCAPPING:
|
|
194
|
+
intermediate = tanh(X_block / softcap)
|
|
195
|
+
X_block = softcap * intermediate
|
|
196
|
+
|
|
197
|
+
if not HAS_WEIGHT:
|
|
198
|
+
# softmax(x_i)
|
|
199
|
+
X_block = tl.exp(X_block - m) / d
|
|
200
|
+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
|
201
|
+
X_block += 2 * lse_square_scale * lse * X_block
|
|
202
|
+
# smoothing term
|
|
203
|
+
X_block += -eps
|
|
204
|
+
# special handle dx_y
|
|
205
|
+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
|
|
206
|
+
# reduction scale
|
|
207
|
+
if reduction == "mean":
|
|
208
|
+
X_block = X_block / n_non_ignore
|
|
209
|
+
else:
|
|
210
|
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
211
|
+
softmax_X = tl.exp(X_block - m) / d
|
|
212
|
+
# derivative of original_loss
|
|
213
|
+
dloss_ori = (1 - label_smoothing) * softmax_X
|
|
214
|
+
# specially handle dx_y
|
|
215
|
+
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
|
|
216
|
+
dloss_ori = dloss_ori * weight_y
|
|
217
|
+
# derivative of smooth_loss
|
|
218
|
+
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
|
|
219
|
+
# derivative of z-loss
|
|
220
|
+
dz_loss = 2 * lse_square_scale * lse * softmax_X
|
|
221
|
+
# reduction scale
|
|
222
|
+
if reduction == "mean":
|
|
223
|
+
dloss_ori = dloss_ori / sum_non_ignore_weight
|
|
224
|
+
dloss_smooth = dloss_smooth / sum_non_ignore_weight
|
|
225
|
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
|
226
|
+
dz_loss = dz_loss / n_non_ignore
|
|
227
|
+
# derivative of total_loss
|
|
228
|
+
X_block = dloss_ori + dloss_smooth + dz_loss
|
|
229
|
+
|
|
230
|
+
# chain rule softcapping
|
|
231
|
+
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
|
|
232
|
+
if HAS_SOFTCAPPING:
|
|
233
|
+
X_block = X_block * (1 - intermediate * intermediate)
|
|
234
|
+
|
|
235
|
+
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
|
210
236
|
|
|
211
237
|
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
|
|
212
238
|
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
|
|
@@ -254,6 +280,10 @@ def liger_cross_entropy_kernel(
|
|
|
254
280
|
tl.store(loss_ptr, loss)
|
|
255
281
|
if RETURN_Z_LOSS:
|
|
256
282
|
tl.store(z_loss_ptr, z_loss)
|
|
283
|
+
if RETURN_TOKEN_ACCURACY:
|
|
284
|
+
# Store 1.0 if prediction is correct, 0.0 otherwise
|
|
285
|
+
is_correct = 1.0 if argmax_idx == y else 0.0
|
|
286
|
+
tl.store(token_accuracy_ptr, is_correct)
|
|
257
287
|
|
|
258
288
|
|
|
259
289
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
@@ -272,8 +302,12 @@ def cross_entropy_forward(
|
|
|
272
302
|
reduction,
|
|
273
303
|
softcap,
|
|
274
304
|
return_z_loss,
|
|
305
|
+
return_token_accuracy=False,
|
|
275
306
|
):
|
|
276
307
|
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
308
|
+
assert isinstance(return_token_accuracy, bool), (
|
|
309
|
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
|
310
|
+
)
|
|
277
311
|
|
|
278
312
|
BT, V = _input.shape
|
|
279
313
|
n_rows = BT
|
|
@@ -283,6 +317,9 @@ def cross_entropy_forward(
|
|
|
283
317
|
# unreduced loss
|
|
284
318
|
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
285
319
|
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
320
|
+
token_accuracy_1d = (
|
|
321
|
+
torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
|
|
322
|
+
)
|
|
286
323
|
|
|
287
324
|
target_mask = target != ignore_index
|
|
288
325
|
n_non_ignore = target_mask.sum().item()
|
|
@@ -319,6 +356,10 @@ def cross_entropy_forward(
|
|
|
319
356
|
loss_ptr=loss_1d,
|
|
320
357
|
z_loss_ptr=z_loss_1d,
|
|
321
358
|
loss_stride=loss_1d.stride(-1), # always 1
|
|
359
|
+
token_accuracy_ptr=token_accuracy_1d,
|
|
360
|
+
token_accuracy_stride=token_accuracy_1d.stride(-1)
|
|
361
|
+
if return_token_accuracy
|
|
362
|
+
else 0, # always 1 if accuracy is enabled
|
|
322
363
|
n_cols=V,
|
|
323
364
|
n_non_ignore=n_non_ignore,
|
|
324
365
|
sum_non_ignore_weight=sum_non_ignore_weight,
|
|
@@ -329,9 +370,11 @@ def cross_entropy_forward(
|
|
|
329
370
|
reduction=reduction,
|
|
330
371
|
softcap=softcap,
|
|
331
372
|
RETURN_Z_LOSS=return_z_loss,
|
|
373
|
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
|
332
374
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
333
375
|
HAS_WEIGHT=True if weight is not None else False,
|
|
334
376
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
377
|
+
HAS_GRADIENTS=_input.requires_grad,
|
|
335
378
|
# TODO: 32 seems to give the best performance
|
|
336
379
|
# Performance is quite sensitive to num_warps
|
|
337
380
|
num_warps=32 if not is_hip() else 16,
|
|
@@ -340,18 +383,24 @@ def cross_entropy_forward(
|
|
|
340
383
|
if reduction == "none":
|
|
341
384
|
loss = loss_1d
|
|
342
385
|
z_loss = z_loss_1d if return_z_loss else None
|
|
386
|
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
|
343
387
|
else:
|
|
344
388
|
loss = torch.sum(loss_1d)
|
|
345
389
|
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
390
|
+
# For accuracy, we compute the mean across all non-ignored tokens
|
|
391
|
+
token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
|
|
346
392
|
|
|
347
|
-
return loss, z_loss, _input
|
|
393
|
+
return loss, z_loss, token_accuracy, _input
|
|
348
394
|
|
|
349
395
|
|
|
350
396
|
def cross_entropy_backward(_input, grad_output):
|
|
351
397
|
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
|
352
398
|
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
353
399
|
pass
|
|
354
|
-
|
|
400
|
+
# If reduction is 'none'
|
|
401
|
+
elif grad_output.ndim > 0:
|
|
402
|
+
_input = _input * grad_output.unsqueeze(dim=1)
|
|
403
|
+
# If reduction is ['mean', 'sum'], grad_output is just a scalar
|
|
355
404
|
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
|
356
405
|
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
|
357
406
|
else:
|
|
@@ -389,6 +438,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
389
438
|
reduction: str = "mean",
|
|
390
439
|
softcap: Optional[float] = None,
|
|
391
440
|
return_z_loss: bool = False,
|
|
441
|
+
return_token_accuracy: bool = False,
|
|
392
442
|
):
|
|
393
443
|
"""
|
|
394
444
|
The forward pass of the Liger Cross Entropy loss.
|
|
@@ -403,12 +453,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
403
453
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
404
454
|
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
|
|
405
455
|
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
406
|
-
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
|
|
456
|
+
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
|
|
457
|
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
|
407
458
|
|
|
408
459
|
Returns:
|
|
409
|
-
tuple: A tuple with the
|
|
460
|
+
tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
|
|
410
461
|
"""
|
|
411
|
-
|
|
462
|
+
input_requires_grad = _input.requires_grad
|
|
463
|
+
|
|
464
|
+
loss, z_loss, token_accuracy, _input = cross_entropy_forward(
|
|
412
465
|
_input,
|
|
413
466
|
target,
|
|
414
467
|
weight,
|
|
@@ -418,29 +471,35 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
418
471
|
reduction,
|
|
419
472
|
softcap,
|
|
420
473
|
return_z_loss,
|
|
474
|
+
return_token_accuracy,
|
|
421
475
|
)
|
|
422
476
|
# TODO: investigation
|
|
423
477
|
# If we don't detach the _input tensor, the memory will double
|
|
424
478
|
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
|
425
|
-
|
|
479
|
+
if input_requires_grad:
|
|
480
|
+
ctx.save_for_backward(_input.detach())
|
|
426
481
|
ctx.return_z_loss = return_z_loss
|
|
482
|
+
ctx.return_token_accuracy = return_token_accuracy
|
|
427
483
|
|
|
428
|
-
return loss, z_loss
|
|
484
|
+
return loss, z_loss, token_accuracy
|
|
429
485
|
|
|
430
486
|
@staticmethod
|
|
431
|
-
def backward(ctx, grad_output,
|
|
487
|
+
def backward(ctx, grad_output, grad_output2, grad_output3):
|
|
432
488
|
"""
|
|
433
489
|
The backward pass of the Liger Cross Entropy loss.
|
|
434
490
|
|
|
435
491
|
Parameters:
|
|
436
492
|
ctx : The context object with saved tensors.
|
|
437
493
|
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
|
|
438
|
-
grad_output2 (
|
|
494
|
+
grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
|
|
495
|
+
grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
|
|
439
496
|
Returns:
|
|
440
497
|
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
|
|
441
498
|
"""
|
|
442
499
|
if ctx.return_z_loss:
|
|
443
|
-
del
|
|
500
|
+
del grad_output2 # z_loss is only for logging
|
|
501
|
+
if ctx.return_token_accuracy:
|
|
502
|
+
del grad_output3 # token_accuracy is only for metrics
|
|
444
503
|
|
|
445
504
|
(_input,) = ctx.saved_tensors
|
|
446
505
|
_input = cross_entropy_backward(_input, grad_output)
|
|
@@ -454,4 +513,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
454
513
|
None,
|
|
455
514
|
None,
|
|
456
515
|
None,
|
|
516
|
+
None,
|
|
457
517
|
)
|