liger-kernel 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. liger_kernel/chunked_loss/functional.py +2 -0
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +17 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +346 -0
  4. liger_kernel/chunked_loss/grpo_loss.py +134 -60
  5. liger_kernel/chunked_loss/jsd_loss.py +12 -7
  6. liger_kernel/ops/cross_entropy.py +3 -2
  7. liger_kernel/ops/dyt.py +225 -0
  8. liger_kernel/ops/fused_linear_jsd.py +2 -1
  9. liger_kernel/ops/jsd.py +32 -12
  10. liger_kernel/ops/kl_div.py +15 -8
  11. liger_kernel/ops/layer_norm.py +14 -1
  12. liger_kernel/ops/rms_norm.py +12 -1
  13. liger_kernel/transformers/__init__.py +133 -15
  14. liger_kernel/transformers/dyt.py +20 -0
  15. liger_kernel/transformers/functional.py +5 -0
  16. liger_kernel/transformers/gema3_rms.py +8 -0
  17. liger_kernel/transformers/model/gemma.py +17 -20
  18. liger_kernel/transformers/model/gemma2.py +17 -21
  19. liger_kernel/transformers/model/gemma3.py +335 -0
  20. liger_kernel/transformers/model/llama.py +17 -19
  21. liger_kernel/transformers/model/llava.py +369 -0
  22. liger_kernel/transformers/model/loss_utils.py +64 -0
  23. liger_kernel/transformers/model/mistral.py +28 -25
  24. liger_kernel/transformers/model/mixtral.py +20 -26
  25. liger_kernel/transformers/model/mllama.py +17 -19
  26. liger_kernel/transformers/model/olmo2.py +17 -20
  27. liger_kernel/transformers/model/paligemma.py +397 -0
  28. liger_kernel/transformers/model/phi3.py +17 -19
  29. liger_kernel/transformers/model/qwen2.py +17 -19
  30. liger_kernel/transformers/model/qwen2_5_vl.py +9 -10
  31. liger_kernel/transformers/model/qwen2_vl.py +9 -10
  32. liger_kernel/transformers/monkey_patch.py +392 -13
  33. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/METADATA +11 -6
  34. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/RECORD +38 -31
  35. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/WHEEL +1 -1
  36. liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
  37. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/LICENSE +0 -0
  38. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/NOTICE +0 -0
  39. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
2
2
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3
+ from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
3
4
  from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
4
5
  from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
5
6
  from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
@@ -11,3 +12,4 @@ liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
11
12
  liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
12
13
  liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
13
14
  liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
15
+ liger_fused_linear_grpo = LigerFusedLinearGRPOFunction.apply
@@ -115,9 +115,24 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
115
115
  student_logits_chunk /= temperature
116
116
  teacher_logits_chunk /= temperature
117
117
 
118
+ # If the teacher and student token size is different, pad student logits to match the teacher's.
119
+ # This only applies to cases where they share exactly the same vocab and tokenizer just
120
+ # that teacher logit is padded for some training efficiency such as
121
+ # https://huggingface.co/Qwen/Qwen1.5-72B-Chat/discussions/1#662883f568adf59b07b176d2
122
+ teacher_vocab_size = teacher_weight.shape[0]
123
+ student_vocab_size = student_weight.shape[0]
124
+ if teacher_vocab_size > student_vocab_size:
125
+ pad_size = teacher_vocab_size - student_vocab_size
126
+ pad_tensor = torch.zeros(
127
+ (*student_logits_chunk.shape[:-1], pad_size),
128
+ dtype=student_logits_chunk.dtype,
129
+ device=student_logits_chunk.device,
130
+ )
131
+ student_logits_chunk = torch.cat([student_logits_chunk, pad_tensor], dim=-1)
132
+
118
133
  hard_loss /= full_target.shape[0]
119
134
 
120
- soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
135
+ soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, **loss_kwargs)
121
136
  soft_loss /= full_target.shape[0]
122
137
 
123
138
  loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
@@ -180,9 +195,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
180
195
  ignore_index=ignore_index,
181
196
  weight_hard_loss=weight_hard_loss,
182
197
  weight_soft_loss=weight_soft_loss,
183
- beta=beta,
184
198
  compute_ce_loss=compute_ce_loss,
185
199
  temperature=temperature,
200
+ beta=beta,
186
201
  **loss_kwargs,
187
202
  )
188
203
 
