liger-kernel 0.5.5__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. liger_kernel/chunked_loss/functional.py +2 -0
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +17 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +331 -0
  4. liger_kernel/chunked_loss/grpo_loss.py +103 -61
  5. liger_kernel/chunked_loss/jsd_loss.py +12 -7
  6. liger_kernel/ops/cross_entropy.py +3 -2
  7. liger_kernel/ops/dyt.py +225 -0
  8. liger_kernel/ops/fused_linear_jsd.py +2 -1
  9. liger_kernel/ops/jsd.py +30 -11
  10. liger_kernel/ops/kl_div.py +2 -2
  11. liger_kernel/transformers/__init__.py +3 -0
  12. liger_kernel/transformers/dyt.py +20 -0
  13. liger_kernel/transformers/functional.py +5 -0
  14. liger_kernel/transformers/model/gemma.py +8 -16
  15. liger_kernel/transformers/model/gemma2.py +7 -16
  16. liger_kernel/transformers/model/llama.py +8 -15
  17. liger_kernel/transformers/model/llava.py +369 -0
  18. liger_kernel/transformers/model/loss_utils.py +57 -0
  19. liger_kernel/transformers/model/mistral.py +9 -10
  20. liger_kernel/transformers/model/mixtral.py +8 -15
  21. liger_kernel/transformers/model/mllama.py +8 -15
  22. liger_kernel/transformers/model/olmo2.py +8 -16
  23. liger_kernel/transformers/model/paligemma.py +397 -0
  24. liger_kernel/transformers/model/phi3.py +8 -15
  25. liger_kernel/transformers/model/qwen2.py +8 -15
  26. liger_kernel/transformers/model/qwen2_5_vl.py +9 -10
  27. liger_kernel/transformers/model/qwen2_vl.py +9 -10
  28. liger_kernel/transformers/monkey_patch.py +219 -13
  29. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/METADATA +9 -6
  30. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/RECORD +34 -29
  31. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/WHEEL +1 -1
  32. liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
  33. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info/licenses}/LICENSE +0 -0
  34. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info/licenses}/NOTICE +0 -0
  35. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
2
2
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3
+ from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
3
4
  from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
4
5
  from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
5
6
  from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
@@ -11,3 +12,4 @@ liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
11
12
  liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
12
13
  liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
13
14
  liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
15
+ liger_fused_linear_grpo = LigerFusedLinearGRPOFunction.apply
@@ -115,9 +115,24 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
115
115
  student_logits_chunk /= temperature
116
116
  teacher_logits_chunk /= temperature
117
117
 
118
+ # If the teacher and student token size is different, pad student logits to match the teacher's.
119
+ # This only applies to cases where they share exactly the same vocab and tokenizer just
120
+ # that teacher logit is padded for some training efficiency such as
121
+ # https://huggingface.co/Qwen/Qwen1.5-72B-Chat/discussions/1#662883f568adf59b07b176d2
122
+ teacher_vocab_size = teacher_weight.shape[0]
123
+ student_vocab_size = student_weight.shape[0]
124
+ if teacher_vocab_size > student_vocab_size:
125
+ pad_size = teacher_vocab_size - student_vocab_size
126
+ pad_tensor = torch.zeros(
127
+ (*student_logits_chunk.shape[:-1], pad_size),
128
+ dtype=student_logits_chunk.dtype,
129
+ device=student_logits_chunk.device,
130
+ )
131
+ student_logits_chunk = torch.cat([student_logits_chunk, pad_tensor], dim=-1)
132
+
118
133
  hard_loss /= full_target.shape[0]
119
134
 
120
- soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
135
+ soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, **loss_kwargs)
121
136
  soft_loss /= full_target.shape[0]
122
137
 
123
138
  loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
@@ -180,9 +195,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
180
195
  ignore_index=ignore_index,
181
196
  weight_hard_loss=weight_hard_loss,
182
197
  weight_soft_loss=weight_soft_loss,
183
- beta=beta,
184
198
  compute_ce_loss=compute_ce_loss,
185
199
  temperature=temperature,
200
+ beta=beta,
186
201
  **loss_kwargs,
187
202
  )
