liger-kernel 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/env_report.py +46 -0
- liger_kernel/ops/cross_entropy.py +130 -63
- liger_kernel/ops/experimental/embedding.py +143 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +203 -126
- liger_kernel/ops/geglu.py +54 -42
- liger_kernel/ops/kl_div.py +247 -0
- liger_kernel/ops/layer_norm.py +236 -0
- liger_kernel/ops/rms_norm.py +220 -84
- liger_kernel/ops/rope.py +91 -84
- liger_kernel/ops/swiglu.py +48 -41
- liger_kernel/ops/utils.py +12 -0
- liger_kernel/transformers/__init__.py +22 -0
- liger_kernel/transformers/auto_model.py +33 -0
- liger_kernel/transformers/cross_entropy.py +11 -1
- liger_kernel/transformers/experimental/embedding.py +28 -0
- liger_kernel/transformers/functional.py +19 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +8 -2
- liger_kernel/transformers/geglu.py +4 -2
- liger_kernel/transformers/kl_div.py +13 -0
- liger_kernel/transformers/layer_norm.py +30 -0
- liger_kernel/transformers/model/gemma.py +138 -0
- liger_kernel/transformers/model/llama.py +1 -1
- liger_kernel/transformers/model/mistral.py +138 -0
- liger_kernel/transformers/model/mixtral.py +158 -0
- liger_kernel/transformers/model/phi3.py +136 -0
- liger_kernel/transformers/model/qwen2.py +135 -0
- liger_kernel/transformers/model/qwen2_vl.py +172 -0
- liger_kernel/transformers/monkey_patch.py +605 -14
- liger_kernel/transformers/rms_norm.py +23 -4
- liger_kernel/transformers/swiglu.py +24 -0
- liger_kernel/transformers/trainer_integration.py +2 -45
- liger_kernel-0.3.0.dist-info/METADATA +388 -0
- liger_kernel-0.3.0.dist-info/RECORD +42 -0
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/WHEEL +1 -1
- liger_kernel-0.1.0.dist-info/METADATA +0 -16
- liger_kernel-0.1.0.dist-info/RECORD +0 -27
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/LICENSE +0 -0
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/NOTICE +0 -0
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/top_level.txt +0 -0
liger_kernel/ops/rms_norm.py
CHANGED
|
@@ -1,26 +1,61 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
|
|
3
|
+
See the original Unsloth repository at https://github.com/unslothai/unsloth.
|
|
4
|
+
|
|
5
|
+
The following line
|
|
6
|
+
https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/rms_norm.py#L30
|
|
7
|
+
is based on code from Unsloth, located at:
|
|
8
|
+
https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
|
|
9
|
+
|
|
10
|
+
Modifications made by Yanning Chen, 2024.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import operator
|
|
14
|
+
|
|
1
15
|
import torch
|
|
2
16
|
import triton
|
|
3
17
|
import triton.language as tl
|
|
4
18
|
|
|
5
|
-
from liger_kernel.ops.utils import
|
|
19
|
+
from liger_kernel.ops.utils import (
|
|
20
|
+
calculate_settings,
|
|
21
|
+
compare_version,
|
|
22
|
+
ensure_contiguous,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
|
26
|
+
try:
|
|
27
|
+
# typical import path with dispatch available
|
|
28
|
+
from triton.language.extra.libdevice import rsqrt
|
|
29
|
+
except ModuleNotFoundError:
|
|
30
|
+
# for working with NGC containers
|
|
31
|
+
from triton.language.extra.cuda.libdevice import rsqrt
|
|
32
|
+
else:
|
|
33
|
+
from triton.language.math import rsqrt
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
_CASTING_MODE_NONE = tl.constexpr(-1)
|
|
37
|
+
_CASTING_MODE_LLAMA = tl.constexpr(0)
|
|
38
|
+
_CASTING_MODE_GEMMA = tl.constexpr(1)
|
|
6
39
|
|
|
7
40
|
|
|
8
41
|
@triton.jit
|
|
9
|
-
def
|
|
42
|
+
def _rms_norm_forward_kernel(
|
|
10
43
|
Y_ptr,
|
|
11
44
|
Y_row_stride,
|
|
12
45
|
X_ptr,
|
|
13
46
|
X_row_stride,
|
|
14
47
|
W_ptr,
|
|
15
48
|
W_row_stride,
|
|
16
|
-
|
|
17
|
-
|
|
49
|
+
RSTD_ptr,
|
|
50
|
+
RSTD_row_stride,
|
|
18
51
|
n_cols,
|
|
19
52
|
eps,
|
|
53
|
+
offset,
|
|
54
|
+
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
|
20
55
|
BLOCK_SIZE: tl.constexpr,
|
|
21
56
|
):
|
|
22
57
|
"""
|
|
23
|
-
y_i = (x_i / (RMS)) * wi, RMS = sqrt(sum(x_i^2) / N)
|
|
58
|
+
y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
|
|
24
59
|
|
|
25
60
|
Reference:
|
|
26
61
|
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
|
@@ -34,42 +69,59 @@ def _rms_norm_forward(
|
|
|
34
69
|
|
|
35
70
|
Y_ptr += row_idx * Y_row_stride
|
|
36
71
|
X_ptr += row_idx * X_row_stride
|
|
37
|
-
|
|
72
|
+
RSTD_ptr += row_idx * RSTD_row_stride
|
|
38
73
|
|
|
39
74
|
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
|
|
75
|
+
X_row_dtype = X_row.dtype
|
|
40
76
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
41
77
|
|
|
78
|
+
# On Llama, only rstd is computed on fp32
|
|
79
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
80
|
+
X_row = X_row.to(tl.float32)
|
|
81
|
+
|
|
82
|
+
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
|
83
|
+
if casting_mode == _CASTING_MODE_GEMMA:
|
|
84
|
+
W_row = W_row.to(tl.float32)
|
|
85
|
+
X_row = X_row.to(tl.float32)
|
|
86
|
+
|
|
42
87
|
mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
|
|
43
|
-
|
|
88
|
+
rstd = rsqrt(mean_square + eps)
|
|
44
89
|
|
|
45
90
|
# We can save time by caching rms with minimal memory overhead
|
|
46
91
|
# because rms is much smaller compared to X_row, as rms is for each row.
|
|
47
92
|
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
|
|
48
|
-
tl.store(
|
|
93
|
+
tl.store(RSTD_ptr, rstd)
|
|
94
|
+
|
|
95
|
+
X_row = X_row * rstd
|
|
96
|
+
|
|
97
|
+
# On Llama, the multiplication with the weight is done on the original dtype
|
|
98
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
99
|
+
X_row = X_row.to(X_row_dtype)
|
|
49
100
|
|
|
50
|
-
Y_row = X_row *
|
|
101
|
+
Y_row = X_row * (offset + W_row)
|
|
51
102
|
|
|
52
103
|
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
|
|
53
104
|
|
|
54
105
|
|
|
55
106
|
@triton.jit
|
|
56
|
-
def
|
|
107
|
+
def _rms_norm_backward_kernel(
|
|
57
108
|
dY_ptr,
|
|
58
109
|
dY_row_stride,
|
|
59
110
|
X_ptr,
|
|
60
111
|
X_row_stride,
|
|
61
112
|
W_ptr,
|
|
62
113
|
W_row_stride,
|
|
63
|
-
|
|
64
|
-
|
|
114
|
+
RSTD_ptr,
|
|
115
|
+
RSTD_row_stride,
|
|
65
116
|
dW_ptr,
|
|
66
117
|
dW_row_stride,
|
|
67
118
|
n_cols,
|
|
68
|
-
|
|
119
|
+
offset,
|
|
120
|
+
casting_mode: tl.constexpr,
|
|
69
121
|
BLOCK_SIZE: tl.constexpr,
|
|
70
122
|
):
|
|
71
123
|
"""
|
|
72
|
-
dx = (1 / RMS) * [dy * w
|
|
124
|
+
dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
|
|
73
125
|
dw = sum(dy * (x / RMS)). summation over BxT dimension
|
|
74
126
|
"""
|
|
75
127
|
|
|
@@ -79,75 +131,175 @@ def _rms_norm_backward(
|
|
|
79
131
|
|
|
80
132
|
dY_ptr += row_idx * dY_row_stride
|
|
81
133
|
X_ptr += row_idx * X_row_stride
|
|
82
|
-
|
|
134
|
+
RSTD_ptr += row_idx * RSTD_row_stride
|
|
83
135
|
dW_ptr += row_idx * dW_row_stride
|
|
84
136
|
|
|
85
137
|
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0)
|
|
86
138
|
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
|
|
87
139
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
140
|
+
original_x_dtype = X_row.dtype
|
|
88
141
|
|
|
89
142
|
# Get cached rms
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
*
|
|
143
|
+
rstd_row = tl.load(RSTD_ptr)
|
|
144
|
+
|
|
145
|
+
W_row = W_row + offset
|
|
146
|
+
|
|
147
|
+
X_row = X_row.to(tl.float32)
|
|
148
|
+
|
|
149
|
+
# Different bacward graphs for different casting modes
|
|
150
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
151
|
+
m = (dY_row * W_row).to(tl.float32)
|
|
152
|
+
|
|
153
|
+
elif casting_mode == _CASTING_MODE_GEMMA:
|
|
154
|
+
dY_row, W_row = (
|
|
155
|
+
dY_row.to(tl.float32),
|
|
156
|
+
W_row.to(tl.float32),
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
m = dY_row * W_row
|
|
160
|
+
|
|
161
|
+
dX_row = rstd_row * m
|
|
162
|
+
|
|
163
|
+
dX_row += (rstd_row) * (
|
|
164
|
+
-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
|
|
99
165
|
)
|
|
100
|
-
tl.store(dY_ptr + col_offsets, dX_row, mask=mask)
|
|
101
166
|
|
|
102
167
|
# calculate the gradient of W
|
|
103
|
-
|
|
168
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
169
|
+
dW_row = dY_row * (X_row * rstd_row).to(original_x_dtype)
|
|
170
|
+
else:
|
|
171
|
+
# here X_row is already in fp32 (see previous if block)
|
|
172
|
+
dW_row = dY_row * (X_row * rstd_row)
|
|
173
|
+
|
|
174
|
+
tl.store(dY_ptr + col_offsets, dX_row, mask=mask)
|
|
104
175
|
tl.store(dW_ptr + col_offsets, dW_row, mask=mask)
|
|
105
176
|
|
|
106
177
|
|
|
178
|
+
_str_to_casting_mode = {
|
|
179
|
+
"llama": _CASTING_MODE_LLAMA.value,
|
|
180
|
+
"gemma": _CASTING_MODE_GEMMA.value,
|
|
181
|
+
"none": _CASTING_MODE_NONE.value,
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
186
|
+
if not isinstance(casting_mode, int):
|
|
187
|
+
assert (
|
|
188
|
+
casting_mode in _str_to_casting_mode
|
|
189
|
+
), f"Invalid casting mode: {casting_mode}"
|
|
190
|
+
casting_mode = _str_to_casting_mode[casting_mode]
|
|
191
|
+
else:
|
|
192
|
+
assert (
|
|
193
|
+
casting_mode in _str_to_casting_mode.values()
|
|
194
|
+
), f"Invalid casting mode: {casting_mode}"
|
|
195
|
+
|
|
196
|
+
shape = X.shape
|
|
197
|
+
dim = shape[-1]
|
|
198
|
+
X = X.view(-1, dim)
|
|
199
|
+
n_rows, n_cols = X.shape
|
|
200
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
201
|
+
|
|
202
|
+
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
203
|
+
# RSTD is to cache rstd for each row
|
|
204
|
+
# RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
|
|
205
|
+
rstd_dtype = (
|
|
206
|
+
torch.float32
|
|
207
|
+
if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value)
|
|
208
|
+
else X.dtype
|
|
209
|
+
)
|
|
210
|
+
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
|
|
211
|
+
|
|
212
|
+
# Check constraints.
|
|
213
|
+
assert (
|
|
214
|
+
X.shape[1] == W.shape[0]
|
|
215
|
+
), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
|
216
|
+
|
|
217
|
+
_rms_norm_forward_kernel[(n_rows,)](
|
|
218
|
+
Y,
|
|
219
|
+
Y.stride(0),
|
|
220
|
+
X,
|
|
221
|
+
X.stride(0),
|
|
222
|
+
W,
|
|
223
|
+
W.stride(0),
|
|
224
|
+
RSTD,
|
|
225
|
+
RSTD.stride(0),
|
|
226
|
+
n_cols,
|
|
227
|
+
eps,
|
|
228
|
+
offset,
|
|
229
|
+
casting_mode,
|
|
230
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
231
|
+
num_warps=num_warps,
|
|
232
|
+
)
|
|
233
|
+
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps):
|
|
237
|
+
shape = dY.shape
|
|
238
|
+
dim = shape[-1]
|
|
239
|
+
dY = dY.view(-1, dim)
|
|
240
|
+
n_rows, n_cols = dY.shape
|
|
241
|
+
dW = torch.empty_like(
|
|
242
|
+
X,
|
|
243
|
+
dtype=(torch.float32 if casting_mode == _CASTING_MODE_GEMMA.value else W.dtype),
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# Here we use dY to store the value of dX to save memory
|
|
247
|
+
_rms_norm_backward_kernel[(n_rows,)](
|
|
248
|
+
dY,
|
|
249
|
+
dY.stride(0),
|
|
250
|
+
X,
|
|
251
|
+
X.stride(0),
|
|
252
|
+
W,
|
|
253
|
+
W.stride(0),
|
|
254
|
+
RSTD,
|
|
255
|
+
RSTD.stride(0),
|
|
256
|
+
dW,
|
|
257
|
+
dW.stride(0),
|
|
258
|
+
n_cols,
|
|
259
|
+
offset,
|
|
260
|
+
casting_mode,
|
|
261
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
262
|
+
num_warps=num_warps,
|
|
263
|
+
)
|
|
264
|
+
dX = dY.view(*shape)
|
|
265
|
+
dW = torch.sum(dW, dim=0).to(W.dtype)
|
|
266
|
+
return dX, dW
|
|
267
|
+
|
|
268
|
+
|
|
107
269
|
class LigerRMSNormFunction(torch.autograd.Function):
|
|
270
|
+
"""
|
|
271
|
+
Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the
|
|
272
|
+
weight tensor `W`, with an optional offset and casting mode.
|
|
273
|
+
|
|
274
|
+
Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma
|
|
275
|
+
uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual
|
|
276
|
+
`(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function.
|
|
277
|
+
|
|
278
|
+
In addition, different models cast their inputs at different places during RMSNorm computation. For
|
|
279
|
+
example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the
|
|
280
|
+
inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently
|
|
281
|
+
support the following casting modes (they match HuggingFace Transformers' implementations):
|
|
282
|
+
- 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
|
|
283
|
+
- 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
|
|
284
|
+
- 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
|
|
285
|
+
"""
|
|
286
|
+
|
|
108
287
|
@staticmethod
|
|
109
288
|
@ensure_contiguous
|
|
110
|
-
def forward(ctx, X, W, eps):
|
|
289
|
+
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama"):
|
|
111
290
|
"""
|
|
112
291
|
X: (B, T, H) or (BxT, H)
|
|
113
292
|
W: (H,)
|
|
114
293
|
"""
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
dim = shape[-1]
|
|
118
|
-
X = X.view(-1, dim)
|
|
119
|
-
n_rows, n_cols = X.shape
|
|
120
|
-
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
121
|
-
|
|
122
|
-
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
123
|
-
# r is to cache (1/rms) for each row
|
|
124
|
-
r = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
125
|
-
|
|
126
|
-
# Check constraints.
|
|
127
|
-
assert (
|
|
128
|
-
X.shape[1] == W.shape[0]
|
|
129
|
-
), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
|
130
|
-
|
|
131
|
-
_rms_norm_forward[(n_rows,)](
|
|
132
|
-
Y,
|
|
133
|
-
Y.stride(0),
|
|
134
|
-
X,
|
|
135
|
-
X.stride(0),
|
|
136
|
-
W,
|
|
137
|
-
W.stride(0),
|
|
138
|
-
r,
|
|
139
|
-
r.stride(0),
|
|
140
|
-
n_cols,
|
|
141
|
-
eps,
|
|
142
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
|
143
|
-
num_warps=num_warps,
|
|
294
|
+
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(
|
|
295
|
+
X, W, eps, offset, casting_mode
|
|
144
296
|
)
|
|
145
|
-
ctx.
|
|
297
|
+
ctx.offset = offset
|
|
298
|
+
ctx.casting_mode = casting_mode
|
|
146
299
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
147
300
|
ctx.num_warps = num_warps
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
return Y.view(*shape)
|
|
301
|
+
ctx.save_for_backward(X, W, RSTD)
|
|
302
|
+
return Y
|
|
151
303
|
|
|
152
304
|
@staticmethod
|
|
153
305
|
@ensure_contiguous
|
|
@@ -155,31 +307,15 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
155
307
|
"""
|
|
156
308
|
Y: (B, T, H) or (BxT, H)
|
|
157
309
|
"""
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
dim = shape[-1]
|
|
161
|
-
dY = dY.view(-1, dim)
|
|
162
|
-
X, W, r = ctx.saved_tensors
|
|
163
|
-
n_rows, n_cols = dY.shape
|
|
164
|
-
dW = torch.zeros_like(X)
|
|
165
|
-
|
|
166
|
-
# Here we use dY to store the value of dX to save memory
|
|
167
|
-
_rms_norm_backward[(n_rows,)](
|
|
310
|
+
X, W, RSTD = ctx.saved_tensors
|
|
311
|
+
dX, dW = rms_norm_backward(
|
|
168
312
|
dY,
|
|
169
|
-
dY.stride(0),
|
|
170
313
|
X,
|
|
171
|
-
X.stride(0),
|
|
172
314
|
W,
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
n_cols,
|
|
179
|
-
ctx.eps,
|
|
180
|
-
BLOCK_SIZE=ctx.BLOCK_SIZE,
|
|
181
|
-
num_warps=ctx.num_warps,
|
|
315
|
+
RSTD,
|
|
316
|
+
ctx.offset,
|
|
317
|
+
ctx.casting_mode,
|
|
318
|
+
ctx.BLOCK_SIZE,
|
|
319
|
+
ctx.num_warps,
|
|
182
320
|
)
|
|
183
|
-
dX
|
|
184
|
-
dW = torch.sum(dW, dim=0)
|
|
185
|
-
return dX, dW, None
|
|
321
|
+
return dX, dW, None, None, None
|
liger_kernel/ops/rope.py
CHANGED
|
@@ -13,8 +13,8 @@ def _triton_rope(
|
|
|
13
13
|
cos_row_stride,
|
|
14
14
|
sin,
|
|
15
15
|
sin_row_stride,
|
|
16
|
+
sl,
|
|
16
17
|
bs: tl.constexpr,
|
|
17
|
-
sl: tl.constexpr,
|
|
18
18
|
n_qh: tl.constexpr,
|
|
19
19
|
n_kh: tl.constexpr,
|
|
20
20
|
hd: tl.constexpr,
|
|
@@ -117,6 +117,92 @@ def _triton_rope(
|
|
|
117
117
|
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
|
118
118
|
|
|
119
119
|
|
|
120
|
+
def rope_forward(q, k, cos, sin):
|
|
121
|
+
|
|
122
|
+
# transpose it back to the physical shape because Triton looks at the physical storage
|
|
123
|
+
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
|
124
|
+
q = q.transpose(1, 2)
|
|
125
|
+
k = k.transpose(1, 2)
|
|
126
|
+
|
|
127
|
+
batch_size, seq_len, n_q_head, head_dim = q.shape
|
|
128
|
+
n_kv_head = k.shape[2]
|
|
129
|
+
pad_hd = triton.next_power_of_2(head_dim)
|
|
130
|
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
131
|
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
132
|
+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
|
133
|
+
|
|
134
|
+
n_row = batch_size * seq_len
|
|
135
|
+
|
|
136
|
+
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
|
|
137
|
+
q = q.contiguous()
|
|
138
|
+
k = k.contiguous()
|
|
139
|
+
cos = cos.contiguous()
|
|
140
|
+
sin = sin.contiguous()
|
|
141
|
+
|
|
142
|
+
_triton_rope[(n_row,)](
|
|
143
|
+
q,
|
|
144
|
+
q.stride(1),
|
|
145
|
+
k,
|
|
146
|
+
k.stride(1),
|
|
147
|
+
cos,
|
|
148
|
+
cos.stride(-2),
|
|
149
|
+
sin,
|
|
150
|
+
sin.stride(-2),
|
|
151
|
+
seq_len,
|
|
152
|
+
batch_size,
|
|
153
|
+
n_q_head,
|
|
154
|
+
n_kv_head,
|
|
155
|
+
head_dim,
|
|
156
|
+
pad_n_q_head,
|
|
157
|
+
pad_n_kv_head,
|
|
158
|
+
pad_hd,
|
|
159
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
160
|
+
BACKWARD_PASS=False,
|
|
161
|
+
)
|
|
162
|
+
return q.transpose(1, 2), k.transpose(1, 2), cos, sin
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def rope_backward(dq, dk, cos, sin):
|
|
166
|
+
dq = dq.transpose(1, 2)
|
|
167
|
+
dk = dk.transpose(1, 2)
|
|
168
|
+
|
|
169
|
+
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
|
170
|
+
n_kv_head = dk.shape[2]
|
|
171
|
+
pad_hd = triton.next_power_of_2(head_dim)
|
|
172
|
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
173
|
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
174
|
+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
|
175
|
+
|
|
176
|
+
n_row = batch_size * seq_len
|
|
177
|
+
|
|
178
|
+
# ensure dq and dk are contiguous
|
|
179
|
+
dq = dq.contiguous()
|
|
180
|
+
dk = dk.contiguous()
|
|
181
|
+
|
|
182
|
+
# backward is similar to forward except swapping few ops
|
|
183
|
+
_triton_rope[(n_row,)](
|
|
184
|
+
dq,
|
|
185
|
+
dq.stride(1),
|
|
186
|
+
dk,
|
|
187
|
+
dk.stride(1),
|
|
188
|
+
cos,
|
|
189
|
+
cos.stride(-2),
|
|
190
|
+
sin,
|
|
191
|
+
sin.stride(-2),
|
|
192
|
+
seq_len,
|
|
193
|
+
batch_size,
|
|
194
|
+
n_q_head,
|
|
195
|
+
n_kv_head,
|
|
196
|
+
head_dim,
|
|
197
|
+
pad_n_q_head,
|
|
198
|
+
pad_n_kv_head,
|
|
199
|
+
pad_hd,
|
|
200
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
201
|
+
BACKWARD_PASS=True,
|
|
202
|
+
)
|
|
203
|
+
return dq.transpose(1, 2), dk.transpose(1, 2)
|
|
204
|
+
|
|
205
|
+
|
|
120
206
|
class LigerRopeFunction(torch.autograd.Function):
|
|
121
207
|
"""
|
|
122
208
|
Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that
|
|
@@ -138,50 +224,9 @@ class LigerRopeFunction(torch.autograd.Function):
|
|
|
138
224
|
cos size: (1, seq_len, head_dim)
|
|
139
225
|
sin size: (1, seq_len, head_dim)
|
|
140
226
|
"""
|
|
141
|
-
|
|
142
|
-
# transpose it back to the physical shape because Triton looks at the physical storage
|
|
143
|
-
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
|
144
|
-
q = q.transpose(1, 2)
|
|
145
|
-
k = k.transpose(1, 2)
|
|
146
|
-
|
|
147
|
-
batch_size, seq_len, n_q_head, head_dim = q.shape
|
|
148
|
-
n_kv_head = k.shape[2]
|
|
149
|
-
pad_hd = triton.next_power_of_2(head_dim)
|
|
150
|
-
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
151
|
-
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
152
|
-
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
|
153
|
-
|
|
154
|
-
n_row = batch_size * seq_len
|
|
155
|
-
|
|
156
|
-
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
|
|
157
|
-
q = q.contiguous()
|
|
158
|
-
k = k.contiguous()
|
|
159
|
-
cos = cos.contiguous()
|
|
160
|
-
sin = sin.contiguous()
|
|
161
|
-
|
|
162
|
-
_triton_rope[(n_row,)](
|
|
163
|
-
q,
|
|
164
|
-
q.stride(1),
|
|
165
|
-
k,
|
|
166
|
-
k.stride(1),
|
|
167
|
-
cos,
|
|
168
|
-
cos.stride(-2),
|
|
169
|
-
sin,
|
|
170
|
-
sin.stride(-2),
|
|
171
|
-
batch_size,
|
|
172
|
-
seq_len,
|
|
173
|
-
n_q_head,
|
|
174
|
-
n_kv_head,
|
|
175
|
-
head_dim,
|
|
176
|
-
pad_n_q_head,
|
|
177
|
-
pad_n_kv_head,
|
|
178
|
-
pad_hd,
|
|
179
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
|
180
|
-
BACKWARD_PASS=False,
|
|
181
|
-
)
|
|
182
|
-
|
|
227
|
+
q, k, cos, sin = rope_forward(q, k, cos, sin)
|
|
183
228
|
ctx.save_for_backward(cos, sin)
|
|
184
|
-
return q
|
|
229
|
+
return q, k
|
|
185
230
|
|
|
186
231
|
def backward(ctx, dq, dk):
|
|
187
232
|
"""
|
|
@@ -192,43 +237,5 @@ class LigerRopeFunction(torch.autograd.Function):
|
|
|
192
237
|
"""
|
|
193
238
|
|
|
194
239
|
cos, sin = ctx.saved_tensors
|
|
195
|
-
|
|
196
|
-
dq
|
|
197
|
-
dk = dk.transpose(1, 2)
|
|
198
|
-
|
|
199
|
-
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
|
200
|
-
n_kv_head = dk.shape[2]
|
|
201
|
-
pad_hd = triton.next_power_of_2(head_dim)
|
|
202
|
-
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
203
|
-
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
204
|
-
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
|
205
|
-
|
|
206
|
-
n_row = batch_size * seq_len
|
|
207
|
-
|
|
208
|
-
# ensure dq and dk are contiguous
|
|
209
|
-
dq = dq.contiguous()
|
|
210
|
-
dk = dk.contiguous()
|
|
211
|
-
|
|
212
|
-
# backward is similar to forward except swapping few ops
|
|
213
|
-
_triton_rope[(n_row,)](
|
|
214
|
-
dq,
|
|
215
|
-
dq.stride(1),
|
|
216
|
-
dk,
|
|
217
|
-
dk.stride(1),
|
|
218
|
-
cos,
|
|
219
|
-
cos.stride(-2),
|
|
220
|
-
sin,
|
|
221
|
-
sin.stride(-2),
|
|
222
|
-
batch_size,
|
|
223
|
-
seq_len,
|
|
224
|
-
n_q_head,
|
|
225
|
-
n_kv_head,
|
|
226
|
-
head_dim,
|
|
227
|
-
pad_n_q_head,
|
|
228
|
-
pad_n_kv_head,
|
|
229
|
-
pad_hd,
|
|
230
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
|
231
|
-
BACKWARD_PASS=True,
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
return dq.transpose(1, 2), dk.transpose(1, 2), None, None, None, None
|
|
240
|
+
dq, dk = rope_backward(dq, dk, cos, sin)
|
|
241
|
+
return dq, dk, None, None, None, None
|
liger_kernel/ops/swiglu.py
CHANGED
|
@@ -60,54 +60,61 @@ def _swiglu_backward_kernel(
|
|
|
60
60
|
tl.store(b_ptr + col_offsets, db_row, mask=mask)
|
|
61
61
|
|
|
62
62
|
|
|
63
|
+
def swiglu_forward(a, b):
|
|
64
|
+
ori_shape = a.shape
|
|
65
|
+
|
|
66
|
+
n_cols = ori_shape[-1]
|
|
67
|
+
a = a.view(-1, n_cols)
|
|
68
|
+
b = b.view(-1, n_cols)
|
|
69
|
+
c = torch.empty_like(a)
|
|
70
|
+
n_rows = a.shape[0]
|
|
71
|
+
|
|
72
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
73
|
+
|
|
74
|
+
_swiglu_forward_kernel[(n_rows,)](
|
|
75
|
+
a,
|
|
76
|
+
b,
|
|
77
|
+
c,
|
|
78
|
+
c.stride(-2),
|
|
79
|
+
n_cols=n_cols,
|
|
80
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
81
|
+
num_warps=num_warps,
|
|
82
|
+
)
|
|
83
|
+
return a, b, c.view(*ori_shape)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def swiglu_backward(a, b, dc):
|
|
87
|
+
|
|
88
|
+
ori_shape = dc.shape
|
|
89
|
+
n_cols = ori_shape[-1]
|
|
90
|
+
dc = dc.view(-1, n_cols)
|
|
91
|
+
n_rows = dc.shape[0]
|
|
92
|
+
|
|
93
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
94
|
+
|
|
95
|
+
_swiglu_backward_kernel[(n_rows,)](
|
|
96
|
+
dc,
|
|
97
|
+
a,
|
|
98
|
+
b,
|
|
99
|
+
dc.stride(-2),
|
|
100
|
+
n_cols=n_cols,
|
|
101
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
102
|
+
num_warps=num_warps,
|
|
103
|
+
)
|
|
104
|
+
return a.view(*ori_shape), b.view(*ori_shape)
|
|
105
|
+
|
|
106
|
+
|
|
63
107
|
class LigerSiLUMulFunction(torch.autograd.Function):
|
|
64
108
|
@staticmethod
|
|
65
109
|
@ensure_contiguous
|
|
66
110
|
def forward(ctx, a, b):
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
n_cols = ori_shape[-1]
|
|
70
|
-
a = a.view(-1, n_cols)
|
|
71
|
-
b = b.view(-1, n_cols)
|
|
72
|
-
c = torch.zeros_like(a)
|
|
73
|
-
n_rows = a.shape[0]
|
|
74
|
-
|
|
75
|
-
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
76
|
-
|
|
77
|
-
_swiglu_forward_kernel[(n_rows,)](
|
|
78
|
-
a,
|
|
79
|
-
b,
|
|
80
|
-
c,
|
|
81
|
-
c.stride(-2),
|
|
82
|
-
n_cols=n_cols,
|
|
83
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
|
84
|
-
num_warps=num_warps,
|
|
85
|
-
)
|
|
86
|
-
|
|
111
|
+
a, b, c = swiglu_forward(a, b)
|
|
87
112
|
ctx.save_for_backward(a, b)
|
|
88
|
-
|
|
89
|
-
return c.view(*ori_shape)
|
|
113
|
+
return c
|
|
90
114
|
|
|
91
115
|
@staticmethod
|
|
92
116
|
@ensure_contiguous
|
|
93
117
|
def backward(ctx, dc):
|
|
94
|
-
|
|
95
|
-
ori_shape = dc.shape
|
|
96
|
-
n_cols = ori_shape[-1]
|
|
97
|
-
dc = dc.view(-1, n_cols)
|
|
98
118
|
a, b = ctx.saved_tensors
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
102
|
-
|
|
103
|
-
_swiglu_backward_kernel[(n_rows,)](
|
|
104
|
-
dc,
|
|
105
|
-
a,
|
|
106
|
-
b,
|
|
107
|
-
dc.stride(-2),
|
|
108
|
-
n_cols=n_cols,
|
|
109
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
|
110
|
-
num_warps=num_warps,
|
|
111
|
-
)
|
|
112
|
-
|
|
113
|
-
return a.view(*ori_shape), b.view(*ori_shape)
|
|
119
|
+
a, b = swiglu_backward(a, b, dc)
|
|
120
|
+
return a, b
|