liger-kernel-nightly 0.5.10.dev20250630172023__py3-none-any.whl → 0.5.10.dev20250702150221__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,4 @@
1
+ from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityLoss # noqa:F401
1
2
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
2
3
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
3
4
  from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
@@ -0,0 +1,127 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
5
+
6
+
7
+ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase):
8
+ @staticmethod
9
+ def distillation_loss_fn(student_logits, teacher_logits, beta=1.0):
10
+ """
11
+ Compute Cosine loss (Cosine Similarity Loss).
12
+ Args:
13
+ student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
14
+ teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
15
+ beta: Coefficient beta of generalized Cosine Similarity in the interval [0, 1]. Default: `1.0` (float): .
16
+ Returns:
17
+ torch.Tensor: cosine similarity loss
18
+ """
19
+ student_norm = F.normalize(student_logits, p=2, dim=-1)
20
+ teacher_norm = F.normalize(teacher_logits, p=2, dim=-1)
21
+
22
+ cosine_sim = F.cosine_similarity(student_norm, teacher_norm, dim=-1)
23
+ loss = beta * (1 - cosine_sim)
24
+ return loss.sum()
25
+
26
+ @classmethod
27
+ def forward(
28
+ cls,
29
+ ctx,
30
+ student_input: torch.Tensor,
31
+ student_weight: torch.Tensor,
32
+ teacher_input: torch.Tensor,
33
+ teacher_weight: torch.Tensor,
34
+ true_labels: torch.LongTensor,
35
+ student_bias: torch.Tensor,
36
+ teacher_bias: torch.Tensor,
37
+ weight_hard_loss: float = 0.5,
38
+ weight_soft_loss: float = 0.5,
39
+ beta: float = 0.5,
40
+ ignore_index: int = -100,
41
+ temperature: float = 1.0,
42
+ compiled: bool = True,
43
+ chunk_size: int = 1024,
44
+ ):
45
+ return super().forward(
46
+ cls=cls,
47
+ ctx=ctx,
48
+ student_input=student_input,
49
+ student_weight=student_weight,
50
+ teacher_input=teacher_input,
51
+ teacher_weight=teacher_weight,
52
+ target=true_labels,
53
+ student_bias=student_bias,
54
+ teacher_bias=teacher_bias,
55
+ chunk_size=chunk_size,
56
+ weight_hard_loss=weight_hard_loss,
57
+ weight_soft_loss=weight_soft_loss,
58
+ beta=beta,
59
+ ignore_index=ignore_index,
60
+ temperature=temperature,
61
+ compiled=compiled,
62
+ )
63
+
64
+ @staticmethod
65
+ def backward(ctx, grad_output):
66
+ grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
67
+
68
+ return (
69
+ *grads,
70
+ None, # teacher_bias
71
+ None, # weight_hard_loss
72
+ None, # weight_soft_loss
73
+ None, # beta
74
+ None, # ignore_index
75
+ None, # temperature
76
+ None, # compiled
77
+ None, # chunk_size
78
+ )
79
+
80
+
81
+ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
82
+ def __init__(
83
+ self,
84
+ weight_hard_loss: float = 0.5,
85
+ weight_soft_loss: float = 0.5,
86
+ beta: float = 0.5,
87
+ ignore_index: int = -100,
88
+ temperature: float = 1.0,
89
+ compiled: bool = True,
90
+ chunk_size: int = 1024,
91
+ ):
92
+ super().__init__()
93
+ assert temperature != 0, "Temperature cannot be 0."
94
+ self.weight_hard_loss = weight_hard_loss
95
+ self.weight_soft_loss = weight_soft_loss
96
+ self.ignore_index = ignore_index
97
+ self.temperature = temperature
98
+ self.compiled = compiled
99
+ self.beta = beta
100
+ self.chunk_size = chunk_size
101
+
102
+ def forward(
103
+ self,
104
+ student_input: torch.Tensor,
105
+ student_weight: torch.Tensor,
106
+ teacher_input: torch.Tensor,
107
+ teacher_weight: torch.Tensor,
108
+ true_labels: torch.LongTensor,
109
+ student_bias: torch.Tensor = None,
110
+ teacher_bias: torch.Tensor = None,
111
+ ) -> torch.Tensor:
112
+ return LigerFusedLinearCosineSimilarityFunction.apply(
113
+ student_input,
114
+ student_weight,
115
+ teacher_input,
116
+ teacher_weight,
117
+ true_labels,
118
+ student_bias,
119
+ teacher_bias,
120
+ self.weight_hard_loss,
121
+ self.weight_soft_loss,
122
+ self.beta,
123
+ self.ignore_index,
124
+ self.temperature,
125
+ self.compiled,
126
+ self.chunk_size,
127
+ )
@@ -1,3 +1,4 @@
1
+ from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction
1
2
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
2
3
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3
4
  from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
