liger-kernel-nightly 0.5.2.dev20241223042135__py3-none-any.whl → 0.5.2.dev20241229035411__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- liger_kernel/chunked_loss/cpo_loss.py +0 -1
- liger_kernel/ops/fused_linear_cross_entropy.py +9 -3
- {liger_kernel_nightly-0.5.2.dev20241223042135.dist-info → liger_kernel_nightly-0.5.2.dev20241229035411.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.2.dev20241223042135.dist-info → liger_kernel_nightly-0.5.2.dev20241229035411.dist-info}/RECORD +8 -8
- {liger_kernel_nightly-0.5.2.dev20241223042135.dist-info → liger_kernel_nightly-0.5.2.dev20241229035411.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223042135.dist-info → liger_kernel_nightly-0.5.2.dev20241229035411.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223042135.dist-info → liger_kernel_nightly-0.5.2.dev20241229035411.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223042135.dist-info → liger_kernel_nightly-0.5.2.dev20241229035411.dist-info}/top_level.txt +0 -0
@@ -59,6 +59,7 @@ def fused_linear_cross_entropy_forward(
|
|
59
59
|
logits_chunk = _input_chunk @ weight.t() # chunk_size x V
|
60
60
|
if bias is not None:
|
61
61
|
logits_chunk = logits_chunk + bias
|
62
|
+
|
62
63
|
target_chunk = target[start_idx:end_idx] # chunk_size,
|
63
64
|
|
64
65
|
n_rows = logits_chunk.shape[0]
|
@@ -112,7 +113,9 @@ def fused_linear_cross_entropy_forward(
|
|
112
113
|
if grad_weight is not None:
|
113
114
|
torch.addmm(
|
114
115
|
input=grad_weight,
|
115
|
-
mat1=logits_chunk.t()
|
116
|
+
mat1=logits_chunk.t().to(
|
117
|
+
_input_chunk.dtype
|
118
|
+
), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
|
116
119
|
mat2=_input_chunk,
|
117
120
|
out=grad_weight,
|
118
121
|
alpha=alpha,
|
@@ -127,13 +130,16 @@ def fused_linear_cross_entropy_forward(
|
|
127
130
|
alpha=alpha,
|
128
131
|
)
|
129
132
|
|
130
|
-
|
133
|
+
if reduction == "none":
|
134
|
+
loss = loss_1d
|
135
|
+
else:
|
136
|
+
loss = torch.sum(loss_1d)
|
131
137
|
return loss, grad_input, grad_weight, grad_bias
|
132
138
|
|
133
139
|
|
134
140
|
def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
|
135
141
|
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
136
|
-
if torch.
|
142
|
+
if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
137
143
|
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
138
144
|
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
139
145
|
BT, H = grad_input.shape
|
@@ -3,7 +3,7 @@ liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,17
|
|
3
3
|
liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
|
4
4
|
liger_kernel/chunked_loss/README.md,sha256=K6rucm6nqHpWCmxUOhBYcE3apwQxAy0TfRUippR7Icw,2243
|
5
5
|
liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
|
6
|
-
liger_kernel/chunked_loss/cpo_loss.py,sha256=
|
6
|
+
liger_kernel/chunked_loss/cpo_loss.py,sha256=L4Nk38Xh5Yfhah3Vsc_sN_Q75FWt1LA-xNNXzsK8iPM,3516
|
7
7
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=VYZMOafdvE8xlhvTtwjrz81tIzxR1mHF4lXdsADnIQg,4373
|
8
8
|
liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
|
9
9
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=M-QWvGPnWefYDn6Hr9bPn7diMNP5qrUaeWTb_zdMO4E,10265
|
@@ -12,7 +12,7 @@ liger_kernel/chunked_loss/orpo_loss.py,sha256=jbZxx-EjPK71A6CSyNzTOAIEQgAUjfvwSV
|
|
12
12
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=ZvDIjT9EQrbwzH2LNZMhv84SPsOHGi_Ywk95vgA0b_o,3736
|
13
13
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
liger_kernel/ops/cross_entropy.py,sha256=2OPIkSXeQAIfSCODYK45Jf8xrz7HoGqFHr1MHS_pijE,15895
|
15
|
-
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=
|
15
|
+
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=H2z-BFd9pGATlEzEeOw4EZwMoWsZtD8ovWJTkHD-9-s,9592
|
16
16
|
liger_kernel/ops/fused_linear_jsd.py,sha256=eKqaADj7LgWfoYqyH03tjrmhNTfJOF1Dhx_bWzBTnTU,9600
|
17
17
|
liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
|
18
18
|
liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
|
@@ -58,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
58
58
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=MId1S_MfA3pPVQA1rkiKxp-jZDNz8VmvZzXC-Kugol4,7662
|
59
59
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
60
60
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
61
|
-
liger_kernel_nightly-0.5.2.
|
62
|
-
liger_kernel_nightly-0.5.2.
|
63
|
-
liger_kernel_nightly-0.5.2.
|
64
|
-
liger_kernel_nightly-0.5.2.
|
65
|
-
liger_kernel_nightly-0.5.2.
|
66
|
-
liger_kernel_nightly-0.5.2.
|
61
|
+
liger_kernel_nightly-0.5.2.dev20241229035411.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
62
|
+
liger_kernel_nightly-0.5.2.dev20241229035411.dist-info/METADATA,sha256=bVGSgTflxiXCSgDtaCWRTo93kcV2WSuSFYnfDHI4XIw,21055
|
63
|
+
liger_kernel_nightly-0.5.2.dev20241229035411.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
64
|
+
liger_kernel_nightly-0.5.2.dev20241229035411.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
65
|
+
liger_kernel_nightly-0.5.2.dev20241229035411.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
66
|
+
liger_kernel_nightly-0.5.2.dev20241229035411.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|