liger-kernel-nightly 0.4.2.dev20241209195823__py3-none-any.whl → 0.4.2.dev20241209224333__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/fused_linear_preference.py +181 -164
- {liger_kernel_nightly-0.4.2.dev20241209195823.dist-info → liger_kernel_nightly-0.4.2.dev20241209224333.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.4.2.dev20241209195823.dist-info → liger_kernel_nightly-0.4.2.dev20241209224333.dist-info}/RECORD +7 -7
- {liger_kernel_nightly-0.4.2.dev20241209195823.dist-info → liger_kernel_nightly-0.4.2.dev20241209224333.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209195823.dist-info → liger_kernel_nightly-0.4.2.dev20241209224333.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209195823.dist-info → liger_kernel_nightly-0.4.2.dev20241209224333.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.4.2.dev20241209195823.dist-info → liger_kernel_nightly-0.4.2.dev20241209224333.dist-info}/top_level.txt +0 -0
@@ -8,159 +8,12 @@ from torch.nn import functional as F
|
|
8
8
|
class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
9
9
|
|
10
10
|
@abstractmethod
|
11
|
-
def preference_loss_fn(
|
11
|
+
def preference_loss_fn(*args, **kwargs):
|
12
12
|
"""
|
13
|
-
|
14
|
-
Args:
|
15
|
-
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
|
16
|
-
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
|
17
|
-
beta (float): Weight for the odds ratio loss.
|
13
|
+
To be extended by subclasses.
|
18
14
|
"""
|
19
15
|
raise NotImplementedError("Preference loss function must be implemented.")
|
20
16
|
|
21
|
-
@staticmethod
|
22
|
-
def chunk_forward(
|
23
|
-
input_chunk,
|
24
|
-
weight,
|
25
|
-
target_chunk,
|
26
|
-
bias=None,
|
27
|
-
ignore_index=-100,
|
28
|
-
compute_nll_loss=True,
|
29
|
-
):
|
30
|
-
len_chosen_chunk = target_chunk.shape[0] // 2
|
31
|
-
logits_chunk = input_chunk @ weight.t()
|
32
|
-
if bias is not None:
|
33
|
-
logits_chunk = logits_chunk + bias
|
34
|
-
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
|
35
|
-
|
36
|
-
chosen_nll_loss = 0.0
|
37
|
-
if compute_nll_loss:
|
38
|
-
chosen_nll_loss = F.nll_loss(
|
39
|
-
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
|
40
|
-
target_chunk[:len_chosen_chunk].view(-1),
|
41
|
-
reduction="sum",
|
42
|
-
ignore_index=ignore_index,
|
43
|
-
)
|
44
|
-
|
45
|
-
loss_mask = target_chunk != ignore_index
|
46
|
-
label_chunk = torch.where(loss_mask, target_chunk, 0)
|
47
|
-
|
48
|
-
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
|
49
|
-
-1
|
50
|
-
)
|
51
|
-
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
52
|
-
|
53
|
-
chosen_logps = average_log_prob[:len_chosen_chunk]
|
54
|
-
rejected_logps = average_log_prob[len_chosen_chunk:]
|
55
|
-
|
56
|
-
chosen_logits = logits_chunk[:len_chosen_chunk]
|
57
|
-
rejected_logits = logits_chunk[len_chosen_chunk:]
|
58
|
-
|
59
|
-
return (
|
60
|
-
chosen_logps,
|
61
|
-
rejected_logps,
|
62
|
-
chosen_logits,
|
63
|
-
rejected_logits,
|
64
|
-
chosen_nll_loss,
|
65
|
-
)
|
66
|
-
|
67
|
-
@staticmethod
|
68
|
-
def _compute_loss(
|
69
|
-
input_chunk,
|
70
|
-
weight,
|
71
|
-
target_chunk,
|
72
|
-
bias=None,
|
73
|
-
preference_loss_fn=None,
|
74
|
-
full_target=None,
|
75
|
-
ignore_index=-100,
|
76
|
-
alpha=1.0,
|
77
|
-
beta=0.1,
|
78
|
-
compute_nll_loss=True,
|
79
|
-
use_ref_model=False,
|
80
|
-
ref_weight=None,
|
81
|
-
ref_bias=None,
|
82
|
-
**loss_kwargs,
|
83
|
-
):
|
84
|
-
"""
|
85
|
-
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
|
86
|
-
Args:
|
87
|
-
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
88
|
-
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
|
89
|
-
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
|
90
|
-
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
|
91
|
-
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
92
|
-
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
|
93
|
-
ignore_index (int): Index to ignore for loss computation.
|
94
|
-
alpha (float): Weight for the NLL loss.
|
95
|
-
beta (float): Weight for the odds ratio loss.
|
96
|
-
compute_nll_loss (bool): Whether to compute NLL loss.
|
97
|
-
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
98
|
-
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
99
|
-
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
100
|
-
loss_kwargs (dict): Additional arguments for the loss function.
|
101
|
-
"""
|
102
|
-
(
|
103
|
-
chosen_logps,
|
104
|
-
rejected_logps,
|
105
|
-
chosen_logits,
|
106
|
-
rejected_logits,
|
107
|
-
chosen_nll_loss,
|
108
|
-
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
109
|
-
input_chunk,
|
110
|
-
weight,
|
111
|
-
target_chunk,
|
112
|
-
bias=bias,
|
113
|
-
ignore_index=ignore_index,
|
114
|
-
compute_nll_loss=compute_nll_loss,
|
115
|
-
)
|
116
|
-
chosen_nll_loss = (
|
117
|
-
chosen_nll_loss
|
118
|
-
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
|
119
|
-
)
|
120
|
-
chosen_logits_mean = chosen_logits.sum() / (
|
121
|
-
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
122
|
-
)
|
123
|
-
rejected_logits_mean = rejected_logits.sum() / (
|
124
|
-
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
125
|
-
)
|
126
|
-
|
127
|
-
if use_ref_model:
|
128
|
-
with torch.no_grad():
|
129
|
-
(
|
130
|
-
ref_chosen_logps,
|
131
|
-
ref_rejected_logps,
|
132
|
-
ref_chosen_logits,
|
133
|
-
ref_rejected_logits,
|
134
|
-
ref_chosen_nll_loss,
|
135
|
-
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
136
|
-
input_chunk,
|
137
|
-
ref_weight,
|
138
|
-
target_chunk,
|
139
|
-
ref_bias,
|
140
|
-
ignore_index=ignore_index,
|
141
|
-
compute_nll_loss=False, # We don't need NLL loss for the reference model
|
142
|
-
)
|
143
|
-
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
|
144
|
-
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
|
145
|
-
|
146
|
-
preference_loss_outputs = preference_loss_fn(
|
147
|
-
chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs
|
148
|
-
)
|
149
|
-
if isinstance(preference_loss_outputs, tuple):
|
150
|
-
preference_loss, *aux_outputs = preference_loss_outputs
|
151
|
-
else:
|
152
|
-
preference_loss, aux_outputs = preference_loss_outputs, []
|
153
|
-
|
154
|
-
loss = alpha * chosen_nll_loss - preference_loss
|
155
|
-
return_vars = (
|
156
|
-
chosen_logps,
|
157
|
-
rejected_logps,
|
158
|
-
chosen_logits_mean,
|
159
|
-
rejected_logits_mean,
|
160
|
-
chosen_nll_loss,
|
161
|
-
)
|
162
|
-
return loss, (*return_vars, *aux_outputs)
|
163
|
-
|
164
17
|
@staticmethod
|
165
18
|
def forward(
|
166
19
|
ctx,
|
@@ -176,6 +29,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
176
29
|
compute_nll_loss=True,
|
177
30
|
compiled=True,
|
178
31
|
use_ref_model=False,
|
32
|
+
# TODO: ref input
|
179
33
|
ref_weight=None,
|
180
34
|
ref_bias=None,
|
181
35
|
**loss_kwargs,
|
@@ -184,6 +38,14 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
184
38
|
Base class for fused linear layer with preference loss.
|
185
39
|
Expects _input to be stacked with chosen and rejected inputs on the batch dimension.
|
186
40
|
|
41
|
+
The mental model is:
|
42
|
+
|
43
|
+
forward()
|
44
|
+
├── Loop over chunks
|
45
|
+
└── compute_loss()
|
46
|
+
├── chunk_forward() # Compute logits and log probs
|
47
|
+
└── prefer_loss() # Calculate preference loss
|
48
|
+
|
187
49
|
Args:
|
188
50
|
_input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
|
189
51
|
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
|
@@ -191,10 +53,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
191
53
|
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
192
54
|
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
193
55
|
chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
|
194
|
-
compute_nll_loss (bool): Whether to compute NLL loss.
|
195
56
|
ignore_index (int): Index to ignore for loss computation.
|
196
57
|
alpha (float): Weight for the NLL loss.
|
197
|
-
beta (float): Weight for the
|
58
|
+
beta (float): Weight for the preference loss.
|
198
59
|
compute_nll_loss (bool): Whether to compute NLL loss.
|
199
60
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
200
61
|
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
@@ -205,11 +66,16 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
205
66
|
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
|
206
67
|
CHUNK_SIZE = chunk_size
|
207
68
|
|
69
|
+
# Gradients to be accumulated
|
208
70
|
grad_weight = torch.zeros_like(weight)
|
209
71
|
grad_chosen_inputs = []
|
210
72
|
grad_rejected_inputs = []
|
211
73
|
grad_bias = torch.zeros_like(bias) if bias is not None else None
|
74
|
+
|
75
|
+
# Loss to be accumulated
|
212
76
|
loss_acc = torch.zeros((), device=_input.device)
|
77
|
+
|
78
|
+
# Metrics to be recorded
|
213
79
|
policy_chosen_logps = []
|
214
80
|
policy_rejected_logps = []
|
215
81
|
policy_chosen_logits_mean = torch.zeros((), device=_input.device)
|
@@ -217,7 +83,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
217
83
|
policy_nll_loss = torch.zeros((), device=_input.device)
|
218
84
|
aggregated_aux_outputs = [] # aggregated aux outputs from all chunks
|
219
85
|
|
220
|
-
|
86
|
+
compute_loss = partial(
|
221
87
|
LigerFusedLinearPreferenceBase._compute_loss,
|
222
88
|
preference_loss_fn=loss_fn,
|
223
89
|
ignore_index=ignore_index,
|
@@ -231,14 +97,17 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
231
97
|
**loss_kwargs,
|
232
98
|
)
|
233
99
|
|
234
|
-
def
|
100
|
+
def fused_fwd_bwd(input_chunk, target_chunk):
|
101
|
+
"""
|
102
|
+
Fused forward and backward pass for a chunk of input and target.
|
103
|
+
"""
|
235
104
|
if bias is not None:
|
236
105
|
return torch.func.grad_and_value(
|
237
|
-
|
106
|
+
compute_loss, argnums=(0, 1, 3), has_aux=True
|
238
107
|
)(input_chunk, weight, target_chunk, bias)
|
239
108
|
else:
|
240
109
|
return torch.func.grad_and_value(
|
241
|
-
|
110
|
+
compute_loss, argnums=(0, 1), has_aux=True
|
242
111
|
)(input_chunk, weight, target_chunk)
|
243
112
|
|
244
113
|
def accumulate_chunk(input_chunk, target_chunk):
|
@@ -253,7 +122,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
253
122
|
chunk_nll_loss,
|
254
123
|
*aux_outputs,
|
255
124
|
),
|
256
|
-
) =
|
125
|
+
) = fused_fwd_bwd(input_chunk, target_chunk)
|
257
126
|
grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
|
258
127
|
else:
|
259
128
|
(chunk_grad_input, chunk_grad_weight), (
|
@@ -266,16 +135,26 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
266
135
|
chunk_nll_loss,
|
267
136
|
*aux_outputs,
|
268
137
|
),
|
269
|
-
) =
|
138
|
+
) = fused_fwd_bwd(input_chunk, target_chunk)
|
270
139
|
|
140
|
+
# Accumulate gradients
|
271
141
|
grad_weight.add_(chunk_grad_weight)
|
142
|
+
grad_chosen_inputs.append(chunk_grad_input[: chosen_target_chunk.shape[0]])
|
143
|
+
grad_rejected_inputs.append(
|
144
|
+
chunk_grad_input[chosen_target_chunk.shape[0] :]
|
145
|
+
)
|
146
|
+
|
147
|
+
# Accumulate loss
|
272
148
|
loss_acc.add_(chunk_loss)
|
149
|
+
|
150
|
+
# Accumulate metrics
|
273
151
|
policy_chosen_logps.append(chunk_chosen_logps)
|
274
152
|
policy_rejected_logps.append(chunk_rejected_logps)
|
275
153
|
policy_chosen_logits_mean.add_(chunk_chosen_logits_mean)
|
276
154
|
policy_rejected_logits_mean.add_(chunk_rejected_logits_mean)
|
277
155
|
policy_nll_loss.add_(chunk_nll_loss)
|
278
156
|
|
157
|
+
# aux_outputs
|
279
158
|
# Initialize storage for aux_outputs
|
280
159
|
if len(aggregated_aux_outputs) == 0:
|
281
160
|
for aux in aux_outputs:
|
@@ -293,10 +172,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
293
172
|
else:
|
294
173
|
aggregated_aux_outputs[i].append(aux)
|
295
174
|
|
296
|
-
return chunk_grad_input
|
297
|
-
|
298
175
|
if compiled:
|
299
|
-
|
176
|
+
fused_fwd_bwd = torch.compile(fused_fwd_bwd)
|
300
177
|
|
301
178
|
len_chosen = target.shape[0] // 2
|
302
179
|
chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
|
@@ -327,10 +204,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
327
204
|
torch._dynamo.mark_dynamic(target, 1)
|
328
205
|
|
329
206
|
# accumulate loss, gradients, and metrics
|
330
|
-
|
331
|
-
|
332
|
-
grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]])
|
333
|
-
grad_rejected_inputs.append(grad_input[chosen_target_chunk.shape[0] :])
|
207
|
+
accumulate_chunk(input_chunk, target_chunk)
|
334
208
|
|
335
209
|
# combine grad_chosen_inputs and grad_rejected_inputs
|
336
210
|
grad_inputs = grad_chosen_inputs + grad_rejected_inputs
|
@@ -367,3 +241,146 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
367
241
|
grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
|
368
242
|
|
369
243
|
return grad_input, grad_weight, None, grad_bias, None, None, None
|
244
|
+
|
245
|
+
@staticmethod
|
246
|
+
def chunk_forward(
|
247
|
+
input_chunk,
|
248
|
+
weight,
|
249
|
+
target_chunk,
|
250
|
+
bias=None,
|
251
|
+
ignore_index=-100,
|
252
|
+
compute_nll_loss=True,
|
253
|
+
):
|
254
|
+
len_chosen_chunk = target_chunk.shape[0] // 2
|
255
|
+
logits_chunk = input_chunk @ weight.t()
|
256
|
+
if bias is not None:
|
257
|
+
logits_chunk = logits_chunk + bias
|
258
|
+
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
|
259
|
+
|
260
|
+
chosen_nll_loss = 0.0
|
261
|
+
if compute_nll_loss:
|
262
|
+
chosen_nll_loss = F.nll_loss(
|
263
|
+
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
|
264
|
+
target_chunk[:len_chosen_chunk].view(-1),
|
265
|
+
reduction="sum",
|
266
|
+
ignore_index=ignore_index,
|
267
|
+
)
|
268
|
+
|
269
|
+
loss_mask = target_chunk != ignore_index
|
270
|
+
label_chunk = torch.where(loss_mask, target_chunk, 0)
|
271
|
+
|
272
|
+
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
|
273
|
+
-1
|
274
|
+
)
|
275
|
+
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
276
|
+
|
277
|
+
chosen_logps = average_log_prob[:len_chosen_chunk]
|
278
|
+
rejected_logps = average_log_prob[len_chosen_chunk:]
|
279
|
+
|
280
|
+
chosen_logits = logits_chunk[:len_chosen_chunk]
|
281
|
+
rejected_logits = logits_chunk[len_chosen_chunk:]
|
282
|
+
|
283
|
+
return (
|
284
|
+
chosen_logps,
|
285
|
+
rejected_logps,
|
286
|
+
chosen_logits,
|
287
|
+
rejected_logits,
|
288
|
+
chosen_nll_loss,
|
289
|
+
)
|
290
|
+
|
291
|
+
@staticmethod
|
292
|
+
def _compute_loss(
|
293
|
+
input_chunk,
|
294
|
+
weight,
|
295
|
+
target_chunk,
|
296
|
+
bias=None,
|
297
|
+
preference_loss_fn=None,
|
298
|
+
full_target=None,
|
299
|
+
ignore_index=-100,
|
300
|
+
alpha=1.0,
|
301
|
+
beta=0.1,
|
302
|
+
compute_nll_loss=True,
|
303
|
+
use_ref_model=False,
|
304
|
+
ref_weight=None,
|
305
|
+
ref_bias=None,
|
306
|
+
**loss_kwargs,
|
307
|
+
):
|
308
|
+
"""
|
309
|
+
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
|
310
|
+
Args:
|
311
|
+
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
312
|
+
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
|
313
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
|
314
|
+
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
|
315
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
316
|
+
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
|
317
|
+
ignore_index (int): Index to ignore for loss computation.
|
318
|
+
alpha (float): Weight for the NLL loss.
|
319
|
+
beta (float): Weight for the preference loss.
|
320
|
+
compute_nll_loss (bool): Whether to compute NLL loss.
|
321
|
+
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
322
|
+
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
323
|
+
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
324
|
+
loss_kwargs (dict): Additional arguments for the loss function.
|
325
|
+
"""
|
326
|
+
(
|
327
|
+
chosen_logps,
|
328
|
+
rejected_logps,
|
329
|
+
chosen_logits,
|
330
|
+
rejected_logits,
|
331
|
+
chosen_nll_loss,
|
332
|
+
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
333
|
+
input_chunk,
|
334
|
+
weight,
|
335
|
+
target_chunk,
|
336
|
+
bias=bias,
|
337
|
+
ignore_index=ignore_index,
|
338
|
+
compute_nll_loss=compute_nll_loss,
|
339
|
+
)
|
340
|
+
chosen_nll_loss = (
|
341
|
+
chosen_nll_loss
|
342
|
+
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
|
343
|
+
)
|
344
|
+
chosen_logits_mean = chosen_logits.sum() / (
|
345
|
+
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
346
|
+
)
|
347
|
+
rejected_logits_mean = rejected_logits.sum() / (
|
348
|
+
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
349
|
+
)
|
350
|
+
|
351
|
+
if use_ref_model:
|
352
|
+
with torch.no_grad():
|
353
|
+
(
|
354
|
+
ref_chosen_logps,
|
355
|
+
ref_rejected_logps,
|
356
|
+
ref_chosen_logits,
|
357
|
+
ref_rejected_logits,
|
358
|
+
ref_chosen_nll_loss,
|
359
|
+
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
360
|
+
input_chunk,
|
361
|
+
ref_weight,
|
362
|
+
target_chunk,
|
363
|
+
ref_bias,
|
364
|
+
ignore_index=ignore_index,
|
365
|
+
compute_nll_loss=False, # We don't need NLL loss for the reference model
|
366
|
+
)
|
367
|
+
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
|
368
|
+
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
|
369
|
+
|
370
|
+
preference_loss_outputs = preference_loss_fn(
|
371
|
+
chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs
|
372
|
+
)
|
373
|
+
if isinstance(preference_loss_outputs, tuple):
|
374
|
+
preference_loss, *aux_outputs = preference_loss_outputs
|
375
|
+
else:
|
376
|
+
preference_loss, aux_outputs = preference_loss_outputs, []
|
377
|
+
|
378
|
+
loss = alpha * chosen_nll_loss - preference_loss
|
379
|
+
return_vars = (
|
380
|
+
chosen_logps,
|
381
|
+
rejected_logps,
|
382
|
+
chosen_logits_mean,
|
383
|
+
rejected_logits_mean,
|
384
|
+
chosen_nll_loss,
|
385
|
+
)
|
386
|
+
return loss, (*return_vars, *aux_outputs)
|
@@ -6,7 +6,7 @@ liger_kernel/chunked_loss/cpo_loss.py,sha256=P20txjErLCSfSfToFT8pnuVPqFU4Bbybt3z
|
|
6
6
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=NZyM4ju56MBVrUTI_7-jGMx5pWWDYzwx7ALoMj1G8Ec,4276
|
7
7
|
liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
|
8
8
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=2BH6DCPjsR2zS6zcwFPcIIZRhLF8SohjGdKsAJ_301o,10222
|
9
|
-
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=
|
9
|
+
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=vlWfaaIECWvCQhY9PM7zRI0vKThIrydMf6P44bXn1EE,15114
|
10
10
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=GGwc3pLGGJzb_P_C7IogcA1EfdAcM1uktfKPmI1z2jk,3523
|
11
11
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=FtURWbXGjoAKyiVYF7fkMv8Us7uk3UrSg21pWOFk11Y,3385
|
12
12
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -56,9 +56,9 @@ liger_kernel/transformers/model/qwen2.py,sha256=EyhSSzQOskGjSnCsKMZpd1s5IAIlHd5P
|
|
56
56
|
liger_kernel/transformers/model/qwen2_vl.py,sha256=bIQe2bWiY--G84FhCD29Gdi64_qHP6vbcGsK6vKysQE,8547
|
57
57
|
liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
|
58
58
|
liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
|
59
|
-
liger_kernel_nightly-0.4.2.
|
60
|
-
liger_kernel_nightly-0.4.2.
|
61
|
-
liger_kernel_nightly-0.4.2.
|
62
|
-
liger_kernel_nightly-0.4.2.
|
63
|
-
liger_kernel_nightly-0.4.2.
|
64
|
-
liger_kernel_nightly-0.4.2.
|
59
|
+
liger_kernel_nightly-0.4.2.dev20241209224333.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
60
|
+
liger_kernel_nightly-0.4.2.dev20241209224333.dist-info/METADATA,sha256=zK3s59xLwgMnS0ImjXWqXz5jVbQHJ5UvV_v5e0xSbbk,22801
|
61
|
+
liger_kernel_nightly-0.4.2.dev20241209224333.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
62
|
+
liger_kernel_nightly-0.4.2.dev20241209224333.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
63
|
+
liger_kernel_nightly-0.4.2.dev20241209224333.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
64
|
+
liger_kernel_nightly-0.4.2.dev20241209224333.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|