liger-kernel 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +3 -0
- liger_kernel/chunked_loss/cpo_loss.py +18 -8
- liger_kernel/chunked_loss/dpo_loss.py +20 -10
- liger_kernel/chunked_loss/functional.py +4 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
- liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
- liger_kernel/chunked_loss/fused_linear_rlhf.py +213 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
- liger_kernel/chunked_loss/grpo_loss.py +160 -0
- liger_kernel/chunked_loss/jsd_loss.py +154 -0
- liger_kernel/chunked_loss/kto_loss.py +172 -0
- liger_kernel/chunked_loss/orpo_loss.py +8 -9
- liger_kernel/chunked_loss/simpo_loss.py +22 -8
- liger_kernel/env_report.py +5 -12
- liger_kernel/ops/cross_entropy.py +102 -51
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
- liger_kernel/ops/fused_linear_jsd.py +14 -32
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +5 -9
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +23 -12
- liger_kernel/ops/qwen2vl_mrope.py +8 -25
- liger_kernel/ops/rms_norm.py +14 -32
- liger_kernel/ops/rope.py +31 -33
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +3 -2
- liger_kernel/transformers/__init__.py +19 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +7 -9
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +28 -7
- liger_kernel/transformers/fused_linear_cross_entropy.py +15 -10
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +9 -15
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +24 -54
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/olmo2.py +124 -0
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +36 -32
- liger_kernel/transformers/monkey_patch.py +214 -144
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/rope.py +2 -2
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- liger_kernel/utils.py +49 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/METADATA +53 -26
- liger_kernel-0.5.4.dist-info/RECORD +74 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/WHEEL +1 -1
- liger_kernel-0.5.2.dist-info/RECORD +0 -65
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/LICENSE +0 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/NOTICE +0 -0
- {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from torch.nn import functional as F
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def preference_loss_fn(*args, **kwargs):
|
|
12
|
+
"""
|
|
13
|
+
To be extended by subclasses.
|
|
14
|
+
"""
|
|
15
|
+
raise NotImplementedError("Preference loss function must be implemented.")
|
|
16
|
+
|
|
17
|
+
@staticmethod
|
|
18
|
+
def forward(
|
|
19
|
+
ctx,
|
|
20
|
+
_input,
|
|
21
|
+
weight,
|
|
22
|
+
target,
|
|
23
|
+
preference_labels,
|
|
24
|
+
bias=None,
|
|
25
|
+
loss_fn=None,
|
|
26
|
+
chunk_size=1,
|
|
27
|
+
ignore_index=-100,
|
|
28
|
+
compiled=True,
|
|
29
|
+
use_ref_model=False,
|
|
30
|
+
ref_input=None,
|
|
31
|
+
ref_weight=None,
|
|
32
|
+
ref_bias=None,
|
|
33
|
+
**loss_kwargs,
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Base class for fused linear layer with unpaired preference loss like KTO
|
|
37
|
+
Expects _input to be stacked with chosen and rejected inputs on the batch dimension.
|
|
38
|
+
|
|
39
|
+
The mental model is:
|
|
40
|
+
|
|
41
|
+
forward()
|
|
42
|
+
├── Loop over chunks
|
|
43
|
+
└── compute_loss()
|
|
44
|
+
├── chunk_forward() # Compute logits and log probs
|
|
45
|
+
└── prefer_loss() # Calculate preference loss
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
|
|
49
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
|
|
50
|
+
target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len).
|
|
51
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
|
52
|
+
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
|
53
|
+
chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
|
|
54
|
+
ignore_index (int): Index to ignore for loss computation.
|
|
55
|
+
beta (float): Weight for the preference loss.
|
|
56
|
+
compiled (bool): Whether to use torch compile for chunk accumulation.
|
|
57
|
+
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
|
58
|
+
preference_labels (torch.Tensor): Boolean tensor indicating chosen (True) vs rejected (False) examples.
|
|
59
|
+
Shape: (batch_size,).
|
|
60
|
+
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
|
61
|
+
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
|
62
|
+
loss_kwargs (dict): Other possible arguments that a loss function might need
|
|
63
|
+
"""
|
|
64
|
+
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
|
|
65
|
+
CHUNK_SIZE = chunk_size
|
|
66
|
+
|
|
67
|
+
# Gradients to be accumulated
|
|
68
|
+
grad_inputs = []
|
|
69
|
+
grad_weight = torch.zeros_like(weight)
|
|
70
|
+
grad_bias = torch.zeros_like(bias) if bias is not None else None
|
|
71
|
+
|
|
72
|
+
# Loss to be accumulated
|
|
73
|
+
loss_acc = torch.zeros((), device=_input.device)
|
|
74
|
+
|
|
75
|
+
compute_loss = partial(
|
|
76
|
+
LigerFusedLinearUnpairedPreferenceBase._compute_loss,
|
|
77
|
+
preference_loss_fn=loss_fn,
|
|
78
|
+
full_target=target,
|
|
79
|
+
ignore_index=ignore_index,
|
|
80
|
+
use_ref_model=use_ref_model,
|
|
81
|
+
ref_weight=ref_weight,
|
|
82
|
+
ref_bias=ref_bias,
|
|
83
|
+
**loss_kwargs,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def fused_fwd_bwd(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk):
|
|
87
|
+
"""
|
|
88
|
+
Fused forward and backward pass for a chunk of input and target.
|
|
89
|
+
"""
|
|
90
|
+
argnums = (0, 1, 4) if bias is not None else (0, 1)
|
|
91
|
+
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=False)(
|
|
92
|
+
input_chunk,
|
|
93
|
+
weight,
|
|
94
|
+
target_chunk,
|
|
95
|
+
preference_labels_chunk,
|
|
96
|
+
bias,
|
|
97
|
+
ref_input_chunk=ref_input_chunk,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def accumulate_chunk(
|
|
101
|
+
input_chunk,
|
|
102
|
+
target_chunk,
|
|
103
|
+
preference_labels_chunk=None,
|
|
104
|
+
ref_input_chunk=None,
|
|
105
|
+
):
|
|
106
|
+
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss) = fused_fwd_bwd(
|
|
107
|
+
input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk
|
|
108
|
+
)
|
|
109
|
+
if bias is not None:
|
|
110
|
+
grad_bias.add_(chunk_grad_bias[0]) # accumulate bias gradient
|
|
111
|
+
|
|
112
|
+
# Accumulate gradients
|
|
113
|
+
grad_weight.add_(chunk_grad_weight)
|
|
114
|
+
grad_inputs.append(chunk_grad_input)
|
|
115
|
+
|
|
116
|
+
# Accumulate loss
|
|
117
|
+
loss_acc.add_(chunk_loss)
|
|
118
|
+
|
|
119
|
+
if compiled:
|
|
120
|
+
fused_fwd_bwd = torch.compile(fused_fwd_bwd)
|
|
121
|
+
|
|
122
|
+
# When not paired, use labels to separate chosen and rejected
|
|
123
|
+
assert preference_labels is not None, "preference_labels must be provided for unpaired preference loss"
|
|
124
|
+
|
|
125
|
+
chunks = max(1, _input.shape[0] // CHUNK_SIZE)
|
|
126
|
+
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
|
|
127
|
+
_target_chunks = torch.chunk(target, chunks=chunks, dim=0)
|
|
128
|
+
_preference_labels_chunks = torch.chunk(preference_labels, chunks=chunks, dim=0)
|
|
129
|
+
|
|
130
|
+
if use_ref_model:
|
|
131
|
+
_ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0)
|
|
132
|
+
|
|
133
|
+
for (
|
|
134
|
+
input_chunk,
|
|
135
|
+
target_chunk,
|
|
136
|
+
ref_input_chunk,
|
|
137
|
+
preference_labels_chunk,
|
|
138
|
+
) in zip(
|
|
139
|
+
_input_chunks,
|
|
140
|
+
_target_chunks,
|
|
141
|
+
(_ref_input_chunks if use_ref_model else [None] * len(_input_chunks)),
|
|
142
|
+
_preference_labels_chunks,
|
|
143
|
+
):
|
|
144
|
+
# mark input_chunk, target_chunk, and target dimension 1 (sequence length) as dynamic to prevent torch.compile recompilation
|
|
145
|
+
torch._dynamo.mark_dynamic(input_chunk, 1)
|
|
146
|
+
torch._dynamo.mark_dynamic(target_chunk, 1)
|
|
147
|
+
torch._dynamo.mark_dynamic(target, 1)
|
|
148
|
+
torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
|
|
149
|
+
torch._dynamo.mark_dynamic(preference_labels_chunk, 1)
|
|
150
|
+
|
|
151
|
+
# accumulate loss, gradients, and metrics
|
|
152
|
+
accumulate_chunk(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk)
|
|
153
|
+
|
|
154
|
+
ctx.save_for_backward(
|
|
155
|
+
torch.cat(grad_inputs, dim=0),
|
|
156
|
+
grad_weight,
|
|
157
|
+
grad_bias,
|
|
158
|
+
)
|
|
159
|
+
return loss_acc
|
|
160
|
+
|
|
161
|
+
@staticmethod
|
|
162
|
+
def backward(ctx, *grad_output):
|
|
163
|
+
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
164
|
+
if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)):
|
|
165
|
+
grad_input = grad_input * grad_output[0][0]
|
|
166
|
+
grad_weight = grad_weight * grad_output[0][0]
|
|
167
|
+
grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
|
|
168
|
+
|
|
169
|
+
return grad_input, grad_weight, None, None, grad_bias
|
|
170
|
+
|
|
171
|
+
@staticmethod
|
|
172
|
+
def chunk_forward(
|
|
173
|
+
input_chunk,
|
|
174
|
+
weight,
|
|
175
|
+
target_chunk,
|
|
176
|
+
bias=None,
|
|
177
|
+
ignore_index=-100,
|
|
178
|
+
):
|
|
179
|
+
logits_chunk = input_chunk @ weight.t()
|
|
180
|
+
if bias is not None:
|
|
181
|
+
logits_chunk = logits_chunk + bias
|
|
182
|
+
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
|
|
183
|
+
|
|
184
|
+
loss_mask_chunk = target_chunk != ignore_index
|
|
185
|
+
label_chunk = torch.where(loss_mask_chunk, target_chunk, 0)
|
|
186
|
+
|
|
187
|
+
per_token_logps_chunk = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
|
|
188
|
+
average_log_prob_chunk = (per_token_logps_chunk * loss_mask_chunk).sum(-1) / loss_mask_chunk.sum(-1)
|
|
189
|
+
|
|
190
|
+
return average_log_prob_chunk
|
|
191
|
+
|
|
192
|
+
@staticmethod
|
|
193
|
+
def _compute_loss(
|
|
194
|
+
input_chunk,
|
|
195
|
+
weight,
|
|
196
|
+
target_chunk,
|
|
197
|
+
preference_labels_chunk,
|
|
198
|
+
bias=None,
|
|
199
|
+
preference_loss_fn=None,
|
|
200
|
+
full_target=None,
|
|
201
|
+
ignore_index=-100,
|
|
202
|
+
use_ref_model=False,
|
|
203
|
+
ref_input_chunk=None,
|
|
204
|
+
ref_weight=None,
|
|
205
|
+
ref_bias=None,
|
|
206
|
+
**loss_kwargs,
|
|
207
|
+
):
|
|
208
|
+
"""
|
|
209
|
+
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
|
|
210
|
+
Args:
|
|
211
|
+
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
|
212
|
+
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
|
|
213
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
|
|
214
|
+
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
|
|
215
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
|
216
|
+
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
|
|
217
|
+
ignore_index (int): Index to ignore for loss computation.
|
|
218
|
+
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
|
219
|
+
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
|
220
|
+
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
|
221
|
+
loss_kwargs (dict): Additional arguments for the loss function.
|
|
222
|
+
"""
|
|
223
|
+
average_log_prob_chunk = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
|
|
224
|
+
input_chunk,
|
|
225
|
+
weight,
|
|
226
|
+
target_chunk,
|
|
227
|
+
bias=bias,
|
|
228
|
+
ignore_index=ignore_index,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
if use_ref_model:
|
|
232
|
+
with torch.no_grad():
|
|
233
|
+
ref_average_log_prob_chunk = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
|
|
234
|
+
ref_input_chunk,
|
|
235
|
+
ref_weight,
|
|
236
|
+
target_chunk,
|
|
237
|
+
ref_bias,
|
|
238
|
+
ignore_index=ignore_index,
|
|
239
|
+
)
|
|
240
|
+
loss_kwargs["ref_average_log_prob_chunk"] = ref_average_log_prob_chunk
|
|
241
|
+
|
|
242
|
+
preference_loss_chunk = preference_loss_fn(
|
|
243
|
+
average_log_prob_chunk, preference_labels_chunk, full_target, **loss_kwargs
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
return preference_loss_chunk
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from liger_kernel.chunked_loss.fused_linear_rlhf import LigerFusedLinearRLHFBase
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
|
|
7
|
+
@staticmethod
|
|
8
|
+
def rlhf_loss_fn(
|
|
9
|
+
log_probs,
|
|
10
|
+
attention_mask,
|
|
11
|
+
rewards,
|
|
12
|
+
ref_log_probs=None,
|
|
13
|
+
beta=0.1,
|
|
14
|
+
**kwargs,
|
|
15
|
+
):
|
|
16
|
+
"""GRPO Loss Function matching GRPOTrainer implementation."""
|
|
17
|
+
# Get chosen token probabilities
|
|
18
|
+
chosen_tokens = log_probs.argmax(dim=-1) # (batch_size, seq_len)
|
|
19
|
+
chosen_token_logprobs = log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(
|
|
20
|
+
-1
|
|
21
|
+
) # (batch_size, seq_len)
|
|
22
|
+
|
|
23
|
+
# Get reference model probabilities
|
|
24
|
+
if ref_log_probs is not None:
|
|
25
|
+
with torch.no_grad():
|
|
26
|
+
ref_token_logprobs = ref_log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(-1)
|
|
27
|
+
else:
|
|
28
|
+
ref_token_logprobs = chosen_token_logprobs.detach()
|
|
29
|
+
|
|
30
|
+
# Compute advantages per batch entry in a grouped fashion
|
|
31
|
+
mean_grouped_rewards = rewards.mean() # [batch_size,]
|
|
32
|
+
std_grouped_rewards = rewards.std() # [batch_size,]
|
|
33
|
+
|
|
34
|
+
# Calculate advantages using the same epsilon as in GRPOTrainer
|
|
35
|
+
eps = 1e-4
|
|
36
|
+
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + eps)
|
|
37
|
+
|
|
38
|
+
# Compute policy gradient loss with importance sampling ratio
|
|
39
|
+
ratio = torch.exp(chosen_token_logprobs - chosen_token_logprobs.detach())
|
|
40
|
+
policy_loss = -ratio * advantages.unsqueeze(1)
|
|
41
|
+
|
|
42
|
+
# Compute KL penalty
|
|
43
|
+
kl_div = (
|
|
44
|
+
torch.exp(ref_token_logprobs - chosen_token_logprobs) - (ref_token_logprobs - chosen_token_logprobs) - 1.0
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Combine losses
|
|
48
|
+
per_token_loss = policy_loss + beta * kl_div
|
|
49
|
+
|
|
50
|
+
# Apply masking and normalize
|
|
51
|
+
masked_loss = per_token_loss * attention_mask
|
|
52
|
+
seq_lengths = attention_mask.sum()
|
|
53
|
+
seq_lengths = torch.clamp(seq_lengths, min=1.0)
|
|
54
|
+
loss = masked_loss.sum() / seq_lengths
|
|
55
|
+
|
|
56
|
+
# Calculate metrics
|
|
57
|
+
metrics = (
|
|
58
|
+
chosen_token_logprobs.mean(), # mean log prob
|
|
59
|
+
chosen_token_logprobs.std(), # std log prob
|
|
60
|
+
log_probs.mean(), # mean all log probs
|
|
61
|
+
((kl_div * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)).mean(), # mean KL div
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
return loss, metrics
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def forward(
|
|
68
|
+
ctx,
|
|
69
|
+
_input,
|
|
70
|
+
weight,
|
|
71
|
+
attention_mask,
|
|
72
|
+
rewards,
|
|
73
|
+
bias=None,
|
|
74
|
+
ref_input=None,
|
|
75
|
+
ref_weight=None,
|
|
76
|
+
ref_bias=None,
|
|
77
|
+
beta=0.1,
|
|
78
|
+
compiled=True,
|
|
79
|
+
use_ref_model=True,
|
|
80
|
+
num_generations=1,
|
|
81
|
+
):
|
|
82
|
+
return LigerFusedLinearRLHFBase.forward(
|
|
83
|
+
ctx=ctx,
|
|
84
|
+
_input=_input,
|
|
85
|
+
weight=weight,
|
|
86
|
+
attention_mask=attention_mask,
|
|
87
|
+
loss_fn=LigerFusedLinearGRPOFunction.rlhf_loss_fn,
|
|
88
|
+
rewards=rewards,
|
|
89
|
+
bias=bias,
|
|
90
|
+
ref_input=ref_input,
|
|
91
|
+
ref_weight=ref_weight,
|
|
92
|
+
ref_bias=ref_bias,
|
|
93
|
+
beta=beta,
|
|
94
|
+
compiled=compiled,
|
|
95
|
+
use_ref_model=use_ref_model,
|
|
96
|
+
num_generations=num_generations,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
@staticmethod
|
|
100
|
+
def backward(ctx, grad_output, *grad_metrics):
|
|
101
|
+
"""Backward pass for GRPO loss.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
grad_output: Gradient of the loss (scalar)
|
|
105
|
+
grad_metrics: Gradients of the metrics (not used in backward computation)
|
|
106
|
+
"""
|
|
107
|
+
grads = LigerFusedLinearRLHFBase.backward(ctx, grad_output)
|
|
108
|
+
return (
|
|
109
|
+
*grads[:5], # grad_input, grad_weight, grad_attention_mask, grad_rewards, grad_bias
|
|
110
|
+
None, # grad_ref_input
|
|
111
|
+
None, # grad_ref_weight
|
|
112
|
+
None, # grad_ref_bias
|
|
113
|
+
None, # grad_beta
|
|
114
|
+
None, # grad_compiled
|
|
115
|
+
None, # grad_use_ref_model
|
|
116
|
+
None, # grad_num_generations
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
121
|
+
"""Fused linear layer with GRPO loss."""
|
|
122
|
+
|
|
123
|
+
def __init__(
|
|
124
|
+
self,
|
|
125
|
+
beta: float = 0.1,
|
|
126
|
+
compiled: bool = True,
|
|
127
|
+
use_ref_model: bool = True,
|
|
128
|
+
num_generations: int = 1,
|
|
129
|
+
):
|
|
130
|
+
super().__init__()
|
|
131
|
+
self.beta = beta
|
|
132
|
+
self.compiled = compiled
|
|
133
|
+
self.use_ref_model = use_ref_model
|
|
134
|
+
self.num_generations = num_generations
|
|
135
|
+
|
|
136
|
+
def forward(
|
|
137
|
+
self,
|
|
138
|
+
_input,
|
|
139
|
+
lin_weight,
|
|
140
|
+
attention_mask,
|
|
141
|
+
rewards,
|
|
142
|
+
bias=None,
|
|
143
|
+
ref_input=None,
|
|
144
|
+
ref_weight=None,
|
|
145
|
+
ref_bias=None,
|
|
146
|
+
):
|
|
147
|
+
return LigerFusedLinearGRPOFunction.apply(
|
|
148
|
+
_input,
|
|
149
|
+
lin_weight,
|
|
150
|
+
attention_mask,
|
|
151
|
+
rewards,
|
|
152
|
+
bias,
|
|
153
|
+
ref_input,
|
|
154
|
+
ref_weight,
|
|
155
|
+
ref_bias,
|
|
156
|
+
self.beta,
|
|
157
|
+
self.compiled,
|
|
158
|
+
self.use_ref_model,
|
|
159
|
+
self.num_generations,
|
|
160
|
+
)
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
8
|
+
@staticmethod
|
|
9
|
+
def distillation_loss_fn(student_logits, teacher_logits, beta=0.5):
|
|
10
|
+
"""
|
|
11
|
+
Compute JSD loss (Jensen-Shannon Divergence Loss).
|
|
12
|
+
Args:
|
|
13
|
+
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
|
|
14
|
+
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
|
|
15
|
+
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
16
|
+
Returns:
|
|
17
|
+
torch.Tensor: Jensen-Shannon Divergence loss
|
|
18
|
+
"""
|
|
19
|
+
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
|
20
|
+
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
|
21
|
+
|
|
22
|
+
# Compute probabilities (only required for mean calculation)
|
|
23
|
+
mean_probs = beta * student_log_probs.exp() + (1 - beta) * teacher_log_probs.exp()
|
|
24
|
+
log_mean_probs = mean_probs.log()
|
|
25
|
+
|
|
26
|
+
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
|
|
27
|
+
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
|
|
28
|
+
|
|
29
|
+
# JSD is the weighted average of the KL divergences
|
|
30
|
+
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
|
31
|
+
return jsd_loss
|
|
32
|
+
|
|
33
|
+
@staticmethod
|
|
34
|
+
def forward(
|
|
35
|
+
ctx,
|
|
36
|
+
student_input: torch.Tensor,
|
|
37
|
+
student_weight: torch.Tensor,
|
|
38
|
+
teacher_input: torch.Tensor,
|
|
39
|
+
teacher_weight: torch.Tensor,
|
|
40
|
+
true_labels: torch.LongTensor,
|
|
41
|
+
weight_hard_loss: float = 0.5,
|
|
42
|
+
weight_soft_loss: float = 0.5,
|
|
43
|
+
beta: float = 0.5,
|
|
44
|
+
ignore_index: int = -100,
|
|
45
|
+
temperature: float = 1.0,
|
|
46
|
+
compiled: bool = True,
|
|
47
|
+
):
|
|
48
|
+
"""
|
|
49
|
+
Fused linear layer with JSD distillation loss.
|
|
50
|
+
Args:
|
|
51
|
+
student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, hidden_size_student)
|
|
52
|
+
student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, hidden_size_student)
|
|
53
|
+
teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, hidden_size_teacher)
|
|
54
|
+
teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, hidden_size_teacher)
|
|
55
|
+
true_labels (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
56
|
+
weight_hard_loss (float): Weight for hard loss.
|
|
57
|
+
weight_soft_loss (float): Weight for soft loss.
|
|
58
|
+
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
59
|
+
ignore_index (int): Index to ignore in loss computation
|
|
60
|
+
temperature (float): Temperature for softening/sharpening distributions
|
|
61
|
+
compiled (bool): Whether to use torch compile
|
|
62
|
+
Returns:
|
|
63
|
+
torch.Tensor: Computed loss
|
|
64
|
+
"""
|
|
65
|
+
return LigerFusedLinearDistillationBase.forward(
|
|
66
|
+
ctx=ctx,
|
|
67
|
+
student_input=student_input,
|
|
68
|
+
student_weight=student_weight,
|
|
69
|
+
teacher_input=teacher_input,
|
|
70
|
+
teacher_weight=teacher_weight,
|
|
71
|
+
target=true_labels,
|
|
72
|
+
loss_fn=LigerFusedLinearJSDFunction.distillation_loss_fn,
|
|
73
|
+
chunk_size=1,
|
|
74
|
+
weight_hard_loss=weight_hard_loss,
|
|
75
|
+
weight_soft_loss=weight_soft_loss,
|
|
76
|
+
beta=beta,
|
|
77
|
+
ignore_index=ignore_index,
|
|
78
|
+
temperature=temperature,
|
|
79
|
+
compiled=compiled,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def backward(ctx, grad_output):
|
|
84
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:4]
|
|
85
|
+
|
|
86
|
+
return (*grads, None, None, None, None, None, None, None)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
90
|
+
"""
|
|
91
|
+
Fused linear layer with JSD distillation loss.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
weight_hard_loss: float = 0.5,
|
|
97
|
+
weight_soft_loss: float = 0.5,
|
|
98
|
+
beta: float = 0.5,
|
|
99
|
+
ignore_index: int = -100,
|
|
100
|
+
temperature: float = 1.0,
|
|
101
|
+
compiled: bool = True,
|
|
102
|
+
):
|
|
103
|
+
"""
|
|
104
|
+
Args:
|
|
105
|
+
weight_hard_loss (float): Weight for hard loss.
|
|
106
|
+
weight_soft_loss (float): Weight for soft loss.
|
|
107
|
+
ignore_index (int): Index to ignore in the loss
|
|
108
|
+
temperature (float): Temperature for softening distributions
|
|
109
|
+
compiled (bool): Whether to use torch compile
|
|
110
|
+
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
111
|
+
"""
|
|
112
|
+
super().__init__()
|
|
113
|
+
assert temperature != 0, "Temperature cannot be 0."
|
|
114
|
+
self.weight_hard_loss = weight_hard_loss
|
|
115
|
+
self.weight_soft_loss = weight_soft_loss
|
|
116
|
+
self.ignore_index = ignore_index
|
|
117
|
+
self.temperature = temperature
|
|
118
|
+
self.compiled = compiled
|
|
119
|
+
self.beta = beta
|
|
120
|
+
|
|
121
|
+
def forward(
|
|
122
|
+
self,
|
|
123
|
+
student_input: torch.Tensor,
|
|
124
|
+
student_weight: torch.Tensor,
|
|
125
|
+
teacher_input: torch.Tensor,
|
|
126
|
+
teacher_weight: torch.Tensor,
|
|
127
|
+
true_labels: torch.LongTensor,
|
|
128
|
+
) -> torch.Tensor:
|
|
129
|
+
"""
|
|
130
|
+
Compute the JSD distillation loss.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
student_input (torch.Tensor): Student input tensor
|
|
134
|
+
student_weight (torch.Tensor): Student weight tensor
|
|
135
|
+
teacher_input (torch.Tensor): Teacher input tensor
|
|
136
|
+
teacher_weight (torch.Tensor): Teacher weight tensor
|
|
137
|
+
true_labels (torch.LongTensor): Target labels tensor
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
torch.Tensor: Computed loss
|
|
141
|
+
"""
|
|
142
|
+
return LigerFusedLinearJSDFunction.apply(
|
|
143
|
+
student_input,
|
|
144
|
+
student_weight,
|
|
145
|
+
teacher_input,
|
|
146
|
+
teacher_weight,
|
|
147
|
+
true_labels,
|
|
148
|
+
self.weight_hard_loss,
|
|
149
|
+
self.weight_soft_loss,
|
|
150
|
+
self.beta,
|
|
151
|
+
self.ignore_index,
|
|
152
|
+
self.temperature,
|
|
153
|
+
self.compiled,
|
|
154
|
+
)
|