liger-kernel 0.5.5__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +17 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +331 -0
- liger_kernel/chunked_loss/grpo_loss.py +103 -61
- liger_kernel/chunked_loss/jsd_loss.py +12 -7
- liger_kernel/ops/cross_entropy.py +3 -2
- liger_kernel/ops/dyt.py +225 -0
- liger_kernel/ops/fused_linear_jsd.py +2 -1
- liger_kernel/ops/jsd.py +30 -11
- liger_kernel/ops/kl_div.py +2 -2
- liger_kernel/transformers/__init__.py +3 -0
- liger_kernel/transformers/dyt.py +20 -0
- liger_kernel/transformers/functional.py +5 -0
- liger_kernel/transformers/model/gemma.py +8 -16
- liger_kernel/transformers/model/gemma2.py +7 -16
- liger_kernel/transformers/model/llama.py +8 -15
- liger_kernel/transformers/model/llava.py +369 -0
- liger_kernel/transformers/model/loss_utils.py +57 -0
- liger_kernel/transformers/model/mistral.py +9 -10
- liger_kernel/transformers/model/mixtral.py +8 -15
- liger_kernel/transformers/model/mllama.py +8 -15
- liger_kernel/transformers/model/olmo2.py +8 -16
- liger_kernel/transformers/model/paligemma.py +397 -0
- liger_kernel/transformers/model/phi3.py +8 -15
- liger_kernel/transformers/model/qwen2.py +8 -15
- liger_kernel/transformers/model/qwen2_5_vl.py +9 -10
- liger_kernel/transformers/model/qwen2_vl.py +9 -10
- liger_kernel/transformers/monkey_patch.py +219 -13
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/METADATA +9 -6
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/RECORD +34 -29
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/WHEEL +1 -1
- liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info/licenses}/LICENSE +0 -0
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info/licenses}/NOTICE +0 -0
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/top_level.txt +0 -0
|
@@ -1,240 +0,0 @@
|
|
|
1
|
-
from abc import abstractmethod
|
|
2
|
-
from functools import partial
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
import torch.nn.functional as F
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class LigerFusedLinearRLHFBase(torch.autograd.Function):
|
|
9
|
-
@abstractmethod
|
|
10
|
-
def rlhf_loss_fn(*args, **kwargs):
|
|
11
|
-
"""
|
|
12
|
-
To be extended by subclasses.
|
|
13
|
-
"""
|
|
14
|
-
raise NotImplementedError("RLHF loss function must be implemented.")
|
|
15
|
-
|
|
16
|
-
@staticmethod
|
|
17
|
-
def forward(
|
|
18
|
-
cls,
|
|
19
|
-
ctx,
|
|
20
|
-
_input,
|
|
21
|
-
weight,
|
|
22
|
-
attention_mask,
|
|
23
|
-
rewards,
|
|
24
|
-
bias=None,
|
|
25
|
-
num_generations=4,
|
|
26
|
-
beta=0.1,
|
|
27
|
-
compiled=True,
|
|
28
|
-
use_ref_model=False,
|
|
29
|
-
ref_input=None,
|
|
30
|
-
ref_weight=None,
|
|
31
|
-
ref_bias=None,
|
|
32
|
-
chunk_size=1,
|
|
33
|
-
):
|
|
34
|
-
"""Chunked forward pass for RLHF loss computation.
|
|
35
|
-
|
|
36
|
-
Args:
|
|
37
|
-
cls: The class
|
|
38
|
-
ctx: Context for backward
|
|
39
|
-
_input: Input tensor
|
|
40
|
-
weight: Weight tensor
|
|
41
|
-
attention_mask: Attention mask tensor
|
|
42
|
-
rewards: Rewards tensor
|
|
43
|
-
bias: Bias tensor
|
|
44
|
-
num_generations: Number of generations per prompt
|
|
45
|
-
beta: Weight for the KL penalty
|
|
46
|
-
compiled: Whether to use torch compile
|
|
47
|
-
use_ref_model: Whether to use a reference model
|
|
48
|
-
ref_input: Reference model input tensor
|
|
49
|
-
ref_weight: Reference model weight tensor
|
|
50
|
-
ref_bias: Reference model bias tensor
|
|
51
|
-
chunk_size: Size of chunks for processing in other loss modules
|
|
52
|
-
"""
|
|
53
|
-
# Save for backward
|
|
54
|
-
ctx.beta = beta
|
|
55
|
-
ctx.rewards = rewards
|
|
56
|
-
|
|
57
|
-
# Initialize accumulators
|
|
58
|
-
loss_acc = torch.zeros((), device=_input.device)
|
|
59
|
-
grad_weight = torch.zeros_like(weight) # [V, H]
|
|
60
|
-
grad_inputs = []
|
|
61
|
-
grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
|
|
62
|
-
aggregated_metrics = []
|
|
63
|
-
|
|
64
|
-
# Create a partial function with fixed arguments
|
|
65
|
-
compute_loss = partial(
|
|
66
|
-
LigerFusedLinearRLHFBase._compute_chunk_loss,
|
|
67
|
-
beta=beta,
|
|
68
|
-
use_ref_model=use_ref_model,
|
|
69
|
-
ref_weight=ref_weight,
|
|
70
|
-
ref_bias=ref_bias,
|
|
71
|
-
rlhf_loss_fn=cls.rlhf_loss_fn,
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
def fused_fwd_bwd(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk):
|
|
75
|
-
"""Fused forward and backward for a chunk."""
|
|
76
|
-
if bias is not None:
|
|
77
|
-
return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 5), has_aux=True)(
|
|
78
|
-
input_chunk, # arg 0
|
|
79
|
-
weight, # arg 1
|
|
80
|
-
attention_mask_chunk, # arg 2
|
|
81
|
-
rewards_chunk, # arg 3
|
|
82
|
-
ref_input_chunk, # arg 4
|
|
83
|
-
bias, # arg 5
|
|
84
|
-
)
|
|
85
|
-
else:
|
|
86
|
-
return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
|
|
87
|
-
input_chunk, # arg 0
|
|
88
|
-
weight, # arg 1
|
|
89
|
-
attention_mask_chunk, # arg 2
|
|
90
|
-
rewards_chunk, # arg 3
|
|
91
|
-
ref_input_chunk, # arg 4
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk=None):
|
|
95
|
-
if bias is not None:
|
|
96
|
-
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
|
|
97
|
-
input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
|
|
98
|
-
)
|
|
99
|
-
grad_bias.add_(chunk_grad_bias)
|
|
100
|
-
else:
|
|
101
|
-
(chunk_grad_input, chunk_grad_weight), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
|
|
102
|
-
input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
|
|
103
|
-
)
|
|
104
|
-
|
|
105
|
-
# Accumulate gradients and loss
|
|
106
|
-
grad_weight.add_(chunk_grad_weight)
|
|
107
|
-
grad_inputs.append(chunk_grad_input)
|
|
108
|
-
loss_acc.add_(chunk_loss)
|
|
109
|
-
|
|
110
|
-
# Initialize storage for metrics on first chunk
|
|
111
|
-
if len(aggregated_metrics) == 0:
|
|
112
|
-
for metric in chunk_metrics:
|
|
113
|
-
if metric.ndim == 0:
|
|
114
|
-
aggregated_metrics.append(torch.zeros((), device=metric.device))
|
|
115
|
-
else:
|
|
116
|
-
aggregated_metrics.append([])
|
|
117
|
-
|
|
118
|
-
# Accumulate metrics
|
|
119
|
-
for i, metric in enumerate(chunk_metrics):
|
|
120
|
-
if metric.ndim == 0:
|
|
121
|
-
aggregated_metrics[i].add_(metric)
|
|
122
|
-
else:
|
|
123
|
-
aggregated_metrics[i].append(metric)
|
|
124
|
-
|
|
125
|
-
if compiled:
|
|
126
|
-
accumulate_chunk = torch.compile(accumulate_chunk)
|
|
127
|
-
|
|
128
|
-
# Process input in chunks based on num_generations
|
|
129
|
-
chunks = max(1, _input.shape[0] // num_generations)
|
|
130
|
-
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
|
|
131
|
-
_attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
|
|
132
|
-
_rewards_chunks = torch.chunk(rewards, chunks=chunks, dim=0)
|
|
133
|
-
_ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) if use_ref_model else [None] * chunks
|
|
134
|
-
|
|
135
|
-
for input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk in zip(
|
|
136
|
-
_input_chunks, _attention_mask_chunks, _rewards_chunks, _ref_input_chunks
|
|
137
|
-
):
|
|
138
|
-
# Mark dynamic dimensions
|
|
139
|
-
torch._dynamo.mark_dynamic(input_chunk, 1)
|
|
140
|
-
torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
|
|
141
|
-
if ref_input_chunk is not None:
|
|
142
|
-
torch._dynamo.mark_dynamic(ref_input_chunk, 1)
|
|
143
|
-
|
|
144
|
-
accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk)
|
|
145
|
-
|
|
146
|
-
# Scale accumulated loss by number of chunks since we're averaging
|
|
147
|
-
loss_acc = loss_acc / chunks
|
|
148
|
-
|
|
149
|
-
# Combine gradients
|
|
150
|
-
grad_input = torch.cat(grad_inputs, dim=0)
|
|
151
|
-
|
|
152
|
-
# Save for backward
|
|
153
|
-
ctx.save_for_backward(grad_input, grad_weight, grad_bias)
|
|
154
|
-
|
|
155
|
-
# Finalize metrics
|
|
156
|
-
final_metrics = []
|
|
157
|
-
for metric in aggregated_metrics:
|
|
158
|
-
if isinstance(metric, list):
|
|
159
|
-
final_metrics.append(torch.cat(metric, dim=0))
|
|
160
|
-
else:
|
|
161
|
-
final_metrics.append(metric / chunks)
|
|
162
|
-
|
|
163
|
-
return loss_acc, tuple(final_metrics)
|
|
164
|
-
|
|
165
|
-
@staticmethod
|
|
166
|
-
def _compute_chunk_loss(
|
|
167
|
-
input_chunk,
|
|
168
|
-
weight,
|
|
169
|
-
attention_mask_chunk,
|
|
170
|
-
rewards_chunk,
|
|
171
|
-
ref_input_chunk=None,
|
|
172
|
-
bias=None,
|
|
173
|
-
beta=0.1,
|
|
174
|
-
use_ref_model=False,
|
|
175
|
-
ref_weight=None,
|
|
176
|
-
ref_bias=None,
|
|
177
|
-
rlhf_loss_fn=None,
|
|
178
|
-
):
|
|
179
|
-
"""Compute loss for a single chunk."""
|
|
180
|
-
# Get policy log probabilities using chunk_forward
|
|
181
|
-
log_probs, _, logits_mean = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias)
|
|
182
|
-
|
|
183
|
-
# Get reference log probabilities if needed
|
|
184
|
-
ref_log_probs = None
|
|
185
|
-
if use_ref_model and ref_input_chunk is not None:
|
|
186
|
-
with torch.no_grad():
|
|
187
|
-
ref_log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward(ref_input_chunk, ref_weight, bias=ref_bias)
|
|
188
|
-
|
|
189
|
-
# Compute chunk loss and metrics using the provided loss function
|
|
190
|
-
chunk_loss, chunk_metrics = rlhf_loss_fn(
|
|
191
|
-
log_probs=log_probs,
|
|
192
|
-
attention_mask=attention_mask_chunk,
|
|
193
|
-
rewards=rewards_chunk,
|
|
194
|
-
ref_log_probs=ref_log_probs,
|
|
195
|
-
beta=beta,
|
|
196
|
-
)
|
|
197
|
-
|
|
198
|
-
return chunk_loss, (logits_mean, *chunk_metrics)
|
|
199
|
-
|
|
200
|
-
@staticmethod
|
|
201
|
-
def chunk_forward(input_chunk, weight, bias=None):
|
|
202
|
-
"""Forward pass computation for a single chunk without explicit reshaping."""
|
|
203
|
-
# Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
|
|
204
|
-
logits = torch.matmul(input_chunk, weight.t())
|
|
205
|
-
if bias is not None:
|
|
206
|
-
logits = logits + bias # Broadcasts bias to [B, T, V]
|
|
207
|
-
|
|
208
|
-
# Compute log probabilities using softmax over the last dimension
|
|
209
|
-
log_probs = F.log_softmax(logits.float(), dim=-1)
|
|
210
|
-
|
|
211
|
-
# Monitoring: compute mean of logits
|
|
212
|
-
batch_size, seq_len, _ = input_chunk.shape
|
|
213
|
-
logits_mean = logits.sum() / (batch_size * seq_len * weight.shape[0])
|
|
214
|
-
return log_probs, logits, logits_mean
|
|
215
|
-
|
|
216
|
-
@staticmethod
|
|
217
|
-
def backward(ctx, grad_output, *grad_metrics):
|
|
218
|
-
"""Backward pass for RLHF loss."""
|
|
219
|
-
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
220
|
-
if grad_output != 1.0:
|
|
221
|
-
grad_input = grad_input * grad_output
|
|
222
|
-
grad_weight = grad_weight * grad_output
|
|
223
|
-
if grad_bias is not None:
|
|
224
|
-
grad_bias = grad_bias * grad_output
|
|
225
|
-
|
|
226
|
-
return (
|
|
227
|
-
grad_input,
|
|
228
|
-
grad_weight,
|
|
229
|
-
None, # grad_attention_mask
|
|
230
|
-
None, # grad_rewards
|
|
231
|
-
grad_bias,
|
|
232
|
-
None, # grad_num_generations
|
|
233
|
-
None, # grad_beta
|
|
234
|
-
None, # grad_compiled
|
|
235
|
-
None, # grad_use_ref_model
|
|
236
|
-
None, # grad_ref_input
|
|
237
|
-
None, # grad_ref_weight
|
|
238
|
-
None, # grad_ref_bias
|
|
239
|
-
None, # grad_chunk_size
|
|
240
|
-
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|