quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/autotuner.py +64 -5
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -35
- quack/gemm.py +194 -0
- quack/gemm_act.py +510 -0
- quack/gemm_config.py +72 -46
- quack/gemm_dact.py +215 -0
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +615 -146
- quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +182 -23
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +508 -624
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +55 -61
- quack/topk.py +409 -85
- quack/utils.py +37 -172
- quack/varlen_utils.py +370 -6
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/gemm_act_sm90.py +0 -368
- quack/gemm_dact_sm90.py +0 -150
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.1.dist-info/RECORD +0 -37
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/rmsnorm.py
CHANGED
|
@@ -1,156 +1,91 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
import math
|
|
4
|
+
from typing import Optional, Tuple, Type
|
|
4
5
|
from functools import partial
|
|
5
6
|
|
|
6
7
|
import cuda.bindings.driver as cuda
|
|
7
8
|
|
|
8
9
|
import cutlass
|
|
9
10
|
import cutlass.cute as cute
|
|
10
|
-
from cutlass import Float32, Int32
|
|
11
|
-
from cutlass import const_expr
|
|
12
|
-
from cutlass.cute.runtime import from_dlpack
|
|
11
|
+
from cutlass import Float32, Int32, const_expr
|
|
13
12
|
|
|
14
13
|
import torch
|
|
15
14
|
from torch import Tensor
|
|
16
15
|
|
|
17
16
|
import quack.utils as utils
|
|
17
|
+
import quack.copy_utils as copy_utils
|
|
18
|
+
import quack.layout_utils as layout_utils
|
|
19
|
+
from quack.compile_utils import make_fake_tensor as fake_tensor
|
|
18
20
|
from quack.reduce import row_reduce
|
|
19
21
|
from quack.reduction_base import ReductionBase
|
|
20
22
|
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
21
23
|
|
|
22
24
|
|
|
23
25
|
class RMSNorm(ReductionBase):
|
|
24
|
-
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
25
|
-
super().__init__(dtype, N, stage=1)
|
|
26
|
-
self.
|
|
26
|
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int, is_layernorm: bool = False):
|
|
27
|
+
super().__init__(dtype, N, stage=2 if is_layernorm else 1)
|
|
28
|
+
self.is_layernorm = is_layernorm
|
|
29
|
+
self.reload_from = None if N <= (16384 if is_layernorm else 8192) else "smem"
|
|
27
30
|
self.delay_w_load = False
|
|
28
31
|
|
|
29
|
-
def
|
|
30
|
-
"""Calculate the number of threads per row for the RMSNorm kernel."""
|
|
32
|
+
def _threads_per_row(self):
|
|
31
33
|
N = self.N
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
elif N <= 3072:
|
|
37
|
-
return 32
|
|
38
|
-
elif N <= 6144:
|
|
39
|
-
return 64
|
|
40
|
-
elif N <= 16384:
|
|
41
|
-
return 128
|
|
42
|
-
else:
|
|
43
|
-
return 256
|
|
34
|
+
for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]:
|
|
35
|
+
if N <= limit:
|
|
36
|
+
return threads
|
|
37
|
+
return 256
|
|
44
38
|
|
|
45
39
|
def _set_cluster_n(self):
|
|
46
|
-
"""
|
|
47
|
-
Set the number of clusters for the RMSNorm kernel.
|
|
48
|
-
Stored in self.cluster_n.
|
|
49
|
-
"""
|
|
50
40
|
N = self.N
|
|
51
|
-
|
|
52
41
|
# cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
|
|
53
42
|
# Similarly cluster_n = 8 is faster for N=128k
|
|
54
43
|
if const_expr(self.dtype.width == 16):
|
|
55
|
-
|
|
56
|
-
if N <= 16 * 1024:
|
|
57
|
-
cluster_n = 1
|
|
58
|
-
elif N <= 32 * 1024:
|
|
59
|
-
cluster_n = 2
|
|
60
|
-
elif N <= 64 * 1024:
|
|
61
|
-
cluster_n = 4
|
|
62
|
-
elif N <= 128 * 1024:
|
|
63
|
-
cluster_n = 8
|
|
64
|
-
else:
|
|
65
|
-
cluster_n = 16
|
|
44
|
+
thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)]
|
|
66
45
|
else:
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
cluster_n = 4
|
|
74
|
-
elif N <= 256 * 1024:
|
|
75
|
-
cluster_n = 8
|
|
76
|
-
else:
|
|
77
|
-
cluster_n = 16
|
|
78
|
-
|
|
79
|
-
self.cluster_n = cluster_n
|
|
80
|
-
|
|
81
|
-
def _smem_size_in_bytes(self, tiler_mn, num_warps, dtype_res=None):
|
|
82
|
-
return (
|
|
83
|
-
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
|
|
84
|
-
+ (
|
|
85
|
-
cute.size_in_bytes(dtype_res, cute.make_layout(tiler_mn))
|
|
86
|
-
if dtype_res is not None
|
|
87
|
-
else 0
|
|
88
|
-
)
|
|
89
|
-
+ self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
|
|
90
|
-
+ self.stage * (cutlass.Int64.width // 8)
|
|
91
|
-
)
|
|
46
|
+
thresholds = [(32 * 1024, 1), (64 * 1024, 2), (128 * 1024, 4), (256 * 1024, 8)]
|
|
47
|
+
for limit, cluster in thresholds:
|
|
48
|
+
if N <= limit:
|
|
49
|
+
self.cluster_n = cluster
|
|
50
|
+
return
|
|
51
|
+
self.cluster_n = 16
|
|
92
52
|
|
|
93
53
|
@cute.jit
|
|
94
54
|
def __call__(
|
|
95
55
|
self,
|
|
96
56
|
mX: cute.Tensor,
|
|
97
|
-
mW: cute.Tensor,
|
|
57
|
+
mW: Optional[cute.Tensor],
|
|
98
58
|
mB: Optional[cute.Tensor],
|
|
99
59
|
mRes: Optional[cute.Tensor],
|
|
100
60
|
mO: cute.Tensor,
|
|
101
61
|
mResO: Optional[cute.Tensor],
|
|
102
62
|
mRstd: Optional[cute.Tensor],
|
|
63
|
+
mMean: Optional[cute.Tensor],
|
|
64
|
+
eps: Float32,
|
|
103
65
|
stream: cuda.CUstream,
|
|
104
|
-
eps: Float32 = 1e-6,
|
|
105
66
|
):
|
|
106
|
-
semistatic_shape = (*mX.shape[:-1], self.N) # Set last dimension to be statically N
|
|
107
|
-
new_stride = lambda t: (
|
|
108
|
-
cute.assume(t.stride[0], divby=128 // t.element_type.width),
|
|
109
|
-
t.stride[1],
|
|
110
|
-
)
|
|
111
|
-
mX, mRes, mO, mResO = [
|
|
112
|
-
cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
|
|
113
|
-
if const_expr(t is not None)
|
|
114
|
-
else None
|
|
115
|
-
for t in (mX, mRes, mO, mResO)
|
|
116
|
-
]
|
|
117
67
|
assert mX.element_type == self.dtype
|
|
118
|
-
assert mO.element_type == self.dtype
|
|
119
68
|
self._set_cluster_n()
|
|
120
69
|
largest_dtype_width = const_expr(
|
|
121
|
-
max(
|
|
122
|
-
mX.element_type.width,
|
|
123
|
-
mRes.element_type.width if mRes is not None else 0,
|
|
124
|
-
mO.element_type.width,
|
|
125
|
-
mResO.element_type.width if mResO is not None else 0,
|
|
126
|
-
)
|
|
127
|
-
)
|
|
128
|
-
tiler_mn, tv_layout = self._get_tv_layout(
|
|
129
|
-
num_copy_bits=128 // largest_dtype_width * mX.element_type.width
|
|
70
|
+
max(*(t.element_type.width for t in [mX, mRes, mW, mB, mO, mResO] if t is not None))
|
|
130
71
|
)
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
mW =
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
mRstd.layout, cute.make_layout((self.N,), stride=(0,))
|
|
143
|
-
)
|
|
144
|
-
mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
|
|
72
|
+
vecsize = math.gcd(self.N, 128 // largest_dtype_width)
|
|
73
|
+
tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize)
|
|
74
|
+
num_threads = tiled_copy.size
|
|
75
|
+
mW, mB = [
|
|
76
|
+
layout_utils.expand(mT, dim=0, size=tiler_mn[0]) if const_expr(mT is not None) else None
|
|
77
|
+
for mT in (mW, mB)
|
|
78
|
+
]
|
|
79
|
+
mRstd, mMean = [
|
|
80
|
+
layout_utils.expand(mT, dim=1, size=self.N) if const_expr(mT is not None) else None
|
|
81
|
+
for mT in (mRstd, mMean)
|
|
82
|
+
]
|
|
145
83
|
self.kernel(
|
|
146
|
-
mX, mW, mB, mRes, mO, mResO, mRstd,
|
|
84
|
+
mX, mW, mB, mRes, mO, mResO, mRstd, mMean, eps, tiler_mn, tiled_copy, threads_per_row
|
|
147
85
|
).launch(
|
|
148
86
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
149
87
|
block=[num_threads, 1, 1],
|
|
150
|
-
cluster=
|
|
151
|
-
smem=self._smem_size_in_bytes(
|
|
152
|
-
tiler_mn, num_warps, dtype_res=mRes.element_type if mRes is not None else None
|
|
153
|
-
),
|
|
88
|
+
cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None,
|
|
154
89
|
stream=stream,
|
|
155
90
|
)
|
|
156
91
|
|
|
@@ -158,30 +93,26 @@ class RMSNorm(ReductionBase):
|
|
|
158
93
|
def kernel(
|
|
159
94
|
self,
|
|
160
95
|
mX: cute.Tensor,
|
|
161
|
-
mW: cute.Tensor,
|
|
96
|
+
mW: Optional[cute.Tensor],
|
|
162
97
|
mB: Optional[cute.Tensor],
|
|
163
98
|
mRes: Optional[cute.Tensor],
|
|
164
99
|
mO: cute.Tensor,
|
|
165
100
|
mResO: Optional[cute.Tensor],
|
|
166
101
|
mRstd: Optional[cute.Tensor],
|
|
167
|
-
|
|
168
|
-
|
|
102
|
+
mMean: Optional[cute.Tensor],
|
|
103
|
+
eps: Float32,
|
|
169
104
|
tiler_mn: cute.Shape,
|
|
170
|
-
|
|
171
|
-
|
|
105
|
+
tiled_copy: cute.TiledCopy,
|
|
106
|
+
threads_per_row: cutlass.Constexpr[int],
|
|
172
107
|
):
|
|
173
108
|
tidx, _, _ = cute.arch.thread_idx()
|
|
174
109
|
bidx, _, _ = cute.arch.block_idx()
|
|
175
|
-
if const_expr(self.cluster_n
|
|
176
|
-
|
|
177
|
-
else:
|
|
178
|
-
cluster_y = const_expr(0)
|
|
110
|
+
cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1]
|
|
111
|
+
tv_layout = tiled_copy.layout_tv_tiled
|
|
179
112
|
|
|
180
113
|
smem = cutlass.utils.SmemAllocator()
|
|
181
114
|
sX = smem.allocate_tensor(
|
|
182
|
-
mX.element_type,
|
|
183
|
-
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
184
|
-
byte_alignment=16,
|
|
115
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
185
116
|
)
|
|
186
117
|
if const_expr(mRes is not None):
|
|
187
118
|
sRes = smem.allocate_tensor(
|
|
@@ -194,73 +125,18 @@ class RMSNorm(ReductionBase):
|
|
|
194
125
|
shape = mX.shape
|
|
195
126
|
idX = cute.make_identity_tensor(shape)
|
|
196
127
|
# slice for CTAs
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
for mT in (mX, mRes, mO, mResO)
|
|
128
|
+
gX, gRes, gO, gResO, gRstd, gMean, cX = [
|
|
129
|
+
cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) if mT is not None else None
|
|
130
|
+
for mT in (mX, mRes, mO, mResO, mRstd, mMean, idX)
|
|
201
131
|
]
|
|
202
|
-
|
|
203
|
-
cute.local_tile(mT, tiler_mn, (0, cluster_y)) if mT is not None else None
|
|
204
|
-
for mT in (
|
|
132
|
+
gW, gB = [
|
|
133
|
+
cute.local_tile(mT, tiler_mn, (0, cluster_y)) if const_expr(mT is not None) else None
|
|
134
|
+
for mT in (mW, mB)
|
|
205
135
|
]
|
|
206
|
-
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
207
|
-
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
208
|
-
gB = cute.local_tile(mB, tiler_mn, (0, cluster_y)) if const_expr(mB is not None) else None
|
|
209
|
-
gRstd = (
|
|
210
|
-
cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
|
|
211
|
-
if const_expr(mRstd is not None)
|
|
212
|
-
else None
|
|
213
|
-
)
|
|
214
136
|
|
|
215
|
-
|
|
216
|
-
num_copy_elems_X = tv_layout.shape[1][0]
|
|
217
|
-
num_copy_bits_X = mX.element_type.width * num_copy_elems_X
|
|
218
|
-
copy_atom_load_X = cute.make_copy_atom(
|
|
219
|
-
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
220
|
-
)
|
|
221
|
-
copy_atom_load_X_async = cute.make_copy_atom(
|
|
222
|
-
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
223
|
-
)
|
|
224
|
-
num_copy_bits_W = const_expr(min(128, num_copy_elems_X * mW.element_type.width))
|
|
225
|
-
copy_atom_load_W = cute.make_copy_atom(
|
|
226
|
-
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
|
|
227
|
-
)
|
|
228
|
-
num_bits_per_copy_B = (
|
|
229
|
-
cutlass.const_expr(min(128, num_copy_elems_X * mB.element_type.width))
|
|
230
|
-
if const_expr(mB is not None)
|
|
231
|
-
else 0
|
|
232
|
-
)
|
|
233
|
-
copy_atom_load_B = (
|
|
234
|
-
cute.make_copy_atom(
|
|
235
|
-
cute.nvgpu.CopyUniversalOp(), mB.element_type, num_bits_per_copy=num_bits_per_copy_B
|
|
236
|
-
)
|
|
237
|
-
if const_expr(mB is not None)
|
|
238
|
-
else None
|
|
239
|
-
)
|
|
240
|
-
if const_expr(mRes is not None):
|
|
241
|
-
num_copy_bits_Res = const_expr(min(128, num_copy_elems_X * mRes.element_type.width))
|
|
242
|
-
copy_atom_load_Res_async = cute.make_copy_atom(
|
|
243
|
-
cute.nvgpu.cpasync.CopyG2SOp(),
|
|
244
|
-
mRes.element_type,
|
|
245
|
-
num_bits_per_copy=num_copy_bits_Res,
|
|
246
|
-
)
|
|
247
|
-
num_copy_bits_O = const_expr(min(128, num_copy_elems_X * mO.element_type.width))
|
|
248
|
-
copy_atom_store_O = cute.make_copy_atom(
|
|
249
|
-
cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=num_copy_bits_O
|
|
250
|
-
)
|
|
251
|
-
if const_expr(mResO is not None):
|
|
252
|
-
num_copy_bits_ResO = const_expr(min(128, num_copy_elems_X * mResO.element_type.width))
|
|
253
|
-
copy_atom_store_ResO = cute.make_copy_atom(
|
|
254
|
-
cute.nvgpu.CopyUniversalOp(),
|
|
255
|
-
mResO.element_type,
|
|
256
|
-
num_bits_per_copy=num_copy_bits_ResO,
|
|
257
|
-
)
|
|
137
|
+
thr_copy_X = tiled_copy.get_slice(tidx)
|
|
258
138
|
|
|
259
|
-
|
|
260
|
-
tidx
|
|
261
|
-
)
|
|
262
|
-
|
|
263
|
-
tXgW = thr_copy_X.partition_S(gW)
|
|
139
|
+
tXgW = thr_copy_X.partition_S(gW) if const_expr(mW is not None) else None
|
|
264
140
|
tXgB = thr_copy_X.partition_S(gB) if const_expr(mB is not None) else None
|
|
265
141
|
tXgX = thr_copy_X.partition_S(gX)
|
|
266
142
|
tXsX = thr_copy_X.partition_D(sX)
|
|
@@ -271,34 +147,40 @@ class RMSNorm(ReductionBase):
|
|
|
271
147
|
if const_expr(mResO is not None):
|
|
272
148
|
tXgResO = thr_copy_X.partition_D(gResO)
|
|
273
149
|
tXrRstd = thr_copy_X.partition_D(gRstd) if const_expr(mRstd is not None) else None
|
|
150
|
+
tXrMean = thr_copy_X.partition_D(gMean) if const_expr(mMean is not None) else None
|
|
274
151
|
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
275
152
|
|
|
276
153
|
# allocate fragments for gmem->rmem
|
|
277
|
-
tXrW = cute.make_fragment_like(tXgW)
|
|
278
|
-
tXrW.fill(0.0)
|
|
154
|
+
tXrW = cute.make_fragment_like(tXgW) if const_expr(mW is not None) else None
|
|
279
155
|
tXrB = cute.make_fragment_like(tXgB) if const_expr(mB is not None) else None
|
|
280
156
|
tXrX, tXrO = [cute.make_fragment_like(t) for t in (tXgX, tXgO)]
|
|
281
157
|
if const_expr(mRes is not None):
|
|
282
158
|
tXrRes = cute.make_fragment_like(tXgRes)
|
|
283
159
|
|
|
284
|
-
num_warps = cute.size(
|
|
160
|
+
num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE
|
|
285
161
|
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
286
162
|
|
|
287
|
-
is_even_N =
|
|
163
|
+
is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
288
164
|
tXpX = (
|
|
289
|
-
|
|
165
|
+
copy_utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
166
|
+
if not is_even_N
|
|
167
|
+
else None
|
|
290
168
|
)
|
|
169
|
+
# Each copy will use the same predicate
|
|
170
|
+
copy = partial(copy_utils.copy, pred=tXpX)
|
|
171
|
+
|
|
291
172
|
row = tXcX[0][0]
|
|
292
173
|
if row < shape[0]:
|
|
293
|
-
|
|
174
|
+
copy(tXgX, tXsX, is_async=True)
|
|
294
175
|
if const_expr(mRes is not None):
|
|
295
|
-
|
|
176
|
+
copy(tXgRes, tXsRes, is_async=True)
|
|
296
177
|
cute.arch.cp_async_commit_group()
|
|
297
178
|
|
|
298
|
-
if const_expr(not delay_w_load):
|
|
299
|
-
|
|
179
|
+
if const_expr(not self.delay_w_load):
|
|
180
|
+
if const_expr(mW is not None):
|
|
181
|
+
copy(tXgW, tXrW)
|
|
300
182
|
if const_expr(mB is not None):
|
|
301
|
-
|
|
183
|
+
copy(tXgB, tXrB)
|
|
302
184
|
|
|
303
185
|
cute.arch.cp_async_wait_group(0)
|
|
304
186
|
cute.autovec_copy(tXsX, tXrX)
|
|
@@ -310,19 +192,63 @@ class RMSNorm(ReductionBase):
|
|
|
310
192
|
tXrResO = cute.make_fragment_like(tXgResO)
|
|
311
193
|
tXrResO.store(x.to(tXrResO.element_type))
|
|
312
194
|
if row < shape[0]:
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
195
|
+
copy(tXrResO, tXgResO)
|
|
196
|
+
|
|
197
|
+
mean, rstd = None, None
|
|
198
|
+
if const_expr(self.is_layernorm):
|
|
199
|
+
# LayerNorm: compute mean first, then variance
|
|
200
|
+
sum_x = row_reduce(
|
|
201
|
+
x,
|
|
202
|
+
cute.ReductionOp.ADD,
|
|
203
|
+
threads_per_row,
|
|
204
|
+
reduction_buffer[None, None, 0],
|
|
205
|
+
mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None,
|
|
206
|
+
init_val=0.0,
|
|
207
|
+
hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
|
|
208
|
+
)
|
|
209
|
+
mean = sum_x / shape[1]
|
|
210
|
+
if const_expr(mMean is not None):
|
|
211
|
+
# Only the thread corresponding to column 0 writes out the mean to gmem
|
|
212
|
+
if (
|
|
213
|
+
tXcX[0][1] == 0
|
|
214
|
+
and row < shape[0]
|
|
215
|
+
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
216
|
+
):
|
|
217
|
+
tXrMean[0] = mean
|
|
218
|
+
if const_expr(self.reload_from == "smem"):
|
|
219
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
220
|
+
x = tXrX.load().to(cute.Float32)
|
|
221
|
+
if const_expr(mRes is not None):
|
|
222
|
+
cute.autovec_copy(tXsRes, tXrRes)
|
|
223
|
+
x += tXrRes.load().to(cute.Float32)
|
|
224
|
+
elif const_expr(self.reload_from == "gmem"):
|
|
225
|
+
copy(tXgX, tXrX)
|
|
226
|
+
x = tXrX.load().to(cute.Float32)
|
|
227
|
+
if const_expr(mRes is not None):
|
|
228
|
+
copy(tXgRes, tXrRes)
|
|
229
|
+
x += tXrRes.load().to(cute.Float32)
|
|
230
|
+
sum_sq_x_sub_mean = row_reduce(
|
|
231
|
+
(x - mean) * (x - mean),
|
|
232
|
+
cute.ReductionOp.ADD,
|
|
233
|
+
threads_per_row,
|
|
234
|
+
reduction_buffer[None, None, 1],
|
|
235
|
+
mbar_ptr + 1 if const_expr(self.cluster_n > 1) else None,
|
|
236
|
+
init_val=0.0,
|
|
237
|
+
)
|
|
238
|
+
rstd = cute.math.rsqrt(sum_sq_x_sub_mean / shape[1] + eps, fastmath=True)
|
|
239
|
+
else:
|
|
240
|
+
# RMSNorm: compute sum of squares directly
|
|
241
|
+
mean = const_expr(0.0)
|
|
242
|
+
sum_sq_x = row_reduce(
|
|
243
|
+
x * x,
|
|
244
|
+
cute.ReductionOp.ADD,
|
|
245
|
+
threads_per_row,
|
|
246
|
+
reduction_buffer[None, None, 0],
|
|
247
|
+
mbar_ptr,
|
|
248
|
+
init_val=0.0,
|
|
249
|
+
hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
|
|
250
|
+
)
|
|
251
|
+
rstd = cute.math.rsqrt(sum_sq_x / shape[1] + eps, fastmath=True)
|
|
326
252
|
if const_expr(mRstd is not None):
|
|
327
253
|
# Only the thread corresponding to column 0 writes out the rstd to gmem
|
|
328
254
|
if (
|
|
@@ -331,139 +257,114 @@ class RMSNorm(ReductionBase):
|
|
|
331
257
|
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
332
258
|
):
|
|
333
259
|
tXrRstd[0] = rstd
|
|
334
|
-
if const_expr(delay_w_load):
|
|
335
|
-
|
|
260
|
+
if const_expr(self.delay_w_load):
|
|
261
|
+
if const_expr(mW is not None):
|
|
262
|
+
copy(tXgW, tXrW)
|
|
336
263
|
if const_expr(mB is not None):
|
|
337
|
-
|
|
338
|
-
if const_expr(reload_from == "smem" or reload_from == "gmem"):
|
|
339
|
-
if const_expr(reload_from == "smem"):
|
|
264
|
+
copy(tXgB, tXrB)
|
|
265
|
+
if const_expr(self.reload_from == "smem" or self.reload_from == "gmem"):
|
|
266
|
+
if const_expr(self.reload_from == "smem"):
|
|
340
267
|
cute.autovec_copy(tXsX, tXrX)
|
|
268
|
+
if const_expr(mRes is not None):
|
|
269
|
+
cute.autovec_copy(tXsRes, tXrRes)
|
|
341
270
|
else:
|
|
342
|
-
|
|
271
|
+
copy(tXgX, tXrX)
|
|
272
|
+
if const_expr(mRes is not None):
|
|
273
|
+
copy(tXgRes, tXrRes)
|
|
343
274
|
x = tXrX.load().to(cute.Float32)
|
|
344
275
|
if const_expr(mRes is not None):
|
|
345
|
-
cute.autovec_copy(tXsRes, tXrRes)
|
|
346
276
|
x += tXrRes.load().to(cute.Float32)
|
|
347
|
-
x_hat = x * rstd
|
|
348
|
-
|
|
349
|
-
|
|
277
|
+
x_hat = (x - mean) * rstd if const_expr(self.is_layernorm) else x * rstd
|
|
278
|
+
y = x_hat
|
|
279
|
+
if const_expr(mW is not None):
|
|
280
|
+
y *= tXrW.load().to(cute.Float32)
|
|
350
281
|
if const_expr(mB is not None):
|
|
351
|
-
|
|
352
|
-
y = y + b
|
|
282
|
+
y += tXrB.load().to(cute.Float32)
|
|
353
283
|
tXrO.store(y.to(tXrO.element_type))
|
|
354
284
|
if row < shape[0]:
|
|
355
|
-
|
|
285
|
+
copy(tXrO, tXgO)
|
|
356
286
|
|
|
357
287
|
|
|
358
288
|
@torch.library.custom_op(
|
|
359
289
|
"quack::_rmsnorm_fwd",
|
|
360
|
-
mutates_args=("out", "rstd", "residual_out"),
|
|
290
|
+
mutates_args=("out", "rstd", "mean", "residual_out"),
|
|
361
291
|
device_types="cuda",
|
|
362
292
|
# We need to specify the schema manually since we're mutating an optional tensor
|
|
363
|
-
schema="(Tensor x, Tensor weight, Tensor(a2!) out, Tensor? bias, Tensor(a4!)? rstd, Tensor? residual, Tensor(
|
|
293
|
+
schema="(Tensor x, Tensor? weight, Tensor(a2!) out, Tensor? bias, Tensor(a4!)? rstd, Tensor(a5!)? mean, Tensor? residual, Tensor(a7!)? residual_out, float eps=1e-6, bool is_layernorm=False) -> ()",
|
|
364
294
|
)
|
|
365
295
|
def _rmsnorm_fwd(
|
|
366
296
|
x: Tensor,
|
|
367
|
-
weight: Tensor,
|
|
297
|
+
weight: Optional[Tensor],
|
|
368
298
|
out: Tensor,
|
|
369
299
|
bias: Optional[Tensor] = None,
|
|
370
300
|
rstd: Optional[Tensor] = None,
|
|
301
|
+
mean: Optional[Tensor] = None,
|
|
371
302
|
residual: Optional[Tensor] = None,
|
|
372
303
|
residual_out: Optional[Tensor] = None,
|
|
373
304
|
eps: float = 1e-6,
|
|
305
|
+
is_layernorm: bool = False,
|
|
374
306
|
) -> None:
|
|
375
|
-
"""RMSNorm forward pass.
|
|
307
|
+
"""RMSNorm/LayerNorm forward pass.
|
|
376
308
|
Args:
|
|
377
309
|
x: Input tensor of shape (M, N)
|
|
378
|
-
weight:
|
|
310
|
+
weight: Optional weight tensor of shape (N,)
|
|
379
311
|
eps: Small value for numerical stability
|
|
312
|
+
is_layernorm: If True, compute LayerNorm instead of RMSNorm
|
|
380
313
|
Returns:
|
|
381
314
|
Normalized output tensor of same shape as x
|
|
382
315
|
"""
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
assert x.
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
assert weight.dtype in [
|
|
389
|
-
torch.float32,
|
|
390
|
-
torch.bfloat16,
|
|
391
|
-
torch.float16,
|
|
392
|
-
], "Weight must be float32, float16 or bfloat16"
|
|
316
|
+
# Don't need to check is_cuda since torch.library ensures that
|
|
317
|
+
supported_types = {torch.float16, torch.bfloat16, torch.float32}
|
|
318
|
+
assert x.dtype in supported_types, "Unsupported dtype"
|
|
319
|
+
if weight is not None:
|
|
320
|
+
assert weight.dtype in supported_types, "Weight must be float32, float16 or bfloat16"
|
|
393
321
|
if residual is not None:
|
|
394
|
-
assert residual.
|
|
395
|
-
assert residual.is_cuda
|
|
396
|
-
assert residual.dtype in [
|
|
397
|
-
torch.float16,
|
|
398
|
-
torch.bfloat16,
|
|
399
|
-
torch.float32,
|
|
400
|
-
], "Residual must be float16, bfloat16, or float32"
|
|
322
|
+
assert residual.dtype in supported_types, "Residual must be float16, bfloat16, or float32"
|
|
401
323
|
|
|
402
324
|
_, N = x.shape
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
# from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
407
|
-
# mode=0, divisibility=128 // dtype.width
|
|
408
|
-
# )
|
|
409
|
-
# )
|
|
410
|
-
convert_from_dlpack = lambda x: (
|
|
411
|
-
from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
|
|
412
|
-
)
|
|
413
|
-
x_tensor, res_tensor, out_tensor, res_out_tensor = [
|
|
414
|
-
convert_from_dlpack(t) if t is not None else None for t in (x, residual, out, residual_out)
|
|
325
|
+
dtype, out_dtype, weight_dtype, bias_dtype, res_dtype, res_out_dtype = [
|
|
326
|
+
torch2cute_dtype_map[t.dtype] if t is not None else None
|
|
327
|
+
for t in [x, out, weight, bias, residual, residual_out]
|
|
415
328
|
]
|
|
416
|
-
# handle weight divisibility based on weight dtype
|
|
417
|
-
weight_dtype = torch2cute_dtype_map[weight.dtype]
|
|
418
|
-
weight_tensor = utils.convert_from_dlpack(
|
|
419
|
-
weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
|
|
420
|
-
)
|
|
421
|
-
if bias is not None:
|
|
422
|
-
bias_dtype = torch2cute_dtype_map[bias.dtype]
|
|
423
|
-
bias_tensor = utils.convert_from_dlpack(
|
|
424
|
-
bias.detach(), leading_dim=0, divisibility=128 // bias_dtype.width
|
|
425
|
-
)
|
|
426
|
-
else:
|
|
427
|
-
bias_tensor = None
|
|
428
|
-
rstd_tensor = (
|
|
429
|
-
from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
430
|
-
if rstd is not None
|
|
431
|
-
else None
|
|
432
|
-
)
|
|
433
|
-
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
434
329
|
compile_key = (
|
|
435
|
-
N,
|
|
436
330
|
dtype,
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
331
|
+
out_dtype,
|
|
332
|
+
res_dtype,
|
|
333
|
+
weight_dtype,
|
|
334
|
+
bias_dtype,
|
|
335
|
+
res_out_dtype,
|
|
336
|
+
N,
|
|
441
337
|
rstd is not None,
|
|
338
|
+
mean is not None,
|
|
339
|
+
is_layernorm,
|
|
442
340
|
)
|
|
443
341
|
if compile_key not in _rmsnorm_fwd.compile_cache:
|
|
444
|
-
|
|
342
|
+
batch_sym = cute.sym_int()
|
|
343
|
+
all_dtypes = [dtype, out_dtype, res_dtype, weight_dtype, bias_dtype, res_out_dtype]
|
|
344
|
+
div = math.gcd(N, *(128 // dt.width for dt in all_dtypes if dt is not None))
|
|
345
|
+
x_cute, out_cute, res_cute, res_out_cute = [
|
|
346
|
+
fake_tensor(dt, (batch_sym, N), div)
|
|
347
|
+
for dt in [dtype, out_dtype, res_dtype, res_out_dtype]
|
|
348
|
+
]
|
|
349
|
+
weight_cute, bias_cute = [fake_tensor(dt, (N,), div) for dt in [weight_dtype, bias_dtype]]
|
|
350
|
+
rstd_cute = fake_tensor(Float32, (batch_sym,)) if rstd is not None else None
|
|
351
|
+
mean_cute = fake_tensor(Float32, (batch_sym,)) if mean is not None else None
|
|
445
352
|
_rmsnorm_fwd.compile_cache[compile_key] = cute.compile(
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
eps,
|
|
353
|
+
RMSNorm(dtype, N, is_layernorm=is_layernorm),
|
|
354
|
+
x_cute,
|
|
355
|
+
weight_cute,
|
|
356
|
+
bias_cute,
|
|
357
|
+
res_cute,
|
|
358
|
+
out_cute,
|
|
359
|
+
res_out_cute,
|
|
360
|
+
rstd_cute,
|
|
361
|
+
mean_cute,
|
|
362
|
+
Float32(0), # eps, just for compilation
|
|
363
|
+
cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
|
|
364
|
+
options="--enable-tvm-ffi",
|
|
456
365
|
)
|
|
457
366
|
_rmsnorm_fwd.compile_cache[compile_key](
|
|
458
|
-
|
|
459
|
-
weight_tensor,
|
|
460
|
-
bias_tensor,
|
|
461
|
-
res_tensor,
|
|
462
|
-
out_tensor,
|
|
463
|
-
res_out_tensor,
|
|
464
|
-
rstd_tensor,
|
|
465
|
-
current_stream,
|
|
466
|
-
eps,
|
|
367
|
+
x, weight, bias, residual, out, residual_out, rstd, mean, eps
|
|
467
368
|
)
|
|
468
369
|
|
|
469
370
|
|
|
@@ -472,7 +373,7 @@ _rmsnorm_fwd.compile_cache = {}
|
|
|
472
373
|
|
|
473
374
|
def rmsnorm_fwd(
|
|
474
375
|
x: Tensor,
|
|
475
|
-
weight: Tensor,
|
|
376
|
+
weight: Optional[Tensor] = None,
|
|
476
377
|
bias: Optional[Tensor] = None,
|
|
477
378
|
residual: Optional[Tensor] = None,
|
|
478
379
|
out_dtype: Optional[torch.dtype] = None,
|
|
@@ -494,19 +395,20 @@ def rmsnorm_fwd(
|
|
|
494
395
|
)
|
|
495
396
|
else:
|
|
496
397
|
residual_out = None
|
|
497
|
-
_rmsnorm_fwd(x, weight, out, bias, rstd, residual, residual_out, eps
|
|
398
|
+
_rmsnorm_fwd(x, weight, out, bias, rstd, None, residual, residual_out, eps, False)
|
|
498
399
|
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
|
|
499
400
|
if residual_out is None:
|
|
500
401
|
residual_out = x
|
|
501
402
|
return out, residual_out, rstd
|
|
502
403
|
|
|
503
404
|
|
|
504
|
-
def rmsnorm_ref(x, w, bias=None, residual=None, eps=1e-6):
|
|
405
|
+
def rmsnorm_ref(x, w=None, bias=None, residual=None, eps=1e-6):
|
|
505
406
|
x_f32 = x.float()
|
|
506
407
|
if residual is not None:
|
|
507
408
|
residual_f32 = residual.float()
|
|
508
409
|
x_f32 += residual_f32
|
|
509
|
-
|
|
410
|
+
x_norm = x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps))
|
|
411
|
+
out = x_norm * w if w is not None else x_norm
|
|
510
412
|
if bias is not None:
|
|
511
413
|
out = out + bias.float()
|
|
512
414
|
if residual is None:
|
|
@@ -519,13 +421,19 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
|
|
|
519
421
|
"""Reference implementation for RMSNorm backward pass."""
|
|
520
422
|
x_f32 = x.float()
|
|
521
423
|
x_hat = x_f32 * rstd.unsqueeze(1)
|
|
522
|
-
|
|
424
|
+
if w is not None:
|
|
425
|
+
wdy = dout * w
|
|
426
|
+
else:
|
|
427
|
+
wdy = dout
|
|
523
428
|
c1 = (x_hat * wdy).mean(dim=-1, keepdim=True)
|
|
524
429
|
dx = (wdy - x_hat * c1) * rstd.unsqueeze(1)
|
|
525
430
|
|
|
526
431
|
# dL/dW
|
|
527
|
-
|
|
528
|
-
|
|
432
|
+
if w is not None:
|
|
433
|
+
dw = (dout * x_hat).sum(dim=0)
|
|
434
|
+
return dx.to(x.dtype), dw.to(w.dtype)
|
|
435
|
+
else:
|
|
436
|
+
return dx.to(x.dtype), None
|
|
529
437
|
|
|
530
438
|
|
|
531
439
|
class RMSNormBackward(ReductionBase):
|
|
@@ -537,91 +445,57 @@ class RMSNormBackward(ReductionBase):
|
|
|
537
445
|
# Not enough smem
|
|
538
446
|
raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits")
|
|
539
447
|
|
|
540
|
-
def
|
|
448
|
+
def _num_threads(self):
|
|
541
449
|
return 128 if self.N <= 4096 else 256
|
|
542
450
|
|
|
543
|
-
def
|
|
451
|
+
def _threads_per_row(self):
|
|
544
452
|
N = self.N
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
16
|
|
550
|
-
if N <= 128
|
|
551
|
-
else (32 if N <= 256 else (64 if N <= 512 else (128 if N <= 4096 else 256)))
|
|
552
|
-
)
|
|
553
|
-
)
|
|
453
|
+
for limit, threads in [(64, 8), (128, 16), (256, 32), (512, 64), (4096, 128)]:
|
|
454
|
+
if N <= limit:
|
|
455
|
+
return threads
|
|
456
|
+
return 256
|
|
554
457
|
|
|
555
458
|
def _set_cluster_n(self):
|
|
556
459
|
N = self.N
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
self.cluster_n = cluster_n
|
|
563
|
-
|
|
564
|
-
def _smem_size_in_bytes(self, tiler_mn, num_warps, do_dtype=None):
|
|
565
|
-
if do_dtype is None:
|
|
566
|
-
do_dtype = self.dtype
|
|
567
|
-
return (
|
|
568
|
-
# We need space for X and dO, and multiply by 2 due to double buffering
|
|
569
|
-
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2
|
|
570
|
-
+ cute.size_in_bytes(do_dtype, cute.make_layout(tiler_mn)) * 2
|
|
571
|
-
+ self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
|
|
572
|
-
+ self.stage * (cutlass.Int64.width // 8) * 2 # mult 2 as we need 2 mbar per stage
|
|
573
|
-
)
|
|
460
|
+
for limit, cluster in [(8 * 1024, 1), (16 * 1024, 2), (32 * 1024, 4), (64 * 1024, 8)]:
|
|
461
|
+
if N <= limit:
|
|
462
|
+
self.cluster_n = cluster
|
|
463
|
+
return
|
|
464
|
+
self.cluster_n = 16
|
|
574
465
|
|
|
575
466
|
@cute.jit
|
|
576
467
|
def __call__(
|
|
577
468
|
self,
|
|
578
469
|
mX: cute.Tensor,
|
|
579
|
-
mW: cute.Tensor,
|
|
470
|
+
mW: Optional[cute.Tensor],
|
|
580
471
|
mdO: cute.Tensor,
|
|
581
472
|
mdResO: Optional[cute.Tensor],
|
|
582
473
|
mRstd: cute.Tensor,
|
|
583
474
|
mdX: cute.Tensor,
|
|
584
|
-
mdW: cute.Tensor,
|
|
475
|
+
mdW: Optional[cute.Tensor],
|
|
585
476
|
mdRes: Optional[cute.Tensor],
|
|
586
477
|
mdB: Optional[cute.Tensor],
|
|
587
478
|
sm_count: Int32,
|
|
588
479
|
stream: cuda.CUstream,
|
|
589
480
|
):
|
|
590
|
-
|
|
591
|
-
new_stride = lambda t: (
|
|
592
|
-
cute.assume(t.stride[0], divby=128 // t.element_type.width),
|
|
593
|
-
t.stride[1],
|
|
594
|
-
)
|
|
595
|
-
mX, mdO, mdResO, mdX, mdRes = [
|
|
596
|
-
cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
|
|
597
|
-
if const_expr(t is not None)
|
|
598
|
-
else None
|
|
599
|
-
for t in (mX, mdO, mdResO, mdX, mdRes)
|
|
600
|
-
]
|
|
481
|
+
assert mX.element_type == self.dtype
|
|
601
482
|
self._set_cluster_n()
|
|
602
483
|
largest_dtype_width = const_expr(
|
|
603
|
-
max(
|
|
604
|
-
mX.element_type.width,
|
|
605
|
-
mdO.element_type.width,
|
|
606
|
-
mdX.element_type.width,
|
|
607
|
-
mdResO.element_type.width if mdResO is not None else 0,
|
|
608
|
-
mdRes.element_type.width if mdRes is not None else 0,
|
|
609
|
-
)
|
|
484
|
+
max(*(t.element_type.width for t in [mX, mW, mdO, mdResO, mdX, mdRes] if t is not None))
|
|
610
485
|
)
|
|
611
|
-
|
|
612
|
-
|
|
486
|
+
vecsize = math.gcd(self.N, 128 // largest_dtype_width)
|
|
487
|
+
tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize)
|
|
488
|
+
num_threads = tiled_copy.size
|
|
489
|
+
mW = (
|
|
490
|
+
layout_utils.expand(mW, dim=0, size=tiler_mn[0]) if const_expr(mW is not None) else None
|
|
613
491
|
)
|
|
614
|
-
num_threads = cute.size(tv_layout, mode=[0])
|
|
615
|
-
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
616
|
-
mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
|
|
617
|
-
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
618
|
-
|
|
619
492
|
num_blocks = sm_count
|
|
620
|
-
self.kernel(
|
|
493
|
+
self.kernel(
|
|
494
|
+
mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tiler_mn, tiled_copy, threads_per_row
|
|
495
|
+
).launch(
|
|
621
496
|
grid=[num_blocks, self.cluster_n, 1],
|
|
622
497
|
block=[num_threads, 1, 1],
|
|
623
498
|
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
624
|
-
smem=self._smem_size_in_bytes(tiler_mn, num_warps, do_dtype=mdO.element_type),
|
|
625
499
|
stream=stream,
|
|
626
500
|
)
|
|
627
501
|
|
|
@@ -629,24 +503,23 @@ class RMSNormBackward(ReductionBase):
|
|
|
629
503
|
def kernel(
|
|
630
504
|
self,
|
|
631
505
|
mX: cute.Tensor,
|
|
632
|
-
mW: cute.Tensor,
|
|
506
|
+
mW: Optional[cute.Tensor],
|
|
633
507
|
mdO: cute.Tensor,
|
|
634
508
|
mdResO: Optional[cute.Tensor],
|
|
635
509
|
mRstd: cute.Tensor,
|
|
636
510
|
mdX: cute.Tensor,
|
|
637
|
-
mdW: cute.Tensor,
|
|
511
|
+
mdW: Optional[cute.Tensor],
|
|
638
512
|
mdB: Optional[cute.Tensor],
|
|
639
513
|
mdRes: Optional[cute.Tensor],
|
|
640
|
-
tv_layout: cute.Layout,
|
|
641
514
|
tiler_mn: cute.Shape,
|
|
515
|
+
tiled_copy: cute.TiledCopy,
|
|
516
|
+
threads_per_row: cutlass.Constexpr[int],
|
|
642
517
|
):
|
|
643
518
|
tidx, _, _ = cute.arch.thread_idx()
|
|
644
519
|
bidx_start, _, _ = cute.arch.block_idx()
|
|
645
520
|
gdim, _, _ = cute.arch.grid_dim()
|
|
646
|
-
if const_expr(self.cluster_n
|
|
647
|
-
|
|
648
|
-
else:
|
|
649
|
-
cluster_y = const_expr(0)
|
|
521
|
+
cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1]
|
|
522
|
+
tv_layout = tiled_copy.layout_tv_tiled
|
|
650
523
|
|
|
651
524
|
shape = mX.shape
|
|
652
525
|
M, N = shape[0], shape[1]
|
|
@@ -666,103 +539,20 @@ class RMSNormBackward(ReductionBase):
|
|
|
666
539
|
else:
|
|
667
540
|
mbar_full_ptr, mbar_empty_ptr = None, None
|
|
668
541
|
|
|
669
|
-
|
|
670
|
-
num_copy_bits_X = mX.element_type.width * num_copy_elems_X
|
|
671
|
-
copy_atom_load_X = cute.make_copy_atom(
|
|
672
|
-
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
673
|
-
)
|
|
674
|
-
copy_atom_load_X_async = cute.make_copy_atom(
|
|
675
|
-
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
676
|
-
)
|
|
677
|
-
num_copy_bits_dO = const_expr(min(128, num_copy_elems_X * mdO.element_type.width))
|
|
678
|
-
copy_atom_load_dO_async = cute.make_copy_atom(
|
|
679
|
-
cute.nvgpu.cpasync.CopyG2SOp(), mdO.element_type, num_bits_per_copy=num_copy_bits_dO
|
|
680
|
-
)
|
|
681
|
-
num_copy_bits_W = const_expr(min(128, num_copy_elems_X * mW.element_type.width))
|
|
682
|
-
copy_atom_load_W = cute.make_copy_atom(
|
|
683
|
-
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
|
|
684
|
-
)
|
|
685
|
-
if const_expr(mdResO is not None):
|
|
686
|
-
num_copy_bits_dResO = const_expr(min(128, num_copy_elems_X * mdResO.element_type.width))
|
|
687
|
-
copy_atom_load_dResO = cute.make_copy_atom(
|
|
688
|
-
cute.nvgpu.CopyUniversalOp(),
|
|
689
|
-
mdResO.element_type,
|
|
690
|
-
num_bits_per_copy=num_copy_bits_dResO,
|
|
691
|
-
)
|
|
692
|
-
num_copy_bits_dX = const_expr(min(128, num_copy_elems_X * mdX.element_type.width))
|
|
693
|
-
copy_atom_store_dX = cute.make_copy_atom(
|
|
694
|
-
cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_copy_bits_dX
|
|
695
|
-
)
|
|
696
|
-
num_copy_bits_dW = const_expr(min(128, num_copy_elems_X * mdW.element_type.width))
|
|
697
|
-
copy_atom_store_dW = cute.make_copy_atom(
|
|
698
|
-
cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=num_copy_bits_dW
|
|
699
|
-
)
|
|
700
|
-
if const_expr(mdB is not None):
|
|
701
|
-
num_copy_bits_dB = const_expr(min(128, num_copy_elems_X * mdB.element_type.width))
|
|
702
|
-
copy_atom_store_dB = cute.make_copy_atom(
|
|
703
|
-
cute.nvgpu.CopyUniversalOp(), mdB.element_type, num_bits_per_copy=num_copy_bits_dB
|
|
704
|
-
)
|
|
705
|
-
if const_expr(mdRes is not None):
|
|
706
|
-
num_copy_bits_dRes = const_expr(min(128, num_copy_elems_X * mdRes.element_type.width))
|
|
707
|
-
copy_atom_load_dRes = cute.make_copy_atom(
|
|
708
|
-
cute.nvgpu.CopyUniversalOp(),
|
|
709
|
-
mdRes.element_type,
|
|
710
|
-
num_bits_per_copy=num_copy_bits_dRes,
|
|
711
|
-
)
|
|
712
|
-
|
|
713
|
-
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
714
|
-
|
|
715
|
-
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
716
|
-
tXgW = thr_copy_X.partition_S(gW)
|
|
717
|
-
tXrW = cute.make_fragment_like(tXgW)
|
|
718
|
-
# Need this, otherwise rW can have arbitrary values that changes the reduction
|
|
719
|
-
if not is_even_N:
|
|
720
|
-
tXrW.fill(0.0)
|
|
721
|
-
|
|
722
|
-
gW_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
|
|
723
|
-
tXpW = (
|
|
724
|
-
utils.predicate_k(thr_copy_X.partition_S(gW_coord), limit=shape[1])
|
|
725
|
-
if not is_even_N
|
|
726
|
-
else None
|
|
727
|
-
)
|
|
728
|
-
cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpW)
|
|
729
|
-
weight = tXrW.load().to(cute.Float32)
|
|
730
|
-
|
|
731
|
-
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
732
|
-
|
|
733
|
-
self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True)
|
|
734
|
-
|
|
735
|
-
dw_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
|
|
736
|
-
tXpdW = (
|
|
737
|
-
utils.predicate_k(thr_copy_X.partition_S(dw_coord), limit=shape[1])
|
|
738
|
-
if not is_even_N
|
|
739
|
-
else None
|
|
740
|
-
)
|
|
741
|
-
if const_expr(mdB is not None):
|
|
742
|
-
db_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
|
|
743
|
-
tXpdB = (
|
|
744
|
-
utils.predicate_k(thr_copy_X.partition_S(db_coord), limit=shape[1])
|
|
745
|
-
if not is_even_N
|
|
746
|
-
else None
|
|
747
|
-
)
|
|
748
|
-
|
|
749
|
-
gdW = cute.local_tile(mdW, (1, tiler_mn[1]), (bidx_start, cluster_y))
|
|
750
|
-
tXgdW = thr_copy_X.partition_S(gdW)
|
|
751
|
-
# Always compute partial weight gradients in fp32
|
|
752
|
-
tXrdW = cute.make_fragment_like(tXgdW, Float32)
|
|
753
|
-
|
|
754
|
-
gdB = (
|
|
755
|
-
cute.local_tile(mdB, (1, tiler_mn[1]), (bidx_start, cluster_y))
|
|
756
|
-
if const_expr(mdB is not None)
|
|
757
|
-
else None
|
|
758
|
-
)
|
|
759
|
-
tXgdB = thr_copy_X.partition_S(gdB) if const_expr(mdB is not None) else None
|
|
760
|
-
tXrdB = cute.make_fragment_like(tXgdB, Float32) if const_expr(mdB is not None) else None
|
|
542
|
+
thr_copy_X = tiled_copy.get_slice(tidx)
|
|
761
543
|
|
|
762
544
|
gX, gdO, gdResO, gdX, gdRes, cX = [
|
|
763
545
|
cute.local_tile(mT, tiler_mn, (None, cluster_y)) if mT is not None else None
|
|
764
546
|
for mT in (mX, mdO, mdResO, mdX, mdRes, idX)
|
|
765
547
|
]
|
|
548
|
+
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y)) if mW is not None else None
|
|
549
|
+
gdW, gdB = [
|
|
550
|
+
cute.local_tile(mT, (1, tiler_mn[1]), (bidx_start, cluster_y))
|
|
551
|
+
if const_expr(mT is not None)
|
|
552
|
+
else None
|
|
553
|
+
for mT in (mdW, mdB)
|
|
554
|
+
]
|
|
555
|
+
|
|
766
556
|
tXgX = thr_copy_X.partition_S(gX)
|
|
767
557
|
tXsX = thr_copy_X.partition_D(sX)
|
|
768
558
|
tXgdO = thr_copy_X.partition_S(gdO)
|
|
@@ -773,12 +563,6 @@ class RMSNormBackward(ReductionBase):
|
|
|
773
563
|
if const_expr(mdRes is not None):
|
|
774
564
|
tXgdRes = thr_copy_X.partition_D(gdRes)
|
|
775
565
|
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None]
|
|
776
|
-
# This doesn't change across iterations
|
|
777
|
-
tXpX = (
|
|
778
|
-
utils.predicate_k(thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1])
|
|
779
|
-
if not is_even_N
|
|
780
|
-
else None
|
|
781
|
-
)
|
|
782
566
|
|
|
783
567
|
tXrX, tXrdO, tXrdX = [
|
|
784
568
|
cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdO, tXgdX)
|
|
@@ -790,28 +574,57 @@ class RMSNormBackward(ReductionBase):
|
|
|
790
574
|
if const_expr(mdRes is not None):
|
|
791
575
|
tXrdRes = cute.make_fragment_like(tXgdRes[None, None, None, 0])
|
|
792
576
|
|
|
793
|
-
|
|
794
|
-
|
|
577
|
+
# This doesn't change across iterations
|
|
578
|
+
tXpX = (
|
|
579
|
+
None
|
|
580
|
+
if is_even_N
|
|
581
|
+
else copy_utils.predicate_k(thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1])
|
|
582
|
+
)
|
|
583
|
+
# Each copy will use the same number of elements as X
|
|
584
|
+
copy = partial(copy_utils.copy, pred=tXpX)
|
|
585
|
+
|
|
586
|
+
tXgdW, tXrdW = None, None
|
|
587
|
+
tXgdB, tXrdB = None, None
|
|
588
|
+
if const_expr(mdW is not None):
|
|
589
|
+
tXgdW = thr_copy_X.partition_S(gdW)
|
|
590
|
+
# Always compute partial weight gradients in fp32
|
|
591
|
+
tXrdW = cute.make_fragment_like(tXgdW, Float32)
|
|
592
|
+
if const_expr(mdB is not None):
|
|
593
|
+
tXgdB = thr_copy_X.partition_S(gdB)
|
|
594
|
+
# Always compute partial bias gradients in fp32
|
|
595
|
+
tXrdB = cute.make_fragment_like(tXgdB, Float32)
|
|
596
|
+
|
|
597
|
+
num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE
|
|
598
|
+
|
|
599
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True)
|
|
600
|
+
|
|
601
|
+
tXrW = None
|
|
602
|
+
if const_expr(mW is not None):
|
|
603
|
+
tXgW = thr_copy_X.partition_S(gW)
|
|
604
|
+
tXrW = cute.make_fragment_like(tXgW)
|
|
605
|
+
# Need this, otherwise rW can have arbitrary values that changes the reduction
|
|
606
|
+
if const_expr(not is_even_N):
|
|
607
|
+
tXrW.fill(0.0)
|
|
608
|
+
copy(tXgW, tXrW)
|
|
795
609
|
|
|
796
610
|
# Prefetch the first batch
|
|
797
611
|
row = tXcX[None, None, None, bidx_start][0][0]
|
|
798
612
|
if row < M:
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
utils.fill_oob(tXsdO[None, None, None, 0], None, fill_value=mdO.element_type.zero)
|
|
613
|
+
copy(tXgX[None, None, None, bidx_start], tXsX[None, None, None, 0], is_async=True)
|
|
614
|
+
copy(tXgdO[None, None, None, bidx_start], tXsdO[None, None, None, 0], is_async=True)
|
|
615
|
+
else:
|
|
616
|
+
if const_expr(tiler_mn[0] > 1):
|
|
617
|
+
# Fill with zero, otherwise smem will be uninitialized, and we could read this back
|
|
618
|
+
# later into registers, causing wrong dW.
|
|
619
|
+
utils.fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero)
|
|
620
|
+
utils.fill_oob(tXsdO[None, None, None, 0], None, fill_value=mdO.element_type.zero)
|
|
808
621
|
cute.arch.cp_async_commit_group()
|
|
809
622
|
|
|
810
623
|
if const_expr(self.cluster_n > 1):
|
|
811
624
|
cute.arch.cluster_wait()
|
|
812
625
|
|
|
813
|
-
|
|
814
|
-
|
|
626
|
+
if const_expr(mdW is not None):
|
|
627
|
+
tXrdW.fill(0.0)
|
|
815
628
|
if const_expr(mdB is not None):
|
|
816
629
|
tXrdB.fill(0.0)
|
|
817
630
|
stage = Int32(0)
|
|
@@ -820,29 +633,31 @@ class RMSNormBackward(ReductionBase):
|
|
|
820
633
|
for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
|
|
821
634
|
row = tXcX[None, None, None, bidx][0][0]
|
|
822
635
|
if row + gdim * tiler_mn[0] < M: # Prefetch the next batch
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
copy_X(tXgX_cur, tXsX[None, None, None, stage ^ 1])
|
|
826
|
-
copy_dO(tXgdO_cur, tXsdO[None, None, None, stage ^ 1])
|
|
827
|
-
elif tiler_mn[0] > 1:
|
|
828
|
-
utils.fill_oob(
|
|
636
|
+
copy(
|
|
637
|
+
tXgX[None, None, None, bidx + gdim],
|
|
829
638
|
tXsX[None, None, None, stage ^ 1],
|
|
830
|
-
|
|
831
|
-
fill_value=mX.element_type.zero,
|
|
639
|
+
is_async=True,
|
|
832
640
|
)
|
|
833
|
-
|
|
641
|
+
copy(
|
|
642
|
+
tXgdO[None, None, None, bidx + gdim],
|
|
834
643
|
tXsdO[None, None, None, stage ^ 1],
|
|
835
|
-
|
|
836
|
-
fill_value=mdO.element_type.zero,
|
|
644
|
+
is_async=True,
|
|
837
645
|
)
|
|
646
|
+
else:
|
|
647
|
+
if const_expr(tiler_mn[0] > 1):
|
|
648
|
+
utils.fill_oob(
|
|
649
|
+
tXsX[None, None, None, stage ^ 1], None, fill_value=mX.element_type.zero
|
|
650
|
+
)
|
|
651
|
+
utils.fill_oob(
|
|
652
|
+
tXsdO[None, None, None, stage ^ 1], None, fill_value=mdO.element_type.zero
|
|
653
|
+
)
|
|
838
654
|
cute.arch.cp_async_commit_group()
|
|
839
655
|
rstd = cutlass.Float.zero
|
|
840
656
|
if row < M or tiler_mn[0] == 1:
|
|
841
657
|
rstd = mRstd[row]
|
|
842
658
|
if const_expr(mdResO is not None):
|
|
843
|
-
tXgdResO_cur = utils.coord_offset_i64(bidx, tXgdResO, dim=3)[None, None, None, 0]
|
|
844
659
|
if row < M or tiler_mn[0] == 1:
|
|
845
|
-
|
|
660
|
+
copy(tXgdResO[None, None, None, bidx], tXrdResO)
|
|
846
661
|
elif tiler_mn[0] > 1:
|
|
847
662
|
tXrdResO.fill(0.0)
|
|
848
663
|
cute.arch.cp_async_wait_group(1)
|
|
@@ -850,10 +665,10 @@ class RMSNormBackward(ReductionBase):
|
|
|
850
665
|
x = tXrX.load().to(cute.Float32)
|
|
851
666
|
cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO)
|
|
852
667
|
dout = tXrdO.load().to(cute.Float32)
|
|
853
|
-
if const_expr(mdResO is not None):
|
|
854
|
-
dout += tXrdResO.load().to(cute.Float32)
|
|
855
668
|
x_hat = x * rstd
|
|
856
|
-
wdy = dout
|
|
669
|
+
wdy = dout
|
|
670
|
+
if const_expr(mW is not None):
|
|
671
|
+
wdy *= tXrW.load().to(Float32)
|
|
857
672
|
if const_expr(self.cluster_n > 1):
|
|
858
673
|
cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
|
|
859
674
|
mean_xhat_wdy = (
|
|
@@ -870,6 +685,10 @@ class RMSNormBackward(ReductionBase):
|
|
|
870
685
|
)
|
|
871
686
|
|
|
872
687
|
if const_expr(self.cluster_n > 1):
|
|
688
|
+
# Need this fence since the STAS from the producer is using the async proxy.
|
|
689
|
+
cute.arch.fence_proxy(
|
|
690
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
691
|
+
)
|
|
873
692
|
# It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
|
|
874
693
|
# Requires adjusting the thread_count when initializing the mbar
|
|
875
694
|
cute.arch.sync_warp()
|
|
@@ -882,22 +701,22 @@ class RMSNormBackward(ReductionBase):
|
|
|
882
701
|
if const_expr(self.reload_wdy == "smem"):
|
|
883
702
|
cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO)
|
|
884
703
|
dout = tXrdO.load().to(cute.Float32)
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
704
|
+
wdy = dout
|
|
705
|
+
if const_expr(mW is not None):
|
|
706
|
+
wdy *= tXrW.load().to(Float32)
|
|
888
707
|
|
|
889
708
|
dx = (wdy - x_hat * mean_xhat_wdy) * rstd
|
|
709
|
+
if const_expr(mdResO is not None):
|
|
710
|
+
dx += tXrdResO.load().to(cute.Float32)
|
|
890
711
|
tXrdX.store(dx.to(tXrdX.element_type))
|
|
891
712
|
if row < M or tiler_mn[0] == 1:
|
|
892
|
-
|
|
893
|
-
cute.copy(copy_atom_store_dX, tXrdX, tXgdX_cur, pred=tXpX)
|
|
713
|
+
copy(tXrdX, tXgdX[None, None, None, bidx])
|
|
894
714
|
if const_expr(mdRes is not None):
|
|
895
715
|
tXrdRes.store(dx.to(tXrdRes.element_type))
|
|
896
|
-
tXgdRes_cur = utils.coord_offset_i64(bidx, tXgdRes, dim=3)[None, None, None, 0]
|
|
897
716
|
if row < M or tiler_mn[0] == 1:
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
717
|
+
copy(tXrdRes, tXgdRes[None, None, None, bidx])
|
|
718
|
+
if const_expr(mdW is not None):
|
|
719
|
+
tXrdW.store(tXrdW.load() + dout * x_hat)
|
|
901
720
|
if const_expr(mdB is not None):
|
|
902
721
|
tXrdB.store(tXrdB.load() + dout)
|
|
903
722
|
|
|
@@ -906,29 +725,29 @@ class RMSNormBackward(ReductionBase):
|
|
|
906
725
|
consumer_phase ^= 1
|
|
907
726
|
producer_phase ^= 1
|
|
908
727
|
|
|
909
|
-
if const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
|
|
910
|
-
cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
|
|
911
|
-
|
|
912
728
|
if const_expr(tiler_mn[0] > 1):
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
729
|
+
if const_expr(mdW is not None):
|
|
730
|
+
# reduction of dw_partial within the same threadblock
|
|
731
|
+
sdW = cute.make_tensor(
|
|
732
|
+
cute.recast_ptr(sX.iterator, dtype=cute.Float32),
|
|
733
|
+
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
734
|
+
)
|
|
735
|
+
tXsdW = thr_copy_X.partition_D(sdW)
|
|
736
|
+
cute.arch.barrier()
|
|
737
|
+
row = tXcX[None, None, None, 0][0][0]
|
|
738
|
+
if row > 0:
|
|
739
|
+
cute.autovec_copy(tXrdW, tXsdW)
|
|
740
|
+
cute.arch.barrier()
|
|
741
|
+
if row == 0:
|
|
742
|
+
for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
|
|
743
|
+
tXrdW_other = cute.make_fragment_like(tXrdW)
|
|
744
|
+
tXsdW_other = cute.make_tensor(
|
|
745
|
+
tXsdW.iterator + i * sdW.stride[0], tXsdW.layout
|
|
746
|
+
)
|
|
747
|
+
cute.autovec_copy(tXsdW_other, tXrdW_other)
|
|
748
|
+
tXrdW.store(tXrdW.load() + tXrdW_other.load())
|
|
749
|
+
copy(tXrdW, tXgdW)
|
|
750
|
+
cute.arch.barrier()
|
|
932
751
|
if const_expr(mdB is not None):
|
|
933
752
|
sdB = cute.make_tensor(
|
|
934
753
|
cute.recast_ptr(sX.iterator, dtype=cute.Float32),
|
|
@@ -948,12 +767,21 @@ class RMSNormBackward(ReductionBase):
|
|
|
948
767
|
)
|
|
949
768
|
cute.autovec_copy(tXsdB_other, tXrdB_other)
|
|
950
769
|
tXrdB.store(tXrdB.load() + tXrdB_other.load())
|
|
951
|
-
|
|
770
|
+
copy(tXrdB, tXgdB)
|
|
952
771
|
else:
|
|
953
772
|
# dw is already in fp32, so we can directly copy to global memory
|
|
954
|
-
|
|
773
|
+
if const_expr(mdW is not None):
|
|
774
|
+
copy(tXrdW, tXgdW)
|
|
955
775
|
if const_expr(mdB is not None):
|
|
956
|
-
|
|
776
|
+
copy(tXrdB, tXgdB)
|
|
777
|
+
|
|
778
|
+
if const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
|
|
779
|
+
# Assume state contains that next useful buffer
|
|
780
|
+
# So we only need to advance to num_stages - 1 times to last used buffer
|
|
781
|
+
stage ^= 1
|
|
782
|
+
if stage == 0:
|
|
783
|
+
producer_phase ^= 1
|
|
784
|
+
cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
|
|
957
785
|
|
|
958
786
|
|
|
959
787
|
def _get_sm_count(N: int, device: torch.device) -> int:
|
|
@@ -978,120 +806,103 @@ def _get_sm_count(N: int, device: torch.device) -> int:
|
|
|
978
806
|
mutates_args={"dx", "dw_partial", "db_partial", "dresidual"},
|
|
979
807
|
device_types="cuda",
|
|
980
808
|
# We need to specify the schema manually since we're mutating an optional tensor
|
|
981
|
-
schema="(Tensor x, Tensor weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!) dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual) -> ()",
|
|
809
|
+
schema="(Tensor x, Tensor? weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!)? dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual, int? sm_count) -> ()",
|
|
982
810
|
)
|
|
983
811
|
def _rmsnorm_bwd(
|
|
984
812
|
x: Tensor,
|
|
985
|
-
weight: Tensor,
|
|
813
|
+
weight: Optional[Tensor],
|
|
986
814
|
dout: Tensor,
|
|
987
815
|
rstd: Tensor,
|
|
988
816
|
dx: Tensor,
|
|
989
|
-
dw_partial: Tensor,
|
|
817
|
+
dw_partial: Optional[Tensor],
|
|
990
818
|
db_partial: Optional[Tensor] = None,
|
|
991
819
|
dresidual_out: Optional[Tensor] = None,
|
|
992
820
|
dresidual: Optional[Tensor] = None,
|
|
821
|
+
sm_count: Optional[int] = None,
|
|
993
822
|
) -> None:
|
|
994
823
|
"""RMSNorm backward pass.
|
|
995
824
|
Args:
|
|
996
825
|
x: Input tensor of shape (M, N)
|
|
997
|
-
weight:
|
|
826
|
+
weight: Optional weight tensor of shape (N,)
|
|
998
827
|
dout: Upstream gradients tensor of shape (M, N)
|
|
999
828
|
rstd: Reciprocal standard deviation tensor of shape (M,)
|
|
1000
829
|
Returns:
|
|
1001
830
|
Tuple of (dx, dw) where:
|
|
1002
831
|
- dx: Input gradients tensor of same shape as x
|
|
1003
|
-
- dw: Weight gradients tensor of same shape as weight
|
|
832
|
+
- dw: Weight gradients tensor of same shape as weight (or None if weight is None)
|
|
1004
833
|
"""
|
|
1005
834
|
assert x.dim() == 2, "Input must be 2D"
|
|
1006
|
-
assert
|
|
1007
|
-
|
|
1008
|
-
assert x.
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
], "Weight must be float32, float16 or bfloat16"
|
|
835
|
+
assert x.is_cuda, "Input tensor must be on CUDA device"
|
|
836
|
+
supported_types = {torch.float16, torch.bfloat16, torch.float32}
|
|
837
|
+
assert x.dtype in supported_types, "Unsupported dtype"
|
|
838
|
+
if weight is not None:
|
|
839
|
+
assert weight.dim() == 1, "Weight must be 1D"
|
|
840
|
+
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
841
|
+
assert weight.is_cuda, "Weight tensor must be on CUDA device"
|
|
842
|
+
assert weight.dtype in supported_types, "Weight must be float32, float16 or bfloat16"
|
|
1015
843
|
if dresidual_out is not None:
|
|
1016
844
|
assert dresidual_out.shape == x.shape
|
|
1017
845
|
assert dresidual_out.is_cuda
|
|
1018
|
-
assert dresidual_out.dtype in
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
torch.float32,
|
|
1022
|
-
], "Residual must be float16, bfloat16, or float32"
|
|
846
|
+
assert dresidual_out.dtype in supported_types, (
|
|
847
|
+
"Residual must be float16, bfloat16, or float32"
|
|
848
|
+
)
|
|
1023
849
|
if dresidual is not None:
|
|
1024
850
|
assert dresidual.shape == x.shape
|
|
1025
851
|
assert dresidual.is_cuda
|
|
1026
|
-
assert dresidual.dtype in
|
|
1027
|
-
torch.float16,
|
|
1028
|
-
torch.bfloat16,
|
|
1029
|
-
torch.float32,
|
|
1030
|
-
], "Residual must be float16, bfloat16, or float32"
|
|
852
|
+
assert dresidual.dtype in supported_types, "Residual must be float16, bfloat16, or float32"
|
|
1031
853
|
|
|
1032
854
|
N = x.size(1)
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
for t in (x, dout, dresidual_out, dx, dresidual)
|
|
855
|
+
if dw_partial is None and db_partial is None:
|
|
856
|
+
assert sm_count is not None
|
|
857
|
+
else:
|
|
858
|
+
sm_count = dw_partial.shape[0] if dw_partial is not None else db_partial.shape[0]
|
|
859
|
+
dtype, dout_dtype, dx_dtype, weight_dtype, dres_dtype, dres_out_dtype = [
|
|
860
|
+
torch2cute_dtype_map[t.dtype] if t is not None else None
|
|
861
|
+
for t in [x, dout, dx, weight, dresidual, dresidual_out]
|
|
1041
862
|
]
|
|
1042
|
-
# Handle weight div based on weight dtype
|
|
1043
|
-
weight_dtype = torch2cute_dtype_map[weight.dtype]
|
|
1044
|
-
weight_tensor = utils.convert_from_dlpack(
|
|
1045
|
-
weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
|
|
1046
|
-
)
|
|
1047
|
-
|
|
1048
|
-
dw_partial_tensor = from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
1049
|
-
db_partial_tensor = (
|
|
1050
|
-
from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
1051
|
-
if db_partial is not None
|
|
1052
|
-
else None
|
|
1053
|
-
)
|
|
1054
|
-
rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
1055
|
-
|
|
1056
|
-
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
1057
|
-
|
|
1058
863
|
compile_key = (
|
|
1059
864
|
N,
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
865
|
+
dtype,
|
|
866
|
+
dout_dtype,
|
|
867
|
+
dx_dtype,
|
|
868
|
+
weight_dtype,
|
|
869
|
+
db_partial is not None,
|
|
870
|
+
dres_dtype,
|
|
871
|
+
dres_out_dtype,
|
|
1065
872
|
)
|
|
1066
873
|
if compile_key not in _rmsnorm_bwd.compile_cache:
|
|
1067
|
-
|
|
874
|
+
batch_sym, batch_partial_sym = cute.sym_int(), cute.sym_int()
|
|
875
|
+
all_dtypes = [dtype, dout_dtype, dx_dtype, dres_dtype, dres_out_dtype]
|
|
876
|
+
div = math.gcd(N, *(128 // dt.width for dt in all_dtypes if dt is not None))
|
|
877
|
+
x_cute, dout_cute, dx_cute, dres_out_cute, dres_cute = [
|
|
878
|
+
fake_tensor(dt, (batch_sym, N), div)
|
|
879
|
+
for dt in [dtype, dout_dtype, dx_dtype, dres_out_dtype, dres_dtype]
|
|
880
|
+
]
|
|
881
|
+
weight_cute = fake_tensor(weight_dtype, (N,), div)
|
|
882
|
+
rstd_cute = fake_tensor(Float32, (batch_sym,))
|
|
883
|
+
dw_partial_cute = (
|
|
884
|
+
fake_tensor(Float32, (batch_partial_sym, N), div) if dw_partial is not None else None
|
|
885
|
+
)
|
|
886
|
+
db_partial_cute = (
|
|
887
|
+
fake_tensor(Float32, (batch_partial_sym, N), div) if db_partial is not None else None
|
|
888
|
+
)
|
|
1068
889
|
_rmsnorm_bwd.compile_cache[compile_key] = cute.compile(
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
890
|
+
RMSNormBackward(dtype, N),
|
|
891
|
+
x_cute,
|
|
892
|
+
weight_cute,
|
|
893
|
+
dout_cute,
|
|
894
|
+
dres_out_cute,
|
|
895
|
+
rstd_cute,
|
|
896
|
+
dx_cute,
|
|
897
|
+
dw_partial_cute,
|
|
898
|
+
dres_cute,
|
|
899
|
+
db_partial_cute,
|
|
1079
900
|
sm_count,
|
|
1080
|
-
|
|
901
|
+
cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
|
|
902
|
+
options="--enable-tvm-ffi",
|
|
1081
903
|
)
|
|
1082
|
-
|
|
1083
904
|
_rmsnorm_bwd.compile_cache[compile_key](
|
|
1084
|
-
|
|
1085
|
-
weight_tensor,
|
|
1086
|
-
dout_tensor,
|
|
1087
|
-
dres_out_tensor,
|
|
1088
|
-
rstd_tensor,
|
|
1089
|
-
dx_tensor,
|
|
1090
|
-
dw_partial_tensor,
|
|
1091
|
-
dres_tensor,
|
|
1092
|
-
db_partial_tensor,
|
|
1093
|
-
sm_count,
|
|
1094
|
-
current_stream,
|
|
905
|
+
x, weight, dout, dresidual_out, rstd, dx, dw_partial, dresidual, db_partial, sm_count
|
|
1095
906
|
)
|
|
1096
907
|
|
|
1097
908
|
|
|
@@ -1100,30 +911,37 @@ _rmsnorm_bwd.compile_cache = {}
|
|
|
1100
911
|
|
|
1101
912
|
def rmsnorm_bwd(
|
|
1102
913
|
x: Tensor,
|
|
1103
|
-
weight: Tensor,
|
|
914
|
+
weight: Optional[Tensor],
|
|
1104
915
|
dout: Tensor,
|
|
1105
916
|
rstd: Tensor,
|
|
1106
917
|
dresidual_out: Optional[Tensor] = None, # grad wrt residual_out
|
|
1107
918
|
has_bias: bool = False,
|
|
1108
|
-
|
|
919
|
+
has_residual: bool = False,
|
|
920
|
+
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
|
|
1109
921
|
device = x.device
|
|
1110
922
|
N = x.size(1)
|
|
1111
|
-
sm_count = _get_sm_count(N, device)
|
|
1112
923
|
dx = torch.empty_like(x)
|
|
1113
|
-
|
|
1114
924
|
if dresidual_out is not None and dresidual_out.dtype != dx.dtype:
|
|
1115
925
|
dresidual = torch.empty_like(x, dtype=dresidual_out.dtype)
|
|
1116
926
|
else:
|
|
1117
927
|
dresidual = None
|
|
1118
|
-
|
|
1119
|
-
|
|
928
|
+
sm_count = _get_sm_count(N, device)
|
|
929
|
+
if weight is not None:
|
|
930
|
+
# Always store partial gradients in fp32 for numerical accuracy
|
|
931
|
+
dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
|
|
932
|
+
else:
|
|
933
|
+
dw_partial = None
|
|
1120
934
|
db_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32) if has_bias else None
|
|
1121
|
-
|
|
935
|
+
|
|
936
|
+
_rmsnorm_bwd(
|
|
937
|
+
x, weight, dout, rstd, dx, dw_partial, db_partial, dresidual_out, dresidual, sm_count
|
|
938
|
+
)
|
|
939
|
+
|
|
1122
940
|
# we have summed the partial gradients in fp32, now we convert back to the weight dtype
|
|
1123
|
-
dw = dw_partial.sum(dim=0).to(weight.dtype)
|
|
941
|
+
dw = dw_partial.sum(dim=0).to(weight.dtype) if weight is not None else None
|
|
1124
942
|
db = db_partial.sum(dim=0).to(weight.dtype) if has_bias else None
|
|
1125
943
|
# dresidual is the same as dx in this case
|
|
1126
|
-
if
|
|
944
|
+
if has_residual and dresidual is None:
|
|
1127
945
|
dresidual = dx
|
|
1128
946
|
return dx, dw, db, dresidual
|
|
1129
947
|
|
|
@@ -1180,11 +998,16 @@ class RMSNormFunction(torch.autograd.Function):
|
|
|
1180
998
|
x_shape_og = ctx.x_shape_og
|
|
1181
999
|
# Reshape dout to match the flattened shape used in forward
|
|
1182
1000
|
dout = dout.view(-1, dout.shape[-1])
|
|
1183
|
-
|
|
1184
|
-
|
|
1001
|
+
dx, dw, db, dresidual = rmsnorm_bwd(
|
|
1002
|
+
x,
|
|
1003
|
+
weight,
|
|
1004
|
+
dout,
|
|
1005
|
+
rstd,
|
|
1006
|
+
dresidual_out,
|
|
1007
|
+
has_bias,
|
|
1008
|
+
has_residual=ctx.residual_dtype is not None,
|
|
1009
|
+
)
|
|
1185
1010
|
dx = dx.view(x_shape_og)
|
|
1186
|
-
if dresidual_out is not None:
|
|
1187
|
-
dresidual_out = dresidual_out.reshape(x_shape_og)
|
|
1188
1011
|
if dresidual is not None:
|
|
1189
1012
|
dresidual = dresidual.reshape(x_shape_og)
|
|
1190
1013
|
|
|
@@ -1193,7 +1016,7 @@ class RMSNormFunction(torch.autograd.Function):
|
|
|
1193
1016
|
|
|
1194
1017
|
def rmsnorm(
|
|
1195
1018
|
x: Tensor,
|
|
1196
|
-
weight: Tensor,
|
|
1019
|
+
weight: Optional[Tensor] = None,
|
|
1197
1020
|
bias: Optional[Tensor] = None,
|
|
1198
1021
|
residual: Optional[Tensor] = None,
|
|
1199
1022
|
out_dtype: Optional[torch.dtype] = None,
|
|
@@ -1205,7 +1028,7 @@ def rmsnorm(
|
|
|
1205
1028
|
|
|
1206
1029
|
Args:
|
|
1207
1030
|
x: Input tensor of shape (M, N)
|
|
1208
|
-
weight:
|
|
1031
|
+
weight: Optional weight tensor of shape (N,)
|
|
1209
1032
|
eps: Small value for numerical stability
|
|
1210
1033
|
|
|
1211
1034
|
Returns:
|
|
@@ -1214,7 +1037,7 @@ def rmsnorm(
|
|
|
1214
1037
|
return RMSNormFunction.apply(x, weight, bias, residual, out_dtype, residual_dtype, eps, prenorm)
|
|
1215
1038
|
|
|
1216
1039
|
|
|
1217
|
-
class QuackRMSNorm(torch.nn.
|
|
1040
|
+
class QuackRMSNorm(torch.nn.RMSNorm):
|
|
1218
1041
|
"""RMSNorm module that behaves like torch.nn.RMSNorm.
|
|
1219
1042
|
|
|
1220
1043
|
This class provides a drop-in replacement for torch.nn.RMSNorm that uses
|
|
@@ -1229,10 +1052,10 @@ class QuackRMSNorm(torch.nn.Module):
|
|
|
1229
1052
|
eps (float): A small constant for numerical stability
|
|
1230
1053
|
"""
|
|
1231
1054
|
|
|
1232
|
-
def __init__(
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1055
|
+
def __init__(
|
|
1056
|
+
self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True, device=None, dtype=None
|
|
1057
|
+
):
|
|
1058
|
+
super().__init__(dim, eps, elementwise_affine, device=device, dtype=dtype)
|
|
1236
1059
|
|
|
1237
1060
|
def forward(self, x: Tensor) -> Tensor:
|
|
1238
1061
|
"""Apply RMSNorm to the input tensor.
|
|
@@ -1245,6 +1068,67 @@ class QuackRMSNorm(torch.nn.Module):
|
|
|
1245
1068
|
"""
|
|
1246
1069
|
return rmsnorm(x, self.weight, eps=self.eps)
|
|
1247
1070
|
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1071
|
+
|
|
1072
|
+
def layernorm_fwd(
|
|
1073
|
+
x: Tensor,
|
|
1074
|
+
weight: Tensor,
|
|
1075
|
+
bias: Optional[Tensor] = None,
|
|
1076
|
+
eps: float = 1e-6,
|
|
1077
|
+
return_rstd: bool = False,
|
|
1078
|
+
return_mean: bool = False,
|
|
1079
|
+
):
|
|
1080
|
+
"""LayerNorm forward pass using the unified RMSNorm/LayerNorm kernel.
|
|
1081
|
+
|
|
1082
|
+
Args:
|
|
1083
|
+
x: Input tensor of shape (M, N)
|
|
1084
|
+
weight: Weight tensor of shape (N,). Must be float32.
|
|
1085
|
+
bias: Optional bias tensor of shape (N,). Must be float32.
|
|
1086
|
+
eps: Small value for numerical stability
|
|
1087
|
+
return_rstd: Whether to return the reciprocal standard deviation
|
|
1088
|
+
return_mean: Whether to return the mean
|
|
1089
|
+
|
|
1090
|
+
Returns:
|
|
1091
|
+
Normalized output tensor of same shape as x
|
|
1092
|
+
If return_rstd is True, also returns rstd tensor of shape (M,)
|
|
1093
|
+
If return_mean is True, also returns mean tensor of shape (M,)
|
|
1094
|
+
"""
|
|
1095
|
+
assert x.dim() == 2, "Input must be 2D"
|
|
1096
|
+
assert weight.dim() == 1, "Weight must be 1D"
|
|
1097
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
1098
|
+
assert weight.dtype == torch.float32, "Weight must be float32"
|
|
1099
|
+
if bias is not None:
|
|
1100
|
+
assert bias.dim() == 1, "Bias must be 1D"
|
|
1101
|
+
assert bias.dtype == torch.float32, "Bias must be float32"
|
|
1102
|
+
|
|
1103
|
+
M, N = x.shape
|
|
1104
|
+
device = x.device
|
|
1105
|
+
out = torch.empty_like(x)
|
|
1106
|
+
rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
|
|
1107
|
+
mean = torch.empty(M, device=device, dtype=torch.float32) if return_mean else None
|
|
1108
|
+
|
|
1109
|
+
_rmsnorm_fwd(x, weight, out, bias, rstd, mean, None, None, eps, True)
|
|
1110
|
+
|
|
1111
|
+
if return_rstd and return_mean:
|
|
1112
|
+
return out, rstd, mean
|
|
1113
|
+
elif return_rstd:
|
|
1114
|
+
return out, rstd
|
|
1115
|
+
elif return_mean:
|
|
1116
|
+
return out, mean
|
|
1117
|
+
return out
|
|
1118
|
+
|
|
1119
|
+
|
|
1120
|
+
def layernorm_ref(x: Tensor, w: Tensor, eps: float = 1e-6) -> Tensor:
|
|
1121
|
+
"""Reference implementation for LayerNorm."""
|
|
1122
|
+
x_f32 = x.float()
|
|
1123
|
+
return torch.nn.functional.layer_norm(x_f32, w.shape, w, None, eps).to(x.dtype)
|
|
1124
|
+
|
|
1125
|
+
|
|
1126
|
+
def layernorm_rstd_ref(x: torch.Tensor, eps: float = 1e-6):
|
|
1127
|
+
x_f32 = x.float()
|
|
1128
|
+
mean = x_f32.mean(dim=-1, keepdim=True)
|
|
1129
|
+
var = ((x_f32 - mean) ** 2).mean(dim=-1)
|
|
1130
|
+
return 1.0 / torch.sqrt(var + eps)
|
|
1131
|
+
|
|
1132
|
+
|
|
1133
|
+
def layernorm_mean_ref(x: torch.Tensor) -> torch.Tensor:
|
|
1134
|
+
return x.float().mean(dim=-1)
|