liger-kernel-nightly 0.4.2.dev20241207202822__py3-none-any.whl → 0.4.2.dev20241209060920__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/fused_linear_distillation.py +250 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +101 -101
- {liger_kernel_nightly-0.4.2.dev20241207202822.dist-info → liger_kernel_nightly-0.4.2.dev20241209060920.dist-info}/METADATA +11 -4
- {liger_kernel_nightly-0.4.2.dev20241207202822.dist-info → liger_kernel_nightly-0.4.2.dev20241209060920.dist-info}/RECORD +8 -7
- {liger_kernel_nightly-0.4.2.dev20241207202822.dist-info → liger_kernel_nightly-0.4.2.dev20241209060920.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.2.dev20241207202822.dist-info → liger_kernel_nightly-0.4.2.dev20241209060920.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.2.dev20241207202822.dist-info → liger_kernel_nightly-0.4.2.dev20241209060920.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.4.2.dev20241207202822.dist-info → liger_kernel_nightly-0.4.2.dev20241209060920.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,250 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from functools import partial
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from torch.nn import functional as F
|
6
|
+
|
7
|
+
|
8
|
+
class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
9
|
+
|
10
|
+
@abstractmethod
|
11
|
+
def distillation_loss_fn(student_logits, teacher_logits, temperature):
|
12
|
+
"""
|
13
|
+
Compute distillation loss.
|
14
|
+
Args:
|
15
|
+
student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
|
16
|
+
teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
|
17
|
+
"""
|
18
|
+
raise NotImplementedError("Distillation loss function must be implemented.")
|
19
|
+
|
20
|
+
@staticmethod
|
21
|
+
def chunk_forward(
|
22
|
+
student_input_chunk,
|
23
|
+
student_weight,
|
24
|
+
teacher_input_chunk,
|
25
|
+
teacher_weight,
|
26
|
+
target_chunk,
|
27
|
+
student_bias=None,
|
28
|
+
teacher_bias=None,
|
29
|
+
ignore_index=-100,
|
30
|
+
compute_ce_loss=True,
|
31
|
+
):
|
32
|
+
# Student
|
33
|
+
student_logits_chunk = student_input_chunk @ student_weight.t()
|
34
|
+
if student_bias is not None:
|
35
|
+
student_logits_chunk += student_bias
|
36
|
+
student_log_probs_chunk = F.log_softmax(student_logits_chunk.float(), dim=-1)
|
37
|
+
|
38
|
+
# Teacher
|
39
|
+
with torch.no_grad():
|
40
|
+
teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t()
|
41
|
+
if teacher_bias is not None:
|
42
|
+
teacher_logits_chunk += teacher_bias
|
43
|
+
|
44
|
+
# The hard/task loss
|
45
|
+
ce_loss = 0.0
|
46
|
+
if compute_ce_loss:
|
47
|
+
ce_loss = F.nll_loss(
|
48
|
+
student_log_probs_chunk.view(-1, student_log_probs_chunk.shape[-1]),
|
49
|
+
target_chunk.view(-1),
|
50
|
+
reduction="sum",
|
51
|
+
ignore_index=ignore_index,
|
52
|
+
)
|
53
|
+
|
54
|
+
return student_logits_chunk, teacher_logits_chunk, ce_loss
|
55
|
+
|
56
|
+
@staticmethod
|
57
|
+
def _compute_loss(
|
58
|
+
student_input_chunk,
|
59
|
+
student_weight,
|
60
|
+
teacher_input_chunk,
|
61
|
+
teacher_weight,
|
62
|
+
target_chunk,
|
63
|
+
student_bias=None,
|
64
|
+
teacher_bias=None,
|
65
|
+
distillation_loss_fn=None,
|
66
|
+
full_target=None,
|
67
|
+
ignore_index=-100,
|
68
|
+
temperature=1.0,
|
69
|
+
weight_hard_loss=0.5,
|
70
|
+
weight_soft_loss=0.5,
|
71
|
+
compute_ce_loss=True,
|
72
|
+
**loss_kwargs,
|
73
|
+
):
|
74
|
+
"""
|
75
|
+
Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function.
|
76
|
+
Args:
|
77
|
+
distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
78
|
+
student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size).
|
79
|
+
student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size).
|
80
|
+
teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size).
|
81
|
+
teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size).
|
82
|
+
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,).
|
83
|
+
student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
84
|
+
teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
85
|
+
full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,).
|
86
|
+
ignore_index (int): Index to ignore for loss computation.
|
87
|
+
weight_hard_loss (float): Weight for hard loss.
|
88
|
+
weight_soft_loss (float): Weight for soft loss.
|
89
|
+
compute_ce_loss (bool): Whether to compute CE loss.
|
90
|
+
loss_kwargs (dict): Additional arguments for the loss function.
|
91
|
+
"""
|
92
|
+
student_logits_chunk, teacher_logits_chunk, hard_loss = (
|
93
|
+
LigerFusedLinearDistillationBase.chunk_forward(
|
94
|
+
student_input_chunk,
|
95
|
+
student_weight,
|
96
|
+
teacher_input_chunk,
|
97
|
+
teacher_weight,
|
98
|
+
target_chunk,
|
99
|
+
student_bias=student_bias,
|
100
|
+
teacher_bias=teacher_bias,
|
101
|
+
ignore_index=ignore_index,
|
102
|
+
compute_ce_loss=compute_ce_loss,
|
103
|
+
)
|
104
|
+
)
|
105
|
+
|
106
|
+
hard_loss /= full_target.shape[0]
|
107
|
+
|
108
|
+
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature)
|
109
|
+
soft_loss /= full_target.shape[0]
|
110
|
+
|
111
|
+
loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
|
112
|
+
return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk)
|
113
|
+
|
114
|
+
@staticmethod
|
115
|
+
def forward(
|
116
|
+
ctx,
|
117
|
+
student_input,
|
118
|
+
student_weight,
|
119
|
+
teacher_input,
|
120
|
+
teacher_weight,
|
121
|
+
target,
|
122
|
+
student_bias=None,
|
123
|
+
teacher_bias=None,
|
124
|
+
loss_fn=None,
|
125
|
+
chunk_size=1024,
|
126
|
+
ignore_index=-100,
|
127
|
+
weight_hard_loss=0.5,
|
128
|
+
weight_soft_loss=0.5,
|
129
|
+
compute_ce_loss=True,
|
130
|
+
temperature=1.0,
|
131
|
+
compiled=True,
|
132
|
+
**loss_kwargs,
|
133
|
+
):
|
134
|
+
"""
|
135
|
+
Base class for fused linear layer with distillation loss.
|
136
|
+
Only need to compute gradients for student model.
|
137
|
+
|
138
|
+
Args:
|
139
|
+
student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, student_hidden_size).
|
140
|
+
student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, student_hidden_size).
|
141
|
+
teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, teacher_hidden_size).
|
142
|
+
teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, teacher_hidden_size).
|
143
|
+
target (torch.Tensor): Target truth label tensor. Shape: (batch_size * seq_len).
|
144
|
+
student_bias (torch.Tensor, optional): Student bias tensor. Shape: (vocab_size,).
|
145
|
+
teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,).
|
146
|
+
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
147
|
+
chunk_size (int): Size of a chunk.
|
148
|
+
compute_ce_loss (bool): Whether to compute CE loss.
|
149
|
+
ignore_index (int): Index to ignore for loss computation.
|
150
|
+
weight_hard_loss (float): Weight for hard/task loss.
|
151
|
+
weight_soft_loss (float): Weight for soft/distillation loss.
|
152
|
+
compiled (bool): Whether to use torch compile for chunk accumulation.
|
153
|
+
loss_kwargs (dict): Other possible arguments that a loss function might need
|
154
|
+
"""
|
155
|
+
CHUNK_SIZE = chunk_size
|
156
|
+
grad_weight = torch.zeros_like(student_weight)
|
157
|
+
grad_inputs = []
|
158
|
+
grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
|
159
|
+
loss_acc = torch.zeros((), device=student_input.device)
|
160
|
+
|
161
|
+
loss_func_to_call = partial(
|
162
|
+
LigerFusedLinearDistillationBase._compute_loss,
|
163
|
+
distillation_loss_fn=loss_fn,
|
164
|
+
full_target=target,
|
165
|
+
ignore_index=ignore_index,
|
166
|
+
weight_hard_loss=weight_hard_loss,
|
167
|
+
weight_soft_loss=weight_soft_loss,
|
168
|
+
compute_ce_loss=compute_ce_loss,
|
169
|
+
temperature=temperature,
|
170
|
+
**loss_kwargs,
|
171
|
+
)
|
172
|
+
|
173
|
+
def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
|
174
|
+
if student_bias is not None:
|
175
|
+
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
|
176
|
+
chunk_loss,
|
177
|
+
(
|
178
|
+
chunk_soft_loss,
|
179
|
+
chunk_hard_loss,
|
180
|
+
chunk_student_logits,
|
181
|
+
chunk_teacher_logits,
|
182
|
+
),
|
183
|
+
) = torch.func.grad_and_value(
|
184
|
+
loss_func_to_call, argnums=(0, 1, 5), has_aux=True
|
185
|
+
)(
|
186
|
+
student_input_chunk,
|
187
|
+
student_weight,
|
188
|
+
teacher_input_chunk,
|
189
|
+
teacher_weight,
|
190
|
+
target_chunk,
|
191
|
+
student_bias,
|
192
|
+
teacher_bias,
|
193
|
+
)
|
194
|
+
grad_bias.add_(chunk_grad_bias)
|
195
|
+
else:
|
196
|
+
(chunk_grad_input, chunk_grad_weight), (
|
197
|
+
chunk_loss,
|
198
|
+
(
|
199
|
+
chunk_soft_loss,
|
200
|
+
chunk_hard_loss,
|
201
|
+
chunk_student_logits,
|
202
|
+
chunk_teacher_logits,
|
203
|
+
),
|
204
|
+
) = torch.func.grad_and_value(
|
205
|
+
loss_func_to_call, argnums=(0, 1), has_aux=True
|
206
|
+
)(
|
207
|
+
student_input_chunk,
|
208
|
+
student_weight,
|
209
|
+
teacher_input_chunk,
|
210
|
+
teacher_weight,
|
211
|
+
target_chunk,
|
212
|
+
student_bias,
|
213
|
+
teacher_bias,
|
214
|
+
)
|
215
|
+
grad_weight.add_(chunk_grad_weight)
|
216
|
+
loss_acc.add_(chunk_loss)
|
217
|
+
return chunk_grad_input
|
218
|
+
|
219
|
+
if compiled:
|
220
|
+
accumulate_chunk = torch.compile(accumulate_chunk)
|
221
|
+
|
222
|
+
num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE)
|
223
|
+
_student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0)
|
224
|
+
_teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)
|
225
|
+
_target_chunks = torch.chunk(target, chunks=num_chunks, dim=0)
|
226
|
+
|
227
|
+
for student_input_chunk, teacher_input_chunk, target_chunk in zip(
|
228
|
+
_student_input_chunks, _teacher_input_chunks, _target_chunks
|
229
|
+
):
|
230
|
+
grad_input = accumulate_chunk(
|
231
|
+
student_input_chunk, teacher_input_chunk, target_chunk
|
232
|
+
)
|
233
|
+
grad_inputs.append(grad_input)
|
234
|
+
|
235
|
+
ctx.save_for_backward(
|
236
|
+
torch.cat(grad_inputs, dim=0),
|
237
|
+
grad_weight,
|
238
|
+
grad_bias,
|
239
|
+
)
|
240
|
+
return loss_acc
|
241
|
+
|
242
|
+
@staticmethod
|
243
|
+
def backward(ctx, grad_output):
|
244
|
+
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
245
|
+
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
246
|
+
grad_input = grad_input * grad_output
|
247
|
+
grad_weight = grad_weight * grad_output
|
248
|
+
grad_bias = grad_bias * grad_output if grad_bias is not None else None
|
249
|
+
|
250
|
+
return grad_input, grad_weight, None, grad_bias
|
@@ -64,6 +64,103 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
64
64
|
chosen_nll_loss,
|
65
65
|
)
|
66
66
|
|
67
|
+
@staticmethod
|
68
|
+
def _compute_loss(
|
69
|
+
input_chunk,
|
70
|
+
weight,
|
71
|
+
target_chunk,
|
72
|
+
bias=None,
|
73
|
+
preference_loss_fn=None,
|
74
|
+
full_target=None,
|
75
|
+
ignore_index=-100,
|
76
|
+
alpha=1.0,
|
77
|
+
beta=0.1,
|
78
|
+
compute_nll_loss=True,
|
79
|
+
use_ref_model=False,
|
80
|
+
ref_weight=None,
|
81
|
+
ref_bias=None,
|
82
|
+
**loss_kwargs,
|
83
|
+
):
|
84
|
+
"""
|
85
|
+
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
|
86
|
+
Args:
|
87
|
+
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
88
|
+
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
|
89
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
|
90
|
+
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
|
91
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
92
|
+
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
|
93
|
+
ignore_index (int): Index to ignore for loss computation.
|
94
|
+
alpha (float): Weight for the NLL loss.
|
95
|
+
beta (float): Weight for the odds ratio loss.
|
96
|
+
compute_nll_loss (bool): Whether to compute NLL loss.
|
97
|
+
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
98
|
+
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
99
|
+
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
100
|
+
loss_kwargs (dict): Additional arguments for the loss function.
|
101
|
+
"""
|
102
|
+
(
|
103
|
+
chosen_logps,
|
104
|
+
rejected_logps,
|
105
|
+
chosen_logits,
|
106
|
+
rejected_logits,
|
107
|
+
chosen_nll_loss,
|
108
|
+
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
109
|
+
input_chunk,
|
110
|
+
weight,
|
111
|
+
target_chunk,
|
112
|
+
bias=bias,
|
113
|
+
ignore_index=ignore_index,
|
114
|
+
compute_nll_loss=compute_nll_loss,
|
115
|
+
)
|
116
|
+
chosen_nll_loss = (
|
117
|
+
chosen_nll_loss
|
118
|
+
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
|
119
|
+
)
|
120
|
+
chosen_logits_mean = chosen_logits.sum() / (
|
121
|
+
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
122
|
+
)
|
123
|
+
rejected_logits_mean = rejected_logits.sum() / (
|
124
|
+
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
125
|
+
)
|
126
|
+
|
127
|
+
if use_ref_model:
|
128
|
+
with torch.no_grad():
|
129
|
+
(
|
130
|
+
ref_chosen_logps,
|
131
|
+
ref_rejected_logps,
|
132
|
+
ref_chosen_logits,
|
133
|
+
ref_rejected_logits,
|
134
|
+
ref_chosen_nll_loss,
|
135
|
+
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
136
|
+
input_chunk,
|
137
|
+
ref_weight,
|
138
|
+
target_chunk,
|
139
|
+
ref_bias,
|
140
|
+
ignore_index=ignore_index,
|
141
|
+
compute_nll_loss=False, # We don't need NLL loss for the reference model
|
142
|
+
)
|
143
|
+
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
|
144
|
+
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
|
145
|
+
|
146
|
+
preference_loss_outputs = preference_loss_fn(
|
147
|
+
chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs
|
148
|
+
)
|
149
|
+
if isinstance(preference_loss_outputs, tuple):
|
150
|
+
preference_loss, *aux_outputs = preference_loss_outputs
|
151
|
+
else:
|
152
|
+
preference_loss, aux_outputs = preference_loss_outputs, []
|
153
|
+
|
154
|
+
loss = alpha * chosen_nll_loss - preference_loss
|
155
|
+
return_vars = (
|
156
|
+
chosen_logps,
|
157
|
+
rejected_logps,
|
158
|
+
chosen_logits_mean,
|
159
|
+
rejected_logits_mean,
|
160
|
+
chosen_nll_loss,
|
161
|
+
)
|
162
|
+
return loss, (*return_vars, *aux_outputs)
|
163
|
+
|
67
164
|
@staticmethod
|
68
165
|
def forward(
|
69
166
|
ctx,
|
@@ -134,7 +231,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
134
231
|
**loss_kwargs,
|
135
232
|
)
|
136
233
|
|
137
|
-
def
|
234
|
+
def accumulate_core(input_chunk, target_chunk):
|
138
235
|
if bias is not None:
|
139
236
|
return torch.func.grad_and_value(
|
140
237
|
loss_func_to_call, argnums=(0, 1, 3), has_aux=True
|
@@ -156,7 +253,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
156
253
|
chunk_nll_loss,
|
157
254
|
*aux_outputs,
|
158
255
|
),
|
159
|
-
) =
|
256
|
+
) = accumulate_core(input_chunk, target_chunk)
|
160
257
|
grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
|
161
258
|
else:
|
162
259
|
(chunk_grad_input, chunk_grad_weight), (
|
@@ -169,7 +266,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
169
266
|
chunk_nll_loss,
|
170
267
|
*aux_outputs,
|
171
268
|
),
|
172
|
-
) =
|
269
|
+
) = accumulate_core(input_chunk, target_chunk)
|
173
270
|
|
174
271
|
grad_weight.add_(chunk_grad_weight)
|
175
272
|
loss_acc.add_(chunk_loss)
|
@@ -199,7 +296,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
199
296
|
return chunk_grad_input
|
200
297
|
|
201
298
|
if compiled:
|
202
|
-
|
299
|
+
accumulate_core = torch.compile(accumulate_core)
|
203
300
|
|
204
301
|
len_chosen = target.shape[0] // 2
|
205
302
|
chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
|
@@ -270,100 +367,3 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
270
367
|
grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
|
271
368
|
|
272
369
|
return grad_input, grad_weight, None, grad_bias, None, None, None
|
273
|
-
|
274
|
-
@staticmethod
|
275
|
-
def _compute_loss(
|
276
|
-
input_chunk,
|
277
|
-
weight,
|
278
|
-
target_chunk,
|
279
|
-
bias=None,
|
280
|
-
preference_loss_fn=None,
|
281
|
-
full_target=None,
|
282
|
-
ignore_index=-100,
|
283
|
-
alpha=1.0,
|
284
|
-
beta=0.1,
|
285
|
-
compute_nll_loss=True,
|
286
|
-
use_ref_model=False,
|
287
|
-
ref_weight=None,
|
288
|
-
ref_bias=None,
|
289
|
-
**loss_kwargs,
|
290
|
-
):
|
291
|
-
"""
|
292
|
-
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
|
293
|
-
Args:
|
294
|
-
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
295
|
-
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
|
296
|
-
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
|
297
|
-
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
|
298
|
-
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
299
|
-
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
|
300
|
-
ignore_index (int): Index to ignore for loss computation.
|
301
|
-
alpha (float): Weight for the NLL loss.
|
302
|
-
beta (float): Weight for the odds ratio loss.
|
303
|
-
compute_nll_loss (bool): Whether to compute NLL loss.
|
304
|
-
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
305
|
-
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
306
|
-
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
307
|
-
loss_kwargs (dict): Additional arguments for the loss function.
|
308
|
-
"""
|
309
|
-
(
|
310
|
-
chosen_logps,
|
311
|
-
rejected_logps,
|
312
|
-
chosen_logits,
|
313
|
-
rejected_logits,
|
314
|
-
chosen_nll_loss,
|
315
|
-
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
316
|
-
input_chunk,
|
317
|
-
weight,
|
318
|
-
target_chunk,
|
319
|
-
bias=bias,
|
320
|
-
ignore_index=ignore_index,
|
321
|
-
compute_nll_loss=compute_nll_loss,
|
322
|
-
)
|
323
|
-
chosen_nll_loss = (
|
324
|
-
chosen_nll_loss
|
325
|
-
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
|
326
|
-
)
|
327
|
-
chosen_logits_mean = chosen_logits.sum() / (
|
328
|
-
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
329
|
-
)
|
330
|
-
rejected_logits_mean = rejected_logits.sum() / (
|
331
|
-
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
332
|
-
)
|
333
|
-
|
334
|
-
if use_ref_model:
|
335
|
-
with torch.no_grad():
|
336
|
-
(
|
337
|
-
ref_chosen_logps,
|
338
|
-
ref_rejected_logps,
|
339
|
-
ref_chosen_logits,
|
340
|
-
ref_rejected_logits,
|
341
|
-
ref_chosen_nll_loss,
|
342
|
-
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
343
|
-
input_chunk,
|
344
|
-
ref_weight,
|
345
|
-
target_chunk,
|
346
|
-
ref_bias,
|
347
|
-
ignore_index=ignore_index,
|
348
|
-
compute_nll_loss=False, # We don't need NLL loss for the reference model
|
349
|
-
)
|
350
|
-
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
|
351
|
-
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
|
352
|
-
|
353
|
-
preference_loss_outputs = preference_loss_fn(
|
354
|
-
chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs
|
355
|
-
)
|
356
|
-
if isinstance(preference_loss_outputs, tuple):
|
357
|
-
preference_loss, *aux_outputs = preference_loss_outputs
|
358
|
-
else:
|
359
|
-
preference_loss, aux_outputs = preference_loss_outputs, []
|
360
|
-
|
361
|
-
loss = alpha * chosen_nll_loss - preference_loss
|
362
|
-
return_vars = (
|
363
|
-
chosen_logps,
|
364
|
-
rejected_logps,
|
365
|
-
chosen_logits_mean,
|
366
|
-
rejected_logits_mean,
|
367
|
-
chosen_nll_loss,
|
368
|
-
)
|
369
|
-
return loss, (*return_vars, *aux_outputs)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: liger_kernel_nightly
|
3
|
-
Version: 0.4.2.
|
3
|
+
Version: 0.4.2.dev20241209060920
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
@@ -85,9 +85,16 @@ Requires-Dist: transformers~=4.0; extra == "transformers"
|
|
85
85
|
</a>
|
86
86
|
</td>
|
87
87
|
<td style="padding: 10px;">
|
88
|
-
<
|
89
|
-
<
|
90
|
-
|
88
|
+
<div style="display: block;">
|
89
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
|
90
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
|
91
|
+
</a>
|
92
|
+
</div>
|
93
|
+
<div style="display: block;">
|
94
|
+
<a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
|
95
|
+
<img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg" alt="Build">
|
96
|
+
</a>
|
97
|
+
</div>
|
91
98
|
</td>
|
92
99
|
</tr>
|
93
100
|
</table>
|
@@ -5,7 +5,8 @@ liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMH
|
|
5
5
|
liger_kernel/chunked_loss/cpo_loss.py,sha256=P20txjErLCSfSfToFT8pnuVPqFU4Bbybt3zRXfGEV-0,3122
|
6
6
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=NZyM4ju56MBVrUTI_7-jGMx5pWWDYzwx7ALoMj1G8Ec,4276
|
7
7
|
liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
|
8
|
-
liger_kernel/chunked_loss/
|
8
|
+
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=8Dvr5FRQcWEpNYu615GxvAxlmh-0cDnUKLet274nxTQ,10200
|
9
|
+
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=_4MDZMzrNNgm91c6qdLEuXG1M8HyglZioiufv5opJOI,14881
|
9
10
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=GGwc3pLGGJzb_P_C7IogcA1EfdAcM1uktfKPmI1z2jk,3523
|
10
11
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=FtURWbXGjoAKyiVYF7fkMv8Us7uk3UrSg21pWOFk11Y,3385
|
11
12
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -55,9 +56,9 @@ liger_kernel/transformers/model/qwen2.py,sha256=EyhSSzQOskGjSnCsKMZpd1s5IAIlHd5P
|
|
55
56
|
liger_kernel/transformers/model/qwen2_vl.py,sha256=bIQe2bWiY--G84FhCD29Gdi64_qHP6vbcGsK6vKysQE,8547
|
56
57
|
liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
|
57
58
|
liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
|
58
|
-
liger_kernel_nightly-0.4.2.
|
59
|
-
liger_kernel_nightly-0.4.2.
|
60
|
-
liger_kernel_nightly-0.4.2.
|
61
|
-
liger_kernel_nightly-0.4.2.
|
62
|
-
liger_kernel_nightly-0.4.2.
|
63
|
-
liger_kernel_nightly-0.4.2.
|
59
|
+
liger_kernel_nightly-0.4.2.dev20241209060920.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
60
|
+
liger_kernel_nightly-0.4.2.dev20241209060920.dist-info/METADATA,sha256=JVpueVyzRkHvUmIZUd28SLap5uhYe0kD_UDc-zkuxSI,22339
|
61
|
+
liger_kernel_nightly-0.4.2.dev20241209060920.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
62
|
+
liger_kernel_nightly-0.4.2.dev20241209060920.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
63
|
+
liger_kernel_nightly-0.4.2.dev20241209060920.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
64
|
+
liger_kernel_nightly-0.4.2.dev20241209060920.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|