liger-kernel 0.6.1__py3-none-any.whl → 0.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/dpo_loss.py +54 -3
- liger_kernel/chunked_loss/fused_linear_ppo.py +4 -0
- liger_kernel/chunked_loss/grpo_loss.py +38 -4
- liger_kernel/chunked_loss/jsd_loss.py +5 -2
- liger_kernel/ops/cross_entropy.py +59 -53
- liger_kernel/ops/fused_linear_cross_entropy.py +83 -17
- liger_kernel/ops/layer_norm.py +4 -6
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/transformers/__init__.py +32 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/functional.py +9 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +8 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +108 -0
- liger_kernel/transformers/model/gemma.py +2 -1
- liger_kernel/transformers/model/gemma2.py +8 -2
- liger_kernel/transformers/model/gemma3.py +27 -2
- liger_kernel/transformers/model/glm4.py +2 -1
- liger_kernel/transformers/model/glm4v.py +151 -0
- liger_kernel/transformers/model/glm4v_moe.py +153 -0
- liger_kernel/transformers/model/internvl.py +150 -0
- liger_kernel/transformers/model/llama.py +2 -1
- liger_kernel/transformers/model/llama4.py +2 -1
- liger_kernel/transformers/model/llava.py +6 -2
- liger_kernel/transformers/model/loss_utils.py +3 -0
- liger_kernel/transformers/model/mistral.py +2 -1
- liger_kernel/transformers/model/mixtral.py +8 -2
- liger_kernel/transformers/model/mllama.py +6 -3
- liger_kernel/transformers/model/olmo2.py +2 -1
- liger_kernel/transformers/model/paligemma.py +19 -0
- liger_kernel/transformers/model/phi3.py +10 -160
- liger_kernel/transformers/model/qwen2.py +2 -1
- liger_kernel/transformers/model/qwen2_5_vl.py +7 -2
- liger_kernel/transformers/model/qwen2_vl.py +7 -2
- liger_kernel/transformers/model/qwen3.py +2 -1
- liger_kernel/transformers/model/qwen3_moe.py +8 -2
- liger_kernel/transformers/model/qwen3_next.py +134 -0
- liger_kernel/transformers/model/smollm3.py +2 -1
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +552 -23
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/rms_norm.py +7 -0
- {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/METADATA +14 -11
- {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/RECORD +50 -39
- {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/WHEEL +0 -0
- {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/top_level.txt +0 -0
|
@@ -13,6 +13,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
13
13
|
ref_chosen_logps=None,
|
|
14
14
|
ref_rejected_logps=None,
|
|
15
15
|
beta=0.1,
|
|
16
|
+
loss_type="sigmoid",
|
|
16
17
|
):
|
|
17
18
|
"""
|
|
18
19
|
Paper: https://arxiv.org/pdf/2305.18290
|
|
@@ -48,8 +49,50 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
48
49
|
chosen_rewards = beta * chosen_logratios
|
|
49
50
|
rejected_rewards = beta * rejected_logratios
|
|
50
51
|
|
|
51
|
-
|
|
52
|
-
|
|
52
|
+
if loss_type == "sigmoid":
|
|
53
|
+
logits_diff = beta * (chosen_logratios - rejected_logratios)
|
|
54
|
+
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
|
|
55
|
+
|
|
56
|
+
elif loss_type == "apo_zero":
|
|
57
|
+
# Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
|
|
58
|
+
# Use this loss when you believe the chosen outputs are better than your model's default output
|
|
59
|
+
losses_chosen = 1 - F.sigmoid(beta * chosen_logratios) # Increase chosen likelihood
|
|
60
|
+
losses_rejected = F.sigmoid(beta * rejected_logratios)
|
|
61
|
+
losses = losses_chosen + losses_rejected
|
|
62
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
63
|
+
|
|
64
|
+
elif loss_type == "apo_down":
|
|
65
|
+
# Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
|
|
66
|
+
# Use this loss when you believe the chosen outputs are worse than your model's default output.
|
|
67
|
+
# Decrease chosen likelihood and decrease rejected likelihood more
|
|
68
|
+
losses_chosen = F.sigmoid(beta * chosen_logratios)
|
|
69
|
+
losses_rejected = 1 - F.sigmoid(beta * (chosen_logratios - rejected_logratios))
|
|
70
|
+
losses = losses_chosen + losses_rejected
|
|
71
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
72
|
+
|
|
73
|
+
elif loss_type == "sppo_hard":
|
|
74
|
+
# In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach,
|
|
75
|
+
# estimated using the PairRM score. The probability calculation is conducted outside of the trainer class.
|
|
76
|
+
# The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is
|
|
77
|
+
# set to 1 for the winner and 0 for the loser.
|
|
78
|
+
a = chosen_logps - ref_chosen_logps
|
|
79
|
+
b = rejected_logps - ref_rejected_logps
|
|
80
|
+
losses = (a - 0.5 / beta) ** 2 + (b + 0.5 / beta) ** 2
|
|
81
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
82
|
+
|
|
83
|
+
elif loss_type == "nca_pair":
|
|
84
|
+
losses = (
|
|
85
|
+
-F.logsigmoid(chosen_rewards)
|
|
86
|
+
- 0.5 * F.logsigmoid(-chosen_rewards)
|
|
87
|
+
- 0.5 * F.logsigmoid(-rejected_rewards)
|
|
88
|
+
)
|
|
89
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
90
|
+
|
|
91
|
+
else:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"Unsupported loss_type: {loss_type}. Supported types are: sigmoid, apo_zero, apo_down, sppo_hard, nca_pair"
|
|
94
|
+
)
|
|
95
|
+
|
|
53
96
|
return loss, chosen_rewards, rejected_rewards
|
|
54
97
|
|
|
55
98
|
@classmethod
|
|
@@ -70,6 +113,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
70
113
|
use_ref_model=True,
|
|
71
114
|
average_log_prob=False,
|
|
72
115
|
chunk_size=1,
|
|
116
|
+
loss_type="sigmoid",
|
|
73
117
|
):
|
|
74
118
|
"""
|
|
75
119
|
Fused linear layer with DPO loss.
|
|
@@ -108,12 +152,13 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
108
152
|
ref_bias=ref_bias,
|
|
109
153
|
average_log_prob=average_log_prob,
|
|
110
154
|
chunk_size=chunk_size,
|
|
155
|
+
loss_type=loss_type,
|
|
111
156
|
)
|
|
112
157
|
|
|
113
158
|
@staticmethod
|
|
114
159
|
def backward(ctx, *grad_output):
|
|
115
160
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
116
|
-
return *grads, None, None, None, None, None, None, None, None, None, None
|
|
161
|
+
return *grads, None, None, None, None, None, None, None, None, None, None, None
|
|
117
162
|
|
|
118
163
|
|
|
119
164
|
class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
@@ -130,6 +175,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
130
175
|
use_ref_model: bool = True,
|
|
131
176
|
average_log_prob: bool = False,
|
|
132
177
|
chunk_size: int = 1,
|
|
178
|
+
loss_type: str = "sigmoid",
|
|
133
179
|
):
|
|
134
180
|
"""
|
|
135
181
|
Args:
|
|
@@ -149,6 +195,10 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
149
195
|
self.use_ref_model = use_ref_model
|
|
150
196
|
self.average_log_prob = average_log_prob
|
|
151
197
|
self.chunk_size = chunk_size
|
|
198
|
+
self.loss_type = loss_type
|
|
199
|
+
supported_loss_types = {"sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"}
|
|
200
|
+
if self.loss_type not in supported_loss_types:
|
|
201
|
+
raise ValueError(f"Unsupported loss_type: {self.loss_type}. Supported types are: {supported_loss_types}")
|
|
152
202
|
|
|
153
203
|
def forward(
|
|
154
204
|
self,
|
|
@@ -175,4 +225,5 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
175
225
|
self.use_ref_model,
|
|
176
226
|
self.average_log_prob,
|
|
177
227
|
self.chunk_size,
|
|
228
|
+
self.loss_type,
|
|
178
229
|
)
|
|
@@ -34,6 +34,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
34
34
|
beta=0.04,
|
|
35
35
|
loss_type="bnpo",
|
|
36
36
|
max_completion_length=None,
|
|
37
|
+
importance_sampling_level="token",
|
|
37
38
|
temperature=1.0,
|
|
38
39
|
compiled=True,
|
|
39
40
|
use_ref_model=False,
|
|
@@ -92,6 +93,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
92
93
|
beta=beta,
|
|
93
94
|
loss_type=loss_type,
|
|
94
95
|
max_completion_length=max_completion_length,
|
|
96
|
+
importance_sampling_level=importance_sampling_level,
|
|
95
97
|
temperature=temperature,
|
|
96
98
|
use_ref_model=use_ref_model,
|
|
97
99
|
ppo_loss_fn=cls.ppo_loss_fn,
|
|
@@ -261,6 +263,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
261
263
|
beta=0.04,
|
|
262
264
|
loss_type="bnpo",
|
|
263
265
|
max_completion_length=None,
|
|
266
|
+
importance_sampling_level="token",
|
|
264
267
|
temperature=1.0,
|
|
265
268
|
use_ref_model=False,
|
|
266
269
|
ppo_loss_fn=None,
|
|
@@ -292,6 +295,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
292
295
|
beta=beta,
|
|
293
296
|
loss_type=loss_type,
|
|
294
297
|
max_completion_length=max_completion_length,
|
|
298
|
+
importance_sampling_level=importance_sampling_level,
|
|
295
299
|
)
|
|
296
300
|
|
|
297
301
|
return chunk_loss, chunk_metrics
|
|
@@ -31,6 +31,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
31
31
|
beta=0.04,
|
|
32
32
|
loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
|
|
33
33
|
max_completion_length=None, # Required for dr_grpo
|
|
34
|
+
importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
|
|
34
35
|
**kwargs,
|
|
35
36
|
):
|
|
36
37
|
"""GRPO Loss Function matching GRPOTrainer implementation."""
|
|
@@ -50,7 +51,22 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
50
51
|
|
|
51
52
|
# Compute policy gradient loss with importance sampling ratio
|
|
52
53
|
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
|
|
53
|
-
|
|
54
|
+
log_ratio = per_token_logps - old_per_token_logps
|
|
55
|
+
|
|
56
|
+
if importance_sampling_level == "token":
|
|
57
|
+
log_importance_weights = log_ratio
|
|
58
|
+
elif importance_sampling_level == "sequence":
|
|
59
|
+
log_importance_weights = (log_ratio * attention_mask).sum(-1) / attention_mask.sum(-1).clamp(min=1.0)
|
|
60
|
+
log_importance_weights = log_importance_weights.unsqueeze(-1)
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
|
|
64
|
+
"and 'sequence'."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
|
|
68
|
+
# importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
|
|
69
|
+
coef_1 = torch.exp(log_importance_weights)
|
|
54
70
|
coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
|
|
55
71
|
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
|
56
72
|
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
|
@@ -85,9 +101,19 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
85
101
|
metrics = []
|
|
86
102
|
if beta != 0.0:
|
|
87
103
|
metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
104
|
+
|
|
105
|
+
# Adjust clipping metric calculation based on importance sampling level
|
|
106
|
+
if importance_sampling_level == "token":
|
|
107
|
+
is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
|
|
108
|
+
(coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
|
|
109
|
+
)
|
|
110
|
+
else: # sequence level
|
|
111
|
+
# For sequence level, coef_1 is shape (B, 1), advantages is shape (B,)
|
|
112
|
+
is_clipped = ((coef_1.squeeze(-1) < 1 - epsilon_low) & (advantages < 0)) | (
|
|
113
|
+
(coef_1.squeeze(-1) > 1 + epsilon_high) & (advantages > 0)
|
|
114
|
+
)
|
|
115
|
+
is_clipped = is_clipped.unsqueeze(1).expand_as(attention_mask)
|
|
116
|
+
|
|
91
117
|
metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
|
|
92
118
|
return loss, metrics
|
|
93
119
|
|
|
@@ -111,6 +137,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
111
137
|
epsilon_high=0.2,
|
|
112
138
|
loss_type="bnpo",
|
|
113
139
|
max_completion_length=None,
|
|
140
|
+
importance_sampling_level="token",
|
|
114
141
|
temperature=1.0,
|
|
115
142
|
compiled=True,
|
|
116
143
|
use_ref_model=True,
|
|
@@ -132,6 +159,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
132
159
|
beta (float): Weight for the KL penalty
|
|
133
160
|
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
|
|
134
161
|
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
162
|
+
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
|
135
163
|
temperature (float): Temperature for the logits
|
|
136
164
|
compiled (bool): Whether to use torch compile
|
|
137
165
|
use_ref_model (bool): Whether to use a reference model
|
|
@@ -162,6 +190,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
162
190
|
compiled=compiled,
|
|
163
191
|
use_ref_model=use_ref_model,
|
|
164
192
|
chunk_size=chunk_size,
|
|
193
|
+
importance_sampling_level=importance_sampling_level,
|
|
165
194
|
)
|
|
166
195
|
|
|
167
196
|
@staticmethod
|
|
@@ -187,6 +216,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
187
216
|
None, # grad_epsilon_high
|
|
188
217
|
None, # grad_loss_type (string, not differentiable)
|
|
189
218
|
None, # grad_max_completion_length (int, not differentiable)
|
|
219
|
+
None, # grad_importance_sampling_level (string, not differentiable)
|
|
190
220
|
None, # grad_temperature
|
|
191
221
|
None, # grad_compiled
|
|
192
222
|
None, # grad_use_ref_model
|
|
@@ -207,6 +237,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
207
237
|
epsilon_high: float = 0.2,
|
|
208
238
|
loss_type: str = "bnpo",
|
|
209
239
|
max_completion_length: Optional[int] = None,
|
|
240
|
+
importance_sampling_level: str = "token",
|
|
210
241
|
temperature: float = 1.0,
|
|
211
242
|
):
|
|
212
243
|
"""
|
|
@@ -219,6 +250,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
219
250
|
epsilon_high (float): Upper bound for the importance sampling ratio.
|
|
220
251
|
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
|
|
221
252
|
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
253
|
+
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
|
222
254
|
temperature (float): Temperature for the logits.
|
|
223
255
|
"""
|
|
224
256
|
super().__init__()
|
|
@@ -230,6 +262,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
230
262
|
self.epsilon_high = epsilon_high
|
|
231
263
|
self.loss_type = loss_type
|
|
232
264
|
self.max_completion_length = max_completion_length
|
|
265
|
+
self.importance_sampling_level = importance_sampling_level
|
|
233
266
|
self.temperature = temperature
|
|
234
267
|
|
|
235
268
|
def forward(
|
|
@@ -263,6 +296,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
263
296
|
self.epsilon_high,
|
|
264
297
|
self.loss_type,
|
|
265
298
|
self.max_completion_length,
|
|
299
|
+
self.importance_sampling_level,
|
|
266
300
|
self.temperature,
|
|
267
301
|
self.compiled,
|
|
268
302
|
self.use_ref_model,
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import torch.nn.functional as F
|
|
3
5
|
|
|
@@ -25,8 +27,9 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
25
27
|
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
|
|
26
28
|
else:
|
|
27
29
|
# Compute probabilities (only required for mean calculation)
|
|
28
|
-
|
|
29
|
-
|
|
30
|
+
log_mean_probs = torch.logsumexp(
|
|
31
|
+
torch.stack([student_log_probs + math.log(1 - beta), teacher_log_probs + math.log(beta)], dim=0), dim=0
|
|
32
|
+
)
|
|
30
33
|
|
|
31
34
|
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
|
|
32
35
|
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
|
|
@@ -45,6 +45,7 @@ def liger_cross_entropy_kernel(
|
|
|
45
45
|
BLOCK_SIZE: tl.constexpr,
|
|
46
46
|
HAS_WEIGHT: tl.constexpr,
|
|
47
47
|
HAS_SOFTCAPPING: tl.constexpr,
|
|
48
|
+
HAS_GRADIENTS: tl.constexpr,
|
|
48
49
|
):
|
|
49
50
|
"""
|
|
50
51
|
This kernel computes both cross entropy loss and the gradient of the input.
|
|
@@ -72,6 +73,7 @@ def liger_cross_entropy_kernel(
|
|
|
72
73
|
BLOCK_SIZE (int): The block size for Triton operations.
|
|
73
74
|
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
|
74
75
|
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
|
76
|
+
HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
|
|
75
77
|
"""
|
|
76
78
|
|
|
77
79
|
# https://github.com/triton-lang/triton/issues/1058
|
|
@@ -155,58 +157,58 @@ def liger_cross_entropy_kernel(
|
|
|
155
157
|
# For 'sum' reduction, no normalization is applied:
|
|
156
158
|
# dx_y = softmax(x_y) - 1
|
|
157
159
|
# dx_i = softmax(x_i), for i ≠ y
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
160
|
+
if HAS_GRADIENTS:
|
|
161
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
162
|
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
163
|
+
X_block = tl.load(
|
|
164
|
+
X_ptr + X_offsets,
|
|
165
|
+
mask=X_offsets < n_cols,
|
|
166
|
+
other=float("-inf"),
|
|
167
|
+
# Ensure float32 precision for softmax calculation
|
|
168
|
+
).cast(tl.float32)
|
|
169
|
+
if HAS_SOFTCAPPING:
|
|
170
|
+
intermediate = tanh(X_block / softcap)
|
|
171
|
+
X_block = softcap * intermediate
|
|
172
|
+
|
|
173
|
+
if not HAS_WEIGHT:
|
|
174
|
+
# softmax(x_i)
|
|
175
|
+
X_block = tl.exp(X_block - m) / d
|
|
176
|
+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
|
177
|
+
X_block += 2 * lse_square_scale * lse * X_block
|
|
178
|
+
# smoothing term
|
|
179
|
+
X_block += -eps
|
|
180
|
+
# special handle dx_y
|
|
181
|
+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
|
|
182
|
+
# reduction scale
|
|
183
|
+
if reduction == "mean":
|
|
184
|
+
X_block = X_block / n_non_ignore
|
|
185
|
+
else:
|
|
186
|
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
187
|
+
softmax_X = tl.exp(X_block - m) / d
|
|
188
|
+
# derivative of original_loss
|
|
189
|
+
dloss_ori = (1 - label_smoothing) * softmax_X
|
|
190
|
+
# specially handle dx_y
|
|
191
|
+
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
|
|
192
|
+
dloss_ori = dloss_ori * weight_y
|
|
193
|
+
# derivative of smooth_loss
|
|
194
|
+
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
|
|
195
|
+
# derivative of z-loss
|
|
196
|
+
dz_loss = 2 * lse_square_scale * lse * softmax_X
|
|
197
|
+
# reduction scale
|
|
198
|
+
if reduction == "mean":
|
|
199
|
+
dloss_ori = dloss_ori / sum_non_ignore_weight
|
|
200
|
+
dloss_smooth = dloss_smooth / sum_non_ignore_weight
|
|
201
|
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
|
202
|
+
dz_loss = dz_loss / n_non_ignore
|
|
203
|
+
# derivative of total_loss
|
|
204
|
+
X_block = dloss_ori + dloss_smooth + dz_loss
|
|
205
|
+
|
|
206
|
+
# chain rule softcapping
|
|
207
|
+
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
|
|
208
|
+
if HAS_SOFTCAPPING:
|
|
209
|
+
X_block = X_block * (1 - intermediate * intermediate)
|
|
210
|
+
|
|
211
|
+
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
|
210
212
|
|
|
211
213
|
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
|
|
212
214
|
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
|
|
@@ -332,6 +334,7 @@ def cross_entropy_forward(
|
|
|
332
334
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
333
335
|
HAS_WEIGHT=True if weight is not None else False,
|
|
334
336
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
337
|
+
HAS_GRADIENTS=_input.requires_grad,
|
|
335
338
|
# TODO: 32 seems to give the best performance
|
|
336
339
|
# Performance is quite sensitive to num_warps
|
|
337
340
|
num_warps=32 if not is_hip() else 16,
|
|
@@ -411,6 +414,8 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
411
414
|
Returns:
|
|
412
415
|
tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
|
|
413
416
|
"""
|
|
417
|
+
input_requires_grad = _input.requires_grad
|
|
418
|
+
|
|
414
419
|
loss, z_loss, _input = cross_entropy_forward(
|
|
415
420
|
_input,
|
|
416
421
|
target,
|
|
@@ -425,7 +430,8 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
425
430
|
# TODO: investigation
|
|
426
431
|
# If we don't detach the _input tensor, the memory will double
|
|
427
432
|
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
|
428
|
-
|
|
433
|
+
if input_requires_grad:
|
|
434
|
+
ctx.save_for_backward(_input.detach())
|
|
429
435
|
ctx.return_z_loss = return_z_loss
|
|
430
436
|
|
|
431
437
|
return loss, z_loss
|
|
@@ -25,10 +25,14 @@ def fused_linear_cross_entropy_forward(
|
|
|
25
25
|
reduction="mean",
|
|
26
26
|
softcap=None,
|
|
27
27
|
return_z_loss=False,
|
|
28
|
+
accum_dtype=None,
|
|
29
|
+
use_token_scaling=False,
|
|
28
30
|
):
|
|
29
31
|
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
30
32
|
device = _input.device
|
|
31
33
|
|
|
34
|
+
input_requires_grad = _input.requires_grad
|
|
35
|
+
|
|
32
36
|
# inputs have shape: BT x H
|
|
33
37
|
# materialized activations will have shape: BT x V
|
|
34
38
|
# the increase in memory = BT x V
|
|
@@ -44,10 +48,17 @@ def fused_linear_cross_entropy_forward(
|
|
|
44
48
|
chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
|
|
45
49
|
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
|
|
46
50
|
|
|
47
|
-
grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
|
|
48
51
|
grad_input = torch.zeros_like(_input, device=device)
|
|
49
|
-
|
|
50
|
-
# we use fp32 for loss accumulator
|
|
52
|
+
|
|
53
|
+
# we use fp32 for loss and gradients accumulator
|
|
54
|
+
if input_requires_grad:
|
|
55
|
+
if accum_dtype is None:
|
|
56
|
+
grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
|
|
57
|
+
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
|
|
58
|
+
else:
|
|
59
|
+
grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
|
|
60
|
+
grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
|
|
61
|
+
|
|
51
62
|
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
|
52
63
|
z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
53
64
|
|
|
@@ -82,6 +93,36 @@ def fused_linear_cross_entropy_forward(
|
|
|
82
93
|
|
|
83
94
|
n_rows = logits_chunk.shape[0]
|
|
84
95
|
|
|
96
|
+
# Compute predicted probabilities for token scaling if needed
|
|
97
|
+
if use_token_scaling:
|
|
98
|
+
# Compute softmax probabilities for scaling
|
|
99
|
+
# We need to compute this before the cross entropy kernel modifies logits_chunk
|
|
100
|
+
logits_for_softmax = logits_chunk.detach().clone() # Detach to avoid gradient flow
|
|
101
|
+
if softcap is not None:
|
|
102
|
+
logits_for_softmax = softcap * torch.tanh(logits_for_softmax / softcap)
|
|
103
|
+
|
|
104
|
+
# Compute softmax to get predicted probabilities
|
|
105
|
+
probs = torch.softmax(logits_for_softmax, dim=-1)
|
|
106
|
+
|
|
107
|
+
# Get predicted probabilities for token scaling, handling ignored targets
|
|
108
|
+
valid_target_mask = target_chunk != ignore_index
|
|
109
|
+
valid_targets = target_chunk[valid_target_mask]
|
|
110
|
+
|
|
111
|
+
if len(valid_targets) > 0:
|
|
112
|
+
# Gather probabilities only for valid targets
|
|
113
|
+
valid_probs = probs[valid_target_mask]
|
|
114
|
+
pred_probs_valid = torch.gather(valid_probs, -1, valid_targets.unsqueeze(-1)).squeeze(-1)
|
|
115
|
+
|
|
116
|
+
# Create full tensor with zeros for ignored targets
|
|
117
|
+
pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
|
|
118
|
+
pred_probs[valid_target_mask] = pred_probs_valid
|
|
119
|
+
else:
|
|
120
|
+
# All targets are ignored
|
|
121
|
+
pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
|
|
122
|
+
|
|
123
|
+
# Store the scaling factors
|
|
124
|
+
scaling_factors = pred_probs.detach() # Detach to ensure no gradient flow
|
|
125
|
+
|
|
85
126
|
# unreduced loss
|
|
86
127
|
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
|
87
128
|
z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
|
|
@@ -112,33 +153,38 @@ def fused_linear_cross_entropy_forward(
|
|
|
112
153
|
RETURN_Z_LOSS=return_z_loss,
|
|
113
154
|
HAS_WEIGHT=True if ce_weight is not None else False,
|
|
114
155
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
156
|
+
HAS_GRADIENTS=input_requires_grad,
|
|
115
157
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
116
158
|
num_warps=32 if not is_hip() else 16,
|
|
117
159
|
)
|
|
118
160
|
|
|
161
|
+
# Apply token scaling if requested
|
|
162
|
+
if use_token_scaling:
|
|
163
|
+
loss_1d_slice = loss_1d_slice * scaling_factors
|
|
164
|
+
if return_z_loss:
|
|
165
|
+
z_loss_1d_slice = z_loss_1d_slice * scaling_factors
|
|
166
|
+
|
|
119
167
|
loss_1d[start_idx:end_idx] = loss_1d_slice
|
|
120
168
|
if return_z_loss:
|
|
121
169
|
z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
|
|
122
170
|
grad_logits_chunk = logits_chunk # chunk_size x V
|
|
123
171
|
|
|
124
|
-
|
|
172
|
+
# Apply token scaling to gradients if requested
|
|
173
|
+
if use_token_scaling:
|
|
174
|
+
# Expand scaling factors to match gradient dimensions
|
|
175
|
+
scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
|
|
176
|
+
grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
|
|
125
177
|
|
|
126
|
-
if
|
|
127
|
-
|
|
128
|
-
input=grad_weight,
|
|
129
|
-
mat1=logits_chunk.t().to(
|
|
130
|
-
_input_chunk.dtype
|
|
131
|
-
), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
|
|
132
|
-
mat2=_input_chunk,
|
|
133
|
-
out=grad_weight,
|
|
134
|
-
alpha=1.0,
|
|
135
|
-
beta=1.0,
|
|
136
|
-
)
|
|
178
|
+
if input_requires_grad:
|
|
179
|
+
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
|
|
137
180
|
|
|
138
|
-
if
|
|
181
|
+
if grad_weight is not None and input_requires_grad:
|
|
182
|
+
grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
|
|
183
|
+
|
|
184
|
+
if bias is not None and input_requires_grad:
|
|
139
185
|
torch.add(
|
|
140
186
|
input=grad_bias,
|
|
141
|
-
other=
|
|
187
|
+
other=grad_logits_chunk.sum(dim=0),
|
|
142
188
|
out=grad_bias,
|
|
143
189
|
alpha=1.0,
|
|
144
190
|
)
|
|
@@ -148,9 +194,18 @@ def fused_linear_cross_entropy_forward(
|
|
|
148
194
|
# loss = loss_1d
|
|
149
195
|
# z_loss = z_loss_1d if return_z_loss else None
|
|
150
196
|
|
|
197
|
+
if reduction == "none":
|
|
198
|
+
# Return per-token losses
|
|
199
|
+
loss = loss_1d
|
|
200
|
+
z_loss = z_loss_1d if return_z_loss else None
|
|
151
201
|
else:
|
|
152
202
|
loss = torch.sum(loss_1d)
|
|
153
203
|
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
204
|
+
|
|
205
|
+
# Cast back to original dtype
|
|
206
|
+
grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
|
|
207
|
+
grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None
|
|
208
|
+
|
|
154
209
|
return loss, z_loss, grad_input, grad_weight, grad_bias
|
|
155
210
|
|
|
156
211
|
|
|
@@ -217,6 +272,8 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
217
272
|
reduction="mean",
|
|
218
273
|
softcap=None,
|
|
219
274
|
return_z_loss: bool = False,
|
|
275
|
+
accum_dtype=None,
|
|
276
|
+
use_token_scaling: bool = False,
|
|
220
277
|
):
|
|
221
278
|
"""
|
|
222
279
|
Fusing the last linear layer with cross-entropy loss
|
|
@@ -235,6 +292,11 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
235
292
|
ignore_index: the index to ignore in the target
|
|
236
293
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
237
294
|
reduction: reduction to apply
|
|
295
|
+
accum_dtype (torch.dtype): the dtype of intermediate result buffers for weight and bias gradient accumulations.
|
|
296
|
+
Recommended to set `accum_dtype` to higher precision, e.g. `torch.float32`, if the training is unstable with original dtype. Default: `None`, performing accumulations in original dtype
|
|
297
|
+
use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
|
|
298
|
+
When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
|
|
299
|
+
Default: False.
|
|
238
300
|
"""
|
|
239
301
|
|
|
240
302
|
loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
@@ -249,6 +311,8 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
249
311
|
reduction=reduction,
|
|
250
312
|
softcap=softcap,
|
|
251
313
|
return_z_loss=return_z_loss,
|
|
314
|
+
accum_dtype=accum_dtype,
|
|
315
|
+
use_token_scaling=use_token_scaling,
|
|
252
316
|
)
|
|
253
317
|
# downcast to dtype and store for backward
|
|
254
318
|
ctx.save_for_backward(
|
|
@@ -280,4 +344,6 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
280
344
|
None,
|
|
281
345
|
None,
|
|
282
346
|
None,
|
|
347
|
+
None,
|
|
348
|
+
None, # use_token_scaling
|
|
283
349
|
)
|
liger_kernel/ops/layer_norm.py
CHANGED
|
@@ -63,12 +63,11 @@ def _layer_norm_forward_kernel(
|
|
|
63
63
|
X_f32 = X_row.to(tl.float32)
|
|
64
64
|
|
|
65
65
|
# Compute statistics in fp32 for numerical stability
|
|
66
|
-
|
|
67
|
-
mean = tl.sum(X_f32, axis=0) / n_cols_f32
|
|
66
|
+
mean = tl.sum(X_f32, axis=0) / n_cols
|
|
68
67
|
X_centered = X_f32 - mean
|
|
69
68
|
# Apply mask to variance calculation to exclude contributions from masked elements
|
|
70
69
|
X_centered_masked = tl.where(mask, X_centered, 0.0)
|
|
71
|
-
var = tl.sum(X_centered_masked * X_centered_masked, axis=0) /
|
|
70
|
+
var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
|
|
72
71
|
rstd = rsqrt(var + eps)
|
|
73
72
|
|
|
74
73
|
# Store statistics (convert back to original dtype only once)
|
|
@@ -113,7 +112,6 @@ def _layer_norm_backward_kernel(
|
|
|
113
112
|
# Pre-load weights once (same optimization as forward pass)
|
|
114
113
|
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
|
115
114
|
w_f32 = w.to(tl.float32)
|
|
116
|
-
n_cols_f32 = n_cols.to(tl.float32)
|
|
117
115
|
|
|
118
116
|
# Calculate pointers for this specific row
|
|
119
117
|
row_X_ptr = X_ptr + row_idx * stride_x
|
|
@@ -137,8 +135,8 @@ def _layer_norm_backward_kernel(
|
|
|
137
135
|
# Compute backward pass for this row
|
|
138
136
|
x_hat = (x_f32 - mean_f32) * rstd_f32
|
|
139
137
|
wdy = w_f32 * dy_f32
|
|
140
|
-
c1 = tl.sum(x_hat * wdy, axis=0) /
|
|
141
|
-
c2 = tl.sum(wdy, axis=0) /
|
|
138
|
+
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
|
|
139
|
+
c2 = tl.sum(wdy, axis=0) / n_cols
|
|
142
140
|
dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
|
|
143
141
|
|
|
144
142
|
# Store input gradient
|