188
203
 
@@ -0,0 +1,331 @@
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+
4
+ import torch
5
+ import torch._dynamo.config
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class LigerFusedLinearPPOBase(torch.autograd.Function):
10
+ @abstractmethod
11
+ def ppo_loss_fn(*args, **kwargs):
12
+ """
13
+ To be extended by subclasses.
14
+ """
15
+ raise NotImplementedError("PPO loss function must be implemented.")
16
+
17
+ @staticmethod
18
+ def forward(
19
+ cls,
20
+ ctx,
21
+ _input,
22
+ weight,
23
+ selected_token_ids,
24
+ attention_mask,
25
+ advantages,
26
+ bias=None,
27
+ ref_per_token_logps=None,
28
+ old_per_token_logps=None,
29
+ ref_input=None,
30
+ ref_weight=None,
31
+ ref_bias=None,
32
+ epsilon_low=0.2,
33
+ epsilon_high=0.2,
34
+ beta=0.04,
35
+ temperature=1.0,
36
+ compiled=True,
37
+ use_ref_model=False,
38
+ chunk_size=1,
39
+ ):
40
+ # TODO: check torch compile matmul
41
+ """Chunked forward pass for PPO loss computation.
42
+
43
+ Args:
44
+ cls: The class
45
+ ctx: Context for backward
46
+ _input: Input tensor
47
+ weight: Weight tensor
48
+ selected_token_ids: Selected token ids tensor
49
+ attention_mask: Attention mask tensor
50
+ advantages: Advantages tensor
51
+ bias: Bias tensor
52
+ ref_per_token_logps: Reference model log probs per token tensor
53
+ old_per_token_logps: Old per token log probabilities tensor
54
+ ref_input: Reference model input tensor
55
+ ref_weight: Reference model weight tensor
56
+ ref_bias: Reference model bias tensor
57
+ epsilon_low: Lower bound for clipping the importance sampling ratio
58
+ epsilon_high: Upper bound for clipping the importance sampling ratio
59
+ beta: Weight for the KL penalty
60
+ temperature: Temperature for the logits
61
+ compiled: Whether to use torch compile
62
+ use_ref_model: Whether to use a reference model
63
+ chunk_size: Size of chunks for processing in other loss modules
64
+ """
65
+ if use_ref_model:
66
+ assert ref_per_token_logps is not None or ref_input is not None, (
67
+ "If use_ref_model is True, ref_per_token_logps or ref_input must be provided"
68
+ )
69
+ if ref_per_token_logps is not None and ref_input is not None:
70
+ raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
71
+ # Initialize accumulators
72
+ loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
73
+ grad_weight = torch.zeros_like(weight) # [V, H]
74
+ grad_inputs = []
75
+ grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
76
+ aggregated_metrics = []
77
+
78
+ # Create a partial function with fixed arguments
79
+ compute_loss = partial(
80
+ LigerFusedLinearPPOBase._compute_chunk_loss,
81
+ ref_weight=ref_weight,
82
+ ref_bias=ref_bias,
83
+ full_attention_mask=attention_mask,
84
+ epsilon_low=epsilon_low,
85
+ epsilon_high=epsilon_high,
86
+ beta=beta,
87
+ temperature=temperature,
88
+ use_ref_model=use_ref_model,
89
+ ppo_loss_fn=cls.ppo_loss_fn,
90
+ )
91
+
92
+ def fused_fwd_bwd(
93
+ input_chunk,
94
+ selected_token_ids_chunk,
95
+ attention_mask_chunk,
96
+ advantages_chunk,
97
+ ref_per_token_logps_chunk,
98
+ old_per_token_logps_chunk,
99
+ ref_input_chunk,
100
+ ):
101
+ """Fused forward and backward for a chunk."""
102
+ argnums = (0, 1, 5) if bias is not None else (0, 1)
103
+ return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
104
+ input_chunk, # arg 0
105
+ weight, # arg 1
106
+ selected_token_ids_chunk, # arg 2
107
+ attention_mask_chunk, # arg 3
108
+ advantages_chunk, # arg 4
109
+ bias, # arg 5
110
+ ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6
111
+ old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7
112
+ ref_input_chunk=ref_input_chunk, # arg 8
113
+ )
114
+
115
+ def accumulate_chunk(
116
+ input_chunk,
117
+ selected_token_ids_chunk,
118
+ attention_mask_chunk,
119
+ advantages_chunk,
120
+ ref_per_token_logps_chunk=None,
121
+ old_per_token_logps_chunk=None,
122
+ ref_input_chunk=None,
123
+ ):
124
+ (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
125
+ input_chunk,
126
+ selected_token_ids_chunk,
127
+ attention_mask_chunk,
128
+ advantages_chunk,
129
+ ref_per_token_logps_chunk,
130
+ old_per_token_logps_chunk,
131
+ ref_input_chunk,
132
+ )
133
+ if bias is not None:
134
+ grad_bias.add_(chunk_grad_bias[0])
135
+
136
+ # Accumulate gradients and loss
137
+ grad_weight.add_(chunk_grad_weight)
138
+ grad_inputs.append(chunk_grad_input)
139
+ loss_acc.add_(chunk_loss)
140
+ # Initialize storage for metrics on first chunk
141
+ if len(aggregated_metrics) == 0:
142
+ for metric in chunk_metrics:
143
+ if metric.ndim == 0:
144
+ aggregated_metrics.append(torch.zeros((), device=metric.device))
145
+ else:
146
+ aggregated_metrics.append([])
147
+
148
+ # Accumulate metrics
149
+ for i, metric in enumerate(chunk_metrics):
150
+ if metric.ndim == 0:
151
+ aggregated_metrics[i].add_(metric)
152
+ else:
153
+ aggregated_metrics[i].append(metric)
154
+
155
+ if compiled:
156
+ # TODO: Figure out what is better to compile here
157
+ # accumulate_chunk = torch.compile(accumulate_chunk)
158
+ fused_fwd_bwd = torch.compile(fused_fwd_bwd)
159
+
160
+ # Process input in chunks based on chunk_size
161
+ chunks = max(1, _input.shape[0] // chunk_size)
162
+ _input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
163
+ _selected_token_ids_chunks = torch.chunk(selected_token_ids, chunks=chunks, dim=0)
164
+ _attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
165
+ _advantages_chunks = torch.chunk(advantages, chunks=chunks, dim=0)
166
+ _ref_per_token_logps_chunks = (
167
+ torch.chunk(ref_per_token_logps, chunks=chunks, dim=0)
168
+ if use_ref_model and ref_per_token_logps is not None
169
+ else [None] * chunks
170
+ )
171
+ _old_per_token_logps_chunks = (
172
+ torch.chunk(old_per_token_logps, chunks=chunks, dim=0)
173
+ if old_per_token_logps is not None
174
+ else [None] * chunks
175
+ )
176
+ # if ref_log_probs is not none, then we don't need ref_input to calculate the log probs
177
+ _ref_input_chunks = (
178
+ torch.chunk(ref_input, chunks=chunks, dim=0)
179
+ if use_ref_model and ref_per_token_logps is None
180
+ else [None] * chunks
181
+ )
182
+
183
+ for (
184
+ input_chunk,
185
+ selected_token_ids_chunk,
186
+ attention_mask_chunk,
187
+ advantages_chunk,
188
+ ref_per_token_logps_chunk,
189
+ old_per_token_logps_chunk,
190
+ ref_input_chunk,
191
+ ) in zip(
192
+ _input_chunks,
193
+ _selected_token_ids_chunks,
194
+ _attention_mask_chunks,
195
+ _advantages_chunks,
196
+ _ref_per_token_logps_chunks,
197
+ _old_per_token_logps_chunks,
198
+ _ref_input_chunks,
199
+ ):
200
+ # Mark dynamic dimensions
201
+ torch._dynamo.mark_dynamic(input_chunk, 1)
202
+ torch._dynamo.mark_dynamic(selected_token_ids_chunk, 1)
203
+ torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
204
+ if ref_per_token_logps_chunk is not None:
205
+ torch._dynamo.mark_dynamic(ref_per_token_logps_chunk, 1)
206
+ if ref_input_chunk is not None:
207
+ torch._dynamo.mark_dynamic(ref_input_chunk, 1)
208
+ if old_per_token_logps_chunk is not None:
209
+ torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1)
210
+
211
+ accumulate_chunk(
212
+ input_chunk,
213
+ selected_token_ids_chunk,
214
+ attention_mask_chunk,
215
+ advantages_chunk,
216
+ ref_per_token_logps_chunk,
217
+ old_per_token_logps_chunk,
218
+ ref_input_chunk,
219
+ )
220
+
221
+ # Combine gradients
222
+ grad_input = torch.cat(grad_inputs, dim=0)
223
+
224
+ # Save for backward
225
+ ctx.save_for_backward(grad_input, grad_weight, grad_bias)
226
+
227
+ # Finalize metrics
228
+ final_metrics = []
229
+ for metric in aggregated_metrics:
230
+ if isinstance(metric, list):
231
+ final_metrics.append(torch.cat(metric, dim=0))
232
+ else:
233
+ final_metrics.append(metric)
234
+
235
+ return loss_acc, tuple(final_metrics)
236
+
237
+ @staticmethod
238
+ def _compute_chunk_loss(
239
+ input_chunk,
240
+ weight,
241
+ selected_token_ids_chunk,
242
+ attention_mask_chunk,
243
+ advantages_chunk,
244
+ bias=None,
245
+ ref_per_token_logps_chunk=None,
246
+ old_per_token_logps_chunk=None,
247
+ ref_input_chunk=None,
248
+ ref_weight=None,
249
+ ref_bias=None,
250
+ full_attention_mask=None,
251
+ epsilon_low=0.2,
252
+ epsilon_high=0.2,
253
+ beta=0.04,
254
+ temperature=1.0,
255
+ use_ref_model=False,
256
+ ppo_loss_fn=None,
257
+ ):
258
+ """Compute loss for a single chunk."""
259
+ # Get policy log probabilities using chunk_forward
260
+ log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(input_chunk, weight, bias=bias, temperature=temperature)
261
+
262
+ # Get reference log probabilities if needed
263
+ ref_log_probs = None
264
+ if use_ref_model and ref_per_token_logps_chunk is None:
265
+ with torch.no_grad():
266
+ ref_log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(
267
+ ref_input_chunk, ref_weight, bias=ref_bias, temperature=temperature
268
+ )
269
+
270
+ # Compute chunk loss and metrics using the provided loss function
271
+ chunk_loss, chunk_metrics = ppo_loss_fn(
272
+ log_probs=log_probs,
273
+ selected_token_ids=selected_token_ids_chunk,
274
+ attention_mask=attention_mask_chunk,
275
+ advantages=advantages_chunk,
276
+ full_attention_mask=full_attention_mask,
277
+ ref_per_token_logps=ref_per_token_logps_chunk.float() if ref_per_token_logps_chunk is not None else None,
278
+ old_per_token_logps=old_per_token_logps_chunk.float() if old_per_token_logps_chunk is not None else None,
279
+ ref_log_probs=ref_log_probs, # used when ref_per_token_logps is None
280
+ epsilon_low=epsilon_low,
281
+ epsilon_high=epsilon_high,
282
+ beta=beta,
283
+ )
284
+
285
+ return chunk_loss, chunk_metrics
286
+
287
+ @staticmethod
288
+ def chunk_forward(input_chunk, weight, bias=None, temperature=1.0):
289
+ """Forward pass computation for a single chunk without explicit reshaping."""
290
+ # Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
291
+ logits = torch.matmul(input_chunk, weight.t())
292
+ if bias is not None:
293
+ logits = logits + bias # Broadcasts bias to [B, T, V]
294
+ if temperature != 1.0:
295
+ logits = logits / temperature
296
+
297
+ # Compute log probabilities using softmax over the last dimension
298
+ log_probs = F.log_softmax(logits.float(), dim=-1)
299
+
300
+ return log_probs, logits
301
+
302
+ @staticmethod
303
+ def backward(ctx, grad_output, *grad_metrics):
304
+ """Backward pass for PPO loss."""
305
+ grad_input, grad_weight, grad_bias = ctx.saved_tensors
306
+ if grad_output != 1.0:
307
+ grad_input = grad_input * grad_output
308
+ grad_weight = grad_weight * grad_output
309
+ if grad_bias is not None:
310
+ grad_bias = grad_bias * grad_output
311
+
312
+ return (
313
+ grad_input,
314
+ grad_weight,
315
+ None, # grad_selected_token_ids
316
+ None, # grad_attention_mask
317
+ None, # grad_advantages
318
+ grad_bias,
319
+ None, # grad_ref_per_token_logps
320
+ None, # grad_old_per_token_logps
321
+ None, # grad_ref_input
322
+ None, # grad_ref_weight
323
+ None, # grad_ref_bias
324
+ None, # grad_epsilon_low
325
+ None, # grad_epsilon_high
326
+ None, # grad_beta
327
+ None, # grad_temperature
328
+ None, # grad_compiled
329
+ None, # grad_use_ref_model
330
+ None, # grad_chunk_size
331
+ )
@@ -1,66 +1,76 @@
1
1
  import torch
