quack-kernels 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/rmsnorm.py CHANGED
@@ -1,9 +1,8 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
- import math
4
- import operator
5
3
 
6
4
  import torch
5
+ from typing import Optional
7
6
 
8
7
  import cuda.bindings.driver as cuda
9
8
 
@@ -12,181 +11,202 @@ import cutlass.cute as cute
12
11
  from cutlass.cute.runtime import from_dlpack
13
12
 
14
13
  import quack.utils as utils
14
+ from quack.reduction_base import ReductionBase, torch2cute_dtype_map
15
+
16
+
17
+ class RMSNorm(ReductionBase):
18
+ def __init__(self, dtype: cutlass.Numeric, N: int):
19
+ super().__init__(dtype, N, stage=1)
20
+ self.reload_from = None if N <= 16384 else "smem"
21
+ self.delay_w_load = False
22
+
23
+ def _calculate_threads_per_row(self):
24
+ N = self.N
25
+ return (
26
+ 8
27
+ if N <= 64
28
+ else (
29
+ 16
30
+ if N <= 128
31
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
32
+ )
33
+ )
15
34
 
35
+ def _set_cluster_n(self):
36
+ N = self.N
37
+ # cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
38
+ # Similarly cluster_n = 8 is faster for N=128k
39
+ if cutlass.const_expr(self.dtype.width == 16):
40
+ cluster_n = (
41
+ 1
42
+ if N <= 16 * 1024
43
+ else (
44
+ 2
45
+ if N <= 32 * 1024
46
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
47
+ )
48
+ )
49
+ else: # fp32
50
+ cluster_n = (
51
+ 1
52
+ if N <= 32 * 1024
53
+ else (
54
+ 2
55
+ if N <= 64 * 1024
56
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
57
+ )
58
+ )
59
+ self.cluster_n = cluster_n
60
+
61
+ @cute.jit
62
+ def __call__(
63
+ self,
64
+ mX: cute.Tensor,
65
+ mW: cute.Tensor,
66
+ mO: cute.Tensor,
67
+ mRstd: Optional[cute.Tensor],
68
+ stream: cuda.CUstream,
69
+ eps: cutlass.Float32 = 1e-6,
70
+ ):
71
+ assert mX.element_type == self.dtype
72
+ assert mO.element_type == self.dtype
73
+ self._set_cluster_n()
74
+ tiler_mn, tv_layout = self._get_tv_layout()
75
+ num_threads = cute.size(tv_layout, mode=[0])
76
+ num_warps = num_threads // cute.arch.WARP_SIZE
77
+ mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
78
+ mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
79
+ if cutlass.const_expr(mRstd is not None):
80
+ mRstd_expanded_layout = cute.append(
81
+ mRstd.layout, cute.make_layout((self.N,), stride=(0,))
82
+ )
83
+ mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
84
+ self.kernel(mX, mW, mO, mRstd, eps, tv_layout, tiler_mn, self.reload_from).launch(
85
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
86
+ block=[num_threads, 1, 1],
87
+ cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
88
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
89
+ stream=stream,
90
+ )
16
91
 
