liger-kernel-nightly 0.5.10.dev20250630172023__py3-none-any.whl → 0.5.10.dev20250702150221__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
- liger_kernel/chunked_loss/functional.py +2 -0
- {liger_kernel_nightly-0.5.10.dev20250630172023.dist-info → liger_kernel_nightly-0.5.10.dev20250702150221.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.10.dev20250630172023.dist-info → liger_kernel_nightly-0.5.10.dev20250702150221.dist-info}/RECORD +9 -8
- {liger_kernel_nightly-0.5.10.dev20250630172023.dist-info → liger_kernel_nightly-0.5.10.dev20250702150221.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250630172023.dist-info → liger_kernel_nightly-0.5.10.dev20250702150221.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250630172023.dist-info → liger_kernel_nightly-0.5.10.dev20250702150221.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250630172023.dist-info → liger_kernel_nightly-0.5.10.dev20250702150221.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
|
|
1
|
+
from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityLoss # noqa:F401
|
1
2
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
|
2
3
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
|
3
4
|
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
|
@@ -0,0 +1,127 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn.functional as F
|
3
|
+
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
|
5
|
+
|
6
|
+
|
7
|
+
class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase):
|
8
|
+
@staticmethod
|
9
|
+
def distillation_loss_fn(student_logits, teacher_logits, beta=1.0):
|
10
|
+
"""
|
11
|
+
Compute Cosine loss (Cosine Similarity Loss).
|
12
|
+
Args:
|
13
|
+
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
|
14
|
+
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
|
15
|
+
beta: Coefficient beta of generalized Cosine Similarity in the interval [0, 1]. Default: `1.0` (float): .
|
16
|
+
Returns:
|
17
|
+
torch.Tensor: cosine similarity loss
|
18
|
+
"""
|
19
|
+
student_norm = F.normalize(student_logits, p=2, dim=-1)
|
20
|
+
teacher_norm = F.normalize(teacher_logits, p=2, dim=-1)
|
21
|
+
|
22
|
+
cosine_sim = F.cosine_similarity(student_norm, teacher_norm, dim=-1)
|
23
|
+
loss = beta * (1 - cosine_sim)
|
24
|
+
return loss.sum()
|
25
|
+
|
26
|
+
@classmethod
|
27
|
+
def forward(
|
28
|
+
cls,
|
29
|
+
ctx,
|
30
|
+
student_input: torch.Tensor,
|
31
|
+
student_weight: torch.Tensor,
|
32
|
+
teacher_input: torch.Tensor,
|
33
|
+
teacher_weight: torch.Tensor,
|
34
|
+
true_labels: torch.LongTensor,
|
35
|
+
student_bias: torch.Tensor,
|
36
|
+
teacher_bias: torch.Tensor,
|
37
|
+
weight_hard_loss: float = 0.5,
|
38
|
+
weight_soft_loss: float = 0.5,
|
39
|
+
beta: float = 0.5,
|
40
|
+
ignore_index: int = -100,
|
41
|
+
temperature: float = 1.0,
|
42
|
+
compiled: bool = True,
|
43
|
+
chunk_size: int = 1024,
|
44
|
+
):
|
45
|
+
return super().forward(
|
46
|
+
cls=cls,
|
47
|
+
ctx=ctx,
|
48
|
+
student_input=student_input,
|
49
|
+
student_weight=student_weight,
|
50
|
+
teacher_input=teacher_input,
|
51
|
+
teacher_weight=teacher_weight,
|
52
|
+
target=true_labels,
|
53
|
+
student_bias=student_bias,
|
54
|
+
teacher_bias=teacher_bias,
|
55
|
+
chunk_size=chunk_size,
|
56
|
+
weight_hard_loss=weight_hard_loss,
|
57
|
+
weight_soft_loss=weight_soft_loss,
|
58
|
+
beta=beta,
|
59
|
+
ignore_index=ignore_index,
|
60
|
+
temperature=temperature,
|
61
|
+
compiled=compiled,
|
62
|
+
)
|
63
|
+
|
64
|
+
@staticmethod
|
65
|
+
def backward(ctx, grad_output):
|
66
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
|
67
|
+
|
68
|
+
return (
|
69
|
+
*grads,
|
70
|
+
None, # teacher_bias
|
71
|
+
None, # weight_hard_loss
|
72
|
+
None, # weight_soft_loss
|
73
|
+
None, # beta
|
74
|
+
None, # ignore_index
|
75
|
+
None, # temperature
|
76
|
+
None, # compiled
|
77
|
+
None, # chunk_size
|
78
|
+
)
|
79
|
+
|
80
|
+
|
81
|
+
class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
82
|
+
def __init__(
|
83
|
+
self,
|
84
|
+
weight_hard_loss: float = 0.5,
|
85
|
+
weight_soft_loss: float = 0.5,
|
86
|
+
beta: float = 0.5,
|
87
|
+
ignore_index: int = -100,
|
88
|
+
temperature: float = 1.0,
|
89
|
+
compiled: bool = True,
|
90
|
+
chunk_size: int = 1024,
|
91
|
+
):
|
92
|
+
super().__init__()
|
93
|
+
assert temperature != 0, "Temperature cannot be 0."
|
94
|
+
self.weight_hard_loss = weight_hard_loss
|
95
|
+
self.weight_soft_loss = weight_soft_loss
|
96
|
+
self.ignore_index = ignore_index
|
97
|
+
self.temperature = temperature
|
98
|
+
self.compiled = compiled
|
99
|
+
self.beta = beta
|
100
|
+
self.chunk_size = chunk_size
|
101
|
+
|
102
|
+
def forward(
|
103
|
+
self,
|
104
|
+
student_input: torch.Tensor,
|
105
|
+
student_weight: torch.Tensor,
|
106
|
+
teacher_input: torch.Tensor,
|
107
|
+
teacher_weight: torch.Tensor,
|
108
|
+
true_labels: torch.LongTensor,
|
109
|
+
student_bias: torch.Tensor = None,
|
110
|
+
teacher_bias: torch.Tensor = None,
|
111
|
+
) -> torch.Tensor:
|
112
|
+
return LigerFusedLinearCosineSimilarityFunction.apply(
|
113
|
+
student_input,
|
114
|
+
student_weight,
|
115
|
+
teacher_input,
|
116
|
+
teacher_weight,
|
117
|
+
true_labels,
|
118
|
+
student_bias,
|
119
|
+
teacher_bias,
|
120
|
+
self.weight_hard_loss,
|
121
|
+
self.weight_soft_loss,
|
122
|
+
self.beta,
|
123
|
+
self.ignore_index,
|
124
|
+
self.temperature,
|
125
|
+
self.compiled,
|
126
|
+
self.chunk_size,
|
127
|
+
)
|
@@ -1,3 +1,4 @@
|
|
1
|
+
from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction
|
1
2
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
|
2
3
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
|
3
4
|
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
|
@@ -9,6 +10,7 @@ from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
|
|
9
10
|
liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
|
10
11
|
liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
|
11
12
|
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
|
13
|
+
liger_fused_linear_cosine = LigerFusedLinearCosineSimilarityFunction.apply
|
12
14
|
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
|
13
15
|
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
|
14
16
|
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
|
@@ -2,10 +2,11 @@ liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,1728
|
3
3
|
liger_kernel/utils.py,sha256=BQleeZWHSZPNuPcYcoZTOp1kcNEZONZilPP5-AmjgWI,2024
|
4
4
|
liger_kernel/chunked_loss/README.md,sha256=0FmkFC3hKBqyoDT5uTlIYmrvRkF-EOCR1y-EBU1LpWU,2248
|
5
|
-
liger_kernel/chunked_loss/__init__.py,sha256=
|
5
|
+
liger_kernel/chunked_loss/__init__.py,sha256=J5_jNnzZ4gZmA38W5f_4oab7xMoNk1Xy-yh3X_Xlf-s,714
|
6
|
+
liger_kernel/chunked_loss/cosine_similarity_loss.py,sha256=pZ07OQ6RI-c8uk96tDRlUXdt31-da7yWhfwircZlKRw,4198
|
6
7
|
liger_kernel/chunked_loss/cpo_loss.py,sha256=Gzz1eU4kgcbdubFVRy55e8A1Cr-r45UgNicXwZIjmBU,5454
|
7
8
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=tapMiNdI8_ufW55iG0Ud4dmiW39gu1DzlvtoOCHrdGg,6259
|
8
|
-
liger_kernel/chunked_loss/functional.py,sha256
|
9
|
+
liger_kernel/chunked_loss/functional.py,sha256=-XPDbLml9dHmvoSU2VNTUrBDFehuzvuAGPikVetBMtI,1132
|
9
10
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=ooR-qnZCyWJN935oHCSWLaKKKyaYERyhNczRGi1VOiw,11935
|
10
11
|
liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=AA19cpv6D8mo5RbSK5GRCcZoOSnpxV_Z1eJlAsC5eic,13434
|
11
12
|
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=FIH85uUXAOgYx5Ax8MjFhJHVu-2pKtY7wSegd0zSyyY,18336
|
@@ -88,9 +89,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
88
89
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
89
90
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
90
91
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
91
|
-
liger_kernel_nightly-0.5.10.
|
92
|
-
liger_kernel_nightly-0.5.10.
|
93
|
-
liger_kernel_nightly-0.5.10.
|
94
|
-
liger_kernel_nightly-0.5.10.
|
95
|
-
liger_kernel_nightly-0.5.10.
|
96
|
-
liger_kernel_nightly-0.5.10.
|
92
|
+
liger_kernel_nightly-0.5.10.dev20250702150221.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
93
|
+
liger_kernel_nightly-0.5.10.dev20250702150221.dist-info/METADATA,sha256=CoPcolC_DjZu7v28Cqy2kQoE65U6f5Rx1EKf55y9NxU,24536
|
94
|
+
liger_kernel_nightly-0.5.10.dev20250702150221.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
95
|
+
liger_kernel_nightly-0.5.10.dev20250702150221.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
96
|
+
liger_kernel_nightly-0.5.10.dev20250702150221.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
97
|
+
liger_kernel_nightly-0.5.10.dev20250702150221.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|