liger-kernel-nightly 0.5.2.dev20250130024630__py3-none-any.whl → 0.5.2.dev20250130172806__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,5 +1,6 @@
1
1
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
2
2
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
3
+ from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401
3
4
  from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
4
5
  from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
5
6
  from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
@@ -1,11 +1,13 @@
1
1
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
2
2
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3
+ from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
3
4
  from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
4
5
  from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
5
6
  from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
6
7
 
7
8
  liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
8
9
  liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
10
+ liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
9
11
  liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
10
12
  liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
11
13
  liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
@@ -17,6 +17,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
17
17
  Args:
18
18
  student_logits (torch.Tensor): Raw (temperature-scaled) logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
19
19
  teacher_logits (torch.Tensor): Raw (temperature-scaled) logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
20
+ Returns:
21
+ torch.Tensor: Sum of distillation losses for the chunk. The class will handle
22
+ converting this to mean loss by dividing by the full batch size * sequence length in _compute_loss.
20
23
  """
21
24
  raise NotImplementedError("Distillation loss function must be implemented.")
22
25
 
@@ -71,10 +74,11 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
71
74
  weight_hard_loss=0.5,
72
75
  weight_soft_loss=0.5,
73
76
  compute_ce_loss=True,
77
+ temperature=1,
74
78
  **loss_kwargs,
75
79
  ):
76
80
  """
77
- Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function.
81
+ Compute the total loss for a chunk of input and target, while using an knowledge distillation loss function.
78
82
  Args:
79
83
  distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
80
84
  student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size).
@@ -84,11 +88,12 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
84
88
  target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,).
85
89
  student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
86
90
  teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
87
- full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,).
91
+ full_target (torch.Tensor): Full target tensor. Shape: (batch_size * sequence_length,).
88
92
  ignore_index (int): Index to ignore for loss computation.
89
93
  weight_hard_loss (float): Weight for hard loss.
90
94
  weight_soft_loss (float): Weight for soft loss.
91
95
  compute_ce_loss (bool): Whether to compute CE loss.
96
+ temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
92
97
  loss_kwargs (dict): Additional arguments for the loss function.
