quack-kernels 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.5"
1
+ __version__ = "0.1.7"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
quack/cross_entropy.py CHANGED
@@ -104,7 +104,10 @@ class CrossEntropy(ReductionBase):
104
104
  shape: cute.Shape = mX.shape
105
105
  idX = cute.make_identity_tensor(shape)
106
106
  # slice for CTAs
107
- gX, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, idX)]
107
+ # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
108
+ mX_off = utils.domain_offset_i64((bidx * tiler_mn[0], 0), mX)
109
+ gX = cute.local_tile(mX_off, tiler_mn, (0, cluster_y))
110
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
108
111
 
109
112
  smem = cutlass.utils.SmemAllocator()
110
113
  sX = smem.allocate_tensor(
@@ -150,7 +153,9 @@ class CrossEntropy(ReductionBase):
150
153
 
151
154
  target_logit = cute.Float32.zero
152
155
  if row < shape[0] and tXcX[0][1] == 0:
153
- target_logit = cute.Float32(mX[row, target])
156
+ # Use Int64 for indexing to deal with large tensors
157
+ mX_off = utils.domain_offset_i64((row, 0), mX)
158
+ target_logit = cute.Float32(mX_off[0, target])
154
159
 
155
160
  threads_per_row = tv_layout.shape[0][0]
156
161
  if cutlass.const_expr(not self.online_softmax):
@@ -363,11 +368,10 @@ class CrossEntropyBackward:
363
368
  )
364
369
 
365
370
  idX = cute.make_identity_tensor(shape)
366
-
367
- gX, gdX, cX, gTarget, gDLoss, gLse = [
368
- cute.local_tile(mT, tiler_mn, (bidx, bidy))
369
- for mT in (mX, mdX, idX, mTarget, mDLoss, mLSE)
370
- ]
371
+ # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
372
+ mX, mdX = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mdX)]
373
+ gX, gdX = [cute.local_tile(mT, tiler_mn, (0, bidy)) for mT in (mX, mdX)]
374
+ cX = cute.local_tile(idX, tiler_mn, (bidx, bidy))
371
375
 
