liger-kernel-nightly 0.5.2.dev20250108072837__py3-none-any.whl → 0.5.2.dev20250108073340__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -65,6 +65,7 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
65
65
  beta=beta,
66
66
  label_smoothing=label_smoothing,
67
67
  compute_nll_loss=compute_nll_loss,
68
+ average_log_prob=False,
68
69
  compiled=compiled,
69
70
  )
70
71
 
@@ -32,6 +32,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
32
32
  ref_input=None,
33
33
  ref_weight=None,
34
34
  ref_bias=None,
35
+ average_log_prob=True,
35
36
  **loss_kwargs,
36
37
  ):
37
38
  """
@@ -61,6 +62,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
61
62
  use_ref_model (bool): Whether to use a reference model for the alignment loss.
62
63
  ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
63
64
  ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
65
+ average_log_prob (bool): Whether to average log probabilities or to sum them over the completion.
64
66
  loss_kwargs (dict): Other possible arguments that a loss function might need
65
67
  """
66
68
  # TODO: Tune CHUNK_SIZE to fully utilize the GPU
@@ -94,6 +96,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
94
96
  use_ref_model=use_ref_model,
95
97
  ref_weight=ref_weight,
96
98
  ref_bias=ref_bias,
99
+ average_log_prob=average_log_prob,
97
100
  **loss_kwargs,
98
101
  )
99
102
 
@@ -265,6 +268,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
265
268
  bias=None,
266
269
  ignore_index=-100,
267
270
  compute_nll_loss=True,
271
+ average_log_prob=True,
268
272
  ):
269
273
  len_chosen_chunk = target_chunk.shape[0] // 2
270
274
  logits_chunk = input_chunk @ weight.t()
@@ -285,10 +289,13 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
285
289
  label_chunk = torch.where(loss_mask, target_chunk, 0)
286
290
 
287
291
  per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
288
- average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
292
+ if average_log_prob:
293
+ log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
294
+ else:
295
+ log_prob = (per_token_logps * loss_mask).sum(-1)
289
296
 
290
- chosen_logps = average_log_prob[:len_chosen_chunk]
291
- rejected_logps = average_log_prob[len_chosen_chunk:]
297
+ chosen_logps = log_prob[:len_chosen_chunk]
298
+ rejected_logps = log_prob[len_chosen_chunk:]
292
299
 
293
300
  chosen_logits = logits_chunk[:len_chosen_chunk]
294
301
  rejected_logits = logits_chunk[len_chosen_chunk:]
@@ -317,6 +324,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
317
324
  ref_input_chunk=None,
318
325
  ref_weight=None,
319
326
  ref_bias=None,
327
+ average_log_prob=True,
320
328
  **loss_kwargs,
321
329
  ):
322
330
  """
@@ -335,6 +343,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
335
343
  use_ref_model (bool): Whether to use a reference model for the alignment loss.
336
344
  ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
337
345
  ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
346
+ average_log_prob (bool): Whether to average log probabilities or the sum.
338
347
  loss_kwargs (dict): Additional arguments for the loss function.
339
348
  """
340
349
  (
@@ -350,6 +359,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
350
359
  bias=bias,
351
360
  ignore_index=ignore_index,
352
361
  compute_nll_loss=compute_nll_loss,
362
+ average_log_prob=average_log_prob,
353
363
  )
354
364
  chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
355
365
  chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
@@ -372,6 +382,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
372
382
  ref_bias,
373
383
  ignore_index=ignore_index,
374
384
  compute_nll_loss=False, # We don't need NLL loss for the reference model
385
+ average_log_prob=average_log_prob,
375
386
  )
376
387
  loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
377
388
  loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20250108072837
3
+ Version: 0.5.2.dev20250108073340
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -3,11 +3,11 @@ liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,17
3
3
  liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
4
4
  liger_kernel/chunked_loss/README.md,sha256=K6rucm6nqHpWCmxUOhBYcE3apwQxAy0TfRUippR7Icw,2243
5
5
  liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
6
- liger_kernel/chunked_loss/cpo_loss.py,sha256=MCR4TzuBoJEaU0IJ7dIreLacQeXLKETV5CegNjhCD9M,3646
6
+ liger_kernel/chunked_loss/cpo_loss.py,sha256=OdBR8WYdHTKpLI_c9DcuwqKSWPeAAeTyREz46Vu_cAY,3682
7
7
  liger_kernel/chunked_loss/dpo_loss.py,sha256=VYZMOafdvE8xlhvTtwjrz81tIzxR1mHF4lXdsADnIQg,4373
8
8
  liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
9
9
  liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=uQtwtu-kaUZJTjNhAnIr3O794oUlUZ98XR5shYtwP5k,10440
10
- liger_kernel/chunked_loss/fused_linear_preference.py,sha256=25sTgvphLKAR0jyJcrsJPKK1abFpTKrajSyAx8nJ3bc,16134
10
+ liger_kernel/chunked_loss/fused_linear_preference.py,sha256=eQCZmQ3xOL3jpZ7RhOfx_pqR9sNEX6RHx8DtIgyXEHc,16656
11
11
  liger_kernel/chunked_loss/orpo_loss.py,sha256=jbZxx-EjPK71A6CSyNzTOAIEQgAUjfvwSViw6R_pPXQ,3510
12
12
  liger_kernel/chunked_loss/simpo_loss.py,sha256=3TTc7U79Orjgi-Wu81WZkWk5MgsdqKXIOBHgIvDazPw,3865
13
13
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
58
58
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=MId1S_MfA3pPVQA1rkiKxp-jZDNz8VmvZzXC-Kugol4,7662
59
59
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
60
60
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
61
- liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
- liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/METADATA,sha256=HwmQEBRYnwwbdkzuW53_qsmTSSbi8qu20cVOHsq6B_s,21055
63
- liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
- liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
- liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
- liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/RECORD,,
61
+ liger_kernel_nightly-0.5.2.dev20250108073340.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
+ liger_kernel_nightly-0.5.2.dev20250108073340.dist-info/METADATA,sha256=m2Zrd4xffCEa6qCxyFCCH6l1WJuk7V6eZ28Pt2_dtHc,21055
63
+ liger_kernel_nightly-0.5.2.dev20250108073340.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
+ liger_kernel_nightly-0.5.2.dev20250108073340.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
+ liger_kernel_nightly-0.5.2.dev20250108073340.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
+ liger_kernel_nightly-0.5.2.dev20250108073340.dist-info/RECORD,,