93
98
  """
94
99
  (
@@ -107,6 +112,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
107
112
  compute_ce_loss=compute_ce_loss,
108
113
  )
109
114
 
115
+ student_logits_chunk /= temperature
116
+ teacher_logits_chunk /= temperature
117
+
110
118
  hard_loss /= full_target.shape[0]
111
119
 
112
120
  soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
@@ -130,6 +138,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
130
138
  ignore_index=-100,
131
139
  weight_hard_loss=0.5,
132
140
  weight_soft_loss=0.5,
141
+ beta=0.5,
133
142
  compute_ce_loss=True,
134
143
  temperature=1.0,
135
144
  compiled=True,
@@ -152,6 +161,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
152
161
  ignore_index (int): Index to ignore for loss computation.
153
162
  weight_hard_loss (float): Weight for hard/task loss.
154
163
  weight_soft_loss (float): Weight for soft/distillation loss.
164
+ beta (float): Interpolation coefficient between 0 and 1 (default: 0.5).
155
165
  compute_ce_loss (bool): Whether to compute CE loss.
156
166
  temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
157
167
  compiled (bool): Whether to use torch compile for chunk accumulation.
@@ -170,7 +180,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
170
180
  ignore_index=ignore_index,
171
181
  weight_hard_loss=weight_hard_loss,
172
182
  weight_soft_loss=weight_soft_loss,
183
+ beta=beta,
173
184
  compute_ce_loss=compute_ce_loss,
185
+ temperature=temperature,
174
186
  **loss_kwargs,
175
187
  )
176
188
 
@@ -225,9 +237,6 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
225
237
  if compiled:
226
238
  accumulate_chunk = torch.compile(accumulate_chunk)
227
239
 
228
- student_input /= temperature
229
- teacher_input /= temperature
230
-
231
240
  num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE)
232
241
  _student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0)
233
242
  _teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)
@@ -0,0 +1,154 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
5
+
6
+
7
+ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
8
+ @staticmethod
9
+ def distillation_loss_fn(student_logits, teacher_logits, beta=0.5):
10
+ """
11
+ Compute JSD loss (Jensen-Shannon Divergence Loss).
12
+ Args:
13
+ student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
14
+ teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
15
+ beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
16
+ Returns:
17
+ torch.Tensor: Jensen-Shannon Divergence loss
18
+ """
19
+ student_log_probs = F.log_softmax(student_logits, dim=-1)
20
+ teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
21
+
22
+ # Compute probabilities (only required for mean calculation)
23
+ mean_probs = beta * student_log_probs.exp() + (1 - beta) * teacher_log_probs.exp()
24
+ log_mean_probs = mean_probs.log()
25
+
26
+ student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
27
+ teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
28
+
29
+ # JSD is the weighted average of the KL divergences
30
+ jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
31
+ return jsd_loss
32
+
33
+ @staticmethod
34
+ def forward(
35
+ ctx,
36
+ student_input: torch.Tensor,
37
+ student_weight: torch.Tensor,
38
+ teacher_input: torch.Tensor,
39
+ teacher_weight: torch.Tensor,
40
+ true_labels: torch.LongTensor,
41
+ weight_hard_loss: float = 0.5,
42
+ weight_soft_loss: float = 0.5,
43
+ beta: float = 0.5,
44
+ ignore_index: int = -100,
45
+ temperature: float = 1.0,
46
+ compiled: bool = True,
47
+ ):
48
+ """
49
+ Fused linear layer with JSD distillation loss.
50
+ Args:
51
+ student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, hidden_size_student)
52
+ student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, hidden_size_student)
53
+ teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, hidden_size_teacher)
54
+ teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, hidden_size_teacher)
55
+ true_labels (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
56
+ weight_hard_loss (float): Weight for hard loss.
57
+ weight_soft_loss (float): Weight for soft loss.
58
+ beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
59
+ ignore_index (int): Index to ignore in loss computation
60
+ temperature (float): Temperature for softening/sharpening distributions
61
+ compiled (bool): Whether to use torch compile
62
+ Returns:
63
+ torch.Tensor: Computed loss
64
+ """
65
+ return LigerFusedLinearDistillationBase.forward(
66
+ ctx=ctx,
67
+ student_input=student_input,
68
+ student_weight=student_weight,
69
+ teacher_input=teacher_input,
70
+ teacher_weight=teacher_weight,
71
+ target=true_labels,
72
+ loss_fn=LigerFusedLinearJSDFunction.distillation_loss_fn,
73
+ chunk_size=1,
74
+ weight_hard_loss=weight_hard_loss,
75
+ weight_soft_loss=weight_soft_loss,
76
+ beta=beta,
77
+ ignore_index=ignore_index,
78
+ temperature=temperature,
79
+ compiled=compiled,
80
+ )
81
+
82
+ @staticmethod
83
+ def backward(ctx, grad_output):
84
+ grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:4]
85
+
86
+ return (*grads, None, None, None, None, None, None, None)
87
+
88
+
89
+ class LigerFusedLinearJSDLoss(torch.nn.Module):
90
+ """
91
+ Fused linear layer with JSD distillation loss.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ weight_hard_loss: float = 0.5,
97
+ weight_soft_loss: float = 0.5,
98
+ beta: float = 0.5,
99
+ ignore_index: int = -100,
100
+ temperature: float = 1.0,
101
+ compiled: bool = True,
102
+ ):
103
+ """
104
+ Args:
105
+ weight_hard_loss (float): Weight for hard loss.
106
+ weight_soft_loss (float): Weight for soft loss.
107
+ ignore_index (int): Index to ignore in the loss
108
+ temperature (float): Temperature for softening distributions
109
+ compiled (bool): Whether to use torch compile
110
+ beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
111
+ """
112
+ super().__init__()
113
+ assert temperature != 0, "Temperature cannot be 0."
114
+ self.weight_hard_loss = weight_hard_loss
115
+ self.weight_soft_loss = weight_soft_loss
116
+ self.ignore_index = ignore_index
117
+ self.temperature = temperature
118
+ self.compiled = compiled
119
+ self.beta = beta
120
+
121
+ def forward(
122
+ self,
123
+ student_input: torch.Tensor,
124
+ student_weight: torch.Tensor,
125
+ teacher_input: torch.Tensor,
126
+ teacher_weight: torch.Tensor,
127
+ true_labels: torch.LongTensor,
128
+ ) -> torch.Tensor:
129
+ """
130
+ Compute the JSD distillation loss.
131
+
132
+ Args:
133
+ student_input (torch.Tensor): Student input tensor
134
+ student_weight (torch.Tensor): Student weight tensor
135
+ teacher_input (torch.Tensor): Teacher input tensor
136
+ teacher_weight (torch.Tensor): Teacher weight tensor
137
+ true_labels (torch.LongTensor): Target labels tensor
138
+
139
+ Returns:
140
+ torch.Tensor: Computed loss
141
+ """
142
+ return LigerFusedLinearJSDFunction.apply(
143
+ student_input,
144
+ student_weight,
145
+ teacher_input,
146
+ teacher_weight,
147
+ true_labels,
148
+ self.weight_hard_loss,
149
+ self.weight_soft_loss,
150
+ self.beta,
151
+ self.ignore_index,
152
+ self.temperature,
153
+ self.compiled,
154
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20250130024630
3
+ Version: 0.5.2.dev20250130172806
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -2,13 +2,14 @@ liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,1728
3
3
  liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
4
4
  liger_kernel/chunked_loss/README.md,sha256=0FmkFC3hKBqyoDT5uTlIYmrvRkF-EOCR1y-EBU1LpWU,2248
5
- liger_kernel/chunked_loss/__init__.py,sha256=CI6hBI7VldTX748c7F6F8YpHTn1q4gv5-lMXf273oXQ,431
5
+ liger_kernel/chunked_loss/__init__.py,sha256=48m-8IMOAReZbi0HK5aV-KGBp2IsZSwFvdnzTNrS4bk,516
6
6
  liger_kernel/chunked_loss/cpo_loss.py,sha256=OdBR8WYdHTKpLI_c9DcuwqKSWPeAAeTyREz46Vu_cAY,3682
7
7
  liger_kernel/chunked_loss/dpo_loss.py,sha256=VYZMOafdvE8xlhvTtwjrz81tIzxR1mHF4lXdsADnIQg,4373
8
- liger_kernel/chunked_loss/functional.py,sha256=dO0DYMPTBxwPtEUQ1DUV2zCmZ6i-k3B7COeR3-IwA6M,683
9
- liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=uQtwtu-kaUZJTjNhAnIr3O794oUlUZ98XR5shYtwP5k,10440
8
+ liger_kernel/chunked_loss/functional.py,sha256=THWWpCnRVhTVfnPnyvQjdBvo1JDtxhwLmtZE_yiBBqM,817
9
+ liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=5V8rdva89WyHVbmJ8JOmC4DYNOR6ByXfx3qlUieOZkI,11002
10
10
  liger_kernel/chunked_loss/fused_linear_preference.py,sha256=idK9V9NivoVITqVpiG0fEGUHSvinYWkn9-EYXZjR-KQ,18356
11
11
  liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=ZqYlXXhIphkJPxOS7iI70avgrr6x0skEtgpckZTYau0,9819
12
+ liger_kernel/chunked_loss/jsd_loss.py,sha256=yRCQdvd3ruTWP4A_BfU8VcZ6LepSUfO0Ob7stGnueQY,6052
12
13
  liger_kernel/chunked_loss/kto_loss.py,sha256=eVNW6HVCAm32shpfhbRlk92Flnjd7G32v0gK9DUUSOQ,5655
13
14
  liger_kernel/chunked_loss/orpo_loss.py,sha256=yjcrrbVeemLYodoSKT-FMSnaPtyKAZ3aOrvPD6tTY6Y,3617
14
15
  liger_kernel/chunked_loss/simpo_loss.py,sha256=3TTc7U79Orjgi-Wu81WZkWk5MgsdqKXIOBHgIvDazPw,3865
@@ -60,9 +61,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
60
61
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
61
62
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
62
63
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
63
- liger_kernel_nightly-0.5.2.dev20250130024630.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
64
- liger_kernel_nightly-0.5.2.dev20250130024630.dist-info/METADATA,sha256=1iagCXBr_TV35IIu5T8wUpFhYmmrxMbbKkOwPZF7pgc,21205
65
- liger_kernel_nightly-0.5.2.dev20250130024630.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
66
- liger_kernel_nightly-0.5.2.dev20250130024630.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
67
- liger_kernel_nightly-0.5.2.dev20250130024630.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
68
- liger_kernel_nightly-0.5.2.dev20250130024630.dist-info/RECORD,,
64
+ liger_kernel_nightly-0.5.2.dev20250130172806.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
65
+ liger_kernel_nightly-0.5.2.dev20250130172806.dist-info/METADATA,sha256=LCLD7LcN4x7h1_LMuYPAtIUrhHGi2eoON-NOBtofCN0,21205
66
+ liger_kernel_nightly-0.5.2.dev20250130172806.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
67
+ liger_kernel_nightly-0.5.2.dev20250130172806.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
68
+ liger_kernel_nightly-0.5.2.dev20250130172806.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
69
+ liger_kernel_nightly-0.5.2.dev20250130172806.dist-info/RECORD,,