372
376
  copy_atom_load_X = cute.make_copy_atom(
373
377
  cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128
quack/layernorm.py ADDED
@@ -0,0 +1,351 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+
4
+ import torch
5
+ from typing import Optional
6
+
7
+ import cuda.bindings.driver as cuda
8
+
9
+ import cutlass
10
+ import cutlass.cute as cute
11
+ from cutlass.cute.runtime import from_dlpack
12
+ import quack.utils as utils
13
+ from quack.reduction_base import ReductionBase, torch2cute_dtype_map
14
+
15
+
16
+ class LayerNorm(ReductionBase):
17
+ def __init__(self, dtype: cutlass.Numeric, N: int):
18
+ super().__init__(dtype, N, stage=2) # 2 stages for mean and var
19
+ self.reload_from = None if N <= 16384 else "smem"
20
+ self.delay_w_load = False
21
+
22
+ def _calculate_threads_per_row(self):
23
+ N = self.N
24
+ return (
25
+ 8
26
+ if N <= 64
27
+ else (
28
+ 16
29
+ if N <= 128
30
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
31
+ )
32
+ )
33
+
34
+ def _set_cluster_n(self):
35
+ N = self.N
36
+ # cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
37
+ # Similarly cluster_n = 8 is faster for N=128k
38
+ if cutlass.const_expr(self.dtype.width == 16):
39
+ cluster_n = (
40
+ 1
41
+ if N <= 16 * 1024
42
+ else (
43
+ 2
44
+ if N <= 32 * 1024
45
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
46
+ )
47
+ )
48
+ else: # fp32
49
+ cluster_n = (
50
+ 1
51
+ if N <= 32 * 1024
52
+ else (
53
+ 2
54
+ if N <= 64 * 1024
55
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
56
+ )
57
+ )
58
+ self.cluster_n = cluster_n
59
+
60
+ @cute.jit
61
+ def __call__(
62
+ self,
63
+ mX: cute.Tensor,
64
+ mW: cute.Tensor,
65
+ mO: cute.Tensor,
66
+ mRstd: Optional[cute.Tensor],
67
+ mMean: Optional[cute.Tensor],
68
+ stream: cuda.CUstream,
69
+ eps: cutlass.Float32 = 1e-6,
70
+ ):
71
+ assert mX.element_type == self.dtype
72
+ assert mO.element_type == self.dtype
73
+ self._set_cluster_n()
74
+ tiler_mn, tv_layout = self._get_tv_layout()
75
+ num_threads = cute.size(tv_layout, mode=[0])
76
+ num_warps = num_threads // cute.arch.WARP_SIZE
77
+ mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
78
+ mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
79
+ if cutlass.const_expr(mRstd is not None):
80
+ mRstd_expanded_layout = cute.append(
81
+ mRstd.layout, cute.make_layout((self.N,), stride=(0,))
82
+ )
83
+ mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
84
+ if cutlass.const_expr(mMean is not None):
85
+ mMean_expanded_layout = cute.append(
86
+ mMean.layout, cute.make_layout((self.N,), stride=(0,))
87
+ )
88
+ mMean = cute.make_tensor(mMean.iterator, mMean_expanded_layout)
89
+ self.kernel(mX, mW, mO, mRstd, mMean, eps, tv_layout, tiler_mn, self.reload_from).launch(
90
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
91
+ block=[num_threads, 1, 1],
92
+ cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
93
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
94
+ stream=stream,
95
+ )
96
+
97
+ @cute.kernel
98
+ def kernel(
99
+ self,
100
+ mX: cute.Tensor,
101
+ mW: cute.Tensor,
102
+ mO: cute.Tensor,
103
+ mRstd: Optional[cute.Tensor],
104
+ mMean: Optional[cute.Tensor],
105
+ eps: cute.Float32,
106
+ tv_layout: cute.Layout,
107
+ tiler_mn: cute.Shape,
108
+ reload_from: cutlass.Constexpr = None,
109
+ delay_w_load: cutlass.Constexpr = False,
110
+ ):
111
+ tidx, _, _ = cute.arch.thread_idx()
112
+ bidx, _, _ = cute.arch.block_idx()
113
+ if cutlass.const_expr(self.cluster_n > 1):
114
+ cluster_y = cute.arch.block_idx()[1]
115
+ else:
116
+ cluster_y = cutlass.const_expr(0)
117
+
118
+ smem = cutlass.utils.SmemAllocator()
119
+ sX = smem.allocate_tensor(
120
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
121
+ )
122
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
123
+
124
+ shape = mX.shape
125
+ idX = cute.make_identity_tensor(shape)
126
+ # slice for CTAs
127
+ # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
128
+ mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
129
+ gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
130
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
131
+ gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
132
+ gRstd = (
133
+ cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
134
+ if cutlass.const_expr(mRstd is not None)
135
+ else None
136
+ )
137
+ gMean = (
138
+ cute.local_tile(mMean, tiler_mn, (bidx, cluster_y))
139
+ if cutlass.const_expr(mMean is not None)
140
+ else None
141
+ )
142
+
143
+ # declare the atoms which will be used later for memory copy
144
+ copy_atom_load_X = cute.make_copy_atom(
145
+ cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
146
+ )
147
+ copy_atom_load_X_async = cute.make_copy_atom(
148
+ cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
149
+ )
150
+ copy_atom_load_W = cute.make_copy_atom(
151
+ cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
152
+ )
153
+ copy_atom_store_O = cute.make_copy_atom(
154
+ cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=128
155
+ )
156
+
157
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
158
+ tidx
159
+ )
160
+ thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
161
+ thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
162
+
163
+ tWgW = thr_copy_W.partition_S(gW)
164
+ tXgX = thr_copy_X.partition_S(gX)
165
+ tXsX = thr_copy_X.partition_D(sX)
166
+ tXgO = thr_copy_O.partition_D(gO)
167
+ tXrRstd = thr_copy_O.partition_D(gRstd) if cutlass.const_expr(mRstd is not None) else None
168
+ tXrMean = thr_copy_O.partition_D(gMean) if cutlass.const_expr(mMean is not None) else None
169
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
170
+
171
+ # allocate fragments for gmem->rmem
172
+ tWrW = cute.make_fragment_like(tWgW)
173
+ tXrW = thr_copy_X.retile(tWrW)
174
+ tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
175
+
176
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
177
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
178
+
179
+ tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
180
+ row = tXcX[0][0]
181
+ if row < shape[0]:
182
+ cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
183
+ cute.arch.cp_async_commit_group()
184
+
185
+ tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
186
+ if cutlass.const_expr(not delay_w_load):
187
+ cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
188
+
189
+ cute.arch.cp_async_wait_group(0)
190
+ cute.autovec_copy(tXsX, tXrX)
191
+ x = tXrX.load().to(cute.Float32)
192
+ threads_per_row = tv_layout.shape[0][0]
193
+ sum_x = utils.row_reduce(
194
+ x,
195
+ cute.ReductionOp.ADD,
196
+ threads_per_row,
197
+ reduction_buffer[None, None, 0],
198
+ mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
199
+ init_val=0.0,
200
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
201
+ )
202
+ mean = sum_x / shape[1]
203
+ if cutlass.const_expr(reload_from == "smem"):
204
+ cute.autovec_copy(tXsX, tXrX)
205
+ x = tXrX.load().to(cute.Float32)
206
+ elif cutlass.const_expr(reload_from == "gmem"):
207
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
208
+ x = tXrX.load().to(cute.Float32)
209
+
210
+ sum_sq_x_sub_mean = utils.row_reduce(
211
+ (x - mean) * (x - mean),
212
+ cute.ReductionOp.ADD,
213
+ threads_per_row,
214
+ reduction_buffer[None, None, 1],
215
+ mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
216
+ init_val=0.0,
217
+ )
218
+ rstd = utils.rsqrt(sum_sq_x_sub_mean / shape[1] + eps)
219
+ if cutlass.const_expr(mRstd is not None):
220
+ # Only the thread corresponding to column 0 writes out the rstd to gmem
221
+ if (
222
+ tXcX[0][1] == 0
223
+ and row < shape[0]
224
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
225
+ ):
226
+ tXrRstd[0] = rstd
227
+ if cutlass.const_expr(mMean is not None):
228
+ # Only the thread corresponding to column 0 writes out the mean to gmem
229
+ if (
230
+ tXcX[0][1] == 0
231
+ and row < shape[0]
232
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
233
+ ):
234
+ tXrMean[0] = mean
235
+ if cutlass.const_expr(delay_w_load):
236
+ cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
237
+ if cutlass.const_expr(reload_from == "smem"):
238
+ cute.autovec_copy(tXsX, tXrX)
239
+ x = tXrX.load().to(cute.Float32)
240
+ elif cutlass.const_expr(reload_from == "gmem"):
241
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
242
+ x = tXrX.load().to(cute.Float32)
243
+ x_hat = (x - mean) * rstd
244
+ w = tXrW.load().to(cute.Float32)
245
+ y = x_hat * w
246
+ tXrO.store(y.to(tXrO.element_type))
247
+ tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
248
+ if row < shape[0]:
249
+ cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
250
+
251
+
252
+ def layernorm(
253
+ x: torch.Tensor,
254
+ weight: torch.Tensor,
255
+ eps: float = 1e-6,
256
+ return_rstd: bool = False,
257
+ return_mean: bool = False,
258
+ ) -> torch.Tensor:
259
+ """LayerNorm forward pass.
260
+
261
+ Args:
262
+ x: Input tensor of shape (M, N)
263
+ weight: Weight tensor of shape (N,)
264
+ eps: Small value for numerical stability
265
+ return_rstd: Whether to return the reciprocal standard deviation
266
+ return_mean: Whether to return the mean
267
+
268
+ Returns:
269
+ Normalized output tensor of same shape as x
270
+ If return_rstd is True, also returns rstd tensor of shape (M,)
271
+ If return_mean is True, also returns mean tensor of shape (M,)
272
+ """
273
+ assert x.dim() == 2, "Input must be 2D"
274
+ assert weight.dim() == 1, "Weight must be 1D"
275
+ assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
276
+ assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
277
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
278
+ assert weight.dtype == torch.float32, "Weight must be float32"
279
+ M, N = x.shape
280
+ device = x.device
281
+ out = torch.empty_like(x)
282
+ rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
283
+ mean = torch.empty(M, device=device, dtype=torch.float32) if return_mean else None
284
+ dtype = torch2cute_dtype_map[x.dtype]
285
+ convert_from_dlpack = lambda x: (
286
+ from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
287
+ mode=0, stride_order=(0, 1)
288
+ )
289
+ )
290
+ x_tensor, out_tensor = [
291
+ # utils.convert_from_dlpack(t, leading_dim=t.ndim - 1, divisibility=128 // dtype.width)
292
+ convert_from_dlpack(t)
293
+ for t in (x, out)
294
+ ]
295
+ weight_tensor = utils.convert_from_dlpack(
296
+ weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
297
+ )
298
+ rstd_tensor = (
299
+ from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
300
+ if rstd is not None
301
+ else None
302
+ )
303
+ mean_tensor = (
304
+ from_dlpack(mean.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
305
+ if mean is not None
306
+ else None
307
+ )
308
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
309
+ compile_key = (dtype, N, rstd is not None, mean is not None)
310
+ if compile_key not in layernorm.compile_cache:
311
+ rmsnorm_op = LayerNorm(dtype, N)
312
+ layernorm.compile_cache[compile_key] = cute.compile(
313
+ rmsnorm_op,
314
+ x_tensor,
315
+ weight_tensor,
316
+ out_tensor,
317
+ rstd_tensor,
318
+ mean_tensor,
319
+ current_stream,
320
+ )
321
+ layernorm.compile_cache[compile_key](
322
+ x_tensor, weight_tensor, out_tensor, rstd_tensor, mean_tensor, current_stream, eps
323
+ )
324
+ return (
325
+ (out, rstd, mean)
326
+ if return_mean and return_rstd
327
+ else (
328
+ (out, rstd)
329
+ if return_rstd and not return_mean
330
+ else ((out, mean) if return_mean and not return_rstd else (out))
331
+ )
332
+ )
333
+
334
+
335
+ layernorm.compile_cache = {}
336
+
337
+
338
+ def layernorm_ref(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6):
339
+ x_f32 = x.float()
340
+ return torch.nn.functional.layer_norm(x_f32, w.shape, w, None, eps).to(x.dtype)
341
+
342
+
343
+ def rstd_ref(x: torch.Tensor, eps: float = 1e-6):
344
+ x_f32 = x.float()
345
+ mean = x_f32.mean(dim=-1, keepdim=True)
346
+ var = ((x_f32 - mean) ** 2).mean(dim=-1)
347
+ return 1.0 / torch.sqrt(var + eps)
348
+
349
+
350
+ def mean_ref(x: torch.Tensor) -> torch.Tensor:
351
+ return x.float().mean(dim=-1)
quack/reduction_base.py CHANGED
@@ -68,7 +68,7 @@ class ReductionBase:
68
68
  )