2
2
 
3
- from liger_kernel.chunked_loss.fused_linear_rlhf import LigerFusedLinearRLHFBase
3
+ from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
4
4
 
5
5
 
6
- class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
6
+ def k3_loss_fn(log_p, log_q):
7
+ # computes k3 estimate of KL[q, p]
8
+ # ref: http://joschu.net/blog/kl-approx.html
9
+ return torch.exp(log_p - log_q) - (log_p - log_q) - 1.0
10
+
11
+
12
+ def clip_coef_fn(coef, epsilon_low, epsilon_high):
13
+ return torch.clamp(coef, 1 - epsilon_low, 1 + epsilon_high)
14
+
15
+
16
+ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
7
17
  @staticmethod
8
- def rlhf_loss_fn(
18
+ def ppo_loss_fn(
9
19
  log_probs,
20
+ selected_token_ids,
10
21
  attention_mask,
11
- rewards,
12
- ref_log_probs=None,
13
- beta=0.1,
22
+ advantages,
23
+ full_attention_mask,
24
+ ref_per_token_logps=None, # shape: [chunk_size, seq_len]
25
+ old_per_token_logps=None,
26
+ ref_log_probs=None, # used when ref_per_token_logps is None (shape: [chunk_size, seq_len, vocab_size])
27
+ epsilon_low=0.2,
28
+ epsilon_high=0.2,
29
+ beta=0.04,
14
30
  **kwargs,
15
31
  ):
16
32
  """GRPO Loss Function matching GRPOTrainer implementation."""
17
- # Get chosen token probabilities
18
- chosen_tokens = log_probs.argmax(dim=-1) # (batch_size, seq_len)
19
- chosen_token_logprobs = log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(
33
+ per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
20
34
  -1
21
35
  ) # (batch_size, seq_len)
22
36
 
23
37
  # Get reference model probabilities
24
- if ref_log_probs is not None:
25
- with torch.no_grad():
26
- ref_token_logprobs = ref_log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(-1)
27
- else:
28
- ref_token_logprobs = chosen_token_logprobs.detach()
29
-
30
- # Compute advantages per batch entry in a grouped fashion
31
- mean_grouped_rewards = rewards.mean() # [batch_size,]
32
- std_grouped_rewards = rewards.std() # [batch_size,]
33
-
34
- # Calculate advantages using the same epsilon as in GRPOTrainer
35
- eps = 1e-4
36
- advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + eps)
38
+ if ref_per_token_logps is None:
39
+ if ref_log_probs is not None:
40
+ with torch.no_grad():
41
+ ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
42
+ -1
43
+ )
44
+ else:
45
+ ref_per_token_logps = per_token_logps.detach()
37
46
 
