liger-kernel 0.6.2__py3-none-any.whl → 0.6.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
- liger_kernel/chunked_loss/grpo_loss.py +46 -9
- liger_kernel/chunked_loss/jsd_loss.py +23 -7
- liger_kernel/ops/cross_entropy.py +118 -62
- liger_kernel/ops/fused_linear_cross_entropy.py +97 -13
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/layer_norm.py +86 -69
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/transformers/__init__.py +36 -0
- liger_kernel/transformers/cross_entropy.py +8 -3
- liger_kernel/transformers/functional.py +31 -6
- liger_kernel/transformers/fused_linear_cross_entropy.py +13 -4
- liger_kernel/transformers/grpo_loss.py +56 -1
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +19 -7
- liger_kernel/transformers/model/gemma2.py +22 -7
- liger_kernel/transformers/model/gemma3.py +52 -14
- liger_kernel/transformers/model/glm4.py +18 -5
- liger_kernel/transformers/model/glm4v.py +19 -6
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +16 -6
- liger_kernel/transformers/model/llama4.py +18 -5
- liger_kernel/transformers/model/llava.py +18 -6
- liger_kernel/transformers/model/loss_utils.py +32 -3
- liger_kernel/transformers/model/mistral.py +17 -7
- liger_kernel/transformers/model/mixtral.py +24 -9
- liger_kernel/transformers/model/mllama.py +14 -5
- liger_kernel/transformers/model/olmo2.py +18 -5
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +41 -5
- liger_kernel/transformers/model/phi3.py +16 -8
- liger_kernel/transformers/model/qwen2.py +18 -4
- liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
- liger_kernel/transformers/model/qwen2_vl.py +24 -7
- liger_kernel/transformers/model/qwen3.py +22 -6
- liger_kernel/transformers/model/qwen3_moe.py +27 -7
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +17 -7
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +830 -3
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/rms_norm.py +7 -0
- liger_kernel/transformers/rope.py +43 -0
- liger_kernel/transformers/swiglu.py +17 -0
- liger_kernel/transformers/tiled_mlp.py +133 -0
- {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.4.dist-info}/METADATA +16 -10
- liger_kernel-0.6.4.dist-info/RECORD +118 -0
- liger_kernel-0.6.2.dist-info/RECORD +0 -104
- {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.4.dist-info}/WHEEL +0 -0
- {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.4.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.4.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.4.dist-info}/top_level.txt +0 -0
|
@@ -32,6 +32,8 @@ def liger_cross_entropy_kernel(
|
|
|
32
32
|
loss_ptr,
|
|
33
33
|
z_loss_ptr,
|
|
34
34
|
loss_stride,
|
|
35
|
+
token_accuracy_ptr,
|
|
36
|
+
token_accuracy_stride,
|
|
35
37
|
n_cols,
|
|
36
38
|
n_non_ignore,
|
|
37
39
|
sum_non_ignore_weight,
|
|
@@ -42,9 +44,11 @@ def liger_cross_entropy_kernel(
|
|
|
42
44
|
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
|
|
43
45
|
softcap,
|
|
44
46
|
RETURN_Z_LOSS: tl.constexpr,
|
|
47
|
+
RETURN_TOKEN_ACCURACY: tl.constexpr,
|
|
45
48
|
BLOCK_SIZE: tl.constexpr,
|
|
46
49
|
HAS_WEIGHT: tl.constexpr,
|
|
47
50
|
HAS_SOFTCAPPING: tl.constexpr,
|
|
51
|
+
HAS_GRADIENTS: tl.constexpr,
|
|
48
52
|
):
|
|
49
53
|
"""
|
|
50
54
|
This kernel computes both cross entropy loss and the gradient of the input.
|
|
@@ -59,6 +63,8 @@ def liger_cross_entropy_kernel(
|
|
|
59
63
|
loss_ptr: Pointer to tensor to store the loss.
|
|
60
64
|
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
|
61
65
|
loss_stride (int): The stride of the loss tensor.
|
|
66
|
+
token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
|
|
67
|
+
token_accuracy_stride (int): The stride of the token accuracy tensor.
|
|
62
68
|
n_cols (int): The number of columns in the input tensor.
|
|
63
69
|
n_non_ignore (float): The number of non-ignored elements in the batch.
|
|
64
70
|
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
|
@@ -68,10 +74,12 @@ def liger_cross_entropy_kernel(
|
|
|
68
74
|
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
69
75
|
reduction (str): The string for the reduction to apply
|
|
70
76
|
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
71
|
-
RETURN_Z_LOSS (int): The boolean value to decide whether
|
|
77
|
+
RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
|
|
78
|
+
RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
|
|
72
79
|
BLOCK_SIZE (int): The block size for Triton operations.
|
|
73
80
|
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
|
74
81
|
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
|
82
|
+
HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
|
|
75
83
|
"""
|
|
76
84
|
|
|
77
85
|
# https://github.com/triton-lang/triton/issues/1058
|
|
@@ -90,11 +98,17 @@ def liger_cross_entropy_kernel(
|
|
|
90
98
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
91
99
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
92
100
|
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
|
|
101
|
+
# For ignored tokens, set token accuracy to 0
|
|
102
|
+
if RETURN_TOKEN_ACCURACY:
|
|
103
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
104
|
+
tl.store(token_accuracy_ptr, 0.0)
|
|
93
105
|
return
|
|
94
106
|
|
|
95
107
|
loss_ptr += program_id * loss_stride
|
|
96
108
|
if RETURN_Z_LOSS:
|
|
97
109
|
z_loss_ptr += program_id * loss_stride
|
|
110
|
+
if RETURN_TOKEN_ACCURACY:
|
|
111
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
98
112
|
|
|
99
113
|
if HAS_WEIGHT:
|
|
100
114
|
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
|
|
@@ -105,6 +119,7 @@ def liger_cross_entropy_kernel(
|
|
|
105
119
|
# 3. [Online softmax] first pass: find max + sum
|
|
106
120
|
m = float("-inf") # m is the max value. use the notation from the paper
|
|
107
121
|
d = 0.0 # d is the sum. use the notation from the paper
|
|
122
|
+
argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
|
|
108
123
|
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
|
|
109
124
|
if HAS_SOFTCAPPING:
|
|
110
125
|
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
|
@@ -125,6 +140,16 @@ def liger_cross_entropy_kernel(
|
|
|
125
140
|
if HAS_SOFTCAPPING:
|
|
126
141
|
X_block = softcap * tanh(X_block / softcap)
|
|
127
142
|
block_max = tl.max(X_block)
|
|
143
|
+
|
|
144
|
+
# Track argmax for accuracy computation
|
|
145
|
+
if RETURN_TOKEN_ACCURACY and block_max > m:
|
|
146
|
+
# Find the index of the maximum value in this block
|
|
147
|
+
is_max_mask = X_block == block_max
|
|
148
|
+
# Mask out invalid indices with a value larger than n_cols
|
|
149
|
+
masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
|
|
150
|
+
# Get the first (smallest) index where max occurs
|
|
151
|
+
argmax_idx = tl.min(masked_offsets)
|
|
152
|
+
|
|
128
153
|
if label_smoothing > 0:
|
|
129
154
|
# scale X beforehand to avoid overflow
|
|
130
155
|
if HAS_WEIGHT:
|
|
@@ -155,58 +180,58 @@ def liger_cross_entropy_kernel(
|
|
|
155
180
|
# For 'sum' reduction, no normalization is applied:
|
|
156
181
|
# dx_y = softmax(x_y) - 1
|
|
157
182
|
# dx_i = softmax(x_i), for i ≠ y
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
183
|
+
if HAS_GRADIENTS:
|
|
184
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
185
|
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
186
|
+
X_block = tl.load(
|
|
187
|
+
X_ptr + X_offsets,
|
|
188
|
+
mask=X_offsets < n_cols,
|
|
189
|
+
other=float("-inf"),
|
|
190
|
+
# Ensure float32 precision for softmax calculation
|
|
191
|
+
).cast(tl.float32)
|
|
192
|
+
if HAS_SOFTCAPPING:
|
|
193
|
+
intermediate = tanh(X_block / softcap)
|
|
194
|
+
X_block = softcap * intermediate
|
|
195
|
+
|
|
196
|
+
if not HAS_WEIGHT:
|
|
197
|
+
# softmax(x_i)
|
|
198
|
+
X_block = tl.exp(X_block - m) / d
|
|
199
|
+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
|
200
|
+
X_block += 2 * lse_square_scale * lse * X_block
|
|
201
|
+
# smoothing term
|
|
202
|
+
X_block += -eps
|
|
203
|
+
# special handle dx_y
|
|
204
|
+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
|
|
205
|
+
# reduction scale
|
|
206
|
+
if reduction == "mean":
|
|
207
|
+
X_block = X_block / n_non_ignore
|
|
208
|
+
else:
|
|
209
|
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
210
|
+
softmax_X = tl.exp(X_block - m) / d
|
|
211
|
+
# derivative of original_loss
|
|
212
|
+
dloss_ori = (1 - label_smoothing) * softmax_X
|
|
213
|
+
# specially handle dx_y
|
|
214
|
+
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
|
|
215
|
+
dloss_ori = dloss_ori * weight_y
|
|
216
|
+
# derivative of smooth_loss
|
|
217
|
+
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
|
|
218
|
+
# derivative of z-loss
|
|
219
|
+
dz_loss = 2 * lse_square_scale * lse * softmax_X
|
|
220
|
+
# reduction scale
|
|
221
|
+
if reduction == "mean":
|
|
222
|
+
dloss_ori = dloss_ori / sum_non_ignore_weight
|
|
223
|
+
dloss_smooth = dloss_smooth / sum_non_ignore_weight
|
|
224
|
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
|
225
|
+
dz_loss = dz_loss / n_non_ignore
|
|
226
|
+
# derivative of total_loss
|
|
227
|
+
X_block = dloss_ori + dloss_smooth + dz_loss
|
|
228
|
+
|
|
229
|
+
# chain rule softcapping
|
|
230
|
+
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
|
|
231
|
+
if HAS_SOFTCAPPING:
|
|
232
|
+
X_block = X_block * (1 - intermediate * intermediate)
|
|
233
|
+
|
|
234
|
+
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
|
210
235
|
|
|
211
236
|
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
|
|
212
237
|
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
|
|
@@ -254,6 +279,10 @@ def liger_cross_entropy_kernel(
|
|
|
254
279
|
tl.store(loss_ptr, loss)
|
|
255
280
|
if RETURN_Z_LOSS:
|
|
256
281
|
tl.store(z_loss_ptr, z_loss)
|
|
282
|
+
if RETURN_TOKEN_ACCURACY:
|
|
283
|
+
# Store 1.0 if prediction is correct, 0.0 otherwise
|
|
284
|
+
is_correct = 1.0 if argmax_idx == y else 0.0
|
|
285
|
+
tl.store(token_accuracy_ptr, is_correct)
|
|
257
286
|
|
|
258
287
|
|
|
259
288
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
@@ -272,8 +301,12 @@ def cross_entropy_forward(
|
|
|
272
301
|
reduction,
|
|
273
302
|
softcap,
|
|
274
303
|
return_z_loss,
|
|
304
|
+
return_token_accuracy=False,
|
|
275
305
|
):
|
|
276
306
|
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
307
|
+
assert isinstance(return_token_accuracy, bool), (
|
|
308
|
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
|
309
|
+
)
|
|
277
310
|
|
|
278
311
|
BT, V = _input.shape
|
|
279
312
|
n_rows = BT
|
|
@@ -283,6 +316,9 @@ def cross_entropy_forward(
|
|
|
283
316
|
# unreduced loss
|
|
284
317
|
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
285
318
|
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
319
|
+
token_accuracy_1d = (
|
|
320
|
+
torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
|
|
321
|
+
)
|
|
286
322
|
|
|
287
323
|
target_mask = target != ignore_index
|
|
288
324
|
n_non_ignore = target_mask.sum().item()
|
|
@@ -319,6 +355,10 @@ def cross_entropy_forward(
|
|
|
319
355
|
loss_ptr=loss_1d,
|
|
320
356
|
z_loss_ptr=z_loss_1d,
|
|
321
357
|
loss_stride=loss_1d.stride(-1), # always 1
|
|
358
|
+
token_accuracy_ptr=token_accuracy_1d,
|
|
359
|
+
token_accuracy_stride=token_accuracy_1d.stride(-1)
|
|
360
|
+
if return_token_accuracy
|
|
361
|
+
else 0, # always 1 if accuracy is enabled
|
|
322
362
|
n_cols=V,
|
|
323
363
|
n_non_ignore=n_non_ignore,
|
|
324
364
|
sum_non_ignore_weight=sum_non_ignore_weight,
|
|
@@ -329,9 +369,11 @@ def cross_entropy_forward(
|
|
|
329
369
|
reduction=reduction,
|
|
330
370
|
softcap=softcap,
|
|
331
371
|
RETURN_Z_LOSS=return_z_loss,
|
|
372
|
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
|
332
373
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
333
374
|
HAS_WEIGHT=True if weight is not None else False,
|
|
334
375
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
376
|
+
HAS_GRADIENTS=_input.requires_grad,
|
|
335
377
|
# TODO: 32 seems to give the best performance
|
|
336
378
|
# Performance is quite sensitive to num_warps
|
|
337
379
|
num_warps=32 if not is_hip() else 16,
|
|
@@ -340,11 +382,14 @@ def cross_entropy_forward(
|
|
|
340
382
|
if reduction == "none":
|
|
341
383
|
loss = loss_1d
|
|
342
384
|
z_loss = z_loss_1d if return_z_loss else None
|
|
385
|
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
|
343
386
|
else:
|
|
344
387
|
loss = torch.sum(loss_1d)
|
|
345
388
|
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
389
|
+
# For accuracy, we compute the mean across all non-ignored tokens
|
|
390
|
+
token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
|
|
346
391
|
|
|
347
|
-
return loss, z_loss, _input
|
|
392
|
+
return loss, z_loss, token_accuracy, _input
|
|
348
393
|
|
|
349
394
|
|
|
350
395
|
def cross_entropy_backward(_input, grad_output):
|
|
@@ -392,6 +437,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
392
437
|
reduction: str = "mean",
|
|
393
438
|
softcap: Optional[float] = None,
|
|
394
439
|
return_z_loss: bool = False,
|
|
440
|
+
return_token_accuracy: bool = False,
|
|
395
441
|
):
|
|
396
442
|
"""
|
|
397
443
|
The forward pass of the Liger Cross Entropy loss.
|
|
@@ -406,12 +452,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
406
452
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
407
453
|
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
|
|
408
454
|
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
409
|
-
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
|
|
455
|
+
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
|
|
456
|
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
|
410
457
|
|
|
411
458
|
Returns:
|
|
412
|
-
tuple: A tuple with the
|
|
459
|
+
tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
|
|
413
460
|
"""
|
|
414
|
-
|
|
461
|
+
input_requires_grad = _input.requires_grad
|
|
462
|
+
|
|
463
|
+
loss, z_loss, token_accuracy, _input = cross_entropy_forward(
|
|
415
464
|
_input,
|
|
416
465
|
target,
|
|
417
466
|
weight,
|
|
@@ -421,29 +470,35 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
421
470
|
reduction,
|
|
422
471
|
softcap,
|
|
423
472
|
return_z_loss,
|
|
473
|
+
return_token_accuracy,
|
|
424
474
|
)
|
|
425
475
|
# TODO: investigation
|
|
426
476
|
# If we don't detach the _input tensor, the memory will double
|
|
427
477
|
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
|
428
|
-
|
|
478
|
+
if input_requires_grad:
|
|
479
|
+
ctx.save_for_backward(_input.detach())
|
|
429
480
|
ctx.return_z_loss = return_z_loss
|
|
481
|
+
ctx.return_token_accuracy = return_token_accuracy
|
|
430
482
|
|
|
431
|
-
return loss, z_loss
|
|
483
|
+
return loss, z_loss, token_accuracy
|
|
432
484
|
|
|
433
485
|
@staticmethod
|
|
434
|
-
def backward(ctx, grad_output,
|
|
486
|
+
def backward(ctx, grad_output, grad_output2, grad_output3):
|
|
435
487
|
"""
|
|
436
488
|
The backward pass of the Liger Cross Entropy loss.
|
|
437
489
|
|
|
438
490
|
Parameters:
|
|
439
491
|
ctx : The context object with saved tensors.
|
|
440
492
|
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
|
|
441
|
-
grad_output2 (
|
|
493
|
+
grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
|
|
494
|
+
grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
|
|
442
495
|
Returns:
|
|
443
496
|
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
|
|
444
497
|
"""
|
|
445
498
|
if ctx.return_z_loss:
|
|
446
|
-
del
|
|
499
|
+
del grad_output2 # z_loss is only for logging
|
|
500
|
+
if ctx.return_token_accuracy:
|
|
501
|
+
del grad_output3 # token_accuracy is only for metrics
|
|
447
502
|
|
|
448
503
|
(_input,) = ctx.saved_tensors
|
|
449
504
|
_input = cross_entropy_backward(_input, grad_output)
|
|
@@ -457,4 +512,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
457
512
|
None,
|
|
458
513
|
None,
|
|
459
514
|
None,
|
|
515
|
+
None,
|
|
460
516
|
)
|
|
@@ -26,10 +26,17 @@ def fused_linear_cross_entropy_forward(
|
|
|
26
26
|
softcap=None,
|
|
27
27
|
return_z_loss=False,
|
|
28
28
|
accum_dtype=None,
|
|
29
|
+
use_token_scaling=False,
|
|
30
|
+
return_token_accuracy=False,
|
|
29
31
|
):
|
|
30
32
|
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
33
|
+
assert isinstance(return_token_accuracy, bool), (
|
|
34
|
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
|
35
|
+
)
|
|
31
36
|
device = _input.device
|
|
32
37
|
|
|
38
|
+
input_requires_grad = _input.requires_grad
|
|
39
|
+
|
|
33
40
|
# inputs have shape: BT x H
|
|
34
41
|
# materialized activations will have shape: BT x V
|
|
35
42
|
# the increase in memory = BT x V
|
|
@@ -48,15 +55,20 @@ def fused_linear_cross_entropy_forward(
|
|
|
48
55
|
grad_input = torch.zeros_like(_input, device=device)
|
|
49
56
|
|
|
50
57
|
# we use fp32 for loss and gradients accumulator
|
|
51
|
-
if
|
|
52
|
-
|
|
53
|
-
|
|
58
|
+
if input_requires_grad:
|
|
59
|
+
if accum_dtype is None:
|
|
60
|
+
grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
|
|
61
|
+
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
|
|
62
|
+
else:
|
|
63
|
+
grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
|
|
64
|
+
grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
|
|
54
65
|
else:
|
|
55
|
-
grad_weight =
|
|
56
|
-
grad_bias =
|
|
66
|
+
grad_weight = None
|
|
67
|
+
grad_bias = None
|
|
57
68
|
|
|
58
69
|
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
|
59
70
|
z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
71
|
+
token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None
|
|
60
72
|
|
|
61
73
|
# TODO: evaluate how CUDA synchronization caused by .item() affects the speed
|
|
62
74
|
target_mask = target != ignore_index
|
|
@@ -89,9 +101,40 @@ def fused_linear_cross_entropy_forward(
|
|
|
89
101
|
|
|
90
102
|
n_rows = logits_chunk.shape[0]
|
|
91
103
|
|
|
104
|
+
# Compute predicted probabilities for token scaling if needed
|
|
105
|
+
if use_token_scaling:
|
|
106
|
+
# Compute softmax probabilities for scaling
|
|
107
|
+
# We need to compute this before the cross entropy kernel modifies logits_chunk
|
|
108
|
+
logits_for_softmax = logits_chunk.detach().clone() # Detach to avoid gradient flow
|
|
109
|
+
if softcap is not None:
|
|
110
|
+
logits_for_softmax = softcap * torch.tanh(logits_for_softmax / softcap)
|
|
111
|
+
|
|
112
|
+
# Compute softmax to get predicted probabilities
|
|
113
|
+
probs = torch.softmax(logits_for_softmax, dim=-1)
|
|
114
|
+
|
|
115
|
+
# Get predicted probabilities for token scaling, handling ignored targets
|
|
116
|
+
valid_target_mask = target_chunk != ignore_index
|
|
117
|
+
valid_targets = target_chunk[valid_target_mask]
|
|
118
|
+
|
|
119
|
+
if len(valid_targets) > 0:
|
|
120
|
+
# Gather probabilities only for valid targets
|
|
121
|
+
valid_probs = probs[valid_target_mask]
|
|
122
|
+
pred_probs_valid = torch.gather(valid_probs, -1, valid_targets.unsqueeze(-1)).squeeze(-1)
|
|
123
|
+
|
|
124
|
+
# Create full tensor with zeros for ignored targets
|
|
125
|
+
pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
|
|
126
|
+
pred_probs[valid_target_mask] = pred_probs_valid
|
|
127
|
+
else:
|
|
128
|
+
# All targets are ignored
|
|
129
|
+
pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
|
|
130
|
+
|
|
131
|
+
# Store the scaling factors
|
|
132
|
+
scaling_factors = pred_probs.detach() # Detach to ensure no gradient flow
|
|
133
|
+
|
|
92
134
|
# unreduced loss
|
|
93
135
|
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
|
94
136
|
z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
|
|
137
|
+
token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None
|
|
95
138
|
|
|
96
139
|
# ensure _input and target are contiguous
|
|
97
140
|
logits_chunk = logits_chunk.contiguous()
|
|
@@ -107,6 +150,10 @@ def fused_linear_cross_entropy_forward(
|
|
|
107
150
|
loss_ptr=loss_1d_slice,
|
|
108
151
|
z_loss_ptr=z_loss_1d_slice,
|
|
109
152
|
loss_stride=loss_1d_slice.stride(-1), # always 1
|
|
153
|
+
token_accuracy_ptr=token_accuracy_1d_slice,
|
|
154
|
+
token_accuracy_stride=token_accuracy_1d_slice.stride(-1)
|
|
155
|
+
if return_token_accuracy
|
|
156
|
+
else 0, # always 1 if accuracy is enabled
|
|
110
157
|
n_cols=V,
|
|
111
158
|
n_non_ignore=total_n_non_ignore,
|
|
112
159
|
sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
|
|
@@ -117,26 +164,43 @@ def fused_linear_cross_entropy_forward(
|
|
|
117
164
|
reduction=reduction,
|
|
118
165
|
softcap=softcap,
|
|
119
166
|
RETURN_Z_LOSS=return_z_loss,
|
|
167
|
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
|
120
168
|
HAS_WEIGHT=True if ce_weight is not None else False,
|
|
121
169
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
170
|
+
HAS_GRADIENTS=input_requires_grad,
|
|
122
171
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
123
172
|
num_warps=32 if not is_hip() else 16,
|
|
124
173
|
)
|
|
125
174
|
|
|
175
|
+
# Apply token scaling if requested
|
|
176
|
+
if use_token_scaling:
|
|
177
|
+
loss_1d_slice = loss_1d_slice * scaling_factors
|
|
178
|
+
if return_z_loss:
|
|
179
|
+
z_loss_1d_slice = z_loss_1d_slice * scaling_factors
|
|
180
|
+
|
|
126
181
|
loss_1d[start_idx:end_idx] = loss_1d_slice
|
|
127
182
|
if return_z_loss:
|
|
128
183
|
z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
|
|
184
|
+
if return_token_accuracy:
|
|
185
|
+
token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice
|
|
129
186
|
grad_logits_chunk = logits_chunk # chunk_size x V
|
|
130
187
|
|
|
131
|
-
|
|
188
|
+
# Apply token scaling to gradients if requested
|
|
189
|
+
if use_token_scaling:
|
|
190
|
+
# Expand scaling factors to match gradient dimensions
|
|
191
|
+
scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
|
|
192
|
+
grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
|
|
132
193
|
|
|
133
|
-
if
|
|
194
|
+
if input_requires_grad:
|
|
195
|
+
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
|
|
196
|
+
|
|
197
|
+
if grad_weight is not None and input_requires_grad:
|
|
134
198
|
grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
|
|
135
199
|
|
|
136
|
-
if bias is not None:
|
|
200
|
+
if bias is not None and input_requires_grad:
|
|
137
201
|
torch.add(
|
|
138
202
|
input=grad_bias,
|
|
139
|
-
other=
|
|
203
|
+
other=grad_logits_chunk.sum(dim=0),
|
|
140
204
|
out=grad_bias,
|
|
141
205
|
alpha=1.0,
|
|
142
206
|
)
|
|
@@ -146,15 +210,22 @@ def fused_linear_cross_entropy_forward(
|
|
|
146
210
|
# loss = loss_1d
|
|
147
211
|
# z_loss = z_loss_1d if return_z_loss else None
|
|
148
212
|
|
|
213
|
+
if reduction == "none":
|
|
214
|
+
# Return per-token losses
|
|
215
|
+
loss = loss_1d
|
|
216
|
+
z_loss = z_loss_1d if return_z_loss else None
|
|
217
|
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
|
149
218
|
else:
|
|
150
219
|
loss = torch.sum(loss_1d)
|
|
151
220
|
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
221
|
+
# For accuracy, we compute the mean across all non-ignored tokens
|
|
222
|
+
token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None
|
|
152
223
|
|
|
153
224
|
# Cast back to original dtype
|
|
154
225
|
grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
|
|
155
226
|
grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None
|
|
156
227
|
|
|
157
|
-
return loss, z_loss, grad_input, grad_weight, grad_bias
|
|
228
|
+
return loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias
|
|
158
229
|
|
|
159
230
|
|
|
160
231
|
def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
|
|
@@ -221,6 +292,8 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
221
292
|
softcap=None,
|
|
222
293
|
return_z_loss: bool = False,
|
|
223
294
|
accum_dtype=None,
|
|
295
|
+
use_token_scaling: bool = False,
|
|
296
|
+
return_token_accuracy: bool = False,
|
|
224
297
|
):
|
|
225
298
|
"""
|
|
226
299
|
Fusing the last linear layer with cross-entropy loss
|
|
@@ -241,9 +314,13 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
241
314
|
reduction: reduction to apply
|
|
242
315
|
accum_dtype (torch.dtype): the dtype of intermediate result buffers for weight and bias gradient accumulations.
|
|
243
316
|
Recommended to set `accum_dtype` to higher precision, e.g. `torch.float32`, if the training is unstable with original dtype. Default: `None`, performing accumulations in original dtype
|
|
317
|
+
use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
|
|
318
|
+
When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
|
|
319
|
+
Default: False.
|
|
320
|
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
|
244
321
|
"""
|
|
245
322
|
|
|
246
|
-
loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
323
|
+
loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
247
324
|
_input=_input,
|
|
248
325
|
weight=weight,
|
|
249
326
|
target=target,
|
|
@@ -256,6 +333,8 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
256
333
|
softcap=softcap,
|
|
257
334
|
return_z_loss=return_z_loss,
|
|
258
335
|
accum_dtype=accum_dtype,
|
|
336
|
+
use_token_scaling=use_token_scaling,
|
|
337
|
+
return_token_accuracy=return_token_accuracy,
|
|
259
338
|
)
|
|
260
339
|
# downcast to dtype and store for backward
|
|
261
340
|
ctx.save_for_backward(
|
|
@@ -264,13 +343,16 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
264
343
|
grad_bias.detach() if bias is not None else None,
|
|
265
344
|
)
|
|
266
345
|
ctx.return_z_loss = return_z_loss
|
|
267
|
-
|
|
346
|
+
ctx.return_token_accuracy = return_token_accuracy
|
|
347
|
+
return loss, z_loss, token_accuracy
|
|
268
348
|
|
|
269
349
|
@staticmethod
|
|
270
350
|
@amp_custom_bwd
|
|
271
|
-
def backward(ctx, grad_output, grad_output2):
|
|
351
|
+
def backward(ctx, grad_output, grad_output2, grad_output3):
|
|
272
352
|
if ctx.return_z_loss:
|
|
273
353
|
del grad_output2 # z_loss is only for logging
|
|
354
|
+
if ctx.return_token_accuracy:
|
|
355
|
+
del grad_output3 # token_accuracy is only for metrics
|
|
274
356
|
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
|
|
275
357
|
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
|
|
276
358
|
grad_output, grad_input, grad_weight, grad_bias
|
|
@@ -288,4 +370,6 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
288
370
|
None,
|
|
289
371
|
None,
|
|
290
372
|
None,
|
|
373
|
+
None, # use_token_scaling
|
|
374
|
+
None, # return_token_accuracy
|
|
291
375
|
)
|
liger_kernel/ops/grpo_loss.py
CHANGED
|
@@ -128,7 +128,9 @@ def _grpo_loss_fwd_kernel(
|
|
|
128
128
|
per_token_loss1 = coef_1 * advantage
|
|
129
129
|
per_token_loss2 = coef_2 * advantage
|
|
130
130
|
per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
|
|
131
|
-
|
|
131
|
+
is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
|
|
132
|
+
is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
|
|
133
|
+
is_clipped = is_low_clipped | is_high_clipped
|
|
132
134
|
|
|
133
135
|
if BETA != 0.0:
|
|
134
136
|
REF_LOGP += off_b * L + off_l
|