quack-kernels 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -1
- quack/cross_entropy.py +199 -168
- quack/reduction_base.py +98 -0
- quack/rmsnorm.py +208 -195
- quack/softmax.py +409 -163
- quack/utils.py +249 -35
- {quack_kernels-0.1.1.dist-info → quack_kernels-0.1.3.dist-info}/METADATA +4 -1
- quack_kernels-0.1.3.dist-info/RECORD +11 -0
- quack_kernels-0.1.1.dist-info/RECORD +0 -10
- {quack_kernels-0.1.1.dist-info → quack_kernels-0.1.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.1.dist-info → quack_kernels-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.1.dist-info → quack_kernels-0.1.3.dist-info}/top_level.txt +0 -0
quack/rmsnorm.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
|
-
import math
|
|
4
|
-
import operator
|
|
5
3
|
|
|
6
4
|
import torch
|
|
5
|
+
from typing import Optional
|
|
7
6
|
|
|
8
7
|
import cuda.bindings.driver as cuda
|
|
9
8
|
|
|
@@ -12,198 +11,202 @@ import cutlass.cute as cute
|
|
|
12
11
|
from cutlass.cute.runtime import from_dlpack
|
|
13
12
|
|
|
14
13
|
import quack.utils as utils
|
|
14
|
+
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class RMSNorm(ReductionBase):
|
|
18
|
+
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
19
|
+
super().__init__(dtype, N, stage=1)
|
|
20
|
+
self.reload_from = None if N <= 16384 else "smem"
|
|
21
|
+
self.delay_w_load = False
|
|
22
|
+
|
|
23
|
+
def _calculate_threads_per_row(self):
|
|
24
|
+
N = self.N
|
|
25
|
+
return (
|
|
26
|
+
8
|
|
27
|
+
if N <= 64
|
|
28
|
+
else (
|
|
29
|
+
16
|
|
30
|
+
if N <= 128
|
|
31
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
32
|
+
)
|
|
33
|
+
)
|
|
15
34
|
|
|
35
|
+
def _set_cluster_n(self):
|
|
36
|
+
N = self.N
|
|
37
|
+
# cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
|
|
38
|
+
# Similarly cluster_n = 8 is faster for N=128k
|
|
39
|
+
if cutlass.const_expr(self.dtype.width == 16):
|
|
40
|
+
cluster_n = (
|
|
41
|
+
1
|
|
42
|
+
if N <= 16 * 1024
|
|
43
|
+
else (
|
|
44
|
+
2
|
|
45
|
+
if N <= 32 * 1024
|
|
46
|
+
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
47
|
+
)
|
|
48
|
+
)
|
|
49
|
+
else: # fp32
|
|
50
|
+
cluster_n = (
|
|
51
|
+
1
|
|
52
|
+
if N <= 32 * 1024
|
|
53
|
+
else (
|
|
54
|
+
2
|
|
55
|
+
if N <= 64 * 1024
|
|
56
|
+
else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
self.cluster_n = cluster_n
|
|
60
|
+
|
|
61
|
+
@cute.jit
|
|
62
|
+
def __call__(
|
|
63
|
+
self,
|
|
64
|
+
mX: cute.Tensor,
|
|
65
|
+
mW: cute.Tensor,
|
|
66
|
+
mO: cute.Tensor,
|
|
67
|
+
mRstd: Optional[cute.Tensor],
|
|
68
|
+
stream: cuda.CUstream,
|
|
69
|
+
eps: cutlass.Float32 = 1e-6,
|
|
70
|
+
):
|
|
71
|
+
assert mX.element_type == self.dtype
|
|
72
|
+
assert mO.element_type == self.dtype
|
|
73
|
+
self._set_cluster_n()
|
|
74
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
75
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
76
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
77
|
+
mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
|
|
78
|
+
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
79
|
+
if cutlass.const_expr(mRstd is not None):
|
|
80
|
+
mRstd_expanded_layout = cute.append(
|
|
81
|
+
mRstd.layout, cute.make_layout((self.N,), stride=(0,))
|
|
82
|
+
)
|
|
83
|
+
mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
|
|
84
|
+
self.kernel(mX, mW, mO, mRstd, eps, tv_layout, tiler_mn, self.reload_from).launch(
|
|
85
|
+
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
86
|
+
block=[num_threads, 1, 1],
|
|
87
|
+
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
88
|
+
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
89
|
+
stream=stream,
|
|
90
|
+
)
|
|
16
91
|
|
|
17
|
-
@cute.kernel
|
|
18
|
-
def
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
)
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
smem = cutlass.utils.SmemAllocator()
|
|
53
|
-
# Don't use blkX.layout here, because the stride is N, not N_rounded
|
|
54
|
-
sX = smem.allocate_tensor(gX.element_type, cute.make_ordered_layout(blkX.shape, order=(1, 0)), byte_alignment=16)
|
|
55
|
-
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
56
|
-
warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
|
|
57
|
-
# reduction_buffer_layout = cute.make_ordered_layout((num_warps // warps_per_row, warps_per_row), order=(1, 0))
|
|
58
|
-
reduction_buffer_layout = cute.make_ordered_layout((num_warps // warps_per_row, warps_per_row if cluster_n == 1 else (warps_per_row, cluster_n)), order=(1, 0))
|
|
59
|
-
reduction_buffer = smem.allocate_tensor(cutlass.Float32, reduction_buffer_layout, byte_alignment=4)
|
|
60
|
-
if cutlass.const_expr(cluster_n > 1):
|
|
61
|
-
mbar_ptr = smem.allocate(cutlass.Int64.width // 8, byte_alignment=8)
|
|
62
|
-
else:
|
|
63
|
-
mbar_ptr = None
|
|
64
|
-
|
|
65
|
-
tWgW = thr_copy_W.partition_S(blkW)
|
|
66
|
-
tXgX = thr_copy_X_async.partition_S(blkX)
|
|
67
|
-
tXsX = thr_copy_X_async.partition_S(sX)
|
|
68
|
-
tXgO, tXrRstd = [thr_copy_O.partition_D(blk) for blk in (blkOut, blkRstd)]
|
|
69
|
-
tXcX = thr_copy_X.partition_S(blkCrd)[(0, None), None, None]
|
|
70
|
-
|
|
71
|
-
# allocate fragments for gmem->rmem
|
|
72
|
-
tWrW = cute.make_fragment_like(tWgW)
|
|
73
|
-
tXrW = thr_copy_X.retile(tWrW)
|
|
74
|
-
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
75
|
-
|
|
76
|
-
if cluster_n > 1:
|
|
77
|
-
if tidx == 0:
|
|
78
|
-
cute.arch.mbarrier_init_arrive_cnt(mbar_ptr, 1)
|
|
79
|
-
cute.arch.mbarrier_init_fence()
|
|
80
|
-
if tidx == 0:
|
|
81
|
-
cute.arch.mbarrier_init_tx_bytes(mbar_ptr, num_warps * cluster_n * cutlass.Float32.width // 8)
|
|
82
|
-
# Cluster arrive after barrier init
|
|
83
|
-
cute.arch.cluster_arrive_relaxed()
|
|
84
|
-
|
|
85
|
-
tXpX = cute.make_fragment_like(tXgX[(0, None), None, None], cutlass.Boolean)
|
|
86
|
-
for i in range(cute.size(tXpX)):
|
|
87
|
-
tXpX[i] = cute.elem_less(tXcX[i][1], shape[1])
|
|
88
|
-
# tXrX.fill(0.0)
|
|
89
|
-
if tXcX[0][0] < shape[0]:
|
|
90
|
-
# cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
91
|
-
cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
|
|
92
|
-
cute.arch.cp_async_commit_group()
|
|
92
|
+
@cute.kernel
|
|
93
|
+
def kernel(
|
|
94
|
+
self,
|
|
95
|
+
mX: cute.Tensor,
|
|
96
|
+
mW: cute.Tensor,
|
|
97
|
+
mO: cute.Tensor,
|
|
98
|
+
mRstd: Optional[cute.Tensor],
|
|
99
|
+
eps: cute.Float32,
|
|
100
|
+
tv_layout: cute.Layout,
|
|
101
|
+
tiler_mn: cute.Shape,
|
|
102
|
+
reload_from: cutlass.Constexpr = None,
|
|
103
|
+
delay_w_load: cutlass.Constexpr = False,
|
|
104
|
+
):
|
|
105
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
106
|
+
bidx, cluster_y, _ = cute.arch.block_idx()
|
|
107
|
+
|
|
108
|
+
smem = cutlass.utils.SmemAllocator()
|
|
109
|
+
sX = smem.allocate_tensor(
|
|
110
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
111
|
+
)
|
|
112
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
113
|
+
|
|
114
|
+
shape = mX.shape
|
|
115
|
+
idX = cute.make_identity_tensor(shape)
|
|
116
|
+
# slice for CTAs
|
|
117
|
+
gX, gO, cX = [
|
|
118
|
+
cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
|
|
119
|
+
for mT in (mX, mO, idX)
|
|
120
|
+
]
|
|
121
|
+
gW = cute.local_tile(mW, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
|
|
122
|
+
gRstd = (
|
|
123
|
+
cute.local_tile(mRstd, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
|
|
124
|
+
if cutlass.const_expr(mRstd is not None)
|
|
125
|
+
else None
|
|
126
|
+
)
|
|
93
127
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
128
|
+
# declare the atoms which will be used later for memory copy
|
|
129
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
130
|
+
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
|
|
131
|
+
)
|
|
132
|
+
copy_atom_load_X_async = cute.make_copy_atom(
|
|
133
|
+
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
|
|
134
|
+
)
|
|
135
|
+
copy_atom_load_W = cute.make_copy_atom(
|
|
136
|
+
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
|
|
137
|
+
)
|
|
138
|
+
copy_atom_store_O = cute.make_copy_atom(
|
|
139
|
+
cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=128
|
|
140
|
+
)
|
|
100
141
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
x = tXrX.load().to(cute.Float32)
|
|
104
|
-
sum_sq_x = utils.warp_reduce(
|
|
105
|
-
(x * x).reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
|
|
106
|
-
operator.add,
|
|
107
|
-
width=utils.min_constexpr(tv_layout.shape[0][0], cute.arch.WARP_SIZE),
|
|
108
|
-
)
|
|
109
|
-
if cutlass.const_expr(cluster_n > 1):
|
|
110
|
-
cute.arch.cluster_wait()
|
|
111
|
-
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
112
|
-
sum_sq_x = utils.block_or_cluster_reduce(
|
|
113
|
-
sum_sq_x, operator.add, reduction_buffer, mbar_ptr, init_val=0.0
|
|
142
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
|
|
143
|
+
tidx
|
|
114
144
|
)
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
145
|
+
thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
|
|
146
|
+
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
147
|
+
|
|
148
|
+
tWgW = thr_copy_W.partition_S(gW)
|
|
149
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
150
|
+
tXsX = thr_copy_X.partition_D(sX)
|
|
151
|
+
tXgO = thr_copy_O.partition_D(gO)
|
|
152
|
+
tXrRstd = thr_copy_O.partition_D(gRstd) if cutlass.const_expr(mRstd is not None) else None
|
|
153
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
154
|
+
|
|
155
|
+
# allocate fragments for gmem->rmem
|
|
156
|
+
tWrW = cute.make_fragment_like(tWgW)
|
|
157
|
+
tXrW = thr_copy_X.retile(tWrW)
|
|
158
|
+
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
159
|
+
|
|
160
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
161
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
162
|
+
|
|
163
|
+
tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
164
|
+
row = tXcX[0][0]
|
|
165
|
+
if row < shape[0]:
|
|
166
|
+
cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
|
|
167
|
+
cute.arch.cp_async_commit_group()
|
|
168
|
+
|
|
169
|
+
tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
|
|
170
|
+
if not delay_w_load:
|
|
171
|
+
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
172
|
+
|
|
173
|
+
cute.arch.cp_async_wait_group(0)
|
|
126
174
|
cute.autovec_copy(tXsX, tXrX)
|
|
127
175
|
x = tXrX.load().to(cute.Float32)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
)
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
assert num_threads % cute.arch.WARP_SIZE == 0
|
|
163
|
-
threads_per_row = 8 if N <= 64 else (16 if N <= 128 else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256))))
|
|
164
|
-
# cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
|
|
165
|
-
# Similarly cluster_n = 8 is faster for N=128k
|
|
166
|
-
if cutlass.const_expr(mX.element_type.width == 16):
|
|
167
|
-
cluster_n = 1 if N <= 16 * 1024 else (2 if N <= 32 * 1024 else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)))
|
|
168
|
-
else: # fp32
|
|
169
|
-
cluster_n = 1 if N <= 32 * 1024 else (2 if N <= 64 * 1024 else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)))
|
|
170
|
-
|
|
171
|
-
num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row * cluster_n)
|
|
172
|
-
cols_per_block = num_threads // threads_per_row
|
|
173
|
-
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) # This rounds up N
|
|
174
|
-
tv_layout = cute.make_layout(
|
|
175
|
-
((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
|
|
176
|
-
stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * threads_per_row))
|
|
177
|
-
)
|
|
178
|
-
|
|
179
|
-
mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
|
|
180
|
-
mW_expanded = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
181
|
-
mRstd_expanded_layout = cute.append(mRstd.layout, cute.make_layout((N,), stride=(0,)))
|
|
182
|
-
mRstd_expanded = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
|
|
183
|
-
idX = cute.make_identity_tensor(mX.shape)
|
|
184
|
-
gX, gW, gO, gRstd, cX = [cute.zipped_divide(mT, tiler_mn) for mT in (mX, mW_expanded, mOut, mRstd_expanded, idX)] # ((TileM,TileN),(RestM,RestN))
|
|
185
|
-
|
|
186
|
-
# reload_from = None if N <= 16384 else ("smem" if N <= 32768 else "gmem")
|
|
187
|
-
reload_from = None if N <= 16384 else "smem"
|
|
188
|
-
# delay_w_load = N > 64 * 1024
|
|
189
|
-
delay_w_load = False
|
|
190
|
-
N_rounded = tiler_mn[1]
|
|
191
|
-
rmsnorm_kernel(gX, gW, gO, gRstd, cX, eps, mX.shape, tv_layout, tiler_mn, cluster_n, reload_from).launch(
|
|
192
|
-
grid=[cute.size(gX, mode=[1, 0]), cluster_n, 1],
|
|
193
|
-
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
|
194
|
-
# Launching with cluster=[1, 1, 1] instead of None slows down the kernel by ~8us
|
|
195
|
-
cluster=[1, cluster_n, 1] if cluster_n > 1 else None,
|
|
196
|
-
# We don't want to use gX.layout[0] here since that has stride in N, not N_rounded, leading IMA on smem
|
|
197
|
-
smem=cute.size_in_bytes(mX.element_type, cute.make_layout(gX.shape[0])) + num_warps * cluster_n * (cutlass.Float32.width // 8) + (cutlass.Int64.width // 8),
|
|
198
|
-
stream=stream,
|
|
199
|
-
)
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
torch2cute_dtype_map = {
|
|
203
|
-
torch.float16: cutlass.Float16,
|
|
204
|
-
torch.bfloat16: cutlass.BFloat16,
|
|
205
|
-
torch.float32: cutlass.Float32,
|
|
206
|
-
}
|
|
176
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
177
|
+
sum_sq_x = utils.row_reduce(
|
|
178
|
+
x * x,
|
|
179
|
+
cute.ReductionOp.ADD,
|
|
180
|
+
threads_per_row,
|
|
181
|
+
reduction_buffer[None, None, 0],
|
|
182
|
+
mbar_ptr,
|
|
183
|
+
init_val=0.0,
|
|
184
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
185
|
+
)
|
|
186
|
+
rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
|
|
187
|
+
if cutlass.const_expr(mRstd is not None):
|
|
188
|
+
# Only the thread corresponding to column 0 writes out the rstd to gmem
|
|
189
|
+
if (
|
|
190
|
+
tXcX[0][1] == 0
|
|
191
|
+
and row < shape[0]
|
|
192
|
+
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
193
|
+
):
|
|
194
|
+
tXrRstd[0] = rstd
|
|
195
|
+
if delay_w_load:
|
|
196
|
+
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
197
|
+
if reload_from == "smem":
|
|
198
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
199
|
+
x = tXrX.load().to(cute.Float32)
|
|
200
|
+
elif reload_from == "gmem":
|
|
201
|
+
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
202
|
+
x = tXrX.load().to(cute.Float32)
|
|
203
|
+
x_hat = x * rstd
|
|
204
|
+
w = tXrW.load().to(cute.Float32)
|
|
205
|
+
y = x_hat * w
|
|
206
|
+
tXrO.store(y.to(tXrO.element_type))
|
|
207
|
+
tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
|
|
208
|
+
if row < shape[0]:
|
|
209
|
+
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
207
210
|
|
|
208
211
|
|
|
209
212
|
def rmsnorm(
|
|
@@ -233,24 +236,32 @@ def rmsnorm(
|
|
|
233
236
|
M, N = x.shape
|
|
234
237
|
device = x.device
|
|
235
238
|
out = torch.empty_like(x)
|
|
236
|
-
rstd = torch.empty(M, device=device, dtype=torch.float32)
|
|
239
|
+
rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
|
|
237
240
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
238
241
|
convert_from_dlpack = lambda x: (
|
|
239
|
-
from_dlpack(x.detach(), assumed_align=16)
|
|
240
|
-
|
|
242
|
+
from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
243
|
+
mode=0, stride_order=(0, 1)
|
|
244
|
+
)
|
|
241
245
|
)
|
|
242
246
|
x_tensor, out_tensor = [
|
|
243
247
|
# utils.convert_from_dlpack(t, leading_dim=t.ndim - 1, divisibility=128 // dtype.width)
|
|
244
248
|
convert_from_dlpack(t)
|
|
245
249
|
for t in (x, out)
|
|
246
250
|
]
|
|
247
|
-
weight_tensor = utils.convert_from_dlpack(
|
|
248
|
-
|
|
251
|
+
weight_tensor = utils.convert_from_dlpack(
|
|
252
|
+
weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
|
|
253
|
+
)
|
|
254
|
+
rstd_tensor = (
|
|
255
|
+
from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
256
|
+
if rstd is not None
|
|
257
|
+
else None
|
|
258
|
+
)
|
|
249
259
|
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
250
|
-
compile_key = (dtype, N)
|
|
260
|
+
compile_key = (dtype, N, rstd is not None)
|
|
251
261
|
if compile_key not in rmsnorm.compile_cache:
|
|
262
|
+
rmsnorm_op = RMSNorm(dtype, N)
|
|
252
263
|
rmsnorm.compile_cache[compile_key] = cute.compile(
|
|
253
|
-
|
|
264
|
+
rmsnorm_op, x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream
|
|
254
265
|
)
|
|
255
266
|
rmsnorm.compile_cache[compile_key](
|
|
256
267
|
x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream, eps
|
|
@@ -263,7 +274,9 @@ rmsnorm.compile_cache = {}
|
|
|
263
274
|
|
|
264
275
|
def rmsnorm_ref(x, w, eps=1e-6):
|
|
265
276
|
x_f32 = x.float()
|
|
266
|
-
return (x_f32 / (torch.sqrt(torch.mean(x_f32
|
|
277
|
+
return (x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps)) * w).to(
|
|
278
|
+
x.dtype
|
|
279
|
+
)
|
|
267
280
|
|
|
268
281
|
|
|
269
282
|
def rstd_ref(x, eps=1e-6):
|