quack-kernels 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/rmsnorm.py CHANGED
@@ -1,9 +1,8 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
- import math
4
- import operator
5
3
 
6
4
  import torch
5
+ from typing import Optional
7
6
 
8
7
  import cuda.bindings.driver as cuda
9
8
 
@@ -12,181 +11,203 @@ import cutlass.cute as cute
12
11
  from cutlass.cute.runtime import from_dlpack
13
12
 
14
13
  import quack.utils as utils
14
+ from quack.reduction_base import ReductionBase, torch2cute_dtype_map
15
+
16
+
17
+ class RMSNorm(ReductionBase):
18
+ def __init__(self, dtype: cutlass.Numeric, N: int):
19
+ super().__init__(dtype, N, stage=1)
20
+ self.reload_from = None if N <= 16384 else "smem"
21
+ self.delay_w_load = False
22
+
23
+ def _calculate_threads_per_row(self):
24
+ N = self.N
25
+ return (
26
+ 8
27
+ if N <= 64
28
+ else (
29
+ 16
30
+ if N <= 128
31
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
32
+ )
33
+ )
15
34
 
35
+ def _set_cluster_n(self):
36
+ N = self.N
37
+ # cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
38
+ # Similarly cluster_n = 8 is faster for N=128k
39
+ if cutlass.const_expr(self.dtype.width == 16):
40
+ cluster_n = (
41
+ 1
42
+ if N <= 16 * 1024
43
+ else (
44
+ 2
45
+ if N <= 32 * 1024
46
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
47
+ )
48
+ )
49
+ else: # fp32
50
+ cluster_n = (
51
+ 1
52
+ if N <= 32 * 1024
53
+ else (
54
+ 2
55
+ if N <= 64 * 1024
56
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
57
+ )
58
+ )
59
+ self.cluster_n = cluster_n
60
+
61
+ @cute.jit
62
+ def __call__(
63
+ self,
64
+ mX: cute.Tensor,
65
+ mW: cute.Tensor,
66
+ mO: cute.Tensor,
67
+ mRstd: Optional[cute.Tensor],
68
+ stream: cuda.CUstream,
69
+ eps: cutlass.Float32 = 1e-6,
70
+ ):
71
+ assert mX.element_type == self.dtype
72
+ assert mO.element_type == self.dtype
73
+ self._set_cluster_n()
74
+ tiler_mn, tv_layout = self._get_tv_layout()
75
+ num_threads = cute.size(tv_layout, mode=[0])
76
+ num_warps = num_threads // cute.arch.WARP_SIZE
77
+ mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
78
+ mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
79
+ if cutlass.const_expr(mRstd is not None):
80
+ mRstd_expanded_layout = cute.append(
81
+ mRstd.layout, cute.make_layout((self.N,), stride=(0,))
82
+ )
83
+ mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
84
+ self.kernel(mX, mW, mO, mRstd, eps, tv_layout, tiler_mn, self.reload_from).launch(
85
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
86
+ block=[num_threads, 1, 1],
87
+ cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
88
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
89
+ stream=stream,
90
+ )
16
91
 
