liger-kernel-nightly 0.5.10.dev20250611191801__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
- liger_kernel/chunked_loss/dpo_loss.py +54 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
- liger_kernel/chunked_loss/grpo_loss.py +46 -9
- liger_kernel/chunked_loss/jsd_loss.py +44 -13
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +130 -64
- liger_kernel/ops/dyt.py +5 -4
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
- liger_kernel/ops/geglu.py +6 -4
- liger_kernel/ops/group_norm.py +7 -7
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +135 -80
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +148 -71
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +14 -0
- liger_kernel/transformers/__init__.py +65 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +56 -24
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +17 -5
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +57 -2
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +28 -8
- liger_kernel/transformers/model/gemma2.py +34 -11
- liger_kernel/transformers/model/gemma3.py +102 -112
- liger_kernel/transformers/model/glm4.py +18 -5
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +26 -7
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +18 -6
- liger_kernel/transformers/model/loss_utils.py +34 -3
- liger_kernel/transformers/model/mistral.py +17 -10
- liger_kernel/transformers/model/mixtral.py +24 -9
- liger_kernel/transformers/model/mllama.py +18 -7
- liger_kernel/transformers/model/olmo2.py +18 -5
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +42 -5
- liger_kernel/transformers/model/phi3.py +24 -159
- liger_kernel/transformers/model/qwen2.py +26 -4
- liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
- liger_kernel/transformers/model/qwen2_vl.py +24 -7
- liger_kernel/transformers/model/qwen3.py +22 -6
- liger_kernel/transformers/model/qwen3_moe.py +27 -7
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +1423 -100
- liger_kernel/transformers/multi_token_attention.py +2 -2
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +15 -5
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +18 -1
- liger_kernel/transformers/tiled_mlp.py +125 -0
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +52 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +37 -25
- liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
- liger_kernel_nightly-0.5.10.dev20250611191801.dist-info/RECORD +0 -95
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
|
@@ -10,8 +10,9 @@ from liger_kernel.ops.utils import compare_version
|
|
|
10
10
|
from liger_kernel.ops.utils import element_mul_kernel
|
|
11
11
|
from liger_kernel.ops.utils import is_hip
|
|
12
12
|
from liger_kernel.utils import infer_device
|
|
13
|
+
from liger_kernel.utils import is_npu_available
|
|
13
14
|
|
|
14
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
15
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
15
16
|
try:
|
|
16
17
|
# typical import path with dispatch available
|
|
17
18
|
from triton.language.extra.libdevice import tanh
|
|
@@ -32,6 +33,8 @@ def liger_cross_entropy_kernel(
|
|
|
32
33
|
loss_ptr,
|
|
33
34
|
z_loss_ptr,
|
|
34
35
|
loss_stride,
|
|
36
|
+
token_accuracy_ptr,
|
|
37
|
+
token_accuracy_stride,
|
|
35
38
|
n_cols,
|
|
36
39
|
n_non_ignore,
|
|
37
40
|
sum_non_ignore_weight,
|
|
@@ -42,9 +45,11 @@ def liger_cross_entropy_kernel(
|
|
|
42
45
|
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
|
|
43
46
|
softcap,
|
|
44
47
|
RETURN_Z_LOSS: tl.constexpr,
|
|
48
|
+
RETURN_TOKEN_ACCURACY: tl.constexpr,
|
|
45
49
|
BLOCK_SIZE: tl.constexpr,
|
|
46
50
|
HAS_WEIGHT: tl.constexpr,
|
|
47
51
|
HAS_SOFTCAPPING: tl.constexpr,
|
|
52
|
+
HAS_GRADIENTS: tl.constexpr,
|
|
48
53
|
):
|
|
49
54
|
"""
|
|
50
55
|
This kernel computes both cross entropy loss and the gradient of the input.
|
|
@@ -59,6 +64,8 @@ def liger_cross_entropy_kernel(
|
|
|
59
64
|
loss_ptr: Pointer to tensor to store the loss.
|
|
60
65
|
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
|
61
66
|
loss_stride (int): The stride of the loss tensor.
|
|
67
|
+
token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
|
|
68
|
+
token_accuracy_stride (int): The stride of the token accuracy tensor.
|
|
62
69
|
n_cols (int): The number of columns in the input tensor.
|
|
63
70
|
n_non_ignore (float): The number of non-ignored elements in the batch.
|
|
64
71
|
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
|
@@ -68,10 +75,12 @@ def liger_cross_entropy_kernel(
|
|
|
68
75
|
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
69
76
|
reduction (str): The string for the reduction to apply
|
|
70
77
|
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
71
|
-
RETURN_Z_LOSS (int): The boolean value to decide whether
|
|
78
|
+
RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
|
|
79
|
+
RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
|
|
72
80
|
BLOCK_SIZE (int): The block size for Triton operations.
|
|
73
81
|
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
|
74
82
|
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
|
83
|
+
HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
|
|
75
84
|
"""
|
|
76
85
|
|
|
77
86
|
# https://github.com/triton-lang/triton/issues/1058
|
|
@@ -90,11 +99,17 @@ def liger_cross_entropy_kernel(
|
|
|
90
99
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
91
100
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
92
101
|
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
|
|
102
|
+
# For ignored tokens, set token accuracy to 0
|
|
103
|
+
if RETURN_TOKEN_ACCURACY:
|
|
104
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
105
|
+
tl.store(token_accuracy_ptr, 0.0)
|
|
93
106
|
return
|
|
94
107
|
|
|
95
108
|
loss_ptr += program_id * loss_stride
|
|
96
109
|
if RETURN_Z_LOSS:
|
|
97
110
|
z_loss_ptr += program_id * loss_stride
|
|
111
|
+
if RETURN_TOKEN_ACCURACY:
|
|
112
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
98
113
|
|
|
99
114
|
if HAS_WEIGHT:
|
|
100
115
|
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
|
|
@@ -105,6 +120,7 @@ def liger_cross_entropy_kernel(
|
|
|
105
120
|
# 3. [Online softmax] first pass: find max + sum
|
|
106
121
|
m = float("-inf") # m is the max value. use the notation from the paper
|
|
107
122
|
d = 0.0 # d is the sum. use the notation from the paper
|
|
123
|
+
argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
|
|
108
124
|
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
|
|
109
125
|
if HAS_SOFTCAPPING:
|
|
110
126
|
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
|
@@ -125,6 +141,19 @@ def liger_cross_entropy_kernel(
|
|
|
125
141
|
if HAS_SOFTCAPPING:
|
|
126
142
|
X_block = softcap * tanh(X_block / softcap)
|
|
127
143
|
block_max = tl.max(X_block)
|
|
144
|
+
|
|
145
|
+
# Track argmax for accuracy computation
|
|
146
|
+
if RETURN_TOKEN_ACCURACY:
|
|
147
|
+
# Find the index of the maximum value in this block
|
|
148
|
+
is_max_mask = X_block == block_max
|
|
149
|
+
# Mask out invalid indices with a value larger than n_cols
|
|
150
|
+
masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
|
|
151
|
+
# Get the first (smallest) index where max occurs
|
|
152
|
+
current_block_argmax_idx = tl.min(masked_offsets)
|
|
153
|
+
|
|
154
|
+
is_new_max = block_max > m
|
|
155
|
+
argmax_idx = tl.where(is_new_max, current_block_argmax_idx, argmax_idx)
|
|
156
|
+
|
|
128
157
|
if label_smoothing > 0:
|
|
129
158
|
# scale X beforehand to avoid overflow
|
|
130
159
|
if HAS_WEIGHT:
|
|
@@ -155,58 +184,58 @@ def liger_cross_entropy_kernel(
|
|
|
155
184
|
# For 'sum' reduction, no normalization is applied:
|
|
156
185
|
# dx_y = softmax(x_y) - 1
|
|
157
186
|
# dx_i = softmax(x_i), for i ≠ y
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
187
|
+
if HAS_GRADIENTS:
|
|
188
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
189
|
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
190
|
+
X_block = tl.load(
|
|
191
|
+
X_ptr + X_offsets,
|
|
192
|
+
mask=X_offsets < n_cols,
|
|
193
|
+
other=float("-inf"),
|
|
194
|
+
# Ensure float32 precision for softmax calculation
|
|
195
|
+
).cast(tl.float32)
|
|
196
|
+
if HAS_SOFTCAPPING:
|
|
197
|
+
intermediate = tanh(X_block / softcap)
|
|
198
|
+
X_block = softcap * intermediate
|
|
199
|
+
|
|
200
|
+
if not HAS_WEIGHT:
|
|
201
|
+
# softmax(x_i)
|
|
202
|
+
X_block = tl.exp(X_block - m) / d
|
|
203
|
+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
|
204
|
+
X_block += 2 * lse_square_scale * lse * X_block
|
|
205
|
+
# smoothing term
|
|
206
|
+
X_block += -eps
|
|
207
|
+
# special handle dx_y
|
|
208
|
+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
|
|
209
|
+
# reduction scale
|
|
210
|
+
if reduction == "mean":
|
|
211
|
+
X_block = X_block / n_non_ignore
|
|
212
|
+
else:
|
|
213
|
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
214
|
+
softmax_X = tl.exp(X_block - m) / d
|
|
215
|
+
# derivative of original_loss
|
|
216
|
+
dloss_ori = (1 - label_smoothing) * softmax_X
|
|
217
|
+
# specially handle dx_y
|
|
218
|
+
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
|
|
219
|
+
dloss_ori = dloss_ori * weight_y
|
|
220
|
+
# derivative of smooth_loss
|
|
221
|
+
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
|
|
222
|
+
# derivative of z-loss
|
|
223
|
+
dz_loss = 2 * lse_square_scale * lse * softmax_X
|
|
224
|
+
# reduction scale
|
|
225
|
+
if reduction == "mean":
|
|
226
|
+
dloss_ori = dloss_ori / sum_non_ignore_weight
|
|
227
|
+
dloss_smooth = dloss_smooth / sum_non_ignore_weight
|
|
228
|
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
|
229
|
+
dz_loss = dz_loss / n_non_ignore
|
|
230
|
+
# derivative of total_loss
|
|
231
|
+
X_block = dloss_ori + dloss_smooth + dz_loss
|
|
232
|
+
|
|
233
|
+
# chain rule softcapping
|
|
234
|
+
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
|
|
235
|
+
if HAS_SOFTCAPPING:
|
|
236
|
+
X_block = X_block * (1 - intermediate * intermediate)
|
|
237
|
+
|
|
238
|
+
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
|
210
239
|
|
|
211
240
|
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
|
|
212
241
|
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
|
|
@@ -254,12 +283,22 @@ def liger_cross_entropy_kernel(
|
|
|
254
283
|
tl.store(loss_ptr, loss)
|
|
255
284
|
if RETURN_Z_LOSS:
|
|
256
285
|
tl.store(z_loss_ptr, z_loss)
|
|
286
|
+
if RETURN_TOKEN_ACCURACY:
|
|
287
|
+
# Store 1.0 if prediction is correct, 0.0 otherwise
|
|
288
|
+
is_correct = 1.0 if argmax_idx == y else 0.0
|
|
289
|
+
tl.store(token_accuracy_ptr, is_correct)
|
|
257
290
|
|
|
258
291
|
|
|
259
292
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
260
293
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
261
294
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
262
|
-
|
|
295
|
+
# the best size we found by manually tuning on xpu and npu.
|
|
296
|
+
if infer_device() == "xpu":
|
|
297
|
+
MAX_FUSED_SIZE = 4096
|
|
298
|
+
elif infer_device() == "npu":
|
|
299
|
+
MAX_FUSED_SIZE = 2048
|
|
300
|
+
else:
|
|
301
|
+
MAX_FUSED_SIZE = 65536 // 2
|
|
263
302
|
|
|
264
303
|
|
|
265
304
|
def cross_entropy_forward(
|
|
@@ -272,8 +311,12 @@ def cross_entropy_forward(
|
|
|
272
311
|
reduction,
|
|
273
312
|
softcap,
|
|
274
313
|
return_z_loss,
|
|
314
|
+
return_token_accuracy=False,
|
|
275
315
|
):
|
|
276
316
|
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
317
|
+
assert isinstance(return_token_accuracy, bool), (
|
|
318
|
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
|
319
|
+
)
|
|
277
320
|
|
|
278
321
|
BT, V = _input.shape
|
|
279
322
|
n_rows = BT
|
|
@@ -283,6 +326,9 @@ def cross_entropy_forward(
|
|
|
283
326
|
# unreduced loss
|
|
284
327
|
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
285
328
|
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
329
|
+
token_accuracy_1d = (
|
|
330
|
+
torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
|
|
331
|
+
)
|
|
286
332
|
|
|
287
333
|
target_mask = target != ignore_index
|
|
288
334
|
n_non_ignore = target_mask.sum().item()
|
|
@@ -319,6 +365,10 @@ def cross_entropy_forward(
|
|
|
319
365
|
loss_ptr=loss_1d,
|
|
320
366
|
z_loss_ptr=z_loss_1d,
|
|
321
367
|
loss_stride=loss_1d.stride(-1), # always 1
|
|
368
|
+
token_accuracy_ptr=token_accuracy_1d,
|
|
369
|
+
token_accuracy_stride=token_accuracy_1d.stride(-1)
|
|
370
|
+
if return_token_accuracy
|
|
371
|
+
else 0, # always 1 if accuracy is enabled
|
|
322
372
|
n_cols=V,
|
|
323
373
|
n_non_ignore=n_non_ignore,
|
|
324
374
|
sum_non_ignore_weight=sum_non_ignore_weight,
|
|
@@ -329,9 +379,11 @@ def cross_entropy_forward(
|
|
|
329
379
|
reduction=reduction,
|
|
330
380
|
softcap=softcap,
|
|
331
381
|
RETURN_Z_LOSS=return_z_loss,
|
|
382
|
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
|
332
383
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
333
384
|
HAS_WEIGHT=True if weight is not None else False,
|
|
334
385
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
386
|
+
HAS_GRADIENTS=_input.requires_grad,
|
|
335
387
|
# TODO: 32 seems to give the best performance
|
|
336
388
|
# Performance is quite sensitive to num_warps
|
|
337
389
|
num_warps=32 if not is_hip() else 16,
|
|
@@ -340,11 +392,14 @@ def cross_entropy_forward(
|
|
|
340
392
|
if reduction == "none":
|
|
341
393
|
loss = loss_1d
|
|
342
394
|
z_loss = z_loss_1d if return_z_loss else None
|
|
395
|
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
|
343
396
|
else:
|
|
344
397
|
loss = torch.sum(loss_1d)
|
|
345
398
|
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
399
|
+
# For accuracy, we compute the mean across all non-ignored tokens
|
|
400
|
+
token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
|
|
346
401
|
|
|
347
|
-
return loss, z_loss, _input
|
|
402
|
+
return loss, z_loss, token_accuracy, _input
|
|
348
403
|
|
|
349
404
|
|
|
350
405
|
def cross_entropy_backward(_input, grad_output):
|
|
@@ -392,6 +447,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
392
447
|
reduction: str = "mean",
|
|
393
448
|
softcap: Optional[float] = None,
|
|
394
449
|
return_z_loss: bool = False,
|
|
450
|
+
return_token_accuracy: bool = False,
|
|
395
451
|
):
|
|
396
452
|
"""
|
|
397
453
|
The forward pass of the Liger Cross Entropy loss.
|
|
@@ -406,12 +462,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
406
462
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
407
463
|
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
|
|
408
464
|
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
409
|
-
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
|
|
465
|
+
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
|
|
466
|
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
|
410
467
|
|
|
411
468
|
Returns:
|
|
412
|
-
tuple: A tuple with the
|
|
469
|
+
tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
|
|
413
470
|
"""
|
|
414
|
-
|
|
471
|
+
input_requires_grad = _input.requires_grad
|
|
472
|
+
|
|
473
|
+
loss, z_loss, token_accuracy, _input = cross_entropy_forward(
|
|
415
474
|
_input,
|
|
416
475
|
target,
|
|
417
476
|
weight,
|
|
@@ -421,29 +480,35 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
421
480
|
reduction,
|
|
422
481
|
softcap,
|
|
423
482
|
return_z_loss,
|
|
483
|
+
return_token_accuracy,
|
|
424
484
|
)
|
|
425
485
|
# TODO: investigation
|
|
426
486
|
# If we don't detach the _input tensor, the memory will double
|
|
427
487
|
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
|
428
|
-
|
|
488
|
+
if input_requires_grad:
|
|
489
|
+
ctx.save_for_backward(_input.detach())
|
|
429
490
|
ctx.return_z_loss = return_z_loss
|
|
491
|
+
ctx.return_token_accuracy = return_token_accuracy
|
|
430
492
|
|
|
431
|
-
return loss, z_loss
|
|
493
|
+
return loss, z_loss, token_accuracy
|
|
432
494
|
|
|
433
495
|
@staticmethod
|
|
434
|
-
def backward(ctx, grad_output,
|
|
496
|
+
def backward(ctx, grad_output, grad_output2, grad_output3):
|
|
435
497
|
"""
|
|
436
498
|
The backward pass of the Liger Cross Entropy loss.
|
|
437
499
|
|
|
438
500
|
Parameters:
|
|
439
501
|
ctx : The context object with saved tensors.
|
|
440
502
|
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
|
|
441
|
-
grad_output2 (
|
|
503
|
+
grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
|
|
504
|
+
grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
|
|
442
505
|
Returns:
|
|
443
506
|
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
|
|
444
507
|
"""
|
|
445
508
|
if ctx.return_z_loss:
|
|
446
|
-
del
|
|
509
|
+
del grad_output2 # z_loss is only for logging
|
|
510
|
+
if ctx.return_token_accuracy:
|
|
511
|
+
del grad_output3 # token_accuracy is only for metrics
|
|
447
512
|
|
|
448
513
|
(_input,) = ctx.saved_tensors
|
|
449
514
|
_input = cross_entropy_backward(_input, grad_output)
|
|
@@ -457,4 +522,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
457
522
|
None,
|
|
458
523
|
None,
|
|
459
524
|
None,
|
|
525
|
+
None,
|
|
460
526
|
)
|
liger_kernel/ops/dyt.py
CHANGED
|
@@ -4,13 +4,13 @@ import torch
|
|
|
4
4
|
import triton
|
|
5
5
|
import triton.language as tl
|
|
6
6
|
|
|
7
|
-
from triton.language.extra.libdevice import tanh
|
|
8
|
-
|
|
9
7
|
from liger_kernel.ops.utils import compare_version
|
|
10
8
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
11
9
|
from liger_kernel.ops.utils import infer_device
|
|
10
|
+
from liger_kernel.utils import get_npu_multi_processor_count
|
|
11
|
+
from liger_kernel.utils import is_npu_available
|
|
12
12
|
|
|
13
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
13
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
14
14
|
try:
|
|
15
15
|
# typical import path with dispatch available
|
|
16
16
|
from triton.language.extra.libdevice import tanh
|
|
@@ -127,7 +127,8 @@ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
|
|
|
127
127
|
NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
|
|
128
128
|
elif device == "xpu":
|
|
129
129
|
NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
|
|
130
|
-
|
|
130
|
+
elif device == "npu":
|
|
131
|
+
NUM_SMS = get_npu_multi_processor_count()
|
|
131
132
|
da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
|
|
132
133
|
dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
|
|
133
134
|
db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
|