liger-kernel 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +17 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +346 -0
- liger_kernel/chunked_loss/grpo_loss.py +134 -60
- liger_kernel/chunked_loss/jsd_loss.py +12 -7
- liger_kernel/ops/cross_entropy.py +3 -2
- liger_kernel/ops/dyt.py +225 -0
- liger_kernel/ops/fused_linear_jsd.py +2 -1
- liger_kernel/ops/jsd.py +32 -12
- liger_kernel/ops/kl_div.py +15 -8
- liger_kernel/ops/layer_norm.py +14 -1
- liger_kernel/ops/rms_norm.py +12 -1
- liger_kernel/transformers/__init__.py +133 -15
- liger_kernel/transformers/dyt.py +20 -0
- liger_kernel/transformers/functional.py +5 -0
- liger_kernel/transformers/gema3_rms.py +8 -0
- liger_kernel/transformers/model/gemma.py +17 -20
- liger_kernel/transformers/model/gemma2.py +17 -21
- liger_kernel/transformers/model/gemma3.py +335 -0
- liger_kernel/transformers/model/llama.py +17 -19
- liger_kernel/transformers/model/llava.py +369 -0
- liger_kernel/transformers/model/loss_utils.py +64 -0
- liger_kernel/transformers/model/mistral.py +28 -25
- liger_kernel/transformers/model/mixtral.py +20 -26
- liger_kernel/transformers/model/mllama.py +17 -19
- liger_kernel/transformers/model/olmo2.py +17 -20
- liger_kernel/transformers/model/paligemma.py +397 -0
- liger_kernel/transformers/model/phi3.py +17 -19
- liger_kernel/transformers/model/qwen2.py +17 -19
- liger_kernel/transformers/model/qwen2_5_vl.py +9 -10
- liger_kernel/transformers/model/qwen2_vl.py +9 -10
- liger_kernel/transformers/monkey_patch.py +392 -13
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/METADATA +11 -6
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/RECORD +38 -31
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/WHEEL +1 -1
- liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/LICENSE +0 -0
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/NOTICE +0 -0
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/top_level.txt +0 -0
|
@@ -1,66 +1,92 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from liger_kernel.chunked_loss.
|
|
3
|
+
from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
|
|
6
|
+
def k3_loss_fn(log_p, log_q):
|
|
7
|
+
# computes k3 estimate of KL[q, p]
|
|
8
|
+
# ref: http://joschu.net/blog/kl-approx.html
|
|
9
|
+
return torch.exp(log_p - log_q) - (log_p - log_q) - 1.0
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def clip_coef_fn(coef, epsilon_low, epsilon_high):
|
|
13
|
+
return torch.clamp(coef, 1 - epsilon_low, 1 + epsilon_high)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
7
17
|
@staticmethod
|
|
8
|
-
def
|
|
18
|
+
def ppo_loss_fn(
|
|
9
19
|
log_probs,
|
|
20
|
+
selected_token_ids,
|
|
10
21
|
attention_mask,
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
22
|
+
advantages,
|
|
23
|
+
full_attention_mask,
|
|
24
|
+
ref_per_token_logps=None, # shape: [chunk_size, seq_len]
|
|
25
|
+
old_per_token_logps=None,
|
|
26
|
+
ref_log_probs=None, # used when ref_per_token_logps is None (shape: [chunk_size, seq_len, vocab_size])
|
|
27
|
+
epsilon_low=0.2,
|
|
28
|
+
epsilon_high=0.2,
|
|
29
|
+
beta=0.04,
|
|
30
|
+
loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
|
|
31
|
+
max_completion_length=None, # Required for dr_grpo
|
|
14
32
|
**kwargs,
|
|
15
33
|
):
|
|
16
34
|
"""GRPO Loss Function matching GRPOTrainer implementation."""
|
|
17
|
-
|
|
18
|
-
chosen_tokens = log_probs.argmax(dim=-1) # (batch_size, seq_len)
|
|
19
|
-
chosen_token_logprobs = log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(
|
|
35
|
+
per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
|
|
20
36
|
-1
|
|
21
37
|
) # (batch_size, seq_len)
|
|
22
38
|
|
|
23
39
|
# Get reference model probabilities
|
|
24
|
-
if
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
std_grouped_rewards = rewards.std() # [batch_size,]
|
|
33
|
-
|
|
34
|
-
# Calculate advantages using the same epsilon as in GRPOTrainer
|
|
35
|
-
eps = 1e-4
|
|
36
|
-
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + eps)
|
|
40
|
+
if ref_per_token_logps is None:
|
|
41
|
+
if ref_log_probs is not None:
|
|
42
|
+
with torch.no_grad():
|
|
43
|
+
ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
|
|
44
|
+
-1
|
|
45
|
+
)
|
|
46
|
+
else:
|
|
47
|
+
ref_per_token_logps = per_token_logps.detach()
|
|
37
48
|
|
|
38
49
|
# Compute policy gradient loss with importance sampling ratio
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
50
|
+
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
|
|
51
|
+
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
|
|
52
|
+
coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
|
|
53
|
+
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
|
54
|
+
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
|
55
|
+
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
|
56
|
+
if beta != 0.0:
|
|
57
|
+
# Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
|
|
58
|
+
kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
|
|
59
|
+
# Combine losses
|
|
60
|
+
per_token_loss = per_token_loss + beta * kl_div
|
|
46
61
|
|
|
47
|
-
#
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
#
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
62
|
+
# Note: We normalize by the number of tokens in the batch (using full_attention_mask),
|
|
63
|
+
# which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
|
|
64
|
+
# and TRL GRPO implementation
|
|
65
|
+
# (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
|
|
66
|
+
if loss_type == "grpo":
|
|
67
|
+
# Average per-sequence loss
|
|
68
|
+
loss = (
|
|
69
|
+
(per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)
|
|
70
|
+
).sum() / full_attention_mask.shape[0]
|
|
71
|
+
elif loss_type == "bnpo":
|
|
72
|
+
# Batch Normalized Per-token loss (original implementation)
|
|
73
|
+
loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
|
|
74
|
+
elif loss_type == "dr_grpo":
|
|
75
|
+
# Dimension-Reduced GRPO (normalize by batch_size * max_completion_length)
|
|
76
|
+
if max_completion_length is None:
|
|
77
|
+
raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
|
|
78
|
+
loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
|
|
79
|
+
else:
|
|
80
|
+
raise ValueError(f"Unknown loss type: {loss_type}")
|
|
55
81
|
|
|
56
82
|
# Calculate metrics
|
|
57
|
-
metrics =
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
(
|
|
83
|
+
metrics = []
|
|
84
|
+
if beta != 0.0:
|
|
85
|
+
metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
|
|
86
|
+
is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
|
|
87
|
+
(coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
|
|
62
88
|
)
|
|
63
|
-
|
|
89
|
+
metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
|
|
64
90
|
return loss, metrics
|
|
65
91
|
|
|
66
92
|
@classmethod
|
|
@@ -69,16 +95,23 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
|
|
|
69
95
|
ctx,
|
|
70
96
|
_input,
|
|
71
97
|
weight,
|
|
98
|
+
selected_token_ids,
|
|
72
99
|
attention_mask,
|
|
73
|
-
|
|
100
|
+
advantages,
|
|
74
101
|
bias=None,
|
|
102
|
+
ref_per_token_logps=None,
|
|
103
|
+
old_per_token_logps=None,
|
|
75
104
|
ref_input=None,
|
|
76
105
|
ref_weight=None,
|
|
77
106
|
ref_bias=None,
|
|
78
|
-
beta=0.
|
|
107
|
+
beta=0.04,
|
|
108
|
+
epsilon_low=0.2,
|
|
109
|
+
epsilon_high=0.2,
|
|
110
|
+
loss_type="bnpo",
|
|
111
|
+
max_completion_length=None,
|
|
112
|
+
temperature=1.0,
|
|
79
113
|
compiled=True,
|
|
80
114
|
use_ref_model=True,
|
|
81
|
-
num_generations=1,
|
|
82
115
|
chunk_size=1,
|
|
83
116
|
):
|
|
84
117
|
"""
|
|
@@ -86,16 +119,20 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
|
|
|
86
119
|
Args:
|
|
87
120
|
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
88
121
|
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
122
|
+
selected_token_ids (torch.Tensor): Selected token ids tensor. Shape: (batch_size, seq_len)
|
|
89
123
|
attention_mask (torch.Tensor): Attention mask tensor. Shape: (batch_size, seq_len)
|
|
90
|
-
|
|
124
|
+
advantages (torch.Tensor): Advantages tensor. Shape: (batch_size,)
|
|
91
125
|
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
126
|
+
ref_per_token_logps: Reference model log probs per token tensor. Shape:(batch_size, seq_len)
|
|
92
127
|
ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
93
128
|
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
|
94
129
|
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
|
95
130
|
beta (float): Weight for the KL penalty
|
|
131
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
|
|
132
|
+
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
133
|
+
temperature (float): Temperature for the logits
|
|
96
134
|
compiled (bool): Whether to use torch compile
|
|
97
135
|
use_ref_model (bool): Whether to use a reference model
|
|
98
|
-
num_generations (int): Number of generations per prompt
|
|
99
136
|
chunk_size (int): Size of chunks for processing.
|
|
100
137
|
Returns:
|
|
101
138
|
torch.Tensor: Computed loss
|
|
@@ -105,16 +142,23 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
|
|
|
105
142
|
ctx=ctx,
|
|
106
143
|
_input=_input,
|
|
107
144
|
weight=weight,
|
|
145
|
+
selected_token_ids=selected_token_ids,
|
|
108
146
|
attention_mask=attention_mask,
|
|
109
|
-
|
|
147
|
+
advantages=advantages,
|
|
110
148
|
bias=bias,
|
|
149
|
+
ref_per_token_logps=ref_per_token_logps,
|
|
150
|
+
old_per_token_logps=old_per_token_logps,
|
|
111
151
|
ref_input=ref_input,
|
|
112
152
|
ref_weight=ref_weight,
|
|
113
153
|
ref_bias=ref_bias,
|
|
114
154
|
beta=beta,
|
|
155
|
+
epsilon_low=epsilon_low,
|
|
156
|
+
epsilon_high=epsilon_high,
|
|
157
|
+
loss_type=loss_type,
|
|
158
|
+
max_completion_length=max_completion_length,
|
|
159
|
+
temperature=temperature,
|
|
115
160
|
compiled=compiled,
|
|
116
161
|
use_ref_model=use_ref_model,
|
|
117
|
-
num_generations=num_generations,
|
|
118
162
|
chunk_size=chunk_size,
|
|
119
163
|
)
|
|
120
164
|
|
|
@@ -126,16 +170,24 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
|
|
|
126
170
|
grad_output: Gradient of the loss (scalar)
|
|
127
171
|
grad_metrics: Gradients of the metrics (not used in backward computation)
|
|
128
172
|
"""
|
|
129
|
-
grads =
|
|
173
|
+
grads = LigerFusedLinearPPOBase.backward(ctx, grad_output)
|
|
130
174
|
return (
|
|
131
|
-
*grads[
|
|
175
|
+
*grads[
|
|
176
|
+
:6
|
|
177
|
+
], # grad_input, grad_weight, grad_selected_token_ids, grad_attention_mask, grad_advantages, grad_bias
|
|
178
|
+
None, # grad_ref_per_token_logps
|
|
179
|
+
None, # grad_old_per_token_logps
|
|
132
180
|
None, # grad_ref_input
|
|
133
181
|
None, # grad_ref_weight
|
|
134
182
|
None, # grad_ref_bias
|
|
135
183
|
None, # grad_beta
|
|
184
|
+
None, # grad_epsilon_low
|
|
185
|
+
None, # grad_epsilon_high
|
|
186
|
+
None, # grad_loss_type (string, not differentiable)
|
|
187
|
+
None, # grad_max_completion_length (int, not differentiable)
|
|
188
|
+
None, # grad_temperature
|
|
136
189
|
None, # grad_compiled
|
|
137
190
|
None, # grad_use_ref_model
|
|
138
|
-
None, # grad_num_generations
|
|
139
191
|
None, # grad_chunk_size
|
|
140
192
|
)
|
|
141
193
|
|
|
@@ -145,34 +197,49 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
145
197
|
|
|
146
198
|
def __init__(
|
|
147
199
|
self,
|
|
148
|
-
beta: float = 0.
|
|
200
|
+
beta: float = 0.04,
|
|
149
201
|
compiled: bool = True,
|
|
150
202
|
use_ref_model: bool = True,
|
|
151
|
-
num_generations: int = 1,
|
|
152
203
|
chunk_size: int = 1,
|
|
204
|
+
epsilon_low: float = 0.2,
|
|
205
|
+
epsilon_high: float = 0.2,
|
|
206
|
+
loss_type: str = "bnpo",
|
|
207
|
+
max_completion_length: int | None = None,
|
|
208
|
+
temperature: float = 1.0,
|
|
153
209
|
):
|
|
154
210
|
"""
|
|
155
211
|
Args:
|
|
156
212
|
beta (float): Weight for the KL penalty.
|
|
157
213
|
compiled (bool): Whether to use torch compile.
|
|
158
214
|
use_ref_model (bool): Whether to use a reference model.
|
|
159
|
-
num_generations (int): Number of generations per prompt.
|
|
160
215
|
chunk_size (int): Size of chunks for processing.
|
|
216
|
+
epsilon_low (float): Lower bound for the importance sampling ratio.
|
|
217
|
+
epsilon_high (float): Upper bound for the importance sampling ratio.
|
|
218
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
|
|
219
|
+
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
220
|
+
temperature (float): Temperature for the logits.
|
|
161
221
|
"""
|
|
162
222
|
super().__init__()
|
|
163
223
|
self.beta = beta
|
|
164
224
|
self.compiled = compiled
|
|
165
225
|
self.use_ref_model = use_ref_model
|
|
166
|
-
self.num_generations = num_generations
|
|
167
226
|
self.chunk_size = chunk_size
|
|
227
|
+
self.epsilon_low = epsilon_low
|
|
228
|
+
self.epsilon_high = epsilon_high
|
|
229
|
+
self.loss_type = loss_type
|
|
230
|
+
self.max_completion_length = max_completion_length
|
|
231
|
+
self.temperature = temperature
|
|
168
232
|
|
|
169
233
|
def forward(
|
|
170
234
|
self,
|
|
171
235
|
_input,
|
|
172
236
|
lin_weight,
|
|
237
|
+
selected_token_ids,
|
|
173
238
|
attention_mask,
|
|
174
|
-
|
|
239
|
+
advantages,
|
|
175
240
|
bias=None,
|
|
241
|
+
ref_per_token_logps=None,
|
|
242
|
+
old_per_token_logps=None,
|
|
176
243
|
ref_input=None,
|
|
177
244
|
ref_weight=None,
|
|
178
245
|
ref_bias=None,
|
|
@@ -180,15 +247,22 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
180
247
|
return LigerFusedLinearGRPOFunction.apply(
|
|
181
248
|
_input,
|
|
182
249
|
lin_weight,
|
|
250
|
+
selected_token_ids,
|
|
183
251
|
attention_mask,
|
|
184
|
-
|
|
252
|
+
advantages,
|
|
185
253
|
bias,
|
|
254
|
+
ref_per_token_logps,
|
|
255
|
+
old_per_token_logps,
|
|
186
256
|
ref_input,
|
|
187
257
|
ref_weight,
|
|
188
258
|
ref_bias,
|
|
189
259
|
self.beta,
|
|
260
|
+
self.epsilon_low,
|
|
261
|
+
self.epsilon_high,
|
|
262
|
+
self.loss_type,
|
|
263
|
+
self.max_completion_length,
|
|
264
|
+
self.temperature,
|
|
190
265
|
self.compiled,
|
|
191
266
|
self.use_ref_model,
|
|
192
|
-
self.num_generations,
|
|
193
267
|
self.chunk_size,
|
|
194
268
|
)
|
|
@@ -19,15 +19,20 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
19
19
|
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
|
20
20
|
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
|
21
21
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
22
|
+
if beta == 0:
|
|
23
|
+
jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="sum", log_target=True)
|
|
24
|
+
elif beta == 1:
|
|
25
|
+
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
|
|
26
|
+
else:
|
|
27
|
+
# Compute probabilities (only required for mean calculation)
|
|
28
|
+
mean_probs = (1 - beta) * student_log_probs.exp() + beta * teacher_log_probs.exp()
|
|
29
|
+
log_mean_probs = mean_probs.log()
|
|
25
30
|
|
|
26
|
-
|
|
27
|
-
|
|
31
|
+
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
|
|
32
|
+
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
|
|
28
33
|
|
|
29
|
-
|
|
30
|
-
|
|
34
|
+
# JSD is the weighted average of the KL divergences
|
|
35
|
+
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
|
31
36
|
return jsd_loss
|
|
32
37
|
|
|
33
38
|
@classmethod
|
|
@@ -9,6 +9,7 @@ import triton.language as tl
|
|
|
9
9
|
from liger_kernel.ops.utils import compare_version
|
|
10
10
|
from liger_kernel.ops.utils import element_mul_kernel
|
|
11
11
|
from liger_kernel.ops.utils import is_hip
|
|
12
|
+
from liger_kernel.utils import infer_device
|
|
12
13
|
|
|
13
14
|
if compare_version("triton", operator.ge, "3.0.0"):
|
|
14
15
|
try:
|
|
@@ -59,7 +60,7 @@ def liger_cross_entropy_kernel(
|
|
|
59
60
|
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
|
60
61
|
loss_stride (int): The stride of the loss tensor.
|
|
61
62
|
n_cols (int): The number of columns in the input tensor.
|
|
62
|
-
n_non_ignore (
|
|
63
|
+
n_non_ignore (float): The number of non-ignored elements in the batch.
|
|
63
64
|
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
|
64
65
|
weight_sum (float): The sum of weight tensor.
|
|
65
66
|
ignore_index (int): The index to ignore in the target.
|
|
@@ -258,7 +259,7 @@ def liger_cross_entropy_kernel(
|
|
|
258
259
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
259
260
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
260
261
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
261
|
-
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
|
|
262
|
+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
|
|
262
263
|
|
|
263
264
|
|
|
264
265
|
def cross_entropy_forward(
|
liger_kernel/ops/dyt.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
8
|
+
from liger_kernel.ops.utils import compare_version
|
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
from liger_kernel.ops.utils import infer_device
|
|
11
|
+
|
|
12
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
|
13
|
+
try:
|
|
14
|
+
# typical import path with dispatch available
|
|
15
|
+
from triton.language.extra.libdevice import tanh
|
|
16
|
+
except ModuleNotFoundError:
|
|
17
|
+
# for working with NGC containers
|
|
18
|
+
from triton.language.extra.cuda.libdevice import tanh
|
|
19
|
+
else:
|
|
20
|
+
from triton.language.math import tanh
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@triton.jit
|
|
24
|
+
def _dyt_fwd_kernel(
|
|
25
|
+
x_ptr,
|
|
26
|
+
x_row_stride,
|
|
27
|
+
alpha_ptr,
|
|
28
|
+
gamma_ptr,
|
|
29
|
+
beta_ptr,
|
|
30
|
+
y_ptr,
|
|
31
|
+
y_row_stride,
|
|
32
|
+
n_cols,
|
|
33
|
+
BLOCK_SIZE: tl.constexpr,
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Reference:
|
|
37
|
+
https://arxiv.org/abs/2503.10622
|
|
38
|
+
|
|
39
|
+
Shapes:
|
|
40
|
+
- x: (BT, C)
|
|
41
|
+
- alpha: (1)
|
|
42
|
+
- gamma: (C)
|
|
43
|
+
- beta: (C)
|
|
44
|
+
"""
|
|
45
|
+
row_idx = tl.program_id(0)
|
|
46
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
|
47
|
+
mask = offsets < n_cols
|
|
48
|
+
|
|
49
|
+
x_ptr += row_idx * x_row_stride
|
|
50
|
+
y_ptr += row_idx * y_row_stride
|
|
51
|
+
|
|
52
|
+
alpha = tl.load(alpha_ptr)
|
|
53
|
+
gamma = tl.load(gamma_ptr + offsets, mask=mask)
|
|
54
|
+
beta = tl.load(beta_ptr + offsets, mask=mask)
|
|
55
|
+
x = tl.load(x_ptr + offsets, mask=mask)
|
|
56
|
+
y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
|
|
57
|
+
tl.store(y_ptr + offsets, y, mask=mask)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@triton.jit
|
|
61
|
+
def _dyt_bwd_kernel(
|
|
62
|
+
x_ptr,
|
|
63
|
+
x_row_stride,
|
|
64
|
+
dy_ptr,
|
|
65
|
+
dy_row_stride,
|
|
66
|
+
dx_ptr,
|
|
67
|
+
dx_row_stride,
|
|
68
|
+
alpha_ptr,
|
|
69
|
+
dalpha_ptr,
|
|
70
|
+
gamma_ptr,
|
|
71
|
+
dgamma_ptr,
|
|
72
|
+
dgamma_row_stride,
|
|
73
|
+
n_cols,
|
|
74
|
+
n_rows,
|
|
75
|
+
ROWS_PER_PROGRAM: tl.constexpr,
|
|
76
|
+
BLOCK_SIZE: tl.constexpr,
|
|
77
|
+
):
|
|
78
|
+
"""
|
|
79
|
+
Reference:
|
|
80
|
+
https://arxiv.org/abs/2503.10622
|
|
81
|
+
|
|
82
|
+
Shapes:
|
|
83
|
+
- x: (BT, C)
|
|
84
|
+
- alpha: (1)
|
|
85
|
+
- gamma: (C)
|
|
86
|
+
- dx: (BT, C)
|
|
87
|
+
- dy: (BT, C)
|
|
88
|
+
- dgamma: (sm_count, C)
|
|
89
|
+
- dalpha: (sm_count,)
|
|
90
|
+
"""
|
|
91
|
+
# d(gamma * tanh(alpha * x) + beta) / dx
|
|
92
|
+
# = gamma * (1 - tanh^2(alpha * x)) * alpha
|
|
93
|
+
# d(gamma * tanh(alpha * x) + beta) / dalpha
|
|
94
|
+
# = gamma * (1 - tanh^2(alpha * x)) * x
|
|
95
|
+
# d(gamma * tanh(alpha * x) + beta) / dgamma
|
|
96
|
+
# = tanh(alpha * x)
|
|
97
|
+
# d(gamma * tanh(alpha * x)) / dbeta = 1
|
|
98
|
+
pid = tl.program_id(0)
|
|
99
|
+
|
|
100
|
+
row_start = pid * ROWS_PER_PROGRAM
|
|
101
|
+
row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
|
|
102
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
|
103
|
+
mask = offsets < n_cols
|
|
104
|
+
|
|
105
|
+
dalpha = 0.0
|
|
106
|
+
dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
107
|
+
|
|
108
|
+
x_ptr += row_start * x_row_stride
|
|
109
|
+
dx_ptr += row_start * dx_row_stride
|
|
110
|
+
dy_ptr += row_start * dy_row_stride
|
|
111
|
+
alpha = tl.load(alpha_ptr)
|
|
112
|
+
gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
|
|
113
|
+
|
|
114
|
+
for _ in tl.range(row_start, row_end):
|
|
115
|
+
dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
|
|
116
|
+
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
|
|
117
|
+
tanh_ax = tanh((alpha * x).cast(tl.float32))
|
|
118
|
+
sech2_ax = 1 - tanh_ax * tanh_ax
|
|
119
|
+
|
|
120
|
+
dx = dy * gamma * sech2_ax * alpha
|
|
121
|
+
dalpha += tl.sum(dy * gamma * sech2_ax * x)
|
|
122
|
+
dgamma += dy * tanh_ax
|
|
123
|
+
tl.store(dx_ptr + offsets, dx, mask=mask)
|
|
124
|
+
|
|
125
|
+
dy_ptr += dy_row_stride
|
|
126
|
+
x_ptr += x_row_stride
|
|
127
|
+
dx_ptr += dx_row_stride
|
|
128
|
+
|
|
129
|
+
tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
|
|
130
|
+
tl.store(dalpha_ptr + pid, dalpha)
|
|
131
|
+
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def liger_dyt_fwd(x, alpha, gamma, beta):
|
|
136
|
+
shape = x.shape
|
|
137
|
+
dim = shape[-1]
|
|
138
|
+
x = x.view(-1, dim)
|
|
139
|
+
n_rows, n_cols = x.shape
|
|
140
|
+
y = torch.empty_like(x)
|
|
141
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
142
|
+
_dyt_fwd_kernel[(n_rows,)](
|
|
143
|
+
x_ptr=x,
|
|
144
|
+
alpha_ptr=alpha,
|
|
145
|
+
gamma_ptr=gamma,
|
|
146
|
+
beta_ptr=beta,
|
|
147
|
+
y_ptr=y,
|
|
148
|
+
x_row_stride=x.stride(0),
|
|
149
|
+
y_row_stride=y.stride(0),
|
|
150
|
+
n_cols=n_cols,
|
|
151
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
152
|
+
num_warps=num_warps,
|
|
153
|
+
)
|
|
154
|
+
return y.view(*shape)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def liger_dyt_bwd(dy, x, alpha, gamma):
|
|
158
|
+
shape = dy.shape
|
|
159
|
+
dtype = x.dtype
|
|
160
|
+
dim = shape[-1]
|
|
161
|
+
dy = dy.view(-1, dim)
|
|
162
|
+
x = x.view(-1, dim)
|
|
163
|
+
n_rows, n_cols = dy.shape
|
|
164
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
165
|
+
sm_count = 1
|
|
166
|
+
device = infer_device()
|
|
167
|
+
if device == "cuda":
|
|
168
|
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
|
169
|
+
elif device == "xpu":
|
|
170
|
+
sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
|
|
171
|
+
if n_cols > BLOCK_SIZE:
|
|
172
|
+
raise RuntimeError(
|
|
173
|
+
f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
dx = torch.empty_like(x, dtype=torch.float32)
|
|
177
|
+
_dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
|
|
178
|
+
_dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
|
|
179
|
+
|
|
180
|
+
grid = (sm_count,)
|
|
181
|
+
rows_per_program = triton.cdiv(n_rows, sm_count)
|
|
182
|
+
_dyt_bwd_kernel[grid](
|
|
183
|
+
x_ptr=x,
|
|
184
|
+
x_row_stride=x.stride(0),
|
|
185
|
+
dy_ptr=dy,
|
|
186
|
+
dy_row_stride=dy.stride(0),
|
|
187
|
+
dx_ptr=dx,
|
|
188
|
+
dx_row_stride=dx.stride(0),
|
|
189
|
+
alpha_ptr=alpha,
|
|
190
|
+
dalpha_ptr=_dalpha,
|
|
191
|
+
gamma_ptr=gamma,
|
|
192
|
+
dgamma_ptr=_dgamma,
|
|
193
|
+
dgamma_row_stride=_dgamma.stride(0),
|
|
194
|
+
n_cols=n_cols,
|
|
195
|
+
n_rows=n_rows,
|
|
196
|
+
ROWS_PER_PROGRAM=rows_per_program,
|
|
197
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
198
|
+
num_warps=num_warps,
|
|
199
|
+
)
|
|
200
|
+
dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
|
|
201
|
+
dgamma = _dgamma.sum(dim=0).to(dtype)
|
|
202
|
+
dbeta = dy.sum(dim=0).to(dtype)
|
|
203
|
+
return dx.view(*shape), dalpha, dgamma, dbeta
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class LigerDyTFunction(torch.autograd.Function):
|
|
207
|
+
@staticmethod
|
|
208
|
+
@ensure_contiguous
|
|
209
|
+
def forward(ctx, x, alpha, gamma, beta):
|
|
210
|
+
y = liger_dyt_fwd(x, alpha, gamma, beta)
|
|
211
|
+
ctx.save_for_backward(x, alpha, gamma)
|
|
212
|
+
return y
|
|
213
|
+
|
|
214
|
+
@staticmethod
|
|
215
|
+
@ensure_contiguous
|
|
216
|
+
def backward(ctx, grad_output):
|
|
217
|
+
x, alpha, gamma = ctx.saved_tensors
|
|
218
|
+
dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
|
|
219
|
+
grad_output,
|
|
220
|
+
x,
|
|
221
|
+
alpha,
|
|
222
|
+
gamma,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return (dx, dalpha, dgamma, dbeta)
|
|
@@ -8,11 +8,12 @@ from liger_kernel.ops.utils import amp_custom_bwd
|
|
|
8
8
|
from liger_kernel.ops.utils import amp_custom_fwd
|
|
9
9
|
from liger_kernel.ops.utils import element_mul_kernel
|
|
10
10
|
from liger_kernel.ops.utils import is_hip
|
|
11
|
+
from liger_kernel.utils import infer_device
|
|
11
12
|
|
|
12
13
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
13
14
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
14
15
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
15
|
-
MAX_FUSED_SIZE = 65536 // 2
|
|
16
|
+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
def fused_linear_jsd_forward(
|
liger_kernel/ops/jsd.py
CHANGED
|
@@ -5,6 +5,7 @@ import triton
|
|
|
5
5
|
import triton.language as tl
|
|
6
6
|
|
|
7
7
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
8
|
+
from liger_kernel.utils import infer_device
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
@triton.jit
|
|
@@ -51,29 +52,48 @@ def _jsd_kernel(
|
|
|
51
52
|
Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
|
|
52
53
|
|
|
53
54
|
if beta == 0.0: # forward KL
|
|
54
|
-
|
|
55
|
+
Y_max = tl.max(Y, axis=0)
|
|
56
|
+
Y_shifted = Y - Y_max
|
|
57
|
+
Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
|
|
55
58
|
loss = Y_prob * (Y - X)
|
|
56
59
|
dX = -Y_prob
|
|
57
|
-
elif beta == 1.0:
|
|
58
|
-
|
|
60
|
+
elif beta == 1.0: # reverse KL
|
|
61
|
+
X_max = tl.max(X, axis=0)
|
|
62
|
+
X_shifted = X - X_max
|
|
63
|
+
X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
|
|
59
64
|
loss = X_prob * (X - Y)
|
|
60
65
|
dX = loss + X_prob
|
|
61
66
|
else:
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
log_M = tl.log(M)
|
|
67
|
+
max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
|
|
68
|
+
X_shifted = X - max_val
|
|
69
|
+
Y_shifted = Y - max_val
|
|
66
70
|
|
|
67
|
-
|
|
68
|
-
|
|
71
|
+
# Pre-compute exp(max_val) since it's used twice
|
|
72
|
+
exp_max = tl.exp(max_val)
|
|
73
|
+
|
|
74
|
+
# Compute exp terms with compensation
|
|
75
|
+
Q = tl.exp(X_shifted) * exp_max # = exp(X)
|
|
76
|
+
P = tl.exp(Y_shifted) * exp_max # = exp(Y)
|
|
77
|
+
|
|
78
|
+
# Pre-compute common terms
|
|
79
|
+
beta_P = beta * P
|
|
80
|
+
one_minus_beta_Q = (1 - beta) * Q
|
|
81
|
+
M = beta_P + one_minus_beta_Q
|
|
82
|
+
log_M = tl.log(M) # No need to compensate as M is already in original scale
|
|
83
|
+
|
|
84
|
+
loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
|
|
85
|
+
dX = one_minus_beta_Q * (X - log_M)
|
|
86
|
+
|
|
87
|
+
# Pre-compute scaling factor
|
|
88
|
+
scale = 1.0 / n_non_ignore
|
|
89
|
+
loss = loss * scale
|
|
90
|
+
dX = dX * scale
|
|
69
91
|
|
|
70
|
-
loss = loss / n_non_ignore
|
|
71
|
-
dX = dX / n_non_ignore
|
|
72
92
|
tl.store(loss_ptr + offsets, loss, mask=mask)
|
|
73
93
|
tl.store(dX_ptr + offsets, dX, mask=mask)
|
|
74
94
|
|
|
75
95
|
|
|
76
|
-
MAX_FUSED_SIZE = 65536
|
|
96
|
+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536
|
|
77
97
|
|
|
78
98
|
|
|
79
99
|
def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
|