17
- @cute.kernel
18
- def rmsnorm_kernel(
19
- mX: cute.Tensor,
20
- mW: cute.Tensor,
21
- mO: cute.Tensor,
22
- mRstd: cute.Tensor,
23
- eps: cute.Float32,
24
- tv_layout: cute.Layout,
25
- tiler_mn: cute.Shape,
26
- cluster_n: cutlass.Constexpr = 1,
27
- reload_from: cutlass.Constexpr = None,
28
- delay_w_load: cutlass.Constexpr = False,
29
- ):
30
- tidx, _, _ = cute.arch.thread_idx()
31
- bidx, cluster_y, _ = cute.arch.block_idx()
32
-
33
- smem = cutlass.utils.SmemAllocator()
34
- sX = smem.allocate_tensor(mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16)
35
- num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
36
- warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
37
- reduction_buffer_layout = cute.make_ordered_layout(
38
- (num_warps // warps_per_row, (warps_per_row, cluster_n)),
39
- order=(1, 0)
40
- )
41
- reduction_buffer = smem.allocate_tensor(cutlass.Float32, reduction_buffer_layout, byte_alignment=4)
42
- if cutlass.const_expr(cluster_n > 1):
43
- mbar_ptr = smem.allocate(cutlass.Int64.width // 8, byte_alignment=8)
44
- else:
45
- mbar_ptr = None
46
-
47
- shape = mX.shape
48
- idX = cute.make_identity_tensor(shape)
49
- # slice for CTAs
50
- gX, gO, gRstd, cX = [
51
- cute.local_tile(mT, tiler_mn, (bidx, 0 if cluster_n == 1 else cluster_y))
52
- for mT in (mX, mO, mRstd, idX)
53
- ]
54
- gW = cute.local_tile(mW, tiler_mn, (0, 0 if cluster_n == 1 else cluster_y))
55
-
56
- # declare the atoms which will be used later for memory copy
57
- copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128)
58
- copy_atom_load_X_async = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128)
59
- copy_atom_load_W = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128)
60
- copy_atom_store_O = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=128)
61
-
62
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(tidx)
63
- thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
64
- thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
65
-
66
- tWgW = thr_copy_W.partition_S(gW)
67
- tXgX = thr_copy_X.partition_S(gX)
68
- tXsX = thr_copy_X.partition_D(sX)
69
- tXgO, tXrRstd = [thr_copy_O.partition_D(gT) for gT in (gO, gRstd)]
70
- tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
71
-
72
- # allocate fragments for gmem->rmem
73
- tWrW = cute.make_fragment_like(tWgW)
74
- tXrW = thr_copy_X.retile(tWrW)
75
- tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
76
-
77
- if cluster_n > 1:
78
- if tidx == 0:
79
- cute.arch.mbarrier_init_arrive_cnt(mbar_ptr, 1)
80
- cute.arch.mbarrier_init_fence()
81
- if tidx == 0:
82
- cute.arch.mbarrier_init_tx_bytes(mbar_ptr, num_warps * cluster_n * cutlass.Float32.width // 8)
83
- # Cluster arrive after barrier init
84
- cute.arch.cluster_arrive_relaxed()
85
-
86
- tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
87
- row = tXcX[0][0]
88
- if row < shape[0]:
89
- cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
90
- cute.arch.cp_async_commit_group()
91
-
92
- tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
93
- if not delay_w_load:
94
- cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
95
-
96
- cute.arch.cp_async_wait_group(0)
97
- cute.autovec_copy(tXsX, tXrX)
98
- x = tXrX.load().to(cute.Float32)
99
- threads_per_row = tv_layout.shape[0][0]
100
- sum_sq_x = utils.row_reduce(
101
- x * x,
102
- cute.ReductionOp.ADD,
103
- threads_per_row,
104
- reduction_buffer,
105
- mbar_ptr,
106
- init_val=0.0,
107
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(cluster_n > 1) else None
108
- )
109
- rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
110
- # Only the thread corresponding to column 0 writes out the rstd to gmem
111
- if tXcX[0][1] == 0 and row < shape[0] and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0):
112
- tXrRstd[0] = rstd
113
- if delay_w_load:
114
- cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
115
- if reload_from == "smem":
116
- cute.autovec_copy(tXsX, tXrX)
117
- x = tXrX.load().to(cute.Float32)
118
- elif reload_from == "gmem":
119
- cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
120
- x = tXrX.load().to(cute.Float32)
121
- x_hat = x * rstd
122
- w = tXrW.load().to(cute.Float32)
123
- y = x_hat * w
124
- tXrO.store(y.to(tXrO.element_type))
125
- tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
126
- if row < shape[0]:
127
- cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
128
-
129
-
130
- @cute.jit
131
- def rmsnorm_interface(
132
- mX: cute.Tensor,
133
- mW: cute.Tensor,
134
- mO: cute.Tensor,
135
- mRstd: cute.Tensor,
136
- stream: cuda.CUstream,
137
- N: cutlass.Constexpr,
138
- eps: cutlass.Float32 = 1e-6,
139
- copy_bits: cutlass.Constexpr = 128
140
- ):
141
- # new_shape = (mX_.shape[0], cute.assume(mX_.shape[1], 128))
142
- # breakpoint()
143
- # mX = cute.make_tensor(mX_.iterator, cute.make_layout(new_shape, stride=mX_.stride))
144
- vecsize = copy_bits // mX.element_type.width
145
- assert N % vecsize == 0, f"Input N {N} is not divisible by vector size {vecsize}"
146
- num_threads = 128 if N <= 16384 else 256
147
- num_warps = num_threads // cute.arch.WARP_SIZE
148
- assert num_threads % cute.arch.WARP_SIZE == 0
149
- threads_per_row = 8 if N <= 64 else (16 if N <= 128 else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256))))
150
- # cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
151
- # Similarly cluster_n = 8 is faster for N=128k
152
- if cutlass.const_expr(mX.element_type.width == 16):
153
- cluster_n = 1 if N <= 16 * 1024 else (2 if N <= 32 * 1024 else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)))
154
- else: # fp32
155
- cluster_n = 1 if N <= 32 * 1024 else (2 if N <= 64 * 1024 else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)))
156
-
157
- num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row * cluster_n)
158
- cols_per_block = num_threads // threads_per_row
159
- tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) # This rounds up N
160
- tv_layout = cute.make_layout(
161
- ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
162
- stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * threads_per_row))
163
- )
164
-
165
- mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
166
- mW_expanded = cute.make_tensor(mW.iterator, mW_expanded_layout)
167
- mRstd_expanded_layout = cute.append(mRstd.layout, cute.make_layout((N,), stride=(0,)))
168
- mRstd_expanded = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
169
-
170
- # reload_from = None if N <= 16384 else ("smem" if N <= 32768 else "gmem")
171
- reload_from = None if N <= 16384 else "smem"
172
- # delay_w_load = N > 64 * 1024
173
- delay_w_load = False
174
- N_rounded = tiler_mn[1]
175
- rmsnorm_kernel(mX, mW_expanded, mO, mRstd_expanded, eps, tv_layout, tiler_mn, cluster_n, reload_from).launch(
176
- grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1],
177
- block=[cute.size(tv_layout, mode=[0]), 1, 1],
178
- # Launching with cluster=[1, 1, 1] instead of None slows down the kernel by ~8us
179
- cluster=[1, cluster_n, 1] if cluster_n > 1 else None,
180
- smem=cute.size_in_bytes(mX.element_type, cute.make_layout(tiler_mn)) + num_warps * cluster_n * (cutlass.Float32.width // 8) + (cutlass.Int64.width // 8),
181
- stream=stream,
182
- )
92
+ @cute.kernel
93
+ def kernel(
94
+ self,
95
+ mX: cute.Tensor,
96
+ mW: cute.Tensor,
97
+ mO: cute.Tensor,
98
+ mRstd: Optional[cute.Tensor],
99
+ eps: cute.Float32,
100
+ tv_layout: cute.Layout,
101
+ tiler_mn: cute.Shape,
102
+ reload_from: cutlass.Constexpr = None,
103
+ delay_w_load: cutlass.Constexpr = False,
104
+ ):
105
+ tidx, _, _ = cute.arch.thread_idx()
106
+ bidx, cluster_y, _ = cute.arch.block_idx()
107
+
108
+ smem = cutlass.utils.SmemAllocator()
109
+ sX = smem.allocate_tensor(
110
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
111
+ )
112
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
113
+
114
+ shape = mX.shape
115
+ idX = cute.make_identity_tensor(shape)
116
+ # slice for CTAs
117
+ gX, gO, cX = [
118
+ cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
119
+ for mT in (mX, mO, idX)
120
+ ]
121
+ gW = cute.local_tile(mW, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
122
+ gRstd = (
123
+ cute.local_tile(mRstd, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
124
+ if cutlass.const_expr(mRstd is not None)
125
+ else None
126
+ )
183
127
 
128
+ # declare the atoms which will be used later for memory copy
129
+ copy_atom_load_X = cute.make_copy_atom(
130
+ cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
131
+ )
132
+ copy_atom_load_X_async = cute.make_copy_atom(
133
+ cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
134
+ )
135
+ copy_atom_load_W = cute.make_copy_atom(
136
+ cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
137
+ )
138
+ copy_atom_store_O = cute.make_copy_atom(
139
+ cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=128
140
+ )
184
141
 
185
- torch2cute_dtype_map = {
186
- torch.float16: cutlass.Float16,
187
- torch.bfloat16: cutlass.BFloat16,
188
- torch.float32: cutlass.Float32,
189
- }
142
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
143
+ tidx
144
+ )
145
+ thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
146
+ thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
147
+
148
+ tWgW = thr_copy_W.partition_S(gW)
149
+ tXgX = thr_copy_X.partition_S(gX)
150
+ tXsX = thr_copy_X.partition_D(sX)
151
+ tXgO = thr_copy_O.partition_D(gO)
152
+ tXrRstd = thr_copy_O.partition_D(gRstd) if cutlass.const_expr(mRstd is not None) else None
153
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
154
+
155
+ # allocate fragments for gmem->rmem
156
+ tWrW = cute.make_fragment_like(tWgW)
157
+ tXrW = thr_copy_X.retile(tWrW)
158
+ tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
159
+
160
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
161
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
162
+
163
+ tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
164
+ row = tXcX[0][0]
165
+ if row < shape[0]:
166
+ cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
167
+ cute.arch.cp_async_commit_group()
168
+
169
+ tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
170
+ if not delay_w_load:
171
+ cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
172
+
173
+ cute.arch.cp_async_wait_group(0)
174
+ cute.autovec_copy(tXsX, tXrX)
175
+ x = tXrX.load().to(cute.Float32)
176
+ threads_per_row = tv_layout.shape[0][0]
177
+ sum_sq_x = utils.row_reduce(
178
+ x * x,
179
+ cute.ReductionOp.ADD,
180
+ threads_per_row,
181
+ reduction_buffer[None, None, 0],
182
+ mbar_ptr,
183
+ init_val=0.0,
184
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
185
+ )
186
+ rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
187
+ if cutlass.const_expr(mRstd is not None):
188
+ # Only the thread corresponding to column 0 writes out the rstd to gmem
189
+ if (
190
+ tXcX[0][1] == 0
191
+ and row < shape[0]
192
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
193
+ ):
194
+ tXrRstd[0] = rstd
195
+ if delay_w_load:
196
+ cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
197
+ if reload_from == "smem":
198
+ cute.autovec_copy(tXsX, tXrX)
199
+ x = tXrX.load().to(cute.Float32)
200
+ elif reload_from == "gmem":
201
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
202
+ x = tXrX.load().to(cute.Float32)
203
+ x_hat = x * rstd
204
+ w = tXrW.load().to(cute.Float32)
205
+ y = x_hat * w
206
+ tXrO.store(y.to(tXrO.element_type))
207
+ tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
208
+ if row < shape[0]:
209
+ cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
190
210
 
191
211
 
192
212
  def rmsnorm(
@@ -216,24 +236,32 @@ def rmsnorm(
216
236
  M, N = x.shape
217
237
  device = x.device
218
238
  out = torch.empty_like(x)
219
- rstd = torch.empty(M, device=device, dtype=torch.float32)
239
+ rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
220
240
  dtype = torch2cute_dtype_map[x.dtype]
221
241
  convert_from_dlpack = lambda x: (
222
- from_dlpack(x.detach(), assumed_align=16)
223
- .mark_compact_shape_dynamic(mode=0, stride_order=(0, 1))
242
+ from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
243
+ mode=0, stride_order=(0, 1)
244
+ )
224
245
  )
225
246
  x_tensor, out_tensor = [
226
247
  # utils.convert_from_dlpack(t, leading_dim=t.ndim - 1, divisibility=128 // dtype.width)
227
248
  convert_from_dlpack(t)
228
249
  for t in (x, out)
229
250
  ]
230
- weight_tensor = utils.convert_from_dlpack(weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width)
231
- rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
251
+ weight_tensor = utils.convert_from_dlpack(
252
+ weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
253
+ )
254
+ rstd_tensor = (
255
+ from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
256
+ if rstd is not None
257
+ else None
258
+ )
232
259
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
233
- compile_key = (dtype, N)
260
+ compile_key = (dtype, N, rstd is not None)
234
261
  if compile_key not in rmsnorm.compile_cache:
262
+ rmsnorm_op = RMSNorm(dtype, N)
235
263
  rmsnorm.compile_cache[compile_key] = cute.compile(
236
- rmsnorm_interface, x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream, N
264
+ rmsnorm_op, x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream
237
265
  )
238
266
  rmsnorm.compile_cache[compile_key](
239
267
  x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream, eps
@@ -246,7 +274,9 @@ rmsnorm.compile_cache = {}
246
274
 
247
275
  def rmsnorm_ref(x, w, eps=1e-6):
248
276
  x_f32 = x.float()
249
- return (x_f32 / (torch.sqrt(torch.mean(x_f32 * x_f32, dim=-1, keepdim=True) + eps)) * w).to(x.dtype)
277
+ return (x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps)) * w).to(
278
+ x.dtype
279
+ )
250
280
 
251
281
 
252
282
  def rstd_ref(x, eps=1e-6):