69
69
 
70
70
  def _allocate_reduction_buffer_and_mbar(
71
- self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout
71
+ self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout, is_persistent: bool = False
72
72
  ) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
73
73
  reduction_buffer = smem.allocate_tensor(
74
74
  self.reduction_dtype,
@@ -76,20 +76,28 @@ class ReductionBase:
76
76
  byte_alignment=4,
77
77
  )
78
78
  if cutlass.const_expr(self.cluster_n > 1):
79
- mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=self.stage)
79
+ mbar_ptr = smem.allocate_array(
80
+ cutlass.Int64, num_elems=self.stage if not is_persistent else self.stage * 2
81
+ )
80
82
  else:
81
83
  mbar_ptr = None
82
84
  return reduction_buffer, mbar_ptr
83
85
 
84
86
  @cute.jit
85
- def _initialize_cluster(self, tidx: cutlass.Int32, mbar_ptr: cute.Pointer, num_warps: int):
87
+ def _initialize_cluster(
88
+ self,
89
+ tidx: cutlass.Int32,
90
+ mbar_ptr: cute.Pointer,
91
+ num_warps: int,
92
+ is_persistent: bool = False,
93
+ ):
86
94
  if cutlass.const_expr(self.cluster_n > 1):
87
- if tidx < self.stage:
95
+ if tidx < self.stage: # Initialize full barrier
88
96
  cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
97
+ if cutlass.const_expr(is_persistent): # Initialize empty barrier
98
+ cute.arch.mbarrier_init(
99
+ mbar_ptr + self.stage + tidx, num_warps * self.cluster_n
100
+ )
89
101
  cute.arch.mbarrier_init_fence()
90
- if tidx < self.stage:
91
- cute.arch.mbarrier_arrive_and_expect_tx(
92
- mbar_ptr + tidx, num_warps * self.cluster_n * self.reduction_dtype.width // 8
93
- )
94
102
  # Cluster arrive after barrier init
95
103
  cute.arch.cluster_arrive_relaxed()
quack/rmsnorm.py CHANGED
@@ -1,6 +1,5 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
-
4
3
  import torch
5
4
  from typing import Optional
6
5
 
@@ -117,7 +116,10 @@ class RMSNorm(ReductionBase):
117
116
  shape = mX.shape
118
117
  idX = cute.make_identity_tensor(shape)
119
118
  # slice for CTAs
120
- gX, gO, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)]
119
+ # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
120
+ mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
121
+ gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
122
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
121
123
  gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
122
124
  gRstd = (
123
125
  cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
@@ -154,6 +156,7 @@ class RMSNorm(ReductionBase):
154
156
 
155
157
  # allocate fragments for gmem->rmem
156
158
  tWrW = cute.make_fragment_like(tWgW)
159
+ tWrW.fill(0.0)
157
160
  tXrW = thr_copy_X.retile(tWrW)
158
161
  tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
159
162
 
@@ -297,8 +300,14 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
297
300
 
298
301
  class RMSNormBackward(ReductionBase):
299
302
  def __init__(self, dtype: cutlass.Numeric, N: int):
300
- # 1 stage for computing mean of x_hat * wdy
301
- super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Float32)
303
+ # 2 stages for double buffering when computing mean of x_hat * wdy
304
+ super().__init__(dtype, N, stage=2, reduction_dtype=cutlass.Float32)
305
+ if self.N > 128 * 1024 and self.dtype.width >= 32:
306
+ # Not enough smem
307
+ raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits")
308
+
309
+ def _get_num_threads(self):
310
+ return 128 if self.N <= 4096 else 256
302
311
 
303
312
  def _calculate_threads_per_row(self):
304
313
  N = self.N
@@ -308,44 +317,38 @@ class RMSNormBackward(ReductionBase):
308
317
  else (
309
318
  16
310
319
  if N <= 128
311
- else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
320
+ else (32 if N <= 256 else (64 if N <= 512 else (128 if N <= 4096 else 256)))
312
321
  )
313
322
  )
314
323
 
315
324
  def _set_cluster_n(self):
316
325
  N = self.N
