liger-kernel-nightly 0.5.5.dev20250331170510__py3-none-any.whl → 0.5.5.dev20250402184001__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
2
2
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3
+ from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
3
4
  from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
4
5
  from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
5
6
  from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
@@ -11,3 +12,4 @@ liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
11
12
  liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
12
13
  liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
13
14
  liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
15
+ liger_fused_linear_grpo = LigerFusedLinearGRPOFunction.apply
@@ -0,0 +1,330 @@
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+
4
+ import torch
5
+ import torch._dynamo.config
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class LigerFusedLinearPPOBase(torch.autograd.Function):
10
+ @abstractmethod
11
+ def ppo_loss_fn(*args, **kwargs):
12
+ """
13
+ To be extended by subclasses.
14
+ """
15
+ raise NotImplementedError("PPO loss function must be implemented.")
16
+
17
+ @staticmethod
18
+ def forward(
19
+ cls,
20
+ ctx,
21
+ _input,
22
+ weight,
23
+ selected_token_ids,
24
+ attention_mask,
25
+ advantages,
26
+ bias=None,
27
+ ref_per_token_logps=None,
28
+ old_per_token_logps=None,
29
+ ref_input=None,
30
+ ref_weight=None,
31
+ ref_bias=None,
32
+ epsilon_low=0.2,
33
+ epsilon_high=0.2,
34
+ beta=0.04,
35
+ temperature=1.0,
36
+ compiled=True,
37
+ use_ref_model=False,
38
+ chunk_size=1,
39
+ ):
40
+ """Chunked forward pass for PPO loss computation.
41
+
42
+ Args:
43
+ cls: The class
44
+ ctx: Context for backward
45
+ _input: Input tensor
46
+ weight: Weight tensor
47
+ selected_token_ids: Selected token ids tensor
48
+ attention_mask: Attention mask tensor
49
+ advantages: Advantages tensor
50
+ bias: Bias tensor
51
+ ref_per_token_logps: Reference model log probs per token tensor
52
+ old_per_token_logps: Old per token log probabilities tensor
53
+ ref_input: Reference model input tensor
54
+ ref_weight: Reference model weight tensor
55
+ ref_bias: Reference model bias tensor
56
+ epsilon_low: Lower bound for clipping the importance sampling ratio
57
+ epsilon_high: Upper bound for clipping the importance sampling ratio
58
+ beta: Weight for the KL penalty
59
+ temperature: Temperature for the logits
60
+ compiled: Whether to use torch compile
61
+ use_ref_model: Whether to use a reference model
62
+ chunk_size: Size of chunks for processing in other loss modules
63
+ """
64
+ if use_ref_model:
65
+ assert ref_per_token_logps is not None or ref_input is not None, (
66
+ "If use_ref_model is True, ref_per_token_logps or ref_input must be provided"
67
+ )
68
+ if ref_per_token_logps is not None and ref_input is not None:
69
+ raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
70
+ # Initialize accumulators
71
+ loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
72
+ grad_weight = torch.zeros_like(weight) # [V, H]
73
+ grad_inputs = []
74
+ grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
75
+ aggregated_metrics = []
76
+
77
+ # Create a partial function with fixed arguments
78
+ compute_loss = partial(
79
+ LigerFusedLinearPPOBase._compute_chunk_loss,
80
+ ref_weight=ref_weight,
81
+ ref_bias=ref_bias,
82
+ full_attention_mask=attention_mask,
83
+ epsilon_low=epsilon_low,
84
+ epsilon_high=epsilon_high,
85
+ beta=beta,
86
+ temperature=temperature,
87
+ use_ref_model=use_ref_model,
88
+ ppo_loss_fn=cls.ppo_loss_fn,
89
+ )
90
+
91
+ def fused_fwd_bwd(
92
+ input_chunk,
93
+ selected_token_ids_chunk,
94
+ attention_mask_chunk,
95
+ advantages_chunk,
96
+ ref_per_token_logps_chunk,
97
+ old_per_token_logps_chunk,
98
+ ref_input_chunk,
99
+ ):
100
+ """Fused forward and backward for a chunk."""
101
+ argnums = (0, 1, 5) if bias is not None else (0, 1)
102
+ return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
103
+ input_chunk, # arg 0
104
+ weight, # arg 1
105
+ selected_token_ids_chunk, # arg 2
106
+ attention_mask_chunk, # arg 3
107
+ advantages_chunk, # arg 4
108
+ bias, # arg 5
109
+ ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6
110
+ old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7
111
+ ref_input_chunk=ref_input_chunk, # arg 8
112
+ )
113
+
114
+ def accumulate_chunk(
115
+ input_chunk,
116
+ selected_token_ids_chunk,
117
+ attention_mask_chunk,
118
+ advantages_chunk,
119
+ ref_per_token_logps_chunk=None,
120
+ old_per_token_logps_chunk=None,
121
+ ref_input_chunk=None,
122
+ ):
123
+ (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
124
+ input_chunk,
125
+ selected_token_ids_chunk,
126
+ attention_mask_chunk,
127
+ advantages_chunk,
128
+ ref_per_token_logps_chunk,
129
+ old_per_token_logps_chunk,
130
+ ref_input_chunk,
131
+ )
132
+ if bias is not None:
133
+ grad_bias.add_(chunk_grad_bias[0])
134
+
135
+ # Accumulate gradients and loss
136
+ grad_weight.add_(chunk_grad_weight)
137
+ grad_inputs.append(chunk_grad_input)
138
+ loss_acc.add_(chunk_loss)
139
+ # Initialize storage for metrics on first chunk
140
+ if len(aggregated_metrics) == 0:
141
+ for metric in chunk_metrics:
142
+ if metric.ndim == 0:
143
+ aggregated_metrics.append(torch.zeros((), device=metric.device))
144
+ else:
145
+ aggregated_metrics.append([])
146
+
147
+ # Accumulate metrics
148
+ for i, metric in enumerate(chunk_metrics):
149
+ if metric.ndim == 0:
150
+ aggregated_metrics[i].add_(metric)
151
+ else:
152
+ aggregated_metrics[i].append(metric)
153
+
154
+ if compiled:
155
+ # TODO: Figure out what is better to compile here
156
+ # accumulate_chunk = torch.compile(accumulate_chunk)
157
+ fused_fwd_bwd = torch.compile(fused_fwd_bwd)
158
+
159
+ # Process input in chunks based on chunk_size
160
+ chunks = max(1, _input.shape[0] // chunk_size)
161
+ _input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
162
+ _selected_token_ids_chunks = torch.chunk(selected_token_ids, chunks=chunks, dim=0)
163
+ _attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
164
+ _advantages_chunks = torch.chunk(advantages, chunks=chunks, dim=0)
165
+ _ref_per_token_logps_chunks = (
166
+ torch.chunk(ref_per_token_logps, chunks=chunks, dim=0)
167
+ if use_ref_model and ref_per_token_logps is not None
168
+ else [None] * chunks
169
+ )
170
+ _old_per_token_logps_chunks = (
171
+ torch.chunk(old_per_token_logps, chunks=chunks, dim=0)
172
+ if old_per_token_logps is not None
173
+ else [None] * chunks
174
+ )
175
+ # if ref_log_probs is not none, then we don't need ref_input to calculate the log probs
176
+ _ref_input_chunks = (
177
+ torch.chunk(ref_input, chunks=chunks, dim=0)
178
+ if use_ref_model and ref_per_token_logps is None
179
+ else [None] * chunks
180
+ )
181
+
182
+ for (
183
+ input_chunk,
184
+ selected_token_ids_chunk,
185
+ attention_mask_chunk,
186
+ advantages_chunk,
187
+ ref_per_token_logps_chunk,
188
+ old_per_token_logps_chunk,
189
+ ref_input_chunk,
190
+ ) in zip(
191
+ _input_chunks,
192
+ _selected_token_ids_chunks,
193
+ _attention_mask_chunks,
194
+ _advantages_chunks,
195
+ _ref_per_token_logps_chunks,
196
+ _old_per_token_logps_chunks,
197
+ _ref_input_chunks,
198
+ ):
199
+ # Mark dynamic dimensions
200
+ torch._dynamo.mark_dynamic(input_chunk, 1)
201
+ torch._dynamo.mark_dynamic(selected_token_ids_chunk, 1)
202
+ torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
203
+ if ref_per_token_logps_chunk is not None:
204
+ torch._dynamo.mark_dynamic(ref_per_token_logps_chunk, 1)
205
+ if ref_input_chunk is not None:
206
+ torch._dynamo.mark_dynamic(ref_input_chunk, 1)
207
+ if old_per_token_logps_chunk is not None:
208
+ torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1)
209
+
210
+ accumulate_chunk(
211
+ input_chunk,
212
+ selected_token_ids_chunk,
213
+ attention_mask_chunk,
214
+ advantages_chunk,
215
+ ref_per_token_logps_chunk,
216
+ old_per_token_logps_chunk,
217
+ ref_input_chunk,
218
+ )
219
+
220
+ # Combine gradients
221
+ grad_input = torch.cat(grad_inputs, dim=0)
222
+
223
+ # Save for backward
224
+ ctx.save_for_backward(grad_input, grad_weight, grad_bias)
225
+
226
+ # Finalize metrics
227
+ final_metrics = []
228
+ for metric in aggregated_metrics:
229
+ if isinstance(metric, list):
230
+ final_metrics.append(torch.cat(metric, dim=0))
231
+ else:
232
+ final_metrics.append(metric)
233
+
234
+ return loss_acc, tuple(final_metrics)
235
+
236
+ @staticmethod
237
+ def _compute_chunk_loss(
238
+ input_chunk,
239
+ weight,
240
+ selected_token_ids_chunk,
241
+ attention_mask_chunk,
242
+ advantages_chunk,
243
+ bias=None,
244
+ ref_per_token_logps_chunk=None,
245
+ old_per_token_logps_chunk=None,
246
+ ref_input_chunk=None,
247
+ ref_weight=None,
248
+ ref_bias=None,
249
+ full_attention_mask=None,
250
+ epsilon_low=0.2,
251
+ epsilon_high=0.2,
252
+ beta=0.04,
253
+ temperature=1.0,
254
+ use_ref_model=False,
255
+ ppo_loss_fn=None,
256
+ ):
257
+ """Compute loss for a single chunk."""
258
+ # Get policy log probabilities using chunk_forward
259
+ log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(input_chunk, weight, bias=bias, temperature=temperature)
260
+
261
+ # Get reference log probabilities if needed
262
+ ref_log_probs = None
263
+ if use_ref_model and ref_per_token_logps_chunk is None:
264
+ with torch.no_grad():
265
+ ref_log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(
266
+ ref_input_chunk, ref_weight, bias=ref_bias, temperature=temperature
267
+ )
268
+
269
+ # Compute chunk loss and metrics using the provided loss function
270
+ chunk_loss, chunk_metrics = ppo_loss_fn(
271
+ log_probs=log_probs,
272
+ selected_token_ids=selected_token_ids_chunk,
273
+ attention_mask=attention_mask_chunk,
274
+ advantages=advantages_chunk,
275
+ full_attention_mask=full_attention_mask,
276
+ ref_per_token_logps=ref_per_token_logps_chunk.float() if ref_per_token_logps_chunk is not None else None,
277
+ old_per_token_logps=old_per_token_logps_chunk.float() if old_per_token_logps_chunk is not None else None,
278
+ ref_log_probs=ref_log_probs, # used when ref_per_token_logps is None
279
+ epsilon_low=epsilon_low,
280
+ epsilon_high=epsilon_high,
281
+ beta=beta,
282
+ )
283
+
284
+ return chunk_loss, chunk_metrics
285
+
286
+ @staticmethod
287
+ def chunk_forward(input_chunk, weight, bias=None, temperature=1.0):
288
+ """Forward pass computation for a single chunk without explicit reshaping."""
289
+ # Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
290
+ logits = torch.matmul(input_chunk, weight.t())
291
+ if bias is not None:
292
+ logits = logits + bias # Broadcasts bias to [B, T, V]
293
+ if temperature != 1.0:
294
+ logits = logits / temperature
295
+
296
+ # Compute log probabilities using softmax over the last dimension
297
+ log_probs = F.log_softmax(logits.float(), dim=-1)
298
+
299
+ return log_probs, logits
300
+
301
+ @staticmethod
302
+ def backward(ctx, grad_output, *grad_metrics):
303
+ """Backward pass for PPO loss."""
304
+ grad_input, grad_weight, grad_bias = ctx.saved_tensors
305
+ if grad_output != 1.0:
306
+ grad_input = grad_input * grad_output
307
+ grad_weight = grad_weight * grad_output
308
+ if grad_bias is not None:
309
+ grad_bias = grad_bias * grad_output
310
+
311
+ return (
312
+ grad_input,
313
+ grad_weight,
314
+ None, # grad_selected_token_ids
315
+ None, # grad_attention_mask
316
+ None, # grad_advantages
317
+ grad_bias,
318
+ None, # grad_ref_per_token_logps
319
+ None, # grad_old_per_token_logps
320
+ None, # grad_ref_input
321
+ None, # grad_ref_weight
322
+ None, # grad_ref_bias
323
+ None, # grad_epsilon_low
324
+ None, # grad_epsilon_high
325
+ None, # grad_beta
326
+ None, # grad_temperature
327
+ None, # grad_compiled
328
+ None, # grad_use_ref_model
329
+ None, # grad_chunk_size
330
+ )
@@ -1,66 +1,76 @@
1
1
  import torch
