liger-kernel-nightly 0.5.2.dev20241229131950__py3-none-any.whl → 0.5.2.dev20250101082227__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -8,12 +8,15 @@ from torch.nn import functional as F
8
8
 
9
9
  class LigerFusedLinearDistillationBase(torch.autograd.Function):
10
10
  @abstractmethod
11
- def distillation_loss_fn(student_logits, teacher_logits, temperature):
11
+ def distillation_loss_fn(
12
+ student_logits,
13
+ teacher_logits,
14
+ ):
12
15
  """
13
16
  Compute distillation loss.
14
17
  Args:
15
- student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
16
- teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
18
+ student_logits (torch.Tensor): Raw (temperature-scaled) logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
19
+ teacher_logits (torch.Tensor): Raw (temperature-scaled) logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
17
20
  """
18
21
  raise NotImplementedError("Distillation loss function must be implemented.")
19
22
 
@@ -65,7 +68,6 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
65
68
  distillation_loss_fn=None,
66
69
  full_target=None,
67
70
  ignore_index=-100,
68
- temperature=1.0,
69
71
  weight_hard_loss=0.5,
70
72
  weight_soft_loss=0.5,
71
73
  compute_ce_loss=True,
@@ -107,7 +109,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
107
109
 
108
110
  hard_loss /= full_target.shape[0]
109
111
 
110
- soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature)
112
+ soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
111
113
  soft_loss /= full_target.shape[0]
112
114
 
113
115
  loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
@@ -147,10 +149,11 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
147
149
  teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,).
148
150
  loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
149
151
  chunk_size (int): Size of a chunk.
150
- compute_ce_loss (bool): Whether to compute CE loss.
151
152
  ignore_index (int): Index to ignore for loss computation.
152
153
  weight_hard_loss (float): Weight for hard/task loss.
153
154
  weight_soft_loss (float): Weight for soft/distillation loss.
155
+ compute_ce_loss (bool): Whether to compute CE loss.
156
+ temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
154
157
  compiled (bool): Whether to use torch compile for chunk accumulation.
155
158
  loss_kwargs (dict): Other possible arguments that a loss function might need
156
159
  """
@@ -168,7 +171,6 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
168
171
  weight_hard_loss=weight_hard_loss,
169
172
  weight_soft_loss=weight_soft_loss,
170
173
  compute_ce_loss=compute_ce_loss,
171
- temperature=temperature,
172
174
  **loss_kwargs,
173
175
  )
174
176
 
@@ -223,6 +225,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
223
225
  if compiled:
224
226
  accumulate_chunk = torch.compile(accumulate_chunk)
225
227
 
228
+ student_input /= temperature
229
+ teacher_input /= temperature
230
+
226
231
  num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE)
227
232
  _student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0)
228
233
  _teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241229131950
3
+ Version: 0.5.2.dev20250101082227
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -6,7 +6,7 @@ liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMH
6
6
  liger_kernel/chunked_loss/cpo_loss.py,sha256=L4Nk38Xh5Yfhah3Vsc_sN_Q75FWt1LA-xNNXzsK8iPM,3516
7
7
  liger_kernel/chunked_loss/dpo_loss.py,sha256=VYZMOafdvE8xlhvTtwjrz81tIzxR1mHF4lXdsADnIQg,4373
8
8
  liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
9
- liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=M-QWvGPnWefYDn6Hr9bPn7diMNP5qrUaeWTb_zdMO4E,10265
9
+ liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=uQtwtu-kaUZJTjNhAnIr3O794oUlUZ98XR5shYtwP5k,10440
10
10
  liger_kernel/chunked_loss/fused_linear_preference.py,sha256=25sTgvphLKAR0jyJcrsJPKK1abFpTKrajSyAx8nJ3bc,16134
11
11
  liger_kernel/chunked_loss/orpo_loss.py,sha256=jbZxx-EjPK71A6CSyNzTOAIEQgAUjfvwSViw6R_pPXQ,3510
12
12
  liger_kernel/chunked_loss/simpo_loss.py,sha256=ZvDIjT9EQrbwzH2LNZMhv84SPsOHGi_Ywk95vgA0b_o,3736
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
58
58
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=MId1S_MfA3pPVQA1rkiKxp-jZDNz8VmvZzXC-Kugol4,7662
59
59
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
60
60
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
61
- liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
- liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/METADATA,sha256=iOyPsdNf1GL3Z3Ng0CS3xoOq6iiTb8eFXAMwqDT1UZM,21055
63
- liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
- liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
- liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
- liger_kernel_nightly-0.5.2.dev20241229131950.dist-info/RECORD,,
61
+ liger_kernel_nightly-0.5.2.dev20250101082227.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
+ liger_kernel_nightly-0.5.2.dev20250101082227.dist-info/METADATA,sha256=gNuR5mtVV7fQsT0qPLr3_Ok2WLKHgbC2FidkcY1q6OA,21055
63
+ liger_kernel_nightly-0.5.2.dev20250101082227.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
+ liger_kernel_nightly-0.5.2.dev20250101082227.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
+ liger_kernel_nightly-0.5.2.dev20250101082227.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
+ liger_kernel_nightly-0.5.2.dev20250101082227.dist-info/RECORD,,