liger-kernel-nightly 0.6.2.dev20250822000312__py3-none-any.whl → 0.6.2.dev20250822031344__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/fused_linear_cross_entropy.py +41 -1
- liger_kernel/transformers/functional.py +2 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +3 -0
- {liger_kernel_nightly-0.6.2.dev20250822000312.dist-info → liger_kernel_nightly-0.6.2.dev20250822031344.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.6.2.dev20250822000312.dist-info → liger_kernel_nightly-0.6.2.dev20250822031344.dist-info}/RECORD +9 -9
- {liger_kernel_nightly-0.6.2.dev20250822000312.dist-info → liger_kernel_nightly-0.6.2.dev20250822031344.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.2.dev20250822000312.dist-info → liger_kernel_nightly-0.6.2.dev20250822031344.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.2.dev20250822000312.dist-info → liger_kernel_nightly-0.6.2.dev20250822031344.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.2.dev20250822000312.dist-info → liger_kernel_nightly-0.6.2.dev20250822031344.dist-info}/top_level.txt +0 -0
@@ -26,6 +26,7 @@ def fused_linear_cross_entropy_forward(
|
|
26
26
|
softcap=None,
|
27
27
|
return_z_loss=False,
|
28
28
|
accum_dtype=None,
|
29
|
+
use_token_scaling=False,
|
29
30
|
):
|
30
31
|
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
31
32
|
device = _input.device
|
@@ -89,6 +90,23 @@ def fused_linear_cross_entropy_forward(
|
|
89
90
|
|
90
91
|
n_rows = logits_chunk.shape[0]
|
91
92
|
|
93
|
+
# Compute predicted probabilities for token scaling if needed
|
94
|
+
if use_token_scaling:
|
95
|
+
# Compute softmax probabilities for scaling
|
96
|
+
# We need to compute this before the cross entropy kernel modifies logits_chunk
|
97
|
+
logits_for_softmax = logits_chunk.detach().clone() # Detach to avoid gradient flow
|
98
|
+
if softcap is not None:
|
99
|
+
logits_for_softmax = softcap * torch.tanh(logits_for_softmax / softcap)
|
100
|
+
|
101
|
+
# Compute softmax to get predicted probabilities
|
102
|
+
probs = torch.softmax(logits_for_softmax, dim=-1)
|
103
|
+
|
104
|
+
# Get the predicted probability for each target token
|
105
|
+
pred_probs = torch.gather(probs, -1, target_chunk.unsqueeze(-1)).squeeze(-1)
|
106
|
+
|
107
|
+
# Store the scaling factors
|
108
|
+
scaling_factors = pred_probs.detach() # Detach to ensure no gradient flow
|
109
|
+
|
92
110
|
# unreduced loss
|
93
111
|
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
94
112
|
z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
|
@@ -123,11 +141,23 @@ def fused_linear_cross_entropy_forward(
|
|
123
141
|
num_warps=32 if not is_hip() else 16,
|
124
142
|
)
|
125
143
|
|
144
|
+
# Apply token scaling if requested
|
145
|
+
if use_token_scaling:
|
146
|
+
loss_1d_slice = loss_1d_slice * scaling_factors
|
147
|
+
if return_z_loss:
|
148
|
+
z_loss_1d_slice = z_loss_1d_slice * scaling_factors
|
149
|
+
|
126
150
|
loss_1d[start_idx:end_idx] = loss_1d_slice
|
127
151
|
if return_z_loss:
|
128
152
|
z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
|
129
153
|
grad_logits_chunk = logits_chunk # chunk_size x V
|
130
154
|
|
155
|
+
# Apply token scaling to gradients if requested
|
156
|
+
if use_token_scaling:
|
157
|
+
# Expand scaling factors to match gradient dimensions
|
158
|
+
scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
|
159
|
+
grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
|
160
|
+
|
131
161
|
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
|
132
162
|
|
133
163
|
if grad_weight is not None:
|
@@ -136,7 +166,7 @@ def fused_linear_cross_entropy_forward(
|
|
136
166
|
if bias is not None:
|
137
167
|
torch.add(
|
138
168
|
input=grad_bias,
|
139
|
-
other=
|
169
|
+
other=grad_logits_chunk.sum(dim=0),
|
140
170
|
out=grad_bias,
|
141
171
|
alpha=1.0,
|
142
172
|
)
|
@@ -146,6 +176,10 @@ def fused_linear_cross_entropy_forward(
|
|
146
176
|
# loss = loss_1d
|
147
177
|
# z_loss = z_loss_1d if return_z_loss else None
|
148
178
|
|
179
|
+
if reduction == "none":
|
180
|
+
# Return per-token losses
|
181
|
+
loss = loss_1d
|
182
|
+
z_loss = z_loss_1d if return_z_loss else None
|
149
183
|
else:
|
150
184
|
loss = torch.sum(loss_1d)
|
151
185
|
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
@@ -221,6 +255,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
221
255
|
softcap=None,
|
222
256
|
return_z_loss: bool = False,
|
223
257
|
accum_dtype=None,
|
258
|
+
use_token_scaling: bool = False,
|
224
259
|
):
|
225
260
|
"""
|
226
261
|
Fusing the last linear layer with cross-entropy loss
|
@@ -241,6 +276,9 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
241
276
|
reduction: reduction to apply
|
242
277
|
accum_dtype (torch.dtype): the dtype of intermediate result buffers for weight and bias gradient accumulations.
|
243
278
|
Recommended to set `accum_dtype` to higher precision, e.g. `torch.float32`, if the training is unstable with original dtype. Default: `None`, performing accumulations in original dtype
|
279
|
+
use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
|
280
|
+
When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
|
281
|
+
Default: False.
|
244
282
|
"""
|
245
283
|
|
246
284
|
loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
@@ -256,6 +294,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
256
294
|
softcap=softcap,
|
257
295
|
return_z_loss=return_z_loss,
|
258
296
|
accum_dtype=accum_dtype,
|
297
|
+
use_token_scaling=use_token_scaling,
|
259
298
|
)
|
260
299
|
# downcast to dtype and store for backward
|
261
300
|
ctx.save_for_backward(
|
@@ -288,4 +327,5 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
288
327
|
None,
|
289
328
|
None,
|
290
329
|
None,
|
330
|
+
None, # use_token_scaling
|
291
331
|
)
|
@@ -65,6 +65,7 @@ def liger_fused_linear_cross_entropy(
|
|
65
65
|
softcap: Optional[float] = None,
|
66
66
|
return_z_loss: bool = False,
|
67
67
|
accum_dtype=None,
|
68
|
+
use_token_scaling: bool = False,
|
68
69
|
):
|
69
70
|
loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
|
70
71
|
input,
|
@@ -79,6 +80,7 @@ def liger_fused_linear_cross_entropy(
|
|
79
80
|
softcap,
|
80
81
|
return_z_loss,
|
81
82
|
accum_dtype,
|
83
|
+
use_token_scaling,
|
82
84
|
)
|
83
85
|
if not return_z_loss:
|
84
86
|
return loss
|
@@ -16,6 +16,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
16
16
|
softcap: Optional[float] = None,
|
17
17
|
return_z_loss: bool = False,
|
18
18
|
accum_dtype: Optional[torch.dtype] = None,
|
19
|
+
use_token_scaling: bool = False,
|
19
20
|
):
|
20
21
|
super().__init__()
|
21
22
|
assert (label_smoothing >= 0) and (label_smoothing <= 1), (
|
@@ -34,6 +35,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
34
35
|
self.softcap = softcap
|
35
36
|
self.return_z_loss = return_z_loss
|
36
37
|
self.accum_dtype = accum_dtype
|
38
|
+
self.use_token_scaling = use_token_scaling
|
37
39
|
|
38
40
|
def forward(self, lin_weight, _input, target, bias=None):
|
39
41
|
loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
|
@@ -49,6 +51,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
49
51
|
self.softcap,
|
50
52
|
self.return_z_loss,
|
51
53
|
self.accum_dtype,
|
54
|
+
self.use_token_scaling,
|
52
55
|
)
|
53
56
|
if not self.return_z_loss:
|
54
57
|
return loss
|
@@ -20,7 +20,7 @@ liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
|
|
20
20
|
liger_kernel/ops/cross_entropy.py,sha256=e8THGnhOcy_0SbOLABx67HEM7-B8a8pG7nDKbCRpQKM,19123
|
21
21
|
liger_kernel/ops/dyt.py,sha256=gCLz4S8aul8SY9nvIGaoK67aGb7U9MJRQdo3ONqmQYs,5417
|
22
22
|
liger_kernel/ops/fused_add_rms_norm.py,sha256=UBqmlqFCmhSAIpkNKd8rrfXatX7Z4J9bp2dX9A0lrJQ,14017
|
23
|
-
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=
|
23
|
+
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=AIlKMOnM3J7ZeAgPP1uvA3T4OIeRkz6TTr_Lg9XgZGY,13581
|
24
24
|
liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
|
25
25
|
liger_kernel/ops/fused_neighborhood_attention.py,sha256=vPi5xbnh6wxyZehaqo6Tuilqo2fN5SGDiONjnNmIKqs,35556
|
26
26
|
liger_kernel/ops/geglu.py,sha256=r0WSq9E93zzynL44Wh8femzOWK07_SseBM_pJUyxT3s,4144
|
@@ -46,9 +46,9 @@ liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawX
|
|
46
46
|
liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
|
47
47
|
liger_kernel/transformers/dyt.py,sha256=i-4GPaMrl-jab9TVI5qN0-H9qycn_mCbV82ozU4nbmU,723
|
48
48
|
liger_kernel/transformers/fsdp.py,sha256=CUiyjTmjkjY7pLXQv8ly9rnzgXw6529csd9pvtJNMYc,3096
|
49
|
-
liger_kernel/transformers/functional.py,sha256
|
49
|
+
liger_kernel/transformers/functional.py,sha256=-vpz95wbv5wLpInjSG06KNHETsEgKnRIiV-lMYHVs68,7841
|
50
50
|
liger_kernel/transformers/fused_add_rms_norm.py,sha256=7_Bzg-x6lLe6W1qG2DtjDALhEpNZlC6N5GppEs9cTYY,1199
|
51
|
-
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=
|
51
|
+
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=ZMxkiJzGz1KtqgAdsqPODq3bugHBx_80kPYcd5z-xmM,1990
|
52
52
|
liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
|
53
53
|
liger_kernel/transformers/fused_neighborhood_attention.py,sha256=TxYDUAt9B6WSP14aJP66C_2Mbds2sSIPGnamhUSTrC8,7957
|
54
54
|
liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
|
@@ -96,9 +96,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
96
96
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
97
97
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
98
98
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
99
|
-
liger_kernel_nightly-0.6.2.
|
100
|
-
liger_kernel_nightly-0.6.2.
|
101
|
-
liger_kernel_nightly-0.6.2.
|
102
|
-
liger_kernel_nightly-0.6.2.
|
103
|
-
liger_kernel_nightly-0.6.2.
|
104
|
-
liger_kernel_nightly-0.6.2.
|
99
|
+
liger_kernel_nightly-0.6.2.dev20250822031344.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
100
|
+
liger_kernel_nightly-0.6.2.dev20250822031344.dist-info/METADATA,sha256=XSw3SXL9PGPj5eGacLKkUfGpT7I7_QcYmrFdC75Wuck,24504
|
101
|
+
liger_kernel_nightly-0.6.2.dev20250822031344.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
102
|
+
liger_kernel_nightly-0.6.2.dev20250822031344.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
103
|
+
liger_kernel_nightly-0.6.2.dev20250822031344.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
104
|
+
liger_kernel_nightly-0.6.2.dev20250822031344.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|