liger-kernel 0.0.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/env_report.py +46 -0
- liger_kernel/ops/cross_entropy.py +130 -63
- liger_kernel/ops/experimental/embedding.py +143 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +203 -126
- liger_kernel/ops/geglu.py +54 -42
- liger_kernel/ops/kl_div.py +247 -0
- liger_kernel/ops/layer_norm.py +236 -0
- liger_kernel/ops/rms_norm.py +235 -81
- liger_kernel/ops/rope.py +91 -84
- liger_kernel/ops/swiglu.py +64 -57
- liger_kernel/ops/utils.py +12 -0
- liger_kernel/transformers/__init__.py +23 -0
- liger_kernel/transformers/auto_model.py +33 -0
- liger_kernel/transformers/cross_entropy.py +11 -1
- liger_kernel/transformers/experimental/embedding.py +28 -0
- liger_kernel/transformers/functional.py +19 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +8 -2
- liger_kernel/transformers/geglu.py +4 -2
- liger_kernel/transformers/kl_div.py +13 -0
- liger_kernel/transformers/layer_norm.py +30 -0
- liger_kernel/transformers/model/gemma.py +138 -0
- liger_kernel/transformers/model/llama.py +1 -1
- liger_kernel/transformers/model/mistral.py +138 -0
- liger_kernel/transformers/model/mixtral.py +158 -0
- liger_kernel/transformers/model/phi3.py +136 -0
- liger_kernel/transformers/model/qwen2.py +135 -0
- liger_kernel/transformers/model/qwen2_vl.py +172 -0
- liger_kernel/transformers/monkey_patch.py +629 -8
- liger_kernel/transformers/rms_norm.py +23 -4
- liger_kernel/transformers/swiglu.py +24 -0
- liger_kernel/transformers/trainer_integration.py +2 -0
- liger_kernel/triton/monkey_patch.py +0 -2
- liger_kernel-0.3.0.dist-info/METADATA +388 -0
- liger_kernel-0.3.0.dist-info/RECORD +42 -0
- {liger_kernel-0.0.1.dist-info → liger_kernel-0.3.0.dist-info}/WHEEL +1 -1
- liger_kernel-0.0.1.dist-info/METADATA +0 -16
- liger_kernel-0.0.1.dist-info/RECORD +0 -26
- {liger_kernel-0.0.1.dist-info → liger_kernel-0.3.0.dist-info}/LICENSE +0 -0
- {liger_kernel-0.0.1.dist-info → liger_kernel-0.3.0.dist-info}/NOTICE +0 -0
- {liger_kernel-0.0.1.dist-info → liger_kernel-0.3.0.dist-info}/top_level.txt +0 -0
liger_kernel/ops/rms_norm.py
CHANGED
|
@@ -1,28 +1,66 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
|
|
3
|
+
See the original Unsloth repository at https://github.com/unslothai/unsloth.
|
|
4
|
+
|
|
5
|
+
The following line
|
|
6
|
+
https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/rms_norm.py#L30
|
|
7
|
+
is based on code from Unsloth, located at:
|
|
8
|
+
https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
|
|
9
|
+
|
|
10
|
+
Modifications made by Yanning Chen, 2024.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import operator
|
|
14
|
+
|
|
1
15
|
import torch
|
|
2
16
|
import triton
|
|
3
17
|
import triton.language as tl
|
|
4
18
|
|
|
5
|
-
from liger_kernel.ops.utils import
|
|
19
|
+
from liger_kernel.ops.utils import (
|
|
20
|
+
calculate_settings,
|
|
21
|
+
compare_version,
|
|
22
|
+
ensure_contiguous,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
|
26
|
+
try:
|
|
27
|
+
# typical import path with dispatch available
|
|
28
|
+
from triton.language.extra.libdevice import rsqrt
|
|
29
|
+
except ModuleNotFoundError:
|
|
30
|
+
# for working with NGC containers
|
|
31
|
+
from triton.language.extra.cuda.libdevice import rsqrt
|
|
32
|
+
else:
|
|
33
|
+
from triton.language.math import rsqrt
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
_CASTING_MODE_NONE = tl.constexpr(-1)
|
|
37
|
+
_CASTING_MODE_LLAMA = tl.constexpr(0)
|
|
38
|
+
_CASTING_MODE_GEMMA = tl.constexpr(1)
|
|
6
39
|
|
|
7
40
|
|
|
8
41
|
@triton.jit
|
|
9
|
-
def
|
|
42
|
+
def _rms_norm_forward_kernel(
|
|
10
43
|
Y_ptr,
|
|
11
44
|
Y_row_stride,
|
|
12
45
|
X_ptr,
|
|
13
46
|
X_row_stride,
|
|
14
47
|
W_ptr,
|
|
15
48
|
W_row_stride,
|
|
16
|
-
|
|
17
|
-
|
|
49
|
+
RSTD_ptr,
|
|
50
|
+
RSTD_row_stride,
|
|
18
51
|
n_cols,
|
|
19
52
|
eps,
|
|
53
|
+
offset,
|
|
54
|
+
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
|
20
55
|
BLOCK_SIZE: tl.constexpr,
|
|
21
56
|
):
|
|
22
57
|
"""
|
|
58
|
+
y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
|
|
59
|
+
|
|
23
60
|
Reference:
|
|
24
61
|
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
|
25
62
|
2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
|
|
63
|
+
3. https://arxiv.org/pdf/1910.07467
|
|
26
64
|
"""
|
|
27
65
|
|
|
28
66
|
row_idx = tl.program_id(0)
|
|
@@ -31,137 +69,253 @@ def _rms_norm_forward(
|
|
|
31
69
|
|
|
32
70
|
Y_ptr += row_idx * Y_row_stride
|
|
33
71
|
X_ptr += row_idx * X_row_stride
|
|
34
|
-
|
|
72
|
+
RSTD_ptr += row_idx * RSTD_row_stride
|
|
35
73
|
|
|
36
74
|
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
|
|
75
|
+
X_row_dtype = X_row.dtype
|
|
37
76
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
38
77
|
|
|
39
|
-
|
|
40
|
-
|
|
78
|
+
# On Llama, only rstd is computed on fp32
|
|
79
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
80
|
+
X_row = X_row.to(tl.float32)
|
|
81
|
+
|
|
82
|
+
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
|
83
|
+
if casting_mode == _CASTING_MODE_GEMMA:
|
|
84
|
+
W_row = W_row.to(tl.float32)
|
|
85
|
+
X_row = X_row.to(tl.float32)
|
|
86
|
+
|
|
87
|
+
mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
|
|
88
|
+
rstd = rsqrt(mean_square + eps)
|
|
89
|
+
|
|
90
|
+
# We can save time by caching rms with minimal memory overhead
|
|
91
|
+
# because rms is much smaller compared to X_row, as rms is for each row.
|
|
92
|
+
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
|
|
93
|
+
tl.store(RSTD_ptr, rstd)
|
|
41
94
|
|
|
42
|
-
|
|
43
|
-
tl.store(r_ptr, inv_var)
|
|
95
|
+
X_row = X_row * rstd
|
|
44
96
|
|
|
45
|
-
|
|
97
|
+
# On Llama, the multiplication with the weight is done on the original dtype
|
|
98
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
99
|
+
X_row = X_row.to(X_row_dtype)
|
|
46
100
|
|
|
47
|
-
|
|
48
|
-
|
|
101
|
+
Y_row = X_row * (offset + W_row)
|
|
102
|
+
|
|
103
|
+
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
|
|
49
104
|
|
|
50
105
|
|
|
51
106
|
@triton.jit
|
|
52
|
-
def
|
|
107
|
+
def _rms_norm_backward_kernel(
|
|
53
108
|
dY_ptr,
|
|
54
109
|
dY_row_stride,
|
|
55
110
|
X_ptr,
|
|
56
111
|
X_row_stride,
|
|
57
112
|
W_ptr,
|
|
58
113
|
W_row_stride,
|
|
59
|
-
|
|
60
|
-
|
|
114
|
+
RSTD_ptr,
|
|
115
|
+
RSTD_row_stride,
|
|
61
116
|
dW_ptr,
|
|
62
117
|
dW_row_stride,
|
|
63
118
|
n_cols,
|
|
64
|
-
|
|
119
|
+
offset,
|
|
120
|
+
casting_mode: tl.constexpr,
|
|
65
121
|
BLOCK_SIZE: tl.constexpr,
|
|
66
122
|
):
|
|
67
123
|
"""
|
|
68
|
-
dx = (1 /
|
|
69
|
-
dw = sum(dy * (x /
|
|
124
|
+
dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
|
|
125
|
+
dw = sum(dy * (x / RMS)). summation over BxT dimension
|
|
70
126
|
"""
|
|
127
|
+
|
|
71
128
|
row_idx = tl.program_id(0)
|
|
72
129
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
73
130
|
mask = col_offsets < n_cols
|
|
74
131
|
|
|
75
132
|
dY_ptr += row_idx * dY_row_stride
|
|
76
133
|
X_ptr += row_idx * X_row_stride
|
|
77
|
-
|
|
134
|
+
RSTD_ptr += row_idx * RSTD_row_stride
|
|
78
135
|
dW_ptr += row_idx * dW_row_stride
|
|
79
136
|
|
|
80
137
|
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0)
|
|
81
138
|
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
|
|
82
139
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
140
|
+
original_x_dtype = X_row.dtype
|
|
83
141
|
|
|
84
|
-
# Get
|
|
85
|
-
|
|
142
|
+
# Get cached rms
|
|
143
|
+
rstd_row = tl.load(RSTD_ptr)
|
|
86
144
|
|
|
87
|
-
|
|
145
|
+
W_row = W_row + offset
|
|
88
146
|
|
|
89
|
-
|
|
90
|
-
|
|
147
|
+
X_row = X_row.to(tl.float32)
|
|
148
|
+
|
|
149
|
+
# Different bacward graphs for different casting modes
|
|
150
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
151
|
+
m = (dY_row * W_row).to(tl.float32)
|
|
152
|
+
|
|
153
|
+
elif casting_mode == _CASTING_MODE_GEMMA:
|
|
154
|
+
dY_row, W_row = (
|
|
155
|
+
dY_row.to(tl.float32),
|
|
156
|
+
W_row.to(tl.float32),
|
|
157
|
+
)
|
|
91
158
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
159
|
+
m = dY_row * W_row
|
|
160
|
+
|
|
161
|
+
dX_row = rstd_row * m
|
|
162
|
+
|
|
163
|
+
dX_row += (rstd_row) * (
|
|
164
|
+
-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
|
|
165
|
+
)
|
|
95
166
|
|
|
96
167
|
# calculate the gradient of W
|
|
97
|
-
|
|
168
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
169
|
+
dW_row = dY_row * (X_row * rstd_row).to(original_x_dtype)
|
|
170
|
+
else:
|
|
171
|
+
# here X_row is already in fp32 (see previous if block)
|
|
172
|
+
dW_row = dY_row * (X_row * rstd_row)
|
|
98
173
|
|
|
174
|
+
tl.store(dY_ptr + col_offsets, dX_row, mask=mask)
|
|
175
|
+
tl.store(dW_ptr + col_offsets, dW_row, mask=mask)
|
|
99
176
|
|
|
100
|
-
class LigerRMSNormFunction(torch.autograd.Function):
|
|
101
|
-
@staticmethod
|
|
102
|
-
@ensure_contiguous
|
|
103
|
-
def forward(ctx, X, W, eps):
|
|
104
|
-
shape = X.shape
|
|
105
|
-
dim = shape[-1]
|
|
106
|
-
X = X.view(-1, dim)
|
|
107
|
-
n_rows, n_cols = X.shape
|
|
108
|
-
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
109
177
|
|
|
110
|
-
|
|
111
|
-
|
|
178
|
+
_str_to_casting_mode = {
|
|
179
|
+
"llama": _CASTING_MODE_LLAMA.value,
|
|
180
|
+
"gemma": _CASTING_MODE_GEMMA.value,
|
|
181
|
+
"none": _CASTING_MODE_NONE.value,
|
|
182
|
+
}
|
|
112
183
|
|
|
113
|
-
|
|
184
|
+
|
|
185
|
+
def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
186
|
+
if not isinstance(casting_mode, int):
|
|
187
|
+
assert (
|
|
188
|
+
casting_mode in _str_to_casting_mode
|
|
189
|
+
), f"Invalid casting mode: {casting_mode}"
|
|
190
|
+
casting_mode = _str_to_casting_mode[casting_mode]
|
|
191
|
+
else:
|
|
114
192
|
assert (
|
|
115
|
-
|
|
116
|
-
), "
|
|
193
|
+
casting_mode in _str_to_casting_mode.values()
|
|
194
|
+
), f"Invalid casting mode: {casting_mode}"
|
|
117
195
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
196
|
+
shape = X.shape
|
|
197
|
+
dim = shape[-1]
|
|
198
|
+
X = X.view(-1, dim)
|
|
199
|
+
n_rows, n_cols = X.shape
|
|
200
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
201
|
+
|
|
202
|
+
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
203
|
+
# RSTD is to cache rstd for each row
|
|
204
|
+
# RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
|
|
205
|
+
rstd_dtype = (
|
|
206
|
+
torch.float32
|
|
207
|
+
if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value)
|
|
208
|
+
else X.dtype
|
|
209
|
+
)
|
|
210
|
+
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
|
|
211
|
+
|
|
212
|
+
# Check constraints.
|
|
213
|
+
assert (
|
|
214
|
+
X.shape[1] == W.shape[0]
|
|
215
|
+
), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
|
216
|
+
|
|
217
|
+
_rms_norm_forward_kernel[(n_rows,)](
|
|
218
|
+
Y,
|
|
219
|
+
Y.stride(0),
|
|
220
|
+
X,
|
|
221
|
+
X.stride(0),
|
|
222
|
+
W,
|
|
223
|
+
W.stride(0),
|
|
224
|
+
RSTD,
|
|
225
|
+
RSTD.stride(0),
|
|
226
|
+
n_cols,
|
|
227
|
+
eps,
|
|
228
|
+
offset,
|
|
229
|
+
casting_mode,
|
|
230
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
231
|
+
num_warps=num_warps,
|
|
232
|
+
)
|
|
233
|
+
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps):
|
|
237
|
+
shape = dY.shape
|
|
238
|
+
dim = shape[-1]
|
|
239
|
+
dY = dY.view(-1, dim)
|
|
240
|
+
n_rows, n_cols = dY.shape
|
|
241
|
+
dW = torch.empty_like(
|
|
242
|
+
X,
|
|
243
|
+
dtype=(torch.float32 if casting_mode == _CASTING_MODE_GEMMA.value else W.dtype),
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# Here we use dY to store the value of dX to save memory
|
|
247
|
+
_rms_norm_backward_kernel[(n_rows,)](
|
|
248
|
+
dY,
|
|
249
|
+
dY.stride(0),
|
|
250
|
+
X,
|
|
251
|
+
X.stride(0),
|
|
252
|
+
W,
|
|
253
|
+
W.stride(0),
|
|
254
|
+
RSTD,
|
|
255
|
+
RSTD.stride(0),
|
|
256
|
+
dW,
|
|
257
|
+
dW.stride(0),
|
|
258
|
+
n_cols,
|
|
259
|
+
offset,
|
|
260
|
+
casting_mode,
|
|
261
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
262
|
+
num_warps=num_warps,
|
|
263
|
+
)
|
|
264
|
+
dX = dY.view(*shape)
|
|
265
|
+
dW = torch.sum(dW, dim=0).to(W.dtype)
|
|
266
|
+
return dX, dW
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
class LigerRMSNormFunction(torch.autograd.Function):
|
|
270
|
+
"""
|
|
271
|
+
Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the
|
|
272
|
+
weight tensor `W`, with an optional offset and casting mode.
|
|
273
|
+
|
|
274
|
+
Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma
|
|
275
|
+
uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual
|
|
276
|
+
`(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function.
|
|
277
|
+
|
|
278
|
+
In addition, different models cast their inputs at different places during RMSNorm computation. For
|
|
279
|
+
example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the
|
|
280
|
+
inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently
|
|
281
|
+
support the following casting modes (they match HuggingFace Transformers' implementations):
|
|
282
|
+
- 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
|
|
283
|
+
- 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
|
|
284
|
+
- 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
|
|
285
|
+
"""
|
|
286
|
+
|
|
287
|
+
@staticmethod
|
|
288
|
+
@ensure_contiguous
|
|
289
|
+
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama"):
|
|
290
|
+
"""
|
|
291
|
+
X: (B, T, H) or (BxT, H)
|
|
292
|
+
W: (H,)
|
|
293
|
+
"""
|
|
294
|
+
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(
|
|
295
|
+
X, W, eps, offset, casting_mode
|
|
131
296
|
)
|
|
132
|
-
ctx.
|
|
297
|
+
ctx.offset = offset
|
|
298
|
+
ctx.casting_mode = casting_mode
|
|
133
299
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
134
300
|
ctx.num_warps = num_warps
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
return Y.view(*shape)
|
|
301
|
+
ctx.save_for_backward(X, W, RSTD)
|
|
302
|
+
return Y
|
|
138
303
|
|
|
139
304
|
@staticmethod
|
|
140
305
|
@ensure_contiguous
|
|
141
306
|
def backward(ctx, dY):
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
X, W,
|
|
146
|
-
|
|
147
|
-
dW = torch.zeros_like(X)
|
|
148
|
-
|
|
149
|
-
_rms_norm_backward[(n_rows,)](
|
|
307
|
+
"""
|
|
308
|
+
Y: (B, T, H) or (BxT, H)
|
|
309
|
+
"""
|
|
310
|
+
X, W, RSTD = ctx.saved_tensors
|
|
311
|
+
dX, dW = rms_norm_backward(
|
|
150
312
|
dY,
|
|
151
|
-
dY.stride(0),
|
|
152
313
|
X,
|
|
153
|
-
X.stride(0),
|
|
154
314
|
W,
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
n_cols,
|
|
161
|
-
ctx.eps,
|
|
162
|
-
BLOCK_SIZE=ctx.BLOCK_SIZE,
|
|
163
|
-
num_warps=ctx.num_warps,
|
|
315
|
+
RSTD,
|
|
316
|
+
ctx.offset,
|
|
317
|
+
ctx.casting_mode,
|
|
318
|
+
ctx.BLOCK_SIZE,
|
|
319
|
+
ctx.num_warps,
|
|
164
320
|
)
|
|
165
|
-
dX
|
|
166
|
-
dW = torch.sum(dW, dim=0)
|
|
167
|
-
return dX, dW, None
|
|
321
|
+
return dX, dW, None, None, None
|
liger_kernel/ops/rope.py
CHANGED
|
@@ -13,8 +13,8 @@ def _triton_rope(
|
|
|
13
13
|
cos_row_stride,
|
|
14
14
|
sin,
|
|
15
15
|
sin_row_stride,
|
|
16
|
+
sl,
|
|
16
17
|
bs: tl.constexpr,
|
|
17
|
-
sl: tl.constexpr,
|
|
18
18
|
n_qh: tl.constexpr,
|
|
19
19
|
n_kh: tl.constexpr,
|
|
20
20
|
hd: tl.constexpr,
|
|
@@ -117,6 +117,92 @@ def _triton_rope(
|
|
|
117
117
|
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
|
118
118
|
|
|
119
119
|
|
|
120
|
+
def rope_forward(q, k, cos, sin):
|
|
121
|
+
|
|
122
|
+
# transpose it back to the physical shape because Triton looks at the physical storage
|
|
123
|
+
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
|
124
|
+
q = q.transpose(1, 2)
|
|
125
|
+
k = k.transpose(1, 2)
|
|
126
|
+
|
|
127
|
+
batch_size, seq_len, n_q_head, head_dim = q.shape
|
|
128
|
+
n_kv_head = k.shape[2]
|
|
129
|
+
pad_hd = triton.next_power_of_2(head_dim)
|
|
130
|
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
131
|
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
132
|
+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
|
133
|
+
|
|
134
|
+
n_row = batch_size * seq_len
|
|
135
|
+
|
|
136
|
+
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
|
|
137
|
+
q = q.contiguous()
|
|
138
|
+
k = k.contiguous()
|
|
139
|
+
cos = cos.contiguous()
|
|
140
|
+
sin = sin.contiguous()
|
|
141
|
+
|
|
142
|
+
_triton_rope[(n_row,)](
|
|
143
|
+
q,
|
|
144
|
+
q.stride(1),
|
|
145
|
+
k,
|
|
146
|
+
k.stride(1),
|
|
147
|
+
cos,
|
|
148
|
+
cos.stride(-2),
|
|
149
|
+
sin,
|
|
150
|
+
sin.stride(-2),
|
|
151
|
+
seq_len,
|
|
152
|
+
batch_size,
|
|
153
|
+
n_q_head,
|
|
154
|
+
n_kv_head,
|
|
155
|
+
head_dim,
|
|
156
|
+
pad_n_q_head,
|
|
157
|
+
pad_n_kv_head,
|
|
158
|
+
pad_hd,
|
|
159
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
160
|
+
BACKWARD_PASS=False,
|
|
161
|
+
)
|
|
162
|
+
return q.transpose(1, 2), k.transpose(1, 2), cos, sin
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def rope_backward(dq, dk, cos, sin):
|
|
166
|
+
dq = dq.transpose(1, 2)
|
|
167
|
+
dk = dk.transpose(1, 2)
|
|
168
|
+
|
|
169
|
+
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
|
170
|
+
n_kv_head = dk.shape[2]
|
|
171
|
+
pad_hd = triton.next_power_of_2(head_dim)
|
|
172
|
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
173
|
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
174
|
+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
|
175
|
+
|
|
176
|
+
n_row = batch_size * seq_len
|
|
177
|
+
|
|
178
|
+
# ensure dq and dk are contiguous
|
|
179
|
+
dq = dq.contiguous()
|
|
180
|
+
dk = dk.contiguous()
|
|
181
|
+
|
|
182
|
+
# backward is similar to forward except swapping few ops
|
|
183
|
+
_triton_rope[(n_row,)](
|
|
184
|
+
dq,
|
|
185
|
+
dq.stride(1),
|
|
186
|
+
dk,
|
|
187
|
+
dk.stride(1),
|
|
188
|
+
cos,
|
|
189
|
+
cos.stride(-2),
|
|
190
|
+
sin,
|
|
191
|
+
sin.stride(-2),
|
|
192
|
+
seq_len,
|
|
193
|
+
batch_size,
|
|
194
|
+
n_q_head,
|
|
195
|
+
n_kv_head,
|
|
196
|
+
head_dim,
|
|
197
|
+
pad_n_q_head,
|
|
198
|
+
pad_n_kv_head,
|
|
199
|
+
pad_hd,
|
|
200
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
201
|
+
BACKWARD_PASS=True,
|
|
202
|
+
)
|
|
203
|
+
return dq.transpose(1, 2), dk.transpose(1, 2)
|
|
204
|
+
|
|
205
|
+
|
|
120
206
|
class LigerRopeFunction(torch.autograd.Function):
|
|
121
207
|
"""
|
|
122
208
|
Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that
|
|
@@ -138,50 +224,9 @@ class LigerRopeFunction(torch.autograd.Function):
|
|
|
138
224
|
cos size: (1, seq_len, head_dim)
|
|
139
225
|
sin size: (1, seq_len, head_dim)
|
|
140
226
|
"""
|
|
141
|
-
|
|
142
|
-
# transpose it back to the physical shape because Triton looks at the physical storage
|
|
143
|
-
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
|
144
|
-
q = q.transpose(1, 2)
|
|
145
|
-
k = k.transpose(1, 2)
|
|
146
|
-
|
|
147
|
-
batch_size, seq_len, n_q_head, head_dim = q.shape
|
|
148
|
-
n_kv_head = k.shape[2]
|
|
149
|
-
pad_hd = triton.next_power_of_2(head_dim)
|
|
150
|
-
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
151
|
-
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
152
|
-
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
|
153
|
-
|
|
154
|
-
n_row = batch_size * seq_len
|
|
155
|
-
|
|
156
|
-
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
|
|
157
|
-
q = q.contiguous()
|
|
158
|
-
k = k.contiguous()
|
|
159
|
-
cos = cos.contiguous()
|
|
160
|
-
sin = sin.contiguous()
|
|
161
|
-
|
|
162
|
-
_triton_rope[(n_row,)](
|
|
163
|
-
q,
|
|
164
|
-
q.stride(1),
|
|
165
|
-
k,
|
|
166
|
-
k.stride(1),
|
|
167
|
-
cos,
|
|
168
|
-
cos.stride(-2),
|
|
169
|
-
sin,
|
|
170
|
-
sin.stride(-2),
|
|
171
|
-
batch_size,
|
|
172
|
-
seq_len,
|
|
173
|
-
n_q_head,
|
|
174
|
-
n_kv_head,
|
|
175
|
-
head_dim,
|
|
176
|
-
pad_n_q_head,
|
|
177
|
-
pad_n_kv_head,
|
|
178
|
-
pad_hd,
|
|
179
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
|
180
|
-
BACKWARD_PASS=False,
|
|
181
|
-
)
|
|
182
|
-
|
|
227
|
+
q, k, cos, sin = rope_forward(q, k, cos, sin)
|
|
183
228
|
ctx.save_for_backward(cos, sin)
|
|
184
|
-
return q
|
|
229
|
+
return q, k
|
|
185
230
|
|
|
186
231
|
def backward(ctx, dq, dk):
|
|
187
232
|
"""
|
|
@@ -192,43 +237,5 @@ class LigerRopeFunction(torch.autograd.Function):
|
|
|
192
237
|
"""
|
|
193
238
|
|
|
194
239
|
cos, sin = ctx.saved_tensors
|
|
195
|
-
|
|
196
|
-
dq
|
|
197
|
-
dk = dk.transpose(1, 2)
|
|
198
|
-
|
|
199
|
-
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
|
200
|
-
n_kv_head = dk.shape[2]
|
|
201
|
-
pad_hd = triton.next_power_of_2(head_dim)
|
|
202
|
-
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
203
|
-
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
204
|
-
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
|
205
|
-
|
|
206
|
-
n_row = batch_size * seq_len
|
|
207
|
-
|
|
208
|
-
# ensure dq and dk are contiguous
|
|
209
|
-
dq = dq.contiguous()
|
|
210
|
-
dk = dk.contiguous()
|
|
211
|
-
|
|
212
|
-
# backward is similar to forward except swapping few ops
|
|
213
|
-
_triton_rope[(n_row,)](
|
|
214
|
-
dq,
|
|
215
|
-
dq.stride(1),
|
|
216
|
-
dk,
|
|
217
|
-
dk.stride(1),
|
|
218
|
-
cos,
|
|
219
|
-
cos.stride(-2),
|
|
220
|
-
sin,
|
|
221
|
-
sin.stride(-2),
|
|
222
|
-
batch_size,
|
|
223
|
-
seq_len,
|
|
224
|
-
n_q_head,
|
|
225
|
-
n_kv_head,
|
|
226
|
-
head_dim,
|
|
227
|
-
pad_n_q_head,
|
|
228
|
-
pad_n_kv_head,
|
|
229
|
-
pad_hd,
|
|
230
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
|
231
|
-
BACKWARD_PASS=True,
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
return dq.transpose(1, 2), dk.transpose(1, 2), None, None, None, None
|
|
240
|
+
dq, dk = rope_backward(dq, dk, cos, sin)
|
|
241
|
+
return dq, dk, None, None, None, None
|