17
- @cute.kernel
18
- def rmsnorm_kernel(
19
- mX: cute.Tensor,
20
- mW: cute.Tensor,
21
- mO: cute.Tensor,
22
- mRstd: cute.Tensor,
23
- eps: cute.Float32,
24
- tv_layout: cute.Layout,
25
- tiler_mn: cute.Shape,
26
- cluster_n: cutlass.Constexpr = 1,
27
- reload_from: cutlass.Constexpr = None,
28
- delay_w_load: cutlass.Constexpr = False,
29
- ):
30
- tidx, _, _ = cute.arch.thread_idx()
31
- bidx, cluster_y, _ = cute.arch.block_idx()
32
-
33
- smem = cutlass.utils.SmemAllocator()
34
- sX = smem.allocate_tensor(mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16)
35
- num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
36
- warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
37
- reduction_buffer_layout = cute.make_ordered_layout(
38
- (num_warps // warps_per_row, (warps_per_row, cluster_n)),
39
- order=(1, 0)
40
- )
41
- reduction_buffer = smem.allocate_tensor(cutlass.Float32, reduction_buffer_layout, byte_alignment=4)
42
- if cutlass.const_expr(cluster_n > 1):
43
- mbar_ptr = smem.allocate(cutlass.Int64.width // 8, byte_alignment=8)
44
- else:
45
- mbar_ptr = None
46
-
47
- shape = mX.shape
48
- idX = cute.make_identity_tensor(shape)
49
- # slice for CTAs
50
- gX, gO, gRstd, cX = [
51
- cute.local_tile(mT, tiler_mn, (bidx, 0 if cluster_n == 1 else cluster_y))
52
- for mT in (mX, mO, mRstd, idX)
53
- ]
54
- gW = cute.local_tile(mW, tiler_mn, (0, 0 if cluster_n == 1 else cluster_y))
55
-
56
- # declare the atoms which will be used later for memory copy
57
- copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128)
58
- copy_atom_load_X_async = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128)
59
- copy_atom_load_W = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128)
60
- copy_atom_store_O = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=128)
61
-
62
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(tidx)
63
- thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
64
- thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
65
-
66
- tWgW = thr_copy_W.partition_S(gW)
67
- tXgX = thr_copy_X.partition_S(gX)
68
- tXsX = thr_copy_X.partition_D(sX)
69
- tXgO, tXrRstd = [thr_copy_O.partition_D(gT) for gT in (gO, gRstd)]
70
- tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
71
-
72
- # allocate fragments for gmem->rmem
73
- tWrW = cute.make_fragment_like(tWgW)
74
- tXrW = thr_copy_X.retile(tWrW)
75
- tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
76
-
77
- if cluster_n > 1:
78
- if tidx == 0:
79
- cute.arch.mbarrier_init_arrive_cnt(mbar_ptr, 1)
80
- cute.arch.mbarrier_init_fence()
81
- if tidx == 0:
82
- cute.arch.mbarrier_init_tx_bytes(mbar_ptr, num_warps * cluster_n * cutlass.Float32.width // 8)
83
- # Cluster arrive after barrier init
84
- cute.arch.cluster_arrive_relaxed()
85
-
86
- tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
87
- row = tXcX[0][0]
88
- if row < shape[0]:
89
- cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
90
- cute.arch.cp_async_commit_group()
91
-
92
- tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
93
- if not delay_w_load:
94
- cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
95
-
96
- cute.arch.cp_async_wait_group(0)
97
- cute.autovec_copy(tXsX, tXrX)
98
- x = tXrX.load().to(cute.Float32)
99
- threads_per_row = tv_layout.shape[0][0]
100
- sum_sq_x = utils.row_reduce(
101
- x * x,
102
- cute.ReductionOp.ADD,
103
- threads_per_row,
104
- reduction_buffer,
105
- mbar_ptr,
106
- init_val=0.0,
107
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(cluster_n > 1) else None
108
- )
109
- rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
110
- # Only the thread corresponding to column 0 writes out the rstd to gmem
111
- if tXcX[0][1] == 0 and row < shape[0] and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0):
112
- tXrRstd[0] = rstd
113
- if delay_w_load:
114
- cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
115
- if reload_from == "smem":
116
- cute.autovec_copy(tXsX, tXrX)
117
- x = tXrX.load().to(cute.Float32)
118
- elif reload_from == "gmem":
119
- cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
120
- x = tXrX.load().to(cute.Float32)
121
- x_hat = x * rstd
122
- w = tXrW.load().to(cute.Float32)
123
- y = x_hat * w
124
- tXrO.store(y.to(tXrO.element_type))
125
- tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
126
- if row < shape[0]:
127
- cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
128
-
129
-
130
- @cute.jit
131
- def rmsnorm_interface(
132
- mX: cute.Tensor,
133
- mW: cute.Tensor,
134
- mO: cute.Tensor,
135
- mRstd: cute.Tensor,
136
- stream: cuda.CUstream,
137
- N: cutlass.Constexpr,
138
- eps: cutlass.Float32 = 1e-6,
139
- copy_bits: cutlass.Constexpr = 128
140
- ):
141
- # new_shape = (mX_.shape[0], cute.assume(mX_.shape[1], 128))
142
- # breakpoint()
143
- # mX = cute.make_tensor(mX_.iterator, cute.make_layout(new_shape, stride=mX_.stride))
144
- vecsize = copy_bits // mX.element_type.width
145
- assert N % vecsize == 0, f"Input N {N} is not divisible by vector size {vecsize}"
146
- num_threads = 128 if N <= 16384 else 256
147
- num_warps = num_threads // cute.arch.WARP_SIZE
148
- assert num_threads % cute.arch.WARP_SIZE == 0
149
- threads_per_row = 8 if N <= 64 else (16 if N <= 128 else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256))))
150
- # cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
151
- # Similarly cluster_n = 8 is faster for N=128k
152
- if cutlass.const_expr(mX.element_type.width == 16):
153
- cluster_n = 1 if N <= 16 * 1024 else (2 if N <= 32 * 1024 else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)))
154
- else: # fp32
155
- cluster_n = 1 if N <= 32 * 1024 else (2 if N <= 64 * 1024 else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)))
156
-
157
- num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row * cluster_n)
158
- cols_per_block = num_threads // threads_per_row
159
- tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) # This rounds up N
160
- tv_layout = cute.make_layout(
161
- ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
162
- stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * threads_per_row))
163
- )
164
-
165
- mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
166
- mW_expanded = cute.make_tensor(mW.iterator, mW_expanded_layout)
167
- mRstd_expanded_layout = cute.append(mRstd.layout, cute.make_layout((N,), stride=(0,)))
168
- mRstd_expanded = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
169
-
170
- # reload_from = None if N <= 16384 else ("smem" if N <= 32768 else "gmem")
171
- reload_from = None if N <= 16384 else "smem"
172
- # delay_w_load = N > 64 * 1024
173
- delay_w_load = False
174
- N_rounded = tiler_mn[1]
175
- rmsnorm_kernel(mX, mW_expanded, mO, mRstd_expanded, eps, tv_layout, tiler_mn, cluster_n, reload_from).launch(
176
- grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1],
177
- block=[cute.size(tv_layout, mode=[0]), 1, 1],
178
- # Launching with cluster=[1, 1, 1] instead of None slows down the kernel by ~8us
179
- cluster=[1, cluster_n, 1] if cluster_n > 1 else None,
180
- smem=cute.size_in_bytes(mX.element_type, cute.make_layout(tiler_mn)) + num_warps * cluster_n * (cutlass.Float32.width // 8) + (cutlass.Int64.width // 8),
181
- stream=stream,
182
- )
92
+ @cute.kernel
93
+ def kernel(
94
+ self,
95
+ mX: cute.Tensor,
96
+ mW: cute.Tensor,
97
+ mO: cute.Tensor,
98
+ mRstd: Optional[cute.Tensor],
99
+ eps: cute.Float32,
100
+ tv_layout: cute.Layout,
101
+ tiler_mn: cute.Shape,
102
+ reload_from: cutlass.Constexpr = None,
103
+ delay_w_load: cutlass.Constexpr = False,
104
+ ):
105
+ tidx, _, _ = cute.arch.thread_idx()
106
+ bidx, _, _ = cute.arch.block_idx()
107
+ if cutlass.const_expr(self.cluster_n > 1):
108
+ cluster_y = cute.arch.block_idx()[1]
109
+ else:
110
+ cluster_y = cutlass.const_expr(0)
111
+
112
+ smem = cutlass.utils.SmemAllocator()
113
+ sX = smem.allocate_tensor(
114
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
115
+ )
116
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
117
+
118
+ shape = mX.shape
119
+ idX = cute.make_identity_tensor(shape)
120
+ # slice for CTAs
121
+ gX, gO, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)]
122
+ gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
123
+ gRstd = (
124
+ cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
125
+ if cutlass.const_expr(mRstd is not None)
126
+ else None
127
+ )
183
128
 
