liger-kernel 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cpo_loss.py +51 -11
- liger_kernel/chunked_loss/dpo_loss.py +30 -4
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +20 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +331 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +2 -2
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +112 -17
- liger_kernel/chunked_loss/grpo_loss.py +137 -61
- liger_kernel/chunked_loss/jsd_loss.py +43 -13
- liger_kernel/chunked_loss/kto_loss.py +50 -12
- liger_kernel/chunked_loss/orpo_loss.py +37 -5
- liger_kernel/chunked_loss/simpo_loss.py +47 -11
- liger_kernel/ops/cross_entropy.py +7 -2
- liger_kernel/ops/dyt.py +225 -0
- liger_kernel/ops/fused_linear_jsd.py +2 -1
- liger_kernel/ops/jsd.py +30 -11
- liger_kernel/ops/kl_div.py +2 -2
- liger_kernel/transformers/__init__.py +4 -0
- liger_kernel/transformers/dyt.py +20 -0
- liger_kernel/transformers/functional.py +5 -0
- liger_kernel/transformers/model/gemma.py +8 -16
- liger_kernel/transformers/model/gemma2.py +7 -16
- liger_kernel/transformers/model/llama.py +8 -15
- liger_kernel/transformers/model/llava.py +369 -0
- liger_kernel/transformers/model/loss_utils.py +57 -0
- liger_kernel/transformers/model/mistral.py +9 -10
- liger_kernel/transformers/model/mixtral.py +8 -15
- liger_kernel/transformers/model/mllama.py +8 -15
- liger_kernel/transformers/model/olmo2.py +8 -16
- liger_kernel/transformers/model/paligemma.py +397 -0
- liger_kernel/transformers/model/phi3.py +8 -15
- liger_kernel/transformers/model/qwen2.py +8 -15
- liger_kernel/transformers/model/qwen2_5_vl.py +204 -0
- liger_kernel/transformers/model/qwen2_vl.py +9 -10
- liger_kernel/transformers/monkey_patch.py +286 -12
- liger_kernel/utils.py +1 -3
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info}/METADATA +11 -7
- liger_kernel-0.5.6.dist-info/RECORD +80 -0
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info}/WHEEL +1 -1
- liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -213
- liger_kernel-0.5.4.dist-info/RECORD +0 -74
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info/licenses}/LICENSE +0 -0
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info/licenses}/NOTICE +0 -0
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info}/top_level.txt +0 -0
|
@@ -16,13 +16,13 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
16
16
|
|
|
17
17
|
@staticmethod
|
|
18
18
|
def forward(
|
|
19
|
+
cls,
|
|
19
20
|
ctx,
|
|
20
21
|
_input,
|
|
21
22
|
weight,
|
|
22
23
|
target,
|
|
23
24
|
preference_labels,
|
|
24
25
|
bias=None,
|
|
25
|
-
loss_fn=None,
|
|
26
26
|
chunk_size=1,
|
|
27
27
|
ignore_index=-100,
|
|
28
28
|
compiled=True,
|
|
@@ -30,6 +30,7 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
30
30
|
ref_input=None,
|
|
31
31
|
ref_weight=None,
|
|
32
32
|
ref_bias=None,
|
|
33
|
+
average_log_prob=False,
|
|
33
34
|
**loss_kwargs,
|
|
34
35
|
):
|
|
35
36
|
"""
|
|
@@ -59,6 +60,7 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
59
60
|
Shape: (batch_size,).
|
|
60
61
|
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
|
61
62
|
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
|
63
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token.
|
|
62
64
|
loss_kwargs (dict): Other possible arguments that a loss function might need
|
|
63
65
|
"""
|
|
64
66
|
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
|
|
@@ -72,14 +74,22 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
72
74
|
# Loss to be accumulated
|
|
73
75
|
loss_acc = torch.zeros((), device=_input.device)
|
|
74
76
|
|
|
77
|
+
# Metrics to be recorded
|
|
78
|
+
chosen_logps_sum = torch.zeros((), device=_input.device)
|
|
79
|
+
rejected_logps_sum = torch.zeros((), device=_input.device)
|
|
80
|
+
chosen_logits_sum = torch.zeros((), device=_input.device)
|
|
81
|
+
rejected_logits_sum = torch.zeros((), device=_input.device)
|
|
82
|
+
aggregated_aux_outputs = []
|
|
83
|
+
|
|
75
84
|
compute_loss = partial(
|
|
76
85
|
LigerFusedLinearUnpairedPreferenceBase._compute_loss,
|
|
77
|
-
preference_loss_fn=
|
|
86
|
+
preference_loss_fn=cls.preference_loss_fn,
|
|
78
87
|
full_target=target,
|
|
79
88
|
ignore_index=ignore_index,
|
|
80
89
|
use_ref_model=use_ref_model,
|
|
81
90
|
ref_weight=ref_weight,
|
|
82
91
|
ref_bias=ref_bias,
|
|
92
|
+
average_log_prob=average_log_prob,
|
|
83
93
|
**loss_kwargs,
|
|
84
94
|
)
|
|
85
95
|
|
|
@@ -88,7 +98,7 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
88
98
|
Fused forward and backward pass for a chunk of input and target.
|
|
89
99
|
"""
|
|
90
100
|
argnums = (0, 1, 4) if bias is not None else (0, 1)
|
|
91
|
-
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=
|
|
101
|
+
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
|
|
92
102
|
input_chunk,
|
|
93
103
|
weight,
|
|
94
104
|
target_chunk,
|
|
@@ -103,9 +113,19 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
103
113
|
preference_labels_chunk=None,
|
|
104
114
|
ref_input_chunk=None,
|
|
105
115
|
):
|
|
106
|
-
(
|
|
107
|
-
|
|
108
|
-
|
|
116
|
+
(
|
|
117
|
+
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias),
|
|
118
|
+
(
|
|
119
|
+
chunk_loss,
|
|
120
|
+
(
|
|
121
|
+
chunk_chosen_logps_sum,
|
|
122
|
+
chunk_rejected_logps_sum,
|
|
123
|
+
chunk_chosen_logits_sum,
|
|
124
|
+
chunk_rejected_logits_sum,
|
|
125
|
+
*aux_outputs,
|
|
126
|
+
),
|
|
127
|
+
),
|
|
128
|
+
) = fused_fwd_bwd(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk)
|
|
109
129
|
if bias is not None:
|
|
110
130
|
grad_bias.add_(chunk_grad_bias[0]) # accumulate bias gradient
|
|
111
131
|
|
|
@@ -116,6 +136,23 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
116
136
|
# Accumulate loss
|
|
117
137
|
loss_acc.add_(chunk_loss)
|
|
118
138
|
|
|
139
|
+
# Accumulate metrics
|
|
140
|
+
chosen_logps_sum.add_(chunk_chosen_logps_sum)
|
|
141
|
+
rejected_logps_sum.add_(chunk_rejected_logps_sum)
|
|
142
|
+
chosen_logits_sum.add_(chunk_chosen_logits_sum)
|
|
143
|
+
rejected_logits_sum.add_(chunk_rejected_logits_sum)
|
|
144
|
+
|
|
145
|
+
# aux_outputs
|
|
146
|
+
# Initialize storage for aux_outputs
|
|
147
|
+
if len(aggregated_aux_outputs) == 0:
|
|
148
|
+
for aux in aux_outputs:
|
|
149
|
+
aggregated_aux_outputs.append(torch.zeros((), device=aux.device))
|
|
150
|
+
|
|
151
|
+
# Process each aux_output
|
|
152
|
+
for i, aux in enumerate(aux_outputs):
|
|
153
|
+
if aux.ndim == 0:
|
|
154
|
+
aggregated_aux_outputs[i].add_(aux)
|
|
155
|
+
|
|
119
156
|
if compiled:
|
|
120
157
|
fused_fwd_bwd = torch.compile(fused_fwd_bwd)
|
|
121
158
|
|
|
@@ -151,12 +188,25 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
151
188
|
# accumulate loss, gradients, and metrics
|
|
152
189
|
accumulate_chunk(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk)
|
|
153
190
|
|
|
191
|
+
# Aggregate aux outputs lists into tensors
|
|
192
|
+
for i, aux in enumerate(aggregated_aux_outputs):
|
|
193
|
+
if isinstance(aux, list):
|
|
194
|
+
aggregated_aux_outputs[i] = torch.cat(aux, dim=0)
|
|
195
|
+
|
|
154
196
|
ctx.save_for_backward(
|
|
155
197
|
torch.cat(grad_inputs, dim=0),
|
|
156
198
|
grad_weight,
|
|
157
199
|
grad_bias,
|
|
158
200
|
)
|
|
159
|
-
|
|
201
|
+
|
|
202
|
+
return_vars = (
|
|
203
|
+
chosen_logps_sum,
|
|
204
|
+
rejected_logps_sum,
|
|
205
|
+
chosen_logits_sum,
|
|
206
|
+
rejected_logits_sum,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
return loss_acc, (*return_vars, *aggregated_aux_outputs)
|
|
160
210
|
|
|
161
211
|
@staticmethod
|
|
162
212
|
def backward(ctx, *grad_output):
|
|
@@ -173,21 +223,37 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
173
223
|
input_chunk,
|
|
174
224
|
weight,
|
|
175
225
|
target_chunk,
|
|
226
|
+
preference_labels_chunk,
|
|
176
227
|
bias=None,
|
|
177
228
|
ignore_index=-100,
|
|
229
|
+
average_log_prob=False,
|
|
178
230
|
):
|
|
179
231
|
logits_chunk = input_chunk @ weight.t()
|
|
180
232
|
if bias is not None:
|
|
181
233
|
logits_chunk = logits_chunk + bias
|
|
182
234
|
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
|
|
183
|
-
|
|
184
235
|
loss_mask_chunk = target_chunk != ignore_index
|
|
185
236
|
label_chunk = torch.where(loss_mask_chunk, target_chunk, 0)
|
|
186
237
|
|
|
187
238
|
per_token_logps_chunk = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
239
|
+
if average_log_prob:
|
|
240
|
+
log_probs = (per_token_logps_chunk * loss_mask_chunk).sum(-1) / loss_mask_chunk.sum(-1)
|
|
241
|
+
else:
|
|
242
|
+
log_probs = (per_token_logps_chunk * loss_mask_chunk).sum(-1)
|
|
243
|
+
|
|
244
|
+
chosen_logps_sum = (log_probs * preference_labels_chunk.unsqueeze(1)).sum()
|
|
245
|
+
rejected_logps_sum = (log_probs * (~preference_labels_chunk).unsqueeze(1)).sum()
|
|
246
|
+
|
|
247
|
+
chosen_logits_sum = (logits_chunk * preference_labels_chunk.unsqueeze(1)).sum()
|
|
248
|
+
rejected_logits_sum = (logits_chunk * (~preference_labels_chunk).unsqueeze(1)).sum()
|
|
249
|
+
|
|
250
|
+
return (
|
|
251
|
+
log_probs,
|
|
252
|
+
chosen_logps_sum,
|
|
253
|
+
rejected_logps_sum,
|
|
254
|
+
chosen_logits_sum,
|
|
255
|
+
rejected_logits_sum,
|
|
256
|
+
)
|
|
191
257
|
|
|
192
258
|
@staticmethod
|
|
193
259
|
def _compute_loss(
|
|
@@ -203,6 +269,7 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
203
269
|
ref_input_chunk=None,
|
|
204
270
|
ref_weight=None,
|
|
205
271
|
ref_bias=None,
|
|
272
|
+
average_log_prob=False,
|
|
206
273
|
**loss_kwargs,
|
|
207
274
|
):
|
|
208
275
|
"""
|
|
@@ -218,29 +285,57 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
218
285
|
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
|
219
286
|
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
|
220
287
|
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
|
288
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token.
|
|
221
289
|
loss_kwargs (dict): Additional arguments for the loss function.
|
|
222
290
|
"""
|
|
223
|
-
|
|
291
|
+
(
|
|
292
|
+
log_prob_chunk,
|
|
293
|
+
chosen_logps_sum,
|
|
294
|
+
rejected_logps_sum,
|
|
295
|
+
chosen_logits_sum,
|
|
296
|
+
rejected_logits_sum,
|
|
297
|
+
) = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
|
|
224
298
|
input_chunk,
|
|
225
299
|
weight,
|
|
226
300
|
target_chunk,
|
|
301
|
+
preference_labels_chunk,
|
|
227
302
|
bias=bias,
|
|
228
303
|
ignore_index=ignore_index,
|
|
304
|
+
average_log_prob=average_log_prob,
|
|
229
305
|
)
|
|
230
306
|
|
|
231
307
|
if use_ref_model:
|
|
232
308
|
with torch.no_grad():
|
|
233
|
-
|
|
309
|
+
(
|
|
310
|
+
ref_log_prob_chunk,
|
|
311
|
+
_,
|
|
312
|
+
_,
|
|
313
|
+
_,
|
|
314
|
+
_,
|
|
315
|
+
) = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
|
|
234
316
|
ref_input_chunk,
|
|
235
317
|
ref_weight,
|
|
236
318
|
target_chunk,
|
|
319
|
+
preference_labels_chunk,
|
|
237
320
|
ref_bias,
|
|
238
321
|
ignore_index=ignore_index,
|
|
322
|
+
average_log_prob=average_log_prob,
|
|
239
323
|
)
|
|
240
|
-
loss_kwargs["
|
|
324
|
+
loss_kwargs["ref_log_prob_chunk"] = ref_log_prob_chunk
|
|
241
325
|
|
|
242
|
-
|
|
243
|
-
|
|
326
|
+
preference_loss_outputs = preference_loss_fn(
|
|
327
|
+
log_prob_chunk, preference_labels_chunk, full_target, **loss_kwargs
|
|
328
|
+
)
|
|
329
|
+
if isinstance(preference_loss_outputs, tuple):
|
|
330
|
+
preference_loss_chunk, *aux_outputs = preference_loss_outputs
|
|
331
|
+
else:
|
|
332
|
+
preference_loss_chunk, aux_outputs = preference_loss_outputs, []
|
|
333
|
+
|
|
334
|
+
return_vars = (
|
|
335
|
+
chosen_logps_sum,
|
|
336
|
+
rejected_logps_sum,
|
|
337
|
+
chosen_logits_sum,
|
|
338
|
+
rejected_logits_sum,
|
|
244
339
|
)
|
|
245
340
|
|
|
246
|
-
return preference_loss_chunk
|
|
341
|
+
return preference_loss_chunk, (*return_vars, *aux_outputs)
|
|
@@ -1,99 +1,143 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from liger_kernel.chunked_loss.
|
|
3
|
+
from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
|
|
6
|
+
def k3_loss_fn(log_p, log_q):
|
|
7
|
+
# computes k3 estimate of KL[q, p]
|
|
8
|
+
# ref: http://joschu.net/blog/kl-approx.html
|
|
9
|
+
return torch.exp(log_p - log_q) - (log_p - log_q) - 1.0
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def clip_coef_fn(coef, epsilon_low, epsilon_high):
|
|
13
|
+
return torch.clamp(coef, 1 - epsilon_low, 1 + epsilon_high)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
7
17
|
@staticmethod
|
|
8
|
-
def
|
|
18
|
+
def ppo_loss_fn(
|
|
9
19
|
log_probs,
|
|
20
|
+
selected_token_ids,
|
|
10
21
|
attention_mask,
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
22
|
+
advantages,
|
|
23
|
+
full_attention_mask,
|
|
24
|
+
ref_per_token_logps=None, # shape: [chunk_size, seq_len]
|
|
25
|
+
old_per_token_logps=None,
|
|
26
|
+
ref_log_probs=None, # used when ref_per_token_logps is None (shape: [chunk_size, seq_len, vocab_size])
|
|
27
|
+
epsilon_low=0.2,
|
|
28
|
+
epsilon_high=0.2,
|
|
29
|
+
beta=0.04,
|
|
14
30
|
**kwargs,
|
|
15
31
|
):
|
|
16
32
|
"""GRPO Loss Function matching GRPOTrainer implementation."""
|
|
17
|
-
|
|
18
|
-
chosen_tokens = log_probs.argmax(dim=-1) # (batch_size, seq_len)
|
|
19
|
-
chosen_token_logprobs = log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(
|
|
33
|
+
per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
|
|
20
34
|
-1
|
|
21
35
|
) # (batch_size, seq_len)
|
|
22
36
|
|
|
23
37
|
# Get reference model probabilities
|
|
24
|
-
if
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
std_grouped_rewards = rewards.std() # [batch_size,]
|
|
33
|
-
|
|
34
|
-
# Calculate advantages using the same epsilon as in GRPOTrainer
|
|
35
|
-
eps = 1e-4
|
|
36
|
-
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + eps)
|
|
38
|
+
if ref_per_token_logps is None:
|
|
39
|
+
if ref_log_probs is not None:
|
|
40
|
+
with torch.no_grad():
|
|
41
|
+
ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
|
|
42
|
+
-1
|
|
43
|
+
)
|
|
44
|
+
else:
|
|
45
|
+
ref_per_token_logps = per_token_logps.detach()
|
|
37
46
|
|
|
38
47
|
# Compute policy gradient loss with importance sampling ratio
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
48
|
+
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
|
|
49
|
+
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
|
|
50
|
+
coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
|
|
51
|
+
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
|
52
|
+
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
|
53
|
+
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
|
54
|
+
if beta != 0.0:
|
|
55
|
+
# Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
|
|
56
|
+
kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
|
|
57
|
+
# Combine losses
|
|
58
|
+
per_token_loss = per_token_loss + beta * kl_div
|
|
59
|
+
|
|
60
|
+
# Note: We normalize by the number of tokens in the batch (using full_attention_mask),
|
|
61
|
+
# which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
|
|
62
|
+
# and TRL GRPO implementation
|
|
63
|
+
# (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
|
|
64
|
+
loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
|
|
55
65
|
|
|
56
66
|
# Calculate metrics
|
|
57
|
-
metrics =
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
(
|
|
67
|
+
metrics = []
|
|
68
|
+
if beta != 0.0:
|
|
69
|
+
metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
|
|
70
|
+
is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
|
|
71
|
+
(coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
|
|
62
72
|
)
|
|
63
|
-
|
|
73
|
+
metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
|
|
64
74
|
return loss, metrics
|
|
65
75
|
|
|
66
|
-
@
|
|
76
|
+
@classmethod
|
|
67
77
|
def forward(
|
|
78
|
+
cls,
|
|
68
79
|
ctx,
|
|
69
80
|
_input,
|
|
70
81
|
weight,
|
|
82
|
+
selected_token_ids,
|
|
71
83
|
attention_mask,
|
|
72
|
-
|
|
84
|
+
advantages,
|
|
73
85
|
bias=None,
|
|
86
|
+
ref_per_token_logps=None,
|
|
87
|
+
old_per_token_logps=None,
|
|
74
88
|
ref_input=None,
|
|
75
89
|
ref_weight=None,
|
|
76
90
|
ref_bias=None,
|
|
77
|
-
beta=0.
|
|
91
|
+
beta=0.04,
|
|
92
|
+
epsilon_low=0.2,
|
|
93
|
+
epsilon_high=0.2,
|
|
94
|
+
temperature=1.0,
|
|
78
95
|
compiled=True,
|
|
79
96
|
use_ref_model=True,
|
|
80
|
-
|
|
97
|
+
chunk_size=1,
|
|
81
98
|
):
|
|
82
|
-
|
|
99
|
+
"""
|
|
100
|
+
Fused linear layer with GRPO loss.
|
|
101
|
+
Args:
|
|
102
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
103
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
104
|
+
selected_token_ids (torch.Tensor): Selected token ids tensor. Shape: (batch_size, seq_len)
|
|
105
|
+
attention_mask (torch.Tensor): Attention mask tensor. Shape: (batch_size, seq_len)
|
|
106
|
+
advantages (torch.Tensor): Advantages tensor. Shape: (batch_size,)
|
|
107
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
108
|
+
ref_per_token_logps: Reference model log probs per token tensor. Shape:(batch_size, seq_len)
|
|
109
|
+
ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
110
|
+
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
|
111
|
+
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
|
112
|
+
beta (float): Weight for the KL penalty
|
|
113
|
+
temperature (float): Temperature for the logits
|
|
114
|
+
compiled (bool): Whether to use torch compile
|
|
115
|
+
use_ref_model (bool): Whether to use a reference model
|
|
116
|
+
chunk_size (int): Size of chunks for processing.
|
|
117
|
+
Returns:
|
|
118
|
+
torch.Tensor: Computed loss
|
|
119
|
+
"""
|
|
120
|
+
return super().forward(
|
|
121
|
+
cls=cls,
|
|
83
122
|
ctx=ctx,
|
|
84
123
|
_input=_input,
|
|
85
124
|
weight=weight,
|
|
125
|
+
selected_token_ids=selected_token_ids,
|
|
86
126
|
attention_mask=attention_mask,
|
|
87
|
-
|
|
88
|
-
rewards=rewards,
|
|
127
|
+
advantages=advantages,
|
|
89
128
|
bias=bias,
|
|
129
|
+
ref_per_token_logps=ref_per_token_logps,
|
|
130
|
+
old_per_token_logps=old_per_token_logps,
|
|
90
131
|
ref_input=ref_input,
|
|
91
132
|
ref_weight=ref_weight,
|
|
92
133
|
ref_bias=ref_bias,
|
|
93
134
|
beta=beta,
|
|
135
|
+
epsilon_low=epsilon_low,
|
|
136
|
+
epsilon_high=epsilon_high,
|
|
137
|
+
temperature=temperature,
|
|
94
138
|
compiled=compiled,
|
|
95
139
|
use_ref_model=use_ref_model,
|
|
96
|
-
|
|
140
|
+
chunk_size=chunk_size,
|
|
97
141
|
)
|
|
98
142
|
|
|
99
143
|
@staticmethod
|
|
@@ -104,16 +148,23 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
|
|
|
104
148
|
grad_output: Gradient of the loss (scalar)
|
|
105
149
|
grad_metrics: Gradients of the metrics (not used in backward computation)
|
|
106
150
|
"""
|
|
107
|
-
grads =
|
|
151
|
+
grads = LigerFusedLinearPPOBase.backward(ctx, grad_output)
|
|
108
152
|
return (
|
|
109
|
-
*grads[
|
|
153
|
+
*grads[
|
|
154
|
+
:6
|
|
155
|
+
], # grad_input, grad_weight, grad_selected_token_ids, grad_attention_mask, grad_advantages, grad_bias
|
|
156
|
+
None, # grad_ref_per_token_logps
|
|
157
|
+
None, # grad_old_per_token_logps
|
|
110
158
|
None, # grad_ref_input
|
|
111
159
|
None, # grad_ref_weight
|
|
112
160
|
None, # grad_ref_bias
|
|
113
161
|
None, # grad_beta
|
|
162
|
+
None, # grad_epsilon_low
|
|
163
|
+
None, # grad_epsilon_high
|
|
164
|
+
None, # grad_temperature
|
|
114
165
|
None, # grad_compiled
|
|
115
166
|
None, # grad_use_ref_model
|
|
116
|
-
None, #
|
|
167
|
+
None, # grad_chunk_size
|
|
117
168
|
)
|
|
118
169
|
|
|
119
170
|
|
|
@@ -122,24 +173,43 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
122
173
|
|
|
123
174
|
def __init__(
|
|
124
175
|
self,
|
|
125
|
-
beta: float = 0.
|
|
176
|
+
beta: float = 0.04,
|
|
126
177
|
compiled: bool = True,
|
|
127
178
|
use_ref_model: bool = True,
|
|
128
|
-
|
|
179
|
+
chunk_size: int = 1,
|
|
180
|
+
epsilon_low: float = 0.2,
|
|
181
|
+
epsilon_high: float = 0.2,
|
|
182
|
+
temperature: float = 1.0,
|
|
129
183
|
):
|
|
184
|
+
"""
|
|
185
|
+
Args:
|
|
186
|
+
beta (float): Weight for the KL penalty.
|
|
187
|
+
compiled (bool): Whether to use torch compile.
|
|
188
|
+
use_ref_model (bool): Whether to use a reference model.
|
|
189
|
+
chunk_size (int): Size of chunks for processing.
|
|
190
|
+
epsilon_low (float): Lower bound for the importance sampling ratio.
|
|
191
|
+
epsilon_high (float): Upper bound for the importance sampling ratio.
|
|
192
|
+
temperature (float): Temperature for the logits.
|
|
193
|
+
"""
|
|
130
194
|
super().__init__()
|
|
131
195
|
self.beta = beta
|
|
132
196
|
self.compiled = compiled
|
|
133
197
|
self.use_ref_model = use_ref_model
|
|
134
|
-
self.
|
|
198
|
+
self.chunk_size = chunk_size
|
|
199
|
+
self.epsilon_low = epsilon_low
|
|
200
|
+
self.epsilon_high = epsilon_high
|
|
201
|
+
self.temperature = temperature
|
|
135
202
|
|
|
136
203
|
def forward(
|
|
137
204
|
self,
|
|
138
205
|
_input,
|
|
139
206
|
lin_weight,
|
|
207
|
+
selected_token_ids,
|
|
140
208
|
attention_mask,
|
|
141
|
-
|
|
209
|
+
advantages,
|
|
142
210
|
bias=None,
|
|
211
|
+
ref_per_token_logps=None,
|
|
212
|
+
old_per_token_logps=None,
|
|
143
213
|
ref_input=None,
|
|
144
214
|
ref_weight=None,
|
|
145
215
|
ref_bias=None,
|
|
@@ -147,14 +217,20 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
147
217
|
return LigerFusedLinearGRPOFunction.apply(
|
|
148
218
|
_input,
|
|
149
219
|
lin_weight,
|
|
220
|
+
selected_token_ids,
|
|
150
221
|
attention_mask,
|
|
151
|
-
|
|
222
|
+
advantages,
|
|
152
223
|
bias,
|
|
224
|
+
ref_per_token_logps,
|
|
225
|
+
old_per_token_logps,
|
|
153
226
|
ref_input,
|
|
154
227
|
ref_weight,
|
|
155
228
|
ref_bias,
|
|
156
229
|
self.beta,
|
|
230
|
+
self.epsilon_low,
|
|
231
|
+
self.epsilon_high,
|
|
232
|
+
self.temperature,
|
|
157
233
|
self.compiled,
|
|
158
234
|
self.use_ref_model,
|
|
159
|
-
self.
|
|
235
|
+
self.chunk_size,
|
|
160
236
|
)
|
|
@@ -19,31 +19,40 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
19
19
|
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
|
20
20
|
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
|
21
21
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
22
|
+
if beta == 0:
|
|
23
|
+
jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="sum", log_target=True)
|
|
24
|
+
elif beta == 1:
|
|
25
|
+
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
|
|
26
|
+
else:
|
|
27
|
+
# Compute probabilities (only required for mean calculation)
|
|
28
|
+
mean_probs = (1 - beta) * student_log_probs.exp() + beta * teacher_log_probs.exp()
|
|
29
|
+
log_mean_probs = mean_probs.log()
|
|
25
30
|
|
|
26
|
-
|
|
27
|
-
|
|
31
|
+
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
|
|
32
|
+
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
|
|
28
33
|
|
|
29
|
-
|
|
30
|
-
|
|
34
|
+
# JSD is the weighted average of the KL divergences
|
|
35
|
+
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
|
31
36
|
return jsd_loss
|
|
32
37
|
|
|
33
|
-
@
|
|
38
|
+
@classmethod
|
|
34
39
|
def forward(
|
|
40
|
+
cls,
|
|
35
41
|
ctx,
|
|
36
42
|
student_input: torch.Tensor,
|
|
37
43
|
student_weight: torch.Tensor,
|
|
38
44
|
teacher_input: torch.Tensor,
|
|
39
45
|
teacher_weight: torch.Tensor,
|
|
40
46
|
true_labels: torch.LongTensor,
|
|
47
|
+
student_bias: torch.Tensor,
|
|
48
|
+
teacher_bias: torch.Tensor,
|
|
41
49
|
weight_hard_loss: float = 0.5,
|
|
42
50
|
weight_soft_loss: float = 0.5,
|
|
43
51
|
beta: float = 0.5,
|
|
44
52
|
ignore_index: int = -100,
|
|
45
53
|
temperature: float = 1.0,
|
|
46
54
|
compiled: bool = True,
|
|
55
|
+
chunk_size: int = 1024,
|
|
47
56
|
):
|
|
48
57
|
"""
|
|
49
58
|
Fused linear layer with JSD distillation loss.
|
|
@@ -59,18 +68,21 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
59
68
|
ignore_index (int): Index to ignore in loss computation
|
|
60
69
|
temperature (float): Temperature for softening/sharpening distributions
|
|
61
70
|
compiled (bool): Whether to use torch compile
|
|
71
|
+
chunk_size (int): Size of chunks for processing.
|
|
62
72
|
Returns:
|
|
63
73
|
torch.Tensor: Computed loss
|
|
64
74
|
"""
|
|
65
|
-
return
|
|
75
|
+
return super().forward(
|
|
76
|
+
cls=cls,
|
|
66
77
|
ctx=ctx,
|
|
67
78
|
student_input=student_input,
|
|
68
79
|
student_weight=student_weight,
|
|
69
80
|
teacher_input=teacher_input,
|
|
70
81
|
teacher_weight=teacher_weight,
|
|
71
82
|
target=true_labels,
|
|
72
|
-
|
|
73
|
-
|
|
83
|
+
student_bias=student_bias,
|
|
84
|
+
teacher_bias=teacher_bias,
|
|
85
|
+
chunk_size=chunk_size,
|
|
74
86
|
weight_hard_loss=weight_hard_loss,
|
|
75
87
|
weight_soft_loss=weight_soft_loss,
|
|
76
88
|
beta=beta,
|
|
@@ -81,9 +93,19 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
81
93
|
|
|
82
94
|
@staticmethod
|
|
83
95
|
def backward(ctx, grad_output):
|
|
84
|
-
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:
|
|
96
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
|
|
85
97
|
|
|
86
|
-
return (
|
|
98
|
+
return (
|
|
99
|
+
*grads,
|
|
100
|
+
None, # teacher_bias
|
|
101
|
+
None, # weight_hard_loss
|
|
102
|
+
None, # weight_soft_loss
|
|
103
|
+
None, # beta
|
|
104
|
+
None, # ignore_index
|
|
105
|
+
None, # temperature
|
|
106
|
+
None, # compiled
|
|
107
|
+
None, # chunk_size
|
|
108
|
+
)
|
|
87
109
|
|
|
88
110
|
|
|
89
111
|
class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
@@ -99,6 +121,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
99
121
|
ignore_index: int = -100,
|
|
100
122
|
temperature: float = 1.0,
|
|
101
123
|
compiled: bool = True,
|
|
124
|
+
chunk_size: int = 1024,
|
|
102
125
|
):
|
|
103
126
|
"""
|
|
104
127
|
Args:
|
|
@@ -108,6 +131,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
108
131
|
temperature (float): Temperature for softening distributions
|
|
109
132
|
compiled (bool): Whether to use torch compile
|
|
110
133
|
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
134
|
+
chunk_size (int): Size of chunks for processing.
|
|
111
135
|
"""
|
|
112
136
|
super().__init__()
|
|
113
137
|
assert temperature != 0, "Temperature cannot be 0."
|
|
@@ -117,6 +141,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
117
141
|
self.temperature = temperature
|
|
118
142
|
self.compiled = compiled
|
|
119
143
|
self.beta = beta
|
|
144
|
+
self.chunk_size = chunk_size
|
|
120
145
|
|
|
121
146
|
def forward(
|
|
122
147
|
self,
|
|
@@ -125,6 +150,8 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
125
150
|
teacher_input: torch.Tensor,
|
|
126
151
|
teacher_weight: torch.Tensor,
|
|
127
152
|
true_labels: torch.LongTensor,
|
|
153
|
+
student_bias: torch.Tensor,
|
|
154
|
+
teacher_bias: torch.Tensor,
|
|
128
155
|
) -> torch.Tensor:
|
|
129
156
|
"""
|
|
130
157
|
Compute the JSD distillation loss.
|
|
@@ -145,10 +172,13 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
145
172
|
teacher_input,
|
|
146
173
|
teacher_weight,
|
|
147
174
|
true_labels,
|
|
175
|
+
student_bias,
|
|
176
|
+
teacher_bias,
|
|
148
177
|
self.weight_hard_loss,
|
|
149
178
|
self.weight_soft_loss,
|
|
150
179
|
self.beta,
|
|
151
180
|
self.ignore_index,
|
|
152
181
|
self.temperature,
|
|
153
182
|
self.compiled,
|
|
183
|
+
self.chunk_size,
|
|
154
184
|
)
|