liger-kernel-nightly 0.6.2.dev20250919191028__py3-none-any.whl → 0.6.4.dev20251202054858__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
- liger_kernel/chunked_loss/grpo_loss.py +8 -5
- liger_kernel/chunked_loss/jsd_loss.py +18 -5
- liger_kernel/ops/cross_entropy.py +120 -63
- liger_kernel/ops/dyt.py +5 -2
- liger_kernel/ops/fused_add_rms_norm.py +5 -1
- liger_kernel/ops/fused_linear_cross_entropy.py +43 -12
- liger_kernel/ops/geglu.py +2 -1
- liger_kernel/ops/group_norm.py +2 -1
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/layer_norm.py +88 -70
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +7 -2
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +33 -0
- liger_kernel/transformers/cross_entropy.py +8 -3
- liger_kernel/transformers/functional.py +29 -6
- liger_kernel/transformers/fused_linear_cross_entropy.py +8 -3
- liger_kernel/transformers/grpo_loss.py +56 -1
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +19 -7
- liger_kernel/transformers/model/gemma2.py +22 -7
- liger_kernel/transformers/model/gemma3.py +52 -14
- liger_kernel/transformers/model/glm4.py +18 -5
- liger_kernel/transformers/model/glm4v.py +18 -5
- liger_kernel/transformers/model/glm4v_moe.py +25 -5
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +16 -6
- liger_kernel/transformers/model/llama4.py +18 -5
- liger_kernel/transformers/model/llava.py +18 -6
- liger_kernel/transformers/model/loss_utils.py +31 -3
- liger_kernel/transformers/model/mistral.py +17 -7
- liger_kernel/transformers/model/mixtral.py +24 -9
- liger_kernel/transformers/model/mllama.py +14 -5
- liger_kernel/transformers/model/olmo2.py +18 -5
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +41 -5
- liger_kernel/transformers/model/phi3.py +16 -8
- liger_kernel/transformers/model/qwen2.py +18 -4
- liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
- liger_kernel/transformers/model/qwen2_vl.py +24 -7
- liger_kernel/transformers/model/qwen3.py +22 -6
- liger_kernel/transformers/model/qwen3_moe.py +27 -7
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +17 -7
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +729 -4
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/rms_norm.py +7 -0
- liger_kernel/transformers/rope.py +43 -0
- liger_kernel/transformers/swiglu.py +17 -0
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/utils.py +25 -0
- {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/METADATA +13 -6
- liger_kernel_nightly-0.6.4.dev20251202054858.dist-info/RECORD +118 -0
- liger_kernel_nightly-0.6.2.dev20250919191028.dist-info/RECORD +0 -105
- {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/top_level.txt +0 -0
|
@@ -10,8 +10,9 @@ from liger_kernel.ops.utils import compare_version
|
|
|
10
10
|
from liger_kernel.ops.utils import element_mul_kernel
|
|
11
11
|
from liger_kernel.ops.utils import is_hip
|
|
12
12
|
from liger_kernel.utils import infer_device
|
|
13
|
+
from liger_kernel.utils import is_npu_available
|
|
13
14
|
|
|
14
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
15
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
15
16
|
try:
|
|
16
17
|
# typical import path with dispatch available
|
|
17
18
|
from triton.language.extra.libdevice import tanh
|
|
@@ -32,6 +33,8 @@ def liger_cross_entropy_kernel(
|
|
|
32
33
|
loss_ptr,
|
|
33
34
|
z_loss_ptr,
|
|
34
35
|
loss_stride,
|
|
36
|
+
token_accuracy_ptr,
|
|
37
|
+
token_accuracy_stride,
|
|
35
38
|
n_cols,
|
|
36
39
|
n_non_ignore,
|
|
37
40
|
sum_non_ignore_weight,
|
|
@@ -42,9 +45,11 @@ def liger_cross_entropy_kernel(
|
|
|
42
45
|
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
|
|
43
46
|
softcap,
|
|
44
47
|
RETURN_Z_LOSS: tl.constexpr,
|
|
48
|
+
RETURN_TOKEN_ACCURACY: tl.constexpr,
|
|
45
49
|
BLOCK_SIZE: tl.constexpr,
|
|
46
50
|
HAS_WEIGHT: tl.constexpr,
|
|
47
51
|
HAS_SOFTCAPPING: tl.constexpr,
|
|
52
|
+
HAS_GRADIENTS: tl.constexpr,
|
|
48
53
|
):
|
|
49
54
|
"""
|
|
50
55
|
This kernel computes both cross entropy loss and the gradient of the input.
|
|
@@ -59,6 +64,8 @@ def liger_cross_entropy_kernel(
|
|
|
59
64
|
loss_ptr: Pointer to tensor to store the loss.
|
|
60
65
|
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
|
61
66
|
loss_stride (int): The stride of the loss tensor.
|
|
67
|
+
token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
|
|
68
|
+
token_accuracy_stride (int): The stride of the token accuracy tensor.
|
|
62
69
|
n_cols (int): The number of columns in the input tensor.
|
|
63
70
|
n_non_ignore (float): The number of non-ignored elements in the batch.
|
|
64
71
|
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
|
@@ -68,10 +75,12 @@ def liger_cross_entropy_kernel(
|
|
|
68
75
|
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
69
76
|
reduction (str): The string for the reduction to apply
|
|
70
77
|
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
71
|
-
RETURN_Z_LOSS (int): The boolean value to decide whether
|
|
78
|
+
RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
|
|
79
|
+
RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
|
|
72
80
|
BLOCK_SIZE (int): The block size for Triton operations.
|
|
73
81
|
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
|
74
82
|
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
|
83
|
+
HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
|
|
75
84
|
"""
|
|
76
85
|
|
|
77
86
|
# https://github.com/triton-lang/triton/issues/1058
|
|
@@ -90,11 +99,17 @@ def liger_cross_entropy_kernel(
|
|
|
90
99
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
91
100
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
92
101
|
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
|
|
102
|
+
# For ignored tokens, set token accuracy to 0
|
|
103
|
+
if RETURN_TOKEN_ACCURACY:
|
|
104
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
105
|
+
tl.store(token_accuracy_ptr, 0.0)
|
|
93
106
|
return
|
|
94
107
|
|
|
95
108
|
loss_ptr += program_id * loss_stride
|
|
96
109
|
if RETURN_Z_LOSS:
|
|
97
110
|
z_loss_ptr += program_id * loss_stride
|
|
111
|
+
if RETURN_TOKEN_ACCURACY:
|
|
112
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
98
113
|
|
|
99
114
|
if HAS_WEIGHT:
|
|
100
115
|
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
|
|
@@ -105,6 +120,7 @@ def liger_cross_entropy_kernel(
|
|
|
105
120
|
# 3. [Online softmax] first pass: find max + sum
|
|
106
121
|
m = float("-inf") # m is the max value. use the notation from the paper
|
|
107
122
|
d = 0.0 # d is the sum. use the notation from the paper
|
|
123
|
+
argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
|
|
108
124
|
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
|
|
109
125
|
if HAS_SOFTCAPPING:
|
|
110
126
|
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
|
@@ -125,6 +141,16 @@ def liger_cross_entropy_kernel(
|
|
|
125
141
|
if HAS_SOFTCAPPING:
|
|
126
142
|
X_block = softcap * tanh(X_block / softcap)
|
|
127
143
|
block_max = tl.max(X_block)
|
|
144
|
+
|
|
145
|
+
# Track argmax for accuracy computation
|
|
146
|
+
if RETURN_TOKEN_ACCURACY and block_max > m:
|
|
147
|
+
# Find the index of the maximum value in this block
|
|
148
|
+
is_max_mask = X_block == block_max
|
|
149
|
+
# Mask out invalid indices with a value larger than n_cols
|
|
150
|
+
masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
|
|
151
|
+
# Get the first (smallest) index where max occurs
|
|
152
|
+
argmax_idx = tl.min(masked_offsets)
|
|
153
|
+
|
|
128
154
|
if label_smoothing > 0:
|
|
129
155
|
# scale X beforehand to avoid overflow
|
|
130
156
|
if HAS_WEIGHT:
|
|
@@ -155,58 +181,58 @@ def liger_cross_entropy_kernel(
|
|
|
155
181
|
# For 'sum' reduction, no normalization is applied:
|
|
156
182
|
# dx_y = softmax(x_y) - 1
|
|
157
183
|
# dx_i = softmax(x_i), for i ≠ y
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
184
|
+
if HAS_GRADIENTS:
|
|
185
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
186
|
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
187
|
+
X_block = tl.load(
|
|
188
|
+
X_ptr + X_offsets,
|
|
189
|
+
mask=X_offsets < n_cols,
|
|
190
|
+
other=float("-inf"),
|
|
191
|
+
# Ensure float32 precision for softmax calculation
|
|
192
|
+
).cast(tl.float32)
|
|
193
|
+
if HAS_SOFTCAPPING:
|
|
194
|
+
intermediate = tanh(X_block / softcap)
|
|
195
|
+
X_block = softcap * intermediate
|
|
196
|
+
|
|
197
|
+
if not HAS_WEIGHT:
|
|
198
|
+
# softmax(x_i)
|
|
199
|
+
X_block = tl.exp(X_block - m) / d
|
|
200
|
+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
|
201
|
+
X_block += 2 * lse_square_scale * lse * X_block
|
|
202
|
+
# smoothing term
|
|
203
|
+
X_block += -eps
|
|
204
|
+
# special handle dx_y
|
|
205
|
+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
|
|
206
|
+
# reduction scale
|
|
207
|
+
if reduction == "mean":
|
|
208
|
+
X_block = X_block / n_non_ignore
|
|
209
|
+
else:
|
|
210
|
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
211
|
+
softmax_X = tl.exp(X_block - m) / d
|
|
212
|
+
# derivative of original_loss
|
|
213
|
+
dloss_ori = (1 - label_smoothing) * softmax_X
|
|
214
|
+
# specially handle dx_y
|
|
215
|
+
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
|
|
216
|
+
dloss_ori = dloss_ori * weight_y
|
|
217
|
+
# derivative of smooth_loss
|
|
218
|
+
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
|
|
219
|
+
# derivative of z-loss
|
|
220
|
+
dz_loss = 2 * lse_square_scale * lse * softmax_X
|
|
221
|
+
# reduction scale
|
|
222
|
+
if reduction == "mean":
|
|
223
|
+
dloss_ori = dloss_ori / sum_non_ignore_weight
|
|
224
|
+
dloss_smooth = dloss_smooth / sum_non_ignore_weight
|
|
225
|
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
|
226
|
+
dz_loss = dz_loss / n_non_ignore
|
|
227
|
+
# derivative of total_loss
|
|
228
|
+
X_block = dloss_ori + dloss_smooth + dz_loss
|
|
229
|
+
|
|
230
|
+
# chain rule softcapping
|
|
231
|
+
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
|
|
232
|
+
if HAS_SOFTCAPPING:
|
|
233
|
+
X_block = X_block * (1 - intermediate * intermediate)
|
|
234
|
+
|
|
235
|
+
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
|
210
236
|
|
|
211
237
|
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
|
|
212
238
|
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
|
|
@@ -254,6 +280,10 @@ def liger_cross_entropy_kernel(
|
|
|
254
280
|
tl.store(loss_ptr, loss)
|
|
255
281
|
if RETURN_Z_LOSS:
|
|
256
282
|
tl.store(z_loss_ptr, z_loss)
|
|
283
|
+
if RETURN_TOKEN_ACCURACY:
|
|
284
|
+
# Store 1.0 if prediction is correct, 0.0 otherwise
|
|
285
|
+
is_correct = 1.0 if argmax_idx == y else 0.0
|
|
286
|
+
tl.store(token_accuracy_ptr, is_correct)
|
|
257
287
|
|
|
258
288
|
|
|
259
289
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
@@ -272,8 +302,12 @@ def cross_entropy_forward(
|
|
|
272
302
|
reduction,
|
|
273
303
|
softcap,
|
|
274
304
|
return_z_loss,
|
|
305
|
+
return_token_accuracy=False,
|
|
275
306
|
):
|
|
276
307
|
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
308
|
+
assert isinstance(return_token_accuracy, bool), (
|
|
309
|
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
|
310
|
+
)
|
|
277
311
|
|
|
278
312
|
BT, V = _input.shape
|
|
279
313
|
n_rows = BT
|
|
@@ -283,6 +317,9 @@ def cross_entropy_forward(
|
|
|
283
317
|
# unreduced loss
|
|
284
318
|
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
285
319
|
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
320
|
+
token_accuracy_1d = (
|
|
321
|
+
torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
|
|
322
|
+
)
|
|
286
323
|
|
|
287
324
|
target_mask = target != ignore_index
|
|
288
325
|
n_non_ignore = target_mask.sum().item()
|
|
@@ -319,6 +356,10 @@ def cross_entropy_forward(
|
|
|
319
356
|
loss_ptr=loss_1d,
|
|
320
357
|
z_loss_ptr=z_loss_1d,
|
|
321
358
|
loss_stride=loss_1d.stride(-1), # always 1
|
|
359
|
+
token_accuracy_ptr=token_accuracy_1d,
|
|
360
|
+
token_accuracy_stride=token_accuracy_1d.stride(-1)
|
|
361
|
+
if return_token_accuracy
|
|
362
|
+
else 0, # always 1 if accuracy is enabled
|
|
322
363
|
n_cols=V,
|
|
323
364
|
n_non_ignore=n_non_ignore,
|
|
324
365
|
sum_non_ignore_weight=sum_non_ignore_weight,
|
|
@@ -329,9 +370,11 @@ def cross_entropy_forward(
|
|
|
329
370
|
reduction=reduction,
|
|
330
371
|
softcap=softcap,
|
|
331
372
|
RETURN_Z_LOSS=return_z_loss,
|
|
373
|
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
|
332
374
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
333
375
|
HAS_WEIGHT=True if weight is not None else False,
|
|
334
376
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
377
|
+
HAS_GRADIENTS=_input.requires_grad,
|
|
335
378
|
# TODO: 32 seems to give the best performance
|
|
336
379
|
# Performance is quite sensitive to num_warps
|
|
337
380
|
num_warps=32 if not is_hip() else 16,
|
|
@@ -340,11 +383,14 @@ def cross_entropy_forward(
|
|
|
340
383
|
if reduction == "none":
|
|
341
384
|
loss = loss_1d
|
|
342
385
|
z_loss = z_loss_1d if return_z_loss else None
|
|
386
|
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
|
343
387
|
else:
|
|
344
388
|
loss = torch.sum(loss_1d)
|
|
345
389
|
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
390
|
+
# For accuracy, we compute the mean across all non-ignored tokens
|
|
391
|
+
token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
|
|
346
392
|
|
|
347
|
-
return loss, z_loss, _input
|
|
393
|
+
return loss, z_loss, token_accuracy, _input
|
|
348
394
|
|
|
349
395
|
|
|
350
396
|
def cross_entropy_backward(_input, grad_output):
|
|
@@ -392,6 +438,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
392
438
|
reduction: str = "mean",
|
|
393
439
|
softcap: Optional[float] = None,
|
|
394
440
|
return_z_loss: bool = False,
|
|
441
|
+
return_token_accuracy: bool = False,
|
|
395
442
|
):
|
|
396
443
|
"""
|
|
397
444
|
The forward pass of the Liger Cross Entropy loss.
|
|
@@ -406,12 +453,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
406
453
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
407
454
|
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
|
|
408
455
|
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
409
|
-
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
|
|
456
|
+
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
|
|
457
|
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
|
410
458
|
|
|
411
459
|
Returns:
|
|
412
|
-
tuple: A tuple with the
|
|
460
|
+
tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
|
|
413
461
|
"""
|
|
414
|
-
|
|
462
|
+
input_requires_grad = _input.requires_grad
|
|
463
|
+
|
|
464
|
+
loss, z_loss, token_accuracy, _input = cross_entropy_forward(
|
|
415
465
|
_input,
|
|
416
466
|
target,
|
|
417
467
|
weight,
|
|
@@ -421,29 +471,35 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
421
471
|
reduction,
|
|
422
472
|
softcap,
|
|
423
473
|
return_z_loss,
|
|
474
|
+
return_token_accuracy,
|
|
424
475
|
)
|
|
425
476
|
# TODO: investigation
|
|
426
477
|
# If we don't detach the _input tensor, the memory will double
|
|
427
478
|
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
|
428
|
-
|
|
479
|
+
if input_requires_grad:
|
|
480
|
+
ctx.save_for_backward(_input.detach())
|
|
429
481
|
ctx.return_z_loss = return_z_loss
|
|
482
|
+
ctx.return_token_accuracy = return_token_accuracy
|
|
430
483
|
|
|
431
|
-
return loss, z_loss
|
|
484
|
+
return loss, z_loss, token_accuracy
|
|
432
485
|
|
|
433
486
|
@staticmethod
|
|
434
|
-
def backward(ctx, grad_output,
|
|
487
|
+
def backward(ctx, grad_output, grad_output2, grad_output3):
|
|
435
488
|
"""
|
|
436
489
|
The backward pass of the Liger Cross Entropy loss.
|
|
437
490
|
|
|
438
491
|
Parameters:
|
|
439
492
|
ctx : The context object with saved tensors.
|
|
440
493
|
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
|
|
441
|
-
grad_output2 (
|
|
494
|
+
grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
|
|
495
|
+
grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
|
|
442
496
|
Returns:
|
|
443
497
|
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
|
|
444
498
|
"""
|
|
445
499
|
if ctx.return_z_loss:
|
|
446
|
-
del
|
|
500
|
+
del grad_output2 # z_loss is only for logging
|
|
501
|
+
if ctx.return_token_accuracy:
|
|
502
|
+
del grad_output3 # token_accuracy is only for metrics
|
|
447
503
|
|
|
448
504
|
(_input,) = ctx.saved_tensors
|
|
449
505
|
_input = cross_entropy_backward(_input, grad_output)
|
|
@@ -457,4 +513,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
457
513
|
None,
|
|
458
514
|
None,
|
|
459
515
|
None,
|
|
516
|
+
None,
|
|
460
517
|
)
|
liger_kernel/ops/dyt.py
CHANGED
|
@@ -7,8 +7,10 @@ import triton.language as tl
|
|
|
7
7
|
from liger_kernel.ops.utils import compare_version
|
|
8
8
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
9
|
from liger_kernel.ops.utils import infer_device
|
|
10
|
+
from liger_kernel.utils import get_npu_multi_processor_count
|
|
11
|
+
from liger_kernel.utils import is_npu_available
|
|
10
12
|
|
|
11
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
13
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
12
14
|
try:
|
|
13
15
|
# typical import path with dispatch available
|
|
14
16
|
from triton.language.extra.libdevice import tanh
|
|
@@ -125,7 +127,8 @@ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
|
|
|
125
127
|
NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
|
|
126
128
|
elif device == "xpu":
|
|
127
129
|
NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
|
|
128
|
-
|
|
130
|
+
elif device == "npu":
|
|
131
|
+
NUM_SMS = get_npu_multi_processor_count()
|
|
129
132
|
da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
|
|
130
133
|
dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
|
|
131
134
|
db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
|
|
@@ -9,8 +9,10 @@ from liger_kernel.ops.utils import calculate_settings
|
|
|
9
9
|
from liger_kernel.ops.utils import compare_version
|
|
10
10
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
11
11
|
from liger_kernel.ops.utils import torch_to_triton_dtype
|
|
12
|
+
from liger_kernel.utils import get_npu_multi_processor_count
|
|
13
|
+
from liger_kernel.utils import is_npu_available
|
|
12
14
|
|
|
13
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
15
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
14
16
|
try:
|
|
15
17
|
# typical import path with dispatch available
|
|
16
18
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -293,6 +295,8 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL
|
|
|
293
295
|
sm_count = torch.cuda.get_device_properties(S.device).multi_processor_count
|
|
294
296
|
elif S.device.type == "xpu":
|
|
295
297
|
sm_count = torch.xpu.get_device_properties(S.device).gpu_eu_count
|
|
298
|
+
elif S.device.type == "npu":
|
|
299
|
+
sm_count = get_npu_multi_processor_count()
|
|
296
300
|
|
|
297
301
|
# fp32 for numerical stability especially.
|
|
298
302
|
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
@@ -27,10 +27,16 @@ def fused_linear_cross_entropy_forward(
|
|
|
27
27
|
return_z_loss=False,
|
|
28
28
|
accum_dtype=None,
|
|
29
29
|
use_token_scaling=False,
|
|
30
|
+
return_token_accuracy=False,
|
|
30
31
|
):
|
|
31
32
|
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
33
|
+
assert isinstance(return_token_accuracy, bool), (
|
|
34
|
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
|
35
|
+
)
|
|
32
36
|
device = _input.device
|
|
33
37
|
|
|
38
|
+
input_requires_grad = _input.requires_grad
|
|
39
|
+
|
|
34
40
|
# inputs have shape: BT x H
|
|
35
41
|
# materialized activations will have shape: BT x V
|
|
36
42
|
# the increase in memory = BT x V
|
|
@@ -49,15 +55,20 @@ def fused_linear_cross_entropy_forward(
|
|
|
49
55
|
grad_input = torch.zeros_like(_input, device=device)
|
|
50
56
|
|
|
51
57
|
# we use fp32 for loss and gradients accumulator
|
|
52
|
-
if
|
|
53
|
-
|
|
54
|
-
|
|
58
|
+
if input_requires_grad:
|
|
59
|
+
if accum_dtype is None:
|
|
60
|
+
grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
|
|
61
|
+
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
|
|
62
|
+
else:
|
|
63
|
+
grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
|
|
64
|
+
grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
|
|
55
65
|
else:
|
|
56
|
-
grad_weight =
|
|
57
|
-
grad_bias =
|
|
66
|
+
grad_weight = None
|
|
67
|
+
grad_bias = None
|
|
58
68
|
|
|
59
69
|
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
|
60
70
|
z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
71
|
+
token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None
|
|
61
72
|
|
|
62
73
|
# TODO: evaluate how CUDA synchronization caused by .item() affects the speed
|
|
63
74
|
target_mask = target != ignore_index
|
|
@@ -123,6 +134,7 @@ def fused_linear_cross_entropy_forward(
|
|
|
123
134
|
# unreduced loss
|
|
124
135
|
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
|
125
136
|
z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
|
|
137
|
+
token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None
|
|
126
138
|
|
|
127
139
|
# ensure _input and target are contiguous
|
|
128
140
|
logits_chunk = logits_chunk.contiguous()
|
|
@@ -138,6 +150,10 @@ def fused_linear_cross_entropy_forward(
|
|
|
138
150
|
loss_ptr=loss_1d_slice,
|
|
139
151
|
z_loss_ptr=z_loss_1d_slice,
|
|
140
152
|
loss_stride=loss_1d_slice.stride(-1), # always 1
|
|
153
|
+
token_accuracy_ptr=token_accuracy_1d_slice,
|
|
154
|
+
token_accuracy_stride=token_accuracy_1d_slice.stride(-1)
|
|
155
|
+
if return_token_accuracy
|
|
156
|
+
else 0, # always 1 if accuracy is enabled
|
|
141
157
|
n_cols=V,
|
|
142
158
|
n_non_ignore=total_n_non_ignore,
|
|
143
159
|
sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
|
|
@@ -148,8 +164,10 @@ def fused_linear_cross_entropy_forward(
|
|
|
148
164
|
reduction=reduction,
|
|
149
165
|
softcap=softcap,
|
|
150
166
|
RETURN_Z_LOSS=return_z_loss,
|
|
167
|
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
|
151
168
|
HAS_WEIGHT=True if ce_weight is not None else False,
|
|
152
169
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
170
|
+
HAS_GRADIENTS=input_requires_grad,
|
|
153
171
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
154
172
|
num_warps=32 if not is_hip() else 16,
|
|
155
173
|
)
|
|
@@ -163,6 +181,8 @@ def fused_linear_cross_entropy_forward(
|
|
|
163
181
|
loss_1d[start_idx:end_idx] = loss_1d_slice
|
|
164
182
|
if return_z_loss:
|
|
165
183
|
z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
|
|
184
|
+
if return_token_accuracy:
|
|
185
|
+
token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice
|
|
166
186
|
grad_logits_chunk = logits_chunk # chunk_size x V
|
|
167
187
|
|
|
168
188
|
# Apply token scaling to gradients if requested
|
|
@@ -171,12 +191,13 @@ def fused_linear_cross_entropy_forward(
|
|
|
171
191
|
scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
|
|
172
192
|
grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
|
|
173
193
|
|
|
174
|
-
|
|
194
|
+
if input_requires_grad:
|
|
195
|
+
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
|
|
175
196
|
|
|
176
|
-
if grad_weight is not None:
|
|
197
|
+
if grad_weight is not None and input_requires_grad:
|
|
177
198
|
grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
|
|
178
199
|
|
|
179
|
-
if bias is not None:
|
|
200
|
+
if bias is not None and input_requires_grad:
|
|
180
201
|
torch.add(
|
|
181
202
|
input=grad_bias,
|
|
182
203
|
other=grad_logits_chunk.sum(dim=0),
|
|
@@ -193,15 +214,18 @@ def fused_linear_cross_entropy_forward(
|
|
|
193
214
|
# Return per-token losses
|
|
194
215
|
loss = loss_1d
|
|
195
216
|
z_loss = z_loss_1d if return_z_loss else None
|
|
217
|
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
|
196
218
|
else:
|
|
197
219
|
loss = torch.sum(loss_1d)
|
|
198
220
|
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
221
|
+
# For accuracy, we compute the mean across all non-ignored tokens
|
|
222
|
+
token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None
|
|
199
223
|
|
|
200
224
|
# Cast back to original dtype
|
|
201
225
|
grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
|
|
202
226
|
grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None
|
|
203
227
|
|
|
204
|
-
return loss, z_loss, grad_input, grad_weight, grad_bias
|
|
228
|
+
return loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias
|
|
205
229
|
|
|
206
230
|
|
|
207
231
|
def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
|
|
@@ -269,6 +293,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
269
293
|
return_z_loss: bool = False,
|
|
270
294
|
accum_dtype=None,
|
|
271
295
|
use_token_scaling: bool = False,
|
|
296
|
+
return_token_accuracy: bool = False,
|
|
272
297
|
):
|
|
273
298
|
"""
|
|
274
299
|
Fusing the last linear layer with cross-entropy loss
|
|
@@ -292,9 +317,10 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
292
317
|
use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
|
|
293
318
|
When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
|
|
294
319
|
Default: False.
|
|
320
|
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
|
295
321
|
"""
|
|
296
322
|
|
|
297
|
-
loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
323
|
+
loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
298
324
|
_input=_input,
|
|
299
325
|
weight=weight,
|
|
300
326
|
target=target,
|
|
@@ -308,6 +334,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
308
334
|
return_z_loss=return_z_loss,
|
|
309
335
|
accum_dtype=accum_dtype,
|
|
310
336
|
use_token_scaling=use_token_scaling,
|
|
337
|
+
return_token_accuracy=return_token_accuracy,
|
|
311
338
|
)
|
|
312
339
|
# downcast to dtype and store for backward
|
|
313
340
|
ctx.save_for_backward(
|
|
@@ -316,13 +343,16 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
316
343
|
grad_bias.detach() if bias is not None else None,
|
|
317
344
|
)
|
|
318
345
|
ctx.return_z_loss = return_z_loss
|
|
319
|
-
|
|
346
|
+
ctx.return_token_accuracy = return_token_accuracy
|
|
347
|
+
return loss, z_loss, token_accuracy
|
|
320
348
|
|
|
321
349
|
@staticmethod
|
|
322
350
|
@amp_custom_bwd
|
|
323
|
-
def backward(ctx, grad_output, grad_output2):
|
|
351
|
+
def backward(ctx, grad_output, grad_output2, grad_output3):
|
|
324
352
|
if ctx.return_z_loss:
|
|
325
353
|
del grad_output2 # z_loss is only for logging
|
|
354
|
+
if ctx.return_token_accuracy:
|
|
355
|
+
del grad_output3 # token_accuracy is only for metrics
|
|
326
356
|
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
|
|
327
357
|
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
|
|
328
358
|
grad_output, grad_input, grad_weight, grad_bias
|
|
@@ -341,4 +371,5 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
341
371
|
None,
|
|
342
372
|
None,
|
|
343
373
|
None, # use_token_scaling
|
|
374
|
+
None, # return_token_accuracy
|
|
344
375
|
)
|
liger_kernel/ops/geglu.py
CHANGED
|
@@ -7,8 +7,9 @@ import triton.language as tl
|
|
|
7
7
|
from liger_kernel.ops.utils import calculate_settings
|
|
8
8
|
from liger_kernel.ops.utils import compare_version
|
|
9
9
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
from liger_kernel.utils import is_npu_available
|
|
10
11
|
|
|
11
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
12
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
12
13
|
try:
|
|
13
14
|
# typical import path with dispatch available
|
|
14
15
|
from triton.language.extra.libdevice import tanh
|
liger_kernel/ops/group_norm.py
CHANGED
|
@@ -6,8 +6,9 @@ import triton.language as tl
|
|
|
6
6
|
|
|
7
7
|
from liger_kernel.ops.utils import compare_version
|
|
8
8
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
from liger_kernel.utils import is_npu_available
|
|
9
10
|
|
|
10
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
11
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
11
12
|
try:
|
|
12
13
|
# typical import path with dispatch available
|
|
13
14
|
from triton.language.extra.libdevice import rsqrt
|
liger_kernel/ops/grpo_loss.py
CHANGED
|
@@ -128,7 +128,9 @@ def _grpo_loss_fwd_kernel(
|
|
|
128
128
|
per_token_loss1 = coef_1 * advantage
|
|
129
129
|
per_token_loss2 = coef_2 * advantage
|
|
130
130
|
per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
|
|
131
|
-
|
|
131
|
+
is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
|
|
132
|
+
is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
|
|
133
|
+
is_clipped = is_low_clipped | is_high_clipped
|
|
132
134
|
|
|
133
135
|
if BETA != 0.0:
|
|
134
136
|
REF_LOGP += off_b * L + off_l
|