liger-kernel-nightly 0.5.2.dev20250130024630__py3-none-any.whl → 0.5.2.dev20250130213846__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/dpo_loss.py +5 -2
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +14 -5
- liger_kernel/chunked_loss/jsd_loss.py +154 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630.dist-info → liger_kernel_nightly-0.5.2.dev20250130213846.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.2.dev20250130024630.dist-info → liger_kernel_nightly-0.5.2.dev20250130213846.dist-info}/RECORD +11 -10
- {liger_kernel_nightly-0.5.2.dev20250130024630.dist-info → liger_kernel_nightly-0.5.2.dev20250130213846.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630.dist-info → liger_kernel_nightly-0.5.2.dev20250130213846.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630.dist-info → liger_kernel_nightly-0.5.2.dev20250130213846.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20250130024630.dist-info → liger_kernel_nightly-0.5.2.dev20250130213846.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
|
|
1
1
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
|
2
2
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
|
3
|
+
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401
|
3
4
|
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
|
4
5
|
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
|
5
6
|
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
|
@@ -45,9 +45,12 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
45
45
|
chosen_logratios = chosen_logps - ref_chosen_logps
|
46
46
|
rejected_logratios = rejected_logps - ref_rejected_logps
|
47
47
|
|
48
|
+
chosen_rewards = beta * (chosen_logps - ref_chosen_logps)
|
49
|
+
rejected_rewards = beta * (rejected_logps - ref_rejected_logps)
|
50
|
+
|
48
51
|
logits_diff = beta * (chosen_logratios - rejected_logratios)
|
49
52
|
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
|
50
|
-
return loss
|
53
|
+
return loss, chosen_rewards, rejected_rewards
|
51
54
|
|
52
55
|
@staticmethod
|
53
56
|
def forward(
|
@@ -99,7 +102,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
99
102
|
beta: float = 0.1,
|
100
103
|
compute_nll_loss: bool = False,
|
101
104
|
compiled: bool = True,
|
102
|
-
use_ref_model: bool =
|
105
|
+
use_ref_model: bool = True,
|
103
106
|
):
|
104
107
|
"""
|
105
108
|
Args:
|
@@ -1,11 +1,13 @@
|
|
1
1
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
|
2
2
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
|
3
|
+
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
|
3
4
|
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
|
4
5
|
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
|
5
6
|
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
|
6
7
|
|
7
8
|
liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
|
8
9
|
liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
|
10
|
+
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
|
9
11
|
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
|
10
12
|
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
|
11
13
|
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
|
@@ -17,6 +17,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
17
17
|
Args:
|
18
18
|
student_logits (torch.Tensor): Raw (temperature-scaled) logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
|
19
19
|
teacher_logits (torch.Tensor): Raw (temperature-scaled) logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
|
20
|
+
Returns:
|
21
|
+
torch.Tensor: Sum of distillation losses for the chunk. The class will handle
|
22
|
+
converting this to mean loss by dividing by the full batch size * sequence length in _compute_loss.
|
20
23
|
"""
|
21
24
|
raise NotImplementedError("Distillation loss function must be implemented.")
|
22
25
|
|
@@ -71,10 +74,11 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
71
74
|
weight_hard_loss=0.5,
|
72
75
|
weight_soft_loss=0.5,
|
73
76
|
compute_ce_loss=True,
|
77
|
+
temperature=1,
|
74
78
|
**loss_kwargs,
|
75
79
|
):
|
76
80
|
"""
|
77
|
-
Compute the total loss for a chunk of input and target, while using an
|
81
|
+
Compute the total loss for a chunk of input and target, while using an knowledge distillation loss function.
|
78
82
|
Args:
|
79
83
|
distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
80
84
|
student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size).
|
@@ -84,11 +88,12 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
84
88
|
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,).
|
85
89
|
student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
86
90
|
teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
87
|
-
full_target (torch.Tensor): Full target tensor. Shape: (
|
91
|
+
full_target (torch.Tensor): Full target tensor. Shape: (batch_size * sequence_length,).
|
88
92
|
ignore_index (int): Index to ignore for loss computation.
|
89
93
|
weight_hard_loss (float): Weight for hard loss.
|
90
94
|
weight_soft_loss (float): Weight for soft loss.
|
91
95
|
compute_ce_loss (bool): Whether to compute CE loss.
|
96
|
+
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
|
92
97
|
loss_kwargs (dict): Additional arguments for the loss function.
|
93
98
|
"""
|
94
99
|
(
|
@@ -107,6 +112,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
107
112
|
compute_ce_loss=compute_ce_loss,
|
108
113
|
)
|
109
114
|
|
115
|
+
student_logits_chunk /= temperature
|
116
|
+
teacher_logits_chunk /= temperature
|
117
|
+
|
110
118
|
hard_loss /= full_target.shape[0]
|
111
119
|
|
112
120
|
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
|
@@ -130,6 +138,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
130
138
|
ignore_index=-100,
|
131
139
|
weight_hard_loss=0.5,
|
132
140
|
weight_soft_loss=0.5,
|
141
|
+
beta=0.5,
|
133
142
|
compute_ce_loss=True,
|
134
143
|
temperature=1.0,
|
135
144
|
compiled=True,
|
@@ -152,6 +161,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
152
161
|
ignore_index (int): Index to ignore for loss computation.
|
153
162
|
weight_hard_loss (float): Weight for hard/task loss.
|
154
163
|
weight_soft_loss (float): Weight for soft/distillation loss.
|
164
|
+
beta (float): Interpolation coefficient between 0 and 1 (default: 0.5).
|
155
165
|
compute_ce_loss (bool): Whether to compute CE loss.
|
156
166
|
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
|
157
167
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
@@ -170,7 +180,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
170
180
|
ignore_index=ignore_index,
|
171
181
|
weight_hard_loss=weight_hard_loss,
|
172
182
|
weight_soft_loss=weight_soft_loss,
|
183
|
+
beta=beta,
|
173
184
|
compute_ce_loss=compute_ce_loss,
|
185
|
+
temperature=temperature,
|
174
186
|
**loss_kwargs,
|
175
187
|
)
|
176
188
|
|
@@ -225,9 +237,6 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
225
237
|
if compiled:
|
226
238
|
accumulate_chunk = torch.compile(accumulate_chunk)
|
227
239
|
|
228
|
-
student_input /= temperature
|
229
|
-
teacher_input /= temperature
|
230
|
-
|
231
240
|
num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE)
|
232
241
|
_student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0)
|
233
242
|
_teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)
|
@@ -0,0 +1,154 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn.functional as F
|
3
|
+
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
|
5
|
+
|
6
|
+
|
7
|
+
class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
8
|
+
@staticmethod
|
9
|
+
def distillation_loss_fn(student_logits, teacher_logits, beta=0.5):
|
10
|
+
"""
|
11
|
+
Compute JSD loss (Jensen-Shannon Divergence Loss).
|
12
|
+
Args:
|
13
|
+
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
|
14
|
+
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
|
15
|
+
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
16
|
+
Returns:
|
17
|
+
torch.Tensor: Jensen-Shannon Divergence loss
|
18
|
+
"""
|
19
|
+
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
20
|
+
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
21
|
+
|
22
|
+
# Compute probabilities (only required for mean calculation)
|
23
|
+
mean_probs = beta * student_log_probs.exp() + (1 - beta) * teacher_log_probs.exp()
|
24
|
+
log_mean_probs = mean_probs.log()
|
25
|
+
|
26
|
+
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
|
27
|
+
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
|
28
|
+
|
29
|
+
# JSD is the weighted average of the KL divergences
|
30
|
+
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
31
|
+
return jsd_loss
|
32
|
+
|
33
|
+
@staticmethod
|
34
|
+
def forward(
|
35
|
+
ctx,
|
36
|
+
student_input: torch.Tensor,
|
37
|
+
student_weight: torch.Tensor,
|
38
|
+
teacher_input: torch.Tensor,
|
39
|
+
teacher_weight: torch.Tensor,
|
40
|
+
true_labels: torch.LongTensor,
|
41
|
+
weight_hard_loss: float = 0.5,
|
42
|
+
weight_soft_loss: float = 0.5,
|
43
|
+
beta: float = 0.5,
|
44
|
+
ignore_index: int = -100,
|
45
|
+
temperature: float = 1.0,
|
46
|
+
compiled: bool = True,
|
47
|
+
):
|
48
|
+
"""
|
49
|
+
Fused linear layer with JSD distillation loss.
|
50
|
+
Args:
|
51
|
+
student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, hidden_size_student)
|
52
|
+
student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, hidden_size_student)
|
53
|
+
teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, hidden_size_teacher)
|
54
|
+
teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, hidden_size_teacher)
|
55
|
+
true_labels (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
56
|
+
weight_hard_loss (float): Weight for hard loss.
|
57
|
+
weight_soft_loss (float): Weight for soft loss.
|
58
|
+
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
59
|
+
ignore_index (int): Index to ignore in loss computation
|
60
|
+
temperature (float): Temperature for softening/sharpening distributions
|
61
|
+
compiled (bool): Whether to use torch compile
|
62
|
+
Returns:
|
63
|
+
torch.Tensor: Computed loss
|
64
|
+
"""
|
65
|
+
return LigerFusedLinearDistillationBase.forward(
|
66
|
+
ctx=ctx,
|
67
|
+
student_input=student_input,
|
68
|
+
student_weight=student_weight,
|
69
|
+
teacher_input=teacher_input,
|
70
|
+
teacher_weight=teacher_weight,
|
71
|
+
target=true_labels,
|
72
|
+
loss_fn=LigerFusedLinearJSDFunction.distillation_loss_fn,
|
73
|
+
chunk_size=1,
|
74
|
+
weight_hard_loss=weight_hard_loss,
|
75
|
+
weight_soft_loss=weight_soft_loss,
|
76
|
+
beta=beta,
|
77
|
+
ignore_index=ignore_index,
|
78
|
+
temperature=temperature,
|
79
|
+
compiled=compiled,
|
80
|
+
)
|
81
|
+
|
82
|
+
@staticmethod
|
83
|
+
def backward(ctx, grad_output):
|
84
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:4]
|
85
|
+
|
86
|
+
return (*grads, None, None, None, None, None, None, None)
|
87
|
+
|
88
|
+
|
89
|
+
class LigerFusedLinearJSDLoss(torch.nn.Module):
|
90
|
+
"""
|
91
|
+
Fused linear layer with JSD distillation loss.
|
92
|
+
"""
|
93
|
+
|
94
|
+
def __init__(
|
95
|
+
self,
|
96
|
+
weight_hard_loss: float = 0.5,
|
97
|
+
weight_soft_loss: float = 0.5,
|
98
|
+
beta: float = 0.5,
|
99
|
+
ignore_index: int = -100,
|
100
|
+
temperature: float = 1.0,
|
101
|
+
compiled: bool = True,
|
102
|
+
):
|
103
|
+
"""
|
104
|
+
Args:
|
105
|
+
weight_hard_loss (float): Weight for hard loss.
|
106
|
+
weight_soft_loss (float): Weight for soft loss.
|
107
|
+
ignore_index (int): Index to ignore in the loss
|
108
|
+
temperature (float): Temperature for softening distributions
|
109
|
+
compiled (bool): Whether to use torch compile
|
110
|
+
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
111
|
+
"""
|
112
|
+
super().__init__()
|
113
|
+
assert temperature != 0, "Temperature cannot be 0."
|
114
|
+
self.weight_hard_loss = weight_hard_loss
|
115
|
+
self.weight_soft_loss = weight_soft_loss
|
116
|
+
self.ignore_index = ignore_index
|
117
|
+
self.temperature = temperature
|
118
|
+
self.compiled = compiled
|
119
|
+
self.beta = beta
|
120
|
+
|
121
|
+
def forward(
|
122
|
+
self,
|
123
|
+
student_input: torch.Tensor,
|
124
|
+
student_weight: torch.Tensor,
|
125
|
+
teacher_input: torch.Tensor,
|
126
|
+
teacher_weight: torch.Tensor,
|
127
|
+
true_labels: torch.LongTensor,
|
128
|
+
) -> torch.Tensor:
|
129
|
+
"""
|
130
|
+
Compute the JSD distillation loss.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
student_input (torch.Tensor): Student input tensor
|
134
|
+
student_weight (torch.Tensor): Student weight tensor
|
135
|
+
teacher_input (torch.Tensor): Teacher input tensor
|
136
|
+
teacher_weight (torch.Tensor): Teacher weight tensor
|
137
|
+
true_labels (torch.LongTensor): Target labels tensor
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
torch.Tensor: Computed loss
|
141
|
+
"""
|
142
|
+
return LigerFusedLinearJSDFunction.apply(
|
143
|
+
student_input,
|
144
|
+
student_weight,
|
145
|
+
teacher_input,
|
146
|
+
teacher_weight,
|
147
|
+
true_labels,
|
148
|
+
self.weight_hard_loss,
|
149
|
+
self.weight_soft_loss,
|
150
|
+
self.beta,
|
151
|
+
self.ignore_index,
|
152
|
+
self.temperature,
|
153
|
+
self.compiled,
|
154
|
+
)
|
@@ -2,13 +2,14 @@ liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,1728
|
3
3
|
liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
|
4
4
|
liger_kernel/chunked_loss/README.md,sha256=0FmkFC3hKBqyoDT5uTlIYmrvRkF-EOCR1y-EBU1LpWU,2248
|
5
|
-
liger_kernel/chunked_loss/__init__.py,sha256=
|
5
|
+
liger_kernel/chunked_loss/__init__.py,sha256=48m-8IMOAReZbi0HK5aV-KGBp2IsZSwFvdnzTNrS4bk,516
|
6
6
|
liger_kernel/chunked_loss/cpo_loss.py,sha256=OdBR8WYdHTKpLI_c9DcuwqKSWPeAAeTyREz46Vu_cAY,3682
|
7
|
-
liger_kernel/chunked_loss/dpo_loss.py,sha256=
|
8
|
-
liger_kernel/chunked_loss/functional.py,sha256=
|
9
|
-
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=
|
7
|
+
liger_kernel/chunked_loss/dpo_loss.py,sha256=enFVgqIvWWOamOV3cl_dbq2HsjX7PF2d0kibDNyuCW4,4545
|
8
|
+
liger_kernel/chunked_loss/functional.py,sha256=THWWpCnRVhTVfnPnyvQjdBvo1JDtxhwLmtZE_yiBBqM,817
|
9
|
+
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=5V8rdva89WyHVbmJ8JOmC4DYNOR6ByXfx3qlUieOZkI,11002
|
10
10
|
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=idK9V9NivoVITqVpiG0fEGUHSvinYWkn9-EYXZjR-KQ,18356
|
11
11
|
liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=ZqYlXXhIphkJPxOS7iI70avgrr6x0skEtgpckZTYau0,9819
|
12
|
+
liger_kernel/chunked_loss/jsd_loss.py,sha256=yRCQdvd3ruTWP4A_BfU8VcZ6LepSUfO0Ob7stGnueQY,6052
|
12
13
|
liger_kernel/chunked_loss/kto_loss.py,sha256=eVNW6HVCAm32shpfhbRlk92Flnjd7G32v0gK9DUUSOQ,5655
|
13
14
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=yjcrrbVeemLYodoSKT-FMSnaPtyKAZ3aOrvPD6tTY6Y,3617
|
14
15
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=3TTc7U79Orjgi-Wu81WZkWk5MgsdqKXIOBHgIvDazPw,3865
|
@@ -60,9 +61,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
60
61
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
61
62
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
62
63
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
63
|
-
liger_kernel_nightly-0.5.2.
|
64
|
-
liger_kernel_nightly-0.5.2.
|
65
|
-
liger_kernel_nightly-0.5.2.
|
66
|
-
liger_kernel_nightly-0.5.2.
|
67
|
-
liger_kernel_nightly-0.5.2.
|
68
|
-
liger_kernel_nightly-0.5.2.
|
64
|
+
liger_kernel_nightly-0.5.2.dev20250130213846.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
65
|
+
liger_kernel_nightly-0.5.2.dev20250130213846.dist-info/METADATA,sha256=piBJYHmJpYyIojWrErmhbQnpy29ILTI03ttIe1ekUZU,21205
|
66
|
+
liger_kernel_nightly-0.5.2.dev20250130213846.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
67
|
+
liger_kernel_nightly-0.5.2.dev20250130213846.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
68
|
+
liger_kernel_nightly-0.5.2.dev20250130213846.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
69
|
+
liger_kernel_nightly-0.5.2.dev20250130213846.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|