quack-kernels 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -1
- quack/cross_entropy.py +201 -167
- quack/reduction_base.py +98 -0
- quack/rmsnorm.py +212 -181
- quack/softmax.py +417 -156
- quack/utils.py +206 -45
- quack_kernels-0.1.4.dist-info/METADATA +11 -0
- quack_kernels-0.1.4.dist-info/RECORD +11 -0
- quack_kernels-0.1.2.dist-info/METADATA +0 -8
- quack_kernels-0.1.2.dist-info/RECORD +0 -10
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.4.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.4.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.4.dist-info}/top_level.txt +0 -0
quack/rmsnorm.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
|
-
import math
|
|
4
|
-
import operator
|
|
5
3
|
|
|
6
4
|
import torch
|
|
5
|
+
from typing import Optional
|
|
7
6
|
|
|
8
7
|
import cuda.bindings.driver as cuda
|
|
9
8
|
|
|
@@ -12,181 +11,203 @@ import cutlass.cute as cute
|
|
|
12
11
|
from cutlass.cute.runtime import from_dlpack
|
|
13
12
|
|
|
14
13
|
import quack.utils as utils
|
|
14
|
+
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class RMSNorm(ReductionBase):
|
|
18
|
+
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
19
|
+
super().__init__(dtype, N, stage=1)
|
|
20
|
+
self.reload_from = None if N <= 16384 else "smem"
|
|
21
|
+
self.delay_w_load = False
|
|
22
|
+
|
|
23
|
+
def _calculate_threads_per_row(self):
|
|
24
|
+
N = self.N
|
|
25
|
+
return (
|
|
26
|
+
8
|
|
27
|
+
if N <= 64
|
|
28
|
+
else (
|
|
29
|
+
16
|
|
30
|
+
if N <= 128
|
|
31
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
32
|
+
)
|
|
33
|
+
)
|
|
15
34
|
|
|
35
|
+
def _set_cluster_n(self):
|
|
36
|
+
N = self.N
|
|
37
|
+
# cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
|
|
38
|
+
# Similarly cluster_n = 8 is faster for N=128k
|
|
39
|
+
if cutlass.const_expr(self.dtype.width == 16):
|
|
40
|
+
cluster_n = (
|
|
41
|
+
1
|
|
42
|
+
if N <= 16 * 1024
|
|
43
|
+
else (
|
|
44
|
+
2
|
|
45
|
+
if N <= 32 * 1024
|
|
46
|
+
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
47
|
+
)
|
|
48
|
+
)
|
|
49
|
+
else: # fp32
|
|
50
|
+
cluster_n = (
|
|
51
|
+
1
|
|
52
|
+
if N <= 32 * 1024
|
|
53
|
+
else (
|
|
54
|
+
2
|
|
55
|
+
if N <= 64 * 1024
|
|
56
|
+
else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
self.cluster_n = cluster_n
|
|
60
|
+
|
|
61
|
+
@cute.jit
|
|
62
|
+
def __call__(
|
|
63
|
+
self,
|
|
64
|
+
mX: cute.Tensor,
|
|
65
|
+
mW: cute.Tensor,
|
|
66
|
+
mO: cute.Tensor,
|
|
67
|
+
mRstd: Optional[cute.Tensor],
|
|
68
|
+
stream: cuda.CUstream,
|
|
69
|
+
eps: cutlass.Float32 = 1e-6,
|
|
70
|
+
):
|
|
71
|
+
assert mX.element_type == self.dtype
|
|
72
|
+
assert mO.element_type == self.dtype
|
|
73
|
+
self._set_cluster_n()
|
|
74
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
75
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
76
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
77
|
+
mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
|
|
78
|
+
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
79
|
+
if cutlass.const_expr(mRstd is not None):
|
|
80
|
+
mRstd_expanded_layout = cute.append(
|
|
81
|
+
mRstd.layout, cute.make_layout((self.N,), stride=(0,))
|
|
82
|
+
)
|
|
83
|
+
mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
|
|
84
|
+
self.kernel(mX, mW, mO, mRstd, eps, tv_layout, tiler_mn, self.reload_from).launch(
|
|
85
|
+
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
86
|
+
block=[num_threads, 1, 1],
|
|
87
|
+
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
88
|
+
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
89
|
+
stream=stream,
|
|
90
|
+
)
|
|
16
91
|
|
|
17
|
-
@cute.kernel
|
|
18
|
-
def
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
):
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
]
|
|
54
|
-
gW = cute.local_tile(mW, tiler_mn, (0, 0 if cluster_n == 1 else cluster_y))
|
|
55
|
-
|
|
56
|
-
# declare the atoms which will be used later for memory copy
|
|
57
|
-
copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128)
|
|
58
|
-
copy_atom_load_X_async = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128)
|
|
59
|
-
copy_atom_load_W = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128)
|
|
60
|
-
copy_atom_store_O = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=128)
|
|
61
|
-
|
|
62
|
-
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(tidx)
|
|
63
|
-
thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
|
|
64
|
-
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
65
|
-
|
|
66
|
-
tWgW = thr_copy_W.partition_S(gW)
|
|
67
|
-
tXgX = thr_copy_X.partition_S(gX)
|
|
68
|
-
tXsX = thr_copy_X.partition_D(sX)
|
|
69
|
-
tXgO, tXrRstd = [thr_copy_O.partition_D(gT) for gT in (gO, gRstd)]
|
|
70
|
-
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
71
|
-
|
|
72
|
-
# allocate fragments for gmem->rmem
|
|
73
|
-
tWrW = cute.make_fragment_like(tWgW)
|
|
74
|
-
tXrW = thr_copy_X.retile(tWrW)
|
|
75
|
-
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
76
|
-
|
|
77
|
-
if cluster_n > 1:
|
|
78
|
-
if tidx == 0:
|
|
79
|
-
cute.arch.mbarrier_init_arrive_cnt(mbar_ptr, 1)
|
|
80
|
-
cute.arch.mbarrier_init_fence()
|
|
81
|
-
if tidx == 0:
|
|
82
|
-
cute.arch.mbarrier_init_tx_bytes(mbar_ptr, num_warps * cluster_n * cutlass.Float32.width // 8)
|
|
83
|
-
# Cluster arrive after barrier init
|
|
84
|
-
cute.arch.cluster_arrive_relaxed()
|
|
85
|
-
|
|
86
|
-
tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
87
|
-
row = tXcX[0][0]
|
|
88
|
-
if row < shape[0]:
|
|
89
|
-
cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
|
|
90
|
-
cute.arch.cp_async_commit_group()
|
|
91
|
-
|
|
92
|
-
tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
|
|
93
|
-
if not delay_w_load:
|
|
94
|
-
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
95
|
-
|
|
96
|
-
cute.arch.cp_async_wait_group(0)
|
|
97
|
-
cute.autovec_copy(tXsX, tXrX)
|
|
98
|
-
x = tXrX.load().to(cute.Float32)
|
|
99
|
-
threads_per_row = tv_layout.shape[0][0]
|
|
100
|
-
sum_sq_x = utils.row_reduce(
|
|
101
|
-
x * x,
|
|
102
|
-
cute.ReductionOp.ADD,
|
|
103
|
-
threads_per_row,
|
|
104
|
-
reduction_buffer,
|
|
105
|
-
mbar_ptr,
|
|
106
|
-
init_val=0.0,
|
|
107
|
-
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(cluster_n > 1) else None
|
|
108
|
-
)
|
|
109
|
-
rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
|
|
110
|
-
# Only the thread corresponding to column 0 writes out the rstd to gmem
|
|
111
|
-
if tXcX[0][1] == 0 and row < shape[0] and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0):
|
|
112
|
-
tXrRstd[0] = rstd
|
|
113
|
-
if delay_w_load:
|
|
114
|
-
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
115
|
-
if reload_from == "smem":
|
|
116
|
-
cute.autovec_copy(tXsX, tXrX)
|
|
117
|
-
x = tXrX.load().to(cute.Float32)
|
|
118
|
-
elif reload_from == "gmem":
|
|
119
|
-
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
120
|
-
x = tXrX.load().to(cute.Float32)
|
|
121
|
-
x_hat = x * rstd
|
|
122
|
-
w = tXrW.load().to(cute.Float32)
|
|
123
|
-
y = x_hat * w
|
|
124
|
-
tXrO.store(y.to(tXrO.element_type))
|
|
125
|
-
tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
|
|
126
|
-
if row < shape[0]:
|
|
127
|
-
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
@cute.jit
|
|
131
|
-
def rmsnorm_interface(
|
|
132
|
-
mX: cute.Tensor,
|
|
133
|
-
mW: cute.Tensor,
|
|
134
|
-
mO: cute.Tensor,
|
|
135
|
-
mRstd: cute.Tensor,
|
|
136
|
-
stream: cuda.CUstream,
|
|
137
|
-
N: cutlass.Constexpr,
|
|
138
|
-
eps: cutlass.Float32 = 1e-6,
|
|
139
|
-
copy_bits: cutlass.Constexpr = 128
|
|
140
|
-
):
|
|
141
|
-
# new_shape = (mX_.shape[0], cute.assume(mX_.shape[1], 128))
|
|
142
|
-
# breakpoint()
|
|
143
|
-
# mX = cute.make_tensor(mX_.iterator, cute.make_layout(new_shape, stride=mX_.stride))
|
|
144
|
-
vecsize = copy_bits // mX.element_type.width
|
|
145
|
-
assert N % vecsize == 0, f"Input N {N} is not divisible by vector size {vecsize}"
|
|
146
|
-
num_threads = 128 if N <= 16384 else 256
|
|
147
|
-
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
148
|
-
assert num_threads % cute.arch.WARP_SIZE == 0
|
|
149
|
-
threads_per_row = 8 if N <= 64 else (16 if N <= 128 else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256))))
|
|
150
|
-
# cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
|
|
151
|
-
# Similarly cluster_n = 8 is faster for N=128k
|
|
152
|
-
if cutlass.const_expr(mX.element_type.width == 16):
|
|
153
|
-
cluster_n = 1 if N <= 16 * 1024 else (2 if N <= 32 * 1024 else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)))
|
|
154
|
-
else: # fp32
|
|
155
|
-
cluster_n = 1 if N <= 32 * 1024 else (2 if N <= 64 * 1024 else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)))
|
|
156
|
-
|
|
157
|
-
num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row * cluster_n)
|
|
158
|
-
cols_per_block = num_threads // threads_per_row
|
|
159
|
-
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) # This rounds up N
|
|
160
|
-
tv_layout = cute.make_layout(
|
|
161
|
-
((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
|
|
162
|
-
stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * threads_per_row))
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
|
|
166
|
-
mW_expanded = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
167
|
-
mRstd_expanded_layout = cute.append(mRstd.layout, cute.make_layout((N,), stride=(0,)))
|
|
168
|
-
mRstd_expanded = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
|
|
169
|
-
|
|
170
|
-
# reload_from = None if N <= 16384 else ("smem" if N <= 32768 else "gmem")
|
|
171
|
-
reload_from = None if N <= 16384 else "smem"
|
|
172
|
-
# delay_w_load = N > 64 * 1024
|
|
173
|
-
delay_w_load = False
|
|
174
|
-
N_rounded = tiler_mn[1]
|
|
175
|
-
rmsnorm_kernel(mX, mW_expanded, mO, mRstd_expanded, eps, tv_layout, tiler_mn, cluster_n, reload_from).launch(
|
|
176
|
-
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1],
|
|
177
|
-
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
|
178
|
-
# Launching with cluster=[1, 1, 1] instead of None slows down the kernel by ~8us
|
|
179
|
-
cluster=[1, cluster_n, 1] if cluster_n > 1 else None,
|
|
180
|
-
smem=cute.size_in_bytes(mX.element_type, cute.make_layout(tiler_mn)) + num_warps * cluster_n * (cutlass.Float32.width // 8) + (cutlass.Int64.width // 8),
|
|
181
|
-
stream=stream,
|
|
182
|
-
)
|
|
92
|
+
@cute.kernel
|
|
93
|
+
def kernel(
|
|
94
|
+
self,
|
|
95
|
+
mX: cute.Tensor,
|
|
96
|
+
mW: cute.Tensor,
|
|
97
|
+
mO: cute.Tensor,
|
|
98
|
+
mRstd: Optional[cute.Tensor],
|
|
99
|
+
eps: cute.Float32,
|
|
100
|
+
tv_layout: cute.Layout,
|
|
101
|
+
tiler_mn: cute.Shape,
|
|
102
|
+
reload_from: cutlass.Constexpr = None,
|
|
103
|
+
delay_w_load: cutlass.Constexpr = False,
|
|
104
|
+
):
|
|
105
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
106
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
107
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
108
|
+
cluster_y = cute.arch.block_idx()[1]
|
|
109
|
+
else:
|
|
110
|
+
cluster_y = cutlass.const_expr(0)
|
|
111
|
+
|
|
112
|
+
smem = cutlass.utils.SmemAllocator()
|
|
113
|
+
sX = smem.allocate_tensor(
|
|
114
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
115
|
+
)
|
|
116
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
117
|
+
|
|
118
|
+
shape = mX.shape
|
|
119
|
+
idX = cute.make_identity_tensor(shape)
|
|
120
|
+
# slice for CTAs
|
|
121
|
+
gX, gO, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)]
|
|
122
|
+
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
123
|
+
gRstd = (
|
|
124
|
+
cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
|
|
125
|
+
if cutlass.const_expr(mRstd is not None)
|
|
126
|
+
else None
|
|
127
|
+
)
|
|
183
128
|
|
|
129
|
+
# declare the atoms which will be used later for memory copy
|
|
130
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
131
|
+
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
|
|
132
|
+
)
|
|
133
|
+
copy_atom_load_X_async = cute.make_copy_atom(
|
|
134
|
+
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
|
|
135
|
+
)
|
|
136
|
+
copy_atom_load_W = cute.make_copy_atom(
|
|
137
|
+
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
|
|
138
|
+
)
|
|
139
|
+
copy_atom_store_O = cute.make_copy_atom(
|
|
140
|
+
cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=128
|
|
141
|
+
)
|
|
184
142
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
143
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
|
|
144
|
+
tidx
|
|
145
|
+
)
|
|
146
|
+
thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
|
|
147
|
+
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
148
|
+
|
|
149
|
+
tWgW = thr_copy_W.partition_S(gW)
|
|
150
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
151
|
+
tXsX = thr_copy_X.partition_D(sX)
|
|
152
|
+
tXgO = thr_copy_O.partition_D(gO)
|
|
153
|
+
tXrRstd = thr_copy_O.partition_D(gRstd) if cutlass.const_expr(mRstd is not None) else None
|
|
154
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
155
|
+
|
|
156
|
+
# allocate fragments for gmem->rmem
|
|
157
|
+
tWrW = cute.make_fragment_like(tWgW)
|
|
158
|
+
tXrW = thr_copy_X.retile(tWrW)
|
|
159
|
+
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
160
|
+
|
|
161
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
162
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
163
|
+
|
|
164
|
+
tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
165
|
+
row = tXcX[0][0]
|
|
166
|
+
if row < shape[0]:
|
|
167
|
+
cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
|
|
168
|
+
cute.arch.cp_async_commit_group()
|
|
169
|
+
|
|
170
|
+
tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
|
|
171
|
+
if cutlass.const_expr(not delay_w_load):
|
|
172
|
+
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
173
|
+
|
|
174
|
+
cute.arch.cp_async_wait_group(0)
|
|
175
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
176
|
+
x = tXrX.load().to(cute.Float32)
|
|
177
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
178
|
+
sum_sq_x = utils.row_reduce(
|
|
179
|
+
x * x,
|
|
180
|
+
cute.ReductionOp.ADD,
|
|
181
|
+
threads_per_row,
|
|
182
|
+
reduction_buffer[None, None, 0],
|
|
183
|
+
mbar_ptr,
|
|
184
|
+
init_val=0.0,
|
|
185
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
186
|
+
)
|
|
187
|
+
rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
|
|
188
|
+
if cutlass.const_expr(mRstd is not None):
|
|
189
|
+
# Only the thread corresponding to column 0 writes out the rstd to gmem
|
|
190
|
+
if (
|
|
191
|
+
tXcX[0][1] == 0
|
|
192
|
+
and row < shape[0]
|
|
193
|
+
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
194
|
+
):
|
|
195
|
+
tXrRstd[0] = rstd
|
|
196
|
+
if cutlass.const_expr(delay_w_load):
|
|
197
|
+
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
198
|
+
if cutlass.const_expr(reload_from == "smem"):
|
|
199
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
200
|
+
x = tXrX.load().to(cute.Float32)
|
|
201
|
+
elif cutlass.const_expr(reload_from == "gmem"):
|
|
202
|
+
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
203
|
+
x = tXrX.load().to(cute.Float32)
|
|
204
|
+
x_hat = x * rstd
|
|
205
|
+
w = tXrW.load().to(cute.Float32)
|
|
206
|
+
y = x_hat * w
|
|
207
|
+
tXrO.store(y.to(tXrO.element_type))
|
|
208
|
+
tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
|
|
209
|
+
if row < shape[0]:
|
|
210
|
+
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
190
211
|
|
|
191
212
|
|
|
192
213
|
def rmsnorm(
|
|
@@ -216,24 +237,32 @@ def rmsnorm(
|
|
|
216
237
|
M, N = x.shape
|
|
217
238
|
device = x.device
|
|
218
239
|
out = torch.empty_like(x)
|
|
219
|
-
rstd = torch.empty(M, device=device, dtype=torch.float32)
|
|
240
|
+
rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
|
|
220
241
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
221
242
|
convert_from_dlpack = lambda x: (
|
|
222
|
-
from_dlpack(x.detach(), assumed_align=16)
|
|
223
|
-
|
|
243
|
+
from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
244
|
+
mode=0, stride_order=(0, 1)
|
|
245
|
+
)
|
|
224
246
|
)
|
|
225
247
|
x_tensor, out_tensor = [
|
|
226
248
|
# utils.convert_from_dlpack(t, leading_dim=t.ndim - 1, divisibility=128 // dtype.width)
|
|
227
249
|
convert_from_dlpack(t)
|
|
228
250
|
for t in (x, out)
|
|
229
251
|
]
|
|
230
|
-
weight_tensor = utils.convert_from_dlpack(
|
|
231
|
-
|
|
252
|
+
weight_tensor = utils.convert_from_dlpack(
|
|
253
|
+
weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
|
|
254
|
+
)
|
|
255
|
+
rstd_tensor = (
|
|
256
|
+
from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
257
|
+
if rstd is not None
|
|
258
|
+
else None
|
|
259
|
+
)
|
|
232
260
|
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
233
|
-
compile_key = (dtype, N)
|
|
261
|
+
compile_key = (dtype, N, rstd is not None)
|
|
234
262
|
if compile_key not in rmsnorm.compile_cache:
|
|
263
|
+
rmsnorm_op = RMSNorm(dtype, N)
|
|
235
264
|
rmsnorm.compile_cache[compile_key] = cute.compile(
|
|
236
|
-
|
|
265
|
+
rmsnorm_op, x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream
|
|
237
266
|
)
|
|
238
267
|
rmsnorm.compile_cache[compile_key](
|
|
239
268
|
x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream, eps
|
|
@@ -246,7 +275,9 @@ rmsnorm.compile_cache = {}
|
|
|
246
275
|
|
|
247
276
|
def rmsnorm_ref(x, w, eps=1e-6):
|
|
248
277
|
x_f32 = x.float()
|
|
249
|
-
return (x_f32 / (torch.sqrt(torch.mean(x_f32
|
|
278
|
+
return (x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps)) * w).to(
|
|
279
|
+
x.dtype
|
|
280
|
+
)
|
|
250
281
|
|
|
251
282
|
|
|
252
283
|
def rstd_ref(x, eps=1e-6):
|