liger-kernel-nightly 0.5.2.dev20241223042135__py3-none-any.whl → 0.5.2.dev20241229035411__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -33,7 +33,6 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
33
33
  loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
34
34
  full_target.shape[0] // 2
35
35
  )
36
-
37
36
  return loss
38
37
 
39
38
  @staticmethod
@@ -59,6 +59,7 @@ def fused_linear_cross_entropy_forward(
59
59
  logits_chunk = _input_chunk @ weight.t() # chunk_size x V
60
60
  if bias is not None:
61
61
  logits_chunk = logits_chunk + bias
62
+
62
63
  target_chunk = target[start_idx:end_idx] # chunk_size,
63
64
 
64
65
  n_rows = logits_chunk.shape[0]
@@ -112,7 +113,9 @@ def fused_linear_cross_entropy_forward(
112
113
  if grad_weight is not None:
113
114
  torch.addmm(
114
115
  input=grad_weight,
115
- mat1=logits_chunk.t(),
116
+ mat1=logits_chunk.t().to(
117
+ _input_chunk.dtype
118
+ ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
116
119
  mat2=_input_chunk,
117
120
  out=grad_weight,
118
121
  alpha=alpha,
@@ -127,13 +130,16 @@ def fused_linear_cross_entropy_forward(
127
130
  alpha=alpha,
128
131
  )
129
132
 
130
- loss = torch.sum(loss_1d)
133
+ if reduction == "none":
134
+ loss = loss_1d
135
+ else:
136
+ loss = torch.sum(loss_1d)
131
137
  return loss, grad_input, grad_weight, grad_bias
132
138
 
133
139
 
134
140
  def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
135
141
  # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
136
- if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
142
+ if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
137
143
  # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
138
144
  # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
139
145
  BT, H = grad_input.shape
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241223042135
3
+ Version: 0.5.2.dev20241229035411
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -3,7 +3,7 @@ liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,17
3
3
  liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
4
4
  liger_kernel/chunked_loss/README.md,sha256=K6rucm6nqHpWCmxUOhBYcE3apwQxAy0TfRUippR7Icw,2243
5
5
  liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
6
- liger_kernel/chunked_loss/cpo_loss.py,sha256=H-BU2QC5GzNQ4NnTM6TLgwvo-Eoh5YAE-Q_j1dX_w0g,3517
6
+ liger_kernel/chunked_loss/cpo_loss.py,sha256=L4Nk38Xh5Yfhah3Vsc_sN_Q75FWt1LA-xNNXzsK8iPM,3516
7
7
  liger_kernel/chunked_loss/dpo_loss.py,sha256=VYZMOafdvE8xlhvTtwjrz81tIzxR1mHF4lXdsADnIQg,4373
8
8
  liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
9
9
  liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=M-QWvGPnWefYDn6Hr9bPn7diMNP5qrUaeWTb_zdMO4E,10265
@@ -12,7 +12,7 @@ liger_kernel/chunked_loss/orpo_loss.py,sha256=jbZxx-EjPK71A6CSyNzTOAIEQgAUjfvwSV
12
12
  liger_kernel/chunked_loss/simpo_loss.py,sha256=ZvDIjT9EQrbwzH2LNZMhv84SPsOHGi_Ywk95vgA0b_o,3736
13
13
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  liger_kernel/ops/cross_entropy.py,sha256=2OPIkSXeQAIfSCODYK45Jf8xrz7HoGqFHr1MHS_pijE,15895
15
- liger_kernel/ops/fused_linear_cross_entropy.py,sha256=ObNZjgYlCvigbgKl-FAjHAvk90wiwJ-4Wrf8JUHmlLQ,9346
15
+ liger_kernel/ops/fused_linear_cross_entropy.py,sha256=H2z-BFd9pGATlEzEeOw4EZwMoWsZtD8ovWJTkHD-9-s,9592
16
16
  liger_kernel/ops/fused_linear_jsd.py,sha256=eKqaADj7LgWfoYqyH03tjrmhNTfJOF1Dhx_bWzBTnTU,9600
17
17
  liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
18
18
  liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
58
58
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=MId1S_MfA3pPVQA1rkiKxp-jZDNz8VmvZzXC-Kugol4,7662
59
59
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
60
60
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
61
- liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
- liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/METADATA,sha256=diXsKJ9zCLk-w9SCZLWWx-xN0ZP8-W51KrgpISmaxn4,21055
63
- liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
- liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
- liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
- liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/RECORD,,
61
+ liger_kernel_nightly-0.5.2.dev20241229035411.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
62
+ liger_kernel_nightly-0.5.2.dev20241229035411.dist-info/METADATA,sha256=bVGSgTflxiXCSgDtaCWRTo93kcV2WSuSFYnfDHI4XIw,21055
63
+ liger_kernel_nightly-0.5.2.dev20241229035411.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
64
+ liger_kernel_nightly-0.5.2.dev20241229035411.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
+ liger_kernel_nightly-0.5.2.dev20241229035411.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
66
+ liger_kernel_nightly-0.5.2.dev20241229035411.dist-info/RECORD,,