quack-kernels 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -1
- quack/cross_entropy.py +197 -166
- quack/reduction_base.py +98 -0
- quack/rmsnorm.py +211 -181
- quack/softmax.py +409 -156
- quack/utils.py +197 -39
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.3.dist-info}/METADATA +4 -1
- quack_kernels-0.1.3.dist-info/RECORD +11 -0
- quack_kernels-0.1.2.dist-info/RECORD +0 -10
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.3.dist-info}/top_level.txt +0 -0
quack/rmsnorm.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
|
-
import math
|
|
4
|
-
import operator
|
|
5
3
|
|
|
6
4
|
import torch
|
|
5
|
+
from typing import Optional
|
|
7
6
|
|
|
8
7
|
import cuda.bindings.driver as cuda
|
|
9
8
|
|
|
@@ -12,181 +11,202 @@ import cutlass.cute as cute
|
|
|
12
11
|
from cutlass.cute.runtime import from_dlpack
|
|
13
12
|
|
|
14
13
|
import quack.utils as utils
|
|
14
|
+
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class RMSNorm(ReductionBase):
|
|
18
|
+
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
19
|
+
super().__init__(dtype, N, stage=1)
|
|
20
|
+
self.reload_from = None if N <= 16384 else "smem"
|
|
21
|
+
self.delay_w_load = False
|
|
22
|
+
|
|
23
|
+
def _calculate_threads_per_row(self):
|
|
24
|
+
N = self.N
|
|
25
|
+
return (
|
|
26
|
+
8
|
|
27
|
+
if N <= 64
|
|
28
|
+
else (
|
|
29
|
+
16
|
|
30
|
+
if N <= 128
|
|
31
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
32
|
+
)
|
|
33
|
+
)
|
|
15
34
|
|
|
35
|
+
def _set_cluster_n(self):
|
|
36
|
+
N = self.N
|
|
37
|
+
# cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
|
|
38
|
+
# Similarly cluster_n = 8 is faster for N=128k
|
|
39
|
+
if cutlass.const_expr(self.dtype.width == 16):
|
|
40
|
+
cluster_n = (
|
|
41
|
+
1
|
|
42
|
+
if N <= 16 * 1024
|
|
43
|
+
else (
|
|
44
|
+
2
|
|
45
|
+
if N <= 32 * 1024
|
|
46
|
+
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
47
|
+
)
|
|
48
|
+
)
|
|
49
|
+
else: # fp32
|
|
50
|
+
cluster_n = (
|
|
51
|
+
1
|
|
52
|
+
if N <= 32 * 1024
|
|
53
|
+
else (
|
|
54
|
+
2
|
|
55
|
+
if N <= 64 * 1024
|
|
56
|
+
else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
self.cluster_n = cluster_n
|
|
60
|
+
|
|
61
|
+
@cute.jit
|
|
62
|
+
def __call__(
|
|
63
|
+
self,
|
|
64
|
+
mX: cute.Tensor,
|
|
65
|
+
mW: cute.Tensor,
|
|
66
|
+
mO: cute.Tensor,
|
|
67
|
+
mRstd: Optional[cute.Tensor],
|
|
68
|
+
stream: cuda.CUstream,
|
|
69
|
+
eps: cutlass.Float32 = 1e-6,
|
|
70
|
+
):
|
|
71
|
+
assert mX.element_type == self.dtype
|
|
72
|
+
assert mO.element_type == self.dtype
|
|
73
|
+
self._set_cluster_n()
|
|
74
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
75
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
76
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
77
|
+
mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
|
|
78
|
+
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
79
|
+
if cutlass.const_expr(mRstd is not None):
|
|
80
|
+
mRstd_expanded_layout = cute.append(
|
|
81
|
+
mRstd.layout, cute.make_layout((self.N,), stride=(0,))
|
|
82
|
+
)
|
|
83
|
+
mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
|
|
84
|
+
self.kernel(mX, mW, mO, mRstd, eps, tv_layout, tiler_mn, self.reload_from).launch(
|
|
85
|
+
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
86
|
+
block=[num_threads, 1, 1],
|
|
87
|
+
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
88
|
+
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
89
|
+
stream=stream,
|
|
90
|
+
)
|
|
16
91
|
|
|
17
|
-
@cute.kernel
|
|
18
|
-
def
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
):
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
for mT in (mX, mO, mRstd, idX)
|
|
53
|
-
]
|
|
54
|
-
gW = cute.local_tile(mW, tiler_mn, (0, 0 if cluster_n == 1 else cluster_y))
|
|
55
|
-
|
|
56
|
-
# declare the atoms which will be used later for memory copy
|
|
57
|
-
copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128)
|
|
58
|
-
copy_atom_load_X_async = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128)
|
|
59
|
-
copy_atom_load_W = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128)
|
|
60
|
-
copy_atom_store_O = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=128)
|
|
61
|
-
|
|
62
|
-
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(tidx)
|
|
63
|
-
thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
|
|
64
|
-
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
65
|
-
|
|
66
|
-
tWgW = thr_copy_W.partition_S(gW)
|
|
67
|
-
tXgX = thr_copy_X.partition_S(gX)
|
|
68
|
-
tXsX = thr_copy_X.partition_D(sX)
|
|
69
|
-
tXgO, tXrRstd = [thr_copy_O.partition_D(gT) for gT in (gO, gRstd)]
|
|
70
|
-
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
71
|
-
|
|
72
|
-
# allocate fragments for gmem->rmem
|
|
73
|
-
tWrW = cute.make_fragment_like(tWgW)
|
|
74
|
-
tXrW = thr_copy_X.retile(tWrW)
|
|
75
|
-
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
76
|
-
|
|
77
|
-
if cluster_n > 1:
|
|
78
|
-
if tidx == 0:
|
|
79
|
-
cute.arch.mbarrier_init_arrive_cnt(mbar_ptr, 1)
|
|
80
|
-
cute.arch.mbarrier_init_fence()
|
|
81
|
-
if tidx == 0:
|
|
82
|
-
cute.arch.mbarrier_init_tx_bytes(mbar_ptr, num_warps * cluster_n * cutlass.Float32.width // 8)
|
|
83
|
-
# Cluster arrive after barrier init
|
|
84
|
-
cute.arch.cluster_arrive_relaxed()
|
|
85
|
-
|
|
86
|
-
tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
87
|
-
row = tXcX[0][0]
|
|
88
|
-
if row < shape[0]:
|
|
89
|
-
cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
|
|
90
|
-
cute.arch.cp_async_commit_group()
|
|
91
|
-
|
|
92
|
-
tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
|
|
93
|
-
if not delay_w_load:
|
|
94
|
-
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
95
|
-
|
|
96
|
-
cute.arch.cp_async_wait_group(0)
|
|
97
|
-
cute.autovec_copy(tXsX, tXrX)
|
|
98
|
-
x = tXrX.load().to(cute.Float32)
|
|
99
|
-
threads_per_row = tv_layout.shape[0][0]
|
|
100
|
-
sum_sq_x = utils.row_reduce(
|
|
101
|
-
x * x,
|
|
102
|
-
cute.ReductionOp.ADD,
|
|
103
|
-
threads_per_row,
|
|
104
|
-
reduction_buffer,
|
|
105
|
-
mbar_ptr,
|
|
106
|
-
init_val=0.0,
|
|
107
|
-
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(cluster_n > 1) else None
|
|
108
|
-
)
|
|
109
|
-
rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
|
|
110
|
-
# Only the thread corresponding to column 0 writes out the rstd to gmem
|
|
111
|
-
if tXcX[0][1] == 0 and row < shape[0] and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0):
|
|
112
|
-
tXrRstd[0] = rstd
|
|
113
|
-
if delay_w_load:
|
|
114
|
-
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
115
|
-
if reload_from == "smem":
|
|
116
|
-
cute.autovec_copy(tXsX, tXrX)
|
|
117
|
-
x = tXrX.load().to(cute.Float32)
|
|
118
|
-
elif reload_from == "gmem":
|
|
119
|
-
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
120
|
-
x = tXrX.load().to(cute.Float32)
|
|
121
|
-
x_hat = x * rstd
|
|
122
|
-
w = tXrW.load().to(cute.Float32)
|
|
123
|
-
y = x_hat * w
|
|
124
|
-
tXrO.store(y.to(tXrO.element_type))
|
|
125
|
-
tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
|
|
126
|
-
if row < shape[0]:
|
|
127
|
-
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
@cute.jit
|
|
131
|
-
def rmsnorm_interface(
|
|
132
|
-
mX: cute.Tensor,
|
|
133
|
-
mW: cute.Tensor,
|
|
134
|
-
mO: cute.Tensor,
|
|
135
|
-
mRstd: cute.Tensor,
|
|
136
|
-
stream: cuda.CUstream,
|
|
137
|
-
N: cutlass.Constexpr,
|
|
138
|
-
eps: cutlass.Float32 = 1e-6,
|
|
139
|
-
copy_bits: cutlass.Constexpr = 128
|
|
140
|
-
):
|
|
141
|
-
# new_shape = (mX_.shape[0], cute.assume(mX_.shape[1], 128))
|
|
142
|
-
# breakpoint()
|
|
143
|
-
# mX = cute.make_tensor(mX_.iterator, cute.make_layout(new_shape, stride=mX_.stride))
|
|
144
|
-
vecsize = copy_bits // mX.element_type.width
|
|
145
|
-
assert N % vecsize == 0, f"Input N {N} is not divisible by vector size {vecsize}"
|
|
146
|
-
num_threads = 128 if N <= 16384 else 256
|
|
147
|
-
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
148
|
-
assert num_threads % cute.arch.WARP_SIZE == 0
|
|
149
|
-
threads_per_row = 8 if N <= 64 else (16 if N <= 128 else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256))))
|
|
150
|
-
# cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
|
|
151
|
-
# Similarly cluster_n = 8 is faster for N=128k
|
|
152
|
-
if cutlass.const_expr(mX.element_type.width == 16):
|
|
153
|
-
cluster_n = 1 if N <= 16 * 1024 else (2 if N <= 32 * 1024 else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)))
|
|
154
|
-
else: # fp32
|
|
155
|
-
cluster_n = 1 if N <= 32 * 1024 else (2 if N <= 64 * 1024 else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)))
|
|
156
|
-
|
|
157
|
-
num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row * cluster_n)
|
|
158
|
-
cols_per_block = num_threads // threads_per_row
|
|
159
|
-
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) # This rounds up N
|
|
160
|
-
tv_layout = cute.make_layout(
|
|
161
|
-
((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
|
|
162
|
-
stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * threads_per_row))
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
|
|
166
|
-
mW_expanded = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
167
|
-
mRstd_expanded_layout = cute.append(mRstd.layout, cute.make_layout((N,), stride=(0,)))
|
|
168
|
-
mRstd_expanded = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
|
|
169
|
-
|
|
170
|
-
# reload_from = None if N <= 16384 else ("smem" if N <= 32768 else "gmem")
|
|
171
|
-
reload_from = None if N <= 16384 else "smem"
|
|
172
|
-
# delay_w_load = N > 64 * 1024
|
|
173
|
-
delay_w_load = False
|
|
174
|
-
N_rounded = tiler_mn[1]
|
|
175
|
-
rmsnorm_kernel(mX, mW_expanded, mO, mRstd_expanded, eps, tv_layout, tiler_mn, cluster_n, reload_from).launch(
|
|
176
|
-
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1],
|
|
177
|
-
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
|
178
|
-
# Launching with cluster=[1, 1, 1] instead of None slows down the kernel by ~8us
|
|
179
|
-
cluster=[1, cluster_n, 1] if cluster_n > 1 else None,
|
|
180
|
-
smem=cute.size_in_bytes(mX.element_type, cute.make_layout(tiler_mn)) + num_warps * cluster_n * (cutlass.Float32.width // 8) + (cutlass.Int64.width // 8),
|
|
181
|
-
stream=stream,
|
|
182
|
-
)
|
|
92
|
+
@cute.kernel
|
|
93
|
+
def kernel(
|
|
94
|
+
self,
|
|
95
|
+
mX: cute.Tensor,
|
|
96
|
+
mW: cute.Tensor,
|
|
97
|
+
mO: cute.Tensor,
|
|
98
|
+
mRstd: Optional[cute.Tensor],
|
|
99
|
+
eps: cute.Float32,
|
|
100
|
+
tv_layout: cute.Layout,
|
|
101
|
+
tiler_mn: cute.Shape,
|
|
102
|
+
reload_from: cutlass.Constexpr = None,
|
|
103
|
+
delay_w_load: cutlass.Constexpr = False,
|
|
104
|
+
):
|
|
105
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
106
|
+
bidx, cluster_y, _ = cute.arch.block_idx()
|
|
107
|
+
|
|
108
|
+
smem = cutlass.utils.SmemAllocator()
|
|
109
|
+
sX = smem.allocate_tensor(
|
|
110
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
111
|
+
)
|
|
112
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
113
|
+
|
|
114
|
+
shape = mX.shape
|
|
115
|
+
idX = cute.make_identity_tensor(shape)
|
|
116
|
+
# slice for CTAs
|
|
117
|
+
gX, gO, cX = [
|
|
118
|
+
cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
|
|
119
|
+
for mT in (mX, mO, idX)
|
|
120
|
+
]
|
|
121
|
+
gW = cute.local_tile(mW, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
|
|
122
|
+
gRstd = (
|
|
123
|
+
cute.local_tile(mRstd, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
|
|
124
|
+
if cutlass.const_expr(mRstd is not None)
|
|
125
|
+
else None
|
|
126
|
+
)
|
|
183
127
|
|
|
128
|
+
# declare the atoms which will be used later for memory copy
|
|
129
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
130
|
+
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
|
|
131
|
+
)
|
|
132
|
+
copy_atom_load_X_async = cute.make_copy_atom(
|
|
133
|
+
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
|
|
134
|
+
)
|
|
135
|
+
copy_atom_load_W = cute.make_copy_atom(
|
|
136
|
+
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
|
|
137
|
+
)
|
|
138
|
+
copy_atom_store_O = cute.make_copy_atom(
|
|
139
|
+
cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=128
|
|
140
|
+
)
|
|
184
141
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
142
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
|
|
143
|
+
tidx
|
|
144
|
+
)
|
|
145
|
+
thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
|
|
146
|
+
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
147
|
+
|
|
148
|
+
tWgW = thr_copy_W.partition_S(gW)
|
|
149
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
150
|
+
tXsX = thr_copy_X.partition_D(sX)
|
|
151
|
+
tXgO = thr_copy_O.partition_D(gO)
|
|
152
|
+
tXrRstd = thr_copy_O.partition_D(gRstd) if cutlass.const_expr(mRstd is not None) else None
|
|
153
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
154
|
+
|
|
155
|
+
# allocate fragments for gmem->rmem
|
|
156
|
+
tWrW = cute.make_fragment_like(tWgW)
|
|
157
|
+
tXrW = thr_copy_X.retile(tWrW)
|
|
158
|
+
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
159
|
+
|
|
160
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
161
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
162
|
+
|
|
163
|
+
tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
164
|
+
row = tXcX[0][0]
|
|
165
|
+
if row < shape[0]:
|
|
166
|
+
cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
|
|
167
|
+
cute.arch.cp_async_commit_group()
|
|
168
|
+
|
|
169
|
+
tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
|
|
170
|
+
if not delay_w_load:
|
|
171
|
+
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
172
|
+
|
|
173
|
+
cute.arch.cp_async_wait_group(0)
|
|
174
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
175
|
+
x = tXrX.load().to(cute.Float32)
|
|
176
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
177
|
+
sum_sq_x = utils.row_reduce(
|
|
178
|
+
x * x,
|
|
179
|
+
cute.ReductionOp.ADD,
|
|
180
|
+
threads_per_row,
|
|
181
|
+
reduction_buffer[None, None, 0],
|
|
182
|
+
mbar_ptr,
|
|
183
|
+
init_val=0.0,
|
|
184
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
185
|
+
)
|
|
186
|
+
rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
|
|
187
|
+
if cutlass.const_expr(mRstd is not None):
|
|
188
|
+
# Only the thread corresponding to column 0 writes out the rstd to gmem
|
|
189
|
+
if (
|
|
190
|
+
tXcX[0][1] == 0
|
|
191
|
+
and row < shape[0]
|
|
192
|
+
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
193
|
+
):
|
|
194
|
+
tXrRstd[0] = rstd
|
|
195
|
+
if delay_w_load:
|
|
196
|
+
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
197
|
+
if reload_from == "smem":
|
|
198
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
199
|
+
x = tXrX.load().to(cute.Float32)
|
|
200
|
+
elif reload_from == "gmem":
|
|
201
|
+
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
202
|
+
x = tXrX.load().to(cute.Float32)
|
|
203
|
+
x_hat = x * rstd
|
|
204
|
+
w = tXrW.load().to(cute.Float32)
|
|
205
|
+
y = x_hat * w
|
|
206
|
+
tXrO.store(y.to(tXrO.element_type))
|
|
207
|
+
tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
|
|
208
|
+
if row < shape[0]:
|
|
209
|
+
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
190
210
|
|
|
191
211
|
|
|
192
212
|
def rmsnorm(
|
|
@@ -216,24 +236,32 @@ def rmsnorm(
|
|
|
216
236
|
M, N = x.shape
|
|
217
237
|
device = x.device
|
|
218
238
|
out = torch.empty_like(x)
|
|
219
|
-
rstd = torch.empty(M, device=device, dtype=torch.float32)
|
|
239
|
+
rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
|
|
220
240
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
221
241
|
convert_from_dlpack = lambda x: (
|
|
222
|
-
from_dlpack(x.detach(), assumed_align=16)
|
|
223
|
-
|
|
242
|
+
from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
243
|
+
mode=0, stride_order=(0, 1)
|
|
244
|
+
)
|
|
224
245
|
)
|
|
225
246
|
x_tensor, out_tensor = [
|
|
226
247
|
# utils.convert_from_dlpack(t, leading_dim=t.ndim - 1, divisibility=128 // dtype.width)
|
|
227
248
|
convert_from_dlpack(t)
|
|
228
249
|
for t in (x, out)
|
|
229
250
|
]
|
|
230
|
-
weight_tensor = utils.convert_from_dlpack(
|
|
231
|
-
|
|
251
|
+
weight_tensor = utils.convert_from_dlpack(
|
|
252
|
+
weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
|
|
253
|
+
)
|
|
254
|
+
rstd_tensor = (
|
|
255
|
+
from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
256
|
+
if rstd is not None
|
|
257
|
+
else None
|
|
258
|
+
)
|
|
232
259
|
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
233
|
-
compile_key = (dtype, N)
|
|
260
|
+
compile_key = (dtype, N, rstd is not None)
|
|
234
261
|
if compile_key not in rmsnorm.compile_cache:
|
|
262
|
+
rmsnorm_op = RMSNorm(dtype, N)
|
|
235
263
|
rmsnorm.compile_cache[compile_key] = cute.compile(
|
|
236
|
-
|
|
264
|
+
rmsnorm_op, x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream
|
|
237
265
|
)
|
|
238
266
|
rmsnorm.compile_cache[compile_key](
|
|
239
267
|
x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream, eps
|
|
@@ -246,7 +274,9 @@ rmsnorm.compile_cache = {}
|
|
|
246
274
|
|
|
247
275
|
def rmsnorm_ref(x, w, eps=1e-6):
|
|
248
276
|
x_f32 = x.float()
|
|
249
|
-
return (x_f32 / (torch.sqrt(torch.mean(x_f32
|
|
277
|
+
return (x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps)) * w).to(
|
|
278
|
+
x.dtype
|
|
279
|
+
)
|
|
250
280
|
|
|
251
281
|
|
|
252
282
|
def rstd_ref(x, eps=1e-6):
|