liger-kernel-nightly 0.5.2.dev20250108072837__py3-none-any.whl → 0.5.2.dev20250108073340__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -65,6 +65,7 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
65
65
  beta=beta,
66
66
  label_smoothing=label_smoothing,
67
67
  compute_nll_loss=compute_nll_loss,
68
+ average_log_prob=False,
68
69
  compiled=compiled,
69
70
  )
70
71
 
@@ -32,6 +32,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
32
32
  ref_input=None,
33
33
  ref_weight=None,
34
34
  ref_bias=None,
35
+ average_log_prob=True,
35
36
  **loss_kwargs,
36
37
  ):
37
38
  """
@@ -61,6 +62,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
61
62
  use_ref_model (bool): Whether to use a reference model for the alignment loss.
62
63
  ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
63
64
  ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
65
+ average_log_prob (bool): Whether to average log probabilities or to sum them over the completion.
64
66
  loss_kwargs (dict): Other possible arguments that a loss function might need
65
67
  """
66
68
  # TODO: Tune CHUNK_SIZE to fully utilize the GPU
@@ -94,6 +96,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
94
96
  use_ref_model=use_ref_model,
95
97
  ref_weight=ref_weight,
96
98
  ref_bias=ref_bias,
99
+ average_log_prob=average_log_prob,
97
100
  **loss_kwargs,
98
101
  )
99
102
 
@@ -265,6 +268,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
265
268
  bias=None,
266
269
  ignore_index=-100,
267
270
  compute_nll_loss=True,
271
+ average_log_prob=True,
268
272
  ):
269
273
  len_chosen_chunk = target_chunk.shape[0] // 2
270
274
  logits_chunk = input_chunk @ weight.t()
@@ -285,10 +289,13 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
285
289
  label_chunk = torch.where(loss_mask, target_chunk, 0)
286
290
 
287
291
  per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
288
- average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
292
+ if average_log_prob:
293
+ log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
294
+ else:
295
+ log_prob = (per_token_logps * loss_mask).sum(-1)
289
296
 
290
- chosen_logps = average_log_prob[:len_chosen_chunk]
291
- rejected_logps = average_log_prob[len_chosen_chunk:]
297
+ chosen_logps = log_prob[:len_chosen_chunk]
298
+ rejected_logps = log_prob[len_chosen_chunk:]
292
299
 
293
300
  chosen_logits = logits_chunk[:len_chosen_chunk]
294
301
  rejected_logits = logits_chunk[len_chosen_chunk:]
@@ -317,6 +324,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
317
324
  ref_input_chunk=None,
318
325
  ref_weight=None,
319
326
  ref_bias=None,
327
+ average_log_prob=True,
320
328
  **loss_kwargs,
321
329
  ):
322
330
  """
@@ -335,6 +343,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
335
343
  use_ref_model (bool): Whether to use a reference model for the alignment loss.
336
344
  ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
337
345
  ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
346
+ average_log_prob (bool): Whether to average log probabilities or the sum.
338
347
  loss_kwargs (dict): Additional arguments for the loss function.
339
348
  """
340
349
  (
@@ -350,6 +359,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
350
359
  bias=bias,
351
360
  ignore_index=ignore_index,
352
361
  compute_nll_loss=compute_nll_loss,
362
+ average_log_prob=average_log_prob,
353
363
  )
354
364
  chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
355
365
  chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
@@ -372,6 +382,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
372
382
  ref_bias,
373
383
  ignore_index=ignore_index,
374
384
  compute_nll_loss=False, # We don't need NLL loss for the reference model
385
+ average_log_prob=average_log_prob,
375
386
  )
376
387
  loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
377
388
  loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20250108072837
3
+ Version: 0.5.2.dev20250108073340
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -3,11 +3,11 @@ liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,17
3
3
  liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
4
4
  liger_kernel/chunked_loss/README.md,sha256=K6rucm6nqHpWCmxUOhBYcE3apwQxAy0TfRUippR7Icw,2243
5
5
  liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
6
- liger_kernel/chunked_loss/cpo_loss.py,sha256=MCR4TzuBoJEaU0IJ7dIreLacQeXLKETV5CegNjhCD9M,3646
6
+ liger_kernel/chunked_loss/cpo_loss.py,sha256=OdBR8WYdHTKpLI_c9DcuwqKSWPeAAeTyREz46Vu_cAY,3682
7
7
  liger_kernel/chunked_loss/dpo_loss.py,sha256=VYZMOafdvE8xlhvTtwjrz81tIzxR1mHF4lXdsADnIQg,4373
8
8
  liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
9
9
  liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=uQtwtu-kaUZJTjNhAnIr3O794oUlUZ98XR5shYtwP5k,10440
10
- liger_kernel/chunked_loss/fused_linear_preference.py,sha256=25sTgvphLKAR0jyJcrsJPKK1abFpTKrajSyAx8nJ3bc,16134
10
+ liger_kernel/chunked_loss/fused_linear_preference.py,sha256=eQCZmQ3xOL3jpZ7RhOfx_pqR9sNEX6RHx8DtIgyXEHc,16656
11
11
  liger_kernel/chunked_loss/orpo_loss.py,sha256=jbZxx-EjPK71A6CSyNzTOAIEQgAUjfvwSViw6R_pPXQ,3510
12
12
  liger_kernel/chunked_loss/simpo_loss.py,sha256=3TTc7U79Orjgi-Wu81WZkWk5MgsdqKXIOBHgIvDazPw,3865
13
13
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
58
58
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=MId1S_MfA3pPVQA1rkiKxp-jZDNz8VmvZzXC-Kugol4,7662
59
59
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
60
60
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
61
- liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
- liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/METADATA,sha256=HwmQEBRYnwwbdkzuW53_qsmTSSbi8qu20cVOHsq6B_s,21055
63
- liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
- liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
- liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
- liger_kernel_nightly-0.5.2.dev20250108072837.dist-info/RECORD,,
61
+ liger_kernel_nightly-0.5.2.dev20250108073340.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
+ liger_kernel_nightly-0.5.2.dev20250108073340.dist-info/METADATA,sha256=m2Zrd4xffCEa6qCxyFCCH6l1WJuk7V6eZ28Pt2_dtHc,21055
63
+ liger_kernel_nightly-0.5.2.dev20250108073340.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
+ liger_kernel_nightly-0.5.2.dev20250108073340.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
+ liger_kernel_nightly-0.5.2.dev20250108073340.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
+ liger_kernel_nightly-0.5.2.dev20250108073340.dist-info/RECORD,,