liger-kernel-nightly 0.4.2.dev20241209195823__py3-none-any.whl → 0.4.2.dev20241209234352__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cpo_loss.py +16 -10
- liger_kernel/chunked_loss/dpo_loss.py +20 -12
- liger_kernel/chunked_loss/fused_linear_preference.py +181 -164
- liger_kernel/chunked_loss/orpo_loss.py +15 -9
- liger_kernel/chunked_loss/simpo_loss.py +17 -11
- {liger_kernel_nightly-0.4.2.dev20241209195823.dist-info → liger_kernel_nightly-0.4.2.dev20241209234352.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.4.2.dev20241209195823.dist-info → liger_kernel_nightly-0.4.2.dev20241209234352.dist-info}/RECORD +11 -11
- {liger_kernel_nightly-0.4.2.dev20241209195823.dist-info → liger_kernel_nightly-0.4.2.dev20241209234352.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209195823.dist-info → liger_kernel_nightly-0.4.2.dev20241209234352.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209195823.dist-info → liger_kernel_nightly-0.4.2.dev20241209234352.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209195823.dist-info → liger_kernel_nightly-0.4.2.dev20241209234352.dist-info}/top_level.txt +0 -0
@@ -11,11 +11,25 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
11
11
|
@staticmethod
|
12
12
|
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
|
13
13
|
"""
|
14
|
-
|
14
|
+
Paper: https://arxiv.org/pdf/2401.08417
|
15
|
+
|
16
|
+
Formula:
|
17
|
+
L(π_θ; U) = -E_(x,y_w,y_l)~D[log σ(β log π_θ(y_w|x) - β log π_θ(y_l|x))]
|
18
|
+
|
19
|
+
Where:
|
20
|
+
- π_θ(y|x): Policy (model) probability
|
21
|
+
- y_w: Chosen sequence
|
22
|
+
- y_l: Rejected sequence
|
23
|
+
- σ: Sigmoid function
|
24
|
+
- β: Temperature parameter
|
25
|
+
- E: Expected value over the dataset D
|
26
|
+
- D: Dataset of preferences
|
27
|
+
|
15
28
|
Args:
|
16
29
|
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
|
17
30
|
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
|
18
|
-
|
31
|
+
full_target (torch.Tensor): Non chunked full target tensor
|
32
|
+
beta (float): Weight for the CPO loss
|
19
33
|
"""
|
20
34
|
logits = beta * (chosen_logps - rejected_logps)
|
21
35
|
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
|
@@ -34,12 +48,6 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
34
48
|
compute_nll_loss=True,
|
35
49
|
compiled=True,
|
36
50
|
):
|
37
|
-
"""
|
38
|
-
Fused linear layer with CPO (Odds-Ratio Preference Optimization) loss.
|
39
|
-
Handles both the forward and backward pass of the final linear layer with CPO loss.
|
40
|
-
Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
|
41
|
-
"""
|
42
|
-
|
43
51
|
return LigerFusedLinearPreferenceBase.forward(
|
44
52
|
ctx,
|
45
53
|
_input,
|
@@ -56,9 +64,7 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
56
64
|
|
57
65
|
@staticmethod
|
58
66
|
def backward(ctx, *grad_output):
|
59
|
-
# Get gradients for _input, weight, bias, and target from the base class
|
60
67
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
61
|
-
# Return these gradients, followed by None for the remaining inputs
|
62
68
|
return *grads, None, None, None, None, None
|
63
69
|
|
64
70
|
|
@@ -18,14 +18,28 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
18
18
|
beta=0.1,
|
19
19
|
):
|
20
20
|
"""
|
21
|
-
|
21
|
+
Paper: https://arxiv.org/pdf/2305.18290
|
22
|
+
|
23
|
+
Formula:
|
24
|
+
L_DPO = -E[ log_sigmoid( β * (log(π(y_w|x)/π_ref(y_w|x)) - log(π(y_l|x)/π_ref(y_l|x))) ) ]
|
25
|
+
|
26
|
+
Where:
|
27
|
+
- π(y|x): Policy (model) probability
|
28
|
+
- π_ref(y|x): Reference model probability
|
29
|
+
- y_w: Chosen sequence
|
30
|
+
- y_l: Rejected sequence
|
31
|
+
- β: Weight for the direct preference loss
|
32
|
+
- E: Expected value over the dataset
|
33
|
+
|
22
34
|
Args:
|
23
|
-
chosen_logps
|
24
|
-
rejected_logps
|
25
|
-
|
26
|
-
|
27
|
-
|
35
|
+
chosen_logps: Log probabilities of chosen tokens (batch_size,)
|
36
|
+
rejected_logps: Log probabilities of rejected tokens (batch_size,)
|
37
|
+
full_target: Non chunked full target tensor
|
38
|
+
ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
|
39
|
+
ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
|
40
|
+
beta: Weight for the direct preference loss
|
28
41
|
"""
|
42
|
+
|
29
43
|
if ref_chosen_logps is None:
|
30
44
|
ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device)
|
31
45
|
if ref_rejected_logps is None:
|
@@ -53,10 +67,6 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
53
67
|
compiled=True,
|
54
68
|
use_ref_model=True,
|
55
69
|
):
|
56
|
-
"""
|
57
|
-
Fused linear layer with DPO (Direct Preference Optimization) loss.
|
58
|
-
Handles both the forward and backward pass of the final linear layer with DPO loss.
|
59
|
-
"""
|
60
70
|
return LigerFusedLinearPreferenceBase.forward(
|
61
71
|
ctx=ctx,
|
62
72
|
_input=_input,
|
@@ -75,9 +85,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
75
85
|
|
76
86
|
@staticmethod
|
77
87
|
def backward(ctx, *grad_output):
|
78
|
-
# Get gradients for _input, weight, bias, and target from the base class
|
79
88
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
80
|
-
# Return these gradients, followed by None for the remaining inputs
|
81
89
|
return *grads, None, None, None, None, None, None, None
|
82
90
|
|
83
91
|
|
@@ -8,159 +8,12 @@ from torch.nn import functional as F
|
|
8
8
|
class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
9
9
|
|
10
10
|
@abstractmethod
|
11
|
-
def preference_loss_fn(
|
11
|
+
def preference_loss_fn(*args, **kwargs):
|
12
12
|
"""
|
13
|
-
|
14
|
-
Args:
|
15
|
-
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
|
16
|
-
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
|
17
|
-
beta (float): Weight for the odds ratio loss.
|
13
|
+
To be extended by subclasses.
|
18
14
|
"""
|
19
15
|
raise NotImplementedError("Preference loss function must be implemented.")
|
20
16
|
|
21
|
-
@staticmethod
|
22
|
-
def chunk_forward(
|
23
|
-
input_chunk,
|
24
|
-
weight,
|
25
|
-
target_chunk,
|
26
|
-
bias=None,
|
27
|
-
ignore_index=-100,
|
28
|
-
compute_nll_loss=True,
|
29
|
-
):
|
30
|
-
len_chosen_chunk = target_chunk.shape[0] // 2
|
31
|
-
logits_chunk = input_chunk @ weight.t()
|
32
|
-
if bias is not None:
|
33
|
-
logits_chunk = logits_chunk + bias
|
34
|
-
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
|
35
|
-
|
36
|
-
chosen_nll_loss = 0.0
|
37
|
-
if compute_nll_loss:
|
38
|
-
chosen_nll_loss = F.nll_loss(
|
39
|
-
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
|
40
|
-
target_chunk[:len_chosen_chunk].view(-1),
|
41
|
-
reduction="sum",
|
42
|
-
ignore_index=ignore_index,
|
43
|
-
)
|
44
|
-
|
45
|
-
loss_mask = target_chunk != ignore_index
|
46
|
-
label_chunk = torch.where(loss_mask, target_chunk, 0)
|
47
|
-
|
48
|
-
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
|
49
|
-
-1
|
50
|
-
)
|
51
|
-
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
52
|
-
|
53
|
-
chosen_logps = average_log_prob[:len_chosen_chunk]
|
54
|
-
rejected_logps = average_log_prob[len_chosen_chunk:]
|
55
|
-
|
56
|
-
chosen_logits = logits_chunk[:len_chosen_chunk]
|
57
|
-
rejected_logits = logits_chunk[len_chosen_chunk:]
|
58
|
-
|
59
|
-
return (
|
60
|
-
chosen_logps,
|
61
|
-
rejected_logps,
|
62
|
-
chosen_logits,
|
63
|
-
rejected_logits,
|
64
|
-
chosen_nll_loss,
|
65
|
-
)
|
66
|
-
|
67
|
-
@staticmethod
|
68
|
-
def _compute_loss(
|
69
|
-
input_chunk,
|
70
|
-
weight,
|
71
|
-
target_chunk,
|
72
|
-
bias=None,
|
73
|
-
preference_loss_fn=None,
|
74
|
-
full_target=None,
|
75
|
-
ignore_index=-100,
|
76
|
-
alpha=1.0,
|
77
|
-
beta=0.1,
|
78
|
-
compute_nll_loss=True,
|
79
|
-
use_ref_model=False,
|
80
|
-
ref_weight=None,
|
81
|
-
ref_bias=None,
|
82
|
-
**loss_kwargs,
|
83
|
-
):
|
84
|
-
"""
|
85
|
-
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
|
86
|
-
Args:
|
87
|
-
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
88
|
-
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
|
89
|
-
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
|
90
|
-
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
|
91
|
-
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
92
|
-
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
|
93
|
-
ignore_index (int): Index to ignore for loss computation.
|
94
|
-
alpha (float): Weight for the NLL loss.
|
95
|
-
beta (float): Weight for the odds ratio loss.
|
96
|
-
compute_nll_loss (bool): Whether to compute NLL loss.
|
97
|
-
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
98
|
-
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
99
|
-
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
100
|
-
loss_kwargs (dict): Additional arguments for the loss function.
|
101
|
-
"""
|
102
|
-
(
|
103
|
-
chosen_logps,
|
104
|
-
rejected_logps,
|
105
|
-
chosen_logits,
|
106
|
-
rejected_logits,
|
107
|
-
chosen_nll_loss,
|
108
|
-
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
109
|
-
input_chunk,
|
110
|
-
weight,
|
111
|
-
target_chunk,
|
112
|
-
bias=bias,
|
113
|
-
ignore_index=ignore_index,
|
114
|
-
compute_nll_loss=compute_nll_loss,
|
115
|
-
)
|
116
|
-
chosen_nll_loss = (
|
117
|
-
chosen_nll_loss
|
118
|
-
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
|
119
|
-
)
|
120
|
-
chosen_logits_mean = chosen_logits.sum() / (
|
121
|
-
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
122
|
-
)
|
123
|
-
rejected_logits_mean = rejected_logits.sum() / (
|
124
|
-
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
125
|
-
)
|
126
|
-
|
127
|
-
if use_ref_model:
|
128
|
-
with torch.no_grad():
|
129
|
-
(
|
130
|
-
ref_chosen_logps,
|
131
|
-
ref_rejected_logps,
|
132
|
-
ref_chosen_logits,
|
133
|
-
ref_rejected_logits,
|
134
|
-
ref_chosen_nll_loss,
|
135
|
-
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
136
|
-
input_chunk,
|
137
|
-
ref_weight,
|
138
|
-
target_chunk,
|
139
|
-
ref_bias,
|
140
|
-
ignore_index=ignore_index,
|
141
|
-
compute_nll_loss=False, # We don't need NLL loss for the reference model
|
142
|
-
)
|
143
|
-
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
|
144
|
-
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
|
145
|
-
|
146
|
-
preference_loss_outputs = preference_loss_fn(
|
147
|
-
chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs
|
148
|
-
)
|
149
|
-
if isinstance(preference_loss_outputs, tuple):
|
150
|
-
preference_loss, *aux_outputs = preference_loss_outputs
|
151
|
-
else:
|
152
|
-
preference_loss, aux_outputs = preference_loss_outputs, []
|
153
|
-
|
154
|
-
loss = alpha * chosen_nll_loss - preference_loss
|
155
|
-
return_vars = (
|
156
|
-
chosen_logps,
|
157
|
-
rejected_logps,
|
158
|
-
chosen_logits_mean,
|
159
|
-
rejected_logits_mean,
|
160
|
-
chosen_nll_loss,
|
161
|
-
)
|
162
|
-
return loss, (*return_vars, *aux_outputs)
|
163
|
-
|
164
17
|
@staticmethod
|
165
18
|
def forward(
|
166
19
|
ctx,
|
@@ -176,6 +29,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
176
29
|
compute_nll_loss=True,
|
177
30
|
compiled=True,
|
178
31
|
use_ref_model=False,
|
32
|
+
# TODO: ref input
|
179
33
|
ref_weight=None,
|
180
34
|
ref_bias=None,
|
181
35
|
**loss_kwargs,
|
@@ -184,6 +38,14 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
184
38
|
Base class for fused linear layer with preference loss.
|
185
39
|
Expects _input to be stacked with chosen and rejected inputs on the batch dimension.
|
186
40
|
|
41
|
+
The mental model is:
|
42
|
+
|
43
|
+
forward()
|
44
|
+
├── Loop over chunks
|
45
|
+
└── compute_loss()
|
46
|
+
├── chunk_forward() # Compute logits and log probs
|
47
|
+
└── prefer_loss() # Calculate preference loss
|
48
|
+
|
187
49
|
Args:
|
188
50
|
_input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
|
189
51
|
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
|
@@ -191,10 +53,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
191
53
|
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
192
54
|
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
193
55
|
chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
|
194
|
-
compute_nll_loss (bool): Whether to compute NLL loss.
|
195
56
|
ignore_index (int): Index to ignore for loss computation.
|
196
57
|
alpha (float): Weight for the NLL loss.
|
197
|
-
beta (float): Weight for the
|
58
|
+
beta (float): Weight for the preference loss.
|
198
59
|
compute_nll_loss (bool): Whether to compute NLL loss.
|
199
60
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
200
61
|
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
@@ -205,11 +66,16 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
205
66
|
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
|
206
67
|
CHUNK_SIZE = chunk_size
|
207
68
|
|
69
|
+
# Gradients to be accumulated
|
208
70
|
grad_weight = torch.zeros_like(weight)
|
209
71
|
grad_chosen_inputs = []
|
210
72
|
grad_rejected_inputs = []
|
211
73
|
grad_bias = torch.zeros_like(bias) if bias is not None else None
|
74
|
+
|
75
|
+
# Loss to be accumulated
|
212
76
|
loss_acc = torch.zeros((), device=_input.device)
|
77
|
+
|
78
|
+
# Metrics to be recorded
|
213
79
|
policy_chosen_logps = []
|
214
80
|
policy_rejected_logps = []
|
215
81
|
policy_chosen_logits_mean = torch.zeros((), device=_input.device)
|
@@ -217,7 +83,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
217
83
|
policy_nll_loss = torch.zeros((), device=_input.device)
|
218
84
|
aggregated_aux_outputs = [] # aggregated aux outputs from all chunks
|
219
85
|
|
220
|
-
|
86
|
+
compute_loss = partial(
|
221
87
|
LigerFusedLinearPreferenceBase._compute_loss,
|
222
88
|
preference_loss_fn=loss_fn,
|
223
89
|
ignore_index=ignore_index,
|
@@ -231,14 +97,17 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
231
97
|
**loss_kwargs,
|
232
98
|
)
|
233
99
|
|
234
|
-
def
|
100
|
+
def fused_fwd_bwd(input_chunk, target_chunk):
|
101
|
+
"""
|
102
|
+
Fused forward and backward pass for a chunk of input and target.
|
103
|
+
"""
|
235
104
|
if bias is not None:
|
236
105
|
return torch.func.grad_and_value(
|
237
|
-
|
106
|
+
compute_loss, argnums=(0, 1, 3), has_aux=True
|
238
107
|
)(input_chunk, weight, target_chunk, bias)
|
239
108
|
else:
|
240
109
|
return torch.func.grad_and_value(
|
241
|
-
|
110
|
+
compute_loss, argnums=(0, 1), has_aux=True
|
242
111
|
)(input_chunk, weight, target_chunk)
|
243
112
|
|
244
113
|
def accumulate_chunk(input_chunk, target_chunk):
|
@@ -253,7 +122,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
253
122
|
chunk_nll_loss,
|
254
123
|
*aux_outputs,
|
255
124
|
),
|
256
|
-
) =
|
125
|
+
) = fused_fwd_bwd(input_chunk, target_chunk)
|
257
126
|
grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
|
258
127
|
else:
|
259
128
|
(chunk_grad_input, chunk_grad_weight), (
|
@@ -266,16 +135,26 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
266
135
|
chunk_nll_loss,
|
267
136
|
*aux_outputs,
|
268
137
|
),
|
269
|
-
) =
|
138
|
+
) = fused_fwd_bwd(input_chunk, target_chunk)
|
270
139
|
|
140
|
+
# Accumulate gradients
|
271
141
|
grad_weight.add_(chunk_grad_weight)
|
142
|
+
grad_chosen_inputs.append(chunk_grad_input[: chosen_target_chunk.shape[0]])
|
143
|
+
grad_rejected_inputs.append(
|
144
|
+
chunk_grad_input[chosen_target_chunk.shape[0] :]
|
145
|
+
)
|
146
|
+
|
147
|
+
# Accumulate loss
|
272
148
|
loss_acc.add_(chunk_loss)
|
149
|
+
|
150
|
+
# Accumulate metrics
|
273
151
|
policy_chosen_logps.append(chunk_chosen_logps)
|
274
152
|
policy_rejected_logps.append(chunk_rejected_logps)
|
275
153
|
policy_chosen_logits_mean.add_(chunk_chosen_logits_mean)
|
276
154
|
policy_rejected_logits_mean.add_(chunk_rejected_logits_mean)
|
277
155
|
policy_nll_loss.add_(chunk_nll_loss)
|
278
156
|
|
157
|
+
# aux_outputs
|
279
158
|
# Initialize storage for aux_outputs
|
280
159
|
if len(aggregated_aux_outputs) == 0:
|
281
160
|
for aux in aux_outputs:
|
@@ -293,10 +172,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
293
172
|
else:
|
294
173
|
aggregated_aux_outputs[i].append(aux)
|
295
174
|
|
296
|
-
return chunk_grad_input
|
297
|
-
|
298
175
|
if compiled:
|
299
|
-
|
176
|
+
fused_fwd_bwd = torch.compile(fused_fwd_bwd)
|
300
177
|
|
301
178
|
len_chosen = target.shape[0] // 2
|
302
179
|
chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
|
@@ -327,10 +204,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
327
204
|
torch._dynamo.mark_dynamic(target, 1)
|
328
205
|
|
329
206
|
# accumulate loss, gradients, and metrics
|
330
|
-
|
331
|
-
|
332
|
-
grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]])
|
333
|
-
grad_rejected_inputs.append(grad_input[chosen_target_chunk.shape[0] :])
|
207
|
+
accumulate_chunk(input_chunk, target_chunk)
|
334
208
|
|
335
209
|
# combine grad_chosen_inputs and grad_rejected_inputs
|
336
210
|
grad_inputs = grad_chosen_inputs + grad_rejected_inputs
|
@@ -367,3 +241,146 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
367
241
|
grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
|
368
242
|
|
369
243
|
return grad_input, grad_weight, None, grad_bias, None, None, None
|
244
|
+
|
245
|
+
@staticmethod
|
246
|
+
def chunk_forward(
|
247
|
+
input_chunk,
|
248
|
+
weight,
|
249
|
+
target_chunk,
|
250
|
+
bias=None,
|
251
|
+
ignore_index=-100,
|
252
|
+
compute_nll_loss=True,
|
253
|
+
):
|
254
|
+
len_chosen_chunk = target_chunk.shape[0] // 2
|
255
|
+
logits_chunk = input_chunk @ weight.t()
|
256
|
+
if bias is not None:
|
257
|
+
logits_chunk = logits_chunk + bias
|
258
|
+
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
|
259
|
+
|
260
|
+
chosen_nll_loss = 0.0
|
261
|
+
if compute_nll_loss:
|
262
|
+
chosen_nll_loss = F.nll_loss(
|
263
|
+
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
|
264
|
+
target_chunk[:len_chosen_chunk].view(-1),
|
265
|
+
reduction="sum",
|
266
|
+
ignore_index=ignore_index,
|
267
|
+
)
|
268
|
+
|
269
|
+
loss_mask = target_chunk != ignore_index
|
270
|
+
label_chunk = torch.where(loss_mask, target_chunk, 0)
|
271
|
+
|
272
|
+
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
|
273
|
+
-1
|
274
|
+
)
|
275
|
+
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
276
|
+
|
277
|
+
chosen_logps = average_log_prob[:len_chosen_chunk]
|
278
|
+
rejected_logps = average_log_prob[len_chosen_chunk:]
|
279
|
+
|
280
|
+
chosen_logits = logits_chunk[:len_chosen_chunk]
|
281
|
+
rejected_logits = logits_chunk[len_chosen_chunk:]
|
282
|
+
|
283
|
+
return (
|
284
|
+
chosen_logps,
|
285
|
+
rejected_logps,
|
286
|
+
chosen_logits,
|
287
|
+
rejected_logits,
|
288
|
+
chosen_nll_loss,
|
289
|
+
)
|
290
|
+
|
291
|
+
@staticmethod
|
292
|
+
def _compute_loss(
|
293
|
+
input_chunk,
|
294
|
+
weight,
|
295
|
+
target_chunk,
|
296
|
+
bias=None,
|
297
|
+
preference_loss_fn=None,
|
298
|
+
full_target=None,
|
299
|
+
ignore_index=-100,
|
300
|
+
alpha=1.0,
|
301
|
+
beta=0.1,
|
302
|
+
compute_nll_loss=True,
|
303
|
+
use_ref_model=False,
|
304
|
+
ref_weight=None,
|
305
|
+
ref_bias=None,
|
306
|
+
**loss_kwargs,
|
307
|
+
):
|
308
|
+
"""
|
309
|
+
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
|
310
|
+
Args:
|
311
|
+
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
312
|
+
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
|
313
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
|
314
|
+
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
|
315
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
316
|
+
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
|
317
|
+
ignore_index (int): Index to ignore for loss computation.
|
318
|
+
alpha (float): Weight for the NLL loss.
|
319
|
+
beta (float): Weight for the preference loss.
|
320
|
+
compute_nll_loss (bool): Whether to compute NLL loss.
|
321
|
+
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
322
|
+
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
323
|
+
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
324
|
+
loss_kwargs (dict): Additional arguments for the loss function.
|
325
|
+
"""
|
326
|
+
(
|
327
|
+
chosen_logps,
|
328
|
+
rejected_logps,
|
329
|
+
chosen_logits,
|
330
|
+
rejected_logits,
|
331
|
+
chosen_nll_loss,
|
332
|
+
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
333
|
+
input_chunk,
|
334
|
+
weight,
|
335
|
+
target_chunk,
|
336
|
+
bias=bias,
|
337
|
+
ignore_index=ignore_index,
|
338
|
+
compute_nll_loss=compute_nll_loss,
|
339
|
+
)
|
340
|
+
chosen_nll_loss = (
|
341
|
+
chosen_nll_loss
|
342
|
+
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
|
343
|
+
)
|
344
|
+
chosen_logits_mean = chosen_logits.sum() / (
|
345
|
+
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
346
|
+
)
|
347
|
+
rejected_logits_mean = rejected_logits.sum() / (
|
348
|
+
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
349
|
+
)
|
350
|
+
|
351
|
+
if use_ref_model:
|
352
|
+
with torch.no_grad():
|
353
|
+
(
|
354
|
+
ref_chosen_logps,
|
355
|
+
ref_rejected_logps,
|
356
|
+
ref_chosen_logits,
|
357
|
+
ref_rejected_logits,
|
358
|
+
ref_chosen_nll_loss,
|
359
|
+
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
360
|
+
input_chunk,
|
361
|
+
ref_weight,
|
362
|
+
target_chunk,
|
363
|
+
ref_bias,
|
364
|
+
ignore_index=ignore_index,
|
365
|
+
compute_nll_loss=False, # We don't need NLL loss for the reference model
|
366
|
+
)
|
367
|
+
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
|
368
|
+
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
|
369
|
+
|
370
|
+
preference_loss_outputs = preference_loss_fn(
|
371
|
+
chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs
|
372
|
+
)
|
373
|
+
if isinstance(preference_loss_outputs, tuple):
|
374
|
+
preference_loss, *aux_outputs = preference_loss_outputs
|
375
|
+
else:
|
376
|
+
preference_loss, aux_outputs = preference_loss_outputs, []
|
377
|
+
|
378
|
+
loss = alpha * chosen_nll_loss - preference_loss
|
379
|
+
return_vars = (
|
380
|
+
chosen_logps,
|
381
|
+
rejected_logps,
|
382
|
+
chosen_logits_mean,
|
383
|
+
rejected_logits_mean,
|
384
|
+
chosen_nll_loss,
|
385
|
+
)
|
386
|
+
return loss, (*return_vars, *aux_outputs)
|
@@ -11,10 +11,24 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
11
11
|
@staticmethod
|
12
12
|
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
|
13
13
|
"""
|
14
|
-
|
14
|
+
Paper: https://arxiv.org/pdf/2403.07691
|
15
|
+
|
16
|
+
Formula:
|
17
|
+
Compute odds-ratio loss: L_OR = -log(σ(log(odds_θ(y_w|x) / odds_θ(y_l|x))))
|
18
|
+
where odds_θ(y|x) = P_θ(y|x) / (1 - P_θ(y|x))
|
19
|
+
|
20
|
+
Where:
|
21
|
+
- P_θ(y|x): Policy (model) probability
|
22
|
+
- y_w: Chosen sequence
|
23
|
+
- y_l: Rejected sequence
|
24
|
+
- σ: Sigmoid function
|
25
|
+
- β: Weight for the odds ratio loss
|
26
|
+
- odds_θ: Odds function for the policy
|
27
|
+
|
15
28
|
Args:
|
16
29
|
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
|
17
30
|
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
|
31
|
+
full_target (torch.Tensor): Non chunked full target tensor
|
18
32
|
beta (float): Weight for the odds ratio loss.
|
19
33
|
"""
|
20
34
|
log_odds = (chosen_logps - rejected_logps) - (
|
@@ -44,12 +58,6 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
44
58
|
compute_nll_loss=True,
|
45
59
|
compiled=True,
|
46
60
|
):
|
47
|
-
"""
|
48
|
-
Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
|
49
|
-
Handles both the forward and backward pass of the final linear layer with ORPO loss.
|
50
|
-
Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
|
51
|
-
"""
|
52
|
-
|
53
61
|
return LigerFusedLinearPreferenceBase.forward(
|
54
62
|
ctx=ctx,
|
55
63
|
_input=_input,
|
@@ -65,9 +73,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
65
73
|
|
66
74
|
@staticmethod
|
67
75
|
def backward(ctx, *grad_output):
|
68
|
-
# Get gradients for _input, weight, bias, and target from the base class
|
69
76
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
70
|
-
# Return these gradients, followed by None for the remaining inputs
|
71
77
|
return *grads, None, None, None, None
|
72
78
|
|
73
79
|
|
@@ -13,12 +13,26 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
13
13
|
chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5
|
14
14
|
):
|
15
15
|
"""
|
16
|
-
|
16
|
+
Paper: https://arxiv.org/pdf/2405.14734
|
17
|
+
|
18
|
+
Formula:
|
19
|
+
L_SimPO(π_θ) = -E [log σ(β/|y_w| log π_θ(y_w|x) - β/|y_l| log π_θ(y_l|x) - γ)]
|
20
|
+
|
21
|
+
Where:
|
22
|
+
- π_θ(y|x): Policy (model) probability
|
23
|
+
- y_w: Chosen sequence
|
24
|
+
- y_l: Rejected sequence
|
25
|
+
- |y_w|, |y_l|: Sequence lengths
|
26
|
+
- σ: Sigmoid function
|
27
|
+
- β: beta weight
|
28
|
+
- γ: gemma margin term
|
29
|
+
|
17
30
|
Args:
|
18
31
|
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
|
19
32
|
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
|
20
|
-
|
21
|
-
|
33
|
+
full_target: Non chunked full target tensor
|
34
|
+
beta (float): beta weight
|
35
|
+
gamma (float): gemma margin term
|
22
36
|
"""
|
23
37
|
logits = beta * (chosen_logps - rejected_logps) - gamma
|
24
38
|
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
|
@@ -38,12 +52,6 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
38
52
|
compiled=True,
|
39
53
|
gamma=0.5,
|
40
54
|
):
|
41
|
-
"""
|
42
|
-
Fused linear layer with SimPO (Simple Preference Optimization) loss. https://arxiv.org/pdf/2405.14734
|
43
|
-
Handles both the forward and backward pass of the final linear layer with SimPO loss.
|
44
|
-
Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
|
45
|
-
"""
|
46
|
-
|
47
55
|
return LigerFusedLinearPreferenceBase.forward(
|
48
56
|
ctx,
|
49
57
|
_input,
|
@@ -61,9 +69,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
61
69
|
|
62
70
|
@staticmethod
|
63
71
|
def backward(ctx, *grad_output):
|
64
|
-
# Get gradients for _input, weight, bias, and target from the base class
|
65
72
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
66
|
-
# Return these gradients, followed by None for the remaining inputs
|
67
73
|
return *grads, None, None, None, None, None, None
|
68
74
|
|
69
75
|
|
@@ -2,13 +2,13 @@ liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
liger_kernel/env_report.py,sha256=1ETxx6HW4bKMK5aa5xaFzEmx0Ibc_kNryL_gXBVyyrI,1374
|
3
3
|
liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
|
4
4
|
liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
|
5
|
-
liger_kernel/chunked_loss/cpo_loss.py,sha256=
|
6
|
-
liger_kernel/chunked_loss/dpo_loss.py,sha256=
|
5
|
+
liger_kernel/chunked_loss/cpo_loss.py,sha256=Qu1Ul2A12sp6CqIT-atPbHWFb_LLtINEA9mOpIRx_0g,3097
|
6
|
+
liger_kernel/chunked_loss/dpo_loss.py,sha256=H9_RRhclckHYM2sd75tgbnf8IxC_PU2JCALbgtPQvwc,4222
|
7
7
|
liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
|
8
8
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=2BH6DCPjsR2zS6zcwFPcIIZRhLF8SohjGdKsAJ_301o,10222
|
9
|
-
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=
|
10
|
-
liger_kernel/chunked_loss/orpo_loss.py,sha256=
|
11
|
-
liger_kernel/chunked_loss/simpo_loss.py,sha256=
|
9
|
+
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=vlWfaaIECWvCQhY9PM7zRI0vKThIrydMf6P44bXn1EE,15114
|
10
|
+
liger_kernel/chunked_loss/orpo_loss.py,sha256=ZuKGjbkIYzV4UzvupNdq6vyxCp7-BztQkUt8ZnFvKos,3531
|
11
|
+
liger_kernel/chunked_loss/simpo_loss.py,sha256=Wa4LOlDG9PbJkOOkKg8hbKvnKgg7OTBz6-qIkwPK1yw,3275
|
12
12
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
13
|
liger_kernel/ops/cross_entropy.py,sha256=VqaYB9Zirc51eZ28OmjEZRrrV9UysRjS_vhIftB9sKo,15753
|
14
14
|
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=Tnw4gyAYVVdnCOqhOuLEzbUQ3goOTnoAfk3pqSIM5ac,9301
|
@@ -56,9 +56,9 @@ liger_kernel/transformers/model/qwen2.py,sha256=EyhSSzQOskGjSnCsKMZpd1s5IAIlHd5P
|
|
56
56
|
liger_kernel/transformers/model/qwen2_vl.py,sha256=bIQe2bWiY--G84FhCD29Gdi64_qHP6vbcGsK6vKysQE,8547
|
57
57
|
liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
|
58
58
|
liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
|
59
|
-
liger_kernel_nightly-0.4.2.
|
60
|
-
liger_kernel_nightly-0.4.2.
|
61
|
-
liger_kernel_nightly-0.4.2.
|
62
|
-
liger_kernel_nightly-0.4.2.
|
63
|
-
liger_kernel_nightly-0.4.2.
|
64
|
-
liger_kernel_nightly-0.4.2.
|
59
|
+
liger_kernel_nightly-0.4.2.dev20241209234352.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
60
|
+
liger_kernel_nightly-0.4.2.dev20241209234352.dist-info/METADATA,sha256=DXgBwRWN509ykIXn_83UuDRiwhZ-1RQPv4ubuieBXBA,22801
|
61
|
+
liger_kernel_nightly-0.4.2.dev20241209234352.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
62
|
+
liger_kernel_nightly-0.4.2.dev20241209234352.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
63
|
+
liger_kernel_nightly-0.4.2.dev20241209234352.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
64
|
+
liger_kernel_nightly-0.4.2.dev20241209234352.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|