317
- if cutlass.const_expr(self.dtype.width == 16):
318
- cluster_n = (
319
- 1
320
- if N <= 16 * 1024
321
- else (
322
- 2
323
- if N <= 32 * 1024
324
- else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
325
- )
326
- )
327
- else: # fp32
328
- cluster_n = (
329
- 1
330
- if N <= 32 * 1024
331
- else (
332
- 2
333
- if N <= 64 * 1024
334
- else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
335
- )
336
- )
326
+ cluster_n = (
327
+ 1
328
+ if N <= 8 * 1024
329
+ else (2 if N <= 16 * 1024 else (4 if N <= 32 * 1024 else (8 if N <= 64 * 1024 else 16)))
330
+ )
337
331
  self.cluster_n = cluster_n
338
332
 
333
+ def _smem_size_in_bytes(self, tiler_mn, num_warps):
334
+ return (
335
+ # Multiply by 2 since we need space for X and dOut,
336
+ # and multiply by another 2 due to double buffering
337
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2 * 2
338
+ + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
339
+ + self.stage * (cutlass.Int64.width // 8) * 2 # mult 2 as we need 2 mbar per stage
340
+ )
341
+
339
342
  @cute.jit
340
343
  def __call__(
341
344
  self,
342
345
  mX: cute.Tensor,
343
346
  mW: cute.Tensor,
344
- mDout: cute.Tensor,
347
+ mdOut: cute.Tensor,
345
348
  mRstd: cute.Tensor,
346
- mDx: cute.Tensor,
347
- mDw: cute.Tensor,
348
- sm_count: cutlass.Constexpr,
349
+ mdX: cute.Tensor,
350
+ mdW: cute.Tensor,
351
+ sm_count: cutlass.Int32,
349
352
  stream: cuda.CUstream,
350
353
  ):
351
354
  self._set_cluster_n()
@@ -356,14 +359,8 @@ class RMSNormBackward(ReductionBase):
356
359
  mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
357
360
  mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
358
361
 
359
- mRstd_expanded_layout = cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,)))
360
- mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
361
-
362
- num_blocks = (
363
- sm_count if tiler_mn[0] == 1 else min(sm_count, cute.ceil_div(1024, tiler_mn[0]))
364
- )
365
-
366
- self.kernel(mX, mW, mDout, mRstd, mDx, mDw, sm_count, tv_layout, tiler_mn).launch(
362
+ num_blocks = sm_count
363
+ self.kernel(mX, mW, mdOut, mRstd, mdX, mdW, tv_layout, tiler_mn).launch(
367
364
  grid=[num_blocks, self.cluster_n, 1],
368
365
  block=[num_threads, 1, 1],
369
366
  cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
@@ -376,177 +373,244 @@ class RMSNormBackward(ReductionBase):
376
373
  self,
377
374
  mX: cute.Tensor,
378
375
  mW: cute.Tensor,
379
- mDout: cute.Tensor,
376
+ mdOut: cute.Tensor,
380
377
  mRstd: cute.Tensor,
381
- mDx: cute.Tensor,
382
- mDw: cute.Tensor,
383
- sm_count: cutlass.Constexpr,
378
+ mdX: cute.Tensor,
379
+ mdW: cute.Tensor,
384
380
  tv_layout: cute.Layout,
385
381
  tiler_mn: cute.Shape,
386
382
  ):
387
383
  tidx, _, _ = cute.arch.thread_idx()
388
- bidx, cluster_y, _ = cute.arch.block_idx()
384
+ bidx_start, _, _ = cute.arch.block_idx()
389
385
  gdim, _, _ = cute.arch.grid_dim()
386
+ if cutlass.const_expr(self.cluster_n > 1):
387
+ cluster_y = cute.arch.block_idx()[1]
388
+ else:
389
+ cluster_y = cutlass.const_expr(0)
390
390
 
391
391
  shape = mX.shape
392
392
  M, N = shape[0], shape[1]
393
+ is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
393
394
 
394
395
  idX = cute.make_identity_tensor(shape)
395
396
 
396
397
  smem = cutlass.utils.SmemAllocator()
397
- reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
398
+ smem_layout = cute.make_ordered_layout((tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2))
399
+ sX = smem.allocate_tensor(mX.element_type, smem_layout, byte_alignment=16)
400
+ sdOut = smem.allocate_tensor(mdOut.element_type, smem_layout, byte_alignment=16)
401
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
402
+ smem, tv_layout, is_persistent=True
403
+ )
404
+ if cutlass.const_expr(mbar_ptr is not None):
405
+ mbar_full_ptr, mbar_empty_ptr = mbar_ptr, mbar_ptr + 2
406
+ else:
407
+ mbar_full_ptr, mbar_empty_ptr = None, None
398
408
 
399
409
  copy_atom_load_X = cute.make_copy_atom(
400
410
  cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
401
411
  )
402
-
412
+ copy_atom_load_X_async = cute.make_copy_atom(
413
+ cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
414
+ )
403
415
  copy_atom_load_W = cute.make_copy_atom(
404
416
  cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
405
417
  )
406
-
407
418
  copy_atom_store_dX = cute.make_copy_atom(
408
- cute.nvgpu.CopyUniversalOp(), mDx.element_type, num_bits_per_copy=128
419
+ cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=128
409
420
  )
