liger-kernel-nightly 0.5.2.dev20241229131950__py3-none-any.whl → 0.5.2.dev20250101082227__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- liger_kernel/chunked_loss/fused_linear_distillation.py +12 -7
- {liger_kernel_nightly-0.5.2.dev20241229131950.dist-info → liger_kernel_nightly-0.5.2.dev20250101082227.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.2.dev20241229131950.dist-info → liger_kernel_nightly-0.5.2.dev20250101082227.dist-info}/RECORD +7 -7
- {liger_kernel_nightly-0.5.2.dev20241229131950.dist-info → liger_kernel_nightly-0.5.2.dev20250101082227.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229131950.dist-info → liger_kernel_nightly-0.5.2.dev20250101082227.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229131950.dist-info → liger_kernel_nightly-0.5.2.dev20250101082227.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20241229131950.dist-info → liger_kernel_nightly-0.5.2.dev20250101082227.dist-info}/top_level.txt +0 -0
@@ -8,12 +8,15 @@ from torch.nn import functional as F
|
|
8
8
|
|
9
9
|
class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
10
10
|
@abstractmethod
|
11
|
-
def distillation_loss_fn(
|
11
|
+
def distillation_loss_fn(
|
12
|
+
student_logits,
|
13
|
+
teacher_logits,
|
14
|
+
):
|
12
15
|
"""
|
13
16
|
Compute distillation loss.
|
14
17
|
Args:
|
15
|
-
student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
|
16
|
-
teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
|
18
|
+
student_logits (torch.Tensor): Raw (temperature-scaled) logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
|
19
|
+
teacher_logits (torch.Tensor): Raw (temperature-scaled) logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
|
17
20
|
"""
|
18
21
|
raise NotImplementedError("Distillation loss function must be implemented.")
|
19
22
|
|
@@ -65,7 +68,6 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
65
68
|
distillation_loss_fn=None,
|
66
69
|
full_target=None,
|
67
70
|
ignore_index=-100,
|
68
|
-
temperature=1.0,
|
69
71
|
weight_hard_loss=0.5,
|
70
72
|
weight_soft_loss=0.5,
|
71
73
|
compute_ce_loss=True,
|
@@ -107,7 +109,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
107
109
|
|
108
110
|
hard_loss /= full_target.shape[0]
|
109
111
|
|
110
|
-
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk
|
112
|
+
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
|
111
113
|
soft_loss /= full_target.shape[0]
|
112
114
|
|
113
115
|
loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
|
@@ -147,10 +149,11 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
147
149
|
teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,).
|
148
150
|
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
149
151
|
chunk_size (int): Size of a chunk.
|
150
|
-
compute_ce_loss (bool): Whether to compute CE loss.
|
151
152
|
ignore_index (int): Index to ignore for loss computation.
|
152
153
|
weight_hard_loss (float): Weight for hard/task loss.
|
153
154
|
weight_soft_loss (float): Weight for soft/distillation loss.
|
155
|
+
compute_ce_loss (bool): Whether to compute CE loss.
|
156
|
+
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
|
154
157
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
155
158
|
loss_kwargs (dict): Other possible arguments that a loss function might need
|
156
159
|
"""
|
@@ -168,7 +171,6 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
168
171
|
weight_hard_loss=weight_hard_loss,
|
169
172
|
weight_soft_loss=weight_soft_loss,
|
170
173
|
compute_ce_loss=compute_ce_loss,
|
171
|
-
temperature=temperature,
|
172
174
|
**loss_kwargs,
|
173
175
|
)
|
174
176
|
|
@@ -223,6 +225,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
223
225
|
if compiled:
|
224
226
|
accumulate_chunk = torch.compile(accumulate_chunk)
|
225
227
|
|
228
|
+
student_input /= temperature
|
229
|
+
teacher_input /= temperature
|
230
|
+
|
226
231
|
num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE)
|
227
232
|
_student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0)
|
228
233
|
_teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)
|
@@ -6,7 +6,7 @@ liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMH
|
|
6
6
|
liger_kernel/chunked_loss/cpo_loss.py,sha256=L4Nk38Xh5Yfhah3Vsc_sN_Q75FWt1LA-xNNXzsK8iPM,3516
|
7
7
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=VYZMOafdvE8xlhvTtwjrz81tIzxR1mHF4lXdsADnIQg,4373
|
8
8
|
liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
|
9
|
-
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=
|
9
|
+
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=uQtwtu-kaUZJTjNhAnIr3O794oUlUZ98XR5shYtwP5k,10440
|
10
10
|
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=25sTgvphLKAR0jyJcrsJPKK1abFpTKrajSyAx8nJ3bc,16134
|
11
11
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=jbZxx-EjPK71A6CSyNzTOAIEQgAUjfvwSViw6R_pPXQ,3510
|
12
12
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=ZvDIjT9EQrbwzH2LNZMhv84SPsOHGi_Ywk95vgA0b_o,3736
|
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
58
58
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=MId1S_MfA3pPVQA1rkiKxp-jZDNz8VmvZzXC-Kugol4,7662
|
59
59
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
60
60
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
61
|
-
liger_kernel_nightly-0.5.2.
|
62
|
-
liger_kernel_nightly-0.5.2.
|
63
|
-
liger_kernel_nightly-0.5.2.
|
64
|
-
liger_kernel_nightly-0.5.2.
|
65
|
-
liger_kernel_nightly-0.5.2.
|
66
|
-
liger_kernel_nightly-0.5.2.
|
61
|
+
liger_kernel_nightly-0.5.2.dev20250101082227.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
62
|
+
liger_kernel_nightly-0.5.2.dev20250101082227.dist-info/METADATA,sha256=gNuR5mtVV7fQsT0qPLr3_Ok2WLKHgbC2FidkcY1q6OA,21055
|
63
|
+
liger_kernel_nightly-0.5.2.dev20250101082227.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
64
|
+
liger_kernel_nightly-0.5.2.dev20250101082227.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
65
|
+
liger_kernel_nightly-0.5.2.dev20250101082227.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
66
|
+
liger_kernel_nightly-0.5.2.dev20250101082227.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|