2
2
 
3
- from liger_kernel.chunked_loss.fused_linear_rlhf import LigerFusedLinearRLHFBase
3
+ from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
4
4
 
5
5
 
6
- class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
6
+ def k3_loss_fn(log_p, log_q):
7
+ # computes k3 estimate of KL[q, p]
8
+ # ref: http://joschu.net/blog/kl-approx.html
9
+ return torch.exp(log_p - log_q) - (log_p - log_q) - 1.0
10
+
11
+
12
+ def clip_coef_fn(coef, epsilon_low, epsilon_high):
13
+ return torch.clamp(coef, 1 - epsilon_low, 1 + epsilon_high)
14
+
15
+
16
+ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
7
17
  @staticmethod
8
- def rlhf_loss_fn(
18
+ def ppo_loss_fn(
9
19
  log_probs,
20
+ selected_token_ids,
10
21
  attention_mask,
11
- rewards,
12
- ref_log_probs=None,
13
- beta=0.1,
22
+ advantages,
23
+ full_attention_mask,
24
+ ref_per_token_logps=None, # shape: [chunk_size, seq_len]
25
+ old_per_token_logps=None,
26
+ ref_log_probs=None, # used when ref_per_token_logps is None (shape: [chunk_size, seq_len, vocab_size])
27
+ epsilon_low=0.2,
28
+ epsilon_high=0.2,
29
+ beta=0.04,
14
30
  **kwargs,
15
31
  ):
16
32
  """GRPO Loss Function matching GRPOTrainer implementation."""
17
- # Get chosen token probabilities
18
- chosen_tokens = log_probs.argmax(dim=-1) # (batch_size, seq_len)
19
- chosen_token_logprobs = log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(
33
+ per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
20
34
  -1
21
35
  ) # (batch_size, seq_len)
22
36
 
23
37
  # Get reference model probabilities
24
- if ref_log_probs is not None:
25
- with torch.no_grad():
26
- ref_token_logprobs = ref_log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(-1)
27
- else:
28
- ref_token_logprobs = chosen_token_logprobs.detach()
29
-
30
- # Compute advantages per batch entry in a grouped fashion
31
- mean_grouped_rewards = rewards.mean() # [batch_size,]
32
- std_grouped_rewards = rewards.std() # [batch_size,]
33
-
34
- # Calculate advantages using the same epsilon as in GRPOTrainer
35
- eps = 1e-4
36
- advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + eps)
38
+ if ref_per_token_logps is None:
39
+ if ref_log_probs is not None:
40
+ with torch.no_grad():
41
+ ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
42
+ -1
43
+ )
44
+ else:
45
+ ref_per_token_logps = per_token_logps.detach()
37
46
 