410
-
411
- copy_atom_dw = cute.make_copy_atom(
412
- cute.nvgpu.CopyUniversalOp(), mDw.element_type, num_bits_per_copy=128
421
+ copy_atom_store_dW = cute.make_copy_atom(
422
+ cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=128
413
423
  )
414
424
 
415
425
  thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
426
+ thr_copy_X_async = cute.make_tiled_copy(
427
+ copy_atom_load_X_async, tv_layout, tiler_mn
428
+ ).get_slice(tidx)
416
429
  thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
417
- thr_copy_dw = cute.make_tiled_copy(copy_atom_dw, tv_layout, tiler_mn).get_slice(tidx)
418
- thr_store_dx = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
430
+ thr_copy_dW = cute.make_tiled_copy(copy_atom_store_dW, tv_layout, tiler_mn).get_slice(tidx)
431
+ thr_store_dX = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
419
432
 
420
- gW = cute.local_tile(mW, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
433
+ gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
421
434
  tWgW = thr_copy_W.partition_S(gW)
422
435
  tWrW = cute.make_fragment_like(tWgW)
436
+ # Need this, otherwise rW can have arbitrary values that changes the reduction
437
+ if not is_even_N:
438
+ tWrW.fill(0.0)
423
439
  tXrW = thr_copy_X.retile(tWrW)
424
440
 
425
- gW_coord = cute.local_tile(idX, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
426
-
427
- tWpW = utils.predicate_k(thr_copy_W.partition_S(gW_coord), limit=shape[1])
441
+ gW_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
442
+ tWpW = (
443
+ utils.predicate_k(thr_copy_W.partition_S(gW_coord), limit=shape[1])
444
+ if not is_even_N
445
+ else None
446
+ )
428
447
  cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
429
448
  weight = tXrW.load().to(cute.Float32)
430
449
 
431
450
  num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
432
451
 
433
- self._initialize_cluster(tidx, mbar_ptr, num_warps)
434
-
435
- dw_coord = cute.local_tile(idX, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
436
- tDwpDw = utils.predicate_k(thr_copy_dw.partition_S(dw_coord), limit=shape[1])
452
+ self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True)
437
453
 
438
- gDw = cute.local_tile(mDw, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
439
- tDwgDw = thr_copy_dw.partition_D(gDw)
440
- tDwrDw = cute.make_fragment_like(tDwgDw)
441
- dw_accumulator = thr_copy_X.retile(tDwrDw)
442
- dw_accumulator.fill(0.0)
443
-
444
- M_pad = ((M + sm_count - 1) // sm_count) * sm_count
454
+ dw_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
455
+ tdWpdW = (
456
+ utils.predicate_k(thr_copy_dW.partition_S(dw_coord), limit=shape[1])
457
+ if not is_even_N
458
+ else None
459
+ )
445
460
 
446
- jump = sm_count if tiler_mn[0] == 1 else min(sm_count, cute.ceil_div(1024, tiler_mn[0]))
461
+ gdW = cute.local_tile(mdW, (1, tiler_mn[1]), (bidx_start, cluster_y))
462
+ tdWgdW = thr_copy_dW.partition_D(gdW)
463
+ tdWrdW = cute.make_fragment_like(tdWgdW, cutlass.Float32)
464
+ tXrdW = thr_copy_X.retile(tdWrdW)
447
465
 
448
- if cutlass.const_expr(self.cluster_n > 1):
449
- cute.arch.cluster_arrive()
450
- cute.arch.cluster_wait()
466
+ gX = cute.local_tile(mX, tiler_mn, (None, cluster_y))
467
+ gdOut = cute.local_tile(mdOut, tiler_mn, (None, cluster_y))
468
+ gdX = cute.local_tile(mdX, tiler_mn, (None, cluster_y))
469
+ cX = cute.local_tile(idX, tiler_mn, (None, cluster_y))
470
+ tXgX = thr_copy_X.partition_S(gX)
471
+ tXsX = thr_copy_X.partition_D(sX)
472
+ tXgdOut = thr_copy_X.partition_S(gdOut)
473
+ tXsdOut = thr_copy_X.partition_D(sdOut)
474
+ tXgdX = thr_store_dX.partition_D(gdX)
475
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None]
476
+ # This doesn't change across iterations
477
+ tXpX = (
478
+ utils.predicate_k(thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1])
479
+ if not is_even_N
480
+ else None
481
+ )
451
482
 
452
- ## need to update range_dynamic since it will be deprecated soon
453
- for row_offset in cutlass.range_dynamic(bidx, M_pad, jump):
454
- gX = cute.local_tile(
455
- mX, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
483
+ tXrX, tXrdOut, tXrdX = [
484
+ cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdOut, tXgdX)
485
+ ]
486
+
487
+ # Prefetch the first batch
488
+ row = tXcX[None, None, None, bidx_start][0][0]
489
+ if row < M:
490
+ tXgX_cur = utils.coord_offset_i64(bidx_start, tXgX, dim=3)[None, None, None, 0]
491
+ tXgdOut_cur = utils.coord_offset_i64(bidx_start, tXgdOut, dim=3)[None, None, None, 0]
492
+ cute.copy(
493
+ copy_atom_load_X_async,
494
+ tXgX_cur,
495
+ tXsX[None, None, None, 0],
496
+ pred=tXpX,
456
497
  )
457
- gDout = cute.local_tile(
458
- mDout, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
498
+ cute.copy(
499
+ copy_atom_load_X_async,
500
+ tXgdOut_cur,
501
+ tXsdOut[None, None, None, 0],
502
+ pred=tXpX,
459
503
  )
460
- gRstd = cute.local_tile(
461
- mRstd, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
462
- )
463
- gDx = cute.local_tile(
464
- mDx, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
465
- )
466
- cX = cute.local_tile(
467
- idX, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
468
- )
469
-
470
- tXgX = thr_copy_X.partition_S(gX)
471
- thrDout = thr_copy_X.partition_S(gDout)
472
- tXrRstd = thr_copy_W.partition_S(gRstd)
473
- thrDx = thr_store_dx.partition_D(gDx)
474
- tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
475
-
476
- tXrX, frgDout, frgDx = [cute.make_fragment_like(thr) for thr in (tXgX, thrDout, thrDx)]
477
-
478
- tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
504
+ elif tiler_mn[0] > 1:
505
+ # Fill with zero, otherwise smem will be uninitialized, and we could read this back
506
+ # later into registers, causing wrong dW.
507
+ utils.fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero)
508
+ utils.fill_oob(tXsdOut[None, None, None, 0], None, fill_value=mdOut.element_type.zero)
509
+ cute.arch.cp_async_commit_group()
479
510
 
480
- if tXcX[0][0] < shape[0]:
481
- cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
482
- cute.copy(copy_atom_load_X, thrDout, frgDout, pred=tXpX)
511
+ if cutlass.const_expr(self.cluster_n > 1):
512
+ cute.arch.cluster_wait()
483
513
 
514
+ threads_per_row = tv_layout.shape[0][0]
515
+ tXrdW.fill(0.0)
516
+ stage = cutlass.Int32(0)
517
+ producer_phase = cutlass.Int32(1)
518
+ consumer_phase = cutlass.Int32(0)
519
+ for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
520
+ row = tXcX[None, None, None, bidx][0][0]
521
+ rstd = cutlass.Float.zero
522
+ if row + gdim * tiler_mn[0] < M: # Prefetch the next batch
523
+ tXgX_cur = utils.coord_offset_i64(bidx + gdim, tXgX, dim=3)[None, None, None, 0]
524
+ tXgdOut_cur = utils.coord_offset_i64(bidx + gdim, tXgdOut, dim=3)[
525
+ None, None, None, 0
526
+ ]
527
+ cute.copy(
528
+ copy_atom_load_X_async,
529
+ tXgX_cur,
530
+ tXsX[None, None, None, stage ^ 1],
531
+ pred=tXpX,
532
+ )
533
+ cute.copy(
534
+ copy_atom_load_X_async,
535
+ tXgdOut_cur,
536
+ tXsdOut[None, None, None, stage ^ 1],
537
+ pred=tXpX,
538
+ )
539
+ elif tiler_mn[0] > 1:
540
+ utils.fill_oob(
541
+ tXsX[None, None, None, stage ^ 1], None, fill_value=mX.element_type.zero
542
+ )
543
+ utils.fill_oob(
544
+ tXsdOut[None, None, None, stage ^ 1], None, fill_value=mdOut.element_type.zero
545
+ )
546
+ cute.arch.cp_async_commit_group()
547
+ if row < M or tiler_mn[0] == 1:
548
+ rstd = mRstd[row]
549
+ cute.arch.cp_async_wait_group(1)
550
+ cute.autovec_copy(tXsX[None, None, None, stage], tXrX)
484
551
  x = tXrX.load().to(cute.Float32)
485
- dout = frgDout.load().to(cute.Float32)
486
-
487
- rstd = tXrRstd[0]
552
+ cute.autovec_copy(tXsdOut[None, None, None, stage], tXrdOut)
553
+ dout = tXrdOut.load().to(cute.Float32)
488
554
  x_hat = x * rstd
489
555
  wdy = dout * weight
490
-
491
- threads_per_row = tv_layout.shape[0][0]
492
-
493
- row = tXcX[0][0]
494
556
  if cutlass.const_expr(self.cluster_n > 1):
495
- cute.arch.cluster_arrive()
496
- cute.arch.cluster_wait()
497
- else:
498
- cute.arch.barrier()
499
-
557
+ cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
500
558
  mean_xhat_wdy = (
501
559
  utils.row_reduce(
502
560
  x_hat * wdy,
503
561
  cute.ReductionOp.ADD,
504
562
  threads_per_row,
505
- reduction_buffer[None, None, 0],
506
- mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
563
+ reduction_buffer[None, None, stage],
564
+ mbar_full_ptr + stage if cutlass.const_expr(self.cluster_n > 1) else None,
565
+ phase=consumer_phase,
507
566
  init_val=0.0,
508
- hook_fn=cute.arch.cluster_wait
509
- if cutlass.const_expr(self.cluster_n > 1)
510
- else None,
511
567
  )
512
568
  / shape[1]
513
569
  )
514
-
515
- dx = (wdy - x_hat * mean_xhat_wdy) * rstd
516
- frgDx.store(dx.to(frgDout.element_type))
517
-
518
- if row < M:
519
- cute.copy(copy_atom_store_dX, frgDx, thrDx, pred=tXpX)
520
-
521
570
  if cutlass.const_expr(self.cluster_n > 1):
522
- cute.arch.cluster_arrive()
523
- cute.arch.cluster_wait()
524
- else:
525
- cute.arch.barrier()
526
-
527
- if row < M:
528
- dw_row = dout * x_hat
529
- current_dw = dw_accumulator.load().to(cute.Float32)
530
- updated_dw = current_dw + dw_row
531
- dw_accumulator.store(updated_dw.to(dw_accumulator.element_type))
532
-
533
- """
534
- if cutlass.const_expr(self.cluster_n > 1):
535
- cute.arch.cluster_arrive()
536
- cute.arch.cluster_wait()
537
- else:
538
- cute.arch.barrier()
539
- """
540
- """
541
- if cutlass.const_expr(self.cluster_n > 1):
542
- cute.arch.cluster_arrive()
543
- cute.arch.cluster_wait()
544
- else:
571
+ # It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
572
+ # Requires adjusting the thread_count when initializing the mbar
573
+ cute.arch.sync_warp()
574
+ lane_idx = cute.arch.lane_idx()
575
+ if lane_idx < self.cluster_n:
576
+ cute.arch.mbarrier_arrive(
577
+ mbar_empty_ptr + stage, peer_cta_rank_in_cluster=lane_idx
578
+ )
579
+ dx = (wdy - x_hat * mean_xhat_wdy) * rstd
580
+ tXrdX.store(dx.to(tXrdOut.element_type))
581
+ if row < M or tiler_mn[0] == 1:
582
+ tXgdX_cur = utils.coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
583
+ cute.copy(copy_atom_store_dX, tXrdX, tXgdX_cur, pred=tXpX)
584
+ tXrdW.store(tXrdW.load() + dout * x_hat)
585
+ stage ^= 1
586
+ if stage == 0:
587
+ consumer_phase ^= 1
588
+ producer_phase ^= 1
589
+
590
+ if cutlass.const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
591
+ cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
592
+
593
+ if cutlass.const_expr(tiler_mn[0] > 1):
594
+ # reduction of dw_partial within the same threadblock
595
+ sdW = cute.make_tensor(
596
+ cute.recast_ptr(sX.iterator, dtype=cute.Float32),
597
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
598
+ )
599
+ tXsdW = thr_copy_X.partition_D(sdW)
545
600
  cute.arch.barrier()
546
- """
547
-
548
- cute.autovec_copy(dw_accumulator, tDwrDw)
549
- cute.copy(copy_atom_dw, tDwrDw, tDwgDw, pred=tDwpDw)
601
+ row = tXcX[None, None, None, 0][0][0]
602
+ if row > 0:
603
+ cute.autovec_copy(tXrdW, tXsdW)
604
+ cute.arch.barrier()
605
+ if row == 0:
606
+ for i in cutlass.range_constexpr(1, cutlass.const_expr(tiler_mn[0])):
607
+ tXrdW_other = cute.make_fragment_like(tXrdW)
608
+ tXsdW_other = cute.make_tensor(tXsdW.iterator + i * sdW.stride[0], tXsdW.layout)
609
+ cute.autovec_copy(tXsdW_other, tXrdW_other)
610
+ tXrdW.store(tXrdW.load() + tXrdW_other.load())
611
+ cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
612
+ else:
613
+ cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
550
614
 
551
615
 
552
616
  def _rmsnorm_backward(
@@ -578,8 +642,19 @@ def _rmsnorm_backward(
578
642
 
579
643
  device = x.device
580
644
 
581
- sm_count = torch.cuda.get_device_properties(device).multi_processor_count * 8
582
- dw_partial = torch.zeros((sm_count, N), device=device, dtype=weight.dtype)
645
+ # This should be tuned on how many CTAs can be launched on each SM
646
+ sm_count_multiple = (
647
+ 16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
648
+ )
649
+ sm_count = torch.cuda.get_device_properties(device).multi_processor_count
650
+ # By right, if we're using cluster, this should be cluster_count not sm_count.
651
+ # But for cluster >= 4, due to quantization we would need to query active max cluster.
652
+ # Instead we just do sm_count * 2, which is reasonably larger than active_cluster_count to
653
+ # avoid wave quantization.
654
+ sm_count = (
655
+ sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
656
+ )
657
+ dw_partial = torch.empty(sm_count, N, device=device, dtype=weight.dtype)
583
658
 
584
659
  dtype = torch2cute_dtype_map[x.dtype]
585
660
 
@@ -622,6 +697,7 @@ def _rmsnorm_backward(
622
697
  rstd_tensor,
623
698
  dx_tensor,
624
699
  dw_partial_tensor,
700
+ sm_count,
625
701
  current_stream,
626
702
  )
627
703
 
quack/softmax.py CHANGED
@@ -98,7 +98,10 @@ class Softmax(ReductionBase):
98
98
  shape = mX.shape
99
99
  idX = cute.make_identity_tensor(shape)
100
100
  # slice for CTAs
101
- gX, gO, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)]
101
+ # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
102
+ mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
103
+ gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
104
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
102
105
 
103
106
  smem = cutlass.utils.SmemAllocator()
104
107
  sX = smem.allocate_tensor(
@@ -130,9 +133,7 @@ class Softmax(ReductionBase):
130
133
 
131
134
  is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
132
135
  tXpX = (
133
- utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
134
- if cutlass.const_expr(not is_even_N)
135
- else None
136
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
136
137
  )
137
138
  if tXcX[0][0] < shape[0]:
138
139
  cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
@@ -312,9 +313,11 @@ class SoftmaxBackward(ReductionBase):
312
313
  shape = mdY.shape
313
314
  idX = cute.make_identity_tensor(shape)
314
315
  # slice for CTAs
315
- gdY, gY, gdX, cX = [
316
- cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mdY, mY, mdX, idX)
316
+ mdY, mY, mdX = [
317
+ utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mdY, mY, mdX)
317
318
  ]
319
+ gdY, gY, gdX = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mdY, mY, mdX)]
320
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
318
321
 
319
322
  smem = cutlass.utils.SmemAllocator()
320
323
  sdY = smem.allocate_tensor(
quack/utils.py CHANGED
@@ -120,12 +120,20 @@ def cluster_reduce(
120
120
  reduction_buffer: cute.Tensor,
121
121
  mbar_ptr: cute.Pointer,
122
122
  init_val: cute.Numeric = 0.0,
123
+ phase: Optional[cutlass.Int32] = None,
123
124
  ) -> cute.Numeric:
124
125
  """reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
125
126
  cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
126
127
  lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
127
- warps_per_row, cluster_n = reduction_buffer.shape[1]
128
+ rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
128
129
  row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
130
+ if warp_idx == 0:
131
+ with cute.arch.elect_one():
132
+ num_warps = rows_per_block * warps_per_row
133
+ cute.arch.mbarrier_arrive_and_expect_tx(
134
+ mbar_ptr,
135
+ num_warps * cluster_n * reduction_buffer.element_type.width // 8,
136
+ )
129
137
  if lane_idx < cluster_n:
130
138
  store_shared_remote(
131
139
  val,
@@ -133,7 +141,7 @@ def cluster_reduce(
133
141
  mbar_ptr,
134
142
  peer_cta_rank_in_cluster=lane_idx,
135
143
  )
136
- cute.arch.mbarrier_wait(mbar_ptr, phase=0)
144
+ cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
137
145
  block_reduce_val = init_val
138
146
  num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
139
147
  for i in cutlass.range_constexpr(num_iter):
@@ -149,13 +157,14 @@ def block_or_cluster_reduce(
149
157
  op: Callable,
150
158
  reduction_buffer: cute.Tensor,
151
159
  mbar_ptr: Optional[cute.Pointer],
160
+ phase: Optional[cutlass.Int32] = None,
152
161
  init_val: cute.Numeric = 0.0,
153
162
  ) -> cute.Numeric:
154
163
  """Perform either block or cluster reduction based on whether mbar_ptr is provided."""
155
164
  if cutlass.const_expr(mbar_ptr is None):
156
165
  return block_reduce(val, op, reduction_buffer, init_val=init_val)
157
166
  else:
158
- return cluster_reduce(val, op, reduction_buffer, mbar_ptr, init_val=init_val)
167
+ return cluster_reduce(val, op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val)
159
168
 
160
169
 
161
170
  @cute.jit
@@ -165,6 +174,7 @@ def row_reduce(
165
174
  threads_per_row: cutlass.Constexpr[int],
166
175
  reduction_buffer: Optional[cute.Tensor] = None,
167
176
  mbar_ptr: Optional[cute.Pointer] = None,
177
+ phase: Optional[cutlass.Int32] = None,
168
178
  init_val: cute.Numeric = 0.0,
169
179
  hook_fn: Optional[Callable] = None,
170
180
  ) -> cute.Numeric:
@@ -193,7 +203,7 @@ def row_reduce(
193
203
  ), "mbar_ptr must be provided for cluster reduction"
194
204
  if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
195
205
  val = block_or_cluster_reduce(
196
- val, warp_op, reduction_buffer, mbar_ptr, init_val=init_val
206
+ val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val
197
207
  )
198
208
  return val
199
209
 
@@ -205,6 +215,7 @@ def online_softmax_reduce(
205
215
  reduction_buffer: Optional[cute.Tensor] = None,
206
216
  mbar_ptr: Optional[cute.Pointer] = None,
207
217
  hook_fn: Optional[Callable] = None,
218
+ phase: Optional[cutlass.Int32] = None,
208
219
  return_exp_x: bool = False,
209
220
  ) -> [Float32, Float32, Optional[cute.TensorSSA]]:
210
221
  assert x.dtype == Float32, "x must be of type Float32"
@@ -225,7 +236,7 @@ def online_softmax_reduce(
225
236
  if cutlass.const_expr(hook_fn is not None):
226
237
  hook_fn()
227
238
  if cutlass.const_expr(reduction_buffer is not None):
228
- warps_per_row, cluster_n = reduction_buffer.shape[1]
239
+ rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
229
240
  assert (
230
241
  cluster_n == 1 or mbar_ptr is not None
231
242
  ), "mbar_ptr must be provided for cluster reduction"
@@ -251,6 +262,13 @@ def online_softmax_reduce(
251
262
  max_x = max_x_final
252
263
  else:
253
264
  cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
265
+ if warp_idx == 0:
266
+ with cute.arch.elect_one():
267
+ num_warps = rows_per_block * warps_per_row
268
+ cute.arch.mbarrier_arrive_and_expect_tx(
269
+ mbar_ptr,
270
+ num_warps * cluster_n * reduction_buffer.element_type.width // 8,
271
+ )
254
272
  if lane_idx < cluster_n:
255
273
  store_shared_remote(
256
274
  f32x2_to_i64(max_x, sum_exp_x),
@@ -258,7 +276,7 @@ def online_softmax_reduce(
258
276
  mbar_ptr,
259
277
  peer_cta_rank_in_cluster=lane_idx,
260
278
  )
261
- cute.arch.mbarrier_wait(mbar_ptr, phase=0)
279
+ cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
262
280
  num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
263
281
  max_x_single_warp = cute.make_fragment(num_iter, Float32)
264
282
  max_x_single_warp.fill(-Float32.inf)
@@ -351,7 +369,7 @@ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
351
369
 
352
370
 
353
371
  @cute.jit
354
- def fill_oob(tXsX: cute.Tensor, tXpX: cute.Tensor, fill_value: cute.Numeric) -> None:
372
+ def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Numeric) -> None:
355
373
  """Fill out-of-bounds values in shared memory tensor.
356
374
 
357
375
  Args:
@@ -361,9 +379,12 @@ def fill_oob(tXsX: cute.Tensor, tXpX: cute.Tensor, fill_value: cute.Numeric) ->
361
379
  """
362
380
  tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), 0, 0])
363
381
  tXrX_fill.fill(fill_value)
364
- for rest_v in cutlass.range_constexpr(tXpX.shape[0]):
365
- for rest_k in cutlass.range_constexpr(tXpX.shape[2]):
366
- if not tXpX[rest_v, 0, rest_k]:
382
+ for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]):
383
+ for rest_k in cutlass.range_constexpr(tXsX.shape[2]):
384
+ if cutlass.const_expr(tXpX is not None):
385
+ if not tXpX[rest_v, 0, rest_k]:
386
+ cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
387
+ else:
367
388
  cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
368
389
 
369
390
 
@@ -390,3 +411,38 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
390
411
  vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip)
391
412
  )
392
413
  return res0, res1
414
+
415
+
416
+ @dsl_user_op
417
+ def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
418
+ flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
419
+ flat_stride = cute.flatten_to_tuple(tensor.stride)
420
+ assert len(flat_coord_i64) == len(
421
+ flat_stride
422
+ ), "Coordinate and stride must have the same length"
423
+ offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
424
+ assert isinstance(tensor.iterator, cute.Pointer)
425
+ # HACK: we assume that applying the offset does not change the pointer alignment
426
+ new_ptr = cute.make_ptr(
427
+ tensor.element_type,
428
+ tensor.iterator.toint() + offset * tensor.element_type.width // 8,
429
+ tensor.memspace,
430
+ assumed_align=tensor.iterator.max_alignment,
431
+ )
432
+ return cute.make_tensor(new_ptr, tensor.layout)
433
+
434
+
435
+ @dsl_user_op
436
+ def coord_offset_i64(
437
+ idx: cute.typing.Int, tensor: cute.Tensor, dim: int, *, loc=None, ip=None
438
+ ) -> cute.Tensor:
439
+ offset = cutlass.Int64(idx) * cute.size(tensor.stride[dim])
440
+ assert isinstance(tensor.iterator, cute.Pointer)
441
+ # HACK: we assume that applying the offset does not change the pointer alignment
442
+ new_ptr = cute.make_ptr(
443
+ tensor.element_type,
444
+ tensor.iterator.toint() + offset * tensor.element_type.width // 8,
445
+ tensor.memspace,
446
+ assumed_align=tensor.iterator.max_alignment,
447
+ )
448
+ return cute.make_tensor(new_ptr, tensor.layout)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.5
3
+ Version: 0.1.7
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
6
  Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
@@ -0,0 +1,12 @@
1
+ quack/__init__.py,sha256=R9cZd_vslI5oZjjS-ojfWAd9tCZAqsLUiFVqEbUaGnw,203
2
+ quack/cross_entropy.py,sha256=bg66wECki5I71SMPIRUa-6-oFJ93aIKpK1jqT__SCBM,19775
3
+ quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
4
+ quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
5
+ quack/rmsnorm.py,sha256=3jiwWhVmaG0n5vuUnGGrpg3StAB4lnzziNF97QVMLGQ,28870
6
+ quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
7
+ quack/utils.py,sha256=aiyzBc9BEwq8s965elfiR331hAaLLBKL9kDHjuls86Q,17791
8
+ quack_kernels-0.1.7.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
9
+ quack_kernels-0.1.7.dist-info/METADATA,sha256=9RlqUmX3-7BI2aZk88r84B8o2FzZkQgkfV1UxwN8GlE,289
10
+ quack_kernels-0.1.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
11
+ quack_kernels-0.1.7.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
12
+ quack_kernels-0.1.7.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- quack/__init__.py,sha256=GPoImcynY5-OkMep5RhQhXrnZyxgqZG3RoHhsYQFSL4,203
2
- quack/cross_entropy.py,sha256=WkngPY8uk4RCjCFtHtB7h9GF_8xt4NnyvDzvw73gIL4,19320
3
- quack/reduction_base.py,sha256=fFuGXPR3lDq2yw_m86ujmkni6R51jzNAzy_r9R6C8tA,3563
4
- quack/rmsnorm.py,sha256=N9NavrR85ws4cZgkfpeRLjYkVSq2yfyzJQWvfKf98pY,23935
5
- quack/softmax.py,sha256=VfhlC2huRuv7olFSVFgS8LF1yF8TFV64yjjjQxYX9yk,16364
6
- quack/utils.py,sha256=6EyWgf0z3wcbhGUivHmWB8hVBnEzMyOhmAuZ2Te82k0,15226
7
- quack_kernels-0.1.5.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
8
- quack_kernels-0.1.5.dist-info/METADATA,sha256=WI-2CP1mRH05V9Fjdx7HsErNOkrc6fUhheoH4ynlo-U,289
9
- quack_kernels-0.1.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
- quack_kernels-0.1.5.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
11
- quack_kernels-0.1.5.dist-info/RECORD,,