129
+ # declare the atoms which will be used later for memory copy
130
+ copy_atom_load_X = cute.make_copy_atom(
131
+ cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
132
+ )
133
+ copy_atom_load_X_async = cute.make_copy_atom(
134
+ cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
135
+ )
136
+ copy_atom_load_W = cute.make_copy_atom(
137
+ cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
138
+ )
139
+ copy_atom_store_O = cute.make_copy_atom(
140
+ cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=128
141
+ )
184
142
 
185
- torch2cute_dtype_map = {
186
- torch.float16: cutlass.Float16,
187
- torch.bfloat16: cutlass.BFloat16,
188
- torch.float32: cutlass.Float32,
189
- }
143
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
144
+ tidx
145
+ )
146
+ thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
147
+ thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
148
+
149
+ tWgW = thr_copy_W.partition_S(gW)
150
+ tXgX = thr_copy_X.partition_S(gX)
151
+ tXsX = thr_copy_X.partition_D(sX)
152
+ tXgO = thr_copy_O.partition_D(gO)
153
+ tXrRstd = thr_copy_O.partition_D(gRstd) if cutlass.const_expr(mRstd is not None) else None
154
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
155
+
156
+ # allocate fragments for gmem->rmem
157
+ tWrW = cute.make_fragment_like(tWgW)
158
+ tXrW = thr_copy_X.retile(tWrW)
159
+ tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
160
+
161
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
162
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
163
+
164
+ tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
165
+ row = tXcX[0][0]
166
+ if row < shape[0]:
167
+ cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
168
+ cute.arch.cp_async_commit_group()
169
+
170
+ tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
171
+ if cutlass.const_expr(not delay_w_load):
172
+ cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
173
+
174
+ cute.arch.cp_async_wait_group(0)
175
+ cute.autovec_copy(tXsX, tXrX)
176
+ x = tXrX.load().to(cute.Float32)
177
+ threads_per_row = tv_layout.shape[0][0]
178
+ sum_sq_x = utils.row_reduce(
179
+ x * x,
180
+ cute.ReductionOp.ADD,
181
+ threads_per_row,
182
+ reduction_buffer[None, None, 0],
183
+ mbar_ptr,
184
+ init_val=0.0,
185
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
186
+ )
187
+ rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
188
+ if cutlass.const_expr(mRstd is not None):
189
+ # Only the thread corresponding to column 0 writes out the rstd to gmem
190
+ if (
191
+ tXcX[0][1] == 0
192
+ and row < shape[0]
193
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
194
+ ):
195
+ tXrRstd[0] = rstd
196
+ if cutlass.const_expr(delay_w_load):
197
+ cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
198
+ if cutlass.const_expr(reload_from == "smem"):
199
+ cute.autovec_copy(tXsX, tXrX)
200
+ x = tXrX.load().to(cute.Float32)
201
+ elif cutlass.const_expr(reload_from == "gmem"):
202
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
203
+ x = tXrX.load().to(cute.Float32)
204
+ x_hat = x * rstd
205
+ w = tXrW.load().to(cute.Float32)
206
+ y = x_hat * w
207
+ tXrO.store(y.to(tXrO.element_type))
208
+ tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
209
+ if row < shape[0]:
210
+ cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
190
211
 
