liger-kernel-nightly 0.5.5.dev20250331042257__py3-none-any.whl → 0.5.5.dev20250402184001__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +330 -0
- liger_kernel/chunked_loss/grpo_loss.py +103 -61
- liger_kernel/transformers/model/llava.py +20 -34
- {liger_kernel_nightly-0.5.5.dev20250331042257.dist-info → liger_kernel_nightly-0.5.5.dev20250402184001.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.5.dev20250331042257.dist-info → liger_kernel_nightly-0.5.5.dev20250402184001.dist-info}/RECORD +10 -10
- liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
- {liger_kernel_nightly-0.5.5.dev20250331042257.dist-info → liger_kernel_nightly-0.5.5.dev20250402184001.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257.dist-info → liger_kernel_nightly-0.5.5.dev20250402184001.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257.dist-info → liger_kernel_nightly-0.5.5.dev20250402184001.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331042257.dist-info → liger_kernel_nightly-0.5.5.dev20250402184001.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
|
|
2
2
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
|
|
3
|
+
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
|
|
3
4
|
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
|
|
4
5
|
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
|
|
5
6
|
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
|
|
@@ -11,3 +12,4 @@ liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
|
|
|
11
12
|
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
|
|
12
13
|
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
|
|
13
14
|
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
|
|
15
|
+
liger_fused_linear_grpo = LigerFusedLinearGRPOFunction.apply
|
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch._dynamo.config
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def ppo_loss_fn(*args, **kwargs):
|
|
12
|
+
"""
|
|
13
|
+
To be extended by subclasses.
|
|
14
|
+
"""
|
|
15
|
+
raise NotImplementedError("PPO loss function must be implemented.")
|
|
16
|
+
|
|
17
|
+
@staticmethod
|
|
18
|
+
def forward(
|
|
19
|
+
cls,
|
|
20
|
+
ctx,
|
|
21
|
+
_input,
|
|
22
|
+
weight,
|
|
23
|
+
selected_token_ids,
|
|
24
|
+
attention_mask,
|
|
25
|
+
advantages,
|
|
26
|
+
bias=None,
|
|
27
|
+
ref_per_token_logps=None,
|
|
28
|
+
old_per_token_logps=None,
|
|
29
|
+
ref_input=None,
|
|
30
|
+
ref_weight=None,
|
|
31
|
+
ref_bias=None,
|
|
32
|
+
epsilon_low=0.2,
|
|
33
|
+
epsilon_high=0.2,
|
|
34
|
+
beta=0.04,
|
|
35
|
+
temperature=1.0,
|
|
36
|
+
compiled=True,
|
|
37
|
+
use_ref_model=False,
|
|
38
|
+
chunk_size=1,
|
|
39
|
+
):
|
|
40
|
+
"""Chunked forward pass for PPO loss computation.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
cls: The class
|
|
44
|
+
ctx: Context for backward
|
|
45
|
+
_input: Input tensor
|
|
46
|
+
weight: Weight tensor
|
|
47
|
+
selected_token_ids: Selected token ids tensor
|
|
48
|
+
attention_mask: Attention mask tensor
|
|
49
|
+
advantages: Advantages tensor
|
|
50
|
+
bias: Bias tensor
|
|
51
|
+
ref_per_token_logps: Reference model log probs per token tensor
|
|
52
|
+
old_per_token_logps: Old per token log probabilities tensor
|
|
53
|
+
ref_input: Reference model input tensor
|
|
54
|
+
ref_weight: Reference model weight tensor
|
|
55
|
+
ref_bias: Reference model bias tensor
|
|
56
|
+
epsilon_low: Lower bound for clipping the importance sampling ratio
|
|
57
|
+
epsilon_high: Upper bound for clipping the importance sampling ratio
|
|
58
|
+
beta: Weight for the KL penalty
|
|
59
|
+
temperature: Temperature for the logits
|
|
60
|
+
compiled: Whether to use torch compile
|
|
61
|
+
use_ref_model: Whether to use a reference model
|
|
62
|
+
chunk_size: Size of chunks for processing in other loss modules
|
|
63
|
+
"""
|
|
64
|
+
if use_ref_model:
|
|
65
|
+
assert ref_per_token_logps is not None or ref_input is not None, (
|
|
66
|
+
"If use_ref_model is True, ref_per_token_logps or ref_input must be provided"
|
|
67
|
+
)
|
|
68
|
+
if ref_per_token_logps is not None and ref_input is not None:
|
|
69
|
+
raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
|
|
70
|
+
# Initialize accumulators
|
|
71
|
+
loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
|
|
72
|
+
grad_weight = torch.zeros_like(weight) # [V, H]
|
|
73
|
+
grad_inputs = []
|
|
74
|
+
grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
|
|
75
|
+
aggregated_metrics = []
|
|
76
|
+
|
|
77
|
+
# Create a partial function with fixed arguments
|
|
78
|
+
compute_loss = partial(
|
|
79
|
+
LigerFusedLinearPPOBase._compute_chunk_loss,
|
|
80
|
+
ref_weight=ref_weight,
|
|
81
|
+
ref_bias=ref_bias,
|
|
82
|
+
full_attention_mask=attention_mask,
|
|
83
|
+
epsilon_low=epsilon_low,
|
|
84
|
+
epsilon_high=epsilon_high,
|
|
85
|
+
beta=beta,
|
|
86
|
+
temperature=temperature,
|
|
87
|
+
use_ref_model=use_ref_model,
|
|
88
|
+
ppo_loss_fn=cls.ppo_loss_fn,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def fused_fwd_bwd(
|
|
92
|
+
input_chunk,
|
|
93
|
+
selected_token_ids_chunk,
|
|
94
|
+
attention_mask_chunk,
|
|
95
|
+
advantages_chunk,
|
|
96
|
+
ref_per_token_logps_chunk,
|
|
97
|
+
old_per_token_logps_chunk,
|
|
98
|
+
ref_input_chunk,
|
|
99
|
+
):
|
|
100
|
+
"""Fused forward and backward for a chunk."""
|
|
101
|
+
argnums = (0, 1, 5) if bias is not None else (0, 1)
|
|
102
|
+
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
|
|
103
|
+
input_chunk, # arg 0
|
|
104
|
+
weight, # arg 1
|
|
105
|
+
selected_token_ids_chunk, # arg 2
|
|
106
|
+
attention_mask_chunk, # arg 3
|
|
107
|
+
advantages_chunk, # arg 4
|
|
108
|
+
bias, # arg 5
|
|
109
|
+
ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6
|
|
110
|
+
old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7
|
|
111
|
+
ref_input_chunk=ref_input_chunk, # arg 8
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
def accumulate_chunk(
|
|
115
|
+
input_chunk,
|
|
116
|
+
selected_token_ids_chunk,
|
|
117
|
+
attention_mask_chunk,
|
|
118
|
+
advantages_chunk,
|
|
119
|
+
ref_per_token_logps_chunk=None,
|
|
120
|
+
old_per_token_logps_chunk=None,
|
|
121
|
+
ref_input_chunk=None,
|
|
122
|
+
):
|
|
123
|
+
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
|
|
124
|
+
input_chunk,
|
|
125
|
+
selected_token_ids_chunk,
|
|
126
|
+
attention_mask_chunk,
|
|
127
|
+
advantages_chunk,
|
|
128
|
+
ref_per_token_logps_chunk,
|
|
129
|
+
old_per_token_logps_chunk,
|
|
130
|
+
ref_input_chunk,
|
|
131
|
+
)
|
|
132
|
+
if bias is not None:
|
|
133
|
+
grad_bias.add_(chunk_grad_bias[0])
|
|
134
|
+
|
|
135
|
+
# Accumulate gradients and loss
|
|
136
|
+
grad_weight.add_(chunk_grad_weight)
|
|
137
|
+
grad_inputs.append(chunk_grad_input)
|
|
138
|
+
loss_acc.add_(chunk_loss)
|
|
139
|
+
# Initialize storage for metrics on first chunk
|
|
140
|
+
if len(aggregated_metrics) == 0:
|
|
141
|
+
for metric in chunk_metrics:
|
|
142
|
+
if metric.ndim == 0:
|
|
143
|
+
aggregated_metrics.append(torch.zeros((), device=metric.device))
|
|
144
|
+
else:
|
|
145
|
+
aggregated_metrics.append([])
|
|
146
|
+
|
|
147
|
+
# Accumulate metrics
|
|
148
|
+
for i, metric in enumerate(chunk_metrics):
|
|
149
|
+
if metric.ndim == 0:
|
|
150
|
+
aggregated_metrics[i].add_(metric)
|
|
151
|
+
else:
|
|
152
|
+
aggregated_metrics[i].append(metric)
|
|
153
|
+
|
|
154
|
+
if compiled:
|
|
155
|
+
# TODO: Figure out what is better to compile here
|
|
156
|
+
# accumulate_chunk = torch.compile(accumulate_chunk)
|
|
157
|
+
fused_fwd_bwd = torch.compile(fused_fwd_bwd)
|
|
158
|
+
|
|
159
|
+
# Process input in chunks based on chunk_size
|
|
160
|
+
chunks = max(1, _input.shape[0] // chunk_size)
|
|
161
|
+
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
|
|
162
|
+
_selected_token_ids_chunks = torch.chunk(selected_token_ids, chunks=chunks, dim=0)
|
|
163
|
+
_attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
|
|
164
|
+
_advantages_chunks = torch.chunk(advantages, chunks=chunks, dim=0)
|
|
165
|
+
_ref_per_token_logps_chunks = (
|
|
166
|
+
torch.chunk(ref_per_token_logps, chunks=chunks, dim=0)
|
|
167
|
+
if use_ref_model and ref_per_token_logps is not None
|
|
168
|
+
else [None] * chunks
|
|
169
|
+
)
|
|
170
|
+
_old_per_token_logps_chunks = (
|
|
171
|
+
torch.chunk(old_per_token_logps, chunks=chunks, dim=0)
|
|
172
|
+
if old_per_token_logps is not None
|
|
173
|
+
else [None] * chunks
|
|
174
|
+
)
|
|
175
|
+
# if ref_log_probs is not none, then we don't need ref_input to calculate the log probs
|
|
176
|
+
_ref_input_chunks = (
|
|
177
|
+
torch.chunk(ref_input, chunks=chunks, dim=0)
|
|
178
|
+
if use_ref_model and ref_per_token_logps is None
|
|
179
|
+
else [None] * chunks
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
for (
|
|
183
|
+
input_chunk,
|
|
184
|
+
selected_token_ids_chunk,
|
|
185
|
+
attention_mask_chunk,
|
|
186
|
+
advantages_chunk,
|
|
187
|
+
ref_per_token_logps_chunk,
|
|
188
|
+
old_per_token_logps_chunk,
|
|
189
|
+
ref_input_chunk,
|
|
190
|
+
) in zip(
|
|
191
|
+
_input_chunks,
|
|
192
|
+
_selected_token_ids_chunks,
|
|
193
|
+
_attention_mask_chunks,
|
|
194
|
+
_advantages_chunks,
|
|
195
|
+
_ref_per_token_logps_chunks,
|
|
196
|
+
_old_per_token_logps_chunks,
|
|
197
|
+
_ref_input_chunks,
|
|
198
|
+
):
|
|
199
|
+
# Mark dynamic dimensions
|
|
200
|
+
torch._dynamo.mark_dynamic(input_chunk, 1)
|
|
201
|
+
torch._dynamo.mark_dynamic(selected_token_ids_chunk, 1)
|
|
202
|
+
torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
|
|
203
|
+
if ref_per_token_logps_chunk is not None:
|
|
204
|
+
torch._dynamo.mark_dynamic(ref_per_token_logps_chunk, 1)
|
|
205
|
+
if ref_input_chunk is not None:
|
|
206
|
+
torch._dynamo.mark_dynamic(ref_input_chunk, 1)
|
|
207
|
+
if old_per_token_logps_chunk is not None:
|
|
208
|
+
torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1)
|
|
209
|
+
|
|
210
|
+
accumulate_chunk(
|
|
211
|
+
input_chunk,
|
|
212
|
+
selected_token_ids_chunk,
|
|
213
|
+
attention_mask_chunk,
|
|
214
|
+
advantages_chunk,
|
|
215
|
+
ref_per_token_logps_chunk,
|
|
216
|
+
old_per_token_logps_chunk,
|
|
217
|
+
ref_input_chunk,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# Combine gradients
|
|
221
|
+
grad_input = torch.cat(grad_inputs, dim=0)
|
|
222
|
+
|
|
223
|
+
# Save for backward
|
|
224
|
+
ctx.save_for_backward(grad_input, grad_weight, grad_bias)
|
|
225
|
+
|
|
226
|
+
# Finalize metrics
|
|
227
|
+
final_metrics = []
|
|
228
|
+
for metric in aggregated_metrics:
|
|
229
|
+
if isinstance(metric, list):
|
|
230
|
+
final_metrics.append(torch.cat(metric, dim=0))
|
|
231
|
+
else:
|
|
232
|
+
final_metrics.append(metric)
|
|
233
|
+
|
|
234
|
+
return loss_acc, tuple(final_metrics)
|
|
235
|
+
|
|
236
|
+
@staticmethod
|
|
237
|
+
def _compute_chunk_loss(
|
|
238
|
+
input_chunk,
|
|
239
|
+
weight,
|
|
240
|
+
selected_token_ids_chunk,
|
|
241
|
+
attention_mask_chunk,
|
|
242
|
+
advantages_chunk,
|
|
243
|
+
bias=None,
|
|
244
|
+
ref_per_token_logps_chunk=None,
|
|
245
|
+
old_per_token_logps_chunk=None,
|
|
246
|
+
ref_input_chunk=None,
|
|
247
|
+
ref_weight=None,
|
|
248
|
+
ref_bias=None,
|
|
249
|
+
full_attention_mask=None,
|
|
250
|
+
epsilon_low=0.2,
|
|
251
|
+
epsilon_high=0.2,
|
|
252
|
+
beta=0.04,
|
|
253
|
+
temperature=1.0,
|
|
254
|
+
use_ref_model=False,
|
|
255
|
+
ppo_loss_fn=None,
|
|
256
|
+
):
|
|
257
|
+
"""Compute loss for a single chunk."""
|
|
258
|
+
# Get policy log probabilities using chunk_forward
|
|
259
|
+
log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(input_chunk, weight, bias=bias, temperature=temperature)
|
|
260
|
+
|
|
261
|
+
# Get reference log probabilities if needed
|
|
262
|
+
ref_log_probs = None
|
|
263
|
+
if use_ref_model and ref_per_token_logps_chunk is None:
|
|
264
|
+
with torch.no_grad():
|
|
265
|
+
ref_log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(
|
|
266
|
+
ref_input_chunk, ref_weight, bias=ref_bias, temperature=temperature
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Compute chunk loss and metrics using the provided loss function
|
|
270
|
+
chunk_loss, chunk_metrics = ppo_loss_fn(
|
|
271
|
+
log_probs=log_probs,
|
|
272
|
+
selected_token_ids=selected_token_ids_chunk,
|
|
273
|
+
attention_mask=attention_mask_chunk,
|
|
274
|
+
advantages=advantages_chunk,
|
|
275
|
+
full_attention_mask=full_attention_mask,
|
|
276
|
+
ref_per_token_logps=ref_per_token_logps_chunk.float() if ref_per_token_logps_chunk is not None else None,
|
|
277
|
+
old_per_token_logps=old_per_token_logps_chunk.float() if old_per_token_logps_chunk is not None else None,
|
|
278
|
+
ref_log_probs=ref_log_probs, # used when ref_per_token_logps is None
|
|
279
|
+
epsilon_low=epsilon_low,
|
|
280
|
+
epsilon_high=epsilon_high,
|
|
281
|
+
beta=beta,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
return chunk_loss, chunk_metrics
|
|
285
|
+
|
|
286
|
+
@staticmethod
|
|
287
|
+
def chunk_forward(input_chunk, weight, bias=None, temperature=1.0):
|
|
288
|
+
"""Forward pass computation for a single chunk without explicit reshaping."""
|
|
289
|
+
# Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
|
|
290
|
+
logits = torch.matmul(input_chunk, weight.t())
|
|
291
|
+
if bias is not None:
|
|
292
|
+
logits = logits + bias # Broadcasts bias to [B, T, V]
|
|
293
|
+
if temperature != 1.0:
|
|
294
|
+
logits = logits / temperature
|
|
295
|
+
|
|
296
|
+
# Compute log probabilities using softmax over the last dimension
|
|
297
|
+
log_probs = F.log_softmax(logits.float(), dim=-1)
|
|
298
|
+
|
|
299
|
+
return log_probs, logits
|
|
300
|
+
|
|
301
|
+
@staticmethod
|
|
302
|
+
def backward(ctx, grad_output, *grad_metrics):
|
|
303
|
+
"""Backward pass for PPO loss."""
|
|
304
|
+
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
305
|
+
if grad_output != 1.0:
|
|
306
|
+
grad_input = grad_input * grad_output
|
|
307
|
+
grad_weight = grad_weight * grad_output
|
|
308
|
+
if grad_bias is not None:
|
|
309
|
+
grad_bias = grad_bias * grad_output
|
|
310
|
+
|
|
311
|
+
return (
|
|
312
|
+
grad_input,
|
|
313
|
+
grad_weight,
|
|
314
|
+
None, # grad_selected_token_ids
|
|
315
|
+
None, # grad_attention_mask
|
|
316
|
+
None, # grad_advantages
|
|
317
|
+
grad_bias,
|
|
318
|
+
None, # grad_ref_per_token_logps
|
|
319
|
+
None, # grad_old_per_token_logps
|
|
320
|
+
None, # grad_ref_input
|
|
321
|
+
None, # grad_ref_weight
|
|
322
|
+
None, # grad_ref_bias
|
|
323
|
+
None, # grad_epsilon_low
|
|
324
|
+
None, # grad_epsilon_high
|
|
325
|
+
None, # grad_beta
|
|
326
|
+
None, # grad_temperature
|
|
327
|
+
None, # grad_compiled
|
|
328
|
+
None, # grad_use_ref_model
|
|
329
|
+
None, # grad_chunk_size
|
|
330
|
+
)
|
|
@@ -1,66 +1,76 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from liger_kernel.chunked_loss.
|
|
3
|
+
from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
|
|
6
|
+
def k3_loss_fn(log_p, log_q):
|
|
7
|
+
# computes k3 estimate of KL[q, p]
|
|
8
|
+
# ref: http://joschu.net/blog/kl-approx.html
|
|
9
|
+
return torch.exp(log_p - log_q) - (log_p - log_q) - 1.0
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def clip_coef_fn(coef, epsilon_low, epsilon_high):
|
|
13
|
+
return torch.clamp(coef, 1 - epsilon_low, 1 + epsilon_high)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
7
17
|
@staticmethod
|
|
8
|
-
def
|
|
18
|
+
def ppo_loss_fn(
|
|
9
19
|
log_probs,
|
|
20
|
+
selected_token_ids,
|
|
10
21
|
attention_mask,
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
22
|
+
advantages,
|
|
23
|
+
full_attention_mask,
|
|
24
|
+
ref_per_token_logps=None, # shape: [chunk_size, seq_len]
|
|
25
|
+
old_per_token_logps=None,
|
|
26
|
+
ref_log_probs=None, # used when ref_per_token_logps is None (shape: [chunk_size, seq_len, vocab_size])
|
|
27
|
+
epsilon_low=0.2,
|
|
28
|
+
epsilon_high=0.2,
|
|
29
|
+
beta=0.04,
|
|
14
30
|
**kwargs,
|
|
15
31
|
):
|
|
16
32
|
"""GRPO Loss Function matching GRPOTrainer implementation."""
|
|
17
|
-
|
|
18
|
-
chosen_tokens = log_probs.argmax(dim=-1) # (batch_size, seq_len)
|
|
19
|
-
chosen_token_logprobs = log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(
|
|
33
|
+
per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
|
|
20
34
|
-1
|
|
21
35
|
) # (batch_size, seq_len)
|
|
22
36
|
|
|
23
37
|
# Get reference model probabilities
|
|
24
|
-
if
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
std_grouped_rewards = rewards.std() # [batch_size,]
|
|
33
|
-
|
|
34
|
-
# Calculate advantages using the same epsilon as in GRPOTrainer
|
|
35
|
-
eps = 1e-4
|
|
36
|
-
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + eps)
|
|
38
|
+
if ref_per_token_logps is None:
|
|
39
|
+
if ref_log_probs is not None:
|
|
40
|
+
with torch.no_grad():
|
|
41
|
+
ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
|
|
42
|
+
-1
|
|
43
|
+
)
|
|
44
|
+
else:
|
|
45
|
+
ref_per_token_logps = per_token_logps.detach()
|
|
37
46
|
|
|
38
47
|
# Compute policy gradient loss with importance sampling ratio
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
48
|
+
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
|
|
49
|
+
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
|
|
50
|
+
coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
|
|
51
|
+
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
|
52
|
+
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
|
53
|
+
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
|
54
|
+
if beta != 0.0:
|
|
55
|
+
# Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
|
|
56
|
+
kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
|
|
57
|
+
# Combine losses
|
|
58
|
+
per_token_loss = per_token_loss + beta * kl_div
|
|
59
|
+
|
|
60
|
+
# Note: We normalize by the number of tokens in the batch (using full_attention_mask),
|
|
61
|
+
# which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
|
|
62
|
+
# and TRL GRPO implementation
|
|
63
|
+
# (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
|
|
64
|
+
loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
|
|
55
65
|
|
|
56
66
|
# Calculate metrics
|
|
57
|
-
metrics =
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
(
|
|
67
|
+
metrics = []
|
|
68
|
+
if beta != 0.0:
|
|
69
|
+
metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
|
|
70
|
+
is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
|
|
71
|
+
(coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
|
|
62
72
|
)
|
|
63
|
-
|
|
73
|
+
metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
|
|
64
74
|
return loss, metrics
|
|
65
75
|
|
|
66
76
|
@classmethod
|
|
@@ -69,16 +79,21 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
|
|
|
69
79
|
ctx,
|
|
70
80
|
_input,
|
|
71
81
|
weight,
|
|
82
|
+
selected_token_ids,
|
|
72
83
|
attention_mask,
|
|
73
|
-
|
|
84
|
+
advantages,
|
|
74
85
|
bias=None,
|
|
86
|
+
ref_per_token_logps=None,
|
|
87
|
+
old_per_token_logps=None,
|
|
75
88
|
ref_input=None,
|
|
76
89
|
ref_weight=None,
|
|
77
90
|
ref_bias=None,
|
|
78
|
-
beta=0.
|
|
91
|
+
beta=0.04,
|
|
92
|
+
epsilon_low=0.2,
|
|
93
|
+
epsilon_high=0.2,
|
|
94
|
+
temperature=1.0,
|
|
79
95
|
compiled=True,
|
|
80
96
|
use_ref_model=True,
|
|
81
|
-
num_generations=1,
|
|
82
97
|
chunk_size=1,
|
|
83
98
|
):
|
|
84
99
|
"""
|
|
@@ -86,16 +101,18 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
|
|
|
86
101
|
Args:
|
|
87
102
|
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
88
103
|
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
104
|
+
selected_token_ids (torch.Tensor): Selected token ids tensor. Shape: (batch_size, seq_len)
|
|
89
105
|
attention_mask (torch.Tensor): Attention mask tensor. Shape: (batch_size, seq_len)
|
|
90
|
-
|
|
106
|
+
advantages (torch.Tensor): Advantages tensor. Shape: (batch_size,)
|
|
91
107
|
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
108
|
+
ref_per_token_logps: Reference model log probs per token tensor. Shape:(batch_size, seq_len)
|
|
92
109
|
ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
93
110
|
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
|
94
111
|
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
|
95
112
|
beta (float): Weight for the KL penalty
|
|
113
|
+
temperature (float): Temperature for the logits
|
|
96
114
|
compiled (bool): Whether to use torch compile
|
|
97
115
|
use_ref_model (bool): Whether to use a reference model
|
|
98
|
-
num_generations (int): Number of generations per prompt
|
|
99
116
|
chunk_size (int): Size of chunks for processing.
|
|
100
117
|
Returns:
|
|
101
118
|
torch.Tensor: Computed loss
|
|
@@ -105,16 +122,21 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
|
|
|
105
122
|
ctx=ctx,
|
|
106
123
|
_input=_input,
|
|
107
124
|
weight=weight,
|
|
125
|
+
selected_token_ids=selected_token_ids,
|
|
108
126
|
attention_mask=attention_mask,
|
|
109
|
-
|
|
127
|
+
advantages=advantages,
|
|
110
128
|
bias=bias,
|
|
129
|
+
ref_per_token_logps=ref_per_token_logps,
|
|
130
|
+
old_per_token_logps=old_per_token_logps,
|
|
111
131
|
ref_input=ref_input,
|
|
112
132
|
ref_weight=ref_weight,
|
|
113
133
|
ref_bias=ref_bias,
|
|
114
134
|
beta=beta,
|
|
135
|
+
epsilon_low=epsilon_low,
|
|
136
|
+
epsilon_high=epsilon_high,
|
|
137
|
+
temperature=temperature,
|
|
115
138
|
compiled=compiled,
|
|
116
139
|
use_ref_model=use_ref_model,
|
|
117
|
-
num_generations=num_generations,
|
|
118
140
|
chunk_size=chunk_size,
|
|
119
141
|
)
|
|
120
142
|
|
|
@@ -126,16 +148,22 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
|
|
|
126
148
|
grad_output: Gradient of the loss (scalar)
|
|
127
149
|
grad_metrics: Gradients of the metrics (not used in backward computation)
|
|
128
150
|
"""
|
|
129
|
-
grads =
|
|
151
|
+
grads = LigerFusedLinearPPOBase.backward(ctx, grad_output)
|
|
130
152
|
return (
|
|
131
|
-
*grads[
|
|
153
|
+
*grads[
|
|
154
|
+
:6
|
|
155
|
+
], # grad_input, grad_weight, grad_selected_token_ids, grad_attention_mask, grad_advantages, grad_bias
|
|
156
|
+
None, # grad_ref_per_token_logps
|
|
157
|
+
None, # grad_old_per_token_logps
|
|
132
158
|
None, # grad_ref_input
|
|
133
159
|
None, # grad_ref_weight
|
|
134
160
|
None, # grad_ref_bias
|
|
135
161
|
None, # grad_beta
|
|
162
|
+
None, # grad_epsilon_low
|
|
163
|
+
None, # grad_epsilon_high
|
|
164
|
+
None, # grad_temperature
|
|
136
165
|
None, # grad_compiled
|
|
137
166
|
None, # grad_use_ref_model
|
|
138
|
-
None, # grad_num_generations
|
|
139
167
|
None, # grad_chunk_size
|
|
140
168
|
)
|
|
141
169
|
|
|
@@ -145,34 +173,43 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
145
173
|
|
|
146
174
|
def __init__(
|
|
147
175
|
self,
|
|
148
|
-
beta: float = 0.
|
|
176
|
+
beta: float = 0.04,
|
|
149
177
|
compiled: bool = True,
|
|
150
178
|
use_ref_model: bool = True,
|
|
151
|
-
num_generations: int = 1,
|
|
152
179
|
chunk_size: int = 1,
|
|
180
|
+
epsilon_low: float = 0.2,
|
|
181
|
+
epsilon_high: float = 0.2,
|
|
182
|
+
temperature: float = 1.0,
|
|
153
183
|
):
|
|
154
184
|
"""
|
|
155
185
|
Args:
|
|
156
186
|
beta (float): Weight for the KL penalty.
|
|
157
187
|
compiled (bool): Whether to use torch compile.
|
|
158
188
|
use_ref_model (bool): Whether to use a reference model.
|
|
159
|
-
num_generations (int): Number of generations per prompt.
|
|
160
189
|
chunk_size (int): Size of chunks for processing.
|
|
190
|
+
epsilon_low (float): Lower bound for the importance sampling ratio.
|
|
191
|
+
epsilon_high (float): Upper bound for the importance sampling ratio.
|
|
192
|
+
temperature (float): Temperature for the logits.
|
|
161
193
|
"""
|
|
162
194
|
super().__init__()
|
|
163
195
|
self.beta = beta
|
|
164
196
|
self.compiled = compiled
|
|
165
197
|
self.use_ref_model = use_ref_model
|
|
166
|
-
self.num_generations = num_generations
|
|
167
198
|
self.chunk_size = chunk_size
|
|
199
|
+
self.epsilon_low = epsilon_low
|
|
200
|
+
self.epsilon_high = epsilon_high
|
|
201
|
+
self.temperature = temperature
|
|
168
202
|
|
|
169
203
|
def forward(
|
|
170
204
|
self,
|
|
171
205
|
_input,
|
|
172
206
|
lin_weight,
|
|
207
|
+
selected_token_ids,
|
|
173
208
|
attention_mask,
|
|
174
|
-
|
|
209
|
+
advantages,
|
|
175
210
|
bias=None,
|
|
211
|
+
ref_per_token_logps=None,
|
|
212
|
+
old_per_token_logps=None,
|
|
176
213
|
ref_input=None,
|
|
177
214
|
ref_weight=None,
|
|
178
215
|
ref_bias=None,
|
|
@@ -180,15 +217,20 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
180
217
|
return LigerFusedLinearGRPOFunction.apply(
|
|
181
218
|
_input,
|
|
182
219
|
lin_weight,
|
|
220
|
+
selected_token_ids,
|
|
183
221
|
attention_mask,
|
|
184
|
-
|
|
222
|
+
advantages,
|
|
185
223
|
bias,
|
|
224
|
+
ref_per_token_logps,
|
|
225
|
+
old_per_token_logps,
|
|
186
226
|
ref_input,
|
|
187
227
|
ref_weight,
|
|
188
228
|
ref_bias,
|
|
189
229
|
self.beta,
|
|
230
|
+
self.epsilon_low,
|
|
231
|
+
self.epsilon_high,
|
|
232
|
+
self.temperature,
|
|
190
233
|
self.compiled,
|
|
191
234
|
self.use_ref_model,
|
|
192
|
-
self.num_generations,
|
|
193
235
|
self.chunk_size,
|
|
194
236
|
)
|
|
@@ -8,7 +8,6 @@ import torch
|
|
|
8
8
|
from transformers.models.llava.modeling_llava import _CONFIG_FOR_DOC
|
|
9
9
|
from transformers.models.llava.modeling_llava import LLAVA_INPUTS_DOCSTRING
|
|
10
10
|
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
|
|
11
|
-
from transformers.models.llava.modeling_llava import logger
|
|
12
11
|
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
12
|
from transformers.utils import is_torchdynamo_compiling
|
|
14
13
|
from transformers.utils import replace_return_docstrings
|
|
@@ -34,8 +33,6 @@ def lce_forward_deprecated(
|
|
|
34
33
|
output_attentions: Optional[bool] = None,
|
|
35
34
|
output_hidden_states: Optional[bool] = None,
|
|
36
35
|
return_dict: Optional[bool] = None,
|
|
37
|
-
cache_position: Optional[torch.LongTensor] = None,
|
|
38
|
-
num_logits_to_keep: int = 0,
|
|
39
36
|
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
|
|
40
37
|
r"""
|
|
41
38
|
Args:
|
|
@@ -96,39 +93,32 @@ def lce_forward_deprecated(
|
|
|
96
93
|
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
|
97
94
|
)
|
|
98
95
|
|
|
99
|
-
legacy_processing = False
|
|
100
96
|
if inputs_embeds is None:
|
|
97
|
+
# 1. Extra the input embeddings
|
|
101
98
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
102
99
|
|
|
103
|
-
#
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
if legacy_processing and image_features is not None:
|
|
119
|
-
logger.warning_once(
|
|
120
|
-
"Expanding inputs for image tokens in LLaVa should be done in processing. "
|
|
121
|
-
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
|
122
|
-
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
|
123
|
-
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
|
|
124
|
-
)
|
|
125
|
-
# prefill stage vs decoding stage (legacy behavior copied)
|
|
126
|
-
if input_ids.shape[1] != 1:
|
|
100
|
+
# 2. Merge text and images
|
|
101
|
+
if pixel_values is not None and input_ids.shape[1] != 1:
|
|
102
|
+
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
|
103
|
+
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
|
|
104
|
+
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
|
105
|
+
|
|
106
|
+
if vision_feature_select_strategy == "default":
|
|
107
|
+
selected_image_feature = selected_image_feature[:, 1:]
|
|
108
|
+
elif vision_feature_select_strategy == "full":
|
|
109
|
+
selected_image_feature = selected_image_feature
|
|
110
|
+
else:
|
|
111
|
+
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
|
112
|
+
|
|
113
|
+
image_features = self.multi_modal_projector(selected_image_feature)
|
|
114
|
+
inputs_embeds = inputs_embeds.to(image_features.dtype)
|
|
127
115
|
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
|
128
116
|
image_features, inputs_embeds, input_ids, attention_mask, labels
|
|
129
117
|
)
|
|
130
|
-
|
|
131
|
-
|
|
118
|
+
|
|
119
|
+
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
|
120
|
+
# generation with cache
|
|
121
|
+
elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
|
|
132
122
|
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
|
133
123
|
# that are set to 0
|
|
134
124
|
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
|
@@ -158,7 +148,6 @@ def lce_forward_deprecated(
|
|
|
158
148
|
|
|
159
149
|
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
|
160
150
|
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
|
161
|
-
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
|
|
162
151
|
|
|
163
152
|
# TODO: @raushan retain only the new behavior after v4.47
|
|
164
153
|
elif image_features is not None:
|
|
@@ -184,8 +173,6 @@ def lce_forward_deprecated(
|
|
|
184
173
|
output_attentions=output_attentions,
|
|
185
174
|
output_hidden_states=output_hidden_states,
|
|
186
175
|
return_dict=return_dict,
|
|
187
|
-
cache_position=cache_position,
|
|
188
|
-
num_logits_to_keep=num_logits_to_keep,
|
|
189
176
|
)
|
|
190
177
|
hidden_states = outputs[0]
|
|
191
178
|
|
|
@@ -220,7 +207,6 @@ def lce_forward_deprecated(
|
|
|
220
207
|
past_key_values=outputs.past_key_values,
|
|
221
208
|
hidden_states=outputs.hidden_states,
|
|
222
209
|
attentions=outputs.attentions,
|
|
223
|
-
image_hidden_states=image_features if pixel_values is not None else None,
|
|
224
210
|
)
|
|
225
211
|
|
|
226
212
|
|
|
@@ -5,12 +5,12 @@ liger_kernel/chunked_loss/README.md,sha256=0FmkFC3hKBqyoDT5uTlIYmrvRkF-EOCR1y-EB
|
|
|
5
5
|
liger_kernel/chunked_loss/__init__.py,sha256=ATu-xX5Fc49Cr6yBOGBRNTo593ZrU5ZCsIuvoIbJWw4,603
|
|
6
6
|
liger_kernel/chunked_loss/cpo_loss.py,sha256=Gzz1eU4kgcbdubFVRy55e8A1Cr-r45UgNicXwZIjmBU,5454
|
|
7
7
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=xZwGqS04si9zXyob95SAdalC-hajZg8fWINqiqffN8k,5855
|
|
8
|
-
liger_kernel/chunked_loss/functional.py,sha256=
|
|
8
|
+
liger_kernel/chunked_loss/functional.py,sha256=9G3nKm-Bi7uoZRFkL8wwGMl6juDl4bSzDvTa5GHZPzg,955
|
|
9
9
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=ooR-qnZCyWJN935oHCSWLaKKKyaYERyhNczRGi1VOiw,11935
|
|
10
|
+
liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=2_UvvIksUP45RBw3c-88-jOtjGATf04vaWopcqtX4Oo,12688
|
|
10
11
|
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=ojB42jYPu0c4ki96Ft-hy7Sf6fh_WikG-aWNrlZzSio,18362
|
|
11
|
-
liger_kernel/chunked_loss/fused_linear_rlhf.py,sha256=wGujqwLz91mOE9MmdenhBIKvbmswhwtINMCpcP7D74c,9050
|
|
12
12
|
liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=RiuK3UtRwH9T6jZ36sA8Urj-TVuOLOO2syLg_JOQapY,13437
|
|
13
|
-
liger_kernel/chunked_loss/grpo_loss.py,sha256=
|
|
13
|
+
liger_kernel/chunked_loss/grpo_loss.py,sha256=6Mb4ZT6MfnOr4Xo681rMR0LKkhzJhInvQp8wp2YVMK0,8913
|
|
14
14
|
liger_kernel/chunked_loss/jsd_loss.py,sha256=u2ahkuHsbhpNaKcpBCz5gCMDk9ou-P04DHji592dIBo,7067
|
|
15
15
|
liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsmSbQyqwQY,7529
|
|
16
16
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
|
|
@@ -57,7 +57,7 @@ liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm
|
|
|
57
57
|
liger_kernel/transformers/model/gemma.py,sha256=7cBTljzh-8_ACBhYl6NUfj5_ux92YRlmnAU5gfDAQAI,9312
|
|
58
58
|
liger_kernel/transformers/model/gemma2.py,sha256=X0FOIhvFlTrmWI7Ws06wUkutgHW3lWtLOnnHp1NgZ3A,10403
|
|
59
59
|
liger_kernel/transformers/model/llama.py,sha256=d9rBaK8e8RSMCFHdgom9ZHuXOlnh6U_o-GkAFGRNGOY,9989
|
|
60
|
-
liger_kernel/transformers/model/llava.py,sha256=
|
|
60
|
+
liger_kernel/transformers/model/llava.py,sha256=b0pEagjUbu2-eS9xegjyfl1DwIXLwZcNpff55ibaMbA,17601
|
|
61
61
|
liger_kernel/transformers/model/loss_utils.py,sha256=Z-fUrf-cUDUjUIH7Tl9OL2hT8nmtx7ES3kg8syuWKy4,1476
|
|
62
62
|
liger_kernel/transformers/model/mistral.py,sha256=o7tyl1sPWPfZwwrBLRlryHlSI8I55viuJoMI5Bh5Nww,5014
|
|
63
63
|
liger_kernel/transformers/model/mixtral.py,sha256=T0ITv2-PkR8VErVOVUizoS4EzjmARyR7GFh0tXDB_i4,11089
|
|
@@ -72,9 +72,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
|
72
72
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
|
73
73
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
|
74
74
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
|
75
|
-
liger_kernel_nightly-0.5.5.
|
|
76
|
-
liger_kernel_nightly-0.5.5.
|
|
77
|
-
liger_kernel_nightly-0.5.5.
|
|
78
|
-
liger_kernel_nightly-0.5.5.
|
|
79
|
-
liger_kernel_nightly-0.5.5.
|
|
80
|
-
liger_kernel_nightly-0.5.5.
|
|
75
|
+
liger_kernel_nightly-0.5.5.dev20250402184001.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
|
76
|
+
liger_kernel_nightly-0.5.5.dev20250402184001.dist-info/METADATA,sha256=DLGGPCgn1-dKSQVP5sYIzzRoh7c9wBUjM7JFujYn1KI,22959
|
|
77
|
+
liger_kernel_nightly-0.5.5.dev20250402184001.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
|
78
|
+
liger_kernel_nightly-0.5.5.dev20250402184001.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
79
|
+
liger_kernel_nightly-0.5.5.dev20250402184001.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
|
80
|
+
liger_kernel_nightly-0.5.5.dev20250402184001.dist-info/RECORD,,
|
|
@@ -1,240 +0,0 @@
|
|
|
1
|
-
from abc import abstractmethod
|
|
2
|
-
from functools import partial
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
import torch.nn.functional as F
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class LigerFusedLinearRLHFBase(torch.autograd.Function):
|
|
9
|
-
@abstractmethod
|
|
10
|
-
def rlhf_loss_fn(*args, **kwargs):
|
|
11
|
-
"""
|
|
12
|
-
To be extended by subclasses.
|
|
13
|
-
"""
|
|
14
|
-
raise NotImplementedError("RLHF loss function must be implemented.")
|
|
15
|
-
|
|
16
|
-
@staticmethod
|
|
17
|
-
def forward(
|
|
18
|
-
cls,
|
|
19
|
-
ctx,
|
|
20
|
-
_input,
|
|
21
|
-
weight,
|
|
22
|
-
attention_mask,
|
|
23
|
-
rewards,
|
|
24
|
-
bias=None,
|
|
25
|
-
num_generations=4,
|
|
26
|
-
beta=0.1,
|
|
27
|
-
compiled=True,
|
|
28
|
-
use_ref_model=False,
|
|
29
|
-
ref_input=None,
|
|
30
|
-
ref_weight=None,
|
|
31
|
-
ref_bias=None,
|
|
32
|
-
chunk_size=1,
|
|
33
|
-
):
|
|
34
|
-
"""Chunked forward pass for RLHF loss computation.
|
|
35
|
-
|
|
36
|
-
Args:
|
|
37
|
-
cls: The class
|
|
38
|
-
ctx: Context for backward
|
|
39
|
-
_input: Input tensor
|
|
40
|
-
weight: Weight tensor
|
|
41
|
-
attention_mask: Attention mask tensor
|
|
42
|
-
rewards: Rewards tensor
|
|
43
|
-
bias: Bias tensor
|
|
44
|
-
num_generations: Number of generations per prompt
|
|
45
|
-
beta: Weight for the KL penalty
|
|
46
|
-
compiled: Whether to use torch compile
|
|
47
|
-
use_ref_model: Whether to use a reference model
|
|
48
|
-
ref_input: Reference model input tensor
|
|
49
|
-
ref_weight: Reference model weight tensor
|
|
50
|
-
ref_bias: Reference model bias tensor
|
|
51
|
-
chunk_size: Size of chunks for processing in other loss modules
|
|
52
|
-
"""
|
|
53
|
-
# Save for backward
|
|
54
|
-
ctx.beta = beta
|
|
55
|
-
ctx.rewards = rewards
|
|
56
|
-
|
|
57
|
-
# Initialize accumulators
|
|
58
|
-
loss_acc = torch.zeros((), device=_input.device)
|
|
59
|
-
grad_weight = torch.zeros_like(weight) # [V, H]
|
|
60
|
-
grad_inputs = []
|
|
61
|
-
grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
|
|
62
|
-
aggregated_metrics = []
|
|
63
|
-
|
|
64
|
-
# Create a partial function with fixed arguments
|
|
65
|
-
compute_loss = partial(
|
|
66
|
-
LigerFusedLinearRLHFBase._compute_chunk_loss,
|
|
67
|
-
beta=beta,
|
|
68
|
-
use_ref_model=use_ref_model,
|
|
69
|
-
ref_weight=ref_weight,
|
|
70
|
-
ref_bias=ref_bias,
|
|
71
|
-
rlhf_loss_fn=cls.rlhf_loss_fn,
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
def fused_fwd_bwd(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk):
|
|
75
|
-
"""Fused forward and backward for a chunk."""
|
|
76
|
-
if bias is not None:
|
|
77
|
-
return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 5), has_aux=True)(
|
|
78
|
-
input_chunk, # arg 0
|
|
79
|
-
weight, # arg 1
|
|
80
|
-
attention_mask_chunk, # arg 2
|
|
81
|
-
rewards_chunk, # arg 3
|
|
82
|
-
ref_input_chunk, # arg 4
|
|
83
|
-
bias, # arg 5
|
|
84
|
-
)
|
|
85
|
-
else:
|
|
86
|
-
return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
|
|
87
|
-
input_chunk, # arg 0
|
|
88
|
-
weight, # arg 1
|
|
89
|
-
attention_mask_chunk, # arg 2
|
|
90
|
-
rewards_chunk, # arg 3
|
|
91
|
-
ref_input_chunk, # arg 4
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk=None):
|
|
95
|
-
if bias is not None:
|
|
96
|
-
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
|
|
97
|
-
input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
|
|
98
|
-
)
|
|
99
|
-
grad_bias.add_(chunk_grad_bias)
|
|
100
|
-
else:
|
|
101
|
-
(chunk_grad_input, chunk_grad_weight), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
|
|
102
|
-
input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
|
|
103
|
-
)
|
|
104
|
-
|
|
105
|
-
# Accumulate gradients and loss
|
|
106
|
-
grad_weight.add_(chunk_grad_weight)
|
|
107
|
-
grad_inputs.append(chunk_grad_input)
|
|
108
|
-
loss_acc.add_(chunk_loss)
|
|
109
|
-
|
|
110
|
-
# Initialize storage for metrics on first chunk
|
|
111
|
-
if len(aggregated_metrics) == 0:
|
|
112
|
-
for metric in chunk_metrics:
|
|
113
|
-
if metric.ndim == 0:
|
|
114
|
-
aggregated_metrics.append(torch.zeros((), device=metric.device))
|
|
115
|
-
else:
|
|
116
|
-
aggregated_metrics.append([])
|
|
117
|
-
|
|
118
|
-
# Accumulate metrics
|
|
119
|
-
for i, metric in enumerate(chunk_metrics):
|
|
120
|
-
if metric.ndim == 0:
|
|
121
|
-
aggregated_metrics[i].add_(metric)
|
|
122
|
-
else:
|
|
123
|
-
aggregated_metrics[i].append(metric)
|
|
124
|
-
|
|
125
|
-
if compiled:
|
|
126
|
-
accumulate_chunk = torch.compile(accumulate_chunk)
|
|
127
|
-
|
|
128
|
-
# Process input in chunks based on num_generations
|
|
129
|
-
chunks = max(1, _input.shape[0] // num_generations)
|
|
130
|
-
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
|
|
131
|
-
_attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
|
|
132
|
-
_rewards_chunks = torch.chunk(rewards, chunks=chunks, dim=0)
|
|
133
|
-
_ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) if use_ref_model else [None] * chunks
|
|
134
|
-
|
|
135
|
-
for input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk in zip(
|
|
136
|
-
_input_chunks, _attention_mask_chunks, _rewards_chunks, _ref_input_chunks
|
|
137
|
-
):
|
|
138
|
-
# Mark dynamic dimensions
|
|
139
|
-
torch._dynamo.mark_dynamic(input_chunk, 1)
|
|
140
|
-
torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
|
|
141
|
-
if ref_input_chunk is not None:
|
|
142
|
-
torch._dynamo.mark_dynamic(ref_input_chunk, 1)
|
|
143
|
-
|
|
144
|
-
accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk)
|
|
145
|
-
|
|
146
|
-
# Scale accumulated loss by number of chunks since we're averaging
|
|
147
|
-
loss_acc = loss_acc / chunks
|
|
148
|
-
|
|
149
|
-
# Combine gradients
|
|
150
|
-
grad_input = torch.cat(grad_inputs, dim=0)
|
|
151
|
-
|
|
152
|
-
# Save for backward
|
|
153
|
-
ctx.save_for_backward(grad_input, grad_weight, grad_bias)
|
|
154
|
-
|
|
155
|
-
# Finalize metrics
|
|
156
|
-
final_metrics = []
|
|
157
|
-
for metric in aggregated_metrics:
|
|
158
|
-
if isinstance(metric, list):
|
|
159
|
-
final_metrics.append(torch.cat(metric, dim=0))
|
|
160
|
-
else:
|
|
161
|
-
final_metrics.append(metric / chunks)
|
|
162
|
-
|
|
163
|
-
return loss_acc, tuple(final_metrics)
|
|
164
|
-
|
|
165
|
-
@staticmethod
|
|
166
|
-
def _compute_chunk_loss(
|
|
167
|
-
input_chunk,
|
|
168
|
-
weight,
|
|
169
|
-
attention_mask_chunk,
|
|
170
|
-
rewards_chunk,
|
|
171
|
-
ref_input_chunk=None,
|
|
172
|
-
bias=None,
|
|
173
|
-
beta=0.1,
|
|
174
|
-
use_ref_model=False,
|
|
175
|
-
ref_weight=None,
|
|
176
|
-
ref_bias=None,
|
|
177
|
-
rlhf_loss_fn=None,
|
|
178
|
-
):
|
|
179
|
-
"""Compute loss for a single chunk."""
|
|
180
|
-
# Get policy log probabilities using chunk_forward
|
|
181
|
-
log_probs, _, logits_mean = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias)
|
|
182
|
-
|
|
183
|
-
# Get reference log probabilities if needed
|
|
184
|
-
ref_log_probs = None
|
|
185
|
-
if use_ref_model and ref_input_chunk is not None:
|
|
186
|
-
with torch.no_grad():
|
|
187
|
-
ref_log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward(ref_input_chunk, ref_weight, bias=ref_bias)
|
|
188
|
-
|
|
189
|
-
# Compute chunk loss and metrics using the provided loss function
|
|
190
|
-
chunk_loss, chunk_metrics = rlhf_loss_fn(
|
|
191
|
-
log_probs=log_probs,
|
|
192
|
-
attention_mask=attention_mask_chunk,
|
|
193
|
-
rewards=rewards_chunk,
|
|
194
|
-
ref_log_probs=ref_log_probs,
|
|
195
|
-
beta=beta,
|
|
196
|
-
)
|
|
197
|
-
|
|
198
|
-
return chunk_loss, (logits_mean, *chunk_metrics)
|
|
199
|
-
|
|
200
|
-
@staticmethod
|
|
201
|
-
def chunk_forward(input_chunk, weight, bias=None):
|
|
202
|
-
"""Forward pass computation for a single chunk without explicit reshaping."""
|
|
203
|
-
# Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
|
|
204
|
-
logits = torch.matmul(input_chunk, weight.t())
|
|
205
|
-
if bias is not None:
|
|
206
|
-
logits = logits + bias # Broadcasts bias to [B, T, V]
|
|
207
|
-
|
|
208
|
-
# Compute log probabilities using softmax over the last dimension
|
|
209
|
-
log_probs = F.log_softmax(logits.float(), dim=-1)
|
|
210
|
-
|
|
211
|
-
# Monitoring: compute mean of logits
|
|
212
|
-
batch_size, seq_len, _ = input_chunk.shape
|
|
213
|
-
logits_mean = logits.sum() / (batch_size * seq_len * weight.shape[0])
|
|
214
|
-
return log_probs, logits, logits_mean
|
|
215
|
-
|
|
216
|
-
@staticmethod
|
|
217
|
-
def backward(ctx, grad_output, *grad_metrics):
|
|
218
|
-
"""Backward pass for RLHF loss."""
|
|
219
|
-
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
220
|
-
if grad_output != 1.0:
|
|
221
|
-
grad_input = grad_input * grad_output
|
|
222
|
-
grad_weight = grad_weight * grad_output
|
|
223
|
-
if grad_bias is not None:
|
|
224
|
-
grad_bias = grad_bias * grad_output
|
|
225
|
-
|
|
226
|
-
return (
|
|
227
|
-
grad_input,
|
|
228
|
-
grad_weight,
|
|
229
|
-
None, # grad_attention_mask
|
|
230
|
-
None, # grad_rewards
|
|
231
|
-
grad_bias,
|
|
232
|
-
None, # grad_num_generations
|
|
233
|
-
None, # grad_beta
|
|
234
|
-
None, # grad_compiled
|
|
235
|
-
None, # grad_use_ref_model
|
|
236
|
-
None, # grad_ref_input
|
|
237
|
-
None, # grad_ref_weight
|
|
238
|
-
None, # grad_ref_bias
|
|
239
|
-
None, # grad_chunk_size
|
|
240
|
-
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|