liger-kernel-nightly 0.6.3.dev20251027181634__py3-none-any.whl → 0.6.3.dev20251028065948__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
- liger_kernel/chunked_loss/jsd_loss.py +18 -5
- {liger_kernel_nightly-0.6.3.dev20251027181634.dist-info → liger_kernel_nightly-0.6.3.dev20251028065948.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.6.3.dev20251027181634.dist-info → liger_kernel_nightly-0.6.3.dev20251028065948.dist-info}/RECORD +9 -9
- {liger_kernel_nightly-0.6.3.dev20251027181634.dist-info → liger_kernel_nightly-0.6.3.dev20251028065948.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.3.dev20251027181634.dist-info → liger_kernel_nightly-0.6.3.dev20251028065948.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.3.dev20251027181634.dist-info → liger_kernel_nightly-0.6.3.dev20251028065948.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.3.dev20251027181634.dist-info → liger_kernel_nightly-0.6.3.dev20251028065948.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
1
4
|
import torch
|
|
2
5
|
import torch.nn.functional as F
|
|
3
6
|
|
|
@@ -41,7 +44,8 @@ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase)
|
|
|
41
44
|
temperature: float = 1.0,
|
|
42
45
|
compiled: bool = True,
|
|
43
46
|
chunk_size: int = 1024,
|
|
44
|
-
|
|
47
|
+
return_soft_hard_loss: bool = False,
|
|
48
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
45
49
|
return super().forward(
|
|
46
50
|
cls=cls,
|
|
47
51
|
ctx=ctx,
|
|
@@ -59,11 +63,12 @@ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase)
|
|
|
59
63
|
ignore_index=ignore_index,
|
|
60
64
|
temperature=temperature,
|
|
61
65
|
compiled=compiled,
|
|
66
|
+
return_soft_hard_loss=return_soft_hard_loss,
|
|
62
67
|
)
|
|
63
68
|
|
|
64
69
|
@staticmethod
|
|
65
|
-
def backward(ctx, grad_output):
|
|
66
|
-
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
|
|
70
|
+
def backward(ctx, grad_output, *args):
|
|
71
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
|
|
67
72
|
|
|
68
73
|
return (
|
|
69
74
|
*grads,
|
|
@@ -75,6 +80,7 @@ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase)
|
|
|
75
80
|
None, # temperature
|
|
76
81
|
None, # compiled
|
|
77
82
|
None, # chunk_size
|
|
83
|
+
None, # return_soft_hard_loss
|
|
78
84
|
)
|
|
79
85
|
|
|
80
86
|
|
|
@@ -88,6 +94,7 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
|
88
94
|
temperature: float = 1.0,
|
|
89
95
|
compiled: bool = True,
|
|
90
96
|
chunk_size: int = 1024,
|
|
97
|
+
return_soft_hard_loss: bool = False,
|
|
91
98
|
):
|
|
92
99
|
super().__init__()
|
|
93
100
|
assert temperature != 0, "Temperature cannot be 0."
|
|
@@ -98,6 +105,7 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
|
98
105
|
self.compiled = compiled
|
|
99
106
|
self.beta = beta
|
|
100
107
|
self.chunk_size = chunk_size
|
|
108
|
+
self.return_soft_hard_loss = return_soft_hard_loss
|
|
101
109
|
|
|
102
110
|
def forward(
|
|
103
111
|
self,
|
|
@@ -108,7 +116,7 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
|
108
116
|
true_labels: torch.LongTensor,
|
|
109
117
|
student_bias: torch.Tensor = None,
|
|
110
118
|
teacher_bias: torch.Tensor = None,
|
|
111
|
-
) -> torch.Tensor:
|
|
119
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
112
120
|
return LigerFusedLinearCosineSimilarityFunction.apply(
|
|
113
121
|
student_input,
|
|
114
122
|
student_weight,
|
|
@@ -124,4 +132,5 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
|
124
132
|
self.temperature,
|
|
125
133
|
self.compiled,
|
|
126
134
|
self.chunk_size,
|
|
135
|
+
self.return_soft_hard_loss,
|
|
127
136
|
)
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
2
|
from functools import partial
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
3
5
|
|
|
4
6
|
import torch
|
|
5
7
|
|
|
@@ -157,8 +159,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
157
159
|
compute_ce_loss=True,
|
|
158
160
|
temperature=1.0,
|
|
159
161
|
compiled=True,
|
|
162
|
+
return_soft_hard_loss=False,
|
|
160
163
|
**loss_kwargs,
|
|
161
|
-
):
|
|
164
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
162
165
|
"""
|
|
163
166
|
Base class for fused linear layer with distillation loss.
|
|
164
167
|
Only need to compute gradients for student model.
|
|
@@ -180,6 +183,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
180
183
|
compute_ce_loss (bool): Whether to compute CE loss.
|
|
181
184
|
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
|
|
182
185
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
|
186
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
183
187
|
loss_kwargs (dict): Other possible arguments that a loss function might need
|
|
184
188
|
"""
|
|
185
189
|
CHUNK_SIZE = chunk_size
|
|
@@ -187,6 +191,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
187
191
|
grad_inputs = []
|
|
188
192
|
grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
|
|
189
193
|
loss_acc = torch.zeros((), device=student_input.device)
|
|
194
|
+
soft_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
|
|
195
|
+
hard_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
|
|
190
196
|
|
|
191
197
|
loss_func_to_call = partial(
|
|
192
198
|
LigerFusedLinearDistillationBase._compute_loss,
|
|
@@ -247,6 +253,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
247
253
|
)
|
|
248
254
|
grad_weight.add_(chunk_grad_weight)
|
|
249
255
|
loss_acc.add_(chunk_loss)
|
|
256
|
+
if return_soft_hard_loss:
|
|
257
|
+
soft_loss_acc.add_(chunk_soft_loss)
|
|
258
|
+
hard_loss_acc.add_(chunk_hard_loss)
|
|
250
259
|
return chunk_grad_input
|
|
251
260
|
|
|
252
261
|
if compiled:
|
|
@@ -268,10 +277,12 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
268
277
|
grad_weight,
|
|
269
278
|
grad_bias,
|
|
270
279
|
)
|
|
280
|
+
if return_soft_hard_loss:
|
|
281
|
+
return loss_acc, soft_loss_acc, hard_loss_acc
|
|
271
282
|
return loss_acc
|
|
272
283
|
|
|
273
284
|
@staticmethod
|
|
274
|
-
def backward(ctx, grad_output):
|
|
285
|
+
def backward(ctx, grad_output, *args):
|
|
275
286
|
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
276
287
|
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
277
288
|
grad_input = grad_input * grad_output
|
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import math
|
|
2
2
|
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
3
6
|
import torch
|
|
4
7
|
import torch.nn.functional as F
|
|
5
8
|
|
|
@@ -56,6 +59,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
56
59
|
temperature: float = 1.0,
|
|
57
60
|
compiled: bool = True,
|
|
58
61
|
chunk_size: int = 1024,
|
|
62
|
+
return_soft_hard_loss: bool = False,
|
|
59
63
|
):
|
|
60
64
|
"""
|
|
61
65
|
Fused linear layer with JSD distillation loss.
|
|
@@ -72,8 +76,9 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
72
76
|
temperature (float): Temperature for softening/sharpening distributions
|
|
73
77
|
compiled (bool): Whether to use torch compile
|
|
74
78
|
chunk_size (int): Size of chunks for processing.
|
|
79
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
75
80
|
Returns:
|
|
76
|
-
torch.Tensor: Computed loss
|
|
81
|
+
torch.Tensor: Computed loss, or tuple (loss, soft_loss, hard_loss) if return_soft_hard_loss=True
|
|
77
82
|
"""
|
|
78
83
|
return super().forward(
|
|
79
84
|
cls=cls,
|
|
@@ -92,11 +97,12 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
92
97
|
ignore_index=ignore_index,
|
|
93
98
|
temperature=temperature,
|
|
94
99
|
compiled=compiled,
|
|
100
|
+
return_soft_hard_loss=return_soft_hard_loss,
|
|
95
101
|
)
|
|
96
102
|
|
|
97
103
|
@staticmethod
|
|
98
|
-
def backward(ctx, grad_output):
|
|
99
|
-
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
|
|
104
|
+
def backward(ctx, grad_output, *args):
|
|
105
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
|
|
100
106
|
|
|
101
107
|
return (
|
|
102
108
|
*grads,
|
|
@@ -108,6 +114,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
108
114
|
None, # temperature
|
|
109
115
|
None, # compiled
|
|
110
116
|
None, # chunk_size
|
|
117
|
+
None, # return_soft_hard_loss
|
|
111
118
|
)
|
|
112
119
|
|
|
113
120
|
|
|
@@ -125,6 +132,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
125
132
|
temperature: float = 1.0,
|
|
126
133
|
compiled: bool = True,
|
|
127
134
|
chunk_size: int = 1024,
|
|
135
|
+
return_soft_hard_loss: bool = False,
|
|
128
136
|
):
|
|
129
137
|
"""
|
|
130
138
|
Args:
|
|
@@ -135,6 +143,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
135
143
|
compiled (bool): Whether to use torch compile
|
|
136
144
|
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
137
145
|
chunk_size (int): Size of chunks for processing.
|
|
146
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
138
147
|
"""
|
|
139
148
|
super().__init__()
|
|
140
149
|
assert temperature != 0, "Temperature cannot be 0."
|
|
@@ -145,6 +154,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
145
154
|
self.compiled = compiled
|
|
146
155
|
self.beta = beta
|
|
147
156
|
self.chunk_size = chunk_size
|
|
157
|
+
self.return_soft_hard_loss = return_soft_hard_loss
|
|
148
158
|
|
|
149
159
|
def forward(
|
|
150
160
|
self,
|
|
@@ -155,7 +165,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
155
165
|
true_labels: torch.LongTensor,
|
|
156
166
|
student_bias: torch.Tensor = None,
|
|
157
167
|
teacher_bias: torch.Tensor = None,
|
|
158
|
-
) -> torch.Tensor:
|
|
168
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
159
169
|
"""
|
|
160
170
|
Compute the JSD distillation loss.
|
|
161
171
|
|
|
@@ -167,7 +177,9 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
167
177
|
true_labels (torch.LongTensor): Target labels tensor
|
|
168
178
|
|
|
169
179
|
Returns:
|
|
170
|
-
torch.Tensor
|
|
180
|
+
torch.Tensor or Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
181
|
+
If return_soft_hard_loss is False: Computed combined loss
|
|
182
|
+
If return_soft_hard_loss is True: Tuple of (combined_loss, soft_loss, hard_loss)
|
|
171
183
|
"""
|
|
172
184
|
return LigerFusedLinearJSDFunction.apply(
|
|
173
185
|
student_input,
|
|
@@ -184,4 +196,5 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
184
196
|
self.temperature,
|
|
185
197
|
self.compiled,
|
|
186
198
|
self.chunk_size,
|
|
199
|
+
self.return_soft_hard_loss,
|
|
187
200
|
)
|
|
@@ -3,16 +3,16 @@ liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,17
|
|
|
3
3
|
liger_kernel/utils.py,sha256=BQleeZWHSZPNuPcYcoZTOp1kcNEZONZilPP5-AmjgWI,2024
|
|
4
4
|
liger_kernel/chunked_loss/README.md,sha256=0FmkFC3hKBqyoDT5uTlIYmrvRkF-EOCR1y-EBU1LpWU,2248
|
|
5
5
|
liger_kernel/chunked_loss/__init__.py,sha256=J5_jNnzZ4gZmA38W5f_4oab7xMoNk1Xy-yh3X_Xlf-s,714
|
|
6
|
-
liger_kernel/chunked_loss/cosine_similarity_loss.py,sha256=
|
|
6
|
+
liger_kernel/chunked_loss/cosine_similarity_loss.py,sha256=x2nprTHPraU8Ya2NMZtaDk9r-s-1NKJwCTrzQIdmg-8,4680
|
|
7
7
|
liger_kernel/chunked_loss/cpo_loss.py,sha256=Gzz1eU4kgcbdubFVRy55e8A1Cr-r45UgNicXwZIjmBU,5454
|
|
8
8
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=I83khNs3QQjuhr8U3NIOAACkbse6DNiBV-TulPZ0lXw,9006
|
|
9
9
|
liger_kernel/chunked_loss/functional.py,sha256=-XPDbLml9dHmvoSU2VNTUrBDFehuzvuAGPikVetBMtI,1132
|
|
10
|
-
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=
|
|
10
|
+
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=yRtolfFGfKB-SxGQQyF68GYXd11Zlvh1InLdGeWNFIE,12652
|
|
11
11
|
liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=ZjpNP5VC-tXXIKb4AckkQ3iWWQeej-JoG4StJq3N0wg,13650
|
|
12
12
|
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=FIH85uUXAOgYx5Ax8MjFhJHVu-2pKtY7wSegd0zSyyY,18336
|
|
13
13
|
liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=RiuK3UtRwH9T6jZ36sA8Urj-TVuOLOO2syLg_JOQapY,13437
|
|
14
14
|
liger_kernel/chunked_loss/grpo_loss.py,sha256=SkZuKoW8K94UbWR-OtfopsQkuQ8tFOr_90AGR6_Mhes,12844
|
|
15
|
-
liger_kernel/chunked_loss/jsd_loss.py,sha256=
|
|
15
|
+
liger_kernel/chunked_loss/jsd_loss.py,sha256=G0RghPYYelyZ6DOEiwS8we9TT5MY2iHpiFqzZ2Xy87g,8038
|
|
16
16
|
liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsmSbQyqwQY,7529
|
|
17
17
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
|
|
18
18
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=fy2w8KbhMrBv7b1jdIeH3bBFxY52bPQPZb3KwBvmurM,5385
|
|
@@ -103,9 +103,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
|
103
103
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
|
104
104
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
|
105
105
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
|
106
|
-
liger_kernel_nightly-0.6.3.
|
|
107
|
-
liger_kernel_nightly-0.6.3.
|
|
108
|
-
liger_kernel_nightly-0.6.3.
|
|
109
|
-
liger_kernel_nightly-0.6.3.
|
|
110
|
-
liger_kernel_nightly-0.6.3.
|
|
111
|
-
liger_kernel_nightly-0.6.3.
|
|
106
|
+
liger_kernel_nightly-0.6.3.dev20251028065948.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
|
107
|
+
liger_kernel_nightly-0.6.3.dev20251028065948.dist-info/METADATA,sha256=2Y-q-3hxi7UILSX1Yn7BTGAqoAhQTpb8mUAyAxagTTQ,24777
|
|
108
|
+
liger_kernel_nightly-0.6.3.dev20251028065948.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
|
109
|
+
liger_kernel_nightly-0.6.3.dev20251028065948.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
110
|
+
liger_kernel_nightly-0.6.3.dev20251028065948.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
|
111
|
+
liger_kernel_nightly-0.6.3.dev20251028065948.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|