liger-kernel 0.5.5__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. liger_kernel/chunked_loss/functional.py +2 -0
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +17 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +331 -0
  4. liger_kernel/chunked_loss/grpo_loss.py +103 -61
  5. liger_kernel/chunked_loss/jsd_loss.py +12 -7
  6. liger_kernel/ops/cross_entropy.py +3 -2
  7. liger_kernel/ops/dyt.py +225 -0
  8. liger_kernel/ops/fused_linear_jsd.py +2 -1
  9. liger_kernel/ops/jsd.py +30 -11
  10. liger_kernel/ops/kl_div.py +2 -2
  11. liger_kernel/transformers/__init__.py +3 -0
  12. liger_kernel/transformers/dyt.py +20 -0
  13. liger_kernel/transformers/functional.py +5 -0
  14. liger_kernel/transformers/model/gemma.py +8 -16
  15. liger_kernel/transformers/model/gemma2.py +7 -16
  16. liger_kernel/transformers/model/llama.py +8 -15
  17. liger_kernel/transformers/model/llava.py +369 -0
  18. liger_kernel/transformers/model/loss_utils.py +57 -0
  19. liger_kernel/transformers/model/mistral.py +9 -10
  20. liger_kernel/transformers/model/mixtral.py +8 -15
  21. liger_kernel/transformers/model/mllama.py +8 -15
  22. liger_kernel/transformers/model/olmo2.py +8 -16
  23. liger_kernel/transformers/model/paligemma.py +397 -0
  24. liger_kernel/transformers/model/phi3.py +8 -15
  25. liger_kernel/transformers/model/qwen2.py +8 -15
  26. liger_kernel/transformers/model/qwen2_5_vl.py +9 -10
  27. liger_kernel/transformers/model/qwen2_vl.py +9 -10
  28. liger_kernel/transformers/monkey_patch.py +219 -13
  29. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/METADATA +9 -6
  30. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/RECORD +34 -29
  31. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/WHEEL +1 -1
  32. liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
  33. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info/licenses}/LICENSE +0 -0
  34. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info/licenses}/NOTICE +0 -0
  35. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/top_level.txt +0 -0
