liger-kernel 0.6.2__py3-none-any.whl → 0.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/fused_linear_ppo.py +4 -0
- liger_kernel/chunked_loss/grpo_loss.py +38 -4
- liger_kernel/chunked_loss/jsd_loss.py +5 -2
- liger_kernel/ops/cross_entropy.py +59 -53
- liger_kernel/ops/fused_linear_cross_entropy.py +68 -10
- liger_kernel/ops/layer_norm.py +4 -6
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/transformers/__init__.py +17 -0
- liger_kernel/transformers/functional.py +7 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +5 -1
- liger_kernel/transformers/model/falcon_h1.py +108 -0
- liger_kernel/transformers/model/gemma.py +2 -1
- liger_kernel/transformers/model/gemma2.py +8 -2
- liger_kernel/transformers/model/gemma3.py +27 -2
- liger_kernel/transformers/model/glm4.py +2 -1
- liger_kernel/transformers/model/glm4v.py +3 -2
- liger_kernel/transformers/model/glm4v_moe.py +153 -0
- liger_kernel/transformers/model/internvl.py +150 -0
- liger_kernel/transformers/model/llama.py +2 -1
- liger_kernel/transformers/model/llama4.py +2 -1
- liger_kernel/transformers/model/llava.py +6 -2
- liger_kernel/transformers/model/loss_utils.py +1 -0
- liger_kernel/transformers/model/mistral.py +2 -1
- liger_kernel/transformers/model/mixtral.py +8 -2
- liger_kernel/transformers/model/mllama.py +2 -1
- liger_kernel/transformers/model/olmo2.py +2 -1
- liger_kernel/transformers/model/paligemma.py +19 -0
- liger_kernel/transformers/model/phi3.py +2 -1
- liger_kernel/transformers/model/qwen2.py +2 -1
- liger_kernel/transformers/model/qwen2_5_vl.py +7 -2
- liger_kernel/transformers/model/qwen2_vl.py +7 -2
- liger_kernel/transformers/model/qwen3.py +2 -1
- liger_kernel/transformers/model/qwen3_moe.py +8 -2
- liger_kernel/transformers/model/qwen3_next.py +134 -0
- liger_kernel/transformers/model/smollm3.py +2 -1
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +452 -3
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/rms_norm.py +7 -0
- {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/METADATA +13 -10
- {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/RECORD +46 -39
- {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/WHEEL +0 -0
- {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.6.2.dist-info → liger_kernel-0.6.3.dist-info}/top_level.txt +0 -0
|
@@ -34,6 +34,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
34
34
|
beta=0.04,
|
|
35
35
|
loss_type="bnpo",
|
|
36
36
|
max_completion_length=None,
|
|
37
|
+
importance_sampling_level="token",
|
|
37
38
|
temperature=1.0,
|
|
38
39
|
compiled=True,
|
|
39
40
|
use_ref_model=False,
|
|
@@ -92,6 +93,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
92
93
|
beta=beta,
|
|
93
94
|
loss_type=loss_type,
|
|
94
95
|
max_completion_length=max_completion_length,
|
|
96
|
+
importance_sampling_level=importance_sampling_level,
|
|
95
97
|
temperature=temperature,
|
|
96
98
|
use_ref_model=use_ref_model,
|
|
97
99
|
ppo_loss_fn=cls.ppo_loss_fn,
|
|
@@ -261,6 +263,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
261
263
|
beta=0.04,
|
|
262
264
|
loss_type="bnpo",
|
|
263
265
|
max_completion_length=None,
|
|
266
|
+
importance_sampling_level="token",
|
|
264
267
|
temperature=1.0,
|
|
265
268
|
use_ref_model=False,
|
|
266
269
|
ppo_loss_fn=None,
|
|
@@ -292,6 +295,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
292
295
|
beta=beta,
|
|
293
296
|
loss_type=loss_type,
|
|
294
297
|
max_completion_length=max_completion_length,
|
|
298
|
+
importance_sampling_level=importance_sampling_level,
|
|
295
299
|
)
|
|
296
300
|
|
|
297
301
|
return chunk_loss, chunk_metrics
|
|
@@ -31,6 +31,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
31
31
|
beta=0.04,
|
|
32
32
|
loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
|
|
33
33
|
max_completion_length=None, # Required for dr_grpo
|
|
34
|
+
importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
|
|
34
35
|
**kwargs,
|
|
35
36
|
):
|
|
36
37
|
"""GRPO Loss Function matching GRPOTrainer implementation."""
|
|
@@ -50,7 +51,22 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
50
51
|
|
|
51
52
|
# Compute policy gradient loss with importance sampling ratio
|
|
52
53
|
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
|
|
53
|
-
|
|
54
|
+
log_ratio = per_token_logps - old_per_token_logps
|
|
55
|
+
|
|
56
|
+
if importance_sampling_level == "token":
|
|
57
|
+
log_importance_weights = log_ratio
|
|
58
|
+
elif importance_sampling_level == "sequence":
|
|
59
|
+
log_importance_weights = (log_ratio * attention_mask).sum(-1) / attention_mask.sum(-1).clamp(min=1.0)
|
|
60
|
+
log_importance_weights = log_importance_weights.unsqueeze(-1)
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
|
|
64
|
+
"and 'sequence'."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
|
|
68
|
+
# importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
|
|
69
|
+
coef_1 = torch.exp(log_importance_weights)
|
|
54
70
|
coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
|
|
55
71
|
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
|
56
72
|
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
|
@@ -85,9 +101,19 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
85
101
|
metrics = []
|
|
86
102
|
if beta != 0.0:
|
|
87
103
|
metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
104
|
+
|
|
105
|
+
# Adjust clipping metric calculation based on importance sampling level
|
|
106
|
+
if importance_sampling_level == "token":
|
|
107
|
+
is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
|
|
108
|
+
(coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
|
|
109
|
+
)
|
|
110
|
+
else: # sequence level
|
|
111
|
+
# For sequence level, coef_1 is shape (B, 1), advantages is shape (B,)
|
|
112
|
+
is_clipped = ((coef_1.squeeze(-1) < 1 - epsilon_low) & (advantages < 0)) | (
|
|
113
|
+
(coef_1.squeeze(-1) > 1 + epsilon_high) & (advantages > 0)
|
|
114
|
+
)
|
|
115
|
+
is_clipped = is_clipped.unsqueeze(1).expand_as(attention_mask)
|
|
116
|
+
|
|
91
117
|
metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
|
|
92
118
|
return loss, metrics
|
|
93
119
|
|
|
@@ -111,6 +137,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
111
137
|
epsilon_high=0.2,
|
|
112
138
|
loss_type="bnpo",
|
|
113
139
|
max_completion_length=None,
|
|
140
|
+
importance_sampling_level="token",
|
|
114
141
|
temperature=1.0,
|
|
115
142
|
compiled=True,
|
|
116
143
|
use_ref_model=True,
|
|
@@ -132,6 +159,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
132
159
|
beta (float): Weight for the KL penalty
|
|
133
160
|
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
|
|
134
161
|
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
162
|
+
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
|
135
163
|
temperature (float): Temperature for the logits
|
|
136
164
|
compiled (bool): Whether to use torch compile
|
|
137
165
|
use_ref_model (bool): Whether to use a reference model
|
|
@@ -162,6 +190,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
162
190
|
compiled=compiled,
|
|
163
191
|
use_ref_model=use_ref_model,
|
|
164
192
|
chunk_size=chunk_size,
|
|
193
|
+
importance_sampling_level=importance_sampling_level,
|
|
165
194
|
)
|
|
166
195
|
|
|
167
196
|
@staticmethod
|
|
@@ -187,6 +216,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
187
216
|
None, # grad_epsilon_high
|
|
188
217
|
None, # grad_loss_type (string, not differentiable)
|
|
189
218
|
None, # grad_max_completion_length (int, not differentiable)
|
|
219
|
+
None, # grad_importance_sampling_level (string, not differentiable)
|
|
190
220
|
None, # grad_temperature
|
|
191
221
|
None, # grad_compiled
|
|
192
222
|
None, # grad_use_ref_model
|
|
@@ -207,6 +237,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
207
237
|
epsilon_high: float = 0.2,
|
|
208
238
|
loss_type: str = "bnpo",
|
|
209
239
|
max_completion_length: Optional[int] = None,
|
|
240
|
+
importance_sampling_level: str = "token",
|
|
210
241
|
temperature: float = 1.0,
|
|
211
242
|
):
|
|
212
243
|
"""
|
|
@@ -219,6 +250,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
219
250
|
epsilon_high (float): Upper bound for the importance sampling ratio.
|
|
220
251
|
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
|
|
221
252
|
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
253
|
+
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
|
222
254
|
temperature (float): Temperature for the logits.
|
|
223
255
|
"""
|
|
224
256
|
super().__init__()
|
|
@@ -230,6 +262,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
230
262
|
self.epsilon_high = epsilon_high
|
|
231
263
|
self.loss_type = loss_type
|
|
232
264
|
self.max_completion_length = max_completion_length
|
|
265
|
+
self.importance_sampling_level = importance_sampling_level
|
|
233
266
|
self.temperature = temperature
|
|
234
267
|
|
|
235
268
|
def forward(
|
|
@@ -263,6 +296,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
263
296
|
self.epsilon_high,
|
|
264
297
|
self.loss_type,
|
|
265
298
|
self.max_completion_length,
|
|
299
|
+
self.importance_sampling_level,
|
|
266
300
|
self.temperature,
|
|
267
301
|
self.compiled,
|
|
268
302
|
self.use_ref_model,
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import torch.nn.functional as F
|
|
3
5
|
|
|
@@ -25,8 +27,9 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
25
27
|
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
|
|
26
28
|
else:
|
|
27
29
|
# Compute probabilities (only required for mean calculation)
|
|
28
|
-
|
|
29
|
-
|
|
30
|
+
log_mean_probs = torch.logsumexp(
|
|
31
|
+
torch.stack([student_log_probs + math.log(1 - beta), teacher_log_probs + math.log(beta)], dim=0), dim=0
|
|
32
|
+
)
|
|
30
33
|
|
|
31
34
|
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
|
|
32
35
|
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
|
|
@@ -45,6 +45,7 @@ def liger_cross_entropy_kernel(
|
|
|
45
45
|
BLOCK_SIZE: tl.constexpr,
|
|
46
46
|
HAS_WEIGHT: tl.constexpr,
|
|
47
47
|
HAS_SOFTCAPPING: tl.constexpr,
|
|
48
|
+
HAS_GRADIENTS: tl.constexpr,
|
|
48
49
|
):
|
|
49
50
|
"""
|
|
50
51
|
This kernel computes both cross entropy loss and the gradient of the input.
|
|
@@ -72,6 +73,7 @@ def liger_cross_entropy_kernel(
|
|
|
72
73
|
BLOCK_SIZE (int): The block size for Triton operations.
|
|
73
74
|
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
|
74
75
|
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
|
76
|
+
HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
|
|
75
77
|
"""
|
|
76
78
|
|
|
77
79
|
# https://github.com/triton-lang/triton/issues/1058
|
|
@@ -155,58 +157,58 @@ def liger_cross_entropy_kernel(
|
|
|
155
157
|
# For 'sum' reduction, no normalization is applied:
|
|
156
158
|
# dx_y = softmax(x_y) - 1
|
|
157
159
|
# dx_i = softmax(x_i), for i ≠ y
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
160
|
+
if HAS_GRADIENTS:
|
|
161
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
162
|
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
163
|
+
X_block = tl.load(
|
|
164
|
+
X_ptr + X_offsets,
|
|
165
|
+
mask=X_offsets < n_cols,
|
|
166
|
+
other=float("-inf"),
|
|
167
|
+
# Ensure float32 precision for softmax calculation
|
|
168
|
+
).cast(tl.float32)
|
|
169
|
+
if HAS_SOFTCAPPING:
|
|
170
|
+
intermediate = tanh(X_block / softcap)
|
|
171
|
+
X_block = softcap * intermediate
|
|
172
|
+
|
|
173
|
+
if not HAS_WEIGHT:
|
|
174
|
+
# softmax(x_i)
|
|
175
|
+
X_block = tl.exp(X_block - m) / d
|
|
176
|
+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
|
177
|
+
X_block += 2 * lse_square_scale * lse * X_block
|
|
178
|
+
# smoothing term
|
|
179
|
+
X_block += -eps
|
|
180
|
+
# special handle dx_y
|
|
181
|
+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
|
|
182
|
+
# reduction scale
|
|
183
|
+
if reduction == "mean":
|
|
184
|
+
X_block = X_block / n_non_ignore
|
|
185
|
+
else:
|
|
186
|
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
187
|
+
softmax_X = tl.exp(X_block - m) / d
|
|
188
|
+
# derivative of original_loss
|
|
189
|
+
dloss_ori = (1 - label_smoothing) * softmax_X
|
|
190
|
+
# specially handle dx_y
|
|
191
|
+
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
|
|
192
|
+
dloss_ori = dloss_ori * weight_y
|
|
193
|
+
# derivative of smooth_loss
|
|
194
|
+
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
|
|
195
|
+
# derivative of z-loss
|
|
196
|
+
dz_loss = 2 * lse_square_scale * lse * softmax_X
|
|
197
|
+
# reduction scale
|
|
198
|
+
if reduction == "mean":
|
|
199
|
+
dloss_ori = dloss_ori / sum_non_ignore_weight
|
|
200
|
+
dloss_smooth = dloss_smooth / sum_non_ignore_weight
|
|
201
|
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
|
202
|
+
dz_loss = dz_loss / n_non_ignore
|
|
203
|
+
# derivative of total_loss
|
|
204
|
+
X_block = dloss_ori + dloss_smooth + dz_loss
|
|
205
|
+
|
|
206
|
+
# chain rule softcapping
|
|
207
|
+
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
|
|
208
|
+
if HAS_SOFTCAPPING:
|
|
209
|
+
X_block = X_block * (1 - intermediate * intermediate)
|
|
210
|
+
|
|
211
|
+
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
|
210
212
|
|
|
211
213
|
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
|
|
212
214
|
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
|
|
@@ -332,6 +334,7 @@ def cross_entropy_forward(
|
|
|
332
334
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
333
335
|
HAS_WEIGHT=True if weight is not None else False,
|
|
334
336
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
337
|
+
HAS_GRADIENTS=_input.requires_grad,
|
|
335
338
|
# TODO: 32 seems to give the best performance
|
|
336
339
|
# Performance is quite sensitive to num_warps
|
|
337
340
|
num_warps=32 if not is_hip() else 16,
|
|
@@ -411,6 +414,8 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
411
414
|
Returns:
|
|
412
415
|
tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
|
|
413
416
|
"""
|
|
417
|
+
input_requires_grad = _input.requires_grad
|
|
418
|
+
|
|
414
419
|
loss, z_loss, _input = cross_entropy_forward(
|
|
415
420
|
_input,
|
|
416
421
|
target,
|
|
@@ -425,7 +430,8 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
425
430
|
# TODO: investigation
|
|
426
431
|
# If we don't detach the _input tensor, the memory will double
|
|
427
432
|
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
|
428
|
-
|
|
433
|
+
if input_requires_grad:
|
|
434
|
+
ctx.save_for_backward(_input.detach())
|
|
429
435
|
ctx.return_z_loss = return_z_loss
|
|
430
436
|
|
|
431
437
|
return loss, z_loss
|
|
@@ -26,10 +26,13 @@ def fused_linear_cross_entropy_forward(
|
|
|
26
26
|
softcap=None,
|
|
27
27
|
return_z_loss=False,
|
|
28
28
|
accum_dtype=None,
|
|
29
|
+
use_token_scaling=False,
|
|
29
30
|
):
|
|
30
31
|
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
31
32
|
device = _input.device
|
|
32
33
|
|
|
34
|
+
input_requires_grad = _input.requires_grad
|
|
35
|
+
|
|
33
36
|
# inputs have shape: BT x H
|
|
34
37
|
# materialized activations will have shape: BT x V
|
|
35
38
|
# the increase in memory = BT x V
|
|
@@ -48,12 +51,13 @@ def fused_linear_cross_entropy_forward(
|
|
|
48
51
|
grad_input = torch.zeros_like(_input, device=device)
|
|
49
52
|
|
|
50
53
|
# we use fp32 for loss and gradients accumulator
|
|
51
|
-
if
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
54
|
+
if input_requires_grad:
|
|
55
|
+
if accum_dtype is None:
|
|
56
|
+
grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
|
|
57
|
+
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
|
|
58
|
+
else:
|
|
59
|
+
grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
|
|
60
|
+
grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
|
|
57
61
|
|
|
58
62
|
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
|
59
63
|
z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
@@ -89,6 +93,36 @@ def fused_linear_cross_entropy_forward(
|
|
|
89
93
|
|
|
90
94
|
n_rows = logits_chunk.shape[0]
|
|
91
95
|
|
|
96
|
+
# Compute predicted probabilities for token scaling if needed
|
|
97
|
+
if use_token_scaling:
|
|
98
|
+
# Compute softmax probabilities for scaling
|
|
99
|
+
# We need to compute this before the cross entropy kernel modifies logits_chunk
|
|
100
|
+
logits_for_softmax = logits_chunk.detach().clone() # Detach to avoid gradient flow
|
|
101
|
+
if softcap is not None:
|
|
102
|
+
logits_for_softmax = softcap * torch.tanh(logits_for_softmax / softcap)
|
|
103
|
+
|
|
104
|
+
# Compute softmax to get predicted probabilities
|
|
105
|
+
probs = torch.softmax(logits_for_softmax, dim=-1)
|
|
106
|
+
|
|
107
|
+
# Get predicted probabilities for token scaling, handling ignored targets
|
|
108
|
+
valid_target_mask = target_chunk != ignore_index
|
|
109
|
+
valid_targets = target_chunk[valid_target_mask]
|
|
110
|
+
|
|
111
|
+
if len(valid_targets) > 0:
|
|
112
|
+
# Gather probabilities only for valid targets
|
|
113
|
+
valid_probs = probs[valid_target_mask]
|
|
114
|
+
pred_probs_valid = torch.gather(valid_probs, -1, valid_targets.unsqueeze(-1)).squeeze(-1)
|
|
115
|
+
|
|
116
|
+
# Create full tensor with zeros for ignored targets
|
|
117
|
+
pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
|
|
118
|
+
pred_probs[valid_target_mask] = pred_probs_valid
|
|
119
|
+
else:
|
|
120
|
+
# All targets are ignored
|
|
121
|
+
pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
|
|
122
|
+
|
|
123
|
+
# Store the scaling factors
|
|
124
|
+
scaling_factors = pred_probs.detach() # Detach to ensure no gradient flow
|
|
125
|
+
|
|
92
126
|
# unreduced loss
|
|
93
127
|
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
|
94
128
|
z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
|
|
@@ -119,24 +153,38 @@ def fused_linear_cross_entropy_forward(
|
|
|
119
153
|
RETURN_Z_LOSS=return_z_loss,
|
|
120
154
|
HAS_WEIGHT=True if ce_weight is not None else False,
|
|
121
155
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
156
|
+
HAS_GRADIENTS=input_requires_grad,
|
|
122
157
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
123
158
|
num_warps=32 if not is_hip() else 16,
|
|
124
159
|
)
|
|
125
160
|
|
|
161
|
+
# Apply token scaling if requested
|
|
162
|
+
if use_token_scaling:
|
|
163
|
+
loss_1d_slice = loss_1d_slice * scaling_factors
|
|
164
|
+
if return_z_loss:
|
|
165
|
+
z_loss_1d_slice = z_loss_1d_slice * scaling_factors
|
|
166
|
+
|
|
126
167
|
loss_1d[start_idx:end_idx] = loss_1d_slice
|
|
127
168
|
if return_z_loss:
|
|
128
169
|
z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
|
|
129
170
|
grad_logits_chunk = logits_chunk # chunk_size x V
|
|
130
171
|
|
|
131
|
-
|
|
172
|
+
# Apply token scaling to gradients if requested
|
|
173
|
+
if use_token_scaling:
|
|
174
|
+
# Expand scaling factors to match gradient dimensions
|
|
175
|
+
scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
|
|
176
|
+
grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
|
|
132
177
|
|
|
133
|
-
if
|
|
178
|
+
if input_requires_grad:
|
|
179
|
+
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
|
|
180
|
+
|
|
181
|
+
if grad_weight is not None and input_requires_grad:
|
|
134
182
|
grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
|
|
135
183
|
|
|
136
|
-
if bias is not None:
|
|
184
|
+
if bias is not None and input_requires_grad:
|
|
137
185
|
torch.add(
|
|
138
186
|
input=grad_bias,
|
|
139
|
-
other=
|
|
187
|
+
other=grad_logits_chunk.sum(dim=0),
|
|
140
188
|
out=grad_bias,
|
|
141
189
|
alpha=1.0,
|
|
142
190
|
)
|
|
@@ -146,6 +194,10 @@ def fused_linear_cross_entropy_forward(
|
|
|
146
194
|
# loss = loss_1d
|
|
147
195
|
# z_loss = z_loss_1d if return_z_loss else None
|
|
148
196
|
|
|
197
|
+
if reduction == "none":
|
|
198
|
+
# Return per-token losses
|
|
199
|
+
loss = loss_1d
|
|
200
|
+
z_loss = z_loss_1d if return_z_loss else None
|
|
149
201
|
else:
|
|
150
202
|
loss = torch.sum(loss_1d)
|
|
151
203
|
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
@@ -221,6 +273,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
221
273
|
softcap=None,
|
|
222
274
|
return_z_loss: bool = False,
|
|
223
275
|
accum_dtype=None,
|
|
276
|
+
use_token_scaling: bool = False,
|
|
224
277
|
):
|
|
225
278
|
"""
|
|
226
279
|
Fusing the last linear layer with cross-entropy loss
|
|
@@ -241,6 +294,9 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
241
294
|
reduction: reduction to apply
|
|
242
295
|
accum_dtype (torch.dtype): the dtype of intermediate result buffers for weight and bias gradient accumulations.
|
|
243
296
|
Recommended to set `accum_dtype` to higher precision, e.g. `torch.float32`, if the training is unstable with original dtype. Default: `None`, performing accumulations in original dtype
|
|
297
|
+
use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
|
|
298
|
+
When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
|
|
299
|
+
Default: False.
|
|
244
300
|
"""
|
|
245
301
|
|
|
246
302
|
loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
@@ -256,6 +312,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
256
312
|
softcap=softcap,
|
|
257
313
|
return_z_loss=return_z_loss,
|
|
258
314
|
accum_dtype=accum_dtype,
|
|
315
|
+
use_token_scaling=use_token_scaling,
|
|
259
316
|
)
|
|
260
317
|
# downcast to dtype and store for backward
|
|
261
318
|
ctx.save_for_backward(
|
|
@@ -288,4 +345,5 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
288
345
|
None,
|
|
289
346
|
None,
|
|
290
347
|
None,
|
|
348
|
+
None, # use_token_scaling
|
|
291
349
|
)
|
liger_kernel/ops/layer_norm.py
CHANGED
|
@@ -63,12 +63,11 @@ def _layer_norm_forward_kernel(
|
|
|
63
63
|
X_f32 = X_row.to(tl.float32)
|
|
64
64
|
|
|
65
65
|
# Compute statistics in fp32 for numerical stability
|
|
66
|
-
|
|
67
|
-
mean = tl.sum(X_f32, axis=0) / n_cols_f32
|
|
66
|
+
mean = tl.sum(X_f32, axis=0) / n_cols
|
|
68
67
|
X_centered = X_f32 - mean
|
|
69
68
|
# Apply mask to variance calculation to exclude contributions from masked elements
|
|
70
69
|
X_centered_masked = tl.where(mask, X_centered, 0.0)
|
|
71
|
-
var = tl.sum(X_centered_masked * X_centered_masked, axis=0) /
|
|
70
|
+
var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
|
|
72
71
|
rstd = rsqrt(var + eps)
|
|
73
72
|
|
|
74
73
|
# Store statistics (convert back to original dtype only once)
|
|
@@ -113,7 +112,6 @@ def _layer_norm_backward_kernel(
|
|
|
113
112
|
# Pre-load weights once (same optimization as forward pass)
|
|
114
113
|
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
|
115
114
|
w_f32 = w.to(tl.float32)
|
|
116
|
-
n_cols_f32 = n_cols.to(tl.float32)
|
|
117
115
|
|
|
118
116
|
# Calculate pointers for this specific row
|
|
119
117
|
row_X_ptr = X_ptr + row_idx * stride_x
|
|
@@ -137,8 +135,8 @@ def _layer_norm_backward_kernel(
|
|
|
137
135
|
# Compute backward pass for this row
|
|
138
136
|
x_hat = (x_f32 - mean_f32) * rstd_f32
|
|
139
137
|
wdy = w_f32 * dy_f32
|
|
140
|
-
c1 = tl.sum(x_hat * wdy, axis=0) /
|
|
141
|
-
c2 = tl.sum(wdy, axis=0) /
|
|
138
|
+
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
|
|
139
|
+
c2 = tl.sum(wdy, axis=0) / n_cols
|
|
142
140
|
dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
|
|
143
141
|
|
|
144
142
|
# Store input gradient
|