38
47
  # Compute policy gradient loss with importance sampling ratio
39
- ratio = torch.exp(chosen_token_logprobs - chosen_token_logprobs.detach())
40
- policy_loss = -ratio * advantages.unsqueeze(1)
41
-
42
- # Compute KL penalty
43
- kl_div = (
44
- torch.exp(ref_token_logprobs - chosen_token_logprobs) - (ref_token_logprobs - chosen_token_logprobs) - 1.0
45
- )
46
-
47
- # Combine losses
48
- per_token_loss = policy_loss + beta * kl_div
49
-
50
- # Apply masking and normalize
51
- masked_loss = per_token_loss * attention_mask
52
- seq_lengths = attention_mask.sum()
53
- seq_lengths = torch.clamp(seq_lengths, min=1.0)
54
- loss = masked_loss.sum() / seq_lengths
48
+ old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
49
+ coef_1 = torch.exp(per_token_logps - old_per_token_logps)
50
+ coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
51
+ per_token_loss1 = coef_1 * advantages.unsqueeze(1)
52
+ per_token_loss2 = coef_2 * advantages.unsqueeze(1)
53
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
54
+ if beta != 0.0:
55
+ # Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
56
+ kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
57
+ # Combine losses
58
+ per_token_loss = per_token_loss + beta * kl_div
59
+
60
+ # Note: We normalize by the number of tokens in the batch (using full_attention_mask),
61
+ # which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
62
+ # and TRL GRPO implementation
63
+ # (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
64
+ loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
55
65
 
