liger-kernel-nightly 0.6.2.dev20251011154226__py3-none-any.whl → 0.6.2.dev20251013144132__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/cross_entropy.py +55 -52
- liger_kernel/ops/fused_linear_cross_entropy.py +3 -2
- liger_kernel/transformers/monkey_patch.py +5 -2
- {liger_kernel_nightly-0.6.2.dev20251011154226.dist-info → liger_kernel_nightly-0.6.2.dev20251013144132.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.6.2.dev20251011154226.dist-info → liger_kernel_nightly-0.6.2.dev20251013144132.dist-info}/RECORD +9 -9
- {liger_kernel_nightly-0.6.2.dev20251011154226.dist-info → liger_kernel_nightly-0.6.2.dev20251013144132.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154226.dist-info → liger_kernel_nightly-0.6.2.dev20251013144132.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154226.dist-info → liger_kernel_nightly-0.6.2.dev20251013144132.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154226.dist-info → liger_kernel_nightly-0.6.2.dev20251013144132.dist-info}/top_level.txt +0 -0
@@ -45,6 +45,7 @@ def liger_cross_entropy_kernel(
|
|
45
45
|
BLOCK_SIZE: tl.constexpr,
|
46
46
|
HAS_WEIGHT: tl.constexpr,
|
47
47
|
HAS_SOFTCAPPING: tl.constexpr,
|
48
|
+
HAS_GRADIENTS: tl.constexpr,
|
48
49
|
):
|
49
50
|
"""
|
50
51
|
This kernel computes both cross entropy loss and the gradient of the input.
|
@@ -72,6 +73,7 @@ def liger_cross_entropy_kernel(
|
|
72
73
|
BLOCK_SIZE (int): The block size for Triton operations.
|
73
74
|
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
74
75
|
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
76
|
+
HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
|
75
77
|
"""
|
76
78
|
|
77
79
|
# https://github.com/triton-lang/triton/issues/1058
|
@@ -155,58 +157,58 @@ def liger_cross_entropy_kernel(
|
|
155
157
|
# For 'sum' reduction, no normalization is applied:
|
156
158
|
# dx_y = softmax(x_y) - 1
|
157
159
|
# dx_i = softmax(x_i), for i ≠ y
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
160
|
+
if HAS_GRADIENTS:
|
161
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
162
|
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
163
|
+
X_block = tl.load(
|
164
|
+
X_ptr + X_offsets,
|
165
|
+
mask=X_offsets < n_cols,
|
166
|
+
other=float("-inf"),
|
167
|
+
# Ensure float32 precision for softmax calculation
|
168
|
+
).cast(tl.float32)
|
169
|
+
if HAS_SOFTCAPPING:
|
170
|
+
intermediate = tanh(X_block / softcap)
|
171
|
+
X_block = softcap * intermediate
|
172
|
+
|
173
|
+
if not HAS_WEIGHT:
|
174
|
+
# softmax(x_i)
|
175
|
+
X_block = tl.exp(X_block - m) / d
|
176
|
+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
177
|
+
X_block += 2 * lse_square_scale * lse * X_block
|
178
|
+
# smoothing term
|
179
|
+
X_block += -eps
|
180
|
+
# special handle dx_y
|
181
|
+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
|
182
|
+
# reduction scale
|
183
|
+
if reduction == "mean":
|
184
|
+
X_block = X_block / n_non_ignore
|
185
|
+
else:
|
186
|
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
187
|
+
softmax_X = tl.exp(X_block - m) / d
|
188
|
+
# derivative of original_loss
|
189
|
+
dloss_ori = (1 - label_smoothing) * softmax_X
|
190
|
+
# specially handle dx_y
|
191
|
+
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
|
192
|
+
dloss_ori = dloss_ori * weight_y
|
193
|
+
# derivative of smooth_loss
|
194
|
+
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
|
195
|
+
# derivative of z-loss
|
196
|
+
dz_loss = 2 * lse_square_scale * lse * softmax_X
|
197
|
+
# reduction scale
|
198
|
+
if reduction == "mean":
|
199
|
+
dloss_ori = dloss_ori / sum_non_ignore_weight
|
200
|
+
dloss_smooth = dloss_smooth / sum_non_ignore_weight
|
201
|
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
202
|
+
dz_loss = dz_loss / n_non_ignore
|
203
|
+
# derivative of total_loss
|
204
|
+
X_block = dloss_ori + dloss_smooth + dz_loss
|
205
|
+
|
206
|
+
# chain rule softcapping
|
207
|
+
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
|
208
|
+
if HAS_SOFTCAPPING:
|
209
|
+
X_block = X_block * (1 - intermediate * intermediate)
|
210
|
+
|
211
|
+
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
210
212
|
|
211
213
|
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
|
212
214
|
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
|
@@ -332,6 +334,7 @@ def cross_entropy_forward(
|
|
332
334
|
BLOCK_SIZE=BLOCK_SIZE,
|
333
335
|
HAS_WEIGHT=True if weight is not None else False,
|
334
336
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
337
|
+
HAS_GRADIENTS=_input.requires_grad,
|
335
338
|
# TODO: 32 seems to give the best performance
|
336
339
|
# Performance is quite sensitive to num_warps
|
337
340
|
num_warps=32 if not is_hip() else 16,
|
@@ -150,6 +150,7 @@ def fused_linear_cross_entropy_forward(
|
|
150
150
|
RETURN_Z_LOSS=return_z_loss,
|
151
151
|
HAS_WEIGHT=True if ce_weight is not None else False,
|
152
152
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
153
|
+
HAS_GRADIENTS=_input.requires_grad,
|
153
154
|
BLOCK_SIZE=BLOCK_SIZE,
|
154
155
|
num_warps=32 if not is_hip() else 16,
|
155
156
|
)
|
@@ -173,10 +174,10 @@ def fused_linear_cross_entropy_forward(
|
|
173
174
|
|
174
175
|
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
|
175
176
|
|
176
|
-
if grad_weight is not None:
|
177
|
+
if grad_weight is not None and _input.requires_grad:
|
177
178
|
grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
|
178
179
|
|
179
|
-
if bias is not None:
|
180
|
+
if bias is not None and _input.requires_grad:
|
180
181
|
torch.add(
|
181
182
|
input=grad_bias,
|
182
183
|
other=grad_logits_chunk.sum(dim=0),
|
@@ -469,7 +469,7 @@ def apply_liger_kernel_to_llama4(
|
|
469
469
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
470
470
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
471
471
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
472
|
-
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is
|
472
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
473
473
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
474
474
|
loaded. Default is None.
|
475
475
|
"""
|
@@ -522,7 +522,10 @@ def apply_liger_kernel_to_llama4(
|
|
522
522
|
_patch_rms_norm_module(text_model.norm)
|
523
523
|
for decoder_layer in text_model.layers:
|
524
524
|
if swiglu:
|
525
|
-
|
525
|
+
if decoder_layer.is_moe_layer:
|
526
|
+
_patch_swiglu_module(decoder_layer.feed_forward.shared_expert, LigerSwiGLUMLP)
|
527
|
+
else:
|
528
|
+
_patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
|
526
529
|
if rms_norm:
|
527
530
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
528
531
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
@@ -17,10 +17,10 @@ liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsm
|
|
17
17
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
|
18
18
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=fy2w8KbhMrBv7b1jdIeH3bBFxY52bPQPZb3KwBvmurM,5385
|
19
19
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
|
-
liger_kernel/ops/cross_entropy.py,sha256=
|
20
|
+
liger_kernel/ops/cross_entropy.py,sha256=OVkani9JEmCJ8IHN3UgJKzGW7zxJWDwy1EaWVcbShgQ,19517
|
21
21
|
liger_kernel/ops/dyt.py,sha256=gCLz4S8aul8SY9nvIGaoK67aGb7U9MJRQdo3ONqmQYs,5417
|
22
22
|
liger_kernel/ops/fused_add_rms_norm.py,sha256=UBqmlqFCmhSAIpkNKd8rrfXatX7Z4J9bp2dX9A0lrJQ,14017
|
23
|
-
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=
|
23
|
+
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=PqIPHU8EjkHRJF6cNZViDucFVOgqo7eanJxB53Npke8,14388
|
24
24
|
liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
|
25
25
|
liger_kernel/ops/fused_neighborhood_attention.py,sha256=vPi5xbnh6wxyZehaqo6Tuilqo2fN5SGDiONjnNmIKqs,35556
|
26
26
|
liger_kernel/ops/geglu.py,sha256=r0WSq9E93zzynL44Wh8femzOWK07_SseBM_pJUyxT3s,4144
|
@@ -58,7 +58,7 @@ liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCc
|
|
58
58
|
liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
|
59
59
|
liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
|
60
60
|
liger_kernel/transformers/llama4_rope.py,sha256=kS6PSHEwf3dS7hD7C7p8S0geugx2EMCiP0h0F7LsUoY,3639
|
61
|
-
liger_kernel/transformers/monkey_patch.py,sha256=
|
61
|
+
liger_kernel/transformers/monkey_patch.py,sha256=TUmx8aY0lonyThcATirRBdSs7uItVvnBggohjBItBuQ,106060
|
62
62
|
liger_kernel/transformers/multi_token_attention.py,sha256=K3NIY9_5TPgZ4_Rahn0xnkMXxD_fmlJHK4CWGYvGQp0,1752
|
63
63
|
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
64
64
|
liger_kernel/transformers/rms_norm.py,sha256=vkekcvTeWY8vL4H6hg3t0XeY0Ew_3OFMPHuzqlxPPVw,2719
|
@@ -99,9 +99,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
99
99
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
100
100
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
101
101
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
102
|
-
liger_kernel_nightly-0.6.2.
|
103
|
-
liger_kernel_nightly-0.6.2.
|
104
|
-
liger_kernel_nightly-0.6.2.
|
105
|
-
liger_kernel_nightly-0.6.2.
|
106
|
-
liger_kernel_nightly-0.6.2.
|
107
|
-
liger_kernel_nightly-0.6.2.
|
102
|
+
liger_kernel_nightly-0.6.2.dev20251013144132.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
103
|
+
liger_kernel_nightly-0.6.2.dev20251013144132.dist-info/METADATA,sha256=3lZjwj_uIcS1aYE--_B3JuOh95x-txytvJPkdZGO_QA,24777
|
104
|
+
liger_kernel_nightly-0.6.2.dev20251013144132.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
105
|
+
liger_kernel_nightly-0.6.2.dev20251013144132.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
106
|
+
liger_kernel_nightly-0.6.2.dev20251013144132.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
107
|
+
liger_kernel_nightly-0.6.2.dev20251013144132.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|