@@ -0,0 +1,346 @@
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+
4
+ import torch
5
+ import torch._dynamo.config
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class LigerFusedLinearPPOBase(torch.autograd.Function):
10
+ @abstractmethod
11
+ def ppo_loss_fn(*args, **kwargs):
12
+ """
13
+ To be extended by subclasses.
14
+ """
15
+ raise NotImplementedError("PPO loss function must be implemented.")
16
+
17
+ @staticmethod
18
+ def forward(
19
+ cls,
20
+ ctx,
21
+ _input,
22
+ weight,
23
+ selected_token_ids,
24
+ attention_mask,
25
+ advantages,
26
+ bias=None,
27
+ ref_per_token_logps=None,
28
+ old_per_token_logps=None,
29
+ ref_input=None,
30
+ ref_weight=None,
31
+ ref_bias=None,
32
+ epsilon_low=0.2,
33
+ epsilon_high=0.2,
34
+ beta=0.04,
35
+ loss_type="bnpo",
36
+ max_completion_length=None,
37
+ temperature=1.0,
38
+ compiled=True,
39
+ use_ref_model=False,
40
+ chunk_size=1,
41
+ ):
42
+ # TODO: check torch compile matmul
43
+ """Chunked forward pass for PPO loss computation.
44
+
45
+ Args:
46
+ cls: The class
47
+ ctx: Context for backward
48
+ _input: Input tensor
49
+ weight: Weight tensor
50
+ selected_token_ids: Selected token ids tensor
51
+ attention_mask: Attention mask tensor
52
+ advantages: Advantages tensor
53
+ bias: Bias tensor
54
+ ref_per_token_logps: Reference model log probs per token tensor
55
+ old_per_token_logps: Old per token log probabilities tensor
56
+ ref_input: Reference model input tensor
57
+ ref_weight: Reference model weight tensor
58
+ ref_bias: Reference model bias tensor
59
+ epsilon_low: Lower bound for clipping the importance sampling ratio
60
+ epsilon_high: Upper bound for clipping the importance sampling ratio
61
+ beta: Weight for the KL penalty
62
+ loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo")
63
+ max_completion_length: Maximum completion length required for "dr_grpo"
64
+ temperature: Temperature for the logits
65
+ compiled: Whether to use torch compile
66
+ use_ref_model: Whether to use a reference model
67
+ chunk_size: Size of chunks for processing in other loss modules
68
+ """
69
+ if use_ref_model:
70
+ assert ref_per_token_logps is not None or ref_input is not None, (
71
+ "If use_ref_model is True, ref_per_token_logps or ref_input must be provided"
72
+ )
73
+ if ref_per_token_logps is not None and ref_input is not None:
74
+ raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
75
+ if loss_type == "dr_grpo":
76
+ assert max_completion_length is not None, "max_completion_length must be provided for loss_type 'dr_grpo'"
77
+ # Initialize accumulators
78
+ loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
79
+ grad_weight = torch.zeros_like(weight) # [V, H]
80
+ grad_inputs = []
81
+ grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
82
+ aggregated_metrics = []
83
+
84
+ # Create a partial function with fixed arguments
85
+ compute_loss = partial(
86
+ LigerFusedLinearPPOBase._compute_chunk_loss,
87
+ ref_weight=ref_weight,
88
+ ref_bias=ref_bias,
89
+ full_attention_mask=attention_mask,
90
+ epsilon_low=epsilon_low,
91
+ epsilon_high=epsilon_high,
92
+ beta=beta,
93
+ loss_type=loss_type,
94
+ max_completion_length=max_completion_length,
95
+ temperature=temperature,
96
+ use_ref_model=use_ref_model,
97
+ ppo_loss_fn=cls.ppo_loss_fn,
98
+ )
99
+
100
+ def fused_fwd_bwd(
101
+ input_chunk,
102
+ selected_token_ids_chunk,
103
+ attention_mask_chunk,
104
+ advantages_chunk,
105
+ ref_per_token_logps_chunk,
106
+ old_per_token_logps_chunk,
107
+ ref_input_chunk,
108
+ ):
109
+ """Fused forward and backward for a chunk."""
110
+ argnums = (0, 1, 5) if bias is not None else (0, 1)
111
+ return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
112
+ input_chunk, # arg 0
113
+ weight, # arg 1
114
+ selected_token_ids_chunk, # arg 2
115
+ attention_mask_chunk, # arg 3
116
+ advantages_chunk, # arg 4
117
+ bias, # arg 5
118
+ ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6
119
+ old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7
120
+ ref_input_chunk=ref_input_chunk, # arg 8
121
+ )
122
+
123
+ def accumulate_chunk(
124
+ input_chunk,
125
+ selected_token_ids_chunk,
126
+ attention_mask_chunk,
127
+ advantages_chunk,
128
+ ref_per_token_logps_chunk=None,
129
+ old_per_token_logps_chunk=None,
130
+ ref_input_chunk=None,
131
+ ):
132
+ (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
133
+ input_chunk,
134
+ selected_token_ids_chunk,
135
+ attention_mask_chunk,
136
+ advantages_chunk,
137
+ ref_per_token_logps_chunk,
138
+ old_per_token_logps_chunk,
139
+ ref_input_chunk,
140
+ )
141
+ if bias is not None:
142
+ grad_bias.add_(chunk_grad_bias[0])
143
+
144
+ # Accumulate gradients and loss
145
+ grad_weight.add_(chunk_grad_weight)
146
+ grad_inputs.append(chunk_grad_input)
147
+ loss_acc.add_(chunk_loss)
148
+ # Initialize storage for metrics on first chunk
149
+ if len(aggregated_metrics) == 0:
150
+ for metric in chunk_metrics:
151
+ if metric.ndim == 0:
152
+ aggregated_metrics.append(torch.zeros((), device=metric.device))
153
+ else:
154
+ aggregated_metrics.append([])
155
+
156
+ # Accumulate metrics
157
+ for i, metric in enumerate(chunk_metrics):
158
+ if metric.ndim == 0:
159
+ aggregated_metrics[i].add_(metric)
160
+ else:
161
+ aggregated_metrics[i].append(metric)
162
+
163
+ if compiled:
164
+ # TODO: Figure out what is better to compile here
165
+ # accumulate_chunk = torch.compile(accumulate_chunk)
166
+ fused_fwd_bwd = torch.compile(fused_fwd_bwd)
167
+
168
+ # Process input in chunks based on chunk_size
169
+ chunks = max(1, _input.shape[0] // chunk_size)
170
+ _input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
171
+ _selected_token_ids_chunks = torch.chunk(selected_token_ids, chunks=chunks, dim=0)
172
+ _attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
173
+ _advantages_chunks = torch.chunk(advantages, chunks=chunks, dim=0)
174
+ _ref_per_token_logps_chunks = (
175
+ torch.chunk(ref_per_token_logps, chunks=chunks, dim=0)
176
+ if use_ref_model and ref_per_token_logps is not None
177
+ else [None] * chunks
178
+ )
179
+ _old_per_token_logps_chunks = (
180
+ torch.chunk(old_per_token_logps, chunks=chunks, dim=0)
181
+ if old_per_token_logps is not None
182
+ else [None] * chunks
183
+ )
184
+ # if ref_log_probs is not none, then we don't need ref_input to calculate the log probs
185
+ _ref_input_chunks = (
186
+ torch.chunk(ref_input, chunks=chunks, dim=0)
187
+ if use_ref_model and ref_per_token_logps is None
188
+ else [None] * chunks
189
+ )
190
+
191
+ for (
192
+ input_chunk,
193
+ selected_token_ids_chunk,
194
+ attention_mask_chunk,
195
+ advantages_chunk,
196
+ ref_per_token_logps_chunk,
197
+ old_per_token_logps_chunk,
198
+ ref_input_chunk,
199
+ ) in zip(
200
+ _input_chunks,
201
+ _selected_token_ids_chunks,
202
+ _attention_mask_chunks,
203
+ _advantages_chunks,
204
+ _ref_per_token_logps_chunks,
205
+ _old_per_token_logps_chunks,
206
+ _ref_input_chunks,
207
+ ):
208
+ # Mark dynamic dimensions
209
+ torch._dynamo.mark_dynamic(input_chunk, 1)
210
+ torch._dynamo.mark_dynamic(selected_token_ids_chunk, 1)
211
+ torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
212
+ if ref_per_token_logps_chunk is not None:
213
+ torch._dynamo.mark_dynamic(ref_per_token_logps_chunk, 1)
214
+ if ref_input_chunk is not None:
215
+ torch._dynamo.mark_dynamic(ref_input_chunk, 1)
216
+ if old_per_token_logps_chunk is not None:
217
+ torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1)
218
+
219
+ accumulate_chunk(
220
+ input_chunk,
221
+ selected_token_ids_chunk,
222
+ attention_mask_chunk,
223
+ advantages_chunk,
224
+ ref_per_token_logps_chunk,
225
+ old_per_token_logps_chunk,
226
+ ref_input_chunk,
227
+ )
228
+
229
+ # Combine gradients
230
+ grad_input = torch.cat(grad_inputs, dim=0)
231
+
232
+ # Save for backward
233
+ ctx.save_for_backward(grad_input, grad_weight, grad_bias)
234
+
235
+ # Finalize metrics
236
+ final_metrics = []
237
+ for metric in aggregated_metrics:
238
+ if isinstance(metric, list):
239
+ final_metrics.append(torch.cat(metric, dim=0))
240
+ else:
241
+ final_metrics.append(metric)
242
+
243
+ return loss_acc, tuple(final_metrics)
244
+
245
+ @staticmethod
246
+ def _compute_chunk_loss(
247
+ input_chunk,
248
+ weight,
249
+ selected_token_ids_chunk,
250
+ attention_mask_chunk,
251
+ advantages_chunk,
252
+ bias=None,
253
+ ref_per_token_logps_chunk=None,
254
+ old_per_token_logps_chunk=None,
255
+ ref_input_chunk=None,
256
+ ref_weight=None,
257
+ ref_bias=None,
258
+ full_attention_mask=None,
259
+ epsilon_low=0.2,
260
+ epsilon_high=0.2,
261
+ beta=0.04,
262
+ loss_type="bnpo",
263
+ max_completion_length=None,
264
+ temperature=1.0,
265
+ use_ref_model=False,
266
+ ppo_loss_fn=None,
267
+ ):
268
+ """Compute loss for a single chunk."""
269
+ # Get policy log probabilities using chunk_forward
270
+ log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(input_chunk, weight, bias=bias, temperature=temperature)
271
+
272
+ # Get reference log probabilities if needed
273
+ ref_log_probs = None
274
+ if use_ref_model and ref_per_token_logps_chunk is None:
275
+ with torch.no_grad():
276
+ ref_log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(
277
+ ref_input_chunk, ref_weight, bias=ref_bias, temperature=temperature
278
+ )
279
+
280
+ # Compute chunk loss and metrics using the provided loss function
281
+ chunk_loss, chunk_metrics = ppo_loss_fn(
282
+ log_probs=log_probs,
283
+ selected_token_ids=selected_token_ids_chunk,
284
+ attention_mask=attention_mask_chunk,
285
+ advantages=advantages_chunk,
286
+ full_attention_mask=full_attention_mask,
287
+ ref_per_token_logps=ref_per_token_logps_chunk.float() if ref_per_token_logps_chunk is not None else None,
288
+ old_per_token_logps=old_per_token_logps_chunk.float() if old_per_token_logps_chunk is not None else None,
289
+ ref_log_probs=ref_log_probs, # used when ref_per_token_logps is None
290
+ epsilon_low=epsilon_low,
291
+ epsilon_high=epsilon_high,
292
+ beta=beta,
293
+ loss_type=loss_type,
294
+ max_completion_length=max_completion_length,
295
+ )
296
+
297
+ return chunk_loss, chunk_metrics
298
+
299
+ @staticmethod
300
+ def chunk_forward(input_chunk, weight, bias=None, temperature=1.0):
301
+ """Forward pass computation for a single chunk without explicit reshaping."""
302
+ # Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
303
+ logits = torch.matmul(input_chunk, weight.t())
304
+ if bias is not None:
305
+ logits = logits + bias # Broadcasts bias to [B, T, V]
306
+ if temperature != 1.0:
307
+ logits = logits / temperature
308
+
309
+ # Compute log probabilities using softmax over the last dimension
310
+ log_probs = F.log_softmax(logits.float(), dim=-1)
311
+
312
+ return log_probs, logits
313
+
314
+ @staticmethod
315
+ def backward(ctx, grad_output, *grad_metrics):
316
+ """Backward pass for PPO loss."""
317
+ grad_input, grad_weight, grad_bias = ctx.saved_tensors
318
+
319
+ if grad_output != 1.0:
320
+ grad_input = grad_input * grad_output
321
+ grad_weight = grad_weight * grad_output
322
+ if grad_bias is not None:
323
+ grad_bias = grad_bias * grad_output
324
+
325
+ return (
326
+ grad_input,
327
+ grad_weight,
328
+ None, # grad_selected_token_ids
329
+ None, # grad_attention_mask
330
+ None, # grad_advantages
331
+ grad_bias,
332
+ None, # grad_ref_per_token_logps
333
+ None, # grad_old_per_token_logps
334
+ None, # grad_ref_input
335
+ None, # grad_ref_weight
336
+ None, # grad_ref_bias
337
+ None, # grad_epsilon_low
338
+ None, # grad_epsilon_high
339
+ None, # grad_beta
340
+ None, # grad_temperature
341
+ None, # grad_compiled
342
+ None, # grad_use_ref_model
343
+ None, # grad_chunk_size
344
+ None, # grad_loss_type
345
+ None, # grad_max_completion_length
346
+ )