56
66
  # Calculate metrics
57
- metrics = (
58
- chosen_token_logprobs.mean(), # mean log prob
59
- chosen_token_logprobs.std(), # std log prob
60
- log_probs.mean(), # mean all log probs
61
- ((kl_div * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)).mean(), # mean KL div
67
+ metrics = []
68
+ if beta != 0.0:
69
+ metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
70
+ is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
71
+ (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
62
72
  )
63
-
73
+ metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
64
74
  return loss, metrics
65
75
 
66
76
  @classmethod
@@ -69,16 +79,21 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
69
79
  ctx,
70
80
  _input,
71
81
  weight,
82
+ selected_token_ids,
72
83
  attention_mask,
73
- rewards,
84
+ advantages,
74
85
  bias=None,
86
+ ref_per_token_logps=None,
87
+ old_per_token_logps=None,
75
88
  ref_input=None,
76
89
  ref_weight=None,
77
90
  ref_bias=None,
78
- beta=0.1,
91
+ beta=0.04,
92
+ epsilon_low=0.2,
93
+ epsilon_high=0.2,
94
+ temperature=1.0,
79
95
  compiled=True,
80
96
  use_ref_model=True,
81
- num_generations=1,
82
97
  chunk_size=1,
83
98
  ):
84
99
  """
@@ -86,16 +101,18 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
86
101
  Args:
87
102
  _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
88
103
  weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
104
+ selected_token_ids (torch.Tensor): Selected token ids tensor. Shape: (batch_size, seq_len)
89
105
  attention_mask (torch.Tensor): Attention mask tensor. Shape: (batch_size, seq_len)
90
- rewards (torch.Tensor): Rewards tensor. Shape: (batch_size,)
106
+ advantages (torch.Tensor): Advantages tensor. Shape: (batch_size,)
91
107
  bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
108
+ ref_per_token_logps: Reference model log probs per token tensor. Shape:(batch_size, seq_len)
92
109
  ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
93
110
  ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
94
111
  ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
95
112
  beta (float): Weight for the KL penalty
113
+ temperature (float): Temperature for the logits
96
114
  compiled (bool): Whether to use torch compile
97
115
  use_ref_model (bool): Whether to use a reference model
98
- num_generations (int): Number of generations per prompt
99
116
  chunk_size (int): Size of chunks for processing.
100
117
  Returns:
101
118
  torch.Tensor: Computed loss
@@ -105,16 +122,21 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
105
122
  ctx=ctx,
106
123
  _input=_input,
107
124
  weight=weight,
125
+ selected_token_ids=selected_token_ids,
108
126
  attention_mask=attention_mask,
109
- rewards=rewards,
127
+ advantages=advantages,
110
128
  bias=bias,
129
+ ref_per_token_logps=ref_per_token_logps,
130
+ old_per_token_logps=old_per_token_logps,
111
131
  ref_input=ref_input,
112
132
  ref_weight=ref_weight,
113
133
  ref_bias=ref_bias,
114
134
  beta=beta,
135
+ epsilon_low=epsilon_low,
136
+ epsilon_high=epsilon_high,
137
+ temperature=temperature,
115
138
  compiled=compiled,
116
139
  use_ref_model=use_ref_model,
117
- num_generations=num_generations,
118
140
  chunk_size=chunk_size,
119
141
  )
120
142
 
@@ -126,16 +148,22 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
126
148
  grad_output: Gradient of the loss (scalar)
127
149
  grad_metrics: Gradients of the metrics (not used in backward computation)
128
150
  """