38
47
  # Compute policy gradient loss with importance sampling ratio
39
- ratio = torch.exp(chosen_token_logprobs - chosen_token_logprobs.detach())
40
- policy_loss = -ratio * advantages.unsqueeze(1)
41
-
42
- # Compute KL penalty
43
- kl_div = (
44
- torch.exp(ref_token_logprobs - chosen_token_logprobs) - (ref_token_logprobs - chosen_token_logprobs) - 1.0
45
- )
46
-
47
- # Combine losses
48
- per_token_loss = policy_loss + beta * kl_div
49
-
50
- # Apply masking and normalize
51
- masked_loss = per_token_loss * attention_mask
52
- seq_lengths = attention_mask.sum()
53
- seq_lengths = torch.clamp(seq_lengths, min=1.0)
54
- loss = masked_loss.sum() / seq_lengths
48
+ old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
49
+ coef_1 = torch.exp(per_token_logps - old_per_token_logps)
50
+ coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
51
+ per_token_loss1 = coef_1 * advantages.unsqueeze(1)
52
+ per_token_loss2 = coef_2 * advantages.unsqueeze(1)
53
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
54
+ if beta != 0.0:
55
+ # Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
56
+ kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
57
+ # Combine losses
58
+ per_token_loss = per_token_loss + beta * kl_div
59
+
60
+ # Note: We normalize by the number of tokens in the batch (using full_attention_mask),
61
+ # which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
62
+ # and TRL GRPO implementation
63
+ # (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
64
+ loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
55
65
 
