liger-kernel-nightly 0.5.2.dev20250122005057__py3-none-any.whl → 0.5.2.dev20250124002122__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- liger_kernel/chunked_loss/README.md +1 -1
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
- liger_kernel/chunked_loss/kto_loss.py +172 -0
- {liger_kernel_nightly-0.5.2.dev20250122005057.dist-info → liger_kernel_nightly-0.5.2.dev20250124002122.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.2.dev20250122005057.dist-info → liger_kernel_nightly-0.5.2.dev20250124002122.dist-info}/RECORD +11 -9
- {liger_kernel_nightly-0.5.2.dev20250122005057.dist-info → liger_kernel_nightly-0.5.2.dev20250124002122.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20250122005057.dist-info → liger_kernel_nightly-0.5.2.dev20250124002122.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20250122005057.dist-info → liger_kernel_nightly-0.5.2.dev20250124002122.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20250122005057.dist-info → liger_kernel_nightly-0.5.2.dev20250124002122.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
# Liger FlexChunkLoss: Alignment and Distillation loss
|
2
2
|
|
3
|
-
Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases.
|
3
|
+
Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO, KTO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases.
|
4
4
|
|
5
5
|
### User interface
|
6
6
|
|
@@ -1,4 +1,5 @@
|
|
1
1
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
|
2
2
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
|
3
|
+
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
|
3
4
|
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
|
4
5
|
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
|
2
2
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
|
3
|
+
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
|
3
4
|
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
|
4
5
|
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
|
5
6
|
|
@@ -7,3 +8,4 @@ liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
|
|
7
8
|
liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
|
8
9
|
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
|
9
10
|
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
|
11
|
+
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
|
@@ -0,0 +1,246 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from functools import partial
|
3
|
+
|
4
|
+
import torch
|
5
|
+
|
6
|
+
from torch.nn import functional as F
|
7
|
+
|
8
|
+
|
9
|
+
class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
|
10
|
+
@abstractmethod
|
11
|
+
def preference_loss_fn(*args, **kwargs):
|
12
|
+
"""
|
13
|
+
To be extended by subclasses.
|
14
|
+
"""
|
15
|
+
raise NotImplementedError("Preference loss function must be implemented.")
|
16
|
+
|
17
|
+
@staticmethod
|
18
|
+
def forward(
|
19
|
+
ctx,
|
20
|
+
_input,
|
21
|
+
weight,
|
22
|
+
target,
|
23
|
+
preference_labels,
|
24
|
+
bias=None,
|
25
|
+
loss_fn=None,
|
26
|
+
chunk_size=1,
|
27
|
+
ignore_index=-100,
|
28
|
+
compiled=True,
|
29
|
+
use_ref_model=False,
|
30
|
+
ref_input=None,
|
31
|
+
ref_weight=None,
|
32
|
+
ref_bias=None,
|
33
|
+
**loss_kwargs,
|
34
|
+
):
|
35
|
+
"""
|
36
|
+
Base class for fused linear layer with unpaired preference loss like KTO
|
37
|
+
Expects _input to be stacked with chosen and rejected inputs on the batch dimension.
|
38
|
+
|
39
|
+
The mental model is:
|
40
|
+
|
41
|
+
forward()
|
42
|
+
├── Loop over chunks
|
43
|
+
└── compute_loss()
|
44
|
+
├── chunk_forward() # Compute logits and log probs
|
45
|
+
└── prefer_loss() # Calculate preference loss
|
46
|
+
|
47
|
+
Args:
|
48
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
|
49
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
|
50
|
+
target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len).
|
51
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
52
|
+
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
53
|
+
chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
|
54
|
+
ignore_index (int): Index to ignore for loss computation.
|
55
|
+
beta (float): Weight for the preference loss.
|
56
|
+
compiled (bool): Whether to use torch compile for chunk accumulation.
|
57
|
+
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
58
|
+
preference_labels (torch.Tensor): Boolean tensor indicating chosen (True) vs rejected (False) examples.
|
59
|
+
Shape: (batch_size,).
|
60
|
+
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
61
|
+
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
62
|
+
loss_kwargs (dict): Other possible arguments that a loss function might need
|
63
|
+
"""
|
64
|
+
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
|
65
|
+
CHUNK_SIZE = chunk_size
|
66
|
+
|
67
|
+
# Gradients to be accumulated
|
68
|
+
grad_inputs = []
|
69
|
+
grad_weight = torch.zeros_like(weight)
|
70
|
+
grad_bias = torch.zeros_like(bias) if bias is not None else None
|
71
|
+
|
72
|
+
# Loss to be accumulated
|
73
|
+
loss_acc = torch.zeros((), device=_input.device)
|
74
|
+
|
75
|
+
compute_loss = partial(
|
76
|
+
LigerFusedLinearUnpairedPreferenceBase._compute_loss,
|
77
|
+
preference_loss_fn=loss_fn,
|
78
|
+
full_target=target,
|
79
|
+
ignore_index=ignore_index,
|
80
|
+
use_ref_model=use_ref_model,
|
81
|
+
ref_weight=ref_weight,
|
82
|
+
ref_bias=ref_bias,
|
83
|
+
**loss_kwargs,
|
84
|
+
)
|
85
|
+
|
86
|
+
def fused_fwd_bwd(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk):
|
87
|
+
"""
|
88
|
+
Fused forward and backward pass for a chunk of input and target.
|
89
|
+
"""
|
90
|
+
argnums = (0, 1, 4) if bias is not None else (0, 1)
|
91
|
+
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=False)(
|
92
|
+
input_chunk,
|
93
|
+
weight,
|
94
|
+
target_chunk,
|
95
|
+
preference_labels_chunk,
|
96
|
+
bias,
|
97
|
+
ref_input_chunk=ref_input_chunk,
|
98
|
+
)
|
99
|
+
|
100
|
+
def accumulate_chunk(
|
101
|
+
input_chunk,
|
102
|
+
target_chunk,
|
103
|
+
preference_labels_chunk=None,
|
104
|
+
ref_input_chunk=None,
|
105
|
+
):
|
106
|
+
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss) = fused_fwd_bwd(
|
107
|
+
input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk
|
108
|
+
)
|
109
|
+
if bias is not None:
|
110
|
+
grad_bias.add_(chunk_grad_bias[0]) # accumulate bias gradient
|
111
|
+
|
112
|
+
# Accumulate gradients
|
113
|
+
grad_weight.add_(chunk_grad_weight)
|
114
|
+
grad_inputs.append(chunk_grad_input)
|
115
|
+
|
116
|
+
# Accumulate loss
|
117
|
+
loss_acc.add_(chunk_loss)
|
118
|
+
|
119
|
+
if compiled:
|
120
|
+
fused_fwd_bwd = torch.compile(fused_fwd_bwd)
|
121
|
+
|
122
|
+
# When not paired, use labels to separate chosen and rejected
|
123
|
+
assert preference_labels is not None, "preference_labels must be provided for unpaired preference loss"
|
124
|
+
|
125
|
+
chunks = max(1, _input.shape[0] // CHUNK_SIZE)
|
126
|
+
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
|
127
|
+
_target_chunks = torch.chunk(target, chunks=chunks, dim=0)
|
128
|
+
_preference_labels_chunks = torch.chunk(preference_labels, chunks=chunks, dim=0)
|
129
|
+
|
130
|
+
if use_ref_model:
|
131
|
+
_ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0)
|
132
|
+
|
133
|
+
for (
|
134
|
+
input_chunk,
|
135
|
+
target_chunk,
|
136
|
+
ref_input_chunk,
|
137
|
+
preference_labels_chunk,
|
138
|
+
) in zip(
|
139
|
+
_input_chunks,
|
140
|
+
_target_chunks,
|
141
|
+
(_ref_input_chunks if use_ref_model else [None] * len(_input_chunks)),
|
142
|
+
_preference_labels_chunks,
|
143
|
+
):
|
144
|
+
# mark input_chunk, target_chunk, and target dimension 1 (sequence length) as dynamic to prevent torch.compile recompilation
|
145
|
+
torch._dynamo.mark_dynamic(input_chunk, 1)
|
146
|
+
torch._dynamo.mark_dynamic(target_chunk, 1)
|
147
|
+
torch._dynamo.mark_dynamic(target, 1)
|
148
|
+
torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
|
149
|
+
torch._dynamo.mark_dynamic(preference_labels_chunk, 1)
|
150
|
+
|
151
|
+
# accumulate loss, gradients, and metrics
|
152
|
+
accumulate_chunk(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk)
|
153
|
+
|
154
|
+
ctx.save_for_backward(
|
155
|
+
torch.cat(grad_inputs, dim=0),
|
156
|
+
grad_weight,
|
157
|
+
grad_bias,
|
158
|
+
)
|
159
|
+
return loss_acc
|
160
|
+
|
161
|
+
@staticmethod
|
162
|
+
def backward(ctx, *grad_output):
|
163
|
+
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
164
|
+
if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)):
|
165
|
+
grad_input = grad_input * grad_output[0][0]
|
166
|
+
grad_weight = grad_weight * grad_output[0][0]
|
167
|
+
grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
|
168
|
+
|
169
|
+
return grad_input, grad_weight, None, None, grad_bias
|
170
|
+
|
171
|
+
@staticmethod
|
172
|
+
def chunk_forward(
|
173
|
+
input_chunk,
|
174
|
+
weight,
|
175
|
+
target_chunk,
|
176
|
+
bias=None,
|
177
|
+
ignore_index=-100,
|
178
|
+
):
|
179
|
+
logits_chunk = input_chunk @ weight.t()
|
180
|
+
if bias is not None:
|
181
|
+
logits_chunk = logits_chunk + bias
|
182
|
+
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
|
183
|
+
|
184
|
+
loss_mask_chunk = target_chunk != ignore_index
|
185
|
+
label_chunk = torch.where(loss_mask_chunk, target_chunk, 0)
|
186
|
+
|
187
|
+
per_token_logps_chunk = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
|
188
|
+
average_log_prob_chunk = (per_token_logps_chunk * loss_mask_chunk).sum(-1) / loss_mask_chunk.sum(-1)
|
189
|
+
|
190
|
+
return average_log_prob_chunk
|
191
|
+
|
192
|
+
@staticmethod
|
193
|
+
def _compute_loss(
|
194
|
+
input_chunk,
|
195
|
+
weight,
|
196
|
+
target_chunk,
|
197
|
+
preference_labels_chunk,
|
198
|
+
bias=None,
|
199
|
+
preference_loss_fn=None,
|
200
|
+
full_target=None,
|
201
|
+
ignore_index=-100,
|
202
|
+
use_ref_model=False,
|
203
|
+
ref_input_chunk=None,
|
204
|
+
ref_weight=None,
|
205
|
+
ref_bias=None,
|
206
|
+
**loss_kwargs,
|
207
|
+
):
|
208
|
+
"""
|
209
|
+
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
|
210
|
+
Args:
|
211
|
+
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
212
|
+
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
|
213
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
|
214
|
+
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
|
215
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
216
|
+
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
|
217
|
+
ignore_index (int): Index to ignore for loss computation.
|
218
|
+
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
219
|
+
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
220
|
+
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
221
|
+
loss_kwargs (dict): Additional arguments for the loss function.
|
222
|
+
"""
|
223
|
+
average_log_prob_chunk = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
|
224
|
+
input_chunk,
|
225
|
+
weight,
|
226
|
+
target_chunk,
|
227
|
+
bias=bias,
|
228
|
+
ignore_index=ignore_index,
|
229
|
+
)
|
230
|
+
|
231
|
+
if use_ref_model:
|
232
|
+
with torch.no_grad():
|
233
|
+
ref_average_log_prob_chunk = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
|
234
|
+
ref_input_chunk,
|
235
|
+
ref_weight,
|
236
|
+
target_chunk,
|
237
|
+
ref_bias,
|
238
|
+
ignore_index=ignore_index,
|
239
|
+
)
|
240
|
+
loss_kwargs["ref_average_log_prob_chunk"] = ref_average_log_prob_chunk
|
241
|
+
|
242
|
+
preference_loss_chunk = preference_loss_fn(
|
243
|
+
average_log_prob_chunk, preference_labels_chunk, full_target, **loss_kwargs
|
244
|
+
)
|
245
|
+
|
246
|
+
return preference_loss_chunk
|
@@ -0,0 +1,172 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn.functional as F
|
3
|
+
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_unpaired_preference import LigerFusedLinearUnpairedPreferenceBase
|
5
|
+
|
6
|
+
|
7
|
+
class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
8
|
+
@staticmethod
|
9
|
+
def preference_loss_fn(
|
10
|
+
average_log_prob_chunk,
|
11
|
+
preference_labels_chunk,
|
12
|
+
full_target,
|
13
|
+
ref_average_log_prob_chunk=None,
|
14
|
+
beta=0.1,
|
15
|
+
kl=None,
|
16
|
+
):
|
17
|
+
"""
|
18
|
+
Implements the Kahneman-Tversky Optimization (KTO) loss function.
|
19
|
+
Paper: "KTO: Model Alignment as Prospect Theory-Guided Optimization"
|
20
|
+
https://arxiv.org/abs/2402.01306
|
21
|
+
|
22
|
+
KTO loss is inspired by prospect theory (https://en.wikipedia.org/wiki/Prospect_theory)
|
23
|
+
from behavioral economics, which models how humans make decisions under uncertainty.
|
24
|
+
The loss function is asymmetric, treating gains and losses differently, similar to
|
25
|
+
human decision-making patterns.
|
26
|
+
|
27
|
+
Formula:
|
28
|
+
When y is chosen:
|
29
|
+
L_KTO = 1 - σ(β * (log[π(x)/π₀(x)] - KL(π||π₀)_y))
|
30
|
+
When y is rejected:
|
31
|
+
L_KTO = 1 - σ(β * (KL(π||π₀)_y - log[π(x)/π₀(x)]))
|
32
|
+
|
33
|
+
Where:
|
34
|
+
- σ: Sigmoid function
|
35
|
+
- β: Temperature parameter controlling the strength of the preference signal
|
36
|
+
- π(x): Policy (current model)
|
37
|
+
- π₀(x): Reference policy (reference model)
|
38
|
+
- KL(π||π₀)_y: KL divergence estimated using the rejected response y
|
39
|
+
|
40
|
+
The loss encourages the model to:
|
41
|
+
1. Assign higher probability to chosen responses
|
42
|
+
2. Assign lower probability to rejected responses
|
43
|
+
3. Maintain reasonable distance from the reference model
|
44
|
+
|
45
|
+
Args:
|
46
|
+
chosen_logps: Log probabilities of chosen tokens (batch_size,)
|
47
|
+
rejected_logps: Log probabilities of rejected tokens (batch_size,)
|
48
|
+
full_target: Non chunked full target tensor
|
49
|
+
ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
|
50
|
+
ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
|
51
|
+
beta: Weight for the direct preference loss
|
52
|
+
kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
|
53
|
+
Returns:
|
54
|
+
Tuple of (loss, chosen_rewards, rejected_rewards):
|
55
|
+
- loss: The KTO loss value
|
56
|
+
- chosen_rewards: Reward signals for chosen responses (detached)
|
57
|
+
- rejected_rewards: Reward signals for rejected responses (detached)
|
58
|
+
"""
|
59
|
+
logratios_chunk = average_log_prob_chunk - ref_average_log_prob_chunk
|
60
|
+
multiplier_chunk = torch.where(preference_labels_chunk, 1, -1)
|
61
|
+
if kl is not None:
|
62
|
+
losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk)
|
63
|
+
else:
|
64
|
+
losses = 1 - F.sigmoid(beta * logratios_chunk * multiplier_chunk)
|
65
|
+
|
66
|
+
return losses.sum() / (full_target.shape[0])
|
67
|
+
|
68
|
+
@staticmethod
|
69
|
+
def forward(
|
70
|
+
ctx,
|
71
|
+
_input,
|
72
|
+
weight,
|
73
|
+
target,
|
74
|
+
preference_labels,
|
75
|
+
bias=None,
|
76
|
+
ref_input=None,
|
77
|
+
ref_weight=None,
|
78
|
+
ref_bias=None,
|
79
|
+
kl=None,
|
80
|
+
ignore_index=-100,
|
81
|
+
beta=0.1,
|
82
|
+
compiled=True,
|
83
|
+
use_ref_model=True,
|
84
|
+
):
|
85
|
+
return LigerFusedLinearUnpairedPreferenceBase.forward(
|
86
|
+
ctx=ctx,
|
87
|
+
_input=_input,
|
88
|
+
weight=weight,
|
89
|
+
target=target,
|
90
|
+
preference_labels=preference_labels,
|
91
|
+
bias=bias,
|
92
|
+
loss_fn=LigerFusedLinearKTOFunction.preference_loss_fn,
|
93
|
+
ignore_index=ignore_index,
|
94
|
+
beta=beta,
|
95
|
+
compiled=compiled,
|
96
|
+
use_ref_model=use_ref_model,
|
97
|
+
ref_input=ref_input,
|
98
|
+
ref_weight=ref_weight,
|
99
|
+
ref_bias=ref_bias,
|
100
|
+
kl=kl,
|
101
|
+
)
|
102
|
+
|
103
|
+
@staticmethod
|
104
|
+
def backward(ctx, *grad_output):
|
105
|
+
grads = LigerFusedLinearUnpairedPreferenceBase.backward(ctx, grad_output)[:5]
|
106
|
+
return (
|
107
|
+
*grads,
|
108
|
+
None,
|
109
|
+
None,
|
110
|
+
None,
|
111
|
+
None,
|
112
|
+
None,
|
113
|
+
None,
|
114
|
+
None,
|
115
|
+
None,
|
116
|
+
None,
|
117
|
+
None,
|
118
|
+
)
|
119
|
+
|
120
|
+
|
121
|
+
class LigerFusedLinearKTOLoss(torch.nn.Module):
|
122
|
+
"""
|
123
|
+
Fused linear layer with Kahneman-Tversky Optimization (KTO) loss.
|
124
|
+
"""
|
125
|
+
|
126
|
+
def __init__(
|
127
|
+
self,
|
128
|
+
ignore_index: int = -100,
|
129
|
+
beta: float = 0.1,
|
130
|
+
compiled: bool = True,
|
131
|
+
use_ref_model: bool = False,
|
132
|
+
):
|
133
|
+
"""
|
134
|
+
Args:
|
135
|
+
ignore_index (int): Index to ignore in the loss calculation
|
136
|
+
beta (float): Temperature parameter for the KTO loss
|
137
|
+
compiled (bool): Whether to use compiled operations
|
138
|
+
use_ref_model (bool): Whether to use a reference model for the DPO loss.
|
139
|
+
"""
|
140
|
+
super().__init__()
|
141
|
+
self.ignore_index = ignore_index
|
142
|
+
self.beta = beta
|
143
|
+
self.compiled = compiled
|
144
|
+
self.use_ref_model = use_ref_model
|
145
|
+
|
146
|
+
def forward(
|
147
|
+
self,
|
148
|
+
_input,
|
149
|
+
lin_weight,
|
150
|
+
target,
|
151
|
+
bias=None,
|
152
|
+
preference_labels=None,
|
153
|
+
ref_input=None,
|
154
|
+
ref_weight=None,
|
155
|
+
ref_bias=None,
|
156
|
+
kl=None,
|
157
|
+
):
|
158
|
+
return LigerFusedLinearKTOFunction.apply(
|
159
|
+
_input,
|
160
|
+
lin_weight,
|
161
|
+
target,
|
162
|
+
preference_labels,
|
163
|
+
bias,
|
164
|
+
ref_input,
|
165
|
+
ref_weight,
|
166
|
+
ref_bias,
|
167
|
+
kl,
|
168
|
+
self.ignore_index,
|
169
|
+
self.beta,
|
170
|
+
self.compiled,
|
171
|
+
self.use_ref_model,
|
172
|
+
)
|
@@ -1,13 +1,15 @@
|
|
1
1
|
liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,1728
|
3
3
|
liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
|
4
|
-
liger_kernel/chunked_loss/README.md,sha256=
|
5
|
-
liger_kernel/chunked_loss/__init__.py,sha256=
|
4
|
+
liger_kernel/chunked_loss/README.md,sha256=0FmkFC3hKBqyoDT5uTlIYmrvRkF-EOCR1y-EBU1LpWU,2248
|
5
|
+
liger_kernel/chunked_loss/__init__.py,sha256=CI6hBI7VldTX748c7F6F8YpHTn1q4gv5-lMXf273oXQ,431
|
6
6
|
liger_kernel/chunked_loss/cpo_loss.py,sha256=OdBR8WYdHTKpLI_c9DcuwqKSWPeAAeTyREz46Vu_cAY,3682
|
7
7
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=VYZMOafdvE8xlhvTtwjrz81tIzxR1mHF4lXdsADnIQg,4373
|
8
|
-
liger_kernel/chunked_loss/functional.py,sha256=
|
8
|
+
liger_kernel/chunked_loss/functional.py,sha256=dO0DYMPTBxwPtEUQ1DUV2zCmZ6i-k3B7COeR3-IwA6M,683
|
9
9
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=uQtwtu-kaUZJTjNhAnIr3O794oUlUZ98XR5shYtwP5k,10440
|
10
10
|
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=idK9V9NivoVITqVpiG0fEGUHSvinYWkn9-EYXZjR-KQ,18356
|
11
|
+
liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=ZqYlXXhIphkJPxOS7iI70avgrr6x0skEtgpckZTYau0,9819
|
12
|
+
liger_kernel/chunked_loss/kto_loss.py,sha256=eVNW6HVCAm32shpfhbRlk92Flnjd7G32v0gK9DUUSOQ,5655
|
11
13
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=yjcrrbVeemLYodoSKT-FMSnaPtyKAZ3aOrvPD6tTY6Y,3617
|
12
14
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=3TTc7U79Orjgi-Wu81WZkWk5MgsdqKXIOBHgIvDazPw,3865
|
13
15
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -58,9 +60,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
58
60
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
59
61
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
60
62
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
61
|
-
liger_kernel_nightly-0.5.2.
|
62
|
-
liger_kernel_nightly-0.5.2.
|
63
|
-
liger_kernel_nightly-0.5.2.
|
64
|
-
liger_kernel_nightly-0.5.2.
|
65
|
-
liger_kernel_nightly-0.5.2.
|
66
|
-
liger_kernel_nightly-0.5.2.
|
63
|
+
liger_kernel_nightly-0.5.2.dev20250124002122.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
64
|
+
liger_kernel_nightly-0.5.2.dev20250124002122.dist-info/METADATA,sha256=XkhmLkKGR1Tuel5f-4SxOwiE2AP0jrWAmkN8jrQcB_U,21140
|
65
|
+
liger_kernel_nightly-0.5.2.dev20250124002122.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
66
|
+
liger_kernel_nightly-0.5.2.dev20250124002122.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
67
|
+
liger_kernel_nightly-0.5.2.dev20250124002122.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
68
|
+
liger_kernel_nightly-0.5.2.dev20250124002122.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|