liger-kernel 0.5.3__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/fused_linear_rlhf.py +213 -0
- liger_kernel/chunked_loss/grpo_loss.py +160 -0
- liger_kernel/chunked_loss/kto_loss.py +9 -9
- liger_kernel/ops/cross_entropy.py +3 -3
- liger_kernel/ops/fused_linear_cross_entropy.py +3 -3
- liger_kernel/ops/fused_linear_jsd.py +3 -3
- liger_kernel/ops/jsd.py +3 -3
- liger_kernel/ops/layer_norm.py +20 -7
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +1 -2
- liger_kernel/transformers/__init__.py +3 -0
- liger_kernel/transformers/cross_entropy.py +3 -3
- liger_kernel/transformers/functional.py +17 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +3 -3
- liger_kernel/transformers/group_norm.py +6 -6
- liger_kernel/transformers/model/olmo2.py +124 -0
- liger_kernel/transformers/monkey_patch.py +171 -27
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/utils.py +49 -0
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.4.dist-info}/METADATA +17 -3
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.4.dist-info}/RECORD +26 -21
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.4.dist-info}/LICENSE +0 -0
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.4.dist-info}/NOTICE +0 -0
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.4.dist-info}/WHEEL +0 -0
- {liger_kernel-0.5.3.dist-info → liger_kernel-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
|
|
2
2
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
|
|
3
|
+
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
|
|
3
4
|
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401
|
|
4
5
|
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
|
|
5
6
|
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerFusedLinearRLHFBase(torch.autograd.Function):
|
|
8
|
+
@staticmethod
|
|
9
|
+
def forward(
|
|
10
|
+
ctx,
|
|
11
|
+
_input,
|
|
12
|
+
weight,
|
|
13
|
+
attention_mask,
|
|
14
|
+
rewards,
|
|
15
|
+
bias=None,
|
|
16
|
+
loss_fn=None,
|
|
17
|
+
num_generations=4,
|
|
18
|
+
beta=0.1,
|
|
19
|
+
compiled=True,
|
|
20
|
+
use_ref_model=False,
|
|
21
|
+
ref_input=None,
|
|
22
|
+
ref_weight=None,
|
|
23
|
+
ref_bias=None,
|
|
24
|
+
):
|
|
25
|
+
"""Chunked forward pass for RLHF loss computation."""
|
|
26
|
+
# Save for backward
|
|
27
|
+
ctx.beta = beta
|
|
28
|
+
ctx.rewards = rewards
|
|
29
|
+
|
|
30
|
+
# Initialize accumulators
|
|
31
|
+
loss_acc = torch.zeros((), device=_input.device)
|
|
32
|
+
grad_weight = torch.zeros_like(weight) # [V, H]
|
|
33
|
+
grad_inputs = []
|
|
34
|
+
grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
|
|
35
|
+
aggregated_metrics = []
|
|
36
|
+
|
|
37
|
+
# Create a partial function with fixed arguments
|
|
38
|
+
compute_loss = partial(
|
|
39
|
+
LigerFusedLinearRLHFBase._compute_chunk_loss,
|
|
40
|
+
beta=beta,
|
|
41
|
+
use_ref_model=use_ref_model,
|
|
42
|
+
ref_weight=ref_weight,
|
|
43
|
+
ref_bias=ref_bias,
|
|
44
|
+
rlhf_loss_fn=loss_fn,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
def fused_fwd_bwd(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk):
|
|
48
|
+
"""Fused forward and backward for a chunk."""
|
|
49
|
+
if bias is not None:
|
|
50
|
+
return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 5), has_aux=True)(
|
|
51
|
+
input_chunk, # arg 0
|
|
52
|
+
weight, # arg 1
|
|
53
|
+
attention_mask_chunk, # arg 2
|
|
54
|
+
rewards_chunk, # arg 3
|
|
55
|
+
ref_input_chunk, # arg 4
|
|
56
|
+
bias, # arg 5
|
|
57
|
+
)
|
|
58
|
+
else:
|
|
59
|
+
return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
|
|
60
|
+
input_chunk, # arg 0
|
|
61
|
+
weight, # arg 1
|
|
62
|
+
attention_mask_chunk, # arg 2
|
|
63
|
+
rewards_chunk, # arg 3
|
|
64
|
+
ref_input_chunk, # arg 4
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk=None):
|
|
68
|
+
if bias is not None:
|
|
69
|
+
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
|
|
70
|
+
input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
|
|
71
|
+
)
|
|
72
|
+
grad_bias.add_(chunk_grad_bias)
|
|
73
|
+
else:
|
|
74
|
+
(chunk_grad_input, chunk_grad_weight), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
|
|
75
|
+
input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Accumulate gradients and loss
|
|
79
|
+
grad_weight.add_(chunk_grad_weight)
|
|
80
|
+
grad_inputs.append(chunk_grad_input)
|
|
81
|
+
loss_acc.add_(chunk_loss)
|
|
82
|
+
|
|
83
|
+
# Initialize storage for metrics on first chunk
|
|
84
|
+
if len(aggregated_metrics) == 0:
|
|
85
|
+
for metric in chunk_metrics:
|
|
86
|
+
if metric.ndim == 0:
|
|
87
|
+
aggregated_metrics.append(torch.zeros((), device=metric.device))
|
|
88
|
+
else:
|
|
89
|
+
aggregated_metrics.append([])
|
|
90
|
+
|
|
91
|
+
# Accumulate metrics
|
|
92
|
+
for i, metric in enumerate(chunk_metrics):
|
|
93
|
+
if metric.ndim == 0:
|
|
94
|
+
aggregated_metrics[i].add_(metric)
|
|
95
|
+
else:
|
|
96
|
+
aggregated_metrics[i].append(metric)
|
|
97
|
+
|
|
98
|
+
if compiled:
|
|
99
|
+
accumulate_chunk = torch.compile(accumulate_chunk)
|
|
100
|
+
|
|
101
|
+
# Process input in chunks
|
|
102
|
+
chunks = max(1, _input.shape[0] // num_generations)
|
|
103
|
+
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
|
|
104
|
+
_attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
|
|
105
|
+
_rewards_chunks = torch.chunk(rewards, chunks=chunks, dim=0)
|
|
106
|
+
_ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) if use_ref_model else [None] * chunks
|
|
107
|
+
|
|
108
|
+
for input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk in zip(
|
|
109
|
+
_input_chunks, _attention_mask_chunks, _rewards_chunks, _ref_input_chunks
|
|
110
|
+
):
|
|
111
|
+
# Mark dynamic dimensions
|
|
112
|
+
torch._dynamo.mark_dynamic(input_chunk, 1)
|
|
113
|
+
torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
|
|
114
|
+
if ref_input_chunk is not None:
|
|
115
|
+
torch._dynamo.mark_dynamic(ref_input_chunk, 1)
|
|
116
|
+
|
|
117
|
+
accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk)
|
|
118
|
+
|
|
119
|
+
# Scale accumulated loss by number of chunks since we're averaging
|
|
120
|
+
loss_acc = loss_acc / chunks
|
|
121
|
+
|
|
122
|
+
# Combine gradients
|
|
123
|
+
grad_input = torch.cat(grad_inputs, dim=0)
|
|
124
|
+
|
|
125
|
+
# Save for backward
|
|
126
|
+
ctx.save_for_backward(grad_input, grad_weight, grad_bias)
|
|
127
|
+
|
|
128
|
+
# Finalize metrics
|
|
129
|
+
final_metrics = []
|
|
130
|
+
for metric in aggregated_metrics:
|
|
131
|
+
if isinstance(metric, list):
|
|
132
|
+
final_metrics.append(torch.cat(metric, dim=0))
|
|
133
|
+
else:
|
|
134
|
+
final_metrics.append(metric / chunks)
|
|
135
|
+
|
|
136
|
+
return loss_acc, tuple(final_metrics)
|
|
137
|
+
|
|
138
|
+
@staticmethod
|
|
139
|
+
def _compute_chunk_loss(
|
|
140
|
+
input_chunk,
|
|
141
|
+
weight,
|
|
142
|
+
attention_mask_chunk,
|
|
143
|
+
rewards_chunk,
|
|
144
|
+
ref_input_chunk=None,
|
|
145
|
+
bias=None,
|
|
146
|
+
beta=0.1,
|
|
147
|
+
use_ref_model=False,
|
|
148
|
+
ref_weight=None,
|
|
149
|
+
ref_bias=None,
|
|
150
|
+
rlhf_loss_fn=None,
|
|
151
|
+
):
|
|
152
|
+
"""Compute loss for a single chunk."""
|
|
153
|
+
# Get policy log probabilities using chunk_forward
|
|
154
|
+
log_probs, _, logits_mean = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias)
|
|
155
|
+
|
|
156
|
+
# Get reference log probabilities if needed
|
|
157
|
+
ref_log_probs = None
|
|
158
|
+
if use_ref_model and ref_input_chunk is not None:
|
|
159
|
+
with torch.no_grad():
|
|
160
|
+
ref_log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward(ref_input_chunk, ref_weight, bias=ref_bias)
|
|
161
|
+
|
|
162
|
+
# Compute chunk loss and metrics using the provided loss function
|
|
163
|
+
chunk_loss, chunk_metrics = rlhf_loss_fn(
|
|
164
|
+
log_probs=log_probs,
|
|
165
|
+
attention_mask=attention_mask_chunk,
|
|
166
|
+
rewards=rewards_chunk,
|
|
167
|
+
ref_log_probs=ref_log_probs,
|
|
168
|
+
beta=beta,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
return chunk_loss, (logits_mean, *chunk_metrics)
|
|
172
|
+
|
|
173
|
+
@staticmethod
|
|
174
|
+
def chunk_forward(input_chunk, weight, bias=None):
|
|
175
|
+
"""Forward pass computation for a single chunk without explicit reshaping."""
|
|
176
|
+
# Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
|
|
177
|
+
logits = torch.matmul(input_chunk, weight.t())
|
|
178
|
+
if bias is not None:
|
|
179
|
+
logits = logits + bias # Broadcasts bias to [B, T, V]
|
|
180
|
+
|
|
181
|
+
# Compute log probabilities using softmax over the last dimension
|
|
182
|
+
log_probs = F.log_softmax(logits.float(), dim=-1)
|
|
183
|
+
|
|
184
|
+
# Monitoring: compute mean of logits
|
|
185
|
+
batch_size, seq_len, _ = input_chunk.shape
|
|
186
|
+
logits_mean = logits.sum() / (batch_size * seq_len * weight.shape[0])
|
|
187
|
+
return log_probs, logits, logits_mean
|
|
188
|
+
|
|
189
|
+
@staticmethod
|
|
190
|
+
def backward(ctx, grad_output, *grad_metrics):
|
|
191
|
+
"""Backward pass for RLHF loss."""
|
|
192
|
+
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
193
|
+
if grad_output != 1.0:
|
|
194
|
+
grad_input = grad_input * grad_output
|
|
195
|
+
grad_weight = grad_weight * grad_output
|
|
196
|
+
if grad_bias is not None:
|
|
197
|
+
grad_bias = grad_bias * grad_output
|
|
198
|
+
|
|
199
|
+
return (
|
|
200
|
+
grad_input,
|
|
201
|
+
grad_weight,
|
|
202
|
+
None, # grad_attention_mask
|
|
203
|
+
None, # grad_rewards
|
|
204
|
+
grad_bias,
|
|
205
|
+
None, # grad_loss_fn
|
|
206
|
+
None, # grad_chunk_size
|
|
207
|
+
None, # grad_beta
|
|
208
|
+
None, # grad_compiled
|
|
209
|
+
None, # grad_use_ref_model
|
|
210
|
+
None, # grad_ref_input
|
|
211
|
+
None, # grad_ref_weight
|
|
212
|
+
None, # grad_ref_bias
|
|
213
|
+
)
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from liger_kernel.chunked_loss.fused_linear_rlhf import LigerFusedLinearRLHFBase
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
|
|
7
|
+
@staticmethod
|
|
8
|
+
def rlhf_loss_fn(
|
|
9
|
+
log_probs,
|
|
10
|
+
attention_mask,
|
|
11
|
+
rewards,
|
|
12
|
+
ref_log_probs=None,
|
|
13
|
+
beta=0.1,
|
|
14
|
+
**kwargs,
|
|
15
|
+
):
|
|
16
|
+
"""GRPO Loss Function matching GRPOTrainer implementation."""
|
|
17
|
+
# Get chosen token probabilities
|
|
18
|
+
chosen_tokens = log_probs.argmax(dim=-1) # (batch_size, seq_len)
|
|
19
|
+
chosen_token_logprobs = log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(
|
|
20
|
+
-1
|
|
21
|
+
) # (batch_size, seq_len)
|
|
22
|
+
|
|
23
|
+
# Get reference model probabilities
|
|
24
|
+
if ref_log_probs is not None:
|
|
25
|
+
with torch.no_grad():
|
|
26
|
+
ref_token_logprobs = ref_log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(-1)
|
|
27
|
+
else:
|
|
28
|
+
ref_token_logprobs = chosen_token_logprobs.detach()
|
|
29
|
+
|
|
30
|
+
# Compute advantages per batch entry in a grouped fashion
|
|
31
|
+
mean_grouped_rewards = rewards.mean() # [batch_size,]
|
|
32
|
+
std_grouped_rewards = rewards.std() # [batch_size,]
|
|
33
|
+
|
|
34
|
+
# Calculate advantages using the same epsilon as in GRPOTrainer
|
|
35
|
+
eps = 1e-4
|
|
36
|
+
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + eps)
|
|
37
|
+
|
|
38
|
+
# Compute policy gradient loss with importance sampling ratio
|
|
39
|
+
ratio = torch.exp(chosen_token_logprobs - chosen_token_logprobs.detach())
|
|
40
|
+
policy_loss = -ratio * advantages.unsqueeze(1)
|
|
41
|
+
|
|
42
|
+
# Compute KL penalty
|
|
43
|
+
kl_div = (
|
|
44
|
+
torch.exp(ref_token_logprobs - chosen_token_logprobs) - (ref_token_logprobs - chosen_token_logprobs) - 1.0
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Combine losses
|
|
48
|
+
per_token_loss = policy_loss + beta * kl_div
|
|
49
|
+
|
|
50
|
+
# Apply masking and normalize
|
|
51
|
+
masked_loss = per_token_loss * attention_mask
|
|
52
|
+
seq_lengths = attention_mask.sum()
|
|
53
|
+
seq_lengths = torch.clamp(seq_lengths, min=1.0)
|
|
54
|
+
loss = masked_loss.sum() / seq_lengths
|
|
55
|
+
|
|
56
|
+
# Calculate metrics
|
|
57
|
+
metrics = (
|
|
58
|
+
chosen_token_logprobs.mean(), # mean log prob
|
|
59
|
+
chosen_token_logprobs.std(), # std log prob
|
|
60
|
+
log_probs.mean(), # mean all log probs
|
|
61
|
+
((kl_div * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)).mean(), # mean KL div
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
return loss, metrics
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def forward(
|
|
68
|
+
ctx,
|
|
69
|
+
_input,
|
|
70
|
+
weight,
|
|
71
|
+
attention_mask,
|
|
72
|
+
rewards,
|
|
73
|
+
bias=None,
|
|
74
|
+
ref_input=None,
|
|
75
|
+
ref_weight=None,
|
|
76
|
+
ref_bias=None,
|
|
77
|
+
beta=0.1,
|
|
78
|
+
compiled=True,
|
|
79
|
+
use_ref_model=True,
|
|
80
|
+
num_generations=1,
|
|
81
|
+
):
|
|
82
|
+
return LigerFusedLinearRLHFBase.forward(
|
|
83
|
+
ctx=ctx,
|
|
84
|
+
_input=_input,
|
|
85
|
+
weight=weight,
|
|
86
|
+
attention_mask=attention_mask,
|
|
87
|
+
loss_fn=LigerFusedLinearGRPOFunction.rlhf_loss_fn,
|
|
88
|
+
rewards=rewards,
|
|
89
|
+
bias=bias,
|
|
90
|
+
ref_input=ref_input,
|
|
91
|
+
ref_weight=ref_weight,
|
|
92
|
+
ref_bias=ref_bias,
|
|
93
|
+
beta=beta,
|
|
94
|
+
compiled=compiled,
|
|
95
|
+
use_ref_model=use_ref_model,
|
|
96
|
+
num_generations=num_generations,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
@staticmethod
|
|
100
|
+
def backward(ctx, grad_output, *grad_metrics):
|
|
101
|
+
"""Backward pass for GRPO loss.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
grad_output: Gradient of the loss (scalar)
|
|
105
|
+
grad_metrics: Gradients of the metrics (not used in backward computation)
|
|
106
|
+
"""
|
|
107
|
+
grads = LigerFusedLinearRLHFBase.backward(ctx, grad_output)
|
|
108
|
+
return (
|
|
109
|
+
*grads[:5], # grad_input, grad_weight, grad_attention_mask, grad_rewards, grad_bias
|
|
110
|
+
None, # grad_ref_input
|
|
111
|
+
None, # grad_ref_weight
|
|
112
|
+
None, # grad_ref_bias
|
|
113
|
+
None, # grad_beta
|
|
114
|
+
None, # grad_compiled
|
|
115
|
+
None, # grad_use_ref_model
|
|
116
|
+
None, # grad_num_generations
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
121
|
+
"""Fused linear layer with GRPO loss."""
|
|
122
|
+
|
|
123
|
+
def __init__(
|
|
124
|
+
self,
|
|
125
|
+
beta: float = 0.1,
|
|
126
|
+
compiled: bool = True,
|
|
127
|
+
use_ref_model: bool = True,
|
|
128
|
+
num_generations: int = 1,
|
|
129
|
+
):
|
|
130
|
+
super().__init__()
|
|
131
|
+
self.beta = beta
|
|
132
|
+
self.compiled = compiled
|
|
133
|
+
self.use_ref_model = use_ref_model
|
|
134
|
+
self.num_generations = num_generations
|
|
135
|
+
|
|
136
|
+
def forward(
|
|
137
|
+
self,
|
|
138
|
+
_input,
|
|
139
|
+
lin_weight,
|
|
140
|
+
attention_mask,
|
|
141
|
+
rewards,
|
|
142
|
+
bias=None,
|
|
143
|
+
ref_input=None,
|
|
144
|
+
ref_weight=None,
|
|
145
|
+
ref_bias=None,
|
|
146
|
+
):
|
|
147
|
+
return LigerFusedLinearGRPOFunction.apply(
|
|
148
|
+
_input,
|
|
149
|
+
lin_weight,
|
|
150
|
+
attention_mask,
|
|
151
|
+
rewards,
|
|
152
|
+
bias,
|
|
153
|
+
ref_input,
|
|
154
|
+
ref_weight,
|
|
155
|
+
ref_bias,
|
|
156
|
+
self.beta,
|
|
157
|
+
self.compiled,
|
|
158
|
+
self.use_ref_model,
|
|
159
|
+
self.num_generations,
|
|
160
|
+
)
|
|
@@ -43,20 +43,20 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
|
43
43
|
3. Maintain reasonable distance from the reference model
|
|
44
44
|
|
|
45
45
|
Args:
|
|
46
|
-
|
|
47
|
-
|
|
46
|
+
average_log_prob_chunk: Log probabilities for the chunk (batch_size,)
|
|
47
|
+
preference_labels_chunk: Preference labels for the chunk (batch_size,)
|
|
48
48
|
full_target: Non chunked full target tensor
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
beta: Weight for the direct preference loss
|
|
49
|
+
ref_average_log_prob_chunk: Reference log probs for the chunk (batch_size,)
|
|
50
|
+
beta: Weight for the KTO loss
|
|
52
51
|
kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
|
|
53
52
|
Returns:
|
|
54
|
-
Tuple of (loss, chosen_rewards, rejected_rewards):
|
|
55
53
|
- loss: The KTO loss value
|
|
56
|
-
- chosen_rewards: Reward signals for chosen responses (detached)
|
|
57
|
-
- rejected_rewards: Reward signals for rejected responses (detached)
|
|
58
54
|
"""
|
|
59
|
-
|
|
55
|
+
if ref_average_log_prob_chunk is not None:
|
|
56
|
+
logratios_chunk = average_log_prob_chunk - ref_average_log_prob_chunk
|
|
57
|
+
else:
|
|
58
|
+
logratios_chunk = average_log_prob_chunk
|
|
59
|
+
|
|
60
60
|
multiplier_chunk = torch.where(preference_labels_chunk, 1, -1)
|
|
61
61
|
if kl is not None:
|
|
62
62
|
losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk)
|
|
@@ -289,9 +289,9 @@ def cross_entropy_forward(
|
|
|
289
289
|
weight_sum = 0.0
|
|
290
290
|
if weight is not None:
|
|
291
291
|
assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
|
|
292
|
-
assert torch.is_floating_point(
|
|
293
|
-
weight
|
|
294
|
-
)
|
|
292
|
+
assert torch.is_floating_point(weight), (
|
|
293
|
+
f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
|
|
294
|
+
)
|
|
295
295
|
sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
|
|
296
296
|
weight_sum = weight.sum().item()
|
|
297
297
|
# ensure weight is contiguous
|
|
@@ -58,9 +58,9 @@ def fused_linear_cross_entropy_forward(
|
|
|
58
58
|
ce_weight_sum = 0.0
|
|
59
59
|
if ce_weight is not None:
|
|
60
60
|
assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}"
|
|
61
|
-
assert torch.is_floating_point(
|
|
62
|
-
ce_weight
|
|
63
|
-
)
|
|
61
|
+
assert torch.is_floating_point(ce_weight), (
|
|
62
|
+
f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
|
|
63
|
+
)
|
|
64
64
|
total_sum_non_ignore_ce_weight = (
|
|
65
65
|
torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item()
|
|
66
66
|
)
|
|
@@ -195,9 +195,9 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
|
|
|
195
195
|
"""
|
|
196
196
|
has_label = False
|
|
197
197
|
if shift_labels is not None:
|
|
198
|
-
assert shift_labels.shape == (
|
|
199
|
-
|
|
200
|
-
)
|
|
198
|
+
assert shift_labels.shape == (teacher_input.shape[0],), (
|
|
199
|
+
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
|
200
|
+
)
|
|
201
201
|
shift_labels = shift_labels.contiguous()
|
|
202
202
|
has_label = True
|
|
203
203
|
|
liger_kernel/ops/jsd.py
CHANGED
|
@@ -157,9 +157,9 @@ class LigerJSDFunction(torch.autograd.Function):
|
|
|
157
157
|
"""
|
|
158
158
|
has_label = False
|
|
159
159
|
if shift_labels is not None:
|
|
160
|
-
assert shift_labels.shape == (
|
|
161
|
-
|
|
162
|
-
)
|
|
160
|
+
assert shift_labels.shape == (_input.shape[0],), (
|
|
161
|
+
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
|
162
|
+
)
|
|
163
163
|
shift_labels = shift_labels.contiguous()
|
|
164
164
|
has_label = True
|
|
165
165
|
|
liger_kernel/ops/layer_norm.py
CHANGED
|
@@ -57,13 +57,14 @@ def _layer_norm_forward_kernel(
|
|
|
57
57
|
B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
|
|
58
58
|
|
|
59
59
|
mean = tl.sum(X_row, axis=0) / n_cols
|
|
60
|
-
|
|
60
|
+
Xmm = tl.where(mask, X_row - mean, 0)
|
|
61
|
+
var = tl.sum(Xmm * Xmm, axis=0) / n_cols
|
|
61
62
|
rstd = rsqrt(var + eps)
|
|
62
63
|
|
|
63
64
|
tl.store(Mean_ptr, mean)
|
|
64
65
|
tl.store(RSTD_ptr, rstd)
|
|
65
66
|
|
|
66
|
-
Y_row =
|
|
67
|
+
Y_row = Xmm * rstd * W_row + B_row
|
|
67
68
|
|
|
68
69
|
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
|
|
69
70
|
|
|
@@ -147,9 +148,11 @@ def layer_norm_forward(X, W, B, eps):
|
|
|
147
148
|
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
148
149
|
Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
149
150
|
RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
151
|
+
if X.shape[1] != W.shape[0]:
|
|
152
|
+
raise ValueError(
|
|
153
|
+
f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
|
|
154
|
+
f"must match weight size (W.shape[0]={W.shape[0]})"
|
|
155
|
+
)
|
|
153
156
|
|
|
154
157
|
_layer_norm_forward_kernel[(n_rows,)](
|
|
155
158
|
Y,
|
|
@@ -190,11 +193,21 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
190
193
|
|
|
191
194
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
192
195
|
if n_cols > BLOCK_SIZE:
|
|
193
|
-
raise RuntimeError(
|
|
196
|
+
raise RuntimeError(
|
|
197
|
+
f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
|
|
198
|
+
)
|
|
194
199
|
|
|
195
200
|
rows_per_program = math.ceil(n_rows / sm_count)
|
|
196
201
|
grid = (sm_count,)
|
|
197
|
-
triton_dtype =
|
|
202
|
+
triton_dtype = (
|
|
203
|
+
tl.float32
|
|
204
|
+
if X.dtype == torch.float32
|
|
205
|
+
else tl.bfloat16
|
|
206
|
+
if X.dtype == torch.bfloat16
|
|
207
|
+
else tl.float16
|
|
208
|
+
if X.dtype == torch.float16
|
|
209
|
+
else tl.float32 # fallback to float32 for other types
|
|
210
|
+
)
|
|
198
211
|
_layer_norm_backward_kernel[grid](
|
|
199
212
|
X,
|
|
200
213
|
W,
|