liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +307 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +63 -0
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +383 -114
- liger_kernel/ops/dyt.py +160 -0
- liger_kernel/ops/experimental/embedding.py +141 -0
- liger_kernel/ops/experimental/mm_int8int2.py +349 -0
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
- liger_kernel/ops/fused_linear_jsd.py +228 -0
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +66 -64
- liger_kernel/ops/group_norm.py +306 -0
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +201 -0
- liger_kernel/ops/kl_div.py +262 -0
- liger_kernel/ops/layer_norm.py +320 -0
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +484 -88
- liger_kernel/ops/rope.py +122 -117
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +68 -65
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +82 -3
- liger_kernel/transformers/__init__.py +218 -6
- liger_kernel/transformers/auto_model.py +38 -0
- liger_kernel/transformers/cross_entropy.py +52 -7
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +26 -0
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +301 -0
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
- liger_kernel/transformers/fused_linear_jsd.py +95 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +6 -7
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +70 -0
- liger_kernel/transformers/kl_div.py +12 -0
- liger_kernel/transformers/layer_norm.py +24 -0
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +261 -0
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +332 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +221 -41
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +145 -0
- liger_kernel/transformers/model/mixtral.py +293 -0
- liger_kernel/transformers/model/mllama.py +269 -0
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +433 -0
- liger_kernel/transformers/model/phi3.py +120 -0
- liger_kernel/transformers/model/qwen2.py +259 -0
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +159 -0
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2816 -21
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +75 -5
- liger_kernel/transformers/rope.py +47 -3
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +62 -6
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/trainer_integration.py +2 -45
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -5
- liger_kernel/utils.py +96 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,27 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
1
5
|
import torch
|
|
2
6
|
import triton
|
|
3
7
|
import triton.language as tl
|
|
4
8
|
|
|
9
|
+
from liger_kernel.ops.utils import compare_version
|
|
10
|
+
from liger_kernel.ops.utils import element_mul_kernel
|
|
11
|
+
from liger_kernel.ops.utils import is_hip
|
|
12
|
+
from liger_kernel.utils import infer_device
|
|
13
|
+
from liger_kernel.utils import is_npu_available
|
|
14
|
+
|
|
15
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
16
|
+
try:
|
|
17
|
+
# typical import path with dispatch available
|
|
18
|
+
from triton.language.extra.libdevice import tanh
|
|
19
|
+
except ModuleNotFoundError:
|
|
20
|
+
# for working with NGC containers
|
|
21
|
+
from triton.language.extra.cuda.libdevice import tanh
|
|
22
|
+
else:
|
|
23
|
+
from triton.language.math import tanh
|
|
24
|
+
|
|
5
25
|
|
|
6
26
|
@triton.jit
|
|
7
27
|
def liger_cross_entropy_kernel(
|
|
@@ -9,12 +29,27 @@ def liger_cross_entropy_kernel(
|
|
|
9
29
|
X_stride,
|
|
10
30
|
Y_ptr,
|
|
11
31
|
Y_stride,
|
|
32
|
+
weight_ptr,
|
|
12
33
|
loss_ptr,
|
|
34
|
+
z_loss_ptr,
|
|
13
35
|
loss_stride,
|
|
36
|
+
token_accuracy_ptr,
|
|
37
|
+
token_accuracy_stride,
|
|
14
38
|
n_cols,
|
|
15
39
|
n_non_ignore,
|
|
40
|
+
sum_non_ignore_weight,
|
|
41
|
+
weight_sum,
|
|
16
42
|
ignore_index,
|
|
43
|
+
lse_square_scale: tl.constexpr,
|
|
44
|
+
label_smoothing: tl.constexpr,
|
|
45
|
+
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
|
|
46
|
+
softcap,
|
|
47
|
+
RETURN_Z_LOSS: tl.constexpr,
|
|
48
|
+
RETURN_TOKEN_ACCURACY: tl.constexpr,
|
|
17
49
|
BLOCK_SIZE: tl.constexpr,
|
|
50
|
+
HAS_WEIGHT: tl.constexpr,
|
|
51
|
+
HAS_SOFTCAPPING: tl.constexpr,
|
|
52
|
+
HAS_GRADIENTS: tl.constexpr,
|
|
18
53
|
):
|
|
19
54
|
"""
|
|
20
55
|
This kernel computes both cross entropy loss and the gradient of the input.
|
|
@@ -25,12 +60,27 @@ def liger_cross_entropy_kernel(
|
|
|
25
60
|
X_stride (int): The stride of the input tensor.
|
|
26
61
|
Y_ptr: Pointer to target tensor.
|
|
27
62
|
Y_stride (int): The stride of the target tensor.
|
|
63
|
+
weight_ptr: Pointer to weight tensor.
|
|
28
64
|
loss_ptr: Pointer to tensor to store the loss.
|
|
65
|
+
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
|
29
66
|
loss_stride (int): The stride of the loss tensor.
|
|
67
|
+
token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
|
|
68
|
+
token_accuracy_stride (int): The stride of the token accuracy tensor.
|
|
30
69
|
n_cols (int): The number of columns in the input tensor.
|
|
31
|
-
n_non_ignore (
|
|
70
|
+
n_non_ignore (float): The number of non-ignored elements in the batch.
|
|
71
|
+
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
|
72
|
+
weight_sum (float): The sum of weight tensor.
|
|
32
73
|
ignore_index (int): The index to ignore in the target.
|
|
74
|
+
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
75
|
+
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
76
|
+
reduction (str): The string for the reduction to apply
|
|
77
|
+
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
78
|
+
RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
|
|
79
|
+
RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
|
|
33
80
|
BLOCK_SIZE (int): The block size for Triton operations.
|
|
81
|
+
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
|
82
|
+
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
|
83
|
+
HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
|
|
34
84
|
"""
|
|
35
85
|
|
|
36
86
|
# https://github.com/triton-lang/triton/issues/1058
|
|
@@ -49,102 +99,325 @@ def liger_cross_entropy_kernel(
|
|
|
49
99
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
50
100
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
51
101
|
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
|
|
102
|
+
# For ignored tokens, set token accuracy to 0
|
|
103
|
+
if RETURN_TOKEN_ACCURACY:
|
|
104
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
105
|
+
tl.store(token_accuracy_ptr, 0.0)
|
|
52
106
|
return
|
|
53
107
|
|
|
54
108
|
loss_ptr += program_id * loss_stride
|
|
109
|
+
if RETURN_Z_LOSS:
|
|
110
|
+
z_loss_ptr += program_id * loss_stride
|
|
111
|
+
if RETURN_TOKEN_ACCURACY:
|
|
112
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
113
|
+
|
|
114
|
+
if HAS_WEIGHT:
|
|
115
|
+
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
|
|
55
116
|
|
|
56
117
|
# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
|
|
57
118
|
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
|
|
58
119
|
|
|
59
|
-
# 3. [
|
|
120
|
+
# 3. [Online softmax] first pass: find max + sum
|
|
60
121
|
m = float("-inf") # m is the max value. use the notation from the paper
|
|
61
122
|
d = 0.0 # d is the sum. use the notation from the paper
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
123
|
+
argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
|
|
124
|
+
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
|
|
125
|
+
if HAS_SOFTCAPPING:
|
|
126
|
+
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
|
127
|
+
|
|
128
|
+
# Label smoothing is a general case of normal cross entropy
|
|
129
|
+
# See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
|
|
130
|
+
scaled_x_sum = 0.0
|
|
131
|
+
eps = label_smoothing / n_cols
|
|
65
132
|
|
|
66
133
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
67
134
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
68
135
|
X_block = tl.load(
|
|
69
|
-
X_ptr + X_offsets,
|
|
70
|
-
|
|
136
|
+
X_ptr + X_offsets,
|
|
137
|
+
mask=X_offsets < n_cols,
|
|
138
|
+
other=float("-inf"),
|
|
139
|
+
# Ensure float32 precision for softmax calculation
|
|
140
|
+
).cast(tl.float32)
|
|
141
|
+
if HAS_SOFTCAPPING:
|
|
142
|
+
X_block = softcap * tanh(X_block / softcap)
|
|
71
143
|
block_max = tl.max(X_block)
|
|
144
|
+
|
|
145
|
+
# Track argmax for accuracy computation
|
|
146
|
+
if RETURN_TOKEN_ACCURACY and block_max > m:
|
|
147
|
+
# Find the index of the maximum value in this block
|
|
148
|
+
is_max_mask = X_block == block_max
|
|
149
|
+
# Mask out invalid indices with a value larger than n_cols
|
|
150
|
+
masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
|
|
151
|
+
# Get the first (smallest) index where max occurs
|
|
152
|
+
argmax_idx = tl.min(masked_offsets)
|
|
153
|
+
|
|
154
|
+
if label_smoothing > 0:
|
|
155
|
+
# scale X beforehand to avoid overflow
|
|
156
|
+
if HAS_WEIGHT:
|
|
157
|
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
158
|
+
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0))
|
|
159
|
+
else:
|
|
160
|
+
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
|
|
72
161
|
m_new = tl.maximum(m, block_max)
|
|
73
162
|
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
|
|
74
163
|
m = m_new
|
|
75
164
|
|
|
76
|
-
#
|
|
165
|
+
# log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X)))))
|
|
166
|
+
# = log (e^(max(X)) * sum(e ^ (X_i - max(X))))
|
|
167
|
+
# = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d
|
|
168
|
+
lse = m + tl.log(d)
|
|
169
|
+
|
|
170
|
+
# 4. [Online Softmax] Second pass: compute gradients
|
|
171
|
+
# For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
|
|
77
172
|
# dx_y = (softmax(x_y) - 1) / N
|
|
78
173
|
# dx_i = softmax(x_i) / N, i != y
|
|
79
|
-
#
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
174
|
+
# For label smoothing:
|
|
175
|
+
# dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y
|
|
176
|
+
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
|
|
177
|
+
# = dx_i - (1 - label_smoothing) / N
|
|
178
|
+
# With Z loss:
|
|
179
|
+
# dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y
|
|
180
|
+
# dx_y = dx_i - (1 - label_smoothing) / N
|
|
181
|
+
# For 'sum' reduction, no normalization is applied:
|
|
182
|
+
# dx_y = softmax(x_y) - 1
|
|
183
|
+
# dx_i = softmax(x_i), for i ≠ y
|
|
184
|
+
if HAS_GRADIENTS:
|
|
185
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
186
|
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
187
|
+
X_block = tl.load(
|
|
188
|
+
X_ptr + X_offsets,
|
|
189
|
+
mask=X_offsets < n_cols,
|
|
190
|
+
other=float("-inf"),
|
|
191
|
+
# Ensure float32 precision for softmax calculation
|
|
192
|
+
).cast(tl.float32)
|
|
193
|
+
if HAS_SOFTCAPPING:
|
|
194
|
+
intermediate = tanh(X_block / softcap)
|
|
195
|
+
X_block = softcap * intermediate
|
|
196
|
+
|
|
197
|
+
if not HAS_WEIGHT:
|
|
198
|
+
# softmax(x_i)
|
|
199
|
+
X_block = tl.exp(X_block - m) / d
|
|
200
|
+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
|
201
|
+
X_block += 2 * lse_square_scale * lse * X_block
|
|
202
|
+
# smoothing term
|
|
203
|
+
X_block += -eps
|
|
204
|
+
# special handle dx_y
|
|
205
|
+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
|
|
206
|
+
# reduction scale
|
|
207
|
+
if reduction == "mean":
|
|
208
|
+
X_block = X_block / n_non_ignore
|
|
209
|
+
else:
|
|
210
|
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
211
|
+
softmax_X = tl.exp(X_block - m) / d
|
|
212
|
+
# derivative of original_loss
|
|
213
|
+
dloss_ori = (1 - label_smoothing) * softmax_X
|
|
214
|
+
# specially handle dx_y
|
|
215
|
+
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
|
|
216
|
+
dloss_ori = dloss_ori * weight_y
|
|
217
|
+
# derivative of smooth_loss
|
|
218
|
+
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
|
|
219
|
+
# derivative of z-loss
|
|
220
|
+
dz_loss = 2 * lse_square_scale * lse * softmax_X
|
|
221
|
+
# reduction scale
|
|
222
|
+
if reduction == "mean":
|
|
223
|
+
dloss_ori = dloss_ori / sum_non_ignore_weight
|
|
224
|
+
dloss_smooth = dloss_smooth / sum_non_ignore_weight
|
|
225
|
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
|
226
|
+
dz_loss = dz_loss / n_non_ignore
|
|
227
|
+
# derivative of total_loss
|
|
228
|
+
X_block = dloss_ori + dloss_smooth + dz_loss
|
|
229
|
+
|
|
230
|
+
# chain rule softcapping
|
|
231
|
+
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
|
|
232
|
+
if HAS_SOFTCAPPING:
|
|
233
|
+
X_block = X_block * (1 - intermediate * intermediate)
|
|
234
|
+
|
|
235
|
+
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
|
87
236
|
|
|
88
237
|
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
|
|
89
|
-
#
|
|
238
|
+
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
|
|
90
239
|
tl.debug_barrier()
|
|
91
240
|
|
|
92
241
|
# 5. Calculate the loss
|
|
93
242
|
|
|
94
243
|
# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
|
|
95
244
|
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
|
|
245
|
+
# = X_y - m - log d = X_y - lse
|
|
96
246
|
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
|
|
97
247
|
# So we can safely calculate log (softmax(X_y)) without overflow
|
|
98
|
-
loss =
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
248
|
+
loss = lse - ori_X_y
|
|
249
|
+
if HAS_WEIGHT:
|
|
250
|
+
loss = weight_y * loss
|
|
251
|
+
|
|
252
|
+
# Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
|
|
253
|
+
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
|
|
254
|
+
# = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
|
|
255
|
+
# By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
|
|
256
|
+
# = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd))
|
|
257
|
+
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
|
|
258
|
+
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
|
|
259
|
+
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
|
|
260
|
+
if label_smoothing > 0:
|
|
261
|
+
if HAS_WEIGHT:
|
|
262
|
+
smooth_loss = scaled_x_sum + eps * lse * weight_sum
|
|
263
|
+
else:
|
|
264
|
+
smooth_loss = scaled_x_sum + label_smoothing * lse
|
|
265
|
+
loss = loss * (1 - label_smoothing) + smooth_loss
|
|
266
|
+
|
|
267
|
+
# An auxiliary loss, z_loss
|
|
268
|
+
# Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
|
|
269
|
+
z_loss = lse_square_scale * lse * lse
|
|
270
|
+
# Normalize the loss by the number of non-ignored elements if reduction is "mean"
|
|
271
|
+
if reduction == "mean":
|
|
272
|
+
if HAS_WEIGHT:
|
|
273
|
+
loss = loss / sum_non_ignore_weight
|
|
274
|
+
else:
|
|
275
|
+
loss = loss / n_non_ignore
|
|
276
|
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
|
277
|
+
z_loss = z_loss / n_non_ignore
|
|
278
|
+
loss += z_loss
|
|
103
279
|
|
|
104
280
|
tl.store(loss_ptr, loss)
|
|
105
|
-
|
|
281
|
+
if RETURN_Z_LOSS:
|
|
282
|
+
tl.store(z_loss_ptr, z_loss)
|
|
283
|
+
if RETURN_TOKEN_ACCURACY:
|
|
284
|
+
# Store 1.0 if prediction is correct, 0.0 otherwise
|
|
285
|
+
is_correct = 1.0 if argmax_idx == y else 0.0
|
|
286
|
+
tl.store(token_accuracy_ptr, is_correct)
|
|
106
287
|
|
|
107
288
|
|
|
108
289
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
109
290
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
110
291
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
111
|
-
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
|
|
292
|
+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
|
|
112
293
|
|
|
113
294
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
295
|
+
def cross_entropy_forward(
|
|
296
|
+
_input,
|
|
297
|
+
target,
|
|
298
|
+
weight,
|
|
299
|
+
ignore_index,
|
|
300
|
+
lse_square_scale,
|
|
301
|
+
label_smoothing,
|
|
302
|
+
reduction,
|
|
303
|
+
softcap,
|
|
304
|
+
return_z_loss,
|
|
305
|
+
return_token_accuracy=False,
|
|
121
306
|
):
|
|
122
|
-
""
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
307
|
+
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
308
|
+
assert isinstance(return_token_accuracy, bool), (
|
|
309
|
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
BT, V = _input.shape
|
|
313
|
+
n_rows = BT
|
|
314
|
+
|
|
315
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
316
|
+
|
|
317
|
+
# unreduced loss
|
|
318
|
+
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
319
|
+
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
320
|
+
token_accuracy_1d = (
|
|
321
|
+
torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
target_mask = target != ignore_index
|
|
325
|
+
n_non_ignore = target_mask.sum().item()
|
|
326
|
+
assert (target * target_mask).max() < _input.shape[-1], (
|
|
327
|
+
f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}"
|
|
328
|
+
)
|
|
329
|
+
assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0"
|
|
330
|
+
sum_non_ignore_weight = n_non_ignore
|
|
331
|
+
weight_sum = 0.0
|
|
332
|
+
if weight is not None:
|
|
333
|
+
assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
|
|
334
|
+
assert torch.is_floating_point(weight), (
|
|
335
|
+
f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
|
|
336
|
+
)
|
|
337
|
+
sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
|
|
338
|
+
weight_sum = weight.sum().item()
|
|
339
|
+
# ensure weight is contiguous
|
|
340
|
+
if weight.stride(-1) != 1:
|
|
341
|
+
weight = weight.contiguous()
|
|
342
|
+
|
|
343
|
+
# ensure _input and target are contiguous in the last dimension
|
|
344
|
+
if _input.stride(-1) != 1:
|
|
345
|
+
_input = _input.contiguous()
|
|
346
|
+
if target.stride(-1) != 1:
|
|
347
|
+
target = target.contiguous()
|
|
348
|
+
|
|
349
|
+
# Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
|
|
350
|
+
liger_cross_entropy_kernel[(n_rows,)](
|
|
351
|
+
X_ptr=_input,
|
|
352
|
+
X_stride=_input.stride(-2),
|
|
353
|
+
Y_ptr=target,
|
|
354
|
+
Y_stride=target.stride(-1), # always 1
|
|
355
|
+
weight_ptr=weight, # dummy if None
|
|
356
|
+
loss_ptr=loss_1d,
|
|
357
|
+
z_loss_ptr=z_loss_1d,
|
|
358
|
+
loss_stride=loss_1d.stride(-1), # always 1
|
|
359
|
+
token_accuracy_ptr=token_accuracy_1d,
|
|
360
|
+
token_accuracy_stride=token_accuracy_1d.stride(-1)
|
|
361
|
+
if return_token_accuracy
|
|
362
|
+
else 0, # always 1 if accuracy is enabled
|
|
363
|
+
n_cols=V,
|
|
364
|
+
n_non_ignore=n_non_ignore,
|
|
365
|
+
sum_non_ignore_weight=sum_non_ignore_weight,
|
|
366
|
+
ignore_index=ignore_index,
|
|
367
|
+
weight_sum=weight_sum,
|
|
368
|
+
lse_square_scale=lse_square_scale,
|
|
369
|
+
label_smoothing=label_smoothing,
|
|
370
|
+
reduction=reduction,
|
|
371
|
+
softcap=softcap,
|
|
372
|
+
RETURN_Z_LOSS=return_z_loss,
|
|
373
|
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
|
374
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
375
|
+
HAS_WEIGHT=True if weight is not None else False,
|
|
376
|
+
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
377
|
+
HAS_GRADIENTS=_input.requires_grad,
|
|
378
|
+
# TODO: 32 seems to give the best performance
|
|
379
|
+
# Performance is quite sensitive to num_warps
|
|
380
|
+
num_warps=32 if not is_hip() else 16,
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
if reduction == "none":
|
|
384
|
+
loss = loss_1d
|
|
385
|
+
z_loss = z_loss_1d if return_z_loss else None
|
|
386
|
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
|
387
|
+
else:
|
|
388
|
+
loss = torch.sum(loss_1d)
|
|
389
|
+
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
390
|
+
# For accuracy, we compute the mean across all non-ignored tokens
|
|
391
|
+
token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
|
|
392
|
+
|
|
393
|
+
return loss, z_loss, token_accuracy, _input
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def cross_entropy_backward(_input, grad_output):
|
|
397
|
+
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
|
398
|
+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
399
|
+
pass
|
|
400
|
+
# If reduction is 'none'
|
|
401
|
+
elif grad_output.ndim > 0:
|
|
402
|
+
_input = _input * grad_output.unsqueeze(dim=1)
|
|
403
|
+
# If reduction is ['mean', 'sum'], grad_output is just a scalar
|
|
404
|
+
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
|
405
|
+
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
|
406
|
+
else:
|
|
407
|
+
BT, V = _input.shape
|
|
408
|
+
n_rows = BT
|
|
409
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
139
410
|
|
|
140
|
-
|
|
141
|
-
|
|
411
|
+
element_mul_kernel[(n_rows,)](
|
|
412
|
+
_input,
|
|
413
|
+
_input.stride(-2),
|
|
414
|
+
grad_output,
|
|
415
|
+
V,
|
|
416
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
417
|
+
num_warps=32 if not is_hip() else 16,
|
|
418
|
+
)
|
|
142
419
|
|
|
143
|
-
|
|
144
|
-
for i in range(0, n_cols, BLOCK_SIZE):
|
|
145
|
-
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
146
|
-
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
147
|
-
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
|
|
420
|
+
return _input
|
|
148
421
|
|
|
149
422
|
|
|
150
423
|
class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
@@ -154,7 +427,19 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
154
427
|
"""
|
|
155
428
|
|
|
156
429
|
@staticmethod
|
|
157
|
-
def forward(
|
|
430
|
+
def forward(
|
|
431
|
+
ctx,
|
|
432
|
+
_input: torch.Tensor,
|
|
433
|
+
target: torch.Tensor,
|
|
434
|
+
weight: Optional[torch.FloatTensor],
|
|
435
|
+
ignore_index: int = -100,
|
|
436
|
+
lse_square_scale: float = 0.0,
|
|
437
|
+
label_smoothing: float = 0.0,
|
|
438
|
+
reduction: str = "mean",
|
|
439
|
+
softcap: Optional[float] = None,
|
|
440
|
+
return_z_loss: bool = False,
|
|
441
|
+
return_token_accuracy: bool = False,
|
|
442
|
+
):
|
|
158
443
|
"""
|
|
159
444
|
The forward pass of the Liger Cross Entropy loss.
|
|
160
445
|
|
|
@@ -162,87 +447,71 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
162
447
|
ctx : The context object.
|
|
163
448
|
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
|
|
164
449
|
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
|
|
450
|
+
weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
|
|
165
451
|
ignore_index (int): The index to ignore in the target.
|
|
452
|
+
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
453
|
+
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
454
|
+
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
|
|
455
|
+
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
456
|
+
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
|
|
457
|
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
|
166
458
|
|
|
167
459
|
Returns:
|
|
168
|
-
|
|
460
|
+
tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
|
|
169
461
|
"""
|
|
170
|
-
|
|
171
|
-
n_rows = BT
|
|
172
|
-
|
|
173
|
-
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
462
|
+
input_requires_grad = _input.requires_grad
|
|
174
463
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
# Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
|
|
187
|
-
liger_cross_entropy_kernel[(n_rows,)](
|
|
188
|
-
X_ptr=_input,
|
|
189
|
-
X_stride=_input.stride(-2),
|
|
190
|
-
Y_ptr=target,
|
|
191
|
-
Y_stride=target.stride(-1), # always 1
|
|
192
|
-
loss_ptr=loss_1d,
|
|
193
|
-
loss_stride=loss_1d.stride(-1), # always 1
|
|
194
|
-
n_cols=V,
|
|
195
|
-
n_non_ignore=n_non_ignore,
|
|
196
|
-
ignore_index=ignore_index,
|
|
197
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
|
198
|
-
# TODO: 32 seems to give the best performance
|
|
199
|
-
# Performance is quite sentitive to num_warps
|
|
200
|
-
num_warps=32,
|
|
464
|
+
loss, z_loss, token_accuracy, _input = cross_entropy_forward(
|
|
465
|
+
_input,
|
|
466
|
+
target,
|
|
467
|
+
weight,
|
|
468
|
+
ignore_index,
|
|
469
|
+
lse_square_scale,
|
|
470
|
+
label_smoothing,
|
|
471
|
+
reduction,
|
|
472
|
+
softcap,
|
|
473
|
+
return_z_loss,
|
|
474
|
+
return_token_accuracy,
|
|
201
475
|
)
|
|
202
|
-
|
|
203
|
-
loss = torch.sum(loss_1d) / n_non_ignore
|
|
204
|
-
|
|
205
476
|
# TODO: investigation
|
|
206
477
|
# If we don't detach the _input tensor, the memory will double
|
|
207
478
|
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
|
208
|
-
|
|
209
|
-
|
|
479
|
+
if input_requires_grad:
|
|
480
|
+
ctx.save_for_backward(_input.detach())
|
|
481
|
+
ctx.return_z_loss = return_z_loss
|
|
482
|
+
ctx.return_token_accuracy = return_token_accuracy
|
|
483
|
+
|
|
484
|
+
return loss, z_loss, token_accuracy
|
|
210
485
|
|
|
211
486
|
@staticmethod
|
|
212
|
-
def backward(ctx, grad_output):
|
|
487
|
+
def backward(ctx, grad_output, grad_output2, grad_output3):
|
|
213
488
|
"""
|
|
214
489
|
The backward pass of the Liger Cross Entropy loss.
|
|
215
490
|
|
|
216
491
|
Parameters:
|
|
217
492
|
ctx : The context object with saved tensors.
|
|
218
493
|
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
|
|
219
|
-
|
|
494
|
+
grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
|
|
495
|
+
grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
|
|
220
496
|
Returns:
|
|
221
497
|
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
|
|
222
498
|
"""
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
if
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
|
229
|
-
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
|
230
|
-
else:
|
|
231
|
-
BT, V = _input.shape
|
|
232
|
-
n_rows = BT
|
|
233
|
-
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
234
|
-
|
|
235
|
-
element_mul[(n_rows,)](
|
|
236
|
-
_input,
|
|
237
|
-
_input.stride(-2),
|
|
238
|
-
grad_output,
|
|
239
|
-
V,
|
|
240
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
|
241
|
-
num_warps=32,
|
|
242
|
-
)
|
|
499
|
+
if ctx.return_z_loss:
|
|
500
|
+
del grad_output2 # z_loss is only for logging
|
|
501
|
+
if ctx.return_token_accuracy:
|
|
502
|
+
del grad_output3 # token_accuracy is only for metrics
|
|
243
503
|
|
|
504
|
+
(_input,) = ctx.saved_tensors
|
|
505
|
+
_input = cross_entropy_backward(_input, grad_output)
|
|
244
506
|
return (
|
|
245
507
|
_input,
|
|
246
508
|
None,
|
|
247
509
|
None,
|
|
510
|
+
None,
|
|
511
|
+
None,
|
|
512
|
+
None,
|
|
513
|
+
None,
|
|
514
|
+
None,
|
|
515
|
+
None,
|
|
516
|
+
None,
|
|
248
517
|
)
|