liger-kernel-nightly 0.3.1.dev20241025165359__py3-none-any.whl → 0.3.1.dev20241030045958__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/ops/fused_linear_cross_entropy.py +4 -4
- liger_kernel/ops/fused_linear_jsd.py +3 -1
- liger_kernel/ops/jsd.py +1 -1
- liger_kernel/ops/utils.py +13 -0
- {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/RECORD +15 -15
- {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/WHEEL +1 -1
- {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/LICENSE-MIT-Efficient Cross Entropy +0 -0
- {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,7 @@ import torch
|
|
|
2
2
|
import triton
|
|
3
3
|
|
|
4
4
|
from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel
|
|
5
|
-
from liger_kernel.ops.utils import element_mul_kernel
|
|
5
|
+
from liger_kernel.ops.utils import amp_custom_bwd, amp_custom_fwd, element_mul_kernel
|
|
6
6
|
|
|
7
7
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
8
8
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
@@ -19,9 +19,7 @@ def fused_linear_cross_entropy_forward(
|
|
|
19
19
|
label_smoothing=0.0,
|
|
20
20
|
reduction="mean",
|
|
21
21
|
):
|
|
22
|
-
dtype =
|
|
23
|
-
torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype
|
|
24
|
-
)
|
|
22
|
+
dtype = _input.dtype
|
|
25
23
|
device = _input.device
|
|
26
24
|
|
|
27
25
|
# inputs have shape: BT x H
|
|
@@ -189,6 +187,7 @@ def fused_linear_cross_entropy_backward(
|
|
|
189
187
|
|
|
190
188
|
class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
191
189
|
@staticmethod
|
|
190
|
+
@amp_custom_fwd
|
|
192
191
|
def forward(
|
|
193
192
|
ctx,
|
|
194
193
|
_input,
|
|
@@ -228,6 +227,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
228
227
|
return loss
|
|
229
228
|
|
|
230
229
|
@staticmethod
|
|
230
|
+
@amp_custom_bwd
|
|
231
231
|
def backward(ctx, grad_output):
|
|
232
232
|
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
|
|
233
233
|
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
|
|
@@ -92,7 +92,9 @@ def fused_linear_jsd_forward(
|
|
|
92
92
|
dX_ptr=student_prob_chunk,
|
|
93
93
|
dX_stride=student_prob_chunk.stride(-2),
|
|
94
94
|
label_ptr=(
|
|
95
|
-
shift_labels
|
|
95
|
+
shift_labels[start_idx:end_idx]
|
|
96
|
+
if has_label
|
|
97
|
+
else torch.empty(1, device=device)
|
|
96
98
|
), # dummy ptr if no label
|
|
97
99
|
beta=jsd_beta,
|
|
98
100
|
n_non_ignore=n_non_ignore,
|
liger_kernel/ops/jsd.py
CHANGED
liger_kernel/ops/utils.py
CHANGED
|
@@ -12,6 +12,7 @@ Modifications made by Yanning Chen, 2024.
|
|
|
12
12
|
|
|
13
13
|
import functools
|
|
14
14
|
import importlib
|
|
15
|
+
import operator
|
|
15
16
|
from typing import Callable
|
|
16
17
|
|
|
17
18
|
import torch
|
|
@@ -63,6 +64,18 @@ def compare_version(package: str, operator: Callable, target: str):
|
|
|
63
64
|
return operator(pkg_version, Version(target))
|
|
64
65
|
|
|
65
66
|
|
|
67
|
+
def get_amp_custom_fwd_bwd() -> Callable:
|
|
68
|
+
if compare_version("torch", operator.ge, "2.4.0"):
|
|
69
|
+
return (
|
|
70
|
+
functools.partial(torch.amp.custom_fwd, device_type="cuda"),
|
|
71
|
+
functools.partial(torch.amp.custom_bwd, device_type="cuda"),
|
|
72
|
+
)
|
|
73
|
+
return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd()
|
|
77
|
+
|
|
78
|
+
|
|
66
79
|
torch_to_triton_dtype = {
|
|
67
80
|
torch.float32: tl.float32,
|
|
68
81
|
torch.float16: tl.float16,
|
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
liger_kernel/env_report.py,sha256=LFUJ6UMkFFGPBYXBlqHFGy4bhsemEpSI-_1edSazlHI,1130
|
|
2
2
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
liger_kernel/ops/cross_entropy.py,sha256=OB3nvIONLB_sj9LO6UQv1qLnf861k-pR58RtwgoiyYA,11192
|
|
4
|
-
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=
|
|
5
|
-
liger_kernel/ops/fused_linear_jsd.py,sha256=
|
|
4
|
+
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=qg7qBQFLDJClnkUOGhFFHPSW_x7rPvQekbm_4OOYxys,9331
|
|
5
|
+
liger_kernel/ops/fused_linear_jsd.py,sha256=eZ8y4GPtPRE0QcNNMLX8l4gSEMPzA3ZuzknfbAbiREA,9234
|
|
6
6
|
liger_kernel/ops/geglu.py,sha256=MQL4zyzneZqZYUGPvb1QjI_EYT9_pKfSDgR25WD9jrI,4127
|
|
7
|
-
liger_kernel/ops/jsd.py,sha256=
|
|
7
|
+
liger_kernel/ops/jsd.py,sha256=anWfdioucxZy4JQfTvbHBR-IQrZKeH-gBF1MHwwTuTQ,5781
|
|
8
8
|
liger_kernel/ops/kl_div.py,sha256=qnmtFQwuO3FR7Ovup_DDzpkD1A1LpwOaWlcO6K9ysHk,8342
|
|
9
9
|
liger_kernel/ops/layer_norm.py,sha256=unGMYMOPqtkM9aTrokhcqgPmsV2AUN7Yzv86isVB9OI,7422
|
|
10
10
|
liger_kernel/ops/rms_norm.py,sha256=9S9wyZLmzNyJlBxV4vbv4p5es7bGP-m_5wK9JC6JIdA,10911
|
|
11
11
|
liger_kernel/ops/rope.py,sha256=jrzaA9-6Orn44y_IIam9_YNPQxOFK2FrIRNfFea4EtU,8513
|
|
12
12
|
liger_kernel/ops/swiglu.py,sha256=Fwxtd76rhHKT9ShQAGca9RsnASplAVxtYKHmiT73_yA,2994
|
|
13
|
-
liger_kernel/ops/utils.py,sha256=
|
|
13
|
+
liger_kernel/ops/utils.py,sha256=w0QT3ynUK2vYUAgsVvfoENTUu5L-2TuB3IYt8JaXlNA,3688
|
|
14
14
|
liger_kernel/ops/experimental/embedding.py,sha256=LYR66dB-jhvhtUjeV4PnNro-n77J1mdlmpSLSxB3Y6U,4186
|
|
15
15
|
liger_kernel/ops/experimental/mm_int8int2.py,sha256=JpGVZCgRC6T8XMUJ_QbZRS2XU1bh0urIZphs5DTc1mY,13358
|
|
16
16
|
liger_kernel/transformers/__init__.py,sha256=gia-eBxr7TLxU0GdDf8AfCY4WgDlFLqIGSt7EoQGsBA,1336
|
|
@@ -40,14 +40,14 @@ liger_kernel/transformers/model/qwen2.py,sha256=3inWFXGHYT7wA10OR6bq3mDUBrr10AS5
|
|
|
40
40
|
liger_kernel/transformers/model/qwen2_vl.py,sha256=ymsm9aQpSUiSU12GY8FO608p9dSHOz4TCnNI1htX5bk,6975
|
|
41
41
|
liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
|
|
42
42
|
liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
|
|
43
|
-
liger_kernel_nightly-0.3.1.
|
|
44
|
-
liger_kernel_nightly-0.3.1.
|
|
45
|
-
liger_kernel_nightly-0.3.1.
|
|
46
|
-
liger_kernel_nightly-0.3.1.
|
|
47
|
-
liger_kernel_nightly-0.3.1.
|
|
48
|
-
liger_kernel_nightly-0.3.1.
|
|
49
|
-
liger_kernel_nightly-0.3.1.
|
|
50
|
-
liger_kernel_nightly-0.3.1.
|
|
51
|
-
liger_kernel_nightly-0.3.1.
|
|
52
|
-
liger_kernel_nightly-0.3.1.
|
|
53
|
-
liger_kernel_nightly-0.3.1.
|
|
43
|
+
liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
|
44
|
+
liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/LICENSE-Apache-2.0,sha256=NRaCIsL9eblGS35gk4WKTC0usNYnR_mgRHJTKqz2_UE,11348
|
|
45
|
+
liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/LICENSE-MIT-AutoAWQ,sha256=pfiOyInrAPY3xQbvV1i-gOqNZK7QEyIepT1IbqOYYYo,1067
|
|
46
|
+
liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/LICENSE-MIT-Efficient Cross Entropy,sha256=PaC9HqyFYTy-ClS0H8Zfa2motJuTppjECXmjHwJcaOk,1063
|
|
47
|
+
liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/LICENSE-MIT-llmc,sha256=kyFLt_XUcXS88CuxQt5-PjOcLjpJP2m-T4gtqZf3GLc,1071
|
|
48
|
+
liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/LICENSE-MIT-triton,sha256=wL6W8IwsKiyHtzXubg8TCXhRZuo8S83EPdqXffYtqWg,1131
|
|
49
|
+
liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/METADATA,sha256=aSchnUTDzM__RwWSMgJNr8BqjwjNzTd-ja3hz2NTBKc,27717
|
|
50
|
+
liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
|
51
|
+
liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
|
52
|
+
liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
|
53
|
+
liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|