liger-kernel-nightly 0.5.2.dev20241223042135__py3-none-any.whl → 0.5.2.dev20241228022953__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- liger_kernel/chunked_loss/cpo_loss.py +0 -1
- liger_kernel/ops/fused_linear_cross_entropy.py +5 -2
- {liger_kernel_nightly-0.5.2.dev20241223042135.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.2.dev20241223042135.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/RECORD +8 -8
- {liger_kernel_nightly-0.5.2.dev20241223042135.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223042135.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223042135.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223042135.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/top_level.txt +0 -0
@@ -127,13 +127,16 @@ def fused_linear_cross_entropy_forward(
|
|
127
127
|
alpha=alpha,
|
128
128
|
)
|
129
129
|
|
130
|
-
|
130
|
+
if reduction == "none":
|
131
|
+
loss = loss_1d
|
132
|
+
else:
|
133
|
+
loss = torch.sum(loss_1d)
|
131
134
|
return loss, grad_input, grad_weight, grad_bias
|
132
135
|
|
133
136
|
|
134
137
|
def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
|
135
138
|
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
136
|
-
if torch.
|
139
|
+
if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
137
140
|
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
138
141
|
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
139
142
|
BT, H = grad_input.shape
|
@@ -3,7 +3,7 @@ liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,17
|
|
3
3
|
liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
|
4
4
|
liger_kernel/chunked_loss/README.md,sha256=K6rucm6nqHpWCmxUOhBYcE3apwQxAy0TfRUippR7Icw,2243
|
5
5
|
liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
|
6
|
-
liger_kernel/chunked_loss/cpo_loss.py,sha256=
|
6
|
+
liger_kernel/chunked_loss/cpo_loss.py,sha256=L4Nk38Xh5Yfhah3Vsc_sN_Q75FWt1LA-xNNXzsK8iPM,3516
|
7
7
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=VYZMOafdvE8xlhvTtwjrz81tIzxR1mHF4lXdsADnIQg,4373
|
8
8
|
liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
|
9
9
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=M-QWvGPnWefYDn6Hr9bPn7diMNP5qrUaeWTb_zdMO4E,10265
|
@@ -12,7 +12,7 @@ liger_kernel/chunked_loss/orpo_loss.py,sha256=jbZxx-EjPK71A6CSyNzTOAIEQgAUjfvwSV
|
|
12
12
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=ZvDIjT9EQrbwzH2LNZMhv84SPsOHGi_Ywk95vgA0b_o,3736
|
13
13
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
liger_kernel/ops/cross_entropy.py,sha256=2OPIkSXeQAIfSCODYK45Jf8xrz7HoGqFHr1MHS_pijE,15895
|
15
|
-
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=
|
15
|
+
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=LR0zLL8JYMhk9e22jmBxU4lwEYic3YqMAG3837yaHmM,9418
|
16
16
|
liger_kernel/ops/fused_linear_jsd.py,sha256=eKqaADj7LgWfoYqyH03tjrmhNTfJOF1Dhx_bWzBTnTU,9600
|
17
17
|
liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
|
18
18
|
liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
|
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
58
58
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=MId1S_MfA3pPVQA1rkiKxp-jZDNz8VmvZzXC-Kugol4,7662
|
59
59
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
60
60
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
61
|
-
liger_kernel_nightly-0.5.2.
|
62
|
-
liger_kernel_nightly-0.5.2.
|
63
|
-
liger_kernel_nightly-0.5.2.
|
64
|
-
liger_kernel_nightly-0.5.2.
|
65
|
-
liger_kernel_nightly-0.5.2.
|
66
|
-
liger_kernel_nightly-0.5.2.
|
61
|
+
liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
62
|
+
liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/METADATA,sha256=Z5fzI-xpYPtjwawEGwIw-LRJUIeY1VEdDUK9wgklR7w,21055
|
63
|
+
liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
64
|
+
liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
65
|
+
liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
66
|
+
liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|