liger-kernel 0.0.0__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/cross_entropy.py +4 -33
- liger_kernel/ops/fused_linear_cross_entropy.py +6 -6
- liger_kernel/ops/geglu.py +14 -3
- liger_kernel/ops/rms_norm.py +40 -22
- liger_kernel/ops/swiglu.py +16 -16
- liger_kernel/ops/utils.py +12 -0
- liger_kernel/transformers/__init__.py +1 -0
- liger_kernel/transformers/model/llama.py +3 -0
- liger_kernel/transformers/monkey_patch.py +35 -8
- liger_kernel/transformers/trainer_integration.py +45 -0
- liger_kernel/triton/monkey_patch.py +0 -2
- liger_kernel-0.1.0.dist-info/LICENSE +23 -0
- {liger_kernel-0.0.0.dist-info → liger_kernel-0.1.0.dist-info}/METADATA +3 -1
- liger_kernel-0.1.0.dist-info/NOTICE +4 -0
- liger_kernel-0.1.0.dist-info/RECORD +27 -0
- liger_kernel-0.0.0.dist-info/RECORD +0 -24
- {liger_kernel-0.0.0.dist-info → liger_kernel-0.1.0.dist-info}/WHEEL +0 -0
- {liger_kernel-0.0.0.dist-info → liger_kernel-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -17,7 +17,7 @@ def liger_cross_entropy_kernel(
|
|
|
17
17
|
BLOCK_SIZE: tl.constexpr,
|
|
18
18
|
):
|
|
19
19
|
"""
|
|
20
|
-
This kernel computes both cross entropy loss and the gradient of the
|
|
20
|
+
This kernel computes both cross entropy loss and the gradient of the input.
|
|
21
21
|
We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.
|
|
22
22
|
|
|
23
23
|
Parameters:
|
|
@@ -34,7 +34,7 @@ def liger_cross_entropy_kernel(
|
|
|
34
34
|
"""
|
|
35
35
|
|
|
36
36
|
# https://github.com/triton-lang/triton/issues/1058
|
|
37
|
-
#
|
|
37
|
+
# If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64
|
|
38
38
|
program_id = tl.program_id(0).to(tl.int64)
|
|
39
39
|
|
|
40
40
|
# 1. Load Y_ptr first because if the target is ignore_index, we can return right away
|
|
@@ -90,13 +90,7 @@ def liger_cross_entropy_kernel(
|
|
|
90
90
|
tl.debug_barrier()
|
|
91
91
|
|
|
92
92
|
# 5. Calculate the loss
|
|
93
|
-
# Old Approach: Problematic LogSoftmax
|
|
94
|
-
# min of bfloat16 and float32 is 1e-38, so we set a value larger than that but small enough
|
|
95
|
-
# This will overflow if X_y * n_non_ignore is too small. Even if we add a tiny epsilon, it will still overflow
|
|
96
|
-
# loss = -tl.log(X_y * n_non_ignore)
|
|
97
93
|
|
|
98
|
-
# New Approach: Safe LogSoftmax
|
|
99
|
-
# Therefore, we propose to use safe logsoftmax by reordering the formula.
|
|
100
94
|
# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
|
|
101
95
|
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
|
|
102
96
|
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
|
|
@@ -114,7 +108,7 @@ def liger_cross_entropy_kernel(
|
|
|
114
108
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
115
109
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
116
110
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
117
|
-
MAX_FUSED_SIZE = 65536 // 2 #
|
|
111
|
+
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
|
|
118
112
|
|
|
119
113
|
|
|
120
114
|
@triton.jit
|
|
@@ -184,28 +178,6 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
184
178
|
n_non_ignore = (target != ignore_index).sum().item()
|
|
185
179
|
|
|
186
180
|
# ensure _input and target are contiguous in the last dimension
|
|
187
|
-
# there are examples that are NOT contiguous overall but contiguous in the last dimension
|
|
188
|
-
####################################################################
|
|
189
|
-
# tensor = torch.arange(1, 21).reshape(5, -1)
|
|
190
|
-
# print(tensor)
|
|
191
|
-
# tensor([[ 1, 2, 3, 4],
|
|
192
|
-
# [ 5, 6, 7, 8],
|
|
193
|
-
# [ 9, 10, 11, 12],
|
|
194
|
-
# [13, 14, 15, 16],
|
|
195
|
-
# [17, 18, 19, 20]])
|
|
196
|
-
# print(tensor.is_contiguous())
|
|
197
|
-
# True
|
|
198
|
-
# slice = tensor[::2, :]
|
|
199
|
-
# print(slice)
|
|
200
|
-
# tensor([[ 1, 2, 3, 4],
|
|
201
|
-
# [ 9, 10, 11, 12],
|
|
202
|
-
# [17, 18, 19, 20]])
|
|
203
|
-
# print(slice.is_contiguous())
|
|
204
|
-
# False
|
|
205
|
-
# print(slice.stride())
|
|
206
|
-
# (8, 1)
|
|
207
|
-
# slice is NOT a contiguous tensor but is contiguous in the last dimension, CE kernel can execute because the stride is 8, and each triton program will jump by 8
|
|
208
|
-
####################################################################
|
|
209
181
|
if _input.stride(-1) != 1:
|
|
210
182
|
_input = _input.contiguous()
|
|
211
183
|
if target.stride(-1) != 1:
|
|
@@ -252,10 +224,9 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
252
224
|
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
|
253
225
|
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
254
226
|
pass
|
|
227
|
+
|
|
255
228
|
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
|
256
229
|
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
|
257
|
-
# Although the Brew trainer should only perform backward once, it encounters this issue.
|
|
258
|
-
# https://github.com/triton-lang/triton/issues/4004
|
|
259
230
|
else:
|
|
260
231
|
BT, V = _input.shape
|
|
261
232
|
n_rows = BT
|
|
@@ -1,8 +1,3 @@
|
|
|
1
|
-
"""Fusing the last linear layer with cross-entropy loss
|
|
2
|
-
|
|
3
|
-
Reference: https://github.com/mgmalek/efficient_cross_entropy
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
1
|
import torch
|
|
7
2
|
import triton
|
|
8
3
|
|
|
@@ -11,13 +6,16 @@ from liger_kernel.ops.cross_entropy import element_mul, liger_cross_entropy_kern
|
|
|
11
6
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
12
7
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
13
8
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
14
|
-
MAX_FUSED_SIZE = 65536 // 2
|
|
9
|
+
MAX_FUSED_SIZE = 65536 // 2
|
|
15
10
|
|
|
16
11
|
|
|
17
12
|
class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
18
13
|
@staticmethod
|
|
19
14
|
def forward(ctx, _input, linear, target, ignore_index):
|
|
20
15
|
"""
|
|
16
|
+
Fusing the last linear layer with cross-entropy loss
|
|
17
|
+
Reference: https://github.com/mgmalek/efficient_cross_entropy
|
|
18
|
+
|
|
21
19
|
Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding
|
|
22
20
|
the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can
|
|
23
21
|
compute the gradient at the forward pass. By doing so, we don't have to store the _input and target
|
|
@@ -54,6 +52,8 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
54
52
|
|
|
55
53
|
grad_linear = torch.zeros_like(linear, device=device)
|
|
56
54
|
grad_input = torch.zeros_like(_input, device=device)
|
|
55
|
+
|
|
56
|
+
# we use fp32 for loss accumulator
|
|
57
57
|
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
|
58
58
|
|
|
59
59
|
total_n_non_ignore = (target != ignore_index).sum().item()
|
liger_kernel/ops/geglu.py
CHANGED
|
@@ -1,8 +1,19 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import triton
|
|
3
5
|
import triton.language as tl
|
|
4
6
|
|
|
5
|
-
from liger_kernel.ops.utils import
|
|
7
|
+
from liger_kernel.ops.utils import (
|
|
8
|
+
calculate_settings,
|
|
9
|
+
compare_version,
|
|
10
|
+
ensure_contiguous,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
|
14
|
+
from triton.language.extra.libdevice import tanh
|
|
15
|
+
else:
|
|
16
|
+
from triton.language.math import tanh
|
|
6
17
|
|
|
7
18
|
|
|
8
19
|
@triton.jit
|
|
@@ -26,7 +37,7 @@ def _geglu_tanh_forward_kernel(
|
|
|
26
37
|
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
|
|
27
38
|
a_cubed = a_row * a_row * a_row
|
|
28
39
|
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
29
|
-
tanh_result =
|
|
40
|
+
tanh_result = tanh(tanh_arg)
|
|
30
41
|
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
31
42
|
c_row = geglu_a * b_row
|
|
32
43
|
tl.store(c + col_offsets, c_row, mask=mask)
|
|
@@ -54,7 +65,7 @@ def _geglu_tanh_backward_kernel(
|
|
|
54
65
|
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
|
|
55
66
|
a_cubed = a_row * a_row * a_row
|
|
56
67
|
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
57
|
-
tanh_result =
|
|
68
|
+
tanh_result = tanh(tanh_arg)
|
|
58
69
|
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
59
70
|
|
|
60
71
|
db_row = dc_row * geglu_a
|
liger_kernel/ops/rms_norm.py
CHANGED
|
@@ -20,9 +20,12 @@ def _rms_norm_forward(
|
|
|
20
20
|
BLOCK_SIZE: tl.constexpr,
|
|
21
21
|
):
|
|
22
22
|
"""
|
|
23
|
+
y_i = (x_i / (RMS)) * wi, RMS = sqrt(sum(x_i^2) / N)
|
|
24
|
+
|
|
23
25
|
Reference:
|
|
24
26
|
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
|
25
27
|
2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
|
|
28
|
+
3. https://arxiv.org/pdf/1910.07467
|
|
26
29
|
"""
|
|
27
30
|
|
|
28
31
|
row_idx = tl.program_id(0)
|
|
@@ -36,16 +39,17 @@ def _rms_norm_forward(
|
|
|
36
39
|
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
|
|
37
40
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
38
41
|
|
|
39
|
-
|
|
40
|
-
|
|
42
|
+
mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
|
|
43
|
+
inv_rms = tl.math.rsqrt(mean_square + eps)
|
|
41
44
|
|
|
42
|
-
#
|
|
43
|
-
|
|
45
|
+
# We can save time by caching rms with minimal memory overhead
|
|
46
|
+
# because rms is much smaller compared to X_row, as rms is for each row.
|
|
47
|
+
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
|
|
48
|
+
tl.store(r_ptr, inv_rms)
|
|
44
49
|
|
|
45
|
-
|
|
50
|
+
Y_row = X_row * inv_rms * W_row
|
|
46
51
|
|
|
47
|
-
|
|
48
|
-
tl.store(Y_ptr + col_offsets, output, mask=mask)
|
|
52
|
+
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
|
|
49
53
|
|
|
50
54
|
|
|
51
55
|
@triton.jit
|
|
@@ -65,9 +69,10 @@ def _rms_norm_backward(
|
|
|
65
69
|
BLOCK_SIZE: tl.constexpr,
|
|
66
70
|
):
|
|
67
71
|
"""
|
|
68
|
-
dx = (1 /
|
|
69
|
-
dw = sum(dy * (x /
|
|
72
|
+
dx = (1 / RMS) * [dy * w - (1 / N) * (1 / RMS^2) * ((dy * w) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
|
|
73
|
+
dw = sum(dy * (x / RMS)). summation over BxT dimension
|
|
70
74
|
"""
|
|
75
|
+
|
|
71
76
|
row_idx = tl.program_id(0)
|
|
72
77
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
73
78
|
mask = col_offsets < n_cols
|
|
@@ -81,34 +86,42 @@ def _rms_norm_backward(
|
|
|
81
86
|
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
|
|
82
87
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
83
88
|
|
|
84
|
-
# Get
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
normed = X_row * inv_var
|
|
89
|
+
# Get cached rms
|
|
90
|
+
inv_rms_row = tl.load(r_ptr)
|
|
88
91
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
92
|
+
dX_row = (inv_rms_row) * (
|
|
93
|
+
dY_row * W_row
|
|
94
|
+
- (1 / n_cols)
|
|
95
|
+
* inv_rms_row
|
|
96
|
+
* inv_rms_row
|
|
97
|
+
* tl.sum(dY_row * W_row * X_row, axis=0)
|
|
98
|
+
* X_row
|
|
99
|
+
)
|
|
100
|
+
tl.store(dY_ptr + col_offsets, dX_row, mask=mask)
|
|
95
101
|
|
|
96
102
|
# calculate the gradient of W
|
|
97
|
-
|
|
103
|
+
dW_row = dY_row * X_row * inv_rms_row
|
|
104
|
+
tl.store(dW_ptr + col_offsets, dW_row, mask=mask)
|
|
98
105
|
|
|
99
106
|
|
|
100
107
|
class LigerRMSNormFunction(torch.autograd.Function):
|
|
101
108
|
@staticmethod
|
|
102
109
|
@ensure_contiguous
|
|
103
110
|
def forward(ctx, X, W, eps):
|
|
111
|
+
"""
|
|
112
|
+
X: (B, T, H) or (BxT, H)
|
|
113
|
+
W: (H,)
|
|
114
|
+
"""
|
|
115
|
+
|
|
104
116
|
shape = X.shape
|
|
105
117
|
dim = shape[-1]
|
|
106
118
|
X = X.view(-1, dim)
|
|
107
119
|
n_rows, n_cols = X.shape
|
|
108
120
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
109
121
|
|
|
110
|
-
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=
|
|
111
|
-
r
|
|
122
|
+
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
123
|
+
# r is to cache (1/rms) for each row
|
|
124
|
+
r = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
112
125
|
|
|
113
126
|
# Check constraints.
|
|
114
127
|
assert (
|
|
@@ -139,6 +152,10 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
139
152
|
@staticmethod
|
|
140
153
|
@ensure_contiguous
|
|
141
154
|
def backward(ctx, dY):
|
|
155
|
+
"""
|
|
156
|
+
Y: (B, T, H) or (BxT, H)
|
|
157
|
+
"""
|
|
158
|
+
|
|
142
159
|
shape = dY.shape
|
|
143
160
|
dim = shape[-1]
|
|
144
161
|
dY = dY.view(-1, dim)
|
|
@@ -146,6 +163,7 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
146
163
|
n_rows, n_cols = dY.shape
|
|
147
164
|
dW = torch.zeros_like(X)
|
|
148
165
|
|
|
166
|
+
# Here we use dY to store the value of dX to save memory
|
|
149
167
|
_rms_norm_backward[(n_rows,)](
|
|
150
168
|
dY,
|
|
151
169
|
dY.stride(0),
|
liger_kernel/ops/swiglu.py
CHANGED
|
@@ -12,43 +12,43 @@ def silu(x):
|
|
|
12
12
|
|
|
13
13
|
@triton.jit
|
|
14
14
|
def _swiglu_forward_kernel(
|
|
15
|
-
|
|
15
|
+
a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
16
16
|
):
|
|
17
17
|
program_id = tl.program_id(0)
|
|
18
18
|
|
|
19
19
|
# locate start index
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
20
|
+
a_ptr += program_id * stride
|
|
21
|
+
b_ptr += program_id * stride
|
|
22
|
+
c_ptr += program_id * stride
|
|
23
23
|
|
|
24
24
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
25
25
|
mask = col_offsets < n_cols
|
|
26
26
|
|
|
27
27
|
# sigmoid requires type float32
|
|
28
|
-
a_row = tl.load(
|
|
29
|
-
b_row = tl.load(
|
|
28
|
+
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
|
29
|
+
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
|
|
30
30
|
c_row = silu(a_row) * b_row
|
|
31
|
-
tl.store(
|
|
31
|
+
tl.store(c_ptr + col_offsets, c_row, mask=mask)
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
@triton.jit
|
|
35
35
|
def _swiglu_backward_kernel(
|
|
36
|
-
|
|
36
|
+
dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
37
37
|
):
|
|
38
38
|
program_id = tl.program_id(0)
|
|
39
39
|
|
|
40
40
|
# locate start index
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
41
|
+
dc_ptr += program_id * stride
|
|
42
|
+
a_ptr += program_id * stride
|
|
43
|
+
b_ptr += program_id * stride
|
|
44
44
|
|
|
45
45
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
46
46
|
mask = col_offsets < n_cols
|
|
47
47
|
|
|
48
|
-
dc_row = tl.load(
|
|
48
|
+
dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0)
|
|
49
49
|
# sigmoid requires type float32
|
|
50
|
-
a_row = tl.load(
|
|
51
|
-
b_row = tl.load(
|
|
50
|
+
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
|
51
|
+
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
|
|
52
52
|
|
|
53
53
|
# recomputation to save memory
|
|
54
54
|
sig_a = tl.sigmoid(a_row)
|
|
@@ -56,8 +56,8 @@ def _swiglu_backward_kernel(
|
|
|
56
56
|
db_row = dc_row * silu_a
|
|
57
57
|
da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row
|
|
58
58
|
|
|
59
|
-
tl.store(
|
|
60
|
-
tl.store(
|
|
59
|
+
tl.store(a_ptr + col_offsets, da_row, mask=mask)
|
|
60
|
+
tl.store(b_ptr + col_offsets, db_row, mask=mask)
|
|
61
61
|
|
|
62
62
|
|
|
63
63
|
class LigerSiLUMulFunction(torch.autograd.Function):
|
liger_kernel/ops/utils.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
import functools
|
|
2
|
+
import importlib
|
|
3
|
+
from typing import Callable
|
|
2
4
|
|
|
3
5
|
import torch
|
|
4
6
|
import triton
|
|
7
|
+
from packaging.version import Version
|
|
5
8
|
|
|
6
9
|
|
|
7
10
|
def ensure_contiguous(fn):
|
|
@@ -36,3 +39,12 @@ def calculate_settings(n):
|
|
|
36
39
|
elif BLOCK_SIZE >= 2048:
|
|
37
40
|
num_warps = 8
|
|
38
41
|
return BLOCK_SIZE, num_warps
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def compare_version(package: str, operator: Callable, target: str):
|
|
45
|
+
try:
|
|
46
|
+
pkg = importlib.import_module(package)
|
|
47
|
+
except ImportError:
|
|
48
|
+
return False
|
|
49
|
+
pkg_version = Version(pkg.__version__)
|
|
50
|
+
return operator(pkg_version, Version(target))
|
|
@@ -37,6 +37,9 @@ def lce_forward(
|
|
|
37
37
|
cache_position: Optional[torch.LongTensor] = None,
|
|
38
38
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
39
39
|
r"""
|
|
40
|
+
Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy
|
|
41
|
+
|
|
42
|
+
|
|
40
43
|
Args:
|
|
41
44
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
42
45
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
@@ -1,27 +1,26 @@
|
|
|
1
1
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
|
2
|
+
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
|
2
3
|
from liger_kernel.transformers.model.llama import lce_forward
|
|
3
4
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
4
5
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
5
6
|
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP, LigerSwiGLUMLP
|
|
6
7
|
|
|
7
8
|
|
|
8
|
-
# TODO: probably rename utils.py as hf_patcher.py to be more descriptive
|
|
9
9
|
def apply_liger_kernel_to_llama(
|
|
10
10
|
rope: bool = True,
|
|
11
|
-
cross_entropy: bool =
|
|
12
|
-
fused_linear_cross_entropy: bool =
|
|
11
|
+
cross_entropy: bool = False,
|
|
12
|
+
fused_linear_cross_entropy: bool = True,
|
|
13
13
|
rms_norm: bool = True,
|
|
14
14
|
swiglu: bool = True,
|
|
15
15
|
) -> None:
|
|
16
16
|
"""
|
|
17
17
|
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
|
|
18
|
-
to make GPU go burrr.
|
|
19
18
|
|
|
20
19
|
Args:
|
|
21
20
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
22
|
-
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is
|
|
21
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
23
22
|
fused_linear_cross_entropy (bool):
|
|
24
|
-
Whether to apply Liger's fused lienar cross entropy loss. Default is
|
|
23
|
+
Whether to apply Liger's fused lienar cross entropy loss. Default is True.
|
|
25
24
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
26
25
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
27
26
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
@@ -54,7 +53,6 @@ def apply_liger_kernel_to_mistral(
|
|
|
54
53
|
) -> None:
|
|
55
54
|
"""
|
|
56
55
|
Apply Liger kernels to replace original implementation in HuggingFace Mistral models
|
|
57
|
-
to make GPU go burrr.
|
|
58
56
|
|
|
59
57
|
Args:
|
|
60
58
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
@@ -83,12 +81,12 @@ def apply_liger_kernel_to_mixtral(
|
|
|
83
81
|
) -> None:
|
|
84
82
|
"""
|
|
85
83
|
Apply Liger kernels to replace original implementation in HuggingFace Mixtral models
|
|
86
|
-
to make GPU go burrr.
|
|
87
84
|
|
|
88
85
|
Args:
|
|
89
86
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
90
87
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
91
88
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
89
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
92
90
|
"""
|
|
93
91
|
|
|
94
92
|
from transformers.models.mixtral import modeling_mixtral
|
|
@@ -101,3 +99,32 @@ def apply_liger_kernel_to_mixtral(
|
|
|
101
99
|
modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
102
100
|
if swiglu:
|
|
103
101
|
modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def apply_liger_kernel_to_gemma(
|
|
105
|
+
rope: bool = True,
|
|
106
|
+
cross_entropy: bool = True,
|
|
107
|
+
rms_norm: bool = True,
|
|
108
|
+
geglu: bool = True,
|
|
109
|
+
) -> None:
|
|
110
|
+
"""
|
|
111
|
+
Apply Liger kernels to replace original implementation in HuggingFace Gemma2 models
|
|
112
|
+
to make GPU go burrr.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
116
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
117
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
118
|
+
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
119
|
+
"""
|
|
120
|
+
# TODO(yundai424): add convergence test for gemma
|
|
121
|
+
from transformers.models.gemma import modeling_gemma
|
|
122
|
+
|
|
123
|
+
if rope:
|
|
124
|
+
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
125
|
+
if rms_norm:
|
|
126
|
+
modeling_gemma.GemmaRMSNorm = LigerRMSNorm
|
|
127
|
+
if cross_entropy:
|
|
128
|
+
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
129
|
+
if geglu:
|
|
130
|
+
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from liger_kernel.transformers.monkey_patch import (
|
|
4
|
+
apply_liger_kernel_to_gemma,
|
|
5
|
+
apply_liger_kernel_to_llama,
|
|
6
|
+
apply_liger_kernel_to_mistral,
|
|
7
|
+
apply_liger_kernel_to_mixtral,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
|
|
13
|
+
MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
14
|
+
"gemma": apply_liger_kernel_to_gemma,
|
|
15
|
+
"llama": apply_liger_kernel_to_llama,
|
|
16
|
+
"mistral": apply_liger_kernel_to_mistral,
|
|
17
|
+
"mixtral": apply_liger_kernel_to_mixtral,
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _apply_liger_kernel(model_type: str = "", **kwargs) -> None:
|
|
22
|
+
"""
|
|
23
|
+
Applies Liger kernels based on the specified model type. The custom
|
|
24
|
+
kernels for the specified model type will be applied with the provided
|
|
25
|
+
keyword arguments, otherwise the default configuration will be used.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
- model_type: the model types as defined in transformers/models/auto/modeling_auto.py
|
|
29
|
+
and specified in the model's config.json
|
|
30
|
+
- kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
if not model_type:
|
|
34
|
+
logger.info("Model type was not provided. No Liger kernels will be applied.")
|
|
35
|
+
return
|
|
36
|
+
|
|
37
|
+
if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
|
|
38
|
+
logger.info(
|
|
39
|
+
f"There are currently no Liger kernels supported for model type: {model_type}."
|
|
40
|
+
)
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
logger.info(f"Applying Liger kernels for model type: {model_type}.")
|
|
44
|
+
# Apply the default combination of liger kernels available for the model
|
|
45
|
+
MODEL_TYPE_TO_APPLY_LIGER_FN[model_type](**kwargs)
|
|
@@ -1,12 +1,10 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import random
|
|
3
3
|
|
|
4
|
-
from overrides import override
|
|
5
4
|
from triton.runtime.cache import FileCacheManager
|
|
6
5
|
|
|
7
6
|
|
|
8
7
|
class LigerTritonFileCacheManager(FileCacheManager):
|
|
9
|
-
@override
|
|
10
8
|
def put(self, data, filename, binary=True) -> str:
|
|
11
9
|
if not self.cache_dir:
|
|
12
10
|
raise RuntimeError("Could not create or locate cache dir")
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
BSD 2-CLAUSE LICENSE
|
|
2
|
+
Copyright 2024 LinkedIn Corporation
|
|
3
|
+
All Rights Reserved.
|
|
4
|
+
Redistribution and use in source and binary forms, with or
|
|
5
|
+
without modification, are permitted provided that the following
|
|
6
|
+
conditions are met:
|
|
7
|
+
1. Redistributions of source code must retain the above copyright
|
|
8
|
+
notice, this list of conditions and the following disclaimer.
|
|
9
|
+
2. Redistributions in binary form must reproduce the above
|
|
10
|
+
copyright notice, this list of conditions and the following
|
|
11
|
+
disclaimer in the documentation and/or other materials provided
|
|
12
|
+
with the distribution.
|
|
13
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
14
|
+
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
15
|
+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
16
|
+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
17
|
+
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
|
18
|
+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
|
19
|
+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
20
|
+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
21
|
+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
22
|
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
23
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
liger_kernel/ops/cross_entropy.py,sha256=YTHKVyPW748EWtbWJeKdIe9S1dEq6i90_PbBuCD-9s0,9178
|
|
3
|
+
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=58MmDhLJGR5b8ixztkhR707yp0VY28oBRASFVwGbeV8,7346
|
|
4
|
+
liger_kernel/ops/geglu.py,sha256=5tGinryOOYRpGtKwJ4B1ertwtzd81xdjevD3Ha7H1AY,3849
|
|
5
|
+
liger_kernel/ops/rms_norm.py,sha256=AQ1jaCXUlrBazqAPg-Cpf2K5OsO4byDKcdfWsGy9-zI,4848
|
|
6
|
+
liger_kernel/ops/rope.py,sha256=fYBct8gDQfKPZdMWlzkZZ8kBzh6nQ7DIpDsc7lZwM8c,8584
|
|
7
|
+
liger_kernel/ops/swiglu.py,sha256=MRbSIXsBLqlFr9ZdtuFqSjLJJ-716URmQIhxQ57GGEw,2915
|
|
8
|
+
liger_kernel/ops/utils.py,sha256=vsFIywd8LQlVPRA3RPZOm5HyN8c0cS4NFEEnwjNw-MI,1427
|
|
9
|
+
liger_kernel/transformers/__init__.py,sha256=nVvk0h7er3fdgubQF8Z8KjA3ew-q5oJHyJRg5cKmBoc,205
|
|
10
|
+
liger_kernel/transformers/cross_entropy.py,sha256=G-L4EaUYVc25NKZ2jrlaG-d5YUvDqJdUlawPN7K1d1g,389
|
|
11
|
+
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=h0AW9ubFGfz4DBwgh2CLW8rpKo9PvxYpB6AUzjx-1b0,501
|
|
12
|
+
liger_kernel/transformers/geglu.py,sha256=FrLBHZRdI68jw9RR6MSTE59-xCzueOwSRp9jL8y-j98,896
|
|
13
|
+
liger_kernel/transformers/monkey_patch.py,sha256=FjaRZVWm_ZMHO3NXc4IT6EpCTWJOdZKP72mZq01qbrA,5006
|
|
14
|
+
liger_kernel/transformers/rms_norm.py,sha256=2LHfEctSpzuNRaoZ9uUECSFK8fZeIxIsHm9QbEHZvDQ,452
|
|
15
|
+
liger_kernel/transformers/rope.py,sha256=m-ah8vZBYW8tfplTXCiAPMHJWlB1tdp_JPXJeWE-Boo,943
|
|
16
|
+
liger_kernel/transformers/swiglu.py,sha256=8kt4MffEZT5vx3k0WA-GO-WPLv5kGdnu_nAwlJyMI2U,1516
|
|
17
|
+
liger_kernel/transformers/trainer_integration.py,sha256=gt0fF-se2XiIB6PocHBPBuD6tLCOtQRcb20WfUS2ceA,1645
|
|
18
|
+
liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
19
|
+
liger_kernel/transformers/model/llama.py,sha256=4mfVTMrY7T-xiJeQJe02hBVnAwNCKlvLGp49gj6TWiU,5298
|
|
20
|
+
liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
|
|
21
|
+
liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
|
|
22
|
+
liger_kernel-0.1.0.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
|
23
|
+
liger_kernel-0.1.0.dist-info/METADATA,sha256=E_OSiFz2sC4jmWO4VH3sTXWiR3Ev7qNy5oSLSWk-s8g,504
|
|
24
|
+
liger_kernel-0.1.0.dist-info/NOTICE,sha256=BXkXY9aWvEy_7MAB57zDu1z8uMYT1i1l9B6EpHuBa8s,173
|
|
25
|
+
liger_kernel-0.1.0.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
|
26
|
+
liger_kernel-0.1.0.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
|
27
|
+
liger_kernel-0.1.0.dist-info/RECORD,,
|
|
@@ -1,24 +0,0 @@
|
|
|
1
|
-
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
liger_kernel/ops/cross_entropy.py,sha256=XRnLWW2Jo1sVllDbyTuM8ir_6WZR791fFgqoaIVzPrM,10665
|
|
3
|
-
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=-z5NDZ1a7htumYPJbmw1QRqgp9N_QKZZiuaZSPCb9Y0,7311
|
|
4
|
-
liger_kernel/ops/geglu.py,sha256=DiJSy4I8kouPFyNpKUuthfibZWRioPJMGR-4MJgebhg,3660
|
|
5
|
-
liger_kernel/ops/rms_norm.py,sha256=iQd8ZDzNM-3b05eLzjh1Jfj2C8QKAtg59h-b-XuIo5s,4299
|
|
6
|
-
liger_kernel/ops/rope.py,sha256=fYBct8gDQfKPZdMWlzkZZ8kBzh6nQ7DIpDsc7lZwM8c,8584
|
|
7
|
-
liger_kernel/ops/swiglu.py,sha256=__QsfYxKyZHtRScm31zL3sAOVEblQFqKj2ll8I4Odqg,2835
|
|
8
|
-
liger_kernel/ops/utils.py,sha256=cC7rvhiEBW-8x4qQRTUYWW790k3TA-S7pKbJmdRj-Xc,1080
|
|
9
|
-
liger_kernel/transformers/__init__.py,sha256=7rOw9yZ8kNXO483Colx-EUq8GcTCvCZxrxF-S7pmkkU,172
|
|
10
|
-
liger_kernel/transformers/cross_entropy.py,sha256=G-L4EaUYVc25NKZ2jrlaG-d5YUvDqJdUlawPN7K1d1g,389
|
|
11
|
-
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=h0AW9ubFGfz4DBwgh2CLW8rpKo9PvxYpB6AUzjx-1b0,501
|
|
12
|
-
liger_kernel/transformers/geglu.py,sha256=FrLBHZRdI68jw9RR6MSTE59-xCzueOwSRp9jL8y-j98,896
|
|
13
|
-
liger_kernel/transformers/monkey_patch.py,sha256=5h436874AENVnTjQAk4-Srp_GIr50CXAl2xeNTbqzJg,3988
|
|
14
|
-
liger_kernel/transformers/rms_norm.py,sha256=2LHfEctSpzuNRaoZ9uUECSFK8fZeIxIsHm9QbEHZvDQ,452
|
|
15
|
-
liger_kernel/transformers/rope.py,sha256=m-ah8vZBYW8tfplTXCiAPMHJWlB1tdp_JPXJeWE-Boo,943
|
|
16
|
-
liger_kernel/transformers/swiglu.py,sha256=8kt4MffEZT5vx3k0WA-GO-WPLv5kGdnu_nAwlJyMI2U,1516
|
|
17
|
-
liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
18
|
-
liger_kernel/transformers/model/llama.py,sha256=DJOjLT5-TGMLKaPqLqyW03rLae8lJTb3nwnfg2mVNXQ,5197
|
|
19
|
-
liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
|
|
20
|
-
liger_kernel/triton/monkey_patch.py,sha256=yRNaGdyG5PrwX5ed_MQdqtqvvpVvQ7ZD2FQ_9W1q9u8,1629
|
|
21
|
-
liger_kernel-0.0.0.dist-info/METADATA,sha256=SBK5dFzMYYtFyorscmi__7u83TitHFpKMsRE9pUKXGI,461
|
|
22
|
-
liger_kernel-0.0.0.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
|
23
|
-
liger_kernel-0.0.0.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
|
24
|
-
liger_kernel-0.0.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|