liger-kernel-nightly 0.5.5.dev20250331170510__py3-none-any.whl → 0.5.5.dev20250402185606__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +330 -0
- liger_kernel/chunked_loss/grpo_loss.py +103 -61
- liger_kernel/ops/cross_entropy.py +3 -2
- {liger_kernel_nightly-0.5.5.dev20250331170510.dist-info → liger_kernel_nightly-0.5.5.dev20250402185606.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.5.dev20250331170510.dist-info → liger_kernel_nightly-0.5.5.dev20250402185606.dist-info}/RECORD +10 -10
- liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
- {liger_kernel_nightly-0.5.5.dev20250331170510.dist-info → liger_kernel_nightly-0.5.5.dev20250402185606.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510.dist-info → liger_kernel_nightly-0.5.5.dev20250402185606.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510.dist-info → liger_kernel_nightly-0.5.5.dev20250402185606.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.5.dev20250331170510.dist-info → liger_kernel_nightly-0.5.5.dev20250402185606.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
|
|
1
1
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
|
2
2
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
|
3
|
+
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
|
3
4
|
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
|
4
5
|
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
|
5
6
|
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
|
@@ -11,3 +12,4 @@ liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
|
|
11
12
|
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
|
12
13
|
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
|
13
14
|
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
|
15
|
+
liger_fused_linear_grpo = LigerFusedLinearGRPOFunction.apply
|
@@ -0,0 +1,330 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from functools import partial
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch._dynamo.config
|
6
|
+
import torch.nn.functional as F
|
7
|
+
|
8
|
+
|
9
|
+
class LigerFusedLinearPPOBase(torch.autograd.Function):
|
10
|
+
@abstractmethod
|
11
|
+
def ppo_loss_fn(*args, **kwargs):
|
12
|
+
"""
|
13
|
+
To be extended by subclasses.
|
14
|
+
"""
|
15
|
+
raise NotImplementedError("PPO loss function must be implemented.")
|
16
|
+
|
17
|
+
@staticmethod
|
18
|
+
def forward(
|
19
|
+
cls,
|
20
|
+
ctx,
|
21
|
+
_input,
|
22
|
+
weight,
|
23
|
+
selected_token_ids,
|
24
|
+
attention_mask,
|
25
|
+
advantages,
|
26
|
+
bias=None,
|
27
|
+
ref_per_token_logps=None,
|
28
|
+
old_per_token_logps=None,
|
29
|
+
ref_input=None,
|
30
|
+
ref_weight=None,
|
31
|
+
ref_bias=None,
|
32
|
+
epsilon_low=0.2,
|
33
|
+
epsilon_high=0.2,
|
34
|
+
beta=0.04,
|
35
|
+
temperature=1.0,
|
36
|
+
compiled=True,
|
37
|
+
use_ref_model=False,
|
38
|
+
chunk_size=1,
|
39
|
+
):
|
40
|
+
"""Chunked forward pass for PPO loss computation.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
cls: The class
|
44
|
+
ctx: Context for backward
|
45
|
+
_input: Input tensor
|
46
|
+
weight: Weight tensor
|
47
|
+
selected_token_ids: Selected token ids tensor
|
48
|
+
attention_mask: Attention mask tensor
|
49
|
+
advantages: Advantages tensor
|
50
|
+
bias: Bias tensor
|
51
|
+
ref_per_token_logps: Reference model log probs per token tensor
|
52
|
+
old_per_token_logps: Old per token log probabilities tensor
|
53
|
+
ref_input: Reference model input tensor
|
54
|
+
ref_weight: Reference model weight tensor
|
55
|
+
ref_bias: Reference model bias tensor
|
56
|
+
epsilon_low: Lower bound for clipping the importance sampling ratio
|
57
|
+
epsilon_high: Upper bound for clipping the importance sampling ratio
|
58
|
+
beta: Weight for the KL penalty
|
59
|
+
temperature: Temperature for the logits
|
60
|
+
compiled: Whether to use torch compile
|
61
|
+
use_ref_model: Whether to use a reference model
|
62
|
+
chunk_size: Size of chunks for processing in other loss modules
|
63
|
+
"""
|
64
|
+
if use_ref_model:
|
65
|
+
assert ref_per_token_logps is not None or ref_input is not None, (
|
66
|
+
"If use_ref_model is True, ref_per_token_logps or ref_input must be provided"
|
67
|
+
)
|
68
|
+
if ref_per_token_logps is not None and ref_input is not None:
|
69
|
+
raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
|
70
|
+
# Initialize accumulators
|
71
|
+
loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
|
72
|
+
grad_weight = torch.zeros_like(weight) # [V, H]
|
73
|
+
grad_inputs = []
|
74
|
+
grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
|
75
|
+
aggregated_metrics = []
|
76
|
+
|
77
|
+
# Create a partial function with fixed arguments
|
78
|
+
compute_loss = partial(
|
79
|
+
LigerFusedLinearPPOBase._compute_chunk_loss,
|
80
|
+
ref_weight=ref_weight,
|
81
|
+
ref_bias=ref_bias,
|
82
|
+
full_attention_mask=attention_mask,
|
83
|
+
epsilon_low=epsilon_low,
|
84
|
+
epsilon_high=epsilon_high,
|
85
|
+
beta=beta,
|
86
|
+
temperature=temperature,
|
87
|
+
use_ref_model=use_ref_model,
|
88
|
+
ppo_loss_fn=cls.ppo_loss_fn,
|
89
|
+
)
|
90
|
+
|
91
|
+
def fused_fwd_bwd(
|
92
|
+
input_chunk,
|
93
|
+
selected_token_ids_chunk,
|
94
|
+
attention_mask_chunk,
|
95
|
+
advantages_chunk,
|
96
|
+
ref_per_token_logps_chunk,
|
97
|
+
old_per_token_logps_chunk,
|
98
|
+
ref_input_chunk,
|
99
|
+
):
|
100
|
+
"""Fused forward and backward for a chunk."""
|
101
|
+
argnums = (0, 1, 5) if bias is not None else (0, 1)
|
102
|
+
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
|
103
|
+
input_chunk, # arg 0
|
104
|
+
weight, # arg 1
|
105
|
+
selected_token_ids_chunk, # arg 2
|
106
|
+
attention_mask_chunk, # arg 3
|
107
|
+
advantages_chunk, # arg 4
|
108
|
+
bias, # arg 5
|
109
|
+
ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6
|
110
|
+
old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7
|
111
|
+
ref_input_chunk=ref_input_chunk, # arg 8
|
112
|
+
)
|
113
|
+
|
114
|
+
def accumulate_chunk(
|
115
|
+
input_chunk,
|
116
|
+
selected_token_ids_chunk,
|
117
|
+
attention_mask_chunk,
|
118
|
+
advantages_chunk,
|
119
|
+
ref_per_token_logps_chunk=None,
|
120
|
+
old_per_token_logps_chunk=None,
|
121
|
+
ref_input_chunk=None,
|
122
|
+
):
|
123
|
+
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
|
124
|
+
input_chunk,
|
125
|
+
selected_token_ids_chunk,
|
126
|
+
attention_mask_chunk,
|
127
|
+
advantages_chunk,
|
128
|
+
ref_per_token_logps_chunk,
|
129
|
+
old_per_token_logps_chunk,
|
130
|
+
ref_input_chunk,
|
131
|
+
)
|
132
|
+
if bias is not None:
|
133
|
+
grad_bias.add_(chunk_grad_bias[0])
|
134
|
+
|
135
|
+
# Accumulate gradients and loss
|
136
|
+
grad_weight.add_(chunk_grad_weight)
|
137
|
+
grad_inputs.append(chunk_grad_input)
|
138
|
+
loss_acc.add_(chunk_loss)
|
139
|
+
# Initialize storage for metrics on first chunk
|
140
|
+
if len(aggregated_metrics) == 0:
|
141
|
+
for metric in chunk_metrics:
|
142
|
+
if metric.ndim == 0:
|
143
|
+
aggregated_metrics.append(torch.zeros((), device=metric.device))
|
144
|
+
else:
|
145
|
+
aggregated_metrics.append([])
|
146
|
+
|
147
|
+
# Accumulate metrics
|
148
|
+
for i, metric in enumerate(chunk_metrics):
|
149
|
+
if metric.ndim == 0:
|
150
|
+
aggregated_metrics[i].add_(metric)
|
151
|
+
else:
|
152
|
+
aggregated_metrics[i].append(metric)
|
153
|
+
|
154
|
+
if compiled:
|
155
|
+
# TODO: Figure out what is better to compile here
|
156
|
+
# accumulate_chunk = torch.compile(accumulate_chunk)
|
157
|
+
fused_fwd_bwd = torch.compile(fused_fwd_bwd)
|
158
|
+
|
159
|
+
# Process input in chunks based on chunk_size
|
160
|
+
chunks = max(1, _input.shape[0] // chunk_size)
|
161
|
+
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
|
162
|
+
_selected_token_ids_chunks = torch.chunk(selected_token_ids, chunks=chunks, dim=0)
|
163
|
+
_attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
|
164
|
+
_advantages_chunks = torch.chunk(advantages, chunks=chunks, dim=0)
|
165
|
+
_ref_per_token_logps_chunks = (
|
166
|
+
torch.chunk(ref_per_token_logps, chunks=chunks, dim=0)
|
167
|
+
if use_ref_model and ref_per_token_logps is not None
|
168
|
+
else [None] * chunks
|
169
|
+
)
|
170
|
+
_old_per_token_logps_chunks = (
|
171
|
+
torch.chunk(old_per_token_logps, chunks=chunks, dim=0)
|
172
|
+
if old_per_token_logps is not None
|
173
|
+
else [None] * chunks
|
174
|
+
)
|
175
|
+
# if ref_log_probs is not none, then we don't need ref_input to calculate the log probs
|
176
|
+
_ref_input_chunks = (
|
177
|
+
torch.chunk(ref_input, chunks=chunks, dim=0)
|
178
|
+
if use_ref_model and ref_per_token_logps is None
|
179
|
+
else [None] * chunks
|
180
|
+
)
|
181
|
+
|
182
|
+
for (
|
183
|
+
input_chunk,
|
184
|
+
selected_token_ids_chunk,
|
185
|
+
attention_mask_chunk,
|
186
|
+
advantages_chunk,
|
187
|
+
ref_per_token_logps_chunk,
|
188
|
+
old_per_token_logps_chunk,
|
189
|
+
ref_input_chunk,
|
190
|
+
) in zip(
|
191
|
+
_input_chunks,
|
192
|
+
_selected_token_ids_chunks,
|
193
|
+
_attention_mask_chunks,
|
194
|
+
_advantages_chunks,
|
195
|
+
_ref_per_token_logps_chunks,
|
196
|
+
_old_per_token_logps_chunks,
|
197
|
+
_ref_input_chunks,
|
198
|
+
):
|
199
|
+
# Mark dynamic dimensions
|
200
|
+
torch._dynamo.mark_dynamic(input_chunk, 1)
|
201
|
+
torch._dynamo.mark_dynamic(selected_token_ids_chunk, 1)
|
202
|
+
torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
|
203
|
+
if ref_per_token_logps_chunk is not None:
|
204
|
+
torch._dynamo.mark_dynamic(ref_per_token_logps_chunk, 1)
|
205
|
+
if ref_input_chunk is not None:
|
206
|
+
torch._dynamo.mark_dynamic(ref_input_chunk, 1)
|
207
|
+
if old_per_token_logps_chunk is not None:
|
208
|
+
torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1)
|
209
|
+
|
210
|
+
accumulate_chunk(
|
211
|
+
input_chunk,
|
212
|
+
selected_token_ids_chunk,
|
213
|
+
attention_mask_chunk,
|
214
|
+
advantages_chunk,
|
215
|
+
ref_per_token_logps_chunk,
|
216
|
+
old_per_token_logps_chunk,
|
217
|
+
ref_input_chunk,
|
218
|
+
)
|
219
|
+
|
220
|
+
# Combine gradients
|
221
|
+
grad_input = torch.cat(grad_inputs, dim=0)
|
222
|
+
|
223
|
+
# Save for backward
|
224
|
+
ctx.save_for_backward(grad_input, grad_weight, grad_bias)
|
225
|
+
|
226
|
+
# Finalize metrics
|
227
|
+
final_metrics = []
|
228
|
+
for metric in aggregated_metrics:
|
229
|
+
if isinstance(metric, list):
|
230
|
+
final_metrics.append(torch.cat(metric, dim=0))
|
231
|
+
else:
|
232
|
+
final_metrics.append(metric)
|
233
|
+
|
234
|
+
return loss_acc, tuple(final_metrics)
|
235
|
+
|
236
|
+
@staticmethod
|
237
|
+
def _compute_chunk_loss(
|
238
|
+
input_chunk,
|
239
|
+
weight,
|
240
|
+
selected_token_ids_chunk,
|
241
|
+
attention_mask_chunk,
|
242
|
+
advantages_chunk,
|
243
|
+
bias=None,
|
244
|
+
ref_per_token_logps_chunk=None,
|
245
|
+
old_per_token_logps_chunk=None,
|
246
|
+
ref_input_chunk=None,
|
247
|
+
ref_weight=None,
|
248
|
+
ref_bias=None,
|
249
|
+
full_attention_mask=None,
|
250
|
+
epsilon_low=0.2,
|
251
|
+
epsilon_high=0.2,
|
252
|
+
beta=0.04,
|
253
|
+
temperature=1.0,
|
254
|
+
use_ref_model=False,
|
255
|
+
ppo_loss_fn=None,
|
256
|
+
):
|
257
|
+
"""Compute loss for a single chunk."""
|
258
|
+
# Get policy log probabilities using chunk_forward
|
259
|
+
log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(input_chunk, weight, bias=bias, temperature=temperature)
|
260
|
+
|
261
|
+
# Get reference log probabilities if needed
|
262
|
+
ref_log_probs = None
|
263
|
+
if use_ref_model and ref_per_token_logps_chunk is None:
|
264
|
+
with torch.no_grad():
|
265
|
+
ref_log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(
|
266
|
+
ref_input_chunk, ref_weight, bias=ref_bias, temperature=temperature
|
267
|
+
)
|
268
|
+
|
269
|
+
# Compute chunk loss and metrics using the provided loss function
|
270
|
+
chunk_loss, chunk_metrics = ppo_loss_fn(
|
271
|
+
log_probs=log_probs,
|
272
|
+
selected_token_ids=selected_token_ids_chunk,
|
273
|
+
attention_mask=attention_mask_chunk,
|
274
|
+
advantages=advantages_chunk,
|
275
|
+
full_attention_mask=full_attention_mask,
|
276
|
+
ref_per_token_logps=ref_per_token_logps_chunk.float() if ref_per_token_logps_chunk is not None else None,
|
277
|
+
old_per_token_logps=old_per_token_logps_chunk.float() if old_per_token_logps_chunk is not None else None,
|
278
|
+
ref_log_probs=ref_log_probs, # used when ref_per_token_logps is None
|
279
|
+
epsilon_low=epsilon_low,
|
280
|
+
epsilon_high=epsilon_high,
|
281
|
+
beta=beta,
|
282
|
+
)
|
283
|
+
|
284
|
+
return chunk_loss, chunk_metrics
|
285
|
+
|
286
|
+
@staticmethod
|
287
|
+
def chunk_forward(input_chunk, weight, bias=None, temperature=1.0):
|
288
|
+
"""Forward pass computation for a single chunk without explicit reshaping."""
|
289
|
+
# Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
|
290
|
+
logits = torch.matmul(input_chunk, weight.t())
|
291
|
+
if bias is not None:
|
292
|
+
logits = logits + bias # Broadcasts bias to [B, T, V]
|
293
|
+
if temperature != 1.0:
|
294
|
+
logits = logits / temperature
|
295
|
+
|
296
|
+
# Compute log probabilities using softmax over the last dimension
|
297
|
+
log_probs = F.log_softmax(logits.float(), dim=-1)
|
298
|
+
|
299
|
+
return log_probs, logits
|
300
|
+
|
301
|
+
@staticmethod
|
302
|
+
def backward(ctx, grad_output, *grad_metrics):
|
303
|
+
"""Backward pass for PPO loss."""
|
304
|
+
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
305
|
+
if grad_output != 1.0:
|
306
|
+
grad_input = grad_input * grad_output
|
307
|
+
grad_weight = grad_weight * grad_output
|
308
|
+
if grad_bias is not None:
|
309
|
+
grad_bias = grad_bias * grad_output
|
310
|
+
|
311
|
+
return (
|
312
|
+
grad_input,
|
313
|
+
grad_weight,
|
314
|
+
None, # grad_selected_token_ids
|
315
|
+
None, # grad_attention_mask
|
316
|
+
None, # grad_advantages
|
317
|
+
grad_bias,
|
318
|
+
None, # grad_ref_per_token_logps
|
319
|
+
None, # grad_old_per_token_logps
|
320
|
+
None, # grad_ref_input
|
321
|
+
None, # grad_ref_weight
|
322
|
+
None, # grad_ref_bias
|
323
|
+
None, # grad_epsilon_low
|
324
|
+
None, # grad_epsilon_high
|
325
|
+
None, # grad_beta
|
326
|
+
None, # grad_temperature
|
327
|
+
None, # grad_compiled
|
328
|
+
None, # grad_use_ref_model
|
329
|
+
None, # grad_chunk_size
|
330
|
+
)
|
@@ -1,66 +1,76 @@
|
|
1
1
|
import torch
|
2
2
|
|
3
|
-
from liger_kernel.chunked_loss.
|
3
|
+
from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
|
4
4
|
|
5
5
|
|
6
|
-
|
6
|
+
def k3_loss_fn(log_p, log_q):
|
7
|
+
# computes k3 estimate of KL[q, p]
|
8
|
+
# ref: http://joschu.net/blog/kl-approx.html
|
9
|
+
return torch.exp(log_p - log_q) - (log_p - log_q) - 1.0
|
10
|
+
|
11
|
+
|
12
|
+
def clip_coef_fn(coef, epsilon_low, epsilon_high):
|
13
|
+
return torch.clamp(coef, 1 - epsilon_low, 1 + epsilon_high)
|
14
|
+
|
15
|
+
|
16
|
+
class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
7
17
|
@staticmethod
|
8
|
-
def
|
18
|
+
def ppo_loss_fn(
|
9
19
|
log_probs,
|
20
|
+
selected_token_ids,
|
10
21
|
attention_mask,
|
11
|
-
|
12
|
-
|
13
|
-
|
22
|
+
advantages,
|
23
|
+
full_attention_mask,
|
24
|
+
ref_per_token_logps=None, # shape: [chunk_size, seq_len]
|
25
|
+
old_per_token_logps=None,
|
26
|
+
ref_log_probs=None, # used when ref_per_token_logps is None (shape: [chunk_size, seq_len, vocab_size])
|
27
|
+
epsilon_low=0.2,
|
28
|
+
epsilon_high=0.2,
|
29
|
+
beta=0.04,
|
14
30
|
**kwargs,
|
15
31
|
):
|
16
32
|
"""GRPO Loss Function matching GRPOTrainer implementation."""
|
17
|
-
|
18
|
-
chosen_tokens = log_probs.argmax(dim=-1) # (batch_size, seq_len)
|
19
|
-
chosen_token_logprobs = log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(
|
33
|
+
per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
|
20
34
|
-1
|
21
35
|
) # (batch_size, seq_len)
|
22
36
|
|
23
37
|
# Get reference model probabilities
|
24
|
-
if
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
std_grouped_rewards = rewards.std() # [batch_size,]
|
33
|
-
|
34
|
-
# Calculate advantages using the same epsilon as in GRPOTrainer
|
35
|
-
eps = 1e-4
|
36
|
-
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + eps)
|
38
|
+
if ref_per_token_logps is None:
|
39
|
+
if ref_log_probs is not None:
|
40
|
+
with torch.no_grad():
|
41
|
+
ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
|
42
|
+
-1
|
43
|
+
)
|
44
|
+
else:
|
45
|
+
ref_per_token_logps = per_token_logps.detach()
|
37
46
|
|
38
47
|
# Compute policy gradient loss with importance sampling ratio
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
48
|
+
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
|
49
|
+
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
|
50
|
+
coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
|
51
|
+
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
52
|
+
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
53
|
+
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
54
|
+
if beta != 0.0:
|
55
|
+
# Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
|
56
|
+
kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
|
57
|
+
# Combine losses
|
58
|
+
per_token_loss = per_token_loss + beta * kl_div
|
59
|
+
|
60
|
+
# Note: We normalize by the number of tokens in the batch (using full_attention_mask),
|
61
|
+
# which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
|
62
|
+
# and TRL GRPO implementation
|
63
|
+
# (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
|
64
|
+
loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
|
55
65
|
|
56
66
|
# Calculate metrics
|
57
|
-
metrics =
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
(
|
67
|
+
metrics = []
|
68
|
+
if beta != 0.0:
|
69
|
+
metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
|
70
|
+
is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
|
71
|
+
(coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
|
62
72
|
)
|
63
|
-
|
73
|
+
metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
|
64
74
|
return loss, metrics
|
65
75
|
|
66
76
|
@classmethod
|
@@ -69,16 +79,21 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
|
|
69
79
|
ctx,
|
70
80
|
_input,
|
71
81
|
weight,
|
82
|
+
selected_token_ids,
|
72
83
|
attention_mask,
|
73
|
-
|
84
|
+
advantages,
|
74
85
|
bias=None,
|
86
|
+
ref_per_token_logps=None,
|
87
|
+
old_per_token_logps=None,
|
75
88
|
ref_input=None,
|
76
89
|
ref_weight=None,
|
77
90
|
ref_bias=None,
|
78
|
-
beta=0.
|
91
|
+
beta=0.04,
|
92
|
+
epsilon_low=0.2,
|
93
|
+
epsilon_high=0.2,
|
94
|
+
temperature=1.0,
|
79
95
|
compiled=True,
|
80
96
|
use_ref_model=True,
|
81
|
-
num_generations=1,
|
82
97
|
chunk_size=1,
|
83
98
|
):
|
84
99
|
"""
|
@@ -86,16 +101,18 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
|
|
86
101
|
Args:
|
87
102
|
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
88
103
|
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
104
|
+
selected_token_ids (torch.Tensor): Selected token ids tensor. Shape: (batch_size, seq_len)
|
89
105
|
attention_mask (torch.Tensor): Attention mask tensor. Shape: (batch_size, seq_len)
|
90
|
-
|
106
|
+
advantages (torch.Tensor): Advantages tensor. Shape: (batch_size,)
|
91
107
|
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
108
|
+
ref_per_token_logps: Reference model log probs per token tensor. Shape:(batch_size, seq_len)
|
92
109
|
ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
|
93
110
|
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
94
111
|
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
95
112
|
beta (float): Weight for the KL penalty
|
113
|
+
temperature (float): Temperature for the logits
|
96
114
|
compiled (bool): Whether to use torch compile
|
97
115
|
use_ref_model (bool): Whether to use a reference model
|
98
|
-
num_generations (int): Number of generations per prompt
|
99
116
|
chunk_size (int): Size of chunks for processing.
|
100
117
|
Returns:
|
101
118
|
torch.Tensor: Computed loss
|
@@ -105,16 +122,21 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
|
|
105
122
|
ctx=ctx,
|
106
123
|
_input=_input,
|
107
124
|
weight=weight,
|
125
|
+
selected_token_ids=selected_token_ids,
|
108
126
|
attention_mask=attention_mask,
|
109
|
-
|
127
|
+
advantages=advantages,
|
110
128
|
bias=bias,
|
129
|
+
ref_per_token_logps=ref_per_token_logps,
|
130
|
+
old_per_token_logps=old_per_token_logps,
|
111
131
|
ref_input=ref_input,
|
112
132
|
ref_weight=ref_weight,
|
113
133
|
ref_bias=ref_bias,
|
114
134
|
beta=beta,
|
135
|
+
epsilon_low=epsilon_low,
|
136
|
+
epsilon_high=epsilon_high,
|
137
|
+
temperature=temperature,
|
115
138
|
compiled=compiled,
|
116
139
|
use_ref_model=use_ref_model,
|
117
|
-
num_generations=num_generations,
|
118
140
|
chunk_size=chunk_size,
|
119
141
|
)
|
120
142
|
|
@@ -126,16 +148,22 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
|
|
126
148
|
grad_output: Gradient of the loss (scalar)
|
127
149
|
grad_metrics: Gradients of the metrics (not used in backward computation)
|
128
150
|
"""
|
129
|
-
grads =
|
151
|
+
grads = LigerFusedLinearPPOBase.backward(ctx, grad_output)
|
130
152
|
return (
|
131
|
-
*grads[
|
153
|
+
*grads[
|
154
|
+
:6
|
155
|
+
], # grad_input, grad_weight, grad_selected_token_ids, grad_attention_mask, grad_advantages, grad_bias
|
156
|
+
None, # grad_ref_per_token_logps
|
157
|
+
None, # grad_old_per_token_logps
|
132
158
|
None, # grad_ref_input
|
133
159
|
None, # grad_ref_weight
|
134
160
|
None, # grad_ref_bias
|
135
161
|
None, # grad_beta
|
162
|
+
None, # grad_epsilon_low
|
163
|
+
None, # grad_epsilon_high
|
164
|
+
None, # grad_temperature
|
136
165
|
None, # grad_compiled
|
137
166
|
None, # grad_use_ref_model
|
138
|
-
None, # grad_num_generations
|
139
167
|
None, # grad_chunk_size
|
140
168
|
)
|
141
169
|
|
@@ -145,34 +173,43 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
145
173
|
|
146
174
|
def __init__(
|
147
175
|
self,
|
148
|
-
beta: float = 0.
|
176
|
+
beta: float = 0.04,
|
149
177
|
compiled: bool = True,
|
150
178
|
use_ref_model: bool = True,
|
151
|
-
num_generations: int = 1,
|
152
179
|
chunk_size: int = 1,
|
180
|
+
epsilon_low: float = 0.2,
|
181
|
+
epsilon_high: float = 0.2,
|
182
|
+
temperature: float = 1.0,
|
153
183
|
):
|
154
184
|
"""
|
155
185
|
Args:
|
156
186
|
beta (float): Weight for the KL penalty.
|
157
187
|
compiled (bool): Whether to use torch compile.
|
158
188
|
use_ref_model (bool): Whether to use a reference model.
|
159
|
-
num_generations (int): Number of generations per prompt.
|
160
189
|
chunk_size (int): Size of chunks for processing.
|
190
|
+
epsilon_low (float): Lower bound for the importance sampling ratio.
|
191
|
+
epsilon_high (float): Upper bound for the importance sampling ratio.
|
192
|
+
temperature (float): Temperature for the logits.
|
161
193
|
"""
|
162
194
|
super().__init__()
|
163
195
|
self.beta = beta
|
164
196
|
self.compiled = compiled
|
165
197
|
self.use_ref_model = use_ref_model
|
166
|
-
self.num_generations = num_generations
|
167
198
|
self.chunk_size = chunk_size
|
199
|
+
self.epsilon_low = epsilon_low
|
200
|
+
self.epsilon_high = epsilon_high
|
201
|
+
self.temperature = temperature
|
168
202
|
|
169
203
|
def forward(
|
170
204
|
self,
|
171
205
|
_input,
|
172
206
|
lin_weight,
|
207
|
+
selected_token_ids,
|
173
208
|
attention_mask,
|
174
|
-
|
209
|
+
advantages,
|
175
210
|
bias=None,
|
211
|
+
ref_per_token_logps=None,
|
212
|
+
old_per_token_logps=None,
|
176
213
|
ref_input=None,
|
177
214
|
ref_weight=None,
|
178
215
|
ref_bias=None,
|
@@ -180,15 +217,20 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
180
217
|
return LigerFusedLinearGRPOFunction.apply(
|
181
218
|
_input,
|
182
219
|
lin_weight,
|
220
|
+
selected_token_ids,
|
183
221
|
attention_mask,
|
184
|
-
|
222
|
+
advantages,
|
185
223
|
bias,
|
224
|
+
ref_per_token_logps,
|
225
|
+
old_per_token_logps,
|
186
226
|
ref_input,
|
187
227
|
ref_weight,
|
188
228
|
ref_bias,
|
189
229
|
self.beta,
|
230
|
+
self.epsilon_low,
|
231
|
+
self.epsilon_high,
|
232
|
+
self.temperature,
|
190
233
|
self.compiled,
|
191
234
|
self.use_ref_model,
|
192
|
-
self.num_generations,
|
193
235
|
self.chunk_size,
|
194
236
|
)
|
@@ -9,6 +9,7 @@ import triton.language as tl
|
|
9
9
|
from liger_kernel.ops.utils import compare_version
|
10
10
|
from liger_kernel.ops.utils import element_mul_kernel
|
11
11
|
from liger_kernel.ops.utils import is_hip
|
12
|
+
from liger_kernel.utils import infer_device
|
12
13
|
|
13
14
|
if compare_version("triton", operator.ge, "3.0.0"):
|
14
15
|
try:
|
@@ -59,7 +60,7 @@ def liger_cross_entropy_kernel(
|
|
59
60
|
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
60
61
|
loss_stride (int): The stride of the loss tensor.
|
61
62
|
n_cols (int): The number of columns in the input tensor.
|
62
|
-
n_non_ignore (
|
63
|
+
n_non_ignore (float): The number of non-ignored elements in the batch.
|
63
64
|
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
64
65
|
weight_sum (float): The sum of weight tensor.
|
65
66
|
ignore_index (int): The index to ignore in the target.
|
@@ -258,7 +259,7 @@ def liger_cross_entropy_kernel(
|
|
258
259
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
259
260
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
260
261
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
261
|
-
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
|
262
|
+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
|
262
263
|
|
263
264
|
|
264
265
|
def cross_entropy_forward(
|
@@ -5,18 +5,18 @@ liger_kernel/chunked_loss/README.md,sha256=0FmkFC3hKBqyoDT5uTlIYmrvRkF-EOCR1y-EB
|
|
5
5
|
liger_kernel/chunked_loss/__init__.py,sha256=ATu-xX5Fc49Cr6yBOGBRNTo593ZrU5ZCsIuvoIbJWw4,603
|
6
6
|
liger_kernel/chunked_loss/cpo_loss.py,sha256=Gzz1eU4kgcbdubFVRy55e8A1Cr-r45UgNicXwZIjmBU,5454
|
7
7
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=xZwGqS04si9zXyob95SAdalC-hajZg8fWINqiqffN8k,5855
|
8
|
-
liger_kernel/chunked_loss/functional.py,sha256=
|
8
|
+
liger_kernel/chunked_loss/functional.py,sha256=9G3nKm-Bi7uoZRFkL8wwGMl6juDl4bSzDvTa5GHZPzg,955
|
9
9
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=ooR-qnZCyWJN935oHCSWLaKKKyaYERyhNczRGi1VOiw,11935
|
10
|
+
liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=2_UvvIksUP45RBw3c-88-jOtjGATf04vaWopcqtX4Oo,12688
|
10
11
|
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=ojB42jYPu0c4ki96Ft-hy7Sf6fh_WikG-aWNrlZzSio,18362
|
11
|
-
liger_kernel/chunked_loss/fused_linear_rlhf.py,sha256=wGujqwLz91mOE9MmdenhBIKvbmswhwtINMCpcP7D74c,9050
|
12
12
|
liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=RiuK3UtRwH9T6jZ36sA8Urj-TVuOLOO2syLg_JOQapY,13437
|
13
|
-
liger_kernel/chunked_loss/grpo_loss.py,sha256=
|
13
|
+
liger_kernel/chunked_loss/grpo_loss.py,sha256=6Mb4ZT6MfnOr4Xo681rMR0LKkhzJhInvQp8wp2YVMK0,8913
|
14
14
|
liger_kernel/chunked_loss/jsd_loss.py,sha256=u2ahkuHsbhpNaKcpBCz5gCMDk9ou-P04DHji592dIBo,7067
|
15
15
|
liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsmSbQyqwQY,7529
|
16
16
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
|
17
17
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=fy2w8KbhMrBv7b1jdIeH3bBFxY52bPQPZb3KwBvmurM,5385
|
18
18
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
liger_kernel/ops/cross_entropy.py,sha256=
|
19
|
+
liger_kernel/ops/cross_entropy.py,sha256=T5oSsqOS1y-Iea5o9v_BSU-_mIEXqWAT1oX_m59NcA4,18941
|
20
20
|
liger_kernel/ops/dyt.py,sha256=YD1-buHz9VmIX838VKzLc-lm5CeUQ4LAskGDWBUMQHA,6187
|
21
21
|
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=1Y3Uk_TCSjqKgoG2eot1ptnWXJXXQESqGvOmqAW1gsM,10912
|
22
22
|
liger_kernel/ops/fused_linear_jsd.py,sha256=Seshez2qaM6HiTQ8_HEqSwhaeVruNT1SvIM4ZrAPBEU,9602
|
@@ -72,9 +72,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
72
72
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
73
73
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
74
74
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
75
|
-
liger_kernel_nightly-0.5.5.
|
76
|
-
liger_kernel_nightly-0.5.5.
|
77
|
-
liger_kernel_nightly-0.5.5.
|
78
|
-
liger_kernel_nightly-0.5.5.
|
79
|
-
liger_kernel_nightly-0.5.5.
|
80
|
-
liger_kernel_nightly-0.5.5.
|
75
|
+
liger_kernel_nightly-0.5.5.dev20250402185606.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
76
|
+
liger_kernel_nightly-0.5.5.dev20250402185606.dist-info/METADATA,sha256=XQaGc9bnsEFdwtLh1Mv5_fX-TIejLbcHk1SP-FEY5ew,22959
|
77
|
+
liger_kernel_nightly-0.5.5.dev20250402185606.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
78
|
+
liger_kernel_nightly-0.5.5.dev20250402185606.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
79
|
+
liger_kernel_nightly-0.5.5.dev20250402185606.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
80
|
+
liger_kernel_nightly-0.5.5.dev20250402185606.dist-info/RECORD,,
|
@@ -1,240 +0,0 @@
|
|
1
|
-
from abc import abstractmethod
|
2
|
-
from functools import partial
|
3
|
-
|
4
|
-
import torch
|
5
|
-
import torch.nn.functional as F
|
6
|
-
|
7
|
-
|
8
|
-
class LigerFusedLinearRLHFBase(torch.autograd.Function):
|
9
|
-
@abstractmethod
|
10
|
-
def rlhf_loss_fn(*args, **kwargs):
|
11
|
-
"""
|
12
|
-
To be extended by subclasses.
|
13
|
-
"""
|
14
|
-
raise NotImplementedError("RLHF loss function must be implemented.")
|
15
|
-
|
16
|
-
@staticmethod
|
17
|
-
def forward(
|
18
|
-
cls,
|
19
|
-
ctx,
|
20
|
-
_input,
|
21
|
-
weight,
|
22
|
-
attention_mask,
|
23
|
-
rewards,
|
24
|
-
bias=None,
|
25
|
-
num_generations=4,
|
26
|
-
beta=0.1,
|
27
|
-
compiled=True,
|
28
|
-
use_ref_model=False,
|
29
|
-
ref_input=None,
|
30
|
-
ref_weight=None,
|
31
|
-
ref_bias=None,
|
32
|
-
chunk_size=1,
|
33
|
-
):
|
34
|
-
"""Chunked forward pass for RLHF loss computation.
|
35
|
-
|
36
|
-
Args:
|
37
|
-
cls: The class
|
38
|
-
ctx: Context for backward
|
39
|
-
_input: Input tensor
|
40
|
-
weight: Weight tensor
|
41
|
-
attention_mask: Attention mask tensor
|
42
|
-
rewards: Rewards tensor
|
43
|
-
bias: Bias tensor
|
44
|
-
num_generations: Number of generations per prompt
|
45
|
-
beta: Weight for the KL penalty
|
46
|
-
compiled: Whether to use torch compile
|
47
|
-
use_ref_model: Whether to use a reference model
|
48
|
-
ref_input: Reference model input tensor
|
49
|
-
ref_weight: Reference model weight tensor
|
50
|
-
ref_bias: Reference model bias tensor
|
51
|
-
chunk_size: Size of chunks for processing in other loss modules
|
52
|
-
"""
|
53
|
-
# Save for backward
|
54
|
-
ctx.beta = beta
|
55
|
-
ctx.rewards = rewards
|
56
|
-
|
57
|
-
# Initialize accumulators
|
58
|
-
loss_acc = torch.zeros((), device=_input.device)
|
59
|
-
grad_weight = torch.zeros_like(weight) # [V, H]
|
60
|
-
grad_inputs = []
|
61
|
-
grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
|
62
|
-
aggregated_metrics = []
|
63
|
-
|
64
|
-
# Create a partial function with fixed arguments
|
65
|
-
compute_loss = partial(
|
66
|
-
LigerFusedLinearRLHFBase._compute_chunk_loss,
|
67
|
-
beta=beta,
|
68
|
-
use_ref_model=use_ref_model,
|
69
|
-
ref_weight=ref_weight,
|
70
|
-
ref_bias=ref_bias,
|
71
|
-
rlhf_loss_fn=cls.rlhf_loss_fn,
|
72
|
-
)
|
73
|
-
|
74
|
-
def fused_fwd_bwd(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk):
|
75
|
-
"""Fused forward and backward for a chunk."""
|
76
|
-
if bias is not None:
|
77
|
-
return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 5), has_aux=True)(
|
78
|
-
input_chunk, # arg 0
|
79
|
-
weight, # arg 1
|
80
|
-
attention_mask_chunk, # arg 2
|
81
|
-
rewards_chunk, # arg 3
|
82
|
-
ref_input_chunk, # arg 4
|
83
|
-
bias, # arg 5
|
84
|
-
)
|
85
|
-
else:
|
86
|
-
return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
|
87
|
-
input_chunk, # arg 0
|
88
|
-
weight, # arg 1
|
89
|
-
attention_mask_chunk, # arg 2
|
90
|
-
rewards_chunk, # arg 3
|
91
|
-
ref_input_chunk, # arg 4
|
92
|
-
)
|
93
|
-
|
94
|
-
def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk=None):
|
95
|
-
if bias is not None:
|
96
|
-
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
|
97
|
-
input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
|
98
|
-
)
|
99
|
-
grad_bias.add_(chunk_grad_bias)
|
100
|
-
else:
|
101
|
-
(chunk_grad_input, chunk_grad_weight), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
|
102
|
-
input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
|
103
|
-
)
|
104
|
-
|
105
|
-
# Accumulate gradients and loss
|
106
|
-
grad_weight.add_(chunk_grad_weight)
|
107
|
-
grad_inputs.append(chunk_grad_input)
|
108
|
-
loss_acc.add_(chunk_loss)
|
109
|
-
|
110
|
-
# Initialize storage for metrics on first chunk
|
111
|
-
if len(aggregated_metrics) == 0:
|
112
|
-
for metric in chunk_metrics:
|
113
|
-
if metric.ndim == 0:
|
114
|
-
aggregated_metrics.append(torch.zeros((), device=metric.device))
|
115
|
-
else:
|
116
|
-
aggregated_metrics.append([])
|
117
|
-
|
118
|
-
# Accumulate metrics
|
119
|
-
for i, metric in enumerate(chunk_metrics):
|
120
|
-
if metric.ndim == 0:
|
121
|
-
aggregated_metrics[i].add_(metric)
|
122
|
-
else:
|
123
|
-
aggregated_metrics[i].append(metric)
|
124
|
-
|
125
|
-
if compiled:
|
126
|
-
accumulate_chunk = torch.compile(accumulate_chunk)
|
127
|
-
|
128
|
-
# Process input in chunks based on num_generations
|
129
|
-
chunks = max(1, _input.shape[0] // num_generations)
|
130
|
-
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
|
131
|
-
_attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
|
132
|
-
_rewards_chunks = torch.chunk(rewards, chunks=chunks, dim=0)
|
133
|
-
_ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) if use_ref_model else [None] * chunks
|
134
|
-
|
135
|
-
for input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk in zip(
|
136
|
-
_input_chunks, _attention_mask_chunks, _rewards_chunks, _ref_input_chunks
|
137
|
-
):
|
138
|
-
# Mark dynamic dimensions
|
139
|
-
torch._dynamo.mark_dynamic(input_chunk, 1)
|
140
|
-
torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
|
141
|
-
if ref_input_chunk is not None:
|
142
|
-
torch._dynamo.mark_dynamic(ref_input_chunk, 1)
|
143
|
-
|
144
|
-
accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk)
|
145
|
-
|
146
|
-
# Scale accumulated loss by number of chunks since we're averaging
|
147
|
-
loss_acc = loss_acc / chunks
|
148
|
-
|
149
|
-
# Combine gradients
|
150
|
-
grad_input = torch.cat(grad_inputs, dim=0)
|
151
|
-
|
152
|
-
# Save for backward
|
153
|
-
ctx.save_for_backward(grad_input, grad_weight, grad_bias)
|
154
|
-
|
155
|
-
# Finalize metrics
|
156
|
-
final_metrics = []
|
157
|
-
for metric in aggregated_metrics:
|
158
|
-
if isinstance(metric, list):
|
159
|
-
final_metrics.append(torch.cat(metric, dim=0))
|
160
|
-
else:
|
161
|
-
final_metrics.append(metric / chunks)
|
162
|
-
|
163
|
-
return loss_acc, tuple(final_metrics)
|
164
|
-
|
165
|
-
@staticmethod
|
166
|
-
def _compute_chunk_loss(
|
167
|
-
input_chunk,
|
168
|
-
weight,
|
169
|
-
attention_mask_chunk,
|
170
|
-
rewards_chunk,
|
171
|
-
ref_input_chunk=None,
|
172
|
-
bias=None,
|
173
|
-
beta=0.1,
|
174
|
-
use_ref_model=False,
|
175
|
-
ref_weight=None,
|
176
|
-
ref_bias=None,
|
177
|
-
rlhf_loss_fn=None,
|
178
|
-
):
|
179
|
-
"""Compute loss for a single chunk."""
|
180
|
-
# Get policy log probabilities using chunk_forward
|
181
|
-
log_probs, _, logits_mean = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias)
|
182
|
-
|
183
|
-
# Get reference log probabilities if needed
|
184
|
-
ref_log_probs = None
|
185
|
-
if use_ref_model and ref_input_chunk is not None:
|
186
|
-
with torch.no_grad():
|
187
|
-
ref_log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward(ref_input_chunk, ref_weight, bias=ref_bias)
|
188
|
-
|
189
|
-
# Compute chunk loss and metrics using the provided loss function
|
190
|
-
chunk_loss, chunk_metrics = rlhf_loss_fn(
|
191
|
-
log_probs=log_probs,
|
192
|
-
attention_mask=attention_mask_chunk,
|
193
|
-
rewards=rewards_chunk,
|
194
|
-
ref_log_probs=ref_log_probs,
|
195
|
-
beta=beta,
|
196
|
-
)
|
197
|
-
|
198
|
-
return chunk_loss, (logits_mean, *chunk_metrics)
|
199
|
-
|
200
|
-
@staticmethod
|
201
|
-
def chunk_forward(input_chunk, weight, bias=None):
|
202
|
-
"""Forward pass computation for a single chunk without explicit reshaping."""
|
203
|
-
# Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
|
204
|
-
logits = torch.matmul(input_chunk, weight.t())
|
205
|
-
if bias is not None:
|
206
|
-
logits = logits + bias # Broadcasts bias to [B, T, V]
|
207
|
-
|
208
|
-
# Compute log probabilities using softmax over the last dimension
|
209
|
-
log_probs = F.log_softmax(logits.float(), dim=-1)
|
210
|
-
|
211
|
-
# Monitoring: compute mean of logits
|
212
|
-
batch_size, seq_len, _ = input_chunk.shape
|
213
|
-
logits_mean = logits.sum() / (batch_size * seq_len * weight.shape[0])
|
214
|
-
return log_probs, logits, logits_mean
|
215
|
-
|
216
|
-
@staticmethod
|
217
|
-
def backward(ctx, grad_output, *grad_metrics):
|
218
|
-
"""Backward pass for RLHF loss."""
|
219
|
-
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
220
|
-
if grad_output != 1.0:
|
221
|
-
grad_input = grad_input * grad_output
|
222
|
-
grad_weight = grad_weight * grad_output
|
223
|
-
if grad_bias is not None:
|
224
|
-
grad_bias = grad_bias * grad_output
|
225
|
-
|
226
|
-
return (
|
227
|
-
grad_input,
|
228
|
-
grad_weight,
|
229
|
-
None, # grad_attention_mask
|
230
|
-
None, # grad_rewards
|
231
|
-
grad_bias,
|
232
|
-
None, # grad_num_generations
|
233
|
-
None, # grad_beta
|
234
|
-
None, # grad_compiled
|
235
|
-
None, # grad_use_ref_model
|
236
|
-
None, # grad_ref_input
|
237
|
-
None, # grad_ref_weight
|
238
|
-
None, # grad_ref_bias
|
239
|
-
None, # grad_chunk_size
|
240
|
-
)
|
File without changes
|
File without changes
|
File without changes
|