quack-kernels 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -34
- quack/gemm.py +194 -0
- quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
- quack/gemm_config.py +72 -46
- quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +177 -31
- quack/gemm_sm100.py +729 -506
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +3 -1
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +476 -526
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +23 -16
- quack/topk.py +409 -85
- quack/utils.py +32 -220
- quack/varlen_utils.py +370 -1
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.2.dist-info/RECORD +0 -37
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/rmsnorm.py
CHANGED
|
@@ -1,94 +1,54 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
import math
|
|
4
|
+
from typing import Optional, Tuple, Type
|
|
4
5
|
from functools import partial
|
|
5
6
|
|
|
6
7
|
import cuda.bindings.driver as cuda
|
|
7
8
|
|
|
8
9
|
import cutlass
|
|
9
10
|
import cutlass.cute as cute
|
|
10
|
-
from cutlass import Float32, Int32
|
|
11
|
-
from cutlass import const_expr
|
|
12
|
-
from cutlass.cute.runtime import from_dlpack
|
|
11
|
+
from cutlass import Float32, Int32, const_expr
|
|
13
12
|
|
|
14
13
|
import torch
|
|
15
14
|
from torch import Tensor
|
|
16
15
|
|
|
17
16
|
import quack.utils as utils
|
|
17
|
+
import quack.copy_utils as copy_utils
|
|
18
|
+
import quack.layout_utils as layout_utils
|
|
19
|
+
from quack.compile_utils import make_fake_tensor as fake_tensor
|
|
18
20
|
from quack.reduce import row_reduce
|
|
19
21
|
from quack.reduction_base import ReductionBase
|
|
20
22
|
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
21
23
|
|
|
22
24
|
|
|
23
25
|
class RMSNorm(ReductionBase):
|
|
24
|
-
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
25
|
-
super().__init__(dtype, N, stage=1)
|
|
26
|
-
self.
|
|
26
|
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int, is_layernorm: bool = False):
|
|
27
|
+
super().__init__(dtype, N, stage=2 if is_layernorm else 1)
|
|
28
|
+
self.is_layernorm = is_layernorm
|
|
29
|
+
self.reload_from = None if N <= (16384 if is_layernorm else 8192) else "smem"
|
|
27
30
|
self.delay_w_load = False
|
|
28
31
|
|
|
29
|
-
def
|
|
30
|
-
"""Calculate the number of threads per row for the RMSNorm kernel."""
|
|
32
|
+
def _threads_per_row(self):
|
|
31
33
|
N = self.N
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
elif N <= 3072:
|
|
37
|
-
return 32
|
|
38
|
-
elif N <= 6144:
|
|
39
|
-
return 64
|
|
40
|
-
elif N <= 16384:
|
|
41
|
-
return 128
|
|
42
|
-
else:
|
|
43
|
-
return 256
|
|
34
|
+
for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]:
|
|
35
|
+
if N <= limit:
|
|
36
|
+
return threads
|
|
37
|
+
return 256
|
|
44
38
|
|
|
45
39
|
def _set_cluster_n(self):
|
|
46
|
-
"""
|
|
47
|
-
Set the number of clusters for the RMSNorm kernel.
|
|
48
|
-
Stored in self.cluster_n.
|
|
49
|
-
"""
|
|
50
40
|
N = self.N
|
|
51
|
-
|
|
52
41
|
# cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
|
|
53
42
|
# Similarly cluster_n = 8 is faster for N=128k
|
|
54
43
|
if const_expr(self.dtype.width == 16):
|
|
55
|
-
|
|
56
|
-
if N <= 16 * 1024:
|
|
57
|
-
cluster_n = 1
|
|
58
|
-
elif N <= 32 * 1024:
|
|
59
|
-
cluster_n = 2
|
|
60
|
-
elif N <= 64 * 1024:
|
|
61
|
-
cluster_n = 4
|
|
62
|
-
elif N <= 128 * 1024:
|
|
63
|
-
cluster_n = 8
|
|
64
|
-
else:
|
|
65
|
-
cluster_n = 16
|
|
44
|
+
thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)]
|
|
66
45
|
else:
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
cluster_n = 4
|
|
74
|
-
elif N <= 256 * 1024:
|
|
75
|
-
cluster_n = 8
|
|
76
|
-
else:
|
|
77
|
-
cluster_n = 16
|
|
78
|
-
|
|
79
|
-
self.cluster_n = cluster_n
|
|
80
|
-
|
|
81
|
-
def _smem_size_in_bytes(self, tiler_mn, num_warps, dtype_res=None):
|
|
82
|
-
return (
|
|
83
|
-
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
|
|
84
|
-
+ (
|
|
85
|
-
cute.size_in_bytes(dtype_res, cute.make_layout(tiler_mn))
|
|
86
|
-
if dtype_res is not None
|
|
87
|
-
else 0
|
|
88
|
-
)
|
|
89
|
-
+ self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
|
|
90
|
-
+ self.stage * (cutlass.Int64.width // 8)
|
|
91
|
-
)
|
|
46
|
+
thresholds = [(32 * 1024, 1), (64 * 1024, 2), (128 * 1024, 4), (256 * 1024, 8)]
|
|
47
|
+
for limit, cluster in thresholds:
|
|
48
|
+
if N <= limit:
|
|
49
|
+
self.cluster_n = cluster
|
|
50
|
+
return
|
|
51
|
+
self.cluster_n = 16
|
|
92
52
|
|
|
93
53
|
@cute.jit
|
|
94
54
|
def __call__(
|
|
@@ -100,60 +60,32 @@ class RMSNorm(ReductionBase):
|
|
|
100
60
|
mO: cute.Tensor,
|
|
101
61
|
mResO: Optional[cute.Tensor],
|
|
102
62
|
mRstd: Optional[cute.Tensor],
|
|
63
|
+
mMean: Optional[cute.Tensor],
|
|
64
|
+
eps: Float32,
|
|
103
65
|
stream: cuda.CUstream,
|
|
104
|
-
eps: Float32 = 1e-6,
|
|
105
66
|
):
|
|
106
|
-
semistatic_shape = (*mX.shape[:-1], self.N) # Set last dimension to be statically N
|
|
107
|
-
new_stride = lambda t: (
|
|
108
|
-
cute.assume(t.stride[0], divby=128 // t.element_type.width),
|
|
109
|
-
t.stride[1],
|
|
110
|
-
)
|
|
111
|
-
mX, mRes, mO, mResO = [
|
|
112
|
-
cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
|
|
113
|
-
if const_expr(t is not None)
|
|
114
|
-
else None
|
|
115
|
-
for t in (mX, mRes, mO, mResO)
|
|
116
|
-
]
|
|
117
67
|
assert mX.element_type == self.dtype
|
|
118
|
-
assert mO.element_type == self.dtype
|
|
119
68
|
self._set_cluster_n()
|
|
120
69
|
largest_dtype_width = const_expr(
|
|
121
|
-
max(
|
|
122
|
-
mX.element_type.width,
|
|
123
|
-
mRes.element_type.width if mRes is not None else 0,
|
|
124
|
-
mO.element_type.width,
|
|
125
|
-
mResO.element_type.width if mResO is not None else 0,
|
|
126
|
-
)
|
|
127
|
-
)
|
|
128
|
-
tiler_mn, tv_layout = self._get_tv_layout(
|
|
129
|
-
num_copy_bits=128 // largest_dtype_width * mX.element_type.width
|
|
70
|
+
max(*(t.element_type.width for t in [mX, mRes, mW, mB, mO, mResO] if t is not None))
|
|
130
71
|
)
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
)
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
mB = cute.make_tensor(mB.iterator, mB_expanded_layout)
|
|
143
|
-
if const_expr(mRstd is not None):
|
|
144
|
-
mRstd_expanded_layout = cute.append(
|
|
145
|
-
mRstd.layout, cute.make_layout((self.N,), stride=(0,))
|
|
146
|
-
)
|
|
147
|
-
mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
|
|
72
|
+
vecsize = math.gcd(self.N, 128 // largest_dtype_width)
|
|
73
|
+
tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize)
|
|
74
|
+
num_threads = tiled_copy.size
|
|
75
|
+
mW, mB = [
|
|
76
|
+
layout_utils.expand(mT, dim=0, size=tiler_mn[0]) if const_expr(mT is not None) else None
|
|
77
|
+
for mT in (mW, mB)
|
|
78
|
+
]
|
|
79
|
+
mRstd, mMean = [
|
|
80
|
+
layout_utils.expand(mT, dim=1, size=self.N) if const_expr(mT is not None) else None
|
|
81
|
+
for mT in (mRstd, mMean)
|
|
82
|
+
]
|
|
148
83
|
self.kernel(
|
|
149
|
-
mX, mW, mB, mRes, mO, mResO, mRstd,
|
|
84
|
+
mX, mW, mB, mRes, mO, mResO, mRstd, mMean, eps, tiler_mn, tiled_copy, threads_per_row
|
|
150
85
|
).launch(
|
|
151
86
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
152
87
|
block=[num_threads, 1, 1],
|
|
153
|
-
cluster=
|
|
154
|
-
smem=self._smem_size_in_bytes(
|
|
155
|
-
tiler_mn, num_warps, dtype_res=mRes.element_type if mRes is not None else None
|
|
156
|
-
),
|
|
88
|
+
cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None,
|
|
157
89
|
stream=stream,
|
|
158
90
|
)
|
|
159
91
|
|
|
@@ -167,24 +99,20 @@ class RMSNorm(ReductionBase):
|
|
|
167
99
|
mO: cute.Tensor,
|
|
168
100
|
mResO: Optional[cute.Tensor],
|
|
169
101
|
mRstd: Optional[cute.Tensor],
|
|
170
|
-
|
|
171
|
-
|
|
102
|
+
mMean: Optional[cute.Tensor],
|
|
103
|
+
eps: Float32,
|
|
172
104
|
tiler_mn: cute.Shape,
|
|
173
|
-
|
|
174
|
-
|
|
105
|
+
tiled_copy: cute.TiledCopy,
|
|
106
|
+
threads_per_row: cutlass.Constexpr[int],
|
|
175
107
|
):
|
|
176
108
|
tidx, _, _ = cute.arch.thread_idx()
|
|
177
109
|
bidx, _, _ = cute.arch.block_idx()
|
|
178
|
-
if const_expr(self.cluster_n
|
|
179
|
-
|
|
180
|
-
else:
|
|
181
|
-
cluster_y = const_expr(0)
|
|
110
|
+
cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1]
|
|
111
|
+
tv_layout = tiled_copy.layout_tv_tiled
|
|
182
112
|
|
|
183
113
|
smem = cutlass.utils.SmemAllocator()
|
|
184
114
|
sX = smem.allocate_tensor(
|
|
185
|
-
mX.element_type,
|
|
186
|
-
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
187
|
-
byte_alignment=16,
|
|
115
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
188
116
|
)
|
|
189
117
|
if const_expr(mRes is not None):
|
|
190
118
|
sRes = smem.allocate_tensor(
|
|
@@ -197,34 +125,16 @@ class RMSNorm(ReductionBase):
|
|
|
197
125
|
shape = mX.shape
|
|
198
126
|
idX = cute.make_identity_tensor(shape)
|
|
199
127
|
# slice for CTAs
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
for mT in (mX, mRes, mO, mResO)
|
|
128
|
+
gX, gRes, gO, gResO, gRstd, gMean, cX = [
|
|
129
|
+
cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) if mT is not None else None
|
|
130
|
+
for mT in (mX, mRes, mO, mResO, mRstd, mMean, idX)
|
|
204
131
|
]
|
|
205
|
-
gX, gRes, gO, gResO = [
|
|
206
|
-
cute.local_tile(mT, tiler_mn, (0, cluster_y)) if mT is not None else None
|
|
207
|
-
for mT in (mX, mRes, mO, mResO)
|
|
208
|
-
]
|
|
209
|
-
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
210
132
|
gW, gB = [
|
|
211
133
|
cute.local_tile(mT, tiler_mn, (0, cluster_y)) if const_expr(mT is not None) else None
|
|
212
134
|
for mT in (mW, mB)
|
|
213
135
|
]
|
|
214
|
-
gRstd = (
|
|
215
|
-
cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
|
|
216
|
-
if const_expr(mRstd is not None)
|
|
217
|
-
else None
|
|
218
|
-
)
|
|
219
136
|
|
|
220
|
-
|
|
221
|
-
num_copy_elems_X = tv_layout.shape[1][0]
|
|
222
|
-
copy_atom_load_X_async = utils.get_copy_atom(
|
|
223
|
-
mX.element_type, num_copy_elems_X, is_async=True
|
|
224
|
-
)
|
|
225
|
-
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
|
|
226
|
-
tidx
|
|
227
|
-
)
|
|
137
|
+
thr_copy_X = tiled_copy.get_slice(tidx)
|
|
228
138
|
|
|
229
139
|
tXgW = thr_copy_X.partition_S(gW) if const_expr(mW is not None) else None
|
|
230
140
|
tXgB = thr_copy_X.partition_S(gB) if const_expr(mB is not None) else None
|
|
@@ -237,26 +147,27 @@ class RMSNorm(ReductionBase):
|
|
|
237
147
|
if const_expr(mResO is not None):
|
|
238
148
|
tXgResO = thr_copy_X.partition_D(gResO)
|
|
239
149
|
tXrRstd = thr_copy_X.partition_D(gRstd) if const_expr(mRstd is not None) else None
|
|
150
|
+
tXrMean = thr_copy_X.partition_D(gMean) if const_expr(mMean is not None) else None
|
|
240
151
|
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
241
152
|
|
|
242
153
|
# allocate fragments for gmem->rmem
|
|
243
154
|
tXrW = cute.make_fragment_like(tXgW) if const_expr(mW is not None) else None
|
|
244
|
-
if const_expr(mW is not None):
|
|
245
|
-
tXrW.fill(0.0)
|
|
246
155
|
tXrB = cute.make_fragment_like(tXgB) if const_expr(mB is not None) else None
|
|
247
156
|
tXrX, tXrO = [cute.make_fragment_like(t) for t in (tXgX, tXgO)]
|
|
248
157
|
if const_expr(mRes is not None):
|
|
249
158
|
tXrRes = cute.make_fragment_like(tXgRes)
|
|
250
159
|
|
|
251
|
-
num_warps = cute.size(
|
|
160
|
+
num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE
|
|
252
161
|
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
253
162
|
|
|
254
|
-
is_even_N =
|
|
163
|
+
is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
255
164
|
tXpX = (
|
|
256
|
-
|
|
165
|
+
copy_utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
166
|
+
if not is_even_N
|
|
167
|
+
else None
|
|
257
168
|
)
|
|
258
|
-
# Each copy will use the same
|
|
259
|
-
copy = partial(
|
|
169
|
+
# Each copy will use the same predicate
|
|
170
|
+
copy = partial(copy_utils.copy, pred=tXpX)
|
|
260
171
|
|
|
261
172
|
row = tXcX[0][0]
|
|
262
173
|
if row < shape[0]:
|
|
@@ -265,7 +176,7 @@ class RMSNorm(ReductionBase):
|
|
|
265
176
|
copy(tXgRes, tXsRes, is_async=True)
|
|
266
177
|
cute.arch.cp_async_commit_group()
|
|
267
178
|
|
|
268
|
-
if const_expr(not delay_w_load):
|
|
179
|
+
if const_expr(not self.delay_w_load):
|
|
269
180
|
if const_expr(mW is not None):
|
|
270
181
|
copy(tXgW, tXrW)
|
|
271
182
|
if const_expr(mB is not None):
|
|
@@ -283,17 +194,61 @@ class RMSNorm(ReductionBase):
|
|
|
283
194
|
if row < shape[0]:
|
|
284
195
|
copy(tXrResO, tXgResO)
|
|
285
196
|
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
197
|
+
mean, rstd = None, None
|
|
198
|
+
if const_expr(self.is_layernorm):
|
|
199
|
+
# LayerNorm: compute mean first, then variance
|
|
200
|
+
sum_x = row_reduce(
|
|
201
|
+
x,
|
|
202
|
+
cute.ReductionOp.ADD,
|
|
203
|
+
threads_per_row,
|
|
204
|
+
reduction_buffer[None, None, 0],
|
|
205
|
+
mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None,
|
|
206
|
+
init_val=0.0,
|
|
207
|
+
hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
|
|
208
|
+
)
|
|
209
|
+
mean = sum_x / shape[1]
|
|
210
|
+
if const_expr(mMean is not None):
|
|
211
|
+
# Only the thread corresponding to column 0 writes out the mean to gmem
|
|
212
|
+
if (
|
|
213
|
+
tXcX[0][1] == 0
|
|
214
|
+
and row < shape[0]
|
|
215
|
+
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
216
|
+
):
|
|
217
|
+
tXrMean[0] = mean
|
|
218
|
+
if const_expr(self.reload_from == "smem"):
|
|
219
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
220
|
+
x = tXrX.load().to(cute.Float32)
|
|
221
|
+
if const_expr(mRes is not None):
|
|
222
|
+
cute.autovec_copy(tXsRes, tXrRes)
|
|
223
|
+
x += tXrRes.load().to(cute.Float32)
|
|
224
|
+
elif const_expr(self.reload_from == "gmem"):
|
|
225
|
+
copy(tXgX, tXrX)
|
|
226
|
+
x = tXrX.load().to(cute.Float32)
|
|
227
|
+
if const_expr(mRes is not None):
|
|
228
|
+
copy(tXgRes, tXrRes)
|
|
229
|
+
x += tXrRes.load().to(cute.Float32)
|
|
230
|
+
sum_sq_x_sub_mean = row_reduce(
|
|
231
|
+
(x - mean) * (x - mean),
|
|
232
|
+
cute.ReductionOp.ADD,
|
|
233
|
+
threads_per_row,
|
|
234
|
+
reduction_buffer[None, None, 1],
|
|
235
|
+
mbar_ptr + 1 if const_expr(self.cluster_n > 1) else None,
|
|
236
|
+
init_val=0.0,
|
|
237
|
+
)
|
|
238
|
+
rstd = cute.math.rsqrt(sum_sq_x_sub_mean / shape[1] + eps, fastmath=True)
|
|
239
|
+
else:
|
|
240
|
+
# RMSNorm: compute sum of squares directly
|
|
241
|
+
mean = const_expr(0.0)
|
|
242
|
+
sum_sq_x = row_reduce(
|
|
243
|
+
x * x,
|
|
244
|
+
cute.ReductionOp.ADD,
|
|
245
|
+
threads_per_row,
|
|
246
|
+
reduction_buffer[None, None, 0],
|
|
247
|
+
mbar_ptr,
|
|
248
|
+
init_val=0.0,
|
|
249
|
+
hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
|
|
250
|
+
)
|
|
251
|
+
rstd = cute.math.rsqrt(sum_sq_x / shape[1] + eps, fastmath=True)
|
|
297
252
|
if const_expr(mRstd is not None):
|
|
298
253
|
# Only the thread corresponding to column 0 writes out the rstd to gmem
|
|
299
254
|
if (
|
|
@@ -302,21 +257,24 @@ class RMSNorm(ReductionBase):
|
|
|
302
257
|
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
303
258
|
):
|
|
304
259
|
tXrRstd[0] = rstd
|
|
305
|
-
if const_expr(delay_w_load):
|
|
260
|
+
if const_expr(self.delay_w_load):
|
|
306
261
|
if const_expr(mW is not None):
|
|
307
262
|
copy(tXgW, tXrW)
|
|
308
263
|
if const_expr(mB is not None):
|
|
309
264
|
copy(tXgB, tXrB)
|
|
310
|
-
if const_expr(reload_from == "smem" or reload_from == "gmem"):
|
|
311
|
-
if const_expr(reload_from == "smem"):
|
|
265
|
+
if const_expr(self.reload_from == "smem" or self.reload_from == "gmem"):
|
|
266
|
+
if const_expr(self.reload_from == "smem"):
|
|
312
267
|
cute.autovec_copy(tXsX, tXrX)
|
|
268
|
+
if const_expr(mRes is not None):
|
|
269
|
+
cute.autovec_copy(tXsRes, tXrRes)
|
|
313
270
|
else:
|
|
314
271
|
copy(tXgX, tXrX)
|
|
272
|
+
if const_expr(mRes is not None):
|
|
273
|
+
copy(tXgRes, tXrRes)
|
|
315
274
|
x = tXrX.load().to(cute.Float32)
|
|
316
275
|
if const_expr(mRes is not None):
|
|
317
|
-
cute.autovec_copy(tXsRes, tXrRes)
|
|
318
276
|
x += tXrRes.load().to(cute.Float32)
|
|
319
|
-
x_hat = x * rstd
|
|
277
|
+
x_hat = (x - mean) * rstd if const_expr(self.is_layernorm) else x * rstd
|
|
320
278
|
y = x_hat
|
|
321
279
|
if const_expr(mW is not None):
|
|
322
280
|
y *= tXrW.load().to(cute.Float32)
|
|
@@ -329,10 +287,10 @@ class RMSNorm(ReductionBase):
|
|
|
329
287
|
|
|
330
288
|
@torch.library.custom_op(
|
|
331
289
|
"quack::_rmsnorm_fwd",
|
|
332
|
-
mutates_args=("out", "rstd", "residual_out"),
|
|
290
|
+
mutates_args=("out", "rstd", "mean", "residual_out"),
|
|
333
291
|
device_types="cuda",
|
|
334
292
|
# We need to specify the schema manually since we're mutating an optional tensor
|
|
335
|
-
schema="(Tensor x, Tensor? weight, Tensor(a2!) out, Tensor? bias, Tensor(a4!)? rstd, Tensor? residual, Tensor(
|
|
293
|
+
schema="(Tensor x, Tensor? weight, Tensor(a2!) out, Tensor? bias, Tensor(a4!)? rstd, Tensor(a5!)? mean, Tensor? residual, Tensor(a7!)? residual_out, float eps=1e-6, bool is_layernorm=False) -> ()",
|
|
336
294
|
)
|
|
337
295
|
def _rmsnorm_fwd(
|
|
338
296
|
x: Tensor,
|
|
@@ -340,102 +298,73 @@ def _rmsnorm_fwd(
|
|
|
340
298
|
out: Tensor,
|
|
341
299
|
bias: Optional[Tensor] = None,
|
|
342
300
|
rstd: Optional[Tensor] = None,
|
|
301
|
+
mean: Optional[Tensor] = None,
|
|
343
302
|
residual: Optional[Tensor] = None,
|
|
344
303
|
residual_out: Optional[Tensor] = None,
|
|
345
304
|
eps: float = 1e-6,
|
|
305
|
+
is_layernorm: bool = False,
|
|
346
306
|
) -> None:
|
|
347
|
-
"""RMSNorm forward pass.
|
|
307
|
+
"""RMSNorm/LayerNorm forward pass.
|
|
348
308
|
Args:
|
|
349
309
|
x: Input tensor of shape (M, N)
|
|
350
310
|
weight: Optional weight tensor of shape (N,)
|
|
351
311
|
eps: Small value for numerical stability
|
|
312
|
+
is_layernorm: If True, compute LayerNorm instead of RMSNorm
|
|
352
313
|
Returns:
|
|
353
314
|
Normalized output tensor of same shape as x
|
|
354
315
|
"""
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
assert x.dtype in
|
|
316
|
+
# Don't need to check is_cuda since torch.library ensures that
|
|
317
|
+
supported_types = {torch.float16, torch.bfloat16, torch.float32}
|
|
318
|
+
assert x.dtype in supported_types, "Unsupported dtype"
|
|
358
319
|
if weight is not None:
|
|
359
|
-
assert weight.
|
|
360
|
-
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
361
|
-
assert weight.is_cuda, "Weight tensor must be on CUDA device"
|
|
362
|
-
assert weight.dtype in [
|
|
363
|
-
torch.float32,
|
|
364
|
-
torch.bfloat16,
|
|
365
|
-
torch.float16,
|
|
366
|
-
], "Weight must be float32, float16 or bfloat16"
|
|
320
|
+
assert weight.dtype in supported_types, "Weight must be float32, float16 or bfloat16"
|
|
367
321
|
if residual is not None:
|
|
368
|
-
assert residual.
|
|
369
|
-
assert residual.is_cuda
|
|
370
|
-
assert residual.dtype in [
|
|
371
|
-
torch.float16,
|
|
372
|
-
torch.bfloat16,
|
|
373
|
-
torch.float32,
|
|
374
|
-
], "Residual must be float16, bfloat16, or float32"
|
|
322
|
+
assert residual.dtype in supported_types, "Residual must be float16, bfloat16, or float32"
|
|
375
323
|
|
|
376
324
|
_, N = x.shape
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
381
|
-
)
|
|
382
|
-
x_tensor, res_tensor, out_tensor, res_out_tensor = [
|
|
383
|
-
convert_from_dlpack(t) if t is not None else None for t in (x, residual, out, residual_out)
|
|
325
|
+
dtype, out_dtype, weight_dtype, bias_dtype, res_dtype, res_out_dtype = [
|
|
326
|
+
torch2cute_dtype_map[t.dtype] if t is not None else None
|
|
327
|
+
for t in [x, out, weight, bias, residual, residual_out]
|
|
384
328
|
]
|
|
385
|
-
# handle weight divisibility based on weight dtype
|
|
386
|
-
if weight is not None:
|
|
387
|
-
weight_dtype = torch2cute_dtype_map[weight.dtype]
|
|
388
|
-
weight_tensor = utils.convert_from_dlpack(
|
|
389
|
-
weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
|
|
390
|
-
)
|
|
391
|
-
else:
|
|
392
|
-
weight_tensor = None
|
|
393
|
-
if bias is not None:
|
|
394
|
-
bias_dtype = torch2cute_dtype_map[bias.dtype]
|
|
395
|
-
bias_tensor = utils.convert_from_dlpack(
|
|
396
|
-
bias.detach(), leading_dim=0, divisibility=128 // bias_dtype.width
|
|
397
|
-
)
|
|
398
|
-
else:
|
|
399
|
-
bias_tensor = None
|
|
400
|
-
rstd_tensor = (
|
|
401
|
-
from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
402
|
-
if rstd is not None
|
|
403
|
-
else None
|
|
404
|
-
)
|
|
405
|
-
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
406
329
|
compile_key = (
|
|
407
|
-
N,
|
|
408
330
|
dtype,
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
331
|
+
out_dtype,
|
|
332
|
+
res_dtype,
|
|
333
|
+
weight_dtype,
|
|
334
|
+
bias_dtype,
|
|
335
|
+
res_out_dtype,
|
|
336
|
+
N,
|
|
413
337
|
rstd is not None,
|
|
338
|
+
mean is not None,
|
|
339
|
+
is_layernorm,
|
|
414
340
|
)
|
|
415
341
|
if compile_key not in _rmsnorm_fwd.compile_cache:
|
|
416
|
-
|
|
342
|
+
batch_sym = cute.sym_int()
|
|
343
|
+
all_dtypes = [dtype, out_dtype, res_dtype, weight_dtype, bias_dtype, res_out_dtype]
|
|
344
|
+
div = math.gcd(N, *(128 // dt.width for dt in all_dtypes if dt is not None))
|
|
345
|
+
x_cute, out_cute, res_cute, res_out_cute = [
|
|
346
|
+
fake_tensor(dt, (batch_sym, N), div)
|
|
347
|
+
for dt in [dtype, out_dtype, res_dtype, res_out_dtype]
|
|
348
|
+
]
|
|
349
|
+
weight_cute, bias_cute = [fake_tensor(dt, (N,), div) for dt in [weight_dtype, bias_dtype]]
|
|
350
|
+
rstd_cute = fake_tensor(Float32, (batch_sym,)) if rstd is not None else None
|
|
351
|
+
mean_cute = fake_tensor(Float32, (batch_sym,)) if mean is not None else None
|
|
417
352
|
_rmsnorm_fwd.compile_cache[compile_key] = cute.compile(
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
eps,
|
|
353
|
+
RMSNorm(dtype, N, is_layernorm=is_layernorm),
|
|
354
|
+
x_cute,
|
|
355
|
+
weight_cute,
|
|
356
|
+
bias_cute,
|
|
357
|
+
res_cute,
|
|
358
|
+
out_cute,
|
|
359
|
+
res_out_cute,
|
|
360
|
+
rstd_cute,
|
|
361
|
+
mean_cute,
|
|
362
|
+
Float32(0), # eps, just for compilation
|
|
363
|
+
cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
|
|
364
|
+
options="--enable-tvm-ffi",
|
|
428
365
|
)
|
|
429
366
|
_rmsnorm_fwd.compile_cache[compile_key](
|
|
430
|
-
|
|
431
|
-
weight_tensor,
|
|
432
|
-
bias_tensor,
|
|
433
|
-
res_tensor,
|
|
434
|
-
out_tensor,
|
|
435
|
-
res_out_tensor,
|
|
436
|
-
rstd_tensor,
|
|
437
|
-
current_stream,
|
|
438
|
-
eps,
|
|
367
|
+
x, weight, bias, residual, out, residual_out, rstd, mean, eps
|
|
439
368
|
)
|
|
440
369
|
|
|
441
370
|
|
|
@@ -466,7 +395,7 @@ def rmsnorm_fwd(
|
|
|
466
395
|
)
|
|
467
396
|
else:
|
|
468
397
|
residual_out = None
|
|
469
|
-
_rmsnorm_fwd(x, weight, out, bias, rstd, residual, residual_out, eps
|
|
398
|
+
_rmsnorm_fwd(x, weight, out, bias, rstd, None, residual, residual_out, eps, False)
|
|
470
399
|
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
|
|
471
400
|
if residual_out is None:
|
|
472
401
|
residual_out = x
|
|
@@ -492,13 +421,19 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
|
|
|
492
421
|
"""Reference implementation for RMSNorm backward pass."""
|
|
493
422
|
x_f32 = x.float()
|
|
494
423
|
x_hat = x_f32 * rstd.unsqueeze(1)
|
|
495
|
-
|
|
424
|
+
if w is not None:
|
|
425
|
+
wdy = dout * w
|
|
426
|
+
else:
|
|
427
|
+
wdy = dout
|
|
496
428
|
c1 = (x_hat * wdy).mean(dim=-1, keepdim=True)
|
|
497
429
|
dx = (wdy - x_hat * c1) * rstd.unsqueeze(1)
|
|
498
430
|
|
|
499
431
|
# dL/dW
|
|
500
|
-
|
|
501
|
-
|
|
432
|
+
if w is not None:
|
|
433
|
+
dw = (dout * x_hat).sum(dim=0)
|
|
434
|
+
return dx.to(x.dtype), dw.to(w.dtype)
|
|
435
|
+
else:
|
|
436
|
+
return dx.to(x.dtype), None
|
|
502
437
|
|
|
503
438
|
|
|
504
439
|
class RMSNormBackward(ReductionBase):
|
|
@@ -510,94 +445,57 @@ class RMSNormBackward(ReductionBase):
|
|
|
510
445
|
# Not enough smem
|
|
511
446
|
raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits")
|
|
512
447
|
|
|
513
|
-
def
|
|
448
|
+
def _num_threads(self):
|
|
514
449
|
return 128 if self.N <= 4096 else 256
|
|
515
450
|
|
|
516
|
-
def
|
|
451
|
+
def _threads_per_row(self):
|
|
517
452
|
N = self.N
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
16
|
|
523
|
-
if N <= 128
|
|
524
|
-
else (32 if N <= 256 else (64 if N <= 512 else (128 if N <= 4096 else 256)))
|
|
525
|
-
)
|
|
526
|
-
)
|
|
453
|
+
for limit, threads in [(64, 8), (128, 16), (256, 32), (512, 64), (4096, 128)]:
|
|
454
|
+
if N <= limit:
|
|
455
|
+
return threads
|
|
456
|
+
return 256
|
|
527
457
|
|
|
528
458
|
def _set_cluster_n(self):
|
|
529
459
|
N = self.N
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
self.cluster_n = cluster_n
|
|
536
|
-
|
|
537
|
-
def _smem_size_in_bytes(self, tiler_mn, num_warps, do_dtype=None):
|
|
538
|
-
if do_dtype is None:
|
|
539
|
-
do_dtype = self.dtype
|
|
540
|
-
return (
|
|
541
|
-
# We need space for X and dO, and multiply by 2 due to double buffering
|
|
542
|
-
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2
|
|
543
|
-
+ cute.size_in_bytes(do_dtype, cute.make_layout(tiler_mn)) * 2
|
|
544
|
-
+ self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
|
|
545
|
-
+ self.stage * (cutlass.Int64.width // 8) * 2 # mult 2 as we need 2 mbar per stage
|
|
546
|
-
)
|
|
460
|
+
for limit, cluster in [(8 * 1024, 1), (16 * 1024, 2), (32 * 1024, 4), (64 * 1024, 8)]:
|
|
461
|
+
if N <= limit:
|
|
462
|
+
self.cluster_n = cluster
|
|
463
|
+
return
|
|
464
|
+
self.cluster_n = 16
|
|
547
465
|
|
|
548
466
|
@cute.jit
|
|
549
467
|
def __call__(
|
|
550
468
|
self,
|
|
551
469
|
mX: cute.Tensor,
|
|
552
|
-
mW: cute.Tensor,
|
|
470
|
+
mW: Optional[cute.Tensor],
|
|
553
471
|
mdO: cute.Tensor,
|
|
554
472
|
mdResO: Optional[cute.Tensor],
|
|
555
473
|
mRstd: cute.Tensor,
|
|
556
474
|
mdX: cute.Tensor,
|
|
557
|
-
mdW: cute.Tensor,
|
|
475
|
+
mdW: Optional[cute.Tensor],
|
|
558
476
|
mdRes: Optional[cute.Tensor],
|
|
559
477
|
mdB: Optional[cute.Tensor],
|
|
560
478
|
sm_count: Int32,
|
|
561
479
|
stream: cuda.CUstream,
|
|
562
480
|
):
|
|
563
|
-
|
|
564
|
-
new_stride = lambda t: (
|
|
565
|
-
cute.assume(t.stride[0], divby=128 // t.element_type.width),
|
|
566
|
-
t.stride[1],
|
|
567
|
-
)
|
|
568
|
-
mX, mdO, mdResO, mdX, mdRes = [
|
|
569
|
-
cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
|
|
570
|
-
if const_expr(t is not None)
|
|
571
|
-
else None
|
|
572
|
-
for t in (mX, mdO, mdResO, mdX, mdRes)
|
|
573
|
-
]
|
|
481
|
+
assert mX.element_type == self.dtype
|
|
574
482
|
self._set_cluster_n()
|
|
575
483
|
largest_dtype_width = const_expr(
|
|
576
|
-
max(
|
|
577
|
-
mX.element_type.width,
|
|
578
|
-
mdO.element_type.width,
|
|
579
|
-
mdX.element_type.width,
|
|
580
|
-
mdResO.element_type.width if mdResO is not None else 0,
|
|
581
|
-
mdRes.element_type.width if mdRes is not None else 0,
|
|
582
|
-
)
|
|
484
|
+
max(*(t.element_type.width for t in [mX, mW, mdO, mdResO, mdX, mdRes] if t is not None))
|
|
583
485
|
)
|
|
584
|
-
|
|
585
|
-
|
|
486
|
+
vecsize = math.gcd(self.N, 128 // largest_dtype_width)
|
|
487
|
+
tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize)
|
|
488
|
+
num_threads = tiled_copy.size
|
|
489
|
+
mW = (
|
|
490
|
+
layout_utils.expand(mW, dim=0, size=tiler_mn[0]) if const_expr(mW is not None) else None
|
|
586
491
|
)
|
|
587
|
-
num_threads = cute.size(tv_layout, mode=[0])
|
|
588
|
-
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
589
|
-
if const_expr(mW is not None):
|
|
590
|
-
mW_expanded_layout = cute.prepend(
|
|
591
|
-
mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
|
|
592
|
-
)
|
|
593
|
-
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
594
|
-
|
|
595
492
|
num_blocks = sm_count
|
|
596
|
-
self.kernel(
|
|
493
|
+
self.kernel(
|
|
494
|
+
mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tiler_mn, tiled_copy, threads_per_row
|
|
495
|
+
).launch(
|
|
597
496
|
grid=[num_blocks, self.cluster_n, 1],
|
|
598
497
|
block=[num_threads, 1, 1],
|
|
599
498
|
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
600
|
-
smem=self._smem_size_in_bytes(tiler_mn, num_warps, do_dtype=mdO.element_type),
|
|
601
499
|
stream=stream,
|
|
602
500
|
)
|
|
603
501
|
|
|
@@ -605,24 +503,23 @@ class RMSNormBackward(ReductionBase):
|
|
|
605
503
|
def kernel(
|
|
606
504
|
self,
|
|
607
505
|
mX: cute.Tensor,
|
|
608
|
-
mW: cute.Tensor,
|
|
506
|
+
mW: Optional[cute.Tensor],
|
|
609
507
|
mdO: cute.Tensor,
|
|
610
508
|
mdResO: Optional[cute.Tensor],
|
|
611
509
|
mRstd: cute.Tensor,
|
|
612
510
|
mdX: cute.Tensor,
|
|
613
|
-
mdW: cute.Tensor,
|
|
511
|
+
mdW: Optional[cute.Tensor],
|
|
614
512
|
mdB: Optional[cute.Tensor],
|
|
615
513
|
mdRes: Optional[cute.Tensor],
|
|
616
|
-
tv_layout: cute.Layout,
|
|
617
514
|
tiler_mn: cute.Shape,
|
|
515
|
+
tiled_copy: cute.TiledCopy,
|
|
516
|
+
threads_per_row: cutlass.Constexpr[int],
|
|
618
517
|
):
|
|
619
518
|
tidx, _, _ = cute.arch.thread_idx()
|
|
620
519
|
bidx_start, _, _ = cute.arch.block_idx()
|
|
621
520
|
gdim, _, _ = cute.arch.grid_dim()
|
|
622
|
-
if const_expr(self.cluster_n
|
|
623
|
-
|
|
624
|
-
else:
|
|
625
|
-
cluster_y = const_expr(0)
|
|
521
|
+
cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1]
|
|
522
|
+
tv_layout = tiled_copy.layout_tv_tiled
|
|
626
523
|
|
|
627
524
|
shape = mX.shape
|
|
628
525
|
M, N = shape[0], shape[1]
|
|
@@ -642,63 +539,20 @@ class RMSNormBackward(ReductionBase):
|
|
|
642
539
|
else:
|
|
643
540
|
mbar_full_ptr, mbar_empty_ptr = None, None
|
|
644
541
|
|
|
645
|
-
|
|
646
|
-
copy_atom_load_X = utils.get_copy_atom(mX.element_type, num_copy_elems_X, is_async=False)
|
|
647
|
-
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
648
|
-
# Each copy will use the same number of elements as X
|
|
649
|
-
copy = partial(utils.copy, num_copy_elems=num_copy_elems_X)
|
|
650
|
-
|
|
651
|
-
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
652
|
-
tXgW = thr_copy_X.partition_S(gW)
|
|
653
|
-
tXrW = cute.make_fragment_like(tXgW)
|
|
654
|
-
# Need this, otherwise rW can have arbitrary values that changes the reduction
|
|
655
|
-
if not is_even_N:
|
|
656
|
-
tXrW.fill(0.0)
|
|
657
|
-
|
|
658
|
-
gW_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
|
|
659
|
-
tXpW = (
|
|
660
|
-
utils.predicate_k(thr_copy_X.partition_S(gW_coord), limit=shape[1])
|
|
661
|
-
if not is_even_N
|
|
662
|
-
else None
|
|
663
|
-
)
|
|
664
|
-
copy(tXgW, tXrW, pred=tXpW)
|
|
665
|
-
weight = tXrW.load().to(cute.Float32)
|
|
666
|
-
|
|
667
|
-
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
668
|
-
|
|
669
|
-
self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True)
|
|
670
|
-
|
|
671
|
-
dw_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
|
|
672
|
-
tXpdW = (
|
|
673
|
-
utils.predicate_k(thr_copy_X.partition_S(dw_coord), limit=shape[1])
|
|
674
|
-
if not is_even_N
|
|
675
|
-
else None
|
|
676
|
-
)
|
|
677
|
-
if const_expr(mdB is not None):
|
|
678
|
-
db_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
|
|
679
|
-
tXpdB = (
|
|
680
|
-
utils.predicate_k(thr_copy_X.partition_S(db_coord), limit=shape[1])
|
|
681
|
-
if not is_even_N
|
|
682
|
-
else None
|
|
683
|
-
)
|
|
684
|
-
|
|
685
|
-
gdW = cute.local_tile(mdW, (1, tiler_mn[1]), (bidx_start, cluster_y))
|
|
686
|
-
tXgdW = thr_copy_X.partition_S(gdW)
|
|
687
|
-
# Always compute partial weight gradients in fp32
|
|
688
|
-
tXrdW = cute.make_fragment_like(tXgdW, Float32)
|
|
689
|
-
|
|
690
|
-
gdB = (
|
|
691
|
-
cute.local_tile(mdB, (1, tiler_mn[1]), (bidx_start, cluster_y))
|
|
692
|
-
if const_expr(mdB is not None)
|
|
693
|
-
else None
|
|
694
|
-
)
|
|
695
|
-
tXgdB = thr_copy_X.partition_S(gdB) if const_expr(mdB is not None) else None
|
|
696
|
-
tXrdB = cute.make_fragment_like(tXgdB, Float32) if const_expr(mdB is not None) else None
|
|
542
|
+
thr_copy_X = tiled_copy.get_slice(tidx)
|
|
697
543
|
|
|
698
544
|
gX, gdO, gdResO, gdX, gdRes, cX = [
|
|
699
545
|
cute.local_tile(mT, tiler_mn, (None, cluster_y)) if mT is not None else None
|
|
700
546
|
for mT in (mX, mdO, mdResO, mdX, mdRes, idX)
|
|
701
547
|
]
|
|
548
|
+
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y)) if mW is not None else None
|
|
549
|
+
gdW, gdB = [
|
|
550
|
+
cute.local_tile(mT, (1, tiler_mn[1]), (bidx_start, cluster_y))
|
|
551
|
+
if const_expr(mT is not None)
|
|
552
|
+
else None
|
|
553
|
+
for mT in (mdW, mdB)
|
|
554
|
+
]
|
|
555
|
+
|
|
702
556
|
tXgX = thr_copy_X.partition_S(gX)
|
|
703
557
|
tXsX = thr_copy_X.partition_D(sX)
|
|
704
558
|
tXgdO = thr_copy_X.partition_S(gdO)
|
|
@@ -709,12 +563,6 @@ class RMSNormBackward(ReductionBase):
|
|
|
709
563
|
if const_expr(mdRes is not None):
|
|
710
564
|
tXgdRes = thr_copy_X.partition_D(gdRes)
|
|
711
565
|
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None]
|
|
712
|
-
# This doesn't change across iterations
|
|
713
|
-
tXpX = (
|
|
714
|
-
utils.predicate_k(thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1])
|
|
715
|
-
if not is_even_N
|
|
716
|
-
else None
|
|
717
|
-
)
|
|
718
566
|
|
|
719
567
|
tXrX, tXrdO, tXrdX = [
|
|
720
568
|
cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdO, tXgdX)
|
|
@@ -726,25 +574,57 @@ class RMSNormBackward(ReductionBase):
|
|
|
726
574
|
if const_expr(mdRes is not None):
|
|
727
575
|
tXrdRes = cute.make_fragment_like(tXgdRes[None, None, None, 0])
|
|
728
576
|
|
|
577
|
+
# This doesn't change across iterations
|
|
578
|
+
tXpX = (
|
|
579
|
+
None
|
|
580
|
+
if is_even_N
|
|
581
|
+
else copy_utils.predicate_k(thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1])
|
|
582
|
+
)
|
|
583
|
+
# Each copy will use the same number of elements as X
|
|
584
|
+
copy = partial(copy_utils.copy, pred=tXpX)
|
|
585
|
+
|
|
586
|
+
tXgdW, tXrdW = None, None
|
|
587
|
+
tXgdB, tXrdB = None, None
|
|
588
|
+
if const_expr(mdW is not None):
|
|
589
|
+
tXgdW = thr_copy_X.partition_S(gdW)
|
|
590
|
+
# Always compute partial weight gradients in fp32
|
|
591
|
+
tXrdW = cute.make_fragment_like(tXgdW, Float32)
|
|
592
|
+
if const_expr(mdB is not None):
|
|
593
|
+
tXgdB = thr_copy_X.partition_S(gdB)
|
|
594
|
+
# Always compute partial bias gradients in fp32
|
|
595
|
+
tXrdB = cute.make_fragment_like(tXgdB, Float32)
|
|
596
|
+
|
|
597
|
+
num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE
|
|
598
|
+
|
|
599
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True)
|
|
600
|
+
|
|
601
|
+
tXrW = None
|
|
602
|
+
if const_expr(mW is not None):
|
|
603
|
+
tXgW = thr_copy_X.partition_S(gW)
|
|
604
|
+
tXrW = cute.make_fragment_like(tXgW)
|
|
605
|
+
# Need this, otherwise rW can have arbitrary values that changes the reduction
|
|
606
|
+
if const_expr(not is_even_N):
|
|
607
|
+
tXrW.fill(0.0)
|
|
608
|
+
copy(tXgW, tXrW)
|
|
609
|
+
|
|
729
610
|
# Prefetch the first batch
|
|
730
611
|
row = tXcX[None, None, None, bidx_start][0][0]
|
|
731
612
|
if row < M:
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
utils.fill_oob(tXsdO[None, None, None, 0], None, fill_value=mdO.element_type.zero)
|
|
613
|
+
copy(tXgX[None, None, None, bidx_start], tXsX[None, None, None, 0], is_async=True)
|
|
614
|
+
copy(tXgdO[None, None, None, bidx_start], tXsdO[None, None, None, 0], is_async=True)
|
|
615
|
+
else:
|
|
616
|
+
if const_expr(tiler_mn[0] > 1):
|
|
617
|
+
# Fill with zero, otherwise smem will be uninitialized, and we could read this back
|
|
618
|
+
# later into registers, causing wrong dW.
|
|
619
|
+
utils.fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero)
|
|
620
|
+
utils.fill_oob(tXsdO[None, None, None, 0], None, fill_value=mdO.element_type.zero)
|
|
741
621
|
cute.arch.cp_async_commit_group()
|
|
742
622
|
|
|
743
623
|
if const_expr(self.cluster_n > 1):
|
|
744
624
|
cute.arch.cluster_wait()
|
|
745
625
|
|
|
746
|
-
|
|
747
|
-
|
|
626
|
+
if const_expr(mdW is not None):
|
|
627
|
+
tXrdW.fill(0.0)
|
|
748
628
|
if const_expr(mdB is not None):
|
|
749
629
|
tXrdB.fill(0.0)
|
|
750
630
|
stage = Int32(0)
|
|
@@ -753,29 +633,31 @@ class RMSNormBackward(ReductionBase):
|
|
|
753
633
|
for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
|
|
754
634
|
row = tXcX[None, None, None, bidx][0][0]
|
|
755
635
|
if row + gdim * tiler_mn[0] < M: # Prefetch the next batch
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
copy(tXgX_cur, tXsX[None, None, None, stage ^ 1], pred=tXpX, is_async=True)
|
|
759
|
-
copy(tXgdO_cur, tXsdO[None, None, None, stage ^ 1], pred=tXpX, is_async=True)
|
|
760
|
-
elif tiler_mn[0] > 1:
|
|
761
|
-
utils.fill_oob(
|
|
636
|
+
copy(
|
|
637
|
+
tXgX[None, None, None, bidx + gdim],
|
|
762
638
|
tXsX[None, None, None, stage ^ 1],
|
|
763
|
-
|
|
764
|
-
fill_value=mX.element_type.zero,
|
|
639
|
+
is_async=True,
|
|
765
640
|
)
|
|
766
|
-
|
|
641
|
+
copy(
|
|
642
|
+
tXgdO[None, None, None, bidx + gdim],
|
|
767
643
|
tXsdO[None, None, None, stage ^ 1],
|
|
768
|
-
|
|
769
|
-
fill_value=mdO.element_type.zero,
|
|
644
|
+
is_async=True,
|
|
770
645
|
)
|
|
646
|
+
else:
|
|
647
|
+
if const_expr(tiler_mn[0] > 1):
|
|
648
|
+
utils.fill_oob(
|
|
649
|
+
tXsX[None, None, None, stage ^ 1], None, fill_value=mX.element_type.zero
|
|
650
|
+
)
|
|
651
|
+
utils.fill_oob(
|
|
652
|
+
tXsdO[None, None, None, stage ^ 1], None, fill_value=mdO.element_type.zero
|
|
653
|
+
)
|
|
771
654
|
cute.arch.cp_async_commit_group()
|
|
772
655
|
rstd = cutlass.Float.zero
|
|
773
656
|
if row < M or tiler_mn[0] == 1:
|
|
774
657
|
rstd = mRstd[row]
|
|
775
658
|
if const_expr(mdResO is not None):
|
|
776
|
-
tXgdResO_cur = utils.coord_offset_i64(bidx, tXgdResO, dim=3)[None, None, None, 0]
|
|
777
659
|
if row < M or tiler_mn[0] == 1:
|
|
778
|
-
copy(
|
|
660
|
+
copy(tXgdResO[None, None, None, bidx], tXrdResO)
|
|
779
661
|
elif tiler_mn[0] > 1:
|
|
780
662
|
tXrdResO.fill(0.0)
|
|
781
663
|
cute.arch.cp_async_wait_group(1)
|
|
@@ -783,10 +665,10 @@ class RMSNormBackward(ReductionBase):
|
|
|
783
665
|
x = tXrX.load().to(cute.Float32)
|
|
784
666
|
cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO)
|
|
785
667
|
dout = tXrdO.load().to(cute.Float32)
|
|
786
|
-
if const_expr(mdResO is not None):
|
|
787
|
-
dout += tXrdResO.load().to(cute.Float32)
|
|
788
668
|
x_hat = x * rstd
|
|
789
|
-
wdy = dout
|
|
669
|
+
wdy = dout
|
|
670
|
+
if const_expr(mW is not None):
|
|
671
|
+
wdy *= tXrW.load().to(Float32)
|
|
790
672
|
if const_expr(self.cluster_n > 1):
|
|
791
673
|
cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
|
|
792
674
|
mean_xhat_wdy = (
|
|
@@ -803,6 +685,10 @@ class RMSNormBackward(ReductionBase):
|
|
|
803
685
|
)
|
|
804
686
|
|
|
805
687
|
if const_expr(self.cluster_n > 1):
|
|
688
|
+
# Need this fence since the STAS from the producer is using the async proxy.
|
|
689
|
+
cute.arch.fence_proxy(
|
|
690
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
691
|
+
)
|
|
806
692
|
# It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
|
|
807
693
|
# Requires adjusting the thread_count when initializing the mbar
|
|
808
694
|
cute.arch.sync_warp()
|
|
@@ -815,22 +701,22 @@ class RMSNormBackward(ReductionBase):
|
|
|
815
701
|
if const_expr(self.reload_wdy == "smem"):
|
|
816
702
|
cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO)
|
|
817
703
|
dout = tXrdO.load().to(cute.Float32)
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
704
|
+
wdy = dout
|
|
705
|
+
if const_expr(mW is not None):
|
|
706
|
+
wdy *= tXrW.load().to(Float32)
|
|
821
707
|
|
|
822
708
|
dx = (wdy - x_hat * mean_xhat_wdy) * rstd
|
|
709
|
+
if const_expr(mdResO is not None):
|
|
710
|
+
dx += tXrdResO.load().to(cute.Float32)
|
|
823
711
|
tXrdX.store(dx.to(tXrdX.element_type))
|
|
824
712
|
if row < M or tiler_mn[0] == 1:
|
|
825
|
-
|
|
826
|
-
copy(tXrdX, tXgdX_cur, pred=tXpX)
|
|
713
|
+
copy(tXrdX, tXgdX[None, None, None, bidx])
|
|
827
714
|
if const_expr(mdRes is not None):
|
|
828
715
|
tXrdRes.store(dx.to(tXrdRes.element_type))
|
|
829
|
-
tXgdRes_cur = utils.coord_offset_i64(bidx, tXgdRes, dim=3)[None, None, None, 0]
|
|
830
716
|
if row < M or tiler_mn[0] == 1:
|
|
831
|
-
copy(tXrdRes,
|
|
832
|
-
|
|
833
|
-
|
|
717
|
+
copy(tXrdRes, tXgdRes[None, None, None, bidx])
|
|
718
|
+
if const_expr(mdW is not None):
|
|
719
|
+
tXrdW.store(tXrdW.load() + dout * x_hat)
|
|
834
720
|
if const_expr(mdB is not None):
|
|
835
721
|
tXrdB.store(tXrdB.load() + dout)
|
|
836
722
|
|
|
@@ -839,29 +725,29 @@ class RMSNormBackward(ReductionBase):
|
|
|
839
725
|
consumer_phase ^= 1
|
|
840
726
|
producer_phase ^= 1
|
|
841
727
|
|
|
842
|
-
if const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
|
|
843
|
-
cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
|
|
844
|
-
|
|
845
728
|
if const_expr(tiler_mn[0] > 1):
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
729
|
+
if const_expr(mdW is not None):
|
|
730
|
+
# reduction of dw_partial within the same threadblock
|
|
731
|
+
sdW = cute.make_tensor(
|
|
732
|
+
cute.recast_ptr(sX.iterator, dtype=cute.Float32),
|
|
733
|
+
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
734
|
+
)
|
|
735
|
+
tXsdW = thr_copy_X.partition_D(sdW)
|
|
736
|
+
cute.arch.barrier()
|
|
737
|
+
row = tXcX[None, None, None, 0][0][0]
|
|
738
|
+
if row > 0:
|
|
739
|
+
cute.autovec_copy(tXrdW, tXsdW)
|
|
740
|
+
cute.arch.barrier()
|
|
741
|
+
if row == 0:
|
|
742
|
+
for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
|
|
743
|
+
tXrdW_other = cute.make_fragment_like(tXrdW)
|
|
744
|
+
tXsdW_other = cute.make_tensor(
|
|
745
|
+
tXsdW.iterator + i * sdW.stride[0], tXsdW.layout
|
|
746
|
+
)
|
|
747
|
+
cute.autovec_copy(tXsdW_other, tXrdW_other)
|
|
748
|
+
tXrdW.store(tXrdW.load() + tXrdW_other.load())
|
|
749
|
+
copy(tXrdW, tXgdW)
|
|
750
|
+
cute.arch.barrier()
|
|
865
751
|
if const_expr(mdB is not None):
|
|
866
752
|
sdB = cute.make_tensor(
|
|
867
753
|
cute.recast_ptr(sX.iterator, dtype=cute.Float32),
|
|
@@ -881,12 +767,21 @@ class RMSNormBackward(ReductionBase):
|
|
|
881
767
|
)
|
|
882
768
|
cute.autovec_copy(tXsdB_other, tXrdB_other)
|
|
883
769
|
tXrdB.store(tXrdB.load() + tXrdB_other.load())
|
|
884
|
-
copy(tXrdB, tXgdB
|
|
770
|
+
copy(tXrdB, tXgdB)
|
|
885
771
|
else:
|
|
886
772
|
# dw is already in fp32, so we can directly copy to global memory
|
|
887
|
-
|
|
773
|
+
if const_expr(mdW is not None):
|
|
774
|
+
copy(tXrdW, tXgdW)
|
|
888
775
|
if const_expr(mdB is not None):
|
|
889
|
-
copy(tXrdB, tXgdB
|
|
776
|
+
copy(tXrdB, tXgdB)
|
|
777
|
+
|
|
778
|
+
if const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
|
|
779
|
+
# Assume state contains that next useful buffer
|
|
780
|
+
# So we only need to advance to num_stages - 1 times to last used buffer
|
|
781
|
+
stage ^= 1
|
|
782
|
+
if stage == 0:
|
|
783
|
+
producer_phase ^= 1
|
|
784
|
+
cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
|
|
890
785
|
|
|
891
786
|
|
|
892
787
|
def _get_sm_count(N: int, device: torch.device) -> int:
|
|
@@ -911,120 +806,103 @@ def _get_sm_count(N: int, device: torch.device) -> int:
|
|
|
911
806
|
mutates_args={"dx", "dw_partial", "db_partial", "dresidual"},
|
|
912
807
|
device_types="cuda",
|
|
913
808
|
# We need to specify the schema manually since we're mutating an optional tensor
|
|
914
|
-
schema="(Tensor x, Tensor weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!) dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual) -> ()",
|
|
809
|
+
schema="(Tensor x, Tensor? weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!)? dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual, int? sm_count) -> ()",
|
|
915
810
|
)
|
|
916
811
|
def _rmsnorm_bwd(
|
|
917
812
|
x: Tensor,
|
|
918
|
-
weight: Tensor,
|
|
813
|
+
weight: Optional[Tensor],
|
|
919
814
|
dout: Tensor,
|
|
920
815
|
rstd: Tensor,
|
|
921
816
|
dx: Tensor,
|
|
922
|
-
dw_partial: Tensor,
|
|
817
|
+
dw_partial: Optional[Tensor],
|
|
923
818
|
db_partial: Optional[Tensor] = None,
|
|
924
819
|
dresidual_out: Optional[Tensor] = None,
|
|
925
820
|
dresidual: Optional[Tensor] = None,
|
|
821
|
+
sm_count: Optional[int] = None,
|
|
926
822
|
) -> None:
|
|
927
823
|
"""RMSNorm backward pass.
|
|
928
824
|
Args:
|
|
929
825
|
x: Input tensor of shape (M, N)
|
|
930
|
-
weight:
|
|
826
|
+
weight: Optional weight tensor of shape (N,)
|
|
931
827
|
dout: Upstream gradients tensor of shape (M, N)
|
|
932
828
|
rstd: Reciprocal standard deviation tensor of shape (M,)
|
|
933
829
|
Returns:
|
|
934
830
|
Tuple of (dx, dw) where:
|
|
935
831
|
- dx: Input gradients tensor of same shape as x
|
|
936
|
-
- dw: Weight gradients tensor of same shape as weight
|
|
832
|
+
- dw: Weight gradients tensor of same shape as weight (or None if weight is None)
|
|
937
833
|
"""
|
|
938
834
|
assert x.dim() == 2, "Input must be 2D"
|
|
939
|
-
assert
|
|
940
|
-
|
|
941
|
-
assert x.
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
], "Weight must be float32, float16 or bfloat16"
|
|
835
|
+
assert x.is_cuda, "Input tensor must be on CUDA device"
|
|
836
|
+
supported_types = {torch.float16, torch.bfloat16, torch.float32}
|
|
837
|
+
assert x.dtype in supported_types, "Unsupported dtype"
|
|
838
|
+
if weight is not None:
|
|
839
|
+
assert weight.dim() == 1, "Weight must be 1D"
|
|
840
|
+
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
841
|
+
assert weight.is_cuda, "Weight tensor must be on CUDA device"
|
|
842
|
+
assert weight.dtype in supported_types, "Weight must be float32, float16 or bfloat16"
|
|
948
843
|
if dresidual_out is not None:
|
|
949
844
|
assert dresidual_out.shape == x.shape
|
|
950
845
|
assert dresidual_out.is_cuda
|
|
951
|
-
assert dresidual_out.dtype in
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
torch.float32,
|
|
955
|
-
], "Residual must be float16, bfloat16, or float32"
|
|
846
|
+
assert dresidual_out.dtype in supported_types, (
|
|
847
|
+
"Residual must be float16, bfloat16, or float32"
|
|
848
|
+
)
|
|
956
849
|
if dresidual is not None:
|
|
957
850
|
assert dresidual.shape == x.shape
|
|
958
851
|
assert dresidual.is_cuda
|
|
959
|
-
assert dresidual.dtype in
|
|
960
|
-
torch.float16,
|
|
961
|
-
torch.bfloat16,
|
|
962
|
-
torch.float32,
|
|
963
|
-
], "Residual must be float16, bfloat16, or float32"
|
|
852
|
+
assert dresidual.dtype in supported_types, "Residual must be float16, bfloat16, or float32"
|
|
964
853
|
|
|
965
854
|
N = x.size(1)
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
for t in (x, dout, dresidual_out, dx, dresidual)
|
|
855
|
+
if dw_partial is None and db_partial is None:
|
|
856
|
+
assert sm_count is not None
|
|
857
|
+
else:
|
|
858
|
+
sm_count = dw_partial.shape[0] if dw_partial is not None else db_partial.shape[0]
|
|
859
|
+
dtype, dout_dtype, dx_dtype, weight_dtype, dres_dtype, dres_out_dtype = [
|
|
860
|
+
torch2cute_dtype_map[t.dtype] if t is not None else None
|
|
861
|
+
for t in [x, dout, dx, weight, dresidual, dresidual_out]
|
|
974
862
|
]
|
|
975
|
-
# Handle weight div based on weight dtype
|
|
976
|
-
weight_dtype = torch2cute_dtype_map[weight.dtype]
|
|
977
|
-
weight_tensor = utils.convert_from_dlpack(
|
|
978
|
-
weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
|
|
979
|
-
)
|
|
980
|
-
|
|
981
|
-
dw_partial_tensor = from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
982
|
-
db_partial_tensor = (
|
|
983
|
-
from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
984
|
-
if db_partial is not None
|
|
985
|
-
else None
|
|
986
|
-
)
|
|
987
|
-
rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
988
|
-
|
|
989
|
-
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
990
|
-
|
|
991
863
|
compile_key = (
|
|
992
864
|
N,
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
865
|
+
dtype,
|
|
866
|
+
dout_dtype,
|
|
867
|
+
dx_dtype,
|
|
868
|
+
weight_dtype,
|
|
869
|
+
db_partial is not None,
|
|
870
|
+
dres_dtype,
|
|
871
|
+
dres_out_dtype,
|
|
998
872
|
)
|
|
999
873
|
if compile_key not in _rmsnorm_bwd.compile_cache:
|
|
1000
|
-
|
|
874
|
+
batch_sym, batch_partial_sym = cute.sym_int(), cute.sym_int()
|
|
875
|
+
all_dtypes = [dtype, dout_dtype, dx_dtype, dres_dtype, dres_out_dtype]
|
|
876
|
+
div = math.gcd(N, *(128 // dt.width for dt in all_dtypes if dt is not None))
|
|
877
|
+
x_cute, dout_cute, dx_cute, dres_out_cute, dres_cute = [
|
|
878
|
+
fake_tensor(dt, (batch_sym, N), div)
|
|
879
|
+
for dt in [dtype, dout_dtype, dx_dtype, dres_out_dtype, dres_dtype]
|
|
880
|
+
]
|
|
881
|
+
weight_cute = fake_tensor(weight_dtype, (N,), div)
|
|
882
|
+
rstd_cute = fake_tensor(Float32, (batch_sym,))
|
|
883
|
+
dw_partial_cute = (
|
|
884
|
+
fake_tensor(Float32, (batch_partial_sym, N), div) if dw_partial is not None else None
|
|
885
|
+
)
|
|
886
|
+
db_partial_cute = (
|
|
887
|
+
fake_tensor(Float32, (batch_partial_sym, N), div) if db_partial is not None else None
|
|
888
|
+
)
|
|
1001
889
|
_rmsnorm_bwd.compile_cache[compile_key] = cute.compile(
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
890
|
+
RMSNormBackward(dtype, N),
|
|
891
|
+
x_cute,
|
|
892
|
+
weight_cute,
|
|
893
|
+
dout_cute,
|
|
894
|
+
dres_out_cute,
|
|
895
|
+
rstd_cute,
|
|
896
|
+
dx_cute,
|
|
897
|
+
dw_partial_cute,
|
|
898
|
+
dres_cute,
|
|
899
|
+
db_partial_cute,
|
|
1012
900
|
sm_count,
|
|
1013
|
-
|
|
901
|
+
cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
|
|
902
|
+
options="--enable-tvm-ffi",
|
|
1014
903
|
)
|
|
1015
|
-
|
|
1016
904
|
_rmsnorm_bwd.compile_cache[compile_key](
|
|
1017
|
-
|
|
1018
|
-
weight_tensor,
|
|
1019
|
-
dout_tensor,
|
|
1020
|
-
dres_out_tensor,
|
|
1021
|
-
rstd_tensor,
|
|
1022
|
-
dx_tensor,
|
|
1023
|
-
dw_partial_tensor,
|
|
1024
|
-
dres_tensor,
|
|
1025
|
-
db_partial_tensor,
|
|
1026
|
-
sm_count,
|
|
1027
|
-
current_stream,
|
|
905
|
+
x, weight, dout, dresidual_out, rstd, dx, dw_partial, dresidual, db_partial, sm_count
|
|
1028
906
|
)
|
|
1029
907
|
|
|
1030
908
|
|
|
@@ -1033,30 +911,37 @@ _rmsnorm_bwd.compile_cache = {}
|
|
|
1033
911
|
|
|
1034
912
|
def rmsnorm_bwd(
|
|
1035
913
|
x: Tensor,
|
|
1036
|
-
weight: Tensor,
|
|
914
|
+
weight: Optional[Tensor],
|
|
1037
915
|
dout: Tensor,
|
|
1038
916
|
rstd: Tensor,
|
|
1039
917
|
dresidual_out: Optional[Tensor] = None, # grad wrt residual_out
|
|
1040
918
|
has_bias: bool = False,
|
|
1041
|
-
|
|
919
|
+
has_residual: bool = False,
|
|
920
|
+
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
|
|
1042
921
|
device = x.device
|
|
1043
922
|
N = x.size(1)
|
|
1044
|
-
sm_count = _get_sm_count(N, device)
|
|
1045
923
|
dx = torch.empty_like(x)
|
|
1046
|
-
|
|
1047
924
|
if dresidual_out is not None and dresidual_out.dtype != dx.dtype:
|
|
1048
925
|
dresidual = torch.empty_like(x, dtype=dresidual_out.dtype)
|
|
1049
926
|
else:
|
|
1050
927
|
dresidual = None
|
|
1051
|
-
|
|
1052
|
-
|
|
928
|
+
sm_count = _get_sm_count(N, device)
|
|
929
|
+
if weight is not None:
|
|
930
|
+
# Always store partial gradients in fp32 for numerical accuracy
|
|
931
|
+
dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
|
|
932
|
+
else:
|
|
933
|
+
dw_partial = None
|
|
1053
934
|
db_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32) if has_bias else None
|
|
1054
|
-
|
|
935
|
+
|
|
936
|
+
_rmsnorm_bwd(
|
|
937
|
+
x, weight, dout, rstd, dx, dw_partial, db_partial, dresidual_out, dresidual, sm_count
|
|
938
|
+
)
|
|
939
|
+
|
|
1055
940
|
# we have summed the partial gradients in fp32, now we convert back to the weight dtype
|
|
1056
|
-
dw = dw_partial.sum(dim=0).to(weight.dtype)
|
|
941
|
+
dw = dw_partial.sum(dim=0).to(weight.dtype) if weight is not None else None
|
|
1057
942
|
db = db_partial.sum(dim=0).to(weight.dtype) if has_bias else None
|
|
1058
943
|
# dresidual is the same as dx in this case
|
|
1059
|
-
if
|
|
944
|
+
if has_residual and dresidual is None:
|
|
1060
945
|
dresidual = dx
|
|
1061
946
|
return dx, dw, db, dresidual
|
|
1062
947
|
|
|
@@ -1104,7 +989,6 @@ class RMSNormFunction(torch.autograd.Function):
|
|
|
1104
989
|
@staticmethod
|
|
1105
990
|
def backward(ctx, dout, *args):
|
|
1106
991
|
x, weight, rstd = ctx.saved_tensors
|
|
1107
|
-
assert weight is not None, "RMSNorm backward doesn't support weight=None yet"
|
|
1108
992
|
has_bias = ctx.has_bias
|
|
1109
993
|
if ctx.prenorm and ctx.residual_dtype is not None:
|
|
1110
994
|
dresidual_out = args[0]
|
|
@@ -1114,11 +998,16 @@ class RMSNormFunction(torch.autograd.Function):
|
|
|
1114
998
|
x_shape_og = ctx.x_shape_og
|
|
1115
999
|
# Reshape dout to match the flattened shape used in forward
|
|
1116
1000
|
dout = dout.view(-1, dout.shape[-1])
|
|
1117
|
-
|
|
1118
|
-
|
|
1001
|
+
dx, dw, db, dresidual = rmsnorm_bwd(
|
|
1002
|
+
x,
|
|
1003
|
+
weight,
|
|
1004
|
+
dout,
|
|
1005
|
+
rstd,
|
|
1006
|
+
dresidual_out,
|
|
1007
|
+
has_bias,
|
|
1008
|
+
has_residual=ctx.residual_dtype is not None,
|
|
1009
|
+
)
|
|
1119
1010
|
dx = dx.view(x_shape_og)
|
|
1120
|
-
if dresidual_out is not None:
|
|
1121
|
-
dresidual_out = dresidual_out.reshape(x_shape_og)
|
|
1122
1011
|
if dresidual is not None:
|
|
1123
1012
|
dresidual = dresidual.reshape(x_shape_og)
|
|
1124
1013
|
|
|
@@ -1148,7 +1037,7 @@ def rmsnorm(
|
|
|
1148
1037
|
return RMSNormFunction.apply(x, weight, bias, residual, out_dtype, residual_dtype, eps, prenorm)
|
|
1149
1038
|
|
|
1150
1039
|
|
|
1151
|
-
class QuackRMSNorm(torch.nn.
|
|
1040
|
+
class QuackRMSNorm(torch.nn.RMSNorm):
|
|
1152
1041
|
"""RMSNorm module that behaves like torch.nn.RMSNorm.
|
|
1153
1042
|
|
|
1154
1043
|
This class provides a drop-in replacement for torch.nn.RMSNorm that uses
|
|
@@ -1163,10 +1052,10 @@ class QuackRMSNorm(torch.nn.Module):
|
|
|
1163
1052
|
eps (float): A small constant for numerical stability
|
|
1164
1053
|
"""
|
|
1165
1054
|
|
|
1166
|
-
def __init__(
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1055
|
+
def __init__(
|
|
1056
|
+
self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True, device=None, dtype=None
|
|
1057
|
+
):
|
|
1058
|
+
super().__init__(dim, eps, elementwise_affine, device=device, dtype=dtype)
|
|
1170
1059
|
|
|
1171
1060
|
def forward(self, x: Tensor) -> Tensor:
|
|
1172
1061
|
"""Apply RMSNorm to the input tensor.
|
|
@@ -1179,6 +1068,67 @@ class QuackRMSNorm(torch.nn.Module):
|
|
|
1179
1068
|
"""
|
|
1180
1069
|
return rmsnorm(x, self.weight, eps=self.eps)
|
|
1181
1070
|
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1071
|
+
|
|
1072
|
+
def layernorm_fwd(
|
|
1073
|
+
x: Tensor,
|
|
1074
|
+
weight: Tensor,
|
|
1075
|
+
bias: Optional[Tensor] = None,
|
|
1076
|
+
eps: float = 1e-6,
|
|
1077
|
+
return_rstd: bool = False,
|
|
1078
|
+
return_mean: bool = False,
|
|
1079
|
+
):
|
|
1080
|
+
"""LayerNorm forward pass using the unified RMSNorm/LayerNorm kernel.
|
|
1081
|
+
|
|
1082
|
+
Args:
|
|
1083
|
+
x: Input tensor of shape (M, N)
|
|
1084
|
+
weight: Weight tensor of shape (N,). Must be float32.
|
|
1085
|
+
bias: Optional bias tensor of shape (N,). Must be float32.
|
|
1086
|
+
eps: Small value for numerical stability
|
|
1087
|
+
return_rstd: Whether to return the reciprocal standard deviation
|
|
1088
|
+
return_mean: Whether to return the mean
|
|
1089
|
+
|
|
1090
|
+
Returns:
|
|
1091
|
+
Normalized output tensor of same shape as x
|
|
1092
|
+
If return_rstd is True, also returns rstd tensor of shape (M,)
|
|
1093
|
+
If return_mean is True, also returns mean tensor of shape (M,)
|
|
1094
|
+
"""
|
|
1095
|
+
assert x.dim() == 2, "Input must be 2D"
|
|
1096
|
+
assert weight.dim() == 1, "Weight must be 1D"
|
|
1097
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
1098
|
+
assert weight.dtype == torch.float32, "Weight must be float32"
|
|
1099
|
+
if bias is not None:
|
|
1100
|
+
assert bias.dim() == 1, "Bias must be 1D"
|
|
1101
|
+
assert bias.dtype == torch.float32, "Bias must be float32"
|
|
1102
|
+
|
|
1103
|
+
M, N = x.shape
|
|
1104
|
+
device = x.device
|
|
1105
|
+
out = torch.empty_like(x)
|
|
1106
|
+
rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
|
|
1107
|
+
mean = torch.empty(M, device=device, dtype=torch.float32) if return_mean else None
|
|
1108
|
+
|
|
1109
|
+
_rmsnorm_fwd(x, weight, out, bias, rstd, mean, None, None, eps, True)
|
|
1110
|
+
|
|
1111
|
+
if return_rstd and return_mean:
|
|
1112
|
+
return out, rstd, mean
|
|
1113
|
+
elif return_rstd:
|
|
1114
|
+
return out, rstd
|
|
1115
|
+
elif return_mean:
|
|
1116
|
+
return out, mean
|
|
1117
|
+
return out
|
|
1118
|
+
|
|
1119
|
+
|
|
1120
|
+
def layernorm_ref(x: Tensor, w: Tensor, eps: float = 1e-6) -> Tensor:
|
|
1121
|
+
"""Reference implementation for LayerNorm."""
|
|
1122
|
+
x_f32 = x.float()
|
|
1123
|
+
return torch.nn.functional.layer_norm(x_f32, w.shape, w, None, eps).to(x.dtype)
|
|
1124
|
+
|
|
1125
|
+
|
|
1126
|
+
def layernorm_rstd_ref(x: torch.Tensor, eps: float = 1e-6):
|
|
1127
|
+
x_f32 = x.float()
|
|
1128
|
+
mean = x_f32.mean(dim=-1, keepdim=True)
|
|
1129
|
+
var = ((x_f32 - mean) ** 2).mean(dim=-1)
|
|
1130
|
+
return 1.0 / torch.sqrt(var + eps)
|
|
1131
|
+
|
|
1132
|
+
|
|
1133
|
+
def layernorm_mean_ref(x: torch.Tensor) -> torch.Tensor:
|
|
1134
|
+
return x.float().mean(dim=-1)
|