191
212
 
192
213
  def rmsnorm(
@@ -216,24 +237,32 @@ def rmsnorm(
216
237
  M, N = x.shape
217
238
  device = x.device
218
239
  out = torch.empty_like(x)
219
- rstd = torch.empty(M, device=device, dtype=torch.float32)
240
+ rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
220
241
  dtype = torch2cute_dtype_map[x.dtype]
221
242
  convert_from_dlpack = lambda x: (
222
- from_dlpack(x.detach(), assumed_align=16)
223
- .mark_compact_shape_dynamic(mode=0, stride_order=(0, 1))
243
+ from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
244
+ mode=0, stride_order=(0, 1)
245
+ )
224
246
  )
225
247
  x_tensor, out_tensor = [
226
248
  # utils.convert_from_dlpack(t, leading_dim=t.ndim - 1, divisibility=128 // dtype.width)
227
249
  convert_from_dlpack(t)
228
250
  for t in (x, out)
229
251
  ]
230
- weight_tensor = utils.convert_from_dlpack(weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width)
231
- rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
252
+ weight_tensor = utils.convert_from_dlpack(
253
+ weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
254
+ )
255
+ rstd_tensor = (
256
+ from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
257
+ if rstd is not None
258
+ else None
259
+ )
232
260
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
233
- compile_key = (dtype, N)
261
+ compile_key = (dtype, N, rstd is not None)
234
262
  if compile_key not in rmsnorm.compile_cache:
263
+ rmsnorm_op = RMSNorm(dtype, N)
235
264
  rmsnorm.compile_cache[compile_key] = cute.compile(
236
- rmsnorm_interface, x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream, N
265
+ rmsnorm_op, x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream
237
266
  )
238
267
  rmsnorm.compile_cache[compile_key](
239
268
  x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream, eps
@@ -246,7 +275,9 @@ rmsnorm.compile_cache = {}
246
275
 
247
276
  def rmsnorm_ref(x, w, eps=1e-6):
248
277
  x_f32 = x.float()
249
- return (x_f32 / (torch.sqrt(torch.mean(x_f32 * x_f32, dim=-1, keepdim=True) + eps)) * w).to(x.dtype)
278
+ return (x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps)) * w).to(
279
+ x.dtype
280
+ )
250
281
 
251
282
 
252
283
  def rstd_ref(x, eps=1e-6):