129
- grads = LigerFusedLinearRLHFBase.backward(ctx, grad_output)
151
+ grads = LigerFusedLinearPPOBase.backward(ctx, grad_output)
130
152
  return (
131
- *grads[:5], # grad_input, grad_weight, grad_attention_mask, grad_rewards, grad_bias
153
+ *grads[
154
+ :6
155
+ ], # grad_input, grad_weight, grad_selected_token_ids, grad_attention_mask, grad_advantages, grad_bias
156
+ None, # grad_ref_per_token_logps
157
+ None, # grad_old_per_token_logps
132
158
  None, # grad_ref_input
133
159
  None, # grad_ref_weight
134
160
  None, # grad_ref_bias
135
161
  None, # grad_beta
162
+ None, # grad_epsilon_low
163
+ None, # grad_epsilon_high
164
+ None, # grad_temperature
136
165
  None, # grad_compiled
137
166
  None, # grad_use_ref_model
138
- None, # grad_num_generations
139
167
  None, # grad_chunk_size
140
168
  )
141
169
 
@@ -145,34 +173,43 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
145
173
 
146
174
  def __init__(
147
175
  self,
148
- beta: float = 0.1,
176
+ beta: float = 0.04,
149
177
  compiled: bool = True,
150
178
  use_ref_model: bool = True,
151
- num_generations: int = 1,
152
179
  chunk_size: int = 1,
180
+ epsilon_low: float = 0.2,
181
+ epsilon_high: float = 0.2,
182
+ temperature: float = 1.0,
153
183
  ):
154
184
  """
155
185
  Args:
156
186
  beta (float): Weight for the KL penalty.
157
187
  compiled (bool): Whether to use torch compile.
158
188
  use_ref_model (bool): Whether to use a reference model.
159
- num_generations (int): Number of generations per prompt.
160
189
  chunk_size (int): Size of chunks for processing.
190
+ epsilon_low (float): Lower bound for the importance sampling ratio.
191
+ epsilon_high (float): Upper bound for the importance sampling ratio.
192
+ temperature (float): Temperature for the logits.
161
193
  """
