liger-kernel 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cpo_loss.py +51 -11
- liger_kernel/chunked_loss/dpo_loss.py +30 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +3 -3
- liger_kernel/chunked_loss/fused_linear_preference.py +2 -2
- liger_kernel/chunked_loss/fused_linear_rlhf.py +240 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +112 -17
- liger_kernel/chunked_loss/grpo_loss.py +194 -0
- liger_kernel/chunked_loss/jsd_loss.py +31 -6
- liger_kernel/chunked_loss/kto_loss.py +53 -15
- liger_kernel/chunked_loss/orpo_loss.py +37 -5
- liger_kernel/chunked_loss/simpo_loss.py +47 -11
- liger_kernel/ops/cross_entropy.py +7 -3
- liger_kernel/ops/fused_linear_cross_entropy.py +3 -3
- liger_kernel/ops/fused_linear_jsd.py +3 -3
- liger_kernel/ops/jsd.py +3 -3
- liger_kernel/ops/layer_norm.py +20 -7
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +1 -2
- liger_kernel/transformers/__init__.py +4 -0
- liger_kernel/transformers/cross_entropy.py +3 -3
- liger_kernel/transformers/functional.py +17 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +3 -3
- liger_kernel/transformers/group_norm.py +6 -6
- liger_kernel/transformers/model/olmo2.py +124 -0
- liger_kernel/transformers/model/qwen2_5_vl.py +205 -0
- liger_kernel/transformers/monkey_patch.py +239 -27
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/utils.py +48 -1
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/METADATA +19 -4
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/RECORD +35 -29
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/WHEEL +1 -1
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/LICENSE +0 -0
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/NOTICE +0 -0
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.5.dist-info}/top_level.txt +0 -0
|
@@ -16,13 +16,13 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
16
16
|
|
|
17
17
|
@staticmethod
|
|
18
18
|
def forward(
|
|
19
|
+
cls,
|
|
19
20
|
ctx,
|
|
20
21
|
_input,
|
|
21
22
|
weight,
|
|
22
23
|
target,
|
|
23
24
|
preference_labels,
|
|
24
25
|
bias=None,
|
|
25
|
-
loss_fn=None,
|
|
26
26
|
chunk_size=1,
|
|
27
27
|
ignore_index=-100,
|
|
28
28
|
compiled=True,
|
|
@@ -30,6 +30,7 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
30
30
|
ref_input=None,
|
|
31
31
|
ref_weight=None,
|
|
32
32
|
ref_bias=None,
|
|
33
|
+
average_log_prob=False,
|
|
33
34
|
**loss_kwargs,
|
|
34
35
|
):
|
|
35
36
|
"""
|
|
@@ -59,6 +60,7 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
59
60
|
Shape: (batch_size,).
|
|
60
61
|
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
|
61
62
|
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
|
63
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token.
|
|
62
64
|
loss_kwargs (dict): Other possible arguments that a loss function might need
|
|
63
65
|
"""
|
|
64
66
|
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
|
|
@@ -72,14 +74,22 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
72
74
|
# Loss to be accumulated
|
|
73
75
|
loss_acc = torch.zeros((), device=_input.device)
|
|
74
76
|
|
|
77
|
+
# Metrics to be recorded
|
|
78
|
+
chosen_logps_sum = torch.zeros((), device=_input.device)
|
|
79
|
+
rejected_logps_sum = torch.zeros((), device=_input.device)
|
|
80
|
+
chosen_logits_sum = torch.zeros((), device=_input.device)
|
|
81
|
+
rejected_logits_sum = torch.zeros((), device=_input.device)
|
|
82
|
+
aggregated_aux_outputs = []
|
|
83
|
+
|
|
75
84
|
compute_loss = partial(
|
|
76
85
|
LigerFusedLinearUnpairedPreferenceBase._compute_loss,
|
|
77
|
-
preference_loss_fn=
|
|
86
|
+
preference_loss_fn=cls.preference_loss_fn,
|
|
78
87
|
full_target=target,
|
|
79
88
|
ignore_index=ignore_index,
|
|
80
89
|
use_ref_model=use_ref_model,
|
|
81
90
|
ref_weight=ref_weight,
|
|
82
91
|
ref_bias=ref_bias,
|
|
92
|
+
average_log_prob=average_log_prob,
|
|
83
93
|
**loss_kwargs,
|
|
84
94
|
)
|
|
85
95
|
|
|
@@ -88,7 +98,7 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
88
98
|
Fused forward and backward pass for a chunk of input and target.
|
|
89
99
|
"""
|
|
90
100
|
argnums = (0, 1, 4) if bias is not None else (0, 1)
|
|
91
|
-
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=
|
|
101
|
+
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
|
|
92
102
|
input_chunk,
|
|
93
103
|
weight,
|
|
94
104
|
target_chunk,
|
|
@@ -103,9 +113,19 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
103
113
|
preference_labels_chunk=None,
|
|
104
114
|
ref_input_chunk=None,
|
|
105
115
|
):
|
|
106
|
-
(
|
|
107
|
-
|
|
108
|
-
|
|
116
|
+
(
|
|
117
|
+
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias),
|
|
118
|
+
(
|
|
119
|
+
chunk_loss,
|
|
120
|
+
(
|
|
121
|
+
chunk_chosen_logps_sum,
|
|
122
|
+
chunk_rejected_logps_sum,
|
|
123
|
+
chunk_chosen_logits_sum,
|
|
124
|
+
chunk_rejected_logits_sum,
|
|
125
|
+
*aux_outputs,
|
|
126
|
+
),
|
|
127
|
+
),
|
|
128
|
+
) = fused_fwd_bwd(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk)
|
|
109
129
|
if bias is not None:
|
|
110
130
|
grad_bias.add_(chunk_grad_bias[0]) # accumulate bias gradient
|
|
111
131
|
|
|
@@ -116,6 +136,23 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
116
136
|
# Accumulate loss
|
|
117
137
|
loss_acc.add_(chunk_loss)
|
|
118
138
|
|
|
139
|
+
# Accumulate metrics
|
|
140
|
+
chosen_logps_sum.add_(chunk_chosen_logps_sum)
|
|
141
|
+
rejected_logps_sum.add_(chunk_rejected_logps_sum)
|
|
142
|
+
chosen_logits_sum.add_(chunk_chosen_logits_sum)
|
|
143
|
+
rejected_logits_sum.add_(chunk_rejected_logits_sum)
|
|
144
|
+
|
|
145
|
+
# aux_outputs
|
|
146
|
+
# Initialize storage for aux_outputs
|
|
147
|
+
if len(aggregated_aux_outputs) == 0:
|
|
148
|
+
for aux in aux_outputs:
|
|
149
|
+
aggregated_aux_outputs.append(torch.zeros((), device=aux.device))
|
|
150
|
+
|
|
151
|
+
# Process each aux_output
|
|
152
|
+
for i, aux in enumerate(aux_outputs):
|
|
153
|
+
if aux.ndim == 0:
|
|
154
|
+
aggregated_aux_outputs[i].add_(aux)
|
|
155
|
+
|
|
119
156
|
if compiled:
|
|
120
157
|
fused_fwd_bwd = torch.compile(fused_fwd_bwd)
|
|
121
158
|
|
|
@@ -151,12 +188,25 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
151
188
|
# accumulate loss, gradients, and metrics
|
|
152
189
|
accumulate_chunk(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk)
|
|
153
190
|
|
|
191
|
+
# Aggregate aux outputs lists into tensors
|
|
192
|
+
for i, aux in enumerate(aggregated_aux_outputs):
|
|
193
|
+
if isinstance(aux, list):
|
|
194
|
+
aggregated_aux_outputs[i] = torch.cat(aux, dim=0)
|
|
195
|
+
|
|
154
196
|
ctx.save_for_backward(
|
|
155
197
|
torch.cat(grad_inputs, dim=0),
|
|
156
198
|
grad_weight,
|
|
157
199
|
grad_bias,
|
|
158
200
|
)
|
|
159
|
-
|
|
201
|
+
|
|
202
|
+
return_vars = (
|
|
203
|
+
chosen_logps_sum,
|
|
204
|
+
rejected_logps_sum,
|
|
205
|
+
chosen_logits_sum,
|
|
206
|
+
rejected_logits_sum,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
return loss_acc, (*return_vars, *aggregated_aux_outputs)
|
|
160
210
|
|
|
161
211
|
@staticmethod
|
|
162
212
|
def backward(ctx, *grad_output):
|
|
@@ -173,21 +223,37 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
173
223
|
input_chunk,
|
|
174
224
|
weight,
|
|
175
225
|
target_chunk,
|
|
226
|
+
preference_labels_chunk,
|
|
176
227
|
bias=None,
|
|
177
228
|
ignore_index=-100,
|
|
229
|
+
average_log_prob=False,
|
|
178
230
|
):
|
|
179
231
|
logits_chunk = input_chunk @ weight.t()
|
|
180
232
|
if bias is not None:
|
|
181
233
|
logits_chunk = logits_chunk + bias
|
|
182
234
|
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
|
|
183
|
-
|
|
184
235
|
loss_mask_chunk = target_chunk != ignore_index
|
|
185
236
|
label_chunk = torch.where(loss_mask_chunk, target_chunk, 0)
|
|
186
237
|
|
|
187
238
|
per_token_logps_chunk = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
239
|
+
if average_log_prob:
|
|
240
|
+
log_probs = (per_token_logps_chunk * loss_mask_chunk).sum(-1) / loss_mask_chunk.sum(-1)
|
|
241
|
+
else:
|
|
242
|
+
log_probs = (per_token_logps_chunk * loss_mask_chunk).sum(-1)
|
|
243
|
+
|
|
244
|
+
chosen_logps_sum = (log_probs * preference_labels_chunk.unsqueeze(1)).sum()
|
|
245
|
+
rejected_logps_sum = (log_probs * (~preference_labels_chunk).unsqueeze(1)).sum()
|
|
246
|
+
|
|
247
|
+
chosen_logits_sum = (logits_chunk * preference_labels_chunk.unsqueeze(1)).sum()
|
|
248
|
+
rejected_logits_sum = (logits_chunk * (~preference_labels_chunk).unsqueeze(1)).sum()
|
|
249
|
+
|
|
250
|
+
return (
|
|
251
|
+
log_probs,
|
|
252
|
+
chosen_logps_sum,
|
|
253
|
+
rejected_logps_sum,
|
|
254
|
+
chosen_logits_sum,
|
|
255
|
+
rejected_logits_sum,
|
|
256
|
+
)
|
|
191
257
|
|
|
192
258
|
@staticmethod
|
|
193
259
|
def _compute_loss(
|
|
@@ -203,6 +269,7 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
203
269
|
ref_input_chunk=None,
|
|
204
270
|
ref_weight=None,
|
|
205
271
|
ref_bias=None,
|
|
272
|
+
average_log_prob=False,
|
|
206
273
|
**loss_kwargs,
|
|
207
274
|
):
|
|
208
275
|
"""
|
|
@@ -218,29 +285,57 @@ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
|
218
285
|
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
|
219
286
|
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
|
220
287
|
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
|
288
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token.
|
|
221
289
|
loss_kwargs (dict): Additional arguments for the loss function.
|
|
222
290
|
"""
|
|
223
|
-
|
|
291
|
+
(
|
|
292
|
+
log_prob_chunk,
|
|
293
|
+
chosen_logps_sum,
|
|
294
|
+
rejected_logps_sum,
|
|
295
|
+
chosen_logits_sum,
|
|
296
|
+
rejected_logits_sum,
|
|
297
|
+
) = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
|
|
224
298
|
input_chunk,
|
|
225
299
|
weight,
|
|
226
300
|
target_chunk,
|
|
301
|
+
preference_labels_chunk,
|
|
227
302
|
bias=bias,
|
|
228
303
|
ignore_index=ignore_index,
|
|
304
|
+
average_log_prob=average_log_prob,
|
|
229
305
|
)
|
|
230
306
|
|
|
231
307
|
if use_ref_model:
|
|
232
308
|
with torch.no_grad():
|
|
233
|
-
|
|
309
|
+
(
|
|
310
|
+
ref_log_prob_chunk,
|
|
311
|
+
_,
|
|
312
|
+
_,
|
|
313
|
+
_,
|
|
314
|
+
_,
|
|
315
|
+
) = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
|
|
234
316
|
ref_input_chunk,
|
|
235
317
|
ref_weight,
|
|
236
318
|
target_chunk,
|
|
319
|
+
preference_labels_chunk,
|
|
237
320
|
ref_bias,
|
|
238
321
|
ignore_index=ignore_index,
|
|
322
|
+
average_log_prob=average_log_prob,
|
|
239
323
|
)
|
|
240
|
-
loss_kwargs["
|
|
324
|
+
loss_kwargs["ref_log_prob_chunk"] = ref_log_prob_chunk
|
|
241
325
|
|
|
242
|
-
|
|
243
|
-
|
|
326
|
+
preference_loss_outputs = preference_loss_fn(
|
|
327
|
+
log_prob_chunk, preference_labels_chunk, full_target, **loss_kwargs
|
|
328
|
+
)
|
|
329
|
+
if isinstance(preference_loss_outputs, tuple):
|
|
330
|
+
preference_loss_chunk, *aux_outputs = preference_loss_outputs
|
|
331
|
+
else:
|
|
332
|
+
preference_loss_chunk, aux_outputs = preference_loss_outputs, []
|
|
333
|
+
|
|
334
|
+
return_vars = (
|
|
335
|
+
chosen_logps_sum,
|
|
336
|
+
rejected_logps_sum,
|
|
337
|
+
chosen_logits_sum,
|
|
338
|
+
rejected_logits_sum,
|
|
244
339
|
)
|
|
245
340
|
|
|
246
|
-
return preference_loss_chunk
|
|
341
|
+
return preference_loss_chunk, (*return_vars, *aux_outputs)
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from liger_kernel.chunked_loss.fused_linear_rlhf import LigerFusedLinearRLHFBase
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
|
|
7
|
+
@staticmethod
|
|
8
|
+
def rlhf_loss_fn(
|
|
9
|
+
log_probs,
|
|
10
|
+
attention_mask,
|
|
11
|
+
rewards,
|
|
12
|
+
ref_log_probs=None,
|
|
13
|
+
beta=0.1,
|
|
14
|
+
**kwargs,
|
|
15
|
+
):
|
|
16
|
+
"""GRPO Loss Function matching GRPOTrainer implementation."""
|
|
17
|
+
# Get chosen token probabilities
|
|
18
|
+
chosen_tokens = log_probs.argmax(dim=-1) # (batch_size, seq_len)
|
|
19
|
+
chosen_token_logprobs = log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(
|
|
20
|
+
-1
|
|
21
|
+
) # (batch_size, seq_len)
|
|
22
|
+
|
|
23
|
+
# Get reference model probabilities
|
|
24
|
+
if ref_log_probs is not None:
|
|
25
|
+
with torch.no_grad():
|
|
26
|
+
ref_token_logprobs = ref_log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(-1)
|
|
27
|
+
else:
|
|
28
|
+
ref_token_logprobs = chosen_token_logprobs.detach()
|
|
29
|
+
|
|
30
|
+
# Compute advantages per batch entry in a grouped fashion
|
|
31
|
+
mean_grouped_rewards = rewards.mean() # [batch_size,]
|
|
32
|
+
std_grouped_rewards = rewards.std() # [batch_size,]
|
|
33
|
+
|
|
34
|
+
# Calculate advantages using the same epsilon as in GRPOTrainer
|
|
35
|
+
eps = 1e-4
|
|
36
|
+
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + eps)
|
|
37
|
+
|
|
38
|
+
# Compute policy gradient loss with importance sampling ratio
|
|
39
|
+
ratio = torch.exp(chosen_token_logprobs - chosen_token_logprobs.detach())
|
|
40
|
+
policy_loss = -ratio * advantages.unsqueeze(1)
|
|
41
|
+
|
|
42
|
+
# Compute KL penalty
|
|
43
|
+
kl_div = (
|
|
44
|
+
torch.exp(ref_token_logprobs - chosen_token_logprobs) - (ref_token_logprobs - chosen_token_logprobs) - 1.0
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Combine losses
|
|
48
|
+
per_token_loss = policy_loss + beta * kl_div
|
|
49
|
+
|
|
50
|
+
# Apply masking and normalize
|
|
51
|
+
masked_loss = per_token_loss * attention_mask
|
|
52
|
+
seq_lengths = attention_mask.sum()
|
|
53
|
+
seq_lengths = torch.clamp(seq_lengths, min=1.0)
|
|
54
|
+
loss = masked_loss.sum() / seq_lengths
|
|
55
|
+
|
|
56
|
+
# Calculate metrics
|
|
57
|
+
metrics = (
|
|
58
|
+
chosen_token_logprobs.mean(), # mean log prob
|
|
59
|
+
chosen_token_logprobs.std(), # std log prob
|
|
60
|
+
log_probs.mean(), # mean all log probs
|
|
61
|
+
((kl_div * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)).mean(), # mean KL div
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
return loss, metrics
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def forward(
|
|
68
|
+
cls,
|
|
69
|
+
ctx,
|
|
70
|
+
_input,
|
|
71
|
+
weight,
|
|
72
|
+
attention_mask,
|
|
73
|
+
rewards,
|
|
74
|
+
bias=None,
|
|
75
|
+
ref_input=None,
|
|
76
|
+
ref_weight=None,
|
|
77
|
+
ref_bias=None,
|
|
78
|
+
beta=0.1,
|
|
79
|
+
compiled=True,
|
|
80
|
+
use_ref_model=True,
|
|
81
|
+
num_generations=1,
|
|
82
|
+
chunk_size=1,
|
|
83
|
+
):
|
|
84
|
+
"""
|
|
85
|
+
Fused linear layer with GRPO loss.
|
|
86
|
+
Args:
|
|
87
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
88
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
89
|
+
attention_mask (torch.Tensor): Attention mask tensor. Shape: (batch_size, seq_len)
|
|
90
|
+
rewards (torch.Tensor): Rewards tensor. Shape: (batch_size,)
|
|
91
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
92
|
+
ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
93
|
+
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
|
94
|
+
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
|
95
|
+
beta (float): Weight for the KL penalty
|
|
96
|
+
compiled (bool): Whether to use torch compile
|
|
97
|
+
use_ref_model (bool): Whether to use a reference model
|
|
98
|
+
num_generations (int): Number of generations per prompt
|
|
99
|
+
chunk_size (int): Size of chunks for processing.
|
|
100
|
+
Returns:
|
|
101
|
+
torch.Tensor: Computed loss
|
|
102
|
+
"""
|
|
103
|
+
return super().forward(
|
|
104
|
+
cls=cls,
|
|
105
|
+
ctx=ctx,
|
|
106
|
+
_input=_input,
|
|
107
|
+
weight=weight,
|
|
108
|
+
attention_mask=attention_mask,
|
|
109
|
+
rewards=rewards,
|
|
110
|
+
bias=bias,
|
|
111
|
+
ref_input=ref_input,
|
|
112
|
+
ref_weight=ref_weight,
|
|
113
|
+
ref_bias=ref_bias,
|
|
114
|
+
beta=beta,
|
|
115
|
+
compiled=compiled,
|
|
116
|
+
use_ref_model=use_ref_model,
|
|
117
|
+
num_generations=num_generations,
|
|
118
|
+
chunk_size=chunk_size,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
@staticmethod
|
|
122
|
+
def backward(ctx, grad_output, *grad_metrics):
|
|
123
|
+
"""Backward pass for GRPO loss.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
grad_output: Gradient of the loss (scalar)
|
|
127
|
+
grad_metrics: Gradients of the metrics (not used in backward computation)
|
|
128
|
+
"""
|
|
129
|
+
grads = LigerFusedLinearRLHFBase.backward(ctx, grad_output)
|
|
130
|
+
return (
|
|
131
|
+
*grads[:5], # grad_input, grad_weight, grad_attention_mask, grad_rewards, grad_bias
|
|
132
|
+
None, # grad_ref_input
|
|
133
|
+
None, # grad_ref_weight
|
|
134
|
+
None, # grad_ref_bias
|
|
135
|
+
None, # grad_beta
|
|
136
|
+
None, # grad_compiled
|
|
137
|
+
None, # grad_use_ref_model
|
|
138
|
+
None, # grad_num_generations
|
|
139
|
+
None, # grad_chunk_size
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
144
|
+
"""Fused linear layer with GRPO loss."""
|
|
145
|
+
|
|
146
|
+
def __init__(
|
|
147
|
+
self,
|
|
148
|
+
beta: float = 0.1,
|
|
149
|
+
compiled: bool = True,
|
|
150
|
+
use_ref_model: bool = True,
|
|
151
|
+
num_generations: int = 1,
|
|
152
|
+
chunk_size: int = 1,
|
|
153
|
+
):
|
|
154
|
+
"""
|
|
155
|
+
Args:
|
|
156
|
+
beta (float): Weight for the KL penalty.
|
|
157
|
+
compiled (bool): Whether to use torch compile.
|
|
158
|
+
use_ref_model (bool): Whether to use a reference model.
|
|
159
|
+
num_generations (int): Number of generations per prompt.
|
|
160
|
+
chunk_size (int): Size of chunks for processing.
|
|
161
|
+
"""
|
|
162
|
+
super().__init__()
|
|
163
|
+
self.beta = beta
|
|
164
|
+
self.compiled = compiled
|
|
165
|
+
self.use_ref_model = use_ref_model
|
|
166
|
+
self.num_generations = num_generations
|
|
167
|
+
self.chunk_size = chunk_size
|
|
168
|
+
|
|
169
|
+
def forward(
|
|
170
|
+
self,
|
|
171
|
+
_input,
|
|
172
|
+
lin_weight,
|
|
173
|
+
attention_mask,
|
|
174
|
+
rewards,
|
|
175
|
+
bias=None,
|
|
176
|
+
ref_input=None,
|
|
177
|
+
ref_weight=None,
|
|
178
|
+
ref_bias=None,
|
|
179
|
+
):
|
|
180
|
+
return LigerFusedLinearGRPOFunction.apply(
|
|
181
|
+
_input,
|
|
182
|
+
lin_weight,
|
|
183
|
+
attention_mask,
|
|
184
|
+
rewards,
|
|
185
|
+
bias,
|
|
186
|
+
ref_input,
|
|
187
|
+
ref_weight,
|
|
188
|
+
ref_bias,
|
|
189
|
+
self.beta,
|
|
190
|
+
self.compiled,
|
|
191
|
+
self.use_ref_model,
|
|
192
|
+
self.num_generations,
|
|
193
|
+
self.chunk_size,
|
|
194
|
+
)
|
|
@@ -30,20 +30,24 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
30
30
|
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
|
31
31
|
return jsd_loss
|
|
32
32
|
|
|
33
|
-
@
|
|
33
|
+
@classmethod
|
|
34
34
|
def forward(
|
|
35
|
+
cls,
|
|
35
36
|
ctx,
|
|
36
37
|
student_input: torch.Tensor,
|
|
37
38
|
student_weight: torch.Tensor,
|
|
38
39
|
teacher_input: torch.Tensor,
|
|
39
40
|
teacher_weight: torch.Tensor,
|
|
40
41
|
true_labels: torch.LongTensor,
|
|
42
|
+
student_bias: torch.Tensor,
|
|
43
|
+
teacher_bias: torch.Tensor,
|
|
41
44
|
weight_hard_loss: float = 0.5,
|
|
42
45
|
weight_soft_loss: float = 0.5,
|
|
43
46
|
beta: float = 0.5,
|
|
44
47
|
ignore_index: int = -100,
|
|
45
48
|
temperature: float = 1.0,
|
|
46
49
|
compiled: bool = True,
|
|
50
|
+
chunk_size: int = 1024,
|
|
47
51
|
):
|
|
48
52
|
"""
|
|
49
53
|
Fused linear layer with JSD distillation loss.
|
|
@@ -59,18 +63,21 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
59
63
|
ignore_index (int): Index to ignore in loss computation
|
|
60
64
|
temperature (float): Temperature for softening/sharpening distributions
|
|
61
65
|
compiled (bool): Whether to use torch compile
|
|
66
|
+
chunk_size (int): Size of chunks for processing.
|
|
62
67
|
Returns:
|
|
63
68
|
torch.Tensor: Computed loss
|
|
64
69
|
"""
|
|
65
|
-
return
|
|
70
|
+
return super().forward(
|
|
71
|
+
cls=cls,
|
|
66
72
|
ctx=ctx,
|
|
67
73
|
student_input=student_input,
|
|
68
74
|
student_weight=student_weight,
|
|
69
75
|
teacher_input=teacher_input,
|
|
70
76
|
teacher_weight=teacher_weight,
|
|
71
77
|
target=true_labels,
|
|
72
|
-
|
|
73
|
-
|
|
78
|
+
student_bias=student_bias,
|
|
79
|
+
teacher_bias=teacher_bias,
|
|
80
|
+
chunk_size=chunk_size,
|
|
74
81
|
weight_hard_loss=weight_hard_loss,
|
|
75
82
|
weight_soft_loss=weight_soft_loss,
|
|
76
83
|
beta=beta,
|
|
@@ -81,9 +88,19 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
81
88
|
|
|
82
89
|
@staticmethod
|
|
83
90
|
def backward(ctx, grad_output):
|
|
84
|
-
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:
|
|
91
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
|
|
85
92
|
|
|
86
|
-
return (
|
|
93
|
+
return (
|
|
94
|
+
*grads,
|
|
95
|
+
None, # teacher_bias
|
|
96
|
+
None, # weight_hard_loss
|
|
97
|
+
None, # weight_soft_loss
|
|
98
|
+
None, # beta
|
|
99
|
+
None, # ignore_index
|
|
100
|
+
None, # temperature
|
|
101
|
+
None, # compiled
|
|
102
|
+
None, # chunk_size
|
|
103
|
+
)
|
|
87
104
|
|
|
88
105
|
|
|
89
106
|
class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
@@ -99,6 +116,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
99
116
|
ignore_index: int = -100,
|
|
100
117
|
temperature: float = 1.0,
|
|
101
118
|
compiled: bool = True,
|
|
119
|
+
chunk_size: int = 1024,
|
|
102
120
|
):
|
|
103
121
|
"""
|
|
104
122
|
Args:
|
|
@@ -108,6 +126,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
108
126
|
temperature (float): Temperature for softening distributions
|
|
109
127
|
compiled (bool): Whether to use torch compile
|
|
110
128
|
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
129
|
+
chunk_size (int): Size of chunks for processing.
|
|
111
130
|
"""
|
|
112
131
|
super().__init__()
|
|
113
132
|
assert temperature != 0, "Temperature cannot be 0."
|
|
@@ -117,6 +136,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
117
136
|
self.temperature = temperature
|
|
118
137
|
self.compiled = compiled
|
|
119
138
|
self.beta = beta
|
|
139
|
+
self.chunk_size = chunk_size
|
|
120
140
|
|
|
121
141
|
def forward(
|
|
122
142
|
self,
|
|
@@ -125,6 +145,8 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
125
145
|
teacher_input: torch.Tensor,
|
|
126
146
|
teacher_weight: torch.Tensor,
|
|
127
147
|
true_labels: torch.LongTensor,
|
|
148
|
+
student_bias: torch.Tensor,
|
|
149
|
+
teacher_bias: torch.Tensor,
|
|
128
150
|
) -> torch.Tensor:
|
|
129
151
|
"""
|
|
130
152
|
Compute the JSD distillation loss.
|
|
@@ -145,10 +167,13 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
145
167
|
teacher_input,
|
|
146
168
|
teacher_weight,
|
|
147
169
|
true_labels,
|
|
170
|
+
student_bias,
|
|
171
|
+
teacher_bias,
|
|
148
172
|
self.weight_hard_loss,
|
|
149
173
|
self.weight_soft_loss,
|
|
150
174
|
self.beta,
|
|
151
175
|
self.ignore_index,
|
|
152
176
|
self.temperature,
|
|
153
177
|
self.compiled,
|
|
178
|
+
self.chunk_size,
|
|
154
179
|
)
|