liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +304 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +21 -4
- liger_kernel/ops/cross_entropy.py +235 -84
- liger_kernel/ops/dyt.py +157 -0
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
- liger_kernel/ops/fused_linear_jsd.py +17 -34
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +7 -18
- liger_kernel/ops/group_norm.py +305 -0
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/jsd.py +46 -21
- liger_kernel/ops/kl_div.py +23 -19
- liger_kernel/ops/layer_norm.py +150 -86
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +314 -84
- liger_kernel/ops/rope.py +32 -34
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +5 -9
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +8 -4
- liger_kernel/transformers/__init__.py +199 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +33 -20
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +291 -13
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
- liger_kernel/transformers/fused_linear_jsd.py +1 -4
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/jsd.py +2 -7
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +77 -77
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +128 -79
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +68 -64
- liger_kernel/transformers/model/mixtral.py +75 -91
- liger_kernel/transformers/model/mllama.py +63 -68
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +432 -0
- liger_kernel/transformers/model/phi3.py +59 -213
- liger_kernel/transformers/model/qwen2.py +75 -72
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +78 -98
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2106 -289
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +57 -6
- liger_kernel/transformers/rope.py +45 -2
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +23 -8
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- liger_kernel/utils.py +71 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
- liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from torch.nn import functional as F
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def preference_loss_fn(*args, **kwargs):
|
|
12
|
+
"""
|
|
13
|
+
To be extended by subclasses.
|
|
14
|
+
"""
|
|
15
|
+
raise NotImplementedError("Preference loss function must be implemented.")
|
|
16
|
+
|
|
17
|
+
@staticmethod
|
|
18
|
+
def forward(
|
|
19
|
+
cls,
|
|
20
|
+
ctx,
|
|
21
|
+
_input,
|
|
22
|
+
weight,
|
|
23
|
+
target,
|
|
24
|
+
preference_labels,
|
|
25
|
+
bias=None,
|
|
26
|
+
chunk_size=1,
|
|
27
|
+
ignore_index=-100,
|
|
28
|
+
compiled=True,
|
|
29
|
+
use_ref_model=False,
|
|
30
|
+
ref_input=None,
|
|
31
|
+
ref_weight=None,
|
|
32
|
+
ref_bias=None,
|
|
33
|
+
average_log_prob=False,
|
|
34
|
+
**loss_kwargs,
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
Base class for fused linear layer with unpaired preference loss like KTO
|
|
38
|
+
Expects _input to be stacked with chosen and rejected inputs on the batch dimension.
|
|
39
|
+
|
|
40
|
+
The mental model is:
|
|
41
|
+
|
|
42
|
+
forward()
|
|
43
|
+
├── Loop over chunks
|
|
44
|
+
└── compute_loss()
|
|
45
|
+
├── chunk_forward() # Compute logits and log probs
|
|
46
|
+
└── prefer_loss() # Calculate preference loss
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
|
|
50
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
|
|
51
|
+
target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len).
|
|
52
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
|
53
|
+
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
|
54
|
+
chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
|
|
55
|
+
ignore_index (int): Index to ignore for loss computation.
|
|
56
|
+
beta (float): Weight for the preference loss.
|
|
57
|
+
compiled (bool): Whether to use torch compile for chunk accumulation.
|
|
58
|
+
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
|
59
|
+
preference_labels (torch.Tensor): Boolean tensor indicating chosen (True) vs rejected (False) examples.
|
|
60
|
+
Shape: (batch_size,).
|
|
61
|
+
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
|
62
|
+
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
|
63
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token.
|
|
64
|
+
loss_kwargs (dict): Other possible arguments that a loss function might need
|
|
65
|
+
"""
|
|
66
|
+
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
|
|
67
|
+
CHUNK_SIZE = chunk_size
|
|
68
|
+
|
|
69
|
+
# Gradients to be accumulated
|
|
70
|
+
grad_inputs = []
|
|
71
|
+
grad_weight = torch.zeros_like(weight)
|
|
72
|
+
grad_bias = torch.zeros_like(bias) if bias is not None else None
|
|
73
|
+
|
|
74
|
+
# Loss to be accumulated
|
|
75
|
+
loss_acc = torch.zeros((), device=_input.device)
|
|
76
|
+
|
|
77
|
+
# Metrics to be recorded
|
|
78
|
+
chosen_logps_sum = torch.zeros((), device=_input.device)
|
|
79
|
+
rejected_logps_sum = torch.zeros((), device=_input.device)
|
|
80
|
+
chosen_logits_sum = torch.zeros((), device=_input.device)
|
|
81
|
+
rejected_logits_sum = torch.zeros((), device=_input.device)
|
|
82
|
+
aggregated_aux_outputs = []
|
|
83
|
+
|
|
84
|
+
compute_loss = partial(
|
|
85
|
+
LigerFusedLinearUnpairedPreferenceBase._compute_loss,
|
|
86
|
+
preference_loss_fn=cls.preference_loss_fn,
|
|
87
|
+
full_target=target,
|
|
88
|
+
ignore_index=ignore_index,
|
|
89
|
+
use_ref_model=use_ref_model,
|
|
90
|
+
ref_weight=ref_weight,
|
|
91
|
+
ref_bias=ref_bias,
|
|
92
|
+
average_log_prob=average_log_prob,
|
|
93
|
+
**loss_kwargs,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def fused_fwd_bwd(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk):
|
|
97
|
+
"""
|
|
98
|
+
Fused forward and backward pass for a chunk of input and target.
|
|
99
|
+
"""
|
|
100
|
+
argnums = (0, 1, 4) if bias is not None else (0, 1)
|
|
101
|
+
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
|
|
102
|
+
input_chunk,
|
|
103
|
+
weight,
|
|
104
|
+
target_chunk,
|
|
105
|
+
preference_labels_chunk,
|
|
106
|
+
bias,
|
|
107
|
+
ref_input_chunk=ref_input_chunk,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
def accumulate_chunk(
|
|
111
|
+
input_chunk,
|
|
112
|
+
target_chunk,
|
|
113
|
+
preference_labels_chunk=None,
|
|
114
|
+
ref_input_chunk=None,
|
|
115
|
+
):
|
|
116
|
+
(
|
|
117
|
+
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias),
|
|
118
|
+
(
|
|
119
|
+
chunk_loss,
|
|
120
|
+
(
|
|
121
|
+
chunk_chosen_logps_sum,
|
|
122
|
+
chunk_rejected_logps_sum,
|
|
123
|
+
chunk_chosen_logits_sum,
|
|
124
|
+
chunk_rejected_logits_sum,
|
|
125
|
+
*aux_outputs,
|
|
126
|
+
),
|
|
127
|
+
),
|
|
128
|
+
) = fused_fwd_bwd(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk)
|
|
129
|
+
if bias is not None:
|
|
130
|
+
grad_bias.add_(chunk_grad_bias[0]) # accumulate bias gradient
|
|
131
|
+
|
|
132
|
+
# Accumulate gradients
|
|
133
|
+
grad_weight.add_(chunk_grad_weight)
|
|
134
|
+
grad_inputs.append(chunk_grad_input)
|
|
135
|
+
|
|
136
|
+
# Accumulate loss
|
|
137
|
+
loss_acc.add_(chunk_loss)
|
|
138
|
+
|
|
139
|
+
# Accumulate metrics
|
|
140
|
+
chosen_logps_sum.add_(chunk_chosen_logps_sum)
|
|
141
|
+
rejected_logps_sum.add_(chunk_rejected_logps_sum)
|
|
142
|
+
chosen_logits_sum.add_(chunk_chosen_logits_sum)
|
|
143
|
+
rejected_logits_sum.add_(chunk_rejected_logits_sum)
|
|
144
|
+
|
|
145
|
+
# aux_outputs
|
|
146
|
+
# Initialize storage for aux_outputs
|
|
147
|
+
if len(aggregated_aux_outputs) == 0:
|
|
148
|
+
for aux in aux_outputs:
|
|
149
|
+
aggregated_aux_outputs.append(torch.zeros((), device=aux.device))
|
|
150
|
+
|
|
151
|
+
# Process each aux_output
|
|
152
|
+
for i, aux in enumerate(aux_outputs):
|
|
153
|
+
if aux.ndim == 0:
|
|
154
|
+
aggregated_aux_outputs[i].add_(aux)
|
|
155
|
+
|
|
156
|
+
if compiled:
|
|
157
|
+
fused_fwd_bwd = torch.compile(fused_fwd_bwd)
|
|
158
|
+
|
|
159
|
+
# When not paired, use labels to separate chosen and rejected
|
|
160
|
+
assert preference_labels is not None, "preference_labels must be provided for unpaired preference loss"
|
|
161
|
+
|
|
162
|
+
chunks = max(1, _input.shape[0] // CHUNK_SIZE)
|
|
163
|
+
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
|
|
164
|
+
_target_chunks = torch.chunk(target, chunks=chunks, dim=0)
|
|
165
|
+
_preference_labels_chunks = torch.chunk(preference_labels, chunks=chunks, dim=0)
|
|
166
|
+
|
|
167
|
+
if use_ref_model:
|
|
168
|
+
_ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0)
|
|
169
|
+
|
|
170
|
+
for (
|
|
171
|
+
input_chunk,
|
|
172
|
+
target_chunk,
|
|
173
|
+
ref_input_chunk,
|
|
174
|
+
preference_labels_chunk,
|
|
175
|
+
) in zip(
|
|
176
|
+
_input_chunks,
|
|
177
|
+
_target_chunks,
|
|
178
|
+
(_ref_input_chunks if use_ref_model else [None] * len(_input_chunks)),
|
|
179
|
+
_preference_labels_chunks,
|
|
180
|
+
):
|
|
181
|
+
# mark input_chunk, target_chunk, and target dimension 1 (sequence length) as dynamic to prevent torch.compile recompilation
|
|
182
|
+
torch._dynamo.mark_dynamic(input_chunk, 1)
|
|
183
|
+
torch._dynamo.mark_dynamic(target_chunk, 1)
|
|
184
|
+
torch._dynamo.mark_dynamic(target, 1)
|
|
185
|
+
torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
|
|
186
|
+
torch._dynamo.mark_dynamic(preference_labels_chunk, 1)
|
|
187
|
+
|
|
188
|
+
# accumulate loss, gradients, and metrics
|
|
189
|
+
accumulate_chunk(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk)
|
|
190
|
+
|
|
191
|
+
# Aggregate aux outputs lists into tensors
|
|
192
|
+
for i, aux in enumerate(aggregated_aux_outputs):
|
|
193
|
+
if isinstance(aux, list):
|
|
194
|
+
aggregated_aux_outputs[i] = torch.cat(aux, dim=0)
|
|
195
|
+
|
|
196
|
+
ctx.save_for_backward(
|
|
197
|
+
torch.cat(grad_inputs, dim=0),
|
|
198
|
+
grad_weight,
|
|
199
|
+
grad_bias,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
return_vars = (
|
|
203
|
+
chosen_logps_sum,
|
|
204
|
+
rejected_logps_sum,
|
|
205
|
+
chosen_logits_sum,
|
|
206
|
+
rejected_logits_sum,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
return loss_acc, (*return_vars, *aggregated_aux_outputs)
|
|
210
|
+
|
|
211
|
+
@staticmethod
|
|
212
|
+
def backward(ctx, *grad_output):
|
|
213
|
+
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
214
|
+
if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)):
|
|
215
|
+
grad_input = grad_input * grad_output[0][0]
|
|
216
|
+
grad_weight = grad_weight * grad_output[0][0]
|
|
217
|
+
grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
|
|
218
|
+
|
|
219
|
+
return grad_input, grad_weight, None, None, grad_bias
|
|
220
|
+
|
|
221
|
+
@staticmethod
|
|
222
|
+
def chunk_forward(
|
|
223
|
+
input_chunk,
|
|
224
|
+
weight,
|
|
225
|
+
target_chunk,
|
|
226
|
+
preference_labels_chunk,
|
|
227
|
+
bias=None,
|
|
228
|
+
ignore_index=-100,
|
|
229
|
+
average_log_prob=False,
|
|
230
|
+
):
|
|
231
|
+
logits_chunk = input_chunk @ weight.t()
|
|
232
|
+
if bias is not None:
|
|
233
|
+
logits_chunk = logits_chunk + bias
|
|
234
|
+
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
|
|
235
|
+
loss_mask_chunk = target_chunk != ignore_index
|
|
236
|
+
label_chunk = torch.where(loss_mask_chunk, target_chunk, 0)
|
|
237
|
+
|
|
238
|
+
per_token_logps_chunk = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
|
|
239
|
+
if average_log_prob:
|
|
240
|
+
log_probs = (per_token_logps_chunk * loss_mask_chunk).sum(-1) / loss_mask_chunk.sum(-1)
|
|
241
|
+
else:
|
|
242
|
+
log_probs = (per_token_logps_chunk * loss_mask_chunk).sum(-1)
|
|
243
|
+
|
|
244
|
+
chosen_logps_sum = (log_probs * preference_labels_chunk.unsqueeze(1)).sum()
|
|
245
|
+
rejected_logps_sum = (log_probs * (~preference_labels_chunk).unsqueeze(1)).sum()
|
|
246
|
+
|
|
247
|
+
chosen_logits_sum = (logits_chunk * preference_labels_chunk.unsqueeze(1)).sum()
|
|
248
|
+
rejected_logits_sum = (logits_chunk * (~preference_labels_chunk).unsqueeze(1)).sum()
|
|
249
|
+
|
|
250
|
+
return (
|
|
251
|
+
log_probs,
|
|
252
|
+
chosen_logps_sum,
|
|
253
|
+
rejected_logps_sum,
|
|
254
|
+
chosen_logits_sum,
|
|
255
|
+
rejected_logits_sum,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
@staticmethod
|
|
259
|
+
def _compute_loss(
|
|
260
|
+
input_chunk,
|
|
261
|
+
weight,
|
|
262
|
+
target_chunk,
|
|
263
|
+
preference_labels_chunk,
|
|
264
|
+
bias=None,
|
|
265
|
+
preference_loss_fn=None,
|
|
266
|
+
full_target=None,
|
|
267
|
+
ignore_index=-100,
|
|
268
|
+
use_ref_model=False,
|
|
269
|
+
ref_input_chunk=None,
|
|
270
|
+
ref_weight=None,
|
|
271
|
+
ref_bias=None,
|
|
272
|
+
average_log_prob=False,
|
|
273
|
+
**loss_kwargs,
|
|
274
|
+
):
|
|
275
|
+
"""
|
|
276
|
+
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
|
|
277
|
+
Args:
|
|
278
|
+
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
|
279
|
+
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
|
|
280
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
|
|
281
|
+
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
|
|
282
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
|
283
|
+
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
|
|
284
|
+
ignore_index (int): Index to ignore for loss computation.
|
|
285
|
+
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
|
286
|
+
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
|
287
|
+
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
|
288
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token.
|
|
289
|
+
loss_kwargs (dict): Additional arguments for the loss function.
|
|
290
|
+
"""
|
|
291
|
+
(
|
|
292
|
+
log_prob_chunk,
|
|
293
|
+
chosen_logps_sum,
|
|
294
|
+
rejected_logps_sum,
|
|
295
|
+
chosen_logits_sum,
|
|
296
|
+
rejected_logits_sum,
|
|
297
|
+
) = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
|
|
298
|
+
input_chunk,
|
|
299
|
+
weight,
|
|
300
|
+
target_chunk,
|
|
301
|
+
preference_labels_chunk,
|
|
302
|
+
bias=bias,
|
|
303
|
+
ignore_index=ignore_index,
|
|
304
|
+
average_log_prob=average_log_prob,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
if use_ref_model:
|
|
308
|
+
with torch.no_grad():
|
|
309
|
+
(
|
|
310
|
+
ref_log_prob_chunk,
|
|
311
|
+
_,
|
|
312
|
+
_,
|
|
313
|
+
_,
|
|
314
|
+
_,
|
|
315
|
+
) = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
|
|
316
|
+
ref_input_chunk,
|
|
317
|
+
ref_weight,
|
|
318
|
+
target_chunk,
|
|
319
|
+
preference_labels_chunk,
|
|
320
|
+
ref_bias,
|
|
321
|
+
ignore_index=ignore_index,
|
|
322
|
+
average_log_prob=average_log_prob,
|
|
323
|
+
)
|
|
324
|
+
loss_kwargs["ref_log_prob_chunk"] = ref_log_prob_chunk
|
|
325
|
+
|
|
326
|
+
preference_loss_outputs = preference_loss_fn(
|
|
327
|
+
log_prob_chunk, preference_labels_chunk, full_target, **loss_kwargs
|
|
328
|
+
)
|
|
329
|
+
if isinstance(preference_loss_outputs, tuple):
|
|
330
|
+
preference_loss_chunk, *aux_outputs = preference_loss_outputs
|
|
331
|
+
else:
|
|
332
|
+
preference_loss_chunk, aux_outputs = preference_loss_outputs, []
|
|
333
|
+
|
|
334
|
+
return_vars = (
|
|
335
|
+
chosen_logps_sum,
|
|
336
|
+
rejected_logps_sum,
|
|
337
|
+
chosen_logits_sum,
|
|
338
|
+
rejected_logits_sum,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
return preference_loss_chunk, (*return_vars, *aux_outputs)
|
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def k3_loss_fn(log_p, log_q):
|
|
9
|
+
# computes k3 estimate of KL[q, p]
|
|
10
|
+
# ref: http://joschu.net/blog/kl-approx.html
|
|
11
|
+
return torch.exp(log_p - log_q) - (log_p - log_q) - 1.0
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def clip_coef_fn(coef, epsilon_low, epsilon_high):
|
|
15
|
+
return torch.clamp(coef, 1 - epsilon_low, 1 + epsilon_high)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
19
|
+
@staticmethod
|
|
20
|
+
def ppo_loss_fn(
|
|
21
|
+
log_probs,
|
|
22
|
+
selected_token_ids,
|
|
23
|
+
attention_mask,
|
|
24
|
+
advantages,
|
|
25
|
+
full_attention_mask,
|
|
26
|
+
ref_per_token_logps=None, # shape: [chunk_size, seq_len]
|
|
27
|
+
old_per_token_logps=None,
|
|
28
|
+
ref_log_probs=None, # used when ref_per_token_logps is None (shape: [chunk_size, seq_len, vocab_size])
|
|
29
|
+
epsilon_low=0.2,
|
|
30
|
+
epsilon_high=0.2,
|
|
31
|
+
beta=0.04,
|
|
32
|
+
loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
|
|
33
|
+
max_completion_length=None, # Required for dr_grpo
|
|
34
|
+
importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
|
|
35
|
+
**kwargs,
|
|
36
|
+
):
|
|
37
|
+
"""GRPO Loss Function matching GRPOTrainer implementation."""
|
|
38
|
+
per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
|
|
39
|
+
-1
|
|
40
|
+
) # (batch_size, seq_len)
|
|
41
|
+
|
|
42
|
+
# Get reference model probabilities
|
|
43
|
+
if ref_per_token_logps is None:
|
|
44
|
+
if ref_log_probs is not None:
|
|
45
|
+
with torch.no_grad():
|
|
46
|
+
ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
|
|
47
|
+
-1
|
|
48
|
+
)
|
|
49
|
+
else:
|
|
50
|
+
ref_per_token_logps = per_token_logps.detach()
|
|
51
|
+
|
|
52
|
+
# Compute policy gradient loss with importance sampling ratio
|
|
53
|
+
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
|
|
54
|
+
log_ratio = per_token_logps - old_per_token_logps
|
|
55
|
+
|
|
56
|
+
if importance_sampling_level == "token":
|
|
57
|
+
log_importance_weights = log_ratio
|
|
58
|
+
elif importance_sampling_level == "sequence":
|
|
59
|
+
log_importance_weights = (log_ratio * attention_mask).sum(-1) / attention_mask.sum(-1).clamp(min=1.0)
|
|
60
|
+
log_importance_weights = log_importance_weights.unsqueeze(-1)
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
|
|
64
|
+
"and 'sequence'."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
|
|
68
|
+
# importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
|
|
69
|
+
coef_1 = torch.exp(log_importance_weights)
|
|
70
|
+
coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
|
|
71
|
+
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
|
72
|
+
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
|
73
|
+
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
|
74
|
+
if beta != 0.0:
|
|
75
|
+
# Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
|
|
76
|
+
kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
|
|
77
|
+
# Combine losses
|
|
78
|
+
per_token_loss = per_token_loss + beta * kl_div
|
|
79
|
+
|
|
80
|
+
# Note: We normalize by the number of tokens in the batch (using full_attention_mask),
|
|
81
|
+
# which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
|
|
82
|
+
# and TRL GRPO implementation
|
|
83
|
+
# (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
|
|
84
|
+
if loss_type == "grpo":
|
|
85
|
+
# Average per-sequence loss
|
|
86
|
+
loss = (
|
|
87
|
+
(per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)
|
|
88
|
+
).sum() / full_attention_mask.shape[0]
|
|
89
|
+
elif loss_type == "bnpo":
|
|
90
|
+
# Batch Normalized Per-token loss (original implementation)
|
|
91
|
+
loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
|
|
92
|
+
elif loss_type == "dr_grpo":
|
|
93
|
+
# Dimension-Reduced GRPO (normalize by batch_size * max_completion_length)
|
|
94
|
+
if max_completion_length is None:
|
|
95
|
+
raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
|
|
96
|
+
loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
|
|
97
|
+
else:
|
|
98
|
+
raise ValueError(f"Unknown loss type: {loss_type}")
|
|
99
|
+
|
|
100
|
+
# Calculate metrics
|
|
101
|
+
metrics = []
|
|
102
|
+
if beta != 0.0:
|
|
103
|
+
metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
|
|
104
|
+
|
|
105
|
+
# Adjust clipping metric calculation based on importance sampling level
|
|
106
|
+
if importance_sampling_level == "token":
|
|
107
|
+
is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
|
|
108
|
+
(coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
|
|
109
|
+
)
|
|
110
|
+
else: # sequence level
|
|
111
|
+
# For sequence level, coef_1 is shape (B, 1), advantages is shape (B,)
|
|
112
|
+
is_clipped = ((coef_1.squeeze(-1) < 1 - epsilon_low) & (advantages < 0)) | (
|
|
113
|
+
(coef_1.squeeze(-1) > 1 + epsilon_high) & (advantages > 0)
|
|
114
|
+
)
|
|
115
|
+
is_clipped = is_clipped.unsqueeze(1).expand_as(attention_mask)
|
|
116
|
+
|
|
117
|
+
metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
|
|
118
|
+
return loss, metrics
|
|
119
|
+
|
|
120
|
+
@classmethod
|
|
121
|
+
def forward(
|
|
122
|
+
cls,
|
|
123
|
+
ctx,
|
|
124
|
+
_input,
|
|
125
|
+
weight,
|
|
126
|
+
selected_token_ids,
|
|
127
|
+
attention_mask,
|
|
128
|
+
advantages,
|
|
129
|
+
bias=None,
|
|
130
|
+
ref_per_token_logps=None,
|
|
131
|
+
old_per_token_logps=None,
|
|
132
|
+
ref_input=None,
|
|
133
|
+
ref_weight=None,
|
|
134
|
+
ref_bias=None,
|
|
135
|
+
beta=0.04,
|
|
136
|
+
epsilon_low=0.2,
|
|
137
|
+
epsilon_high=0.2,
|
|
138
|
+
loss_type="bnpo",
|
|
139
|
+
max_completion_length=None,
|
|
140
|
+
importance_sampling_level="token",
|
|
141
|
+
temperature=1.0,
|
|
142
|
+
compiled=True,
|
|
143
|
+
use_ref_model=True,
|
|
144
|
+
chunk_size=1,
|
|
145
|
+
):
|
|
146
|
+
"""
|
|
147
|
+
Fused linear layer with GRPO loss.
|
|
148
|
+
Args:
|
|
149
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
150
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
151
|
+
selected_token_ids (torch.Tensor): Selected token ids tensor. Shape: (batch_size, seq_len)
|
|
152
|
+
attention_mask (torch.Tensor): Attention mask tensor. Shape: (batch_size, seq_len)
|
|
153
|
+
advantages (torch.Tensor): Advantages tensor. Shape: (batch_size,)
|
|
154
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
155
|
+
ref_per_token_logps: Reference model log probs per token tensor. Shape:(batch_size, seq_len)
|
|
156
|
+
ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
157
|
+
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
|
158
|
+
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
|
159
|
+
beta (float): Weight for the KL penalty
|
|
160
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
|
|
161
|
+
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
162
|
+
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
|
163
|
+
temperature (float): Temperature for the logits
|
|
164
|
+
compiled (bool): Whether to use torch compile
|
|
165
|
+
use_ref_model (bool): Whether to use a reference model
|
|
166
|
+
chunk_size (int): Size of chunks for processing.
|
|
167
|
+
Returns:
|
|
168
|
+
torch.Tensor: Computed loss
|
|
169
|
+
"""
|
|
170
|
+
return super().forward(
|
|
171
|
+
cls=cls,
|
|
172
|
+
ctx=ctx,
|
|
173
|
+
_input=_input,
|
|
174
|
+
weight=weight,
|
|
175
|
+
selected_token_ids=selected_token_ids,
|
|
176
|
+
attention_mask=attention_mask,
|
|
177
|
+
advantages=advantages,
|
|
178
|
+
bias=bias,
|
|
179
|
+
ref_per_token_logps=ref_per_token_logps,
|
|
180
|
+
old_per_token_logps=old_per_token_logps,
|
|
181
|
+
ref_input=ref_input,
|
|
182
|
+
ref_weight=ref_weight,
|
|
183
|
+
ref_bias=ref_bias,
|
|
184
|
+
beta=beta,
|
|
185
|
+
epsilon_low=epsilon_low,
|
|
186
|
+
epsilon_high=epsilon_high,
|
|
187
|
+
loss_type=loss_type,
|
|
188
|
+
max_completion_length=max_completion_length,
|
|
189
|
+
temperature=temperature,
|
|
190
|
+
compiled=compiled,
|
|
191
|
+
use_ref_model=use_ref_model,
|
|
192
|
+
chunk_size=chunk_size,
|
|
193
|
+
importance_sampling_level=importance_sampling_level,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
@staticmethod
|
|
197
|
+
def backward(ctx, grad_output, *grad_metrics):
|
|
198
|
+
"""Backward pass for GRPO loss.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
grad_output: Gradient of the loss (scalar)
|
|
202
|
+
grad_metrics: Gradients of the metrics (not used in backward computation)
|
|
203
|
+
"""
|
|
204
|
+
grads = LigerFusedLinearPPOBase.backward(ctx, grad_output)
|
|
205
|
+
return (
|
|
206
|
+
*grads[
|
|
207
|
+
:6
|
|
208
|
+
], # grad_input, grad_weight, grad_selected_token_ids, grad_attention_mask, grad_advantages, grad_bias
|
|
209
|
+
None, # grad_ref_per_token_logps
|
|
210
|
+
None, # grad_old_per_token_logps
|
|
211
|
+
None, # grad_ref_input
|
|
212
|
+
None, # grad_ref_weight
|
|
213
|
+
None, # grad_ref_bias
|
|
214
|
+
None, # grad_beta
|
|
215
|
+
None, # grad_epsilon_low
|
|
216
|
+
None, # grad_epsilon_high
|
|
217
|
+
None, # grad_loss_type (string, not differentiable)
|
|
218
|
+
None, # grad_max_completion_length (int, not differentiable)
|
|
219
|
+
None, # grad_importance_sampling_level (string, not differentiable)
|
|
220
|
+
None, # grad_temperature
|
|
221
|
+
None, # grad_compiled
|
|
222
|
+
None, # grad_use_ref_model
|
|
223
|
+
None, # grad_chunk_size
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
228
|
+
"""Fused linear layer with GRPO loss."""
|
|
229
|
+
|
|
230
|
+
def __init__(
|
|
231
|
+
self,
|
|
232
|
+
beta: float = 0.04,
|
|
233
|
+
compiled: bool = True,
|
|
234
|
+
use_ref_model: bool = True,
|
|
235
|
+
chunk_size: int = 1,
|
|
236
|
+
epsilon_low: float = 0.2,
|
|
237
|
+
epsilon_high: float = 0.2,
|
|
238
|
+
loss_type: str = "bnpo",
|
|
239
|
+
max_completion_length: Optional[int] = None,
|
|
240
|
+
importance_sampling_level: str = "token",
|
|
241
|
+
temperature: float = 1.0,
|
|
242
|
+
):
|
|
243
|
+
"""
|
|
244
|
+
Args:
|
|
245
|
+
beta (float): Weight for the KL penalty.
|
|
246
|
+
compiled (bool): Whether to use torch compile.
|
|
247
|
+
use_ref_model (bool): Whether to use a reference model.
|
|
248
|
+
chunk_size (int): Size of chunks for processing.
|
|
249
|
+
epsilon_low (float): Lower bound for the importance sampling ratio.
|
|
250
|
+
epsilon_high (float): Upper bound for the importance sampling ratio.
|
|
251
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
|
|
252
|
+
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
253
|
+
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
|
254
|
+
temperature (float): Temperature for the logits.
|
|
255
|
+
"""
|
|
256
|
+
super().__init__()
|
|
257
|
+
self.beta = beta
|
|
258
|
+
self.compiled = compiled
|
|
259
|
+
self.use_ref_model = use_ref_model
|
|
260
|
+
self.chunk_size = chunk_size
|
|
261
|
+
self.epsilon_low = epsilon_low
|
|
262
|
+
self.epsilon_high = epsilon_high
|
|
263
|
+
self.loss_type = loss_type
|
|
264
|
+
self.max_completion_length = max_completion_length
|
|
265
|
+
self.importance_sampling_level = importance_sampling_level
|
|
266
|
+
self.temperature = temperature
|
|
267
|
+
|
|
268
|
+
def forward(
|
|
269
|
+
self,
|
|
270
|
+
_input,
|
|
271
|
+
lin_weight,
|
|
272
|
+
selected_token_ids,
|
|
273
|
+
attention_mask,
|
|
274
|
+
advantages,
|
|
275
|
+
bias=None,
|
|
276
|
+
ref_per_token_logps=None,
|
|
277
|
+
old_per_token_logps=None,
|
|
278
|
+
ref_input=None,
|
|
279
|
+
ref_weight=None,
|
|
280
|
+
ref_bias=None,
|
|
281
|
+
):
|
|
282
|
+
return LigerFusedLinearGRPOFunction.apply(
|
|
283
|
+
_input,
|
|
284
|
+
lin_weight,
|
|
285
|
+
selected_token_ids,
|
|
286
|
+
attention_mask,
|
|
287
|
+
advantages,
|
|
288
|
+
bias,
|
|
289
|
+
ref_per_token_logps,
|
|
290
|
+
old_per_token_logps,
|
|
291
|
+
ref_input,
|
|
292
|
+
ref_weight,
|
|
293
|
+
ref_bias,
|
|
294
|
+
self.beta,
|
|
295
|
+
self.epsilon_low,
|
|
296
|
+
self.epsilon_high,
|
|
297
|
+
self.loss_type,
|
|
298
|
+
self.max_completion_length,
|
|
299
|
+
self.importance_sampling_level,
|
|
300
|
+
self.temperature,
|
|
301
|
+
self.compiled,
|
|
302
|
+
self.use_ref_model,
|
|
303
|
+
self.chunk_size,
|
|
304
|
+
)
|