liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +304 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +21 -4
- liger_kernel/ops/cross_entropy.py +235 -84
- liger_kernel/ops/dyt.py +157 -0
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
- liger_kernel/ops/fused_linear_jsd.py +17 -34
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +7 -18
- liger_kernel/ops/group_norm.py +305 -0
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/jsd.py +46 -21
- liger_kernel/ops/kl_div.py +23 -19
- liger_kernel/ops/layer_norm.py +150 -86
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +314 -84
- liger_kernel/ops/rope.py +32 -34
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +5 -9
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +8 -4
- liger_kernel/transformers/__init__.py +199 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +33 -20
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +291 -13
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
- liger_kernel/transformers/fused_linear_jsd.py +1 -4
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/jsd.py +2 -7
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +77 -77
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +128 -79
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +68 -64
- liger_kernel/transformers/model/mixtral.py +75 -91
- liger_kernel/transformers/model/mllama.py +63 -68
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +432 -0
- liger_kernel/transformers/model/phi3.py +59 -213
- liger_kernel/transformers/model/qwen2.py +75 -72
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +78 -98
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2106 -289
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +57 -6
- liger_kernel/transformers/rope.py +45 -2
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +23 -8
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- liger_kernel/utils.py +71 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
- liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
|
@@ -1,11 +1,25 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
1
5
|
import torch
|
|
2
6
|
import triton
|
|
3
7
|
import triton.language as tl
|
|
4
8
|
|
|
5
|
-
from liger_kernel.ops.utils import
|
|
9
|
+
from liger_kernel.ops.utils import compare_version
|
|
10
|
+
from liger_kernel.ops.utils import element_mul_kernel
|
|
11
|
+
from liger_kernel.ops.utils import is_hip
|
|
12
|
+
from liger_kernel.utils import infer_device
|
|
6
13
|
|
|
7
|
-
|
|
8
|
-
|
|
14
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
|
15
|
+
try:
|
|
16
|
+
# typical import path with dispatch available
|
|
17
|
+
from triton.language.extra.libdevice import tanh
|
|
18
|
+
except ModuleNotFoundError:
|
|
19
|
+
# for working with NGC containers
|
|
20
|
+
from triton.language.extra.cuda.libdevice import tanh
|
|
21
|
+
else:
|
|
22
|
+
from triton.language.math import tanh
|
|
9
23
|
|
|
10
24
|
|
|
11
25
|
@triton.jit
|
|
@@ -14,17 +28,27 @@ def liger_cross_entropy_kernel(
|
|
|
14
28
|
X_stride,
|
|
15
29
|
Y_ptr,
|
|
16
30
|
Y_stride,
|
|
31
|
+
weight_ptr,
|
|
17
32
|
loss_ptr,
|
|
18
33
|
z_loss_ptr,
|
|
19
34
|
loss_stride,
|
|
35
|
+
token_accuracy_ptr,
|
|
36
|
+
token_accuracy_stride,
|
|
20
37
|
n_cols,
|
|
21
38
|
n_non_ignore,
|
|
39
|
+
sum_non_ignore_weight,
|
|
40
|
+
weight_sum,
|
|
22
41
|
ignore_index,
|
|
23
42
|
lse_square_scale: tl.constexpr,
|
|
24
43
|
label_smoothing: tl.constexpr,
|
|
25
44
|
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
|
|
45
|
+
softcap,
|
|
26
46
|
RETURN_Z_LOSS: tl.constexpr,
|
|
47
|
+
RETURN_TOKEN_ACCURACY: tl.constexpr,
|
|
27
48
|
BLOCK_SIZE: tl.constexpr,
|
|
49
|
+
HAS_WEIGHT: tl.constexpr,
|
|
50
|
+
HAS_SOFTCAPPING: tl.constexpr,
|
|
51
|
+
HAS_GRADIENTS: tl.constexpr,
|
|
28
52
|
):
|
|
29
53
|
"""
|
|
30
54
|
This kernel computes both cross entropy loss and the gradient of the input.
|
|
@@ -35,17 +59,27 @@ def liger_cross_entropy_kernel(
|
|
|
35
59
|
X_stride (int): The stride of the input tensor.
|
|
36
60
|
Y_ptr: Pointer to target tensor.
|
|
37
61
|
Y_stride (int): The stride of the target tensor.
|
|
62
|
+
weight_ptr: Pointer to weight tensor.
|
|
38
63
|
loss_ptr: Pointer to tensor to store the loss.
|
|
39
64
|
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
|
40
65
|
loss_stride (int): The stride of the loss tensor.
|
|
66
|
+
token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
|
|
67
|
+
token_accuracy_stride (int): The stride of the token accuracy tensor.
|
|
41
68
|
n_cols (int): The number of columns in the input tensor.
|
|
42
|
-
n_non_ignore (
|
|
69
|
+
n_non_ignore (float): The number of non-ignored elements in the batch.
|
|
70
|
+
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
|
71
|
+
weight_sum (float): The sum of weight tensor.
|
|
43
72
|
ignore_index (int): The index to ignore in the target.
|
|
44
73
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
45
74
|
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
46
|
-
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
|
|
47
75
|
reduction (str): The string for the reduction to apply
|
|
76
|
+
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
77
|
+
RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
|
|
78
|
+
RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
|
|
48
79
|
BLOCK_SIZE (int): The block size for Triton operations.
|
|
80
|
+
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
|
81
|
+
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
|
82
|
+
HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
|
|
49
83
|
"""
|
|
50
84
|
|
|
51
85
|
# https://github.com/triton-lang/triton/issues/1058
|
|
@@ -64,10 +98,20 @@ def liger_cross_entropy_kernel(
|
|
|
64
98
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
65
99
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
66
100
|
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
|
|
101
|
+
# For ignored tokens, set token accuracy to 0
|
|
102
|
+
if RETURN_TOKEN_ACCURACY:
|
|
103
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
104
|
+
tl.store(token_accuracy_ptr, 0.0)
|
|
67
105
|
return
|
|
68
106
|
|
|
69
107
|
loss_ptr += program_id * loss_stride
|
|
70
|
-
|
|
108
|
+
if RETURN_Z_LOSS:
|
|
109
|
+
z_loss_ptr += program_id * loss_stride
|
|
110
|
+
if RETURN_TOKEN_ACCURACY:
|
|
111
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
112
|
+
|
|
113
|
+
if HAS_WEIGHT:
|
|
114
|
+
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
|
|
71
115
|
|
|
72
116
|
# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
|
|
73
117
|
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
|
|
@@ -75,9 +119,10 @@ def liger_cross_entropy_kernel(
|
|
|
75
119
|
# 3. [Online softmax] first pass: find max + sum
|
|
76
120
|
m = float("-inf") # m is the max value. use the notation from the paper
|
|
77
121
|
d = 0.0 # d is the sum. use the notation from the paper
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
122
|
+
argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
|
|
123
|
+
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
|
|
124
|
+
if HAS_SOFTCAPPING:
|
|
125
|
+
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
|
81
126
|
|
|
82
127
|
# Label smoothing is a general case of normal cross entropy
|
|
83
128
|
# See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
|
|
@@ -87,12 +132,31 @@ def liger_cross_entropy_kernel(
|
|
|
87
132
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
88
133
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
89
134
|
X_block = tl.load(
|
|
90
|
-
X_ptr + X_offsets,
|
|
91
|
-
|
|
135
|
+
X_ptr + X_offsets,
|
|
136
|
+
mask=X_offsets < n_cols,
|
|
137
|
+
other=float("-inf"),
|
|
138
|
+
# Ensure float32 precision for softmax calculation
|
|
139
|
+
).cast(tl.float32)
|
|
140
|
+
if HAS_SOFTCAPPING:
|
|
141
|
+
X_block = softcap * tanh(X_block / softcap)
|
|
92
142
|
block_max = tl.max(X_block)
|
|
143
|
+
|
|
144
|
+
# Track argmax for accuracy computation
|
|
145
|
+
if RETURN_TOKEN_ACCURACY and block_max > m:
|
|
146
|
+
# Find the index of the maximum value in this block
|
|
147
|
+
is_max_mask = X_block == block_max
|
|
148
|
+
# Mask out invalid indices with a value larger than n_cols
|
|
149
|
+
masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
|
|
150
|
+
# Get the first (smallest) index where max occurs
|
|
151
|
+
argmax_idx = tl.min(masked_offsets)
|
|
152
|
+
|
|
93
153
|
if label_smoothing > 0:
|
|
94
154
|
# scale X beforehand to avoid overflow
|
|
95
|
-
|
|
155
|
+
if HAS_WEIGHT:
|
|
156
|
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
157
|
+
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0))
|
|
158
|
+
else:
|
|
159
|
+
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
|
|
96
160
|
m_new = tl.maximum(m, block_max)
|
|
97
161
|
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
|
|
98
162
|
m = m_new
|
|
@@ -116,23 +180,58 @@ def liger_cross_entropy_kernel(
|
|
|
116
180
|
# For 'sum' reduction, no normalization is applied:
|
|
117
181
|
# dx_y = softmax(x_y) - 1
|
|
118
182
|
# dx_i = softmax(x_i), for i ≠ y
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
183
|
+
if HAS_GRADIENTS:
|
|
184
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
185
|
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
186
|
+
X_block = tl.load(
|
|
187
|
+
X_ptr + X_offsets,
|
|
188
|
+
mask=X_offsets < n_cols,
|
|
189
|
+
other=float("-inf"),
|
|
190
|
+
# Ensure float32 precision for softmax calculation
|
|
191
|
+
).cast(tl.float32)
|
|
192
|
+
if HAS_SOFTCAPPING:
|
|
193
|
+
intermediate = tanh(X_block / softcap)
|
|
194
|
+
X_block = softcap * intermediate
|
|
195
|
+
|
|
196
|
+
if not HAS_WEIGHT:
|
|
197
|
+
# softmax(x_i)
|
|
198
|
+
X_block = tl.exp(X_block - m) / d
|
|
199
|
+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
|
200
|
+
X_block += 2 * lse_square_scale * lse * X_block
|
|
201
|
+
# smoothing term
|
|
202
|
+
X_block += -eps
|
|
203
|
+
# special handle dx_y
|
|
204
|
+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
|
|
205
|
+
# reduction scale
|
|
206
|
+
if reduction == "mean":
|
|
207
|
+
X_block = X_block / n_non_ignore
|
|
208
|
+
else:
|
|
209
|
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
210
|
+
softmax_X = tl.exp(X_block - m) / d
|
|
211
|
+
# derivative of original_loss
|
|
212
|
+
dloss_ori = (1 - label_smoothing) * softmax_X
|
|
213
|
+
# specially handle dx_y
|
|
214
|
+
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
|
|
215
|
+
dloss_ori = dloss_ori * weight_y
|
|
216
|
+
# derivative of smooth_loss
|
|
217
|
+
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
|
|
218
|
+
# derivative of z-loss
|
|
219
|
+
dz_loss = 2 * lse_square_scale * lse * softmax_X
|
|
220
|
+
# reduction scale
|
|
221
|
+
if reduction == "mean":
|
|
222
|
+
dloss_ori = dloss_ori / sum_non_ignore_weight
|
|
223
|
+
dloss_smooth = dloss_smooth / sum_non_ignore_weight
|
|
224
|
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
|
225
|
+
dz_loss = dz_loss / n_non_ignore
|
|
226
|
+
# derivative of total_loss
|
|
227
|
+
X_block = dloss_ori + dloss_smooth + dz_loss
|
|
228
|
+
|
|
229
|
+
# chain rule softcapping
|
|
230
|
+
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
|
|
231
|
+
if HAS_SOFTCAPPING:
|
|
232
|
+
X_block = X_block * (1 - intermediate * intermediate)
|
|
233
|
+
|
|
234
|
+
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
|
136
235
|
|
|
137
236
|
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
|
|
138
237
|
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
|
|
@@ -146,71 +245,68 @@ def liger_cross_entropy_kernel(
|
|
|
146
245
|
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
|
|
147
246
|
# So we can safely calculate log (softmax(X_y)) without overflow
|
|
148
247
|
loss = lse - ori_X_y
|
|
248
|
+
if HAS_WEIGHT:
|
|
249
|
+
loss = weight_y * loss
|
|
149
250
|
|
|
150
251
|
# Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
|
|
151
252
|
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
|
|
152
253
|
# = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
|
|
153
254
|
# By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
|
|
154
|
-
# = (1 - label_smoothing) * H(q, p) + (
|
|
255
|
+
# = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd))
|
|
155
256
|
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
|
|
156
257
|
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
|
|
157
258
|
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
|
|
158
259
|
if label_smoothing > 0:
|
|
159
|
-
|
|
260
|
+
if HAS_WEIGHT:
|
|
261
|
+
smooth_loss = scaled_x_sum + eps * lse * weight_sum
|
|
262
|
+
else:
|
|
263
|
+
smooth_loss = scaled_x_sum + label_smoothing * lse
|
|
160
264
|
loss = loss * (1 - label_smoothing) + smooth_loss
|
|
161
265
|
|
|
162
266
|
# An auxiliary loss, z_loss
|
|
163
267
|
# Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
|
|
164
268
|
z_loss = lse_square_scale * lse * lse
|
|
165
|
-
loss += z_loss
|
|
166
269
|
# Normalize the loss by the number of non-ignored elements if reduction is "mean"
|
|
167
270
|
if reduction == "mean":
|
|
271
|
+
if HAS_WEIGHT:
|
|
272
|
+
loss = loss / sum_non_ignore_weight
|
|
273
|
+
else:
|
|
274
|
+
loss = loss / n_non_ignore
|
|
275
|
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
|
168
276
|
z_loss = z_loss / n_non_ignore
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
|
|
172
|
-
X_y = tl.load(X_ptr + y)
|
|
173
|
-
if reduction == "mean":
|
|
174
|
-
X_y += -(1 - label_smoothing) / (n_non_ignore)
|
|
175
|
-
else:
|
|
176
|
-
X_y += -(1 - label_smoothing)
|
|
277
|
+
loss += z_loss
|
|
177
278
|
|
|
178
279
|
tl.store(loss_ptr, loss)
|
|
179
|
-
if RETURN_Z_LOSS
|
|
280
|
+
if RETURN_Z_LOSS:
|
|
180
281
|
tl.store(z_loss_ptr, z_loss)
|
|
181
|
-
|
|
282
|
+
if RETURN_TOKEN_ACCURACY:
|
|
283
|
+
# Store 1.0 if prediction is correct, 0.0 otherwise
|
|
284
|
+
is_correct = 1.0 if argmax_idx == y else 0.0
|
|
285
|
+
tl.store(token_accuracy_ptr, is_correct)
|
|
182
286
|
|
|
183
287
|
|
|
184
288
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
185
289
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
186
290
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
187
|
-
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
_bool_to_return_z_loss = {
|
|
191
|
-
True: _TRUE.value,
|
|
192
|
-
False: _FALSE.value,
|
|
193
|
-
}
|
|
291
|
+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
|
|
194
292
|
|
|
195
293
|
|
|
196
294
|
def cross_entropy_forward(
|
|
197
295
|
_input,
|
|
198
296
|
target,
|
|
297
|
+
weight,
|
|
199
298
|
ignore_index,
|
|
200
299
|
lse_square_scale,
|
|
201
300
|
label_smoothing,
|
|
202
301
|
reduction,
|
|
302
|
+
softcap,
|
|
203
303
|
return_z_loss,
|
|
304
|
+
return_token_accuracy=False,
|
|
204
305
|
):
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
return_z_loss = _bool_to_return_z_loss[return_z_loss]
|
|
210
|
-
else:
|
|
211
|
-
assert (
|
|
212
|
-
return_z_loss in _bool_to_return_z_loss
|
|
213
|
-
), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
306
|
+
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
307
|
+
assert isinstance(return_token_accuracy, bool), (
|
|
308
|
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
|
309
|
+
)
|
|
214
310
|
|
|
215
311
|
BT, V = _input.shape
|
|
216
312
|
n_rows = BT
|
|
@@ -219,12 +315,29 @@ def cross_entropy_forward(
|
|
|
219
315
|
|
|
220
316
|
# unreduced loss
|
|
221
317
|
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
222
|
-
if return_z_loss
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
318
|
+
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
319
|
+
token_accuracy_1d = (
|
|
320
|
+
torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
|
|
321
|
+
)
|
|
226
322
|
|
|
227
|
-
|
|
323
|
+
target_mask = target != ignore_index
|
|
324
|
+
n_non_ignore = target_mask.sum().item()
|
|
325
|
+
assert (target * target_mask).max() < _input.shape[-1], (
|
|
326
|
+
f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}"
|
|
327
|
+
)
|
|
328
|
+
assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0"
|
|
329
|
+
sum_non_ignore_weight = n_non_ignore
|
|
330
|
+
weight_sum = 0.0
|
|
331
|
+
if weight is not None:
|
|
332
|
+
assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
|
|
333
|
+
assert torch.is_floating_point(weight), (
|
|
334
|
+
f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
|
|
335
|
+
)
|
|
336
|
+
sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
|
|
337
|
+
weight_sum = weight.sum().item()
|
|
338
|
+
# ensure weight is contiguous
|
|
339
|
+
if weight.stride(-1) != 1:
|
|
340
|
+
weight = weight.contiguous()
|
|
228
341
|
|
|
229
342
|
# ensure _input and target are contiguous in the last dimension
|
|
230
343
|
if _input.stride(-1) != 1:
|
|
@@ -238,36 +351,55 @@ def cross_entropy_forward(
|
|
|
238
351
|
X_stride=_input.stride(-2),
|
|
239
352
|
Y_ptr=target,
|
|
240
353
|
Y_stride=target.stride(-1), # always 1
|
|
354
|
+
weight_ptr=weight, # dummy if None
|
|
241
355
|
loss_ptr=loss_1d,
|
|
242
356
|
z_loss_ptr=z_loss_1d,
|
|
243
357
|
loss_stride=loss_1d.stride(-1), # always 1
|
|
358
|
+
token_accuracy_ptr=token_accuracy_1d,
|
|
359
|
+
token_accuracy_stride=token_accuracy_1d.stride(-1)
|
|
360
|
+
if return_token_accuracy
|
|
361
|
+
else 0, # always 1 if accuracy is enabled
|
|
244
362
|
n_cols=V,
|
|
245
363
|
n_non_ignore=n_non_ignore,
|
|
364
|
+
sum_non_ignore_weight=sum_non_ignore_weight,
|
|
246
365
|
ignore_index=ignore_index,
|
|
366
|
+
weight_sum=weight_sum,
|
|
247
367
|
lse_square_scale=lse_square_scale,
|
|
248
368
|
label_smoothing=label_smoothing,
|
|
249
369
|
reduction=reduction,
|
|
370
|
+
softcap=softcap,
|
|
250
371
|
RETURN_Z_LOSS=return_z_loss,
|
|
372
|
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
|
251
373
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
374
|
+
HAS_WEIGHT=True if weight is not None else False,
|
|
375
|
+
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
376
|
+
HAS_GRADIENTS=_input.requires_grad,
|
|
252
377
|
# TODO: 32 seems to give the best performance
|
|
253
378
|
# Performance is quite sensitive to num_warps
|
|
254
379
|
num_warps=32 if not is_hip() else 16,
|
|
255
380
|
)
|
|
256
381
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
z_loss =
|
|
382
|
+
if reduction == "none":
|
|
383
|
+
loss = loss_1d
|
|
384
|
+
z_loss = z_loss_1d if return_z_loss else None
|
|
385
|
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
|
260
386
|
else:
|
|
261
|
-
|
|
387
|
+
loss = torch.sum(loss_1d)
|
|
388
|
+
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
389
|
+
# For accuracy, we compute the mean across all non-ignored tokens
|
|
390
|
+
token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
|
|
262
391
|
|
|
263
|
-
return loss, z_loss, _input
|
|
392
|
+
return loss, z_loss, token_accuracy, _input
|
|
264
393
|
|
|
265
394
|
|
|
266
395
|
def cross_entropy_backward(_input, grad_output):
|
|
267
396
|
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
|
268
397
|
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
269
398
|
pass
|
|
270
|
-
|
|
399
|
+
# If reduction is 'none'
|
|
400
|
+
elif grad_output.ndim > 0:
|
|
401
|
+
_input = _input * grad_output.unsqueeze(dim=1)
|
|
402
|
+
# If reduction is ['mean', 'sum'], grad_output is just a scalar
|
|
271
403
|
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
|
272
404
|
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
|
273
405
|
else:
|
|
@@ -296,13 +428,16 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
296
428
|
@staticmethod
|
|
297
429
|
def forward(
|
|
298
430
|
ctx,
|
|
299
|
-
_input,
|
|
300
|
-
target,
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
431
|
+
_input: torch.Tensor,
|
|
432
|
+
target: torch.Tensor,
|
|
433
|
+
weight: Optional[torch.FloatTensor],
|
|
434
|
+
ignore_index: int = -100,
|
|
435
|
+
lse_square_scale: float = 0.0,
|
|
436
|
+
label_smoothing: float = 0.0,
|
|
437
|
+
reduction: str = "mean",
|
|
438
|
+
softcap: Optional[float] = None,
|
|
439
|
+
return_z_loss: bool = False,
|
|
440
|
+
return_token_accuracy: bool = False,
|
|
306
441
|
):
|
|
307
442
|
"""
|
|
308
443
|
The forward pass of the Liger Cross Entropy loss.
|
|
@@ -311,46 +446,59 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
311
446
|
ctx : The context object.
|
|
312
447
|
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
|
|
313
448
|
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
|
|
449
|
+
weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
|
|
314
450
|
ignore_index (int): The index to ignore in the target.
|
|
315
451
|
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
316
452
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
317
453
|
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
|
|
318
|
-
|
|
454
|
+
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
455
|
+
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
|
|
456
|
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
|
319
457
|
|
|
320
458
|
Returns:
|
|
321
|
-
tuple: A tuple with the
|
|
459
|
+
tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
|
|
322
460
|
"""
|
|
323
|
-
|
|
461
|
+
input_requires_grad = _input.requires_grad
|
|
462
|
+
|
|
463
|
+
loss, z_loss, token_accuracy, _input = cross_entropy_forward(
|
|
324
464
|
_input,
|
|
325
465
|
target,
|
|
466
|
+
weight,
|
|
326
467
|
ignore_index,
|
|
327
468
|
lse_square_scale,
|
|
328
469
|
label_smoothing,
|
|
329
470
|
reduction,
|
|
471
|
+
softcap,
|
|
330
472
|
return_z_loss,
|
|
473
|
+
return_token_accuracy,
|
|
331
474
|
)
|
|
332
475
|
# TODO: investigation
|
|
333
476
|
# If we don't detach the _input tensor, the memory will double
|
|
334
477
|
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
|
335
|
-
|
|
478
|
+
if input_requires_grad:
|
|
479
|
+
ctx.save_for_backward(_input.detach())
|
|
336
480
|
ctx.return_z_loss = return_z_loss
|
|
481
|
+
ctx.return_token_accuracy = return_token_accuracy
|
|
337
482
|
|
|
338
|
-
return loss, z_loss
|
|
483
|
+
return loss, z_loss, token_accuracy
|
|
339
484
|
|
|
340
485
|
@staticmethod
|
|
341
|
-
def backward(ctx, grad_output,
|
|
486
|
+
def backward(ctx, grad_output, grad_output2, grad_output3):
|
|
342
487
|
"""
|
|
343
488
|
The backward pass of the Liger Cross Entropy loss.
|
|
344
489
|
|
|
345
490
|
Parameters:
|
|
346
491
|
ctx : The context object with saved tensors.
|
|
347
492
|
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
|
|
348
|
-
grad_output2 (
|
|
493
|
+
grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
|
|
494
|
+
grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
|
|
349
495
|
Returns:
|
|
350
496
|
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
|
|
351
497
|
"""
|
|
352
498
|
if ctx.return_z_loss:
|
|
353
|
-
del
|
|
499
|
+
del grad_output2 # z_loss is only for logging
|
|
500
|
+
if ctx.return_token_accuracy:
|
|
501
|
+
del grad_output3 # token_accuracy is only for metrics
|
|
354
502
|
|
|
355
503
|
(_input,) = ctx.saved_tensors
|
|
356
504
|
_input = cross_entropy_backward(_input, grad_output)
|
|
@@ -362,4 +510,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
362
510
|
None,
|
|
363
511
|
None,
|
|
364
512
|
None,
|
|
513
|
+
None,
|
|
514
|
+
None,
|
|
515
|
+
None,
|
|
365
516
|
)
|