162
194
  super().__init__()
163
195
  self.beta = beta
164
196
  self.compiled = compiled
165
197
  self.use_ref_model = use_ref_model
166
- self.num_generations = num_generations
167
198
  self.chunk_size = chunk_size
199
+ self.epsilon_low = epsilon_low
200
+ self.epsilon_high = epsilon_high
201
+ self.temperature = temperature
168
202
 
169
203
  def forward(
170
204
  self,
171
205
  _input,
172
206
  lin_weight,
207
+ selected_token_ids,
173
208
  attention_mask,
174
- rewards,
209
+ advantages,
175
210
  bias=None,
211
+ ref_per_token_logps=None,
212
+ old_per_token_logps=None,
176
213
  ref_input=None,
177
214
  ref_weight=None,
178
215
  ref_bias=None,
@@ -180,15 +217,20 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
180
217
  return LigerFusedLinearGRPOFunction.apply(
181
218
  _input,
182
219
  lin_weight,
220
+ selected_token_ids,
183
221
  attention_mask,
184
- rewards,
222
+ advantages,
185
223
  bias,
224
+ ref_per_token_logps,
225
+ old_per_token_logps,
186
226
  ref_input,
187
227
  ref_weight,
188
228
  ref_bias,
189
229
  self.beta,
230
+ self.epsilon_low,
231
+ self.epsilon_high,
232
+ self.temperature,
190
233
  self.compiled,
191
234
  self.use_ref_model,
192
- self.num_generations,
193
235
  self.chunk_size,
194
236
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.5.dev20250331170510
3
+ Version: 0.5.5.dev20250402184001
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -5,12 +5,12 @@ liger_kernel/chunked_loss/README.md,sha256=0FmkFC3hKBqyoDT5uTlIYmrvRkF-EOCR1y-EB
5
5
  liger_kernel/chunked_loss/__init__.py,sha256=ATu-xX5Fc49Cr6yBOGBRNTo593ZrU5ZCsIuvoIbJWw4,603
6
6
  liger_kernel/chunked_loss/cpo_loss.py,sha256=Gzz1eU4kgcbdubFVRy55e8A1Cr-r45UgNicXwZIjmBU,5454
7
7
  liger_kernel/chunked_loss/dpo_loss.py,sha256=xZwGqS04si9zXyob95SAdalC-hajZg8fWINqiqffN8k,5855
8
- liger_kernel/chunked_loss/functional.py,sha256=THWWpCnRVhTVfnPnyvQjdBvo1JDtxhwLmtZE_yiBBqM,817
8
+ liger_kernel/chunked_loss/functional.py,sha256=9G3nKm-Bi7uoZRFkL8wwGMl6juDl4bSzDvTa5GHZPzg,955
9
9
  liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=ooR-qnZCyWJN935oHCSWLaKKKyaYERyhNczRGi1VOiw,11935
10
+ liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=2_UvvIksUP45RBw3c-88-jOtjGATf04vaWopcqtX4Oo,12688
10
11
  liger_kernel/chunked_loss/fused_linear_preference.py,sha256=ojB42jYPu0c4ki96Ft-hy7Sf6fh_WikG-aWNrlZzSio,18362
11
- liger_kernel/chunked_loss/fused_linear_rlhf.py,sha256=wGujqwLz91mOE9MmdenhBIKvbmswhwtINMCpcP7D74c,9050
12
12
  liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=RiuK3UtRwH9T6jZ36sA8Urj-TVuOLOO2syLg_JOQapY,13437
13
- liger_kernel/chunked_loss/grpo_loss.py,sha256=axED3628yKODu1v7PMAvSd08WZqwNQvJOTUYMgcihdQ,6665
13
+ liger_kernel/chunked_loss/grpo_loss.py,sha256=6Mb4ZT6MfnOr4Xo681rMR0LKkhzJhInvQp8wp2YVMK0,8913
14
14
  liger_kernel/chunked_loss/jsd_loss.py,sha256=u2ahkuHsbhpNaKcpBCz5gCMDk9ou-P04DHji592dIBo,7067
15
15
  liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsmSbQyqwQY,7529
16
16
  liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
@@ -72,9 +72,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
72
72
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
73
73
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
74
74
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
75
- liger_kernel_nightly-0.5.5.dev20250331170510.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
76
- liger_kernel_nightly-0.5.5.dev20250331170510.dist-info/METADATA,sha256=KEjXLNI8PYfmvipid4KUVeM0XE5oKXd5Pl7ikrZbAqU,22959
77
- liger_kernel_nightly-0.5.5.dev20250331170510.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
78
- liger_kernel_nightly-0.5.5.dev20250331170510.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
79
- liger_kernel_nightly-0.5.5.dev20250331170510.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
80
- liger_kernel_nightly-0.5.5.dev20250331170510.dist-info/RECORD,,
75
+ liger_kernel_nightly-0.5.5.dev20250402184001.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
76
+ liger_kernel_nightly-0.5.5.dev20250402184001.dist-info/METADATA,sha256=DLGGPCgn1-dKSQVP5sYIzzRoh7c9wBUjM7JFujYn1KI,22959
77
+ liger_kernel_nightly-0.5.5.dev20250402184001.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
78
+ liger_kernel_nightly-0.5.5.dev20250402184001.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
79
+ liger_kernel_nightly-0.5.5.dev20250402184001.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
80
+ liger_kernel_nightly-0.5.5.dev20250402184001.dist-info/RECORD,,
@@ -1,240 +0,0 @@
1
- from abc import abstractmethod
2
- from functools import partial
3
-
4
- import torch
5
- import torch.nn.functional as F
6
-
7
-
8
- class LigerFusedLinearRLHFBase(torch.autograd.Function):
9
- @abstractmethod
10
- def rlhf_loss_fn(*args, **kwargs):
11
- """
12
- To be extended by subclasses.
13
- """
14
- raise NotImplementedError("RLHF loss function must be implemented.")
15
-
16
- @staticmethod
17
- def forward(
18
- cls,
19
- ctx,
20
- _input,
21
- weight,
22
- attention_mask,
23
- rewards,
24
- bias=None,
25
- num_generations=4,
26
- beta=0.1,
27
- compiled=True,
28
- use_ref_model=False,
29
- ref_input=None,
30
- ref_weight=None,
31
- ref_bias=None,
32
- chunk_size=1,
33
- ):
34
- """Chunked forward pass for RLHF loss computation.
35
-
36
- Args:
37
- cls: The class
38
- ctx: Context for backward
39
- _input: Input tensor
40
- weight: Weight tensor
41
- attention_mask: Attention mask tensor
42
- rewards: Rewards tensor
43
- bias: Bias tensor
44
- num_generations: Number of generations per prompt
45
- beta: Weight for the KL penalty
46
- compiled: Whether to use torch compile
47
- use_ref_model: Whether to use a reference model
48
- ref_input: Reference model input tensor
49
- ref_weight: Reference model weight tensor
50
- ref_bias: Reference model bias tensor
51
- chunk_size: Size of chunks for processing in other loss modules
52
- """
53
- # Save for backward
54
- ctx.beta = beta
55
- ctx.rewards = rewards
56
-
57
- # Initialize accumulators
58
- loss_acc = torch.zeros((), device=_input.device)
59
- grad_weight = torch.zeros_like(weight) # [V, H]
60
- grad_inputs = []
61
- grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
62
- aggregated_metrics = []
63
-
64
- # Create a partial function with fixed arguments
65
- compute_loss = partial(
66
- LigerFusedLinearRLHFBase._compute_chunk_loss,
67
- beta=beta,
68
- use_ref_model=use_ref_model,
69
- ref_weight=ref_weight,
70
- ref_bias=ref_bias,
71
- rlhf_loss_fn=cls.rlhf_loss_fn,
72
- )
73
-
74
- def fused_fwd_bwd(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk):
75
- """Fused forward and backward for a chunk."""
76
- if bias is not None:
77
- return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 5), has_aux=True)(
78
- input_chunk, # arg 0
79
- weight, # arg 1
80
- attention_mask_chunk, # arg 2
81
- rewards_chunk, # arg 3
82
- ref_input_chunk, # arg 4
83
- bias, # arg 5
84
- )
85
- else:
86
- return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
87
- input_chunk, # arg 0
88
- weight, # arg 1
89
- attention_mask_chunk, # arg 2
90
- rewards_chunk, # arg 3
91
- ref_input_chunk, # arg 4
92
- )
93
-
94
- def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk=None):
95
- if bias is not None:
96
- (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
97
- input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
98
- )
99
- grad_bias.add_(chunk_grad_bias)
100
- else:
101
- (chunk_grad_input, chunk_grad_weight), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
102
- input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
103
- )
104
-
105
- # Accumulate gradients and loss
106
- grad_weight.add_(chunk_grad_weight)
107
- grad_inputs.append(chunk_grad_input)
108
- loss_acc.add_(chunk_loss)
109
-
110
- # Initialize storage for metrics on first chunk
111
- if len(aggregated_metrics) == 0:
112
- for metric in chunk_metrics:
113
- if metric.ndim == 0:
114
- aggregated_metrics.append(torch.zeros((), device=metric.device))
115
- else:
116
- aggregated_metrics.append([])
117
-
118
- # Accumulate metrics
119
- for i, metric in enumerate(chunk_metrics):
120
- if metric.ndim == 0:
121
- aggregated_metrics[i].add_(metric)
122
- else:
123
- aggregated_metrics[i].append(metric)
124
-
125
- if compiled:
126
- accumulate_chunk = torch.compile(accumulate_chunk)
127
-
128
- # Process input in chunks based on num_generations
129
- chunks = max(1, _input.shape[0] // num_generations)
130
- _input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
131
- _attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
132
- _rewards_chunks = torch.chunk(rewards, chunks=chunks, dim=0)
133
- _ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) if use_ref_model else [None] * chunks
134
-
135
- for input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk in zip(
136
- _input_chunks, _attention_mask_chunks, _rewards_chunks, _ref_input_chunks
137
- ):
138
- # Mark dynamic dimensions
139
- torch._dynamo.mark_dynamic(input_chunk, 1)
140
- torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
141
- if ref_input_chunk is not None:
142
- torch._dynamo.mark_dynamic(ref_input_chunk, 1)
143
-
144
- accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk)
145
-
146
- # Scale accumulated loss by number of chunks since we're averaging
147
- loss_acc = loss_acc / chunks
148
-
149
- # Combine gradients
150
- grad_input = torch.cat(grad_inputs, dim=0)
151
-
152
- # Save for backward
153
- ctx.save_for_backward(grad_input, grad_weight, grad_bias)
154
-
155
- # Finalize metrics
156
- final_metrics = []
157
- for metric in aggregated_metrics:
158
- if isinstance(metric, list):
159
- final_metrics.append(torch.cat(metric, dim=0))
160
- else:
161
- final_metrics.append(metric / chunks)
162
-
163
- return loss_acc, tuple(final_metrics)
164
-
165
- @staticmethod
166
- def _compute_chunk_loss(
167
- input_chunk,
168
- weight,
169
- attention_mask_chunk,
170
- rewards_chunk,
171
- ref_input_chunk=None,
172
- bias=None,
173
- beta=0.1,
174
- use_ref_model=False,
175
- ref_weight=None,
176
- ref_bias=None,
177
- rlhf_loss_fn=None,
178
- ):
179
- """Compute loss for a single chunk."""
180
- # Get policy log probabilities using chunk_forward
181
- log_probs, _, logits_mean = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias)
182
-
183
- # Get reference log probabilities if needed
184
- ref_log_probs = None
185
- if use_ref_model and ref_input_chunk is not None:
186
- with torch.no_grad():
187
- ref_log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward(ref_input_chunk, ref_weight, bias=ref_bias)
188
-
189
- # Compute chunk loss and metrics using the provided loss function
190
- chunk_loss, chunk_metrics = rlhf_loss_fn(
191
- log_probs=log_probs,
192
- attention_mask=attention_mask_chunk,
193
- rewards=rewards_chunk,
194
- ref_log_probs=ref_log_probs,
195
- beta=beta,
196
- )
197
-
198
- return chunk_loss, (logits_mean, *chunk_metrics)
199
-
200
- @staticmethod
201
- def chunk_forward(input_chunk, weight, bias=None):
202
- """Forward pass computation for a single chunk without explicit reshaping."""
203
- # Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
204
- logits = torch.matmul(input_chunk, weight.t())
205
- if bias is not None:
206
- logits = logits + bias # Broadcasts bias to [B, T, V]
207
-
208
- # Compute log probabilities using softmax over the last dimension
209
- log_probs = F.log_softmax(logits.float(), dim=-1)
210
-
211
- # Monitoring: compute mean of logits
212
- batch_size, seq_len, _ = input_chunk.shape
213
- logits_mean = logits.sum() / (batch_size * seq_len * weight.shape[0])
214
- return log_probs, logits, logits_mean
215
-
216
- @staticmethod
217
- def backward(ctx, grad_output, *grad_metrics):
218
- """Backward pass for RLHF loss."""
219
- grad_input, grad_weight, grad_bias = ctx.saved_tensors
220
- if grad_output != 1.0:
221
- grad_input = grad_input * grad_output
222
- grad_weight = grad_weight * grad_output
223
- if grad_bias is not None:
224
- grad_bias = grad_bias * grad_output
225
-
226
- return (
227
- grad_input,
228
- grad_weight,
229
- None, # grad_attention_mask
230
- None, # grad_rewards
231
- grad_bias,
232
- None, # grad_num_generations
233
- None, # grad_beta
234
- None, # grad_compiled
235
- None, # grad_use_ref_model
236
- None, # grad_ref_input
237
- None, # grad_ref_weight
238
- None, # grad_ref_bias
239
- None, # grad_chunk_size
240
- )