liger-kernel-nightly 0.6.2.dev20251011154226__py3-none-any.whl → 0.6.2.dev20251011154427__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/cross_entropy.py +55 -52
- liger_kernel/ops/fused_linear_cross_entropy.py +3 -2
- {liger_kernel_nightly-0.6.2.dev20251011154226.dist-info → liger_kernel_nightly-0.6.2.dev20251011154427.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.6.2.dev20251011154226.dist-info → liger_kernel_nightly-0.6.2.dev20251011154427.dist-info}/RECORD +8 -8
- {liger_kernel_nightly-0.6.2.dev20251011154226.dist-info → liger_kernel_nightly-0.6.2.dev20251011154427.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154226.dist-info → liger_kernel_nightly-0.6.2.dev20251011154427.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154226.dist-info → liger_kernel_nightly-0.6.2.dev20251011154427.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154226.dist-info → liger_kernel_nightly-0.6.2.dev20251011154427.dist-info}/top_level.txt +0 -0
@@ -45,6 +45,7 @@ def liger_cross_entropy_kernel(
|
|
45
45
|
BLOCK_SIZE: tl.constexpr,
|
46
46
|
HAS_WEIGHT: tl.constexpr,
|
47
47
|
HAS_SOFTCAPPING: tl.constexpr,
|
48
|
+
HAS_GRADIENTS: tl.constexpr,
|
48
49
|
):
|
49
50
|
"""
|
50
51
|
This kernel computes both cross entropy loss and the gradient of the input.
|
@@ -72,6 +73,7 @@ def liger_cross_entropy_kernel(
|
|
72
73
|
BLOCK_SIZE (int): The block size for Triton operations.
|
73
74
|
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
74
75
|
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
76
|
+
HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
|
75
77
|
"""
|
76
78
|
|
77
79
|
# https://github.com/triton-lang/triton/issues/1058
|
@@ -155,58 +157,58 @@ def liger_cross_entropy_kernel(
|
|
155
157
|
# For 'sum' reduction, no normalization is applied:
|
156
158
|
# dx_y = softmax(x_y) - 1
|
157
159
|
# dx_i = softmax(x_i), for i ≠ y
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
160
|
+
if HAS_GRADIENTS:
|
161
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
162
|
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
163
|
+
X_block = tl.load(
|
164
|
+
X_ptr + X_offsets,
|
165
|
+
mask=X_offsets < n_cols,
|
166
|
+
other=float("-inf"),
|
167
|
+
# Ensure float32 precision for softmax calculation
|
168
|
+
).cast(tl.float32)
|
169
|
+
if HAS_SOFTCAPPING:
|
170
|
+
intermediate = tanh(X_block / softcap)
|
171
|
+
X_block = softcap * intermediate
|
172
|
+
|
173
|
+
if not HAS_WEIGHT:
|
174
|
+
# softmax(x_i)
|
175
|
+
X_block = tl.exp(X_block - m) / d
|
176
|
+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
177
|
+
X_block += 2 * lse_square_scale * lse * X_block
|
178
|
+
# smoothing term
|
179
|
+
X_block += -eps
|
180
|
+
# special handle dx_y
|
181
|
+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
|
182
|
+
# reduction scale
|
183
|
+
if reduction == "mean":
|
184
|
+
X_block = X_block / n_non_ignore
|
185
|
+
else:
|
186
|
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
187
|
+
softmax_X = tl.exp(X_block - m) / d
|
188
|
+
# derivative of original_loss
|
189
|
+
dloss_ori = (1 - label_smoothing) * softmax_X
|
190
|
+
# specially handle dx_y
|
191
|
+
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
|
192
|
+
dloss_ori = dloss_ori * weight_y
|
193
|
+
# derivative of smooth_loss
|
194
|
+
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
|
195
|
+
# derivative of z-loss
|
196
|
+
dz_loss = 2 * lse_square_scale * lse * softmax_X
|
197
|
+
# reduction scale
|
198
|
+
if reduction == "mean":
|
199
|
+
dloss_ori = dloss_ori / sum_non_ignore_weight
|
200
|
+
dloss_smooth = dloss_smooth / sum_non_ignore_weight
|
201
|
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
202
|
+
dz_loss = dz_loss / n_non_ignore
|
203
|
+
# derivative of total_loss
|
204
|
+
X_block = dloss_ori + dloss_smooth + dz_loss
|
205
|
+
|
206
|
+
# chain rule softcapping
|
207
|
+
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
|
208
|
+
if HAS_SOFTCAPPING:
|
209
|
+
X_block = X_block * (1 - intermediate * intermediate)
|
210
|
+
|
211
|
+
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
210
212
|
|
211
213
|
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
|
212
214
|
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
|
@@ -332,6 +334,7 @@ def cross_entropy_forward(
|
|
332
334
|
BLOCK_SIZE=BLOCK_SIZE,
|
333
335
|
HAS_WEIGHT=True if weight is not None else False,
|
334
336
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
337
|
+
HAS_GRADIENTS=_input.requires_grad,
|
335
338
|
# TODO: 32 seems to give the best performance
|
336
339
|
# Performance is quite sensitive to num_warps
|
337
340
|
num_warps=32 if not is_hip() else 16,
|
@@ -150,6 +150,7 @@ def fused_linear_cross_entropy_forward(
|
|
150
150
|
RETURN_Z_LOSS=return_z_loss,
|
151
151
|
HAS_WEIGHT=True if ce_weight is not None else False,
|
152
152
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
153
|
+
HAS_GRADIENTS=_input.requires_grad,
|
153
154
|
BLOCK_SIZE=BLOCK_SIZE,
|
154
155
|
num_warps=32 if not is_hip() else 16,
|
155
156
|
)
|
@@ -173,10 +174,10 @@ def fused_linear_cross_entropy_forward(
|
|
173
174
|
|
174
175
|
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
|
175
176
|
|
176
|
-
if grad_weight is not None:
|
177
|
+
if grad_weight is not None and _input.requires_grad:
|
177
178
|
grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
|
178
179
|
|
179
|
-
if bias is not None:
|
180
|
+
if bias is not None and _input.requires_grad:
|
180
181
|
torch.add(
|
181
182
|
input=grad_bias,
|
182
183
|
other=grad_logits_chunk.sum(dim=0),
|
@@ -17,10 +17,10 @@ liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsm
|
|
17
17
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
|
18
18
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=fy2w8KbhMrBv7b1jdIeH3bBFxY52bPQPZb3KwBvmurM,5385
|
19
19
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
|
-
liger_kernel/ops/cross_entropy.py,sha256=
|
20
|
+
liger_kernel/ops/cross_entropy.py,sha256=OVkani9JEmCJ8IHN3UgJKzGW7zxJWDwy1EaWVcbShgQ,19517
|
21
21
|
liger_kernel/ops/dyt.py,sha256=gCLz4S8aul8SY9nvIGaoK67aGb7U9MJRQdo3ONqmQYs,5417
|
22
22
|
liger_kernel/ops/fused_add_rms_norm.py,sha256=UBqmlqFCmhSAIpkNKd8rrfXatX7Z4J9bp2dX9A0lrJQ,14017
|
23
|
-
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=
|
23
|
+
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=PqIPHU8EjkHRJF6cNZViDucFVOgqo7eanJxB53Npke8,14388
|
24
24
|
liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
|
25
25
|
liger_kernel/ops/fused_neighborhood_attention.py,sha256=vPi5xbnh6wxyZehaqo6Tuilqo2fN5SGDiONjnNmIKqs,35556
|
26
26
|
liger_kernel/ops/geglu.py,sha256=r0WSq9E93zzynL44Wh8femzOWK07_SseBM_pJUyxT3s,4144
|
@@ -99,9 +99,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
99
99
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
100
100
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
101
101
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
102
|
-
liger_kernel_nightly-0.6.2.
|
103
|
-
liger_kernel_nightly-0.6.2.
|
104
|
-
liger_kernel_nightly-0.6.2.
|
105
|
-
liger_kernel_nightly-0.6.2.
|
106
|
-
liger_kernel_nightly-0.6.2.
|
107
|
-
liger_kernel_nightly-0.6.2.
|
102
|
+
liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
103
|
+
liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/METADATA,sha256=3CtD4mdR4zhG-Dj4OQESjqTdQrC1_w-gVsOuzIosGW8,24777
|
104
|
+
liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
105
|
+
liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
106
|
+
liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
107
|
+
liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|