liger-kernel-nightly 0.3.1.dev20241102065152__py3-none-any.whl → 0.3.1.dev20241104210835__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/ops/cross_entropy.py +3 -3
- liger_kernel/ops/fused_linear_cross_entropy.py +10 -5
- liger_kernel/ops/fused_linear_jsd.py +8 -3
- liger_kernel/ops/kl_div.py +2 -2
- liger_kernel/ops/utils.py +5 -1
- liger_kernel/transformers/model/llama.py +0 -1
- {liger_kernel_nightly-0.3.1.dev20241102065152.dist-info → liger_kernel_nightly-0.3.1.dev20241104210835.dist-info}/METADATA +10 -2
- {liger_kernel_nightly-0.3.1.dev20241102065152.dist-info → liger_kernel_nightly-0.3.1.dev20241104210835.dist-info}/RECORD +17 -17
- {liger_kernel_nightly-0.3.1.dev20241102065152.dist-info → liger_kernel_nightly-0.3.1.dev20241104210835.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.3.1.dev20241102065152.dist-info → liger_kernel_nightly-0.3.1.dev20241104210835.dist-info}/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.3.1.dev20241102065152.dist-info → liger_kernel_nightly-0.3.1.dev20241104210835.dist-info}/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.3.1.dev20241102065152.dist-info → liger_kernel_nightly-0.3.1.dev20241104210835.dist-info}/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.3.1.dev20241102065152.dist-info → liger_kernel_nightly-0.3.1.dev20241104210835.dist-info}/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.3.1.dev20241102065152.dist-info → liger_kernel_nightly-0.3.1.dev20241104210835.dist-info}/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.3.1.dev20241102065152.dist-info → liger_kernel_nightly-0.3.1.dev20241104210835.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.3.1.dev20241102065152.dist-info → liger_kernel_nightly-0.3.1.dev20241104210835.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.3.1.dev20241102065152.dist-info → liger_kernel_nightly-0.3.1.dev20241104210835.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,7 @@ import torch
|
|
|
2
2
|
import triton
|
|
3
3
|
import triton.language as tl
|
|
4
4
|
|
|
5
|
-
from liger_kernel.ops.utils import element_mul_kernel
|
|
5
|
+
from liger_kernel.ops.utils import element_mul_kernel, is_hip
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
@triton.jit
|
|
@@ -194,7 +194,7 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
|
|
|
194
194
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
195
195
|
# TODO: 32 seems to give the best performance
|
|
196
196
|
# Performance is quite sensitive to num_warps
|
|
197
|
-
num_warps=32,
|
|
197
|
+
num_warps=32 if not is_hip() else 16,
|
|
198
198
|
)
|
|
199
199
|
|
|
200
200
|
loss = torch.sum(loss_1d)
|
|
@@ -219,7 +219,7 @@ def cross_entropy_backward(_input, grad_output):
|
|
|
219
219
|
grad_output,
|
|
220
220
|
V,
|
|
221
221
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
222
|
-
num_warps=32,
|
|
222
|
+
num_warps=32 if not is_hip() else 16,
|
|
223
223
|
)
|
|
224
224
|
|
|
225
225
|
return _input
|
|
@@ -2,7 +2,12 @@ import torch
|
|
|
2
2
|
import triton
|
|
3
3
|
|
|
4
4
|
from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel
|
|
5
|
-
from liger_kernel.ops.utils import
|
|
5
|
+
from liger_kernel.ops.utils import (
|
|
6
|
+
amp_custom_bwd,
|
|
7
|
+
amp_custom_fwd,
|
|
8
|
+
element_mul_kernel,
|
|
9
|
+
is_hip,
|
|
10
|
+
)
|
|
6
11
|
|
|
7
12
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
8
13
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
@@ -88,7 +93,7 @@ def fused_linear_cross_entropy_forward(
|
|
|
88
93
|
label_smoothing=label_smoothing,
|
|
89
94
|
reduction=reduction,
|
|
90
95
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
91
|
-
num_warps=32,
|
|
96
|
+
num_warps=32 if not is_hip() else 16,
|
|
92
97
|
)
|
|
93
98
|
|
|
94
99
|
# gradient of logits_chunk is computed in-place by the above triton kernel.
|
|
@@ -153,7 +158,7 @@ def fused_linear_cross_entropy_backward(
|
|
|
153
158
|
grad_output,
|
|
154
159
|
H,
|
|
155
160
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
156
|
-
num_warps=32,
|
|
161
|
+
num_warps=32 if not is_hip() else 16,
|
|
157
162
|
)
|
|
158
163
|
|
|
159
164
|
# handle grad_weight
|
|
@@ -167,7 +172,7 @@ def fused_linear_cross_entropy_backward(
|
|
|
167
172
|
grad_output,
|
|
168
173
|
H,
|
|
169
174
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
170
|
-
num_warps=32,
|
|
175
|
+
num_warps=32 if not is_hip() else 16,
|
|
171
176
|
)
|
|
172
177
|
|
|
173
178
|
if grad_bias is not None:
|
|
@@ -180,7 +185,7 @@ def fused_linear_cross_entropy_backward(
|
|
|
180
185
|
grad_output,
|
|
181
186
|
1,
|
|
182
187
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
183
|
-
num_warps=32,
|
|
188
|
+
num_warps=32 if not is_hip() else 16,
|
|
184
189
|
)
|
|
185
190
|
return grad_input, grad_weight, grad_bias
|
|
186
191
|
|
|
@@ -4,7 +4,12 @@ import torch
|
|
|
4
4
|
import triton
|
|
5
5
|
|
|
6
6
|
from liger_kernel.ops.jsd import _jsd_kernel
|
|
7
|
-
from liger_kernel.ops.utils import
|
|
7
|
+
from liger_kernel.ops.utils import (
|
|
8
|
+
amp_custom_bwd,
|
|
9
|
+
amp_custom_fwd,
|
|
10
|
+
element_mul_kernel,
|
|
11
|
+
is_hip,
|
|
12
|
+
)
|
|
8
13
|
|
|
9
14
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
10
15
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
@@ -147,7 +152,7 @@ def fused_linear_jsd_backward(grad_output, grad_input, grad_weight):
|
|
|
147
152
|
grad_output,
|
|
148
153
|
H,
|
|
149
154
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
150
|
-
num_warps=32,
|
|
155
|
+
num_warps=32 if not is_hip() else 16,
|
|
151
156
|
)
|
|
152
157
|
|
|
153
158
|
# handle grad_weight
|
|
@@ -161,7 +166,7 @@ def fused_linear_jsd_backward(grad_output, grad_input, grad_weight):
|
|
|
161
166
|
grad_output,
|
|
162
167
|
H,
|
|
163
168
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
164
|
-
num_warps=32,
|
|
169
|
+
num_warps=32 if not is_hip() else 16,
|
|
165
170
|
)
|
|
166
171
|
|
|
167
172
|
return grad_input, grad_weight
|
liger_kernel/ops/kl_div.py
CHANGED
|
@@ -4,13 +4,13 @@ import torch
|
|
|
4
4
|
import triton
|
|
5
5
|
import triton.language as tl
|
|
6
6
|
|
|
7
|
-
from liger_kernel.ops.utils import ensure_contiguous
|
|
7
|
+
from liger_kernel.ops.utils import ensure_contiguous, is_hip
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def get_num_warps(BLOCK_SIZE):
|
|
11
11
|
num_warps = 4
|
|
12
12
|
if BLOCK_SIZE >= 32768:
|
|
13
|
-
num_warps = 32
|
|
13
|
+
num_warps = 32 if not is_hip() else 16
|
|
14
14
|
elif BLOCK_SIZE >= 8192:
|
|
15
15
|
num_warps = 16
|
|
16
16
|
elif BLOCK_SIZE >= 2048:
|
liger_kernel/ops/utils.py
CHANGED
|
@@ -21,6 +21,10 @@ import triton.language as tl
|
|
|
21
21
|
from packaging.version import Version
|
|
22
22
|
|
|
23
23
|
|
|
24
|
+
def is_hip() -> bool:
|
|
25
|
+
return torch.version.hip is not None
|
|
26
|
+
|
|
27
|
+
|
|
24
28
|
def ensure_contiguous(fn):
|
|
25
29
|
@functools.wraps(fn)
|
|
26
30
|
def wrapper(ctx, *args, **kwargs):
|
|
@@ -47,7 +51,7 @@ def calculate_settings(n):
|
|
|
47
51
|
|
|
48
52
|
num_warps = 4
|
|
49
53
|
if BLOCK_SIZE >= 32768:
|
|
50
|
-
num_warps = 32
|
|
54
|
+
num_warps = 32 if not is_hip() else 16
|
|
51
55
|
elif BLOCK_SIZE >= 8192:
|
|
52
56
|
num_warps = 16
|
|
53
57
|
elif BLOCK_SIZE >= 2048:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: liger_kernel_nightly
|
|
3
|
-
Version: 0.3.1.
|
|
3
|
+
Version: 0.3.1.dev20241104210835
|
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
|
@@ -163,11 +163,18 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
|
|
|
163
163
|
|
|
164
164
|
## Installation
|
|
165
165
|
|
|
166
|
-
### Dependencies
|
|
166
|
+
### Dependencies
|
|
167
|
+
|
|
168
|
+
#### CUDA
|
|
167
169
|
|
|
168
170
|
- `torch >= 2.1.2`
|
|
169
171
|
- `triton >= 2.3.0`
|
|
170
172
|
|
|
173
|
+
#### ROCm
|
|
174
|
+
|
|
175
|
+
- `torch >= 2.5.0` Install according to the instruction in Pytorch official webpage.
|
|
176
|
+
- `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`)
|
|
177
|
+
|
|
171
178
|
### Optional Dependencies
|
|
172
179
|
|
|
173
180
|
- `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
|
|
@@ -197,6 +204,7 @@ pip install -e .
|
|
|
197
204
|
pip install -e .[transformers]
|
|
198
205
|
```
|
|
199
206
|
|
|
207
|
+
|
|
200
208
|
## Getting Started
|
|
201
209
|
|
|
202
210
|
There are a couple of ways to apply Liger kernels, depending on the level of customization required.
|
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
liger_kernel/env_report.py,sha256=LFUJ6UMkFFGPBYXBlqHFGy4bhsemEpSI-_1edSazlHI,1130
|
|
2
2
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
-
liger_kernel/ops/cross_entropy.py,sha256=
|
|
4
|
-
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=
|
|
5
|
-
liger_kernel/ops/fused_linear_jsd.py,sha256=
|
|
3
|
+
liger_kernel/ops/cross_entropy.py,sha256=23Di7l0T20OBj8K3-0PYEA5FCJrrbiKs3xMGyLlzbtg,11248
|
|
4
|
+
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=M-cF4BO-vvso2BIdk7-Q2FleeFPhqSQwZR1EirPC4OE,9456
|
|
5
|
+
liger_kernel/ops/fused_linear_jsd.py,sha256=5D_obamh08lGGTMyh85kBJD_aNjPhOYf4-TmCZ6m4s4,9626
|
|
6
6
|
liger_kernel/ops/geglu.py,sha256=MQL4zyzneZqZYUGPvb1QjI_EYT9_pKfSDgR25WD9jrI,4127
|
|
7
7
|
liger_kernel/ops/jsd.py,sha256=anWfdioucxZy4JQfTvbHBR-IQrZKeH-gBF1MHwwTuTQ,5781
|
|
8
|
-
liger_kernel/ops/kl_div.py,sha256=
|
|
8
|
+
liger_kernel/ops/kl_div.py,sha256=03FNXfvCb6M-56hhFepAFV9p6brArPR6KOKkdGD34mw,8374
|
|
9
9
|
liger_kernel/ops/layer_norm.py,sha256=unGMYMOPqtkM9aTrokhcqgPmsV2AUN7Yzv86isVB9OI,7422
|
|
10
10
|
liger_kernel/ops/rms_norm.py,sha256=9S9wyZLmzNyJlBxV4vbv4p5es7bGP-m_5wK9JC6JIdA,10911
|
|
11
11
|
liger_kernel/ops/rope.py,sha256=jrzaA9-6Orn44y_IIam9_YNPQxOFK2FrIRNfFea4EtU,8513
|
|
12
12
|
liger_kernel/ops/swiglu.py,sha256=Fwxtd76rhHKT9ShQAGca9RsnASplAVxtYKHmiT73_yA,2994
|
|
13
|
-
liger_kernel/ops/utils.py,sha256=
|
|
13
|
+
liger_kernel/ops/utils.py,sha256=3JSF--O7KT5Wa5BuO70M4h0XetxoZ_e9IoW9GRlxlBg,3777
|
|
14
14
|
liger_kernel/ops/experimental/embedding.py,sha256=LYR66dB-jhvhtUjeV4PnNro-n77J1mdlmpSLSxB3Y6U,4186
|
|
15
15
|
liger_kernel/ops/experimental/mm_int8int2.py,sha256=JpGVZCgRC6T8XMUJ_QbZRS2XU1bh0urIZphs5DTc1mY,13358
|
|
16
16
|
liger_kernel/transformers/__init__.py,sha256=gia-eBxr7TLxU0GdDf8AfCY4WgDlFLqIGSt7EoQGsBA,1336
|
|
@@ -31,7 +31,7 @@ liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx
|
|
|
31
31
|
liger_kernel/transformers/experimental/embedding.py,sha256=HpckiAMKM8-SRxKDcGTqortVxnjhwpZsfsp9lfjqfeM,895
|
|
32
32
|
liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
33
33
|
liger_kernel/transformers/model/gemma.py,sha256=EcdkGbSj_qroTDFl0Sc_HLyDyY0xcDhwrgkM_wkXnw8,4987
|
|
34
|
-
liger_kernel/transformers/model/llama.py,sha256=
|
|
34
|
+
liger_kernel/transformers/model/llama.py,sha256=RinsgC_eR-YNvZd2SHPQxZ4eyR3uViaTFCM3SvI5nks,10426
|
|
35
35
|
liger_kernel/transformers/model/mistral.py,sha256=_MQJrDntlxBO5cJwgTjr2rk2nNd5FAXVnzcTg_PEekQ,5079
|
|
36
36
|
liger_kernel/transformers/model/mixtral.py,sha256=51FghRY8aGBWat7KSgTeFDqdStDiXY3dEJepByNhEOE,5847
|
|
37
37
|
liger_kernel/transformers/model/mllama.py,sha256=S00P0pJrGHOWBx170TPYZbQ0djv0__m8Dqv1FvKZUyE,5926
|
|
@@ -40,14 +40,14 @@ liger_kernel/transformers/model/qwen2.py,sha256=3inWFXGHYT7wA10OR6bq3mDUBrr10AS5
|
|
|
40
40
|
liger_kernel/transformers/model/qwen2_vl.py,sha256=ymsm9aQpSUiSU12GY8FO608p9dSHOz4TCnNI1htX5bk,6975
|
|
41
41
|
liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
|
|
42
42
|
liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
|
|
43
|
-
liger_kernel_nightly-0.3.1.
|
|
44
|
-
liger_kernel_nightly-0.3.1.
|
|
45
|
-
liger_kernel_nightly-0.3.1.
|
|
46
|
-
liger_kernel_nightly-0.3.1.
|
|
47
|
-
liger_kernel_nightly-0.3.1.
|
|
48
|
-
liger_kernel_nightly-0.3.1.
|
|
49
|
-
liger_kernel_nightly-0.3.1.
|
|
50
|
-
liger_kernel_nightly-0.3.1.
|
|
51
|
-
liger_kernel_nightly-0.3.1.
|
|
52
|
-
liger_kernel_nightly-0.3.1.
|
|
53
|
-
liger_kernel_nightly-0.3.1.
|
|
43
|
+
liger_kernel_nightly-0.3.1.dev20241104210835.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
|
44
|
+
liger_kernel_nightly-0.3.1.dev20241104210835.dist-info/LICENSE-Apache-2.0,sha256=NRaCIsL9eblGS35gk4WKTC0usNYnR_mgRHJTKqz2_UE,11348
|
|
45
|
+
liger_kernel_nightly-0.3.1.dev20241104210835.dist-info/LICENSE-MIT-AutoAWQ,sha256=pfiOyInrAPY3xQbvV1i-gOqNZK7QEyIepT1IbqOYYYo,1067
|
|
46
|
+
liger_kernel_nightly-0.3.1.dev20241104210835.dist-info/LICENSE-MIT-Efficient-Cross-Entropy,sha256=PaC9HqyFYTy-ClS0H8Zfa2motJuTppjECXmjHwJcaOk,1063
|
|
47
|
+
liger_kernel_nightly-0.3.1.dev20241104210835.dist-info/LICENSE-MIT-llmc,sha256=kyFLt_XUcXS88CuxQt5-PjOcLjpJP2m-T4gtqZf3GLc,1071
|
|
48
|
+
liger_kernel_nightly-0.3.1.dev20241104210835.dist-info/LICENSE-MIT-triton,sha256=wL6W8IwsKiyHtzXubg8TCXhRZuo8S83EPdqXffYtqWg,1131
|
|
49
|
+
liger_kernel_nightly-0.3.1.dev20241104210835.dist-info/METADATA,sha256=KLe3u0yMc9Dipf9wsCM2DXabjlK1X-cgfqnJe5z-Lmk,27901
|
|
50
|
+
liger_kernel_nightly-0.3.1.dev20241104210835.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
|
51
|
+
liger_kernel_nightly-0.3.1.dev20241104210835.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
|
52
|
+
liger_kernel_nightly-0.3.1.dev20241104210835.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
|
53
|
+
liger_kernel_nightly-0.3.1.dev20241104210835.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|