@@ -1,240 +0,0 @@
1
- from abc import abstractmethod
2
- from functools import partial
3
-
4
- import torch
5
- import torch.nn.functional as F
6
-
7
-
8
- class LigerFusedLinearRLHFBase(torch.autograd.Function):
9
- @abstractmethod
10
- def rlhf_loss_fn(*args, **kwargs):
11
- """
12
- To be extended by subclasses.
13
- """
14
- raise NotImplementedError("RLHF loss function must be implemented.")
15
-
16
- @staticmethod
17
- def forward(
18
- cls,
19
- ctx,
20
- _input,
21
- weight,
22
- attention_mask,
23
- rewards,
24
- bias=None,
25
- num_generations=4,
26
- beta=0.1,
27
- compiled=True,
28
- use_ref_model=False,
29
- ref_input=None,
30
- ref_weight=None,
31
- ref_bias=None,
32
- chunk_size=1,
33
- ):
34
- """Chunked forward pass for RLHF loss computation.
35
-
36
- Args:
37
- cls: The class
38
- ctx: Context for backward
39
- _input: Input tensor
40
- weight: Weight tensor
41
- attention_mask: Attention mask tensor
42
- rewards: Rewards tensor
43
- bias: Bias tensor
44
- num_generations: Number of generations per prompt
45
- beta: Weight for the KL penalty
46
- compiled: Whether to use torch compile
47
- use_ref_model: Whether to use a reference model
48
- ref_input: Reference model input tensor
49
- ref_weight: Reference model weight tensor
50
- ref_bias: Reference model bias tensor
51
- chunk_size: Size of chunks for processing in other loss modules
52
- """
53
- # Save for backward
54
- ctx.beta = beta
55
- ctx.rewards = rewards
56
-
57
- # Initialize accumulators
58
- loss_acc = torch.zeros((), device=_input.device)
59
- grad_weight = torch.zeros_like(weight) # [V, H]
60
- grad_inputs = []
61
- grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
62
- aggregated_metrics = []
63
-
64
- # Create a partial function with fixed arguments
65
- compute_loss = partial(
66
- LigerFusedLinearRLHFBase._compute_chunk_loss,
67
- beta=beta,
68
- use_ref_model=use_ref_model,
69
- ref_weight=ref_weight,
70
- ref_bias=ref_bias,
71
- rlhf_loss_fn=cls.rlhf_loss_fn,
72
- )
73
-
74
- def fused_fwd_bwd(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk):
75
- """Fused forward and backward for a chunk."""
76
- if bias is not None:
77
- return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 5), has_aux=True)(
78
- input_chunk, # arg 0
79
- weight, # arg 1
80
- attention_mask_chunk, # arg 2
81
- rewards_chunk, # arg 3
82
- ref_input_chunk, # arg 4
83
- bias, # arg 5
84
- )
85
- else:
86
- return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
87
- input_chunk, # arg 0
88
- weight, # arg 1
89
- attention_mask_chunk, # arg 2
90
- rewards_chunk, # arg 3
91
- ref_input_chunk, # arg 4
92
- )
93
-
94
- def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk=None):
95
- if bias is not None:
96
- (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
97
- input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
98
- )
99
- grad_bias.add_(chunk_grad_bias)
100
- else:
101
- (chunk_grad_input, chunk_grad_weight), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
102
- input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
103
- )
104
-
105
- # Accumulate gradients and loss
106
- grad_weight.add_(chunk_grad_weight)
107
- grad_inputs.append(chunk_grad_input)
108
- loss_acc.add_(chunk_loss)
109
-
110
- # Initialize storage for metrics on first chunk
111
- if len(aggregated_metrics) == 0:
112
- for metric in chunk_metrics:
113
- if metric.ndim == 0:
114
- aggregated_metrics.append(torch.zeros((), device=metric.device))
115
- else:
116
- aggregated_metrics.append([])
117
-
118
- # Accumulate metrics
119
- for i, metric in enumerate(chunk_metrics):
120
- if metric.ndim == 0:
121
- aggregated_metrics[i].add_(metric)
122
- else:
123
- aggregated_metrics[i].append(metric)
124
-
125
- if compiled:
126
- accumulate_chunk = torch.compile(accumulate_chunk)
127
-
128
- # Process input in chunks based on num_generations
129
- chunks = max(1, _input.shape[0] // num_generations)
130
- _input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
131
- _attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
132
- _rewards_chunks = torch.chunk(rewards, chunks=chunks, dim=0)
133
- _ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) if use_ref_model else [None] * chunks
134
-
135
- for input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk in zip(
136
- _input_chunks, _attention_mask_chunks, _rewards_chunks, _ref_input_chunks
137
- ):
138
- # Mark dynamic dimensions
139
- torch._dynamo.mark_dynamic(input_chunk, 1)
140
- torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
141
- if ref_input_chunk is not None:
142
- torch._dynamo.mark_dynamic(ref_input_chunk, 1)
143
-
144
- accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk)
145
-
146
- # Scale accumulated loss by number of chunks since we're averaging
147
- loss_acc = loss_acc / chunks
148
-
149
- # Combine gradients
150
- grad_input = torch.cat(grad_inputs, dim=0)
151
-
152
- # Save for backward
153
- ctx.save_for_backward(grad_input, grad_weight, grad_bias)
154
-
155
- # Finalize metrics
156
- final_metrics = []
157
- for metric in aggregated_metrics:
158
- if isinstance(metric, list):
159
- final_metrics.append(torch.cat(metric, dim=0))
160
- else:
161
- final_metrics.append(metric / chunks)
162
-
163
- return loss_acc, tuple(final_metrics)
164
-
165
- @staticmethod
166
- def _compute_chunk_loss(
167
- input_chunk,
168
- weight,
169
- attention_mask_chunk,
170
- rewards_chunk,
171
- ref_input_chunk=None,
172
- bias=None,
173
- beta=0.1,
174
- use_ref_model=False,
175
- ref_weight=None,
176
- ref_bias=None,
177
- rlhf_loss_fn=None,
178
- ):
179
- """Compute loss for a single chunk."""
180
- # Get policy log probabilities using chunk_forward
181
- log_probs, _, logits_mean = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias)
182
-
183
- # Get reference log probabilities if needed
184
- ref_log_probs = None
185
- if use_ref_model and ref_input_chunk is not None:
186
- with torch.no_grad():
187
- ref_log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward(ref_input_chunk, ref_weight, bias=ref_bias)
188
-
189
- # Compute chunk loss and metrics using the provided loss function
190
- chunk_loss, chunk_metrics = rlhf_loss_fn(
191
- log_probs=log_probs,
192
- attention_mask=attention_mask_chunk,
193
- rewards=rewards_chunk,
194
- ref_log_probs=ref_log_probs,
195
- beta=beta,
196
- )
197
-
198
- return chunk_loss, (logits_mean, *chunk_metrics)
199
-
200
- @staticmethod
201
- def chunk_forward(input_chunk, weight, bias=None):
202
- """Forward pass computation for a single chunk without explicit reshaping."""
203
- # Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
204
- logits = torch.matmul(input_chunk, weight.t())
205
- if bias is not None:
206
- logits = logits + bias # Broadcasts bias to [B, T, V]
207
-
208
- # Compute log probabilities using softmax over the last dimension
209
- log_probs = F.log_softmax(logits.float(), dim=-1)
210
-
211
- # Monitoring: compute mean of logits
212
- batch_size, seq_len, _ = input_chunk.shape
213
- logits_mean = logits.sum() / (batch_size * seq_len * weight.shape[0])
214
- return log_probs, logits, logits_mean
215
-
216
- @staticmethod
217
- def backward(ctx, grad_output, *grad_metrics):
218
- """Backward pass for RLHF loss."""
219
- grad_input, grad_weight, grad_bias = ctx.saved_tensors
220
- if grad_output != 1.0:
221
- grad_input = grad_input * grad_output
222
- grad_weight = grad_weight * grad_output
223
- if grad_bias is not None:
224
- grad_bias = grad_bias * grad_output
225
-
226
- return (
227
- grad_input,
228
- grad_weight,
229
- None, # grad_attention_mask
230
- None, # grad_rewards
231
- grad_bias,
232
- None, # grad_num_generations
233
- None, # grad_beta
234
- None, # grad_compiled
235
- None, # grad_use_ref_model
236
- None, # grad_ref_input
237
- None, # grad_ref_weight
238
- None, # grad_ref_bias
239
- None, # grad_chunk_size
240
- )