56
66
  # Calculate metrics
57
- metrics = (
58
- chosen_token_logprobs.mean(), # mean log prob
59
- chosen_token_logprobs.std(), # std log prob
60
- log_probs.mean(), # mean all log probs
61
- ((kl_div * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)).mean(), # mean KL div
67
+ metrics = []
68
+ if beta != 0.0:
69
+ metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
70
+ is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
71
+ (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
62
72
  )
63
-
73
+ metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
64
74
  return loss, metrics
65
75
 
66
76
  @classmethod
@@ -69,16 +79,21 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
69
79
  ctx,
70
80
  _input,
71
81
  weight,
82
+ selected_token_ids,
72
83
  attention_mask,
73
- rewards,
84
+ advantages,
74
85
  bias=None,
86
+ ref_per_token_logps=None,
87
+ old_per_token_logps=None,
75
88
  ref_input=None,
76
89
  ref_weight=None,
77
90
  ref_bias=None,
78
- beta=0.1,
91
+ beta=0.04,
92
+ epsilon_low=0.2,
93
+ epsilon_high=0.2,
94
+ temperature=1.0,
79
95
  compiled=True,
80
96
  use_ref_model=True,
81
- num_generations=1,
82
97
  chunk_size=1,
83
98
  ):
84
99
  """
@@ -86,16 +101,18 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
86
101
  Args:
87
102
  _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
88
103
  weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
104
+ selected_token_ids (torch.Tensor): Selected token ids tensor. Shape: (batch_size, seq_len)
89
105
  attention_mask (torch.Tensor): Attention mask tensor. Shape: (batch_size, seq_len)
90
- rewards (torch.Tensor): Rewards tensor. Shape: (batch_size,)
106
+ advantages (torch.Tensor): Advantages tensor. Shape: (batch_size,)
91
107
  bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
108
+ ref_per_token_logps: Reference model log probs per token tensor. Shape:(batch_size, seq_len)
92
109
  ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
93
110
  ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
94
111
  ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
95
112
  beta (float): Weight for the KL penalty
113
+ temperature (float): Temperature for the logits
96
114
  compiled (bool): Whether to use torch compile
97
115
  use_ref_model (bool): Whether to use a reference model
98
- num_generations (int): Number of generations per prompt
99
116
  chunk_size (int): Size of chunks for processing.
100
117
  Returns:
101
118
  torch.Tensor: Computed loss
@@ -105,16 +122,21 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
105
122
  ctx=ctx,
106
123
  _input=_input,
107
124
  weight=weight,
125
+ selected_token_ids=selected_token_ids,
108
126
  attention_mask=attention_mask,
109
- rewards=rewards,
127
+ advantages=advantages,
110
128
  bias=bias,
129
+ ref_per_token_logps=ref_per_token_logps,
130
+ old_per_token_logps=old_per_token_logps,
111
131
  ref_input=ref_input,
112
132
  ref_weight=ref_weight,
113
133
  ref_bias=ref_bias,
114
134
  beta=beta,
135
+ epsilon_low=epsilon_low,
136
+ epsilon_high=epsilon_high,
137
+ temperature=temperature,
115
138
  compiled=compiled,
116
139
  use_ref_model=use_ref_model,
117
- num_generations=num_generations,
118
140
  chunk_size=chunk_size,
119
141
  )
120
142
 
@@ -126,16 +148,22 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
126
148
  grad_output: Gradient of the loss (scalar)
127
149
  grad_metrics: Gradients of the metrics (not used in backward computation)
128
150
  """