@@ -9,6 +10,7 @@ from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
9
10
  liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
10
11
  liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
11
12
  liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
13
+ liger_fused_linear_cosine = LigerFusedLinearCosineSimilarityFunction.apply
12
14
  liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
13
15
  liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
14
16
  liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.10.dev20250630172023
3
+ Version: 0.5.10.dev20250702150221
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -2,10 +2,11 @@ liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,1728
3
3
  liger_kernel/utils.py,sha256=BQleeZWHSZPNuPcYcoZTOp1kcNEZONZilPP5-AmjgWI,2024
4
4
  liger_kernel/chunked_loss/README.md,sha256=0FmkFC3hKBqyoDT5uTlIYmrvRkF-EOCR1y-EBU1LpWU,2248
5
- liger_kernel/chunked_loss/__init__.py,sha256=ATu-xX5Fc49Cr6yBOGBRNTo593ZrU5ZCsIuvoIbJWw4,603
5
+ liger_kernel/chunked_loss/__init__.py,sha256=J5_jNnzZ4gZmA38W5f_4oab7xMoNk1Xy-yh3X_Xlf-s,714
6
+ liger_kernel/chunked_loss/cosine_similarity_loss.py,sha256=pZ07OQ6RI-c8uk96tDRlUXdt31-da7yWhfwircZlKRw,4198
6
7
  liger_kernel/chunked_loss/cpo_loss.py,sha256=Gzz1eU4kgcbdubFVRy55e8A1Cr-r45UgNicXwZIjmBU,5454
7
8
  liger_kernel/chunked_loss/dpo_loss.py,sha256=tapMiNdI8_ufW55iG0Ud4dmiW39gu1DzlvtoOCHrdGg,6259
8
- liger_kernel/chunked_loss/functional.py,sha256=9G3nKm-Bi7uoZRFkL8wwGMl6juDl4bSzDvTa5GHZPzg,955
9
+ liger_kernel/chunked_loss/functional.py,sha256=-XPDbLml9dHmvoSU2VNTUrBDFehuzvuAGPikVetBMtI,1132
9
10
  liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=ooR-qnZCyWJN935oHCSWLaKKKyaYERyhNczRGi1VOiw,11935
10
11
  liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=AA19cpv6D8mo5RbSK5GRCcZoOSnpxV_Z1eJlAsC5eic,13434
11
12
  liger_kernel/chunked_loss/fused_linear_preference.py,sha256=FIH85uUXAOgYx5Ax8MjFhJHVu-2pKtY7wSegd0zSyyY,18336
@@ -88,9 +89,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
88
89
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
89
90
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
90
91
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
91
- liger_kernel_nightly-0.5.10.dev20250630172023.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
92
- liger_kernel_nightly-0.5.10.dev20250630172023.dist-info/METADATA,sha256=R9S054XUfsyrq9HECn8SHjNLRdXF6KxS6vP1w_fuqjI,24536
93
- liger_kernel_nightly-0.5.10.dev20250630172023.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
94
- liger_kernel_nightly-0.5.10.dev20250630172023.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
95
- liger_kernel_nightly-0.5.10.dev20250630172023.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
96
- liger_kernel_nightly-0.5.10.dev20250630172023.dist-info/RECORD,,
92
+ liger_kernel_nightly-0.5.10.dev20250702150221.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
93
+ liger_kernel_nightly-0.5.10.dev20250702150221.dist-info/METADATA,sha256=CoPcolC_DjZu7v28Cqy2kQoE65U6f5Rx1EKf55y9NxU,24536
94
+ liger_kernel_nightly-0.5.10.dev20250702150221.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
95
+ liger_kernel_nightly-0.5.10.dev20250702150221.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
96
+ liger_kernel_nightly-0.5.10.dev20250702150221.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
97
+ liger_kernel_nightly-0.5.10.dev20250702150221.dist-info/RECORD,,