129
- grads = LigerFusedLinearRLHFBase.backward(ctx, grad_output)
151
+ grads = LigerFusedLinearPPOBase.backward(ctx, grad_output)
130
152
  return (
131
- *grads[:5], # grad_input, grad_weight, grad_attention_mask, grad_rewards, grad_bias
153
+ *grads[
154
+ :6
155
+ ], # grad_input, grad_weight, grad_selected_token_ids, grad_attention_mask, grad_advantages, grad_bias
156
+ None, # grad_ref_per_token_logps
157
+ None, # grad_old_per_token_logps
132
158
  None, # grad_ref_input
133
159
  None, # grad_ref_weight
134
160
  None, # grad_ref_bias
135
161
  None, # grad_beta
162
+ None, # grad_epsilon_low
163
+ None, # grad_epsilon_high
164
+ None, # grad_temperature
136
165
  None, # grad_compiled
137
166
  None, # grad_use_ref_model
138
- None, # grad_num_generations
139
167
  None, # grad_chunk_size
140
168
  )
141
169
 
@@ -145,34 +173,43 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
145
173
 
146
174
  def __init__(
147
175
  self,
148
- beta: float = 0.1,
176
+ beta: float = 0.04,
149
177
  compiled: bool = True,
150
178
  use_ref_model: bool = True,
151
- num_generations: int = 1,
152
179
  chunk_size: int = 1,
180
+ epsilon_low: float = 0.2,
181
+ epsilon_high: float = 0.2,
182
+ temperature: float = 1.0,
153
183
  ):
154
184
  """
155
185
  Args:
156
186
  beta (float): Weight for the KL penalty.
157
187
  compiled (bool): Whether to use torch compile.
158
188
  use_ref_model (bool): Whether to use a reference model.
159
- num_generations (int): Number of generations per prompt.
160
189
  chunk_size (int): Size of chunks for processing.
190
+ epsilon_low (float): Lower bound for the importance sampling ratio.
191
+ epsilon_high (float): Upper bound for the importance sampling ratio.
192
+ temperature (float): Temperature for the logits.
161
193
  """
162
194
  super().__init__()
163
195
  self.beta = beta
164
196
  self.compiled = compiled
165
197
  self.use_ref_model = use_ref_model
166
- self.num_generations = num_generations
167
198
  self.chunk_size = chunk_size
199
+ self.epsilon_low = epsilon_low
200
+ self.epsilon_high = epsilon_high
201
+ self.temperature = temperature
168
202
 
169
203
  def forward(
170
204
  self,
171
205
  _input,
172
206
  lin_weight,
207
+ selected_token_ids,
173
208
  attention_mask,
174
- rewards,
209
+ advantages,
175
210
  bias=None,
211
+ ref_per_token_logps=None,
212
+ old_per_token_logps=None,
176
213
  ref_input=None,
177
214
  ref_weight=None,
178
215
  ref_bias=None,
@@ -180,15 +217,20 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
180
217
  return LigerFusedLinearGRPOFunction.apply(
181
218
  _input,
182
219
  lin_weight,
220
+ selected_token_ids,
183
221
  attention_mask,
184
- rewards,
222
+ advantages,
185
223
  bias,
224
+ ref_per_token_logps,
225
+ old_per_token_logps,
186
226
  ref_input,
187
227
  ref_weight,
188
228
  ref_bias,
189
229
  self.beta,
230
+ self.epsilon_low,
231
+ self.epsilon_high,
232
+ self.temperature,
190
233
  self.compiled,
191
234
  self.use_ref_model,
192
- self.num_generations,
193
235
  self.chunk_size,
194
236
  )