quack-kernels 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.6"
1
+ __version__ = "0.1.8"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
quack/cross_entropy.py CHANGED
@@ -1,16 +1,16 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
3
  import math
4
- import torch
5
4
  from typing import Optional, Type
6
5
 
7
6
  import cuda.bindings.driver as cuda
8
7
 
9
8
  import cutlass
10
9
  import cutlass.cute as cute
11
- from cutlass.cute.runtime import from_dlpack
12
10
 
13
11
  import quack.utils as utils
12
+ import torch
13
+ from cutlass.cute.runtime import from_dlpack
14
14
  from quack.reduction_base import ReductionBase, torch2cute_dtype_map
15
15
 
16
16
 
@@ -79,7 +79,7 @@ class CrossEntropy(ReductionBase):
79
79
  self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
80
80
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
81
81
  block=[num_threads, 1, 1],
82
- cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
82
+ cluster=([1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None),
83
83
  smem=self._smem_size_in_bytes(tiler_mn, num_warps),
84
84
  stream=stream,
85
85
  )
@@ -111,7 +111,9 @@ class CrossEntropy(ReductionBase):
111
111
 
112
112
  smem = cutlass.utils.SmemAllocator()
113
113
  sX = smem.allocate_tensor(
114
- mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
114
+ mX.element_type,
115
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
116
+ byte_alignment=16,
115
117
  )
116
118
  reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
117
119
 
@@ -166,7 +168,9 @@ class CrossEntropy(ReductionBase):
166
168
  reduction_buffer[None, None, 0],
167
169
  mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
168
170
  init_val=-cutlass.Float32.inf,
169
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
171
+ hook_fn=(
172
+ cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
173
+ ),
170
174
  )
171
175
  if cutlass.const_expr(self.reload_from == "smem"):
172
176
  cute.autovec_copy(tXsX, tXrX)
@@ -191,7 +195,9 @@ class CrossEntropy(ReductionBase):
191
195
  threads_per_row,
192
196
  reduction_buffer[None, None, 0],
193
197
  mbar_ptr,
194
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
198
+ hook_fn=(
199
+ cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
200
+ ),
195
201
  )
196
202
 
197
203
  if (
@@ -225,7 +231,11 @@ def _cross_entropy(
225
231
  assert target.dim() == 1, "Target must be 1D"
226
232
  assert x.shape[0] == target.shape[0], "Batch dimensions must match"
227
233
  assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
228
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
234
+ assert x.dtype in [
235
+ torch.float16,
236
+ torch.bfloat16,
237
+ torch.float32,
238
+ ], "Unsupported input dtype"
229
239
  assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
230
240
  M, N = x.shape
231
241
  device = x.device
@@ -314,13 +324,16 @@ class CrossEntropyBackward:
314
324
  num_threads = cute.size(tv_layout, mode=[0])
315
325
 
316
326
  mDLoss = cute.make_tensor(
317
- mDLoss.iterator, cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,)))
327
+ mDLoss.iterator,
328
+ cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,))),
318
329
  )
319
330
  mTarget = cute.make_tensor(
320
- mTarget.iterator, cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,)))
331
+ mTarget.iterator,
332
+ cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,))),
321
333
  )
322
334
  mLSE = cute.make_tensor(
323
- mLSE.iterator, cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,)))
335
+ mLSE.iterator,
336
+ cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,))),
324
337
  )
325
338
 
326
339
  smem_size = cute.size_in_bytes(
@@ -364,7 +377,9 @@ class CrossEntropyBackward:
364
377
 
365
378
  smem = cutlass.utils.SmemAllocator()
366
379
  sX = smem.allocate_tensor(
367
- mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
380
+ mX.element_type,
381
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
382
+ byte_alignment=16,
368
383
  )
369
384
 
370
385
  idX = cute.make_identity_tensor(shape)
@@ -474,7 +489,11 @@ def _cross_entropy_backward(
474
489
  assert (
475
490
  x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
476
491
  ), "Tensors must be on CUDA device"
477
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
492
+ assert x.dtype in [
493
+ torch.float16,
494
+ torch.bfloat16,
495
+ torch.float32,
496
+ ], "Unsupported input dtype"
478
497
  assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
479
498
 
480
499
  M, N = x.shape
@@ -532,15 +551,37 @@ class CrossEntropyFunction(torch.autograd.Function):
532
551
 
533
552
 
534
553
  def cross_entropy(
535
- x: torch.Tensor, target: torch.Tensor, inplace_backward: bool = False
554
+ x: torch.Tensor,
555
+ target: torch.Tensor,
556
+ inplace_backward: bool = True,
557
+ reduction: str = "none",
536
558
  ) -> torch.Tensor:
537
559
  """Cross entropy loss with automatic differentiation support.
538
560
 
539
561
  Args:
540
562
  x: Input logits tensor of shape (M, N)
541
563
  target: Target class indices tensor of shape (M,)
564
+ inplace_backward: Whether to perform backward pass in-place
565
+ reduction: Specifies the reduction to apply to the output:
566
+ 'none': no reduction will be applied (default)
567
+ 'mean': the sum of the output will be divided by the number of elements
568
+ 'sum': the output will be summed
542
569
 
543
570
  Returns:
544
- Cross entropy loss tensor of shape (M,)
571
+ Cross entropy loss tensor:
572
+ - If reduction='none': tensor of shape (M,) with per-example losses
573
+ - If reduction='mean': scalar tensor with mean loss
574
+ - If reduction='sum': scalar tensor with sum of losses
545
575
  """
546
- return CrossEntropyFunction.apply(x, target, inplace_backward)
576
+ loss = CrossEntropyFunction.apply(x, target, inplace_backward)
577
+
578
+ if reduction == "mean":
579
+ return loss.mean()
580
+ elif reduction == "sum":
581
+ return loss.sum()
582
+ elif reduction == "none":
583
+ return loss
584
+ else:
585
+ raise ValueError(
586
+ f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', or 'sum'"
587
+ )
quack/reduction_base.py CHANGED
@@ -68,7 +68,7 @@ class ReductionBase:
68
68
  )
69
69
 
70
70
  def _allocate_reduction_buffer_and_mbar(
71
- self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout
71
+ self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout, is_persistent: bool = False
72
72
  ) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
73
73
  reduction_buffer = smem.allocate_tensor(
74
74
  self.reduction_dtype,
@@ -76,20 +76,28 @@ class ReductionBase:
76
76
  byte_alignment=4,
77
77
  )
78
78
  if cutlass.const_expr(self.cluster_n > 1):
79
- mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=self.stage)
79
+ mbar_ptr = smem.allocate_array(
80
+ cutlass.Int64, num_elems=self.stage if not is_persistent else self.stage * 2
81
+ )
80
82
  else:
81
83
  mbar_ptr = None
82
84
  return reduction_buffer, mbar_ptr
83
85
 
84
86
  @cute.jit
85
- def _initialize_cluster(self, tidx: cutlass.Int32, mbar_ptr: cute.Pointer, num_warps: int):
87
+ def _initialize_cluster(
88
+ self,
89
+ tidx: cutlass.Int32,
90
+ mbar_ptr: cute.Pointer,
91
+ num_warps: int,
92
+ is_persistent: bool = False,
93
+ ):
86
94
  if cutlass.const_expr(self.cluster_n > 1):
87
- if tidx < self.stage:
95
+ if tidx < self.stage: # Initialize full barrier
88
96
  cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
97
+ if cutlass.const_expr(is_persistent): # Initialize empty barrier
98
+ cute.arch.mbarrier_init(
99
+ mbar_ptr + self.stage + tidx, num_warps * self.cluster_n
100
+ )
89
101
  cute.arch.mbarrier_init_fence()
90
- if tidx < self.stage:
91
- cute.arch.mbarrier_arrive_and_expect_tx(
92
- mbar_ptr + tidx, num_warps * self.cluster_n * self.reduction_dtype.width // 8
93
- )
94
102
  # Cluster arrive after barrier init
95
103
  cute.arch.cluster_arrive_relaxed()
quack/rmsnorm.py CHANGED
@@ -1,15 +1,16 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
-
4
- import torch
5
3
  from typing import Optional
6
4
 
7
5
  import cuda.bindings.driver as cuda
8
6
 
9
7
  import cutlass
10
8
  import cutlass.cute as cute
9
+ from cutlass import Float32, Int32
11
10
  from cutlass.cute.runtime import from_dlpack
11
+
12
12
  import quack.utils as utils
13
+ import torch
13
14
  from quack.reduction_base import ReductionBase, torch2cute_dtype_map
14
15
 
15
16
 
@@ -20,41 +21,55 @@ class RMSNorm(ReductionBase):
20
21
  self.delay_w_load = False
21
22
 
22
23
  def _calculate_threads_per_row(self):
24
+ """Calculate the number of threads per row for the RMSNorm kernel."""
23
25
  N = self.N
24
- return (
25
- 8
26
- if N <= 64
27
- else (
28
- 16
29
- if N <= 128
30
- else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
31
- )
32
- )
26
+ if N <= 64:
27
+ return 8
28
+ elif N <= 128:
29
+ return 16
30
+ elif N <= 3072:
31
+ return 32
32
+ elif N <= 6144:
33
+ return 64
34
+ elif N <= 16384:
35
+ return 128
36
+ else:
37
+ return 256
33
38
 
34
39
  def _set_cluster_n(self):
40
+ """
41
+ Set the number of clusters for the RMSNorm kernel.
42
+ Stored in self.cluster_n.
43
+ """
35
44
  N = self.N
45
+
36
46
  # cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
37
47
  # Similarly cluster_n = 8 is faster for N=128k
38
48
  if cutlass.const_expr(self.dtype.width == 16):
39
- cluster_n = (
40
- 1
41
- if N <= 16 * 1024
42
- else (
43
- 2
44
- if N <= 32 * 1024
45
- else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
46
- )
47
- )
48
- else: # fp32
49
- cluster_n = (
50
- 1
51
- if N <= 32 * 1024
52
- else (
53
- 2
54
- if N <= 64 * 1024
55
- else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
56
- )
57
- )
49
+ # 16-bit types (fp16, bf16)
50
+ if N <= 16 * 1024:
51
+ cluster_n = 1
52
+ elif N <= 32 * 1024:
53
+ cluster_n = 2
54
+ elif N <= 64 * 1024:
55
+ cluster_n = 4
56
+ elif N <= 128 * 1024:
57
+ cluster_n = 8
58
+ else:
59
+ cluster_n = 16
60
+ else:
61
+ # 32-bit types (fp32)
62
+ if N <= 32 * 1024:
63
+ cluster_n = 1
64
+ elif N <= 64 * 1024:
65
+ cluster_n = 2
66
+ elif N <= 128 * 1024:
67
+ cluster_n = 4
68
+ elif N <= 256 * 1024:
69
+ cluster_n = 8
70
+ else:
71
+ cluster_n = 16
72
+
58
73
  self.cluster_n = cluster_n
59
74
 
60
75
  @cute.jit
@@ -65,8 +80,17 @@ class RMSNorm(ReductionBase):
65
80
  mO: cute.Tensor,
66
81
  mRstd: Optional[cute.Tensor],
67
82
  stream: cuda.CUstream,
68
- eps: cutlass.Float32 = 1e-6,
83
+ eps: Float32 = 1e-6,
69
84
  ):
85
+ semistatic_shape = (*mX.shape[:-1], self.N) # Set last dimension to be statically N
86
+ new_stride = lambda t: (
87
+ cute.assume(t.stride[0], divby=128 // t.element_type.width),
88
+ t.stride[1],
89
+ )
90
+ mX, mO = [
91
+ cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
92
+ for t in (mX, mO)
93
+ ]
70
94
  assert mX.element_type == self.dtype
71
95
  assert mO.element_type == self.dtype
72
96
  self._set_cluster_n()
@@ -83,7 +107,7 @@ class RMSNorm(ReductionBase):
83
107
  self.kernel(mX, mW, mO, mRstd, eps, tv_layout, tiler_mn, self.reload_from).launch(
84
108
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
85
109
  block=[num_threads, 1, 1],
86
- cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
110
+ cluster=([1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None),
87
111
  smem=self._smem_size_in_bytes(tiler_mn, num_warps),
88
112
  stream=stream,
89
113
  )
@@ -110,7 +134,9 @@ class RMSNorm(ReductionBase):
110
134
 
111
135
  smem = cutlass.utils.SmemAllocator()
112
136
  sX = smem.allocate_tensor(
113
- mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
137
+ mX.element_type,
138
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
139
+ byte_alignment=16,
114
140
  )
115
141
  reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
116
142
 
@@ -157,6 +183,7 @@ class RMSNorm(ReductionBase):
157
183
 
158
184
  # allocate fragments for gmem->rmem
159
185
  tWrW = cute.make_fragment_like(tWgW)
186
+ tWrW.fill(0.0)
160
187
  tXrW = thr_copy_X.retile(tWrW)
161
188
  tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
162
189
 
@@ -184,7 +211,7 @@ class RMSNorm(ReductionBase):
184
211
  reduction_buffer[None, None, 0],
185
212
  mbar_ptr,
186
213
  init_val=0.0,
187
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
214
+ hook_fn=(cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None),
188
215
  )
189
216
  rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
190
217
  if cutlass.const_expr(mRstd is not None):
@@ -232,25 +259,36 @@ def _rmsnorm_fwd(
232
259
  assert weight.dim() == 1, "Weight must be 1D"
233
260
  assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
234
261
  assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
235
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
236
- assert weight.dtype == torch.float32, "Weight must be float32"
262
+ assert x.dtype in [
263
+ torch.float16,
264
+ torch.bfloat16,
265
+ torch.float32,
266
+ ], "Unsupported dtype"
267
+
268
+ assert weight.dtype in [
269
+ torch.float32,
270
+ torch.bfloat16,
271
+ torch.float16,
272
+ ], "Weight must be float32, float16 or bfloat16"
273
+
237
274
  M, N = x.shape
238
275
  device = x.device
239
276
  out = torch.empty_like(x)
240
277
  rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
241
278
  dtype = torch2cute_dtype_map[x.dtype]
279
+ # convert_from_dlpack = lambda x: (
280
+ # from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
281
+ # mode=0, divisibility=128 // dtype.width
282
+ # )
283
+ # )
242
284
  convert_from_dlpack = lambda x: (
243
- from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
244
- mode=0, stride_order=(0, 1)
245
- )
285
+ from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
246
286
  )
247
- x_tensor, out_tensor = [
248
- # utils.convert_from_dlpack(t, leading_dim=t.ndim - 1, divisibility=128 // dtype.width)
249
- convert_from_dlpack(t)
250
- for t in (x, out)
251
- ]
287
+ x_tensor, out_tensor = [convert_from_dlpack(t) for t in (x, out)]
288
+ # handle weight divisibility based on weight dtype
289
+ weight_dtype = torch2cute_dtype_map[weight.dtype]
252
290
  weight_tensor = utils.convert_from_dlpack(
253
- weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
291
+ weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
254
292
  )
255
293
  rstd_tensor = (
256
294
  from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
@@ -258,7 +296,7 @@ def _rmsnorm_fwd(
258
296
  else None
259
297
  )
260
298
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
261
- compile_key = (dtype, N, rstd is not None)
299
+ compile_key = (dtype, N, rstd is not None, weight.dtype)
262
300
  if compile_key not in _rmsnorm_fwd.compile_cache:
263
301
  rmsnorm_op = RMSNorm(dtype, N)
264
302
  _rmsnorm_fwd.compile_cache[compile_key] = cute.compile(
@@ -300,8 +338,15 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
300
338
 
301
339
  class RMSNormBackward(ReductionBase):
302
340
  def __init__(self, dtype: cutlass.Numeric, N: int):
303
- # 1 stage for computing mean of x_hat * wdy
304
- super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Float32)
341
+ # 2 stages for double buffering when computing mean of x_hat * wdy
342
+ super().__init__(dtype, N, stage=2, reduction_dtype=Float32)
343
+ self.reload_wdy = None if N <= 16 * 1024 else "smem"
344
+ if self.N > 128 * 1024 and self.dtype.width >= 32:
345
+ # Not enough smem
346
+ raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits")
347
+
348
+ def _get_num_threads(self):
349
+ return 128 if self.N <= 4096 else 256
305
350
 
306
351
  def _calculate_threads_per_row(self):
307
352
  N = self.N
@@ -311,46 +356,49 @@ class RMSNormBackward(ReductionBase):
311
356
  else (
312
357
  16
313
358
  if N <= 128
314
- else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
359
+ else (32 if N <= 256 else (64 if N <= 512 else (128 if N <= 4096 else 256)))
315
360
  )
316
361
  )
317
362
 
318
363
  def _set_cluster_n(self):
319
364
  N = self.N
320
- if cutlass.const_expr(self.dtype.width == 16):
321
- cluster_n = (
322
- 1
323
- if N <= 16 * 1024
324
- else (
325
- 2
326
- if N <= 32 * 1024
327
- else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
328
- )
329
- )
330
- else: # fp32
331
- cluster_n = (
332
- 1
333
- if N <= 32 * 1024
334
- else (
335
- 2
336
- if N <= 64 * 1024
337
- else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
338
- )
339
- )
365
+ cluster_n = (
366
+ 1
367
+ if N <= 8 * 1024
368
+ else (2 if N <= 16 * 1024 else (4 if N <= 32 * 1024 else (8 if N <= 64 * 1024 else 16)))
369
+ )
340
370
  self.cluster_n = cluster_n
341
371
 
372
+ def _smem_size_in_bytes(self, tiler_mn, num_warps):
373
+ return (
374
+ # Multiply by 2 since we need space for X and dOut,
375
+ # and multiply by another 2 due to double buffering
376
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2 * 2
377
+ + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
378
+ + self.stage * (cutlass.Int64.width // 8) * 2 # mult 2 as we need 2 mbar per stage
379
+ )
380
+
342
381
  @cute.jit
343
382
  def __call__(
344
383
  self,
345
384
  mX: cute.Tensor,
346
385
  mW: cute.Tensor,
347
- mDout: cute.Tensor,
386
+ mdOut: cute.Tensor,
348
387
  mRstd: cute.Tensor,
349
- mDx: cute.Tensor,
350
- mDw: cute.Tensor,
351
- sm_count: cutlass.Constexpr,
388
+ mdX: cute.Tensor,
389
+ mdW: cute.Tensor,
390
+ sm_count: Int32,
352
391
  stream: cuda.CUstream,
353
392
  ):
393
+ semistatic_shape = (*mX.shape[:-1], self.N) # Set last dimension to be statically N
394
+ new_stride = lambda t: (
395
+ cute.assume(t.stride[0], divby=128 // t.element_type.width),
396
+ t.stride[1],
397
+ )
398
+ mX, mdOut, mdX = [
399
+ cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
400
+ for t in (mX, mdOut, mdX)
401
+ ]
354
402
  self._set_cluster_n()
355
403
  tiler_mn, tv_layout = self._get_tv_layout()
356
404
  num_threads = cute.size(tv_layout, mode=[0])
@@ -359,14 +407,8 @@ class RMSNormBackward(ReductionBase):
359
407
  mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
360
408
  mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
361
409
 
362
- mRstd_expanded_layout = cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,)))
363
- mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
364
-
365
- num_blocks = (
366
- sm_count if tiler_mn[0] == 1 else min(sm_count, cute.ceil_div(1024, tiler_mn[0]))
367
- )
368
-
369
- self.kernel(mX, mW, mDout, mRstd, mDx, mDw, sm_count, tv_layout, tiler_mn).launch(
410
+ num_blocks = sm_count
411
+ self.kernel(mX, mW, mdOut, mRstd, mdX, mdW, tv_layout, tiler_mn).launch(
370
412
  grid=[num_blocks, self.cluster_n, 1],
371
413
  block=[num_threads, 1, 1],
372
414
  cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
@@ -379,177 +421,260 @@ class RMSNormBackward(ReductionBase):
379
421
  self,
380
422
  mX: cute.Tensor,
381
423
  mW: cute.Tensor,
382
- mDout: cute.Tensor,
424
+ mdOut: cute.Tensor,
383
425
  mRstd: cute.Tensor,
384
- mDx: cute.Tensor,
385
- mDw: cute.Tensor,
386
- sm_count: cutlass.Constexpr,
426
+ mdX: cute.Tensor,
427
+ mdW: cute.Tensor,
387
428
  tv_layout: cute.Layout,
388
429
  tiler_mn: cute.Shape,
389
430
  ):
390
431
  tidx, _, _ = cute.arch.thread_idx()
391
- bidx, cluster_y, _ = cute.arch.block_idx()
432
+ bidx_start, _, _ = cute.arch.block_idx()
392
433
  gdim, _, _ = cute.arch.grid_dim()
434
+ if cutlass.const_expr(self.cluster_n > 1):
435
+ cluster_y = cute.arch.block_idx()[1]
436
+ else:
437
+ cluster_y = cutlass.const_expr(0)
393
438
 
394
439
  shape = mX.shape
395
440
  M, N = shape[0], shape[1]
441
+ is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
396
442
 
397
443
  idX = cute.make_identity_tensor(shape)
398
444
 
399
445
  smem = cutlass.utils.SmemAllocator()
400
- reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
446
+ smem_layout = cute.make_ordered_layout((tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2))
447
+ sX = smem.allocate_tensor(mX.element_type, smem_layout, byte_alignment=16)
448
+ sdOut = smem.allocate_tensor(mdOut.element_type, smem_layout, byte_alignment=16)
449
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
450
+ smem, tv_layout, is_persistent=True
451
+ )
452
+ if cutlass.const_expr(mbar_ptr is not None):
453
+ mbar_full_ptr, mbar_empty_ptr = mbar_ptr, mbar_ptr + 2
454
+ else:
455
+ mbar_full_ptr, mbar_empty_ptr = None, None
401
456
 
402
457
  copy_atom_load_X = cute.make_copy_atom(
403
458
  cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
404
459
  )
405
-
460
+ copy_atom_load_X_async = cute.make_copy_atom(
461
+ cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
462
+ )
406
463
  copy_atom_load_W = cute.make_copy_atom(
407
464
  cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
408
465
  )
409
-
410
466
  copy_atom_store_dX = cute.make_copy_atom(
411
- cute.nvgpu.CopyUniversalOp(), mDx.element_type, num_bits_per_copy=128
467
+ cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=128
412
468
  )
413
-
414
- copy_atom_dw = cute.make_copy_atom(
415
- cute.nvgpu.CopyUniversalOp(), mDw.element_type, num_bits_per_copy=128
469
+ copy_atom_store_dW = cute.make_copy_atom(
470
+ cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=128
416
471
  )
417
472
 
418
473
  thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
474
+ thr_copy_X_async = cute.make_tiled_copy(
475
+ copy_atom_load_X_async, tv_layout, tiler_mn
476
+ ).get_slice(tidx)
419
477
  thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
420
- thr_copy_dw = cute.make_tiled_copy(copy_atom_dw, tv_layout, tiler_mn).get_slice(tidx)
421
- thr_store_dx = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
478
+ thr_copy_dW = cute.make_tiled_copy(copy_atom_store_dW, tv_layout, tiler_mn).get_slice(tidx)
479
+ thr_store_dX = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
422
480
 
423
- gW = cute.local_tile(mW, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
481
+ gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
424
482
  tWgW = thr_copy_W.partition_S(gW)
425
483
  tWrW = cute.make_fragment_like(tWgW)
484
+ # Need this, otherwise rW can have arbitrary values that changes the reduction
485
+ if not is_even_N:
486
+ tWrW.fill(0.0)
426
487
  tXrW = thr_copy_X.retile(tWrW)
427
488
 
428
- gW_coord = cute.local_tile(idX, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
429
-
430
- tWpW = utils.predicate_k(thr_copy_W.partition_S(gW_coord), limit=shape[1])
489
+ gW_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
490
+ tWpW = (
491
+ utils.predicate_k(thr_copy_W.partition_S(gW_coord), limit=shape[1])
492
+ if not is_even_N
493
+ else None
494
+ )
431
495
  cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
432
496
  weight = tXrW.load().to(cute.Float32)
433
497
 
434
498
  num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
435
499
 
436
- self._initialize_cluster(tidx, mbar_ptr, num_warps)
437
-
438
- dw_coord = cute.local_tile(idX, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
439
- tDwpDw = utils.predicate_k(thr_copy_dw.partition_S(dw_coord), limit=shape[1])
500
+ self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True)
440
501
 
441
- gDw = cute.local_tile(mDw, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
442
- tDwgDw = thr_copy_dw.partition_D(gDw)
443
- tDwrDw = cute.make_fragment_like(tDwgDw)
444
- dw_accumulator = thr_copy_X.retile(tDwrDw)
445
- dw_accumulator.fill(0.0)
446
-
447
- M_pad = ((M + sm_count - 1) // sm_count) * sm_count
502
+ dw_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
503
+ tdWpdW = (
504
+ utils.predicate_k(thr_copy_dW.partition_S(dw_coord), limit=shape[1])
505
+ if not is_even_N
506
+ else None
507
+ )
448
508
 
449
- jump = sm_count if tiler_mn[0] == 1 else min(sm_count, cute.ceil_div(1024, tiler_mn[0]))
509
+ gdW = cute.local_tile(mdW, (1, tiler_mn[1]), (bidx_start, cluster_y))
510
+ tdWgdW = thr_copy_dW.partition_D(gdW)
511
+ # Always compute partial weight gradients in fp32
512
+ tdWrdW = cute.make_fragment_like(tdWgdW, Float32)
513
+ tXrdW = thr_copy_X.retile(tdWrdW)
450
514
 
451
- if cutlass.const_expr(self.cluster_n > 1):
452
- cute.arch.cluster_arrive()
453
- cute.arch.cluster_wait()
515
+ gX = cute.local_tile(mX, tiler_mn, (None, cluster_y))
516
+ gdOut = cute.local_tile(mdOut, tiler_mn, (None, cluster_y))
517
+ gdX = cute.local_tile(mdX, tiler_mn, (None, cluster_y))
518
+ cX = cute.local_tile(idX, tiler_mn, (None, cluster_y))
519
+ tXgX = thr_copy_X.partition_S(gX)
520
+ tXsX = thr_copy_X.partition_D(sX)
521
+ tXgdOut = thr_copy_X.partition_S(gdOut)
522
+ tXsdOut = thr_copy_X.partition_D(sdOut)
523
+ tXgdX = thr_store_dX.partition_D(gdX)
524
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None]
525
+ # This doesn't change across iterations
526
+ tXpX = (
527
+ utils.predicate_k(thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1])
528
+ if not is_even_N
529
+ else None
530
+ )
454
531
 
455
- ## need to update range_dynamic since it will be deprecated soon
456
- for row_offset in cutlass.range_dynamic(bidx, M_pad, jump):
457
- gX = cute.local_tile(
458
- mX, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
459
- )
460
- gDout = cute.local_tile(
461
- mDout, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
462
- )
463
- gRstd = cute.local_tile(
464
- mRstd, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
532
+ tXrX, tXrdOut, tXrdX = [
533
+ cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdOut, tXgdX)
534
+ ]
535
+
536
+ # Prefetch the first batch
537
+ row = tXcX[None, None, None, bidx_start][0][0]
538
+ if row < M:
539
+ tXgX_cur = utils.coord_offset_i64(bidx_start, tXgX, dim=3)[None, None, None, 0]
540
+ tXgdOut_cur = utils.coord_offset_i64(bidx_start, tXgdOut, dim=3)[None, None, None, 0]
541
+ cute.copy(
542
+ copy_atom_load_X_async,
543
+ tXgX_cur,
544
+ tXsX[None, None, None, 0],
545
+ pred=tXpX,
465
546
  )
466
- gDx = cute.local_tile(
467
- mDx, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
547
+ cute.copy(
548
+ copy_atom_load_X_async,
549
+ tXgdOut_cur,
550
+ tXsdOut[None, None, None, 0],
551
+ pred=tXpX,
468
552
  )
469
- cX = cute.local_tile(
470
- idX, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
471
- )
472
-
473
- tXgX = thr_copy_X.partition_S(gX)
474
- thrDout = thr_copy_X.partition_S(gDout)
475
- tXrRstd = thr_copy_W.partition_S(gRstd)
476
- thrDx = thr_store_dx.partition_D(gDx)
477
- tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
478
-
479
- tXrX, frgDout, frgDx = [cute.make_fragment_like(thr) for thr in (tXgX, thrDout, thrDx)]
480
-
481
- tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
553
+ elif tiler_mn[0] > 1:
554
+ # Fill with zero, otherwise smem will be uninitialized, and we could read this back
555
+ # later into registers, causing wrong dW.
556
+ utils.fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero)
557
+ utils.fill_oob(tXsdOut[None, None, None, 0], None, fill_value=mdOut.element_type.zero)
558
+ cute.arch.cp_async_commit_group()
482
559
 
483
- if tXcX[0][0] < shape[0]:
484
- cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
485
- cute.copy(copy_atom_load_X, thrDout, frgDout, pred=tXpX)
560
+ if cutlass.const_expr(self.cluster_n > 1):
561
+ cute.arch.cluster_wait()
486
562
 
563
+ threads_per_row = tv_layout.shape[0][0]
564
+ tXrdW.fill(0.0)
565
+ stage = Int32(0)
566
+ producer_phase = Int32(1)
567
+ consumer_phase = Int32(0)
568
+ for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
569
+ row = tXcX[None, None, None, bidx][0][0]
570
+ rstd = cutlass.Float.zero
571
+ if row + gdim * tiler_mn[0] < M: # Prefetch the next batch
572
+ tXgX_cur = utils.coord_offset_i64(bidx + gdim, tXgX, dim=3)[None, None, None, 0]
573
+ tXgdOut_cur = utils.coord_offset_i64(bidx + gdim, tXgdOut, dim=3)[
574
+ None, None, None, 0
575
+ ]
576
+ cute.copy(
577
+ copy_atom_load_X_async,
578
+ tXgX_cur,
579
+ tXsX[None, None, None, stage ^ 1],
580
+ pred=tXpX,
581
+ )
582
+ cute.copy(
583
+ copy_atom_load_X_async,
584
+ tXgdOut_cur,
585
+ tXsdOut[None, None, None, stage ^ 1],
586
+ pred=tXpX,
587
+ )
588
+ elif tiler_mn[0] > 1:
589
+ utils.fill_oob(
590
+ tXsX[None, None, None, stage ^ 1],
591
+ None,
592
+ fill_value=mX.element_type.zero,
593
+ )
594
+ utils.fill_oob(
595
+ tXsdOut[None, None, None, stage ^ 1],
596
+ None,
597
+ fill_value=mdOut.element_type.zero,
598
+ )
599
+ cute.arch.cp_async_commit_group()
600
+ if row < M or tiler_mn[0] == 1:
601
+ rstd = mRstd[row]
602
+ cute.arch.cp_async_wait_group(1)
603
+ cute.autovec_copy(tXsX[None, None, None, stage], tXrX)
487
604
  x = tXrX.load().to(cute.Float32)
488
- dout = frgDout.load().to(cute.Float32)
489
-
490
- rstd = tXrRstd[0]
605
+ cute.autovec_copy(tXsdOut[None, None, None, stage], tXrdOut)
606
+ dout = tXrdOut.load().to(cute.Float32)
491
607
  x_hat = x * rstd
492
608
  wdy = dout * weight
493
-
494
- threads_per_row = tv_layout.shape[0][0]
495
-
496
- row = tXcX[0][0]
497
609
  if cutlass.const_expr(self.cluster_n > 1):
498
- cute.arch.cluster_arrive()
499
- cute.arch.cluster_wait()
500
- else:
501
- cute.arch.barrier()
502
-
610
+ cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
503
611
  mean_xhat_wdy = (
504
612
  utils.row_reduce(
505
613
  x_hat * wdy,
506
614
  cute.ReductionOp.ADD,
507
615
  threads_per_row,
508
- reduction_buffer[None, None, 0],
509
- mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
616
+ reduction_buffer[None, None, stage],
617
+ (mbar_full_ptr + stage if cutlass.const_expr(self.cluster_n > 1) else None),
618
+ phase=consumer_phase,
510
619
  init_val=0.0,
511
- hook_fn=cute.arch.cluster_wait
512
- if cutlass.const_expr(self.cluster_n > 1)
513
- else None,
514
620
  )
515
621
  / shape[1]
516
622
  )
517
623
 
518
- dx = (wdy - x_hat * mean_xhat_wdy) * rstd
519
- frgDx.store(dx.to(frgDout.element_type))
520
-
521
- if row < M:
522
- cute.copy(copy_atom_store_dX, frgDx, thrDx, pred=tXpX)
523
-
524
624
  if cutlass.const_expr(self.cluster_n > 1):
525
- cute.arch.cluster_arrive()
526
- cute.arch.cluster_wait()
527
- else:
528
- cute.arch.barrier()
529
-
530
- if row < M:
531
- dw_row = dout * x_hat
532
- current_dw = dw_accumulator.load().to(cute.Float32)
533
- updated_dw = current_dw + dw_row
534
- dw_accumulator.store(updated_dw.to(dw_accumulator.element_type))
625
+ # It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
626
+ # Requires adjusting the thread_count when initializing the mbar
627
+ cute.arch.sync_warp()
628
+ lane_idx = cute.arch.lane_idx()
629
+ if lane_idx < self.cluster_n:
630
+ cute.arch.mbarrier_arrive(
631
+ mbar_empty_ptr + stage, peer_cta_rank_in_cluster=lane_idx
632
+ )
633
+
634
+ if cutlass.const_expr(self.reload_wdy == "smem"):
635
+ cute.autovec_copy(tXsdOut[None, None, None, stage], tXrdOut)
636
+ dout = tXrdOut.load().to(cute.Float32)
637
+ wdy = dout * weight
535
638
 
536
- """
537
- if cutlass.const_expr(self.cluster_n > 1):
538
- cute.arch.cluster_arrive()
539
- cute.arch.cluster_wait()
540
- else:
541
- cute.arch.barrier()
542
- """
543
- """
544
- if cutlass.const_expr(self.cluster_n > 1):
545
- cute.arch.cluster_arrive()
546
- cute.arch.cluster_wait()
547
- else:
639
+ dx = (wdy - x_hat * mean_xhat_wdy) * rstd
640
+ tXrdX.store(dx.to(tXrdOut.element_type))
641
+ if row < M or tiler_mn[0] == 1:
642
+ tXgdX_cur = utils.coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
643
+ cute.copy(copy_atom_store_dX, tXrdX, tXgdX_cur, pred=tXpX)
644
+ # Accumulate weight gradients in fp32
645
+ tXrdW.store(tXrdW.load() + dout * x_hat)
646
+
647
+ stage ^= 1
648
+ if stage == 0:
649
+ consumer_phase ^= 1
650
+ producer_phase ^= 1
651
+
652
+ if cutlass.const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
653
+ cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
654
+
655
+ if cutlass.const_expr(tiler_mn[0] > 1):
656
+ # reduction of dw_partial within the same threadblock
657
+ sdW = cute.make_tensor(
658
+ cute.recast_ptr(sX.iterator, dtype=cute.Float32),
659
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
660
+ )
661
+ tXsdW = thr_copy_X.partition_D(sdW)
548
662
  cute.arch.barrier()
549
- """
663
+ row = tXcX[None, None, None, 0][0][0]
664
+ if row > 0:
665
+ cute.autovec_copy(tXrdW, tXsdW)
666
+ cute.arch.barrier()
667
+ if row == 0:
668
+ for i in cutlass.range_constexpr(1, cutlass.const_expr(tiler_mn[0])):
669
+ tXrdW_other = cute.make_fragment_like(tXrdW)
670
+ tXsdW_other = cute.make_tensor(tXsdW.iterator + i * sdW.stride[0], tXsdW.layout)
671
+ cute.autovec_copy(tXsdW_other, tXrdW_other)
672
+ tXrdW.store(tXrdW.load() + tXrdW_other.load())
673
+ cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
550
674
 
551
- cute.autovec_copy(dw_accumulator, tDwrDw)
552
- cute.copy(copy_atom_dw, tDwrDw, tDwgDw, pred=tDwpDw)
675
+ else:
676
+ # dw is already in fp32, so we can directly copy to global memory
677
+ cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
553
678
 
554
679
 
555
680
  def _rmsnorm_backward(
@@ -573,37 +698,58 @@ def _rmsnorm_backward(
573
698
  assert weight.dim() == 1, "Weight must be 1D"
574
699
  assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
575
700
  assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
576
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
577
- assert weight.dtype == torch.float32, "Weight must be float32"
701
+ assert x.dtype in [
702
+ torch.float16,
703
+ torch.bfloat16,
704
+ torch.float32,
705
+ ], "Unsupported dtype"
706
+
707
+ assert weight.dtype in [
708
+ torch.float32,
709
+ torch.bfloat16,
710
+ torch.float16,
711
+ ], "Weight must be float32, float16 or bfloat16"
578
712
 
579
713
  M, N = x.shape
580
714
  dx = torch.empty_like(x)
581
715
 
582
716
  device = x.device
583
717
 
584
- sm_count = torch.cuda.get_device_properties(device).multi_processor_count * 8
585
- dw_partial = torch.zeros((sm_count, N), device=device, dtype=weight.dtype)
718
+ # This should be tuned on how many CTAs can be launched on each SM
719
+ sm_count_multiple = (
720
+ 16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
721
+ )
722
+ sm_count = torch.cuda.get_device_properties(device).multi_processor_count
723
+ # By right, if we're using cluster, this should be cluster_count not sm_count.
724
+ # But for cluster >= 4, due to quantization we would need to query active max cluster.
725
+ # Instead we just do sm_count * 2, which is reasonably larger than active_cluster_count to
726
+ # avoid wave quantization.
727
+ sm_count = (
728
+ sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
729
+ )
730
+
731
+ # Always store partial gradients in fp32 for numerical accuracy
732
+ dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
586
733
 
587
734
  dtype = torch2cute_dtype_map[x.dtype]
588
735
 
589
- convert_from_dlpack = lambda tensor: (
590
- from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
591
- mode=0, stride_order=(0, 1)
592
- )
736
+ convert_from_dlpack = lambda x: (
737
+ from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
593
738
  )
594
-
595
739
  x_tensor, dout_tensor, dx_tensor = [convert_from_dlpack(tensor) for tensor in (x, dout, dx)]
596
740
 
741
+ # Handle weight div based on weight dtype
742
+ weight_dtype = torch2cute_dtype_map[weight.dtype]
597
743
  weight_tensor = utils.convert_from_dlpack(
598
- weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
744
+ weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
599
745
  )
600
746
 
601
- dw_partial_tensor = convert_from_dlpack(dw_partial)
747
+ dw_partial_tensor = from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
602
748
  rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
603
749
 
604
750
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
605
751
 
606
- compile_key = (dtype, N)
752
+ compile_key = (dtype, N, weight.dtype)
607
753
  if compile_key not in _rmsnorm_backward.compile_cache:
608
754
  rmsnorm_backward_op = RMSNormBackward(dtype, N)
609
755
  _rmsnorm_backward.compile_cache[compile_key] = cute.compile(
@@ -625,9 +771,10 @@ def _rmsnorm_backward(
625
771
  rstd_tensor,
626
772
  dx_tensor,
627
773
  dw_partial_tensor,
774
+ sm_count,
628
775
  current_stream,
629
776
  )
630
-
777
+ # we have summed the partial gradients in fp32, now we convert back to the weight dtype
631
778
  dw = dw_partial.sum(dim=0).to(weight.dtype)
632
779
  return dx, dw
633
780
 
@@ -638,16 +785,29 @@ _rmsnorm_backward.compile_cache = {}
638
785
  class RMSNormFunction(torch.autograd.Function):
639
786
  @staticmethod
640
787
  def forward(ctx, x, weight, eps):
788
+ x_shape_start = x.shape
789
+
790
+ # Flatten input
791
+ x = x.view(-1, x.shape[-1])
792
+
641
793
  out, rstd = _rmsnorm_fwd(x, weight, eps, return_rstd=True)
642
794
  ctx.save_for_backward(x, weight, rstd)
643
795
  ctx.eps = eps
644
- return out
796
+ ctx.x_shape_start = x_shape_start
797
+
798
+ return out.reshape(x_shape_start)
645
799
 
646
800
  @staticmethod
647
801
  def backward(ctx, dout):
648
802
  x, weight, rstd = ctx.saved_tensors
803
+ x_shape_start = ctx.x_shape_start
804
+ # Reshape dout to match the flattened shape used in forward
805
+ dout = dout.view(-1, dout.shape[-1])
649
806
  dx, dw = _rmsnorm_backward(x, weight, dout, rstd)
650
- # dw is returned for weight gradient, None for eps gradient
807
+ dx = dx.view(x_shape_start)
808
+ # dx is returned for input gradient,
809
+ # dw is returned for weight gradient,
810
+ # None for eps gradient
651
811
  return dx, dw, None
652
812
 
653
813
 
@@ -663,3 +823,39 @@ def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.T
663
823
  Normalized output tensor of same shape as x
664
824
  """
665
825
  return RMSNormFunction.apply(x, weight, eps)
826
+
827
+
828
+ class QuackRMSNorm(torch.nn.Module):
829
+ """RMSNorm module that behaves like torch.nn.RMSNorm.
830
+
831
+ This class provides a drop-in replacement for torch.nn.RMSNorm that uses
832
+ the quack.rmsnorm implementation under the hood.
833
+
834
+ Args:
835
+ dim (int): The dimension to normalize over
836
+ eps (float, optional): A small constant for numerical stability. Default: 1e-6
837
+
838
+ Attributes:
839
+ weight (torch.nn.Parameter): The learnable weight parameter
840
+ eps (float): A small constant for numerical stability
841
+ """
842
+
843
+ def __init__(self, dim: int, eps: float = 1e-6):
844
+ super().__init__()
845
+ self.weight = torch.nn.Parameter(torch.ones(dim))
846
+ self.eps = eps
847
+
848
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
849
+ """Apply RMSNorm to the input tensor.
850
+
851
+ Args:
852
+ x (torch.Tensor): Input tensor
853
+
854
+ Returns:
855
+ torch.Tensor: Normalized tensor
856
+ """
857
+ return rmsnorm(x, self.weight, self.eps)
858
+
859
+ def reset_parameters(self):
860
+ """Reset the weight parameter to ones."""
861
+ torch.nn.init.ones_(self.weight)
quack/softmax.py CHANGED
@@ -133,9 +133,7 @@ class Softmax(ReductionBase):
133
133
 
134
134
  is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
135
135
  tXpX = (
136
- utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
137
- if cutlass.const_expr(not is_even_N)
138
- else None
136
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
139
137
  )
140
138
  if tXcX[0][0] < shape[0]:
141
139
  cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
quack/utils.py CHANGED
@@ -120,12 +120,20 @@ def cluster_reduce(
120
120
  reduction_buffer: cute.Tensor,
121
121
  mbar_ptr: cute.Pointer,
122
122
  init_val: cute.Numeric = 0.0,
123
+ phase: Optional[cutlass.Int32] = None,
123
124
  ) -> cute.Numeric:
124
125
  """reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
125
126
  cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
126
127
  lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
127
- warps_per_row, cluster_n = reduction_buffer.shape[1]
128
+ rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
128
129
  row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
130
+ if warp_idx == 0:
131
+ with cute.arch.elect_one():
132
+ num_warps = rows_per_block * warps_per_row
133
+ cute.arch.mbarrier_arrive_and_expect_tx(
134
+ mbar_ptr,
135
+ num_warps * cluster_n * reduction_buffer.element_type.width // 8,
136
+ )
129
137
  if lane_idx < cluster_n:
130
138
  store_shared_remote(
131
139
  val,
@@ -133,7 +141,7 @@ def cluster_reduce(
133
141
  mbar_ptr,
134
142
  peer_cta_rank_in_cluster=lane_idx,
135
143
  )
136
- cute.arch.mbarrier_wait(mbar_ptr, phase=0)
144
+ cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
137
145
  block_reduce_val = init_val
138
146
  num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
139
147
  for i in cutlass.range_constexpr(num_iter):
@@ -149,13 +157,14 @@ def block_or_cluster_reduce(
149
157
  op: Callable,
150
158
  reduction_buffer: cute.Tensor,
151
159
  mbar_ptr: Optional[cute.Pointer],
160
+ phase: Optional[cutlass.Int32] = None,
152
161
  init_val: cute.Numeric = 0.0,
153
162
  ) -> cute.Numeric:
154
163
  """Perform either block or cluster reduction based on whether mbar_ptr is provided."""
155
164
  if cutlass.const_expr(mbar_ptr is None):
156
165
  return block_reduce(val, op, reduction_buffer, init_val=init_val)
157
166
  else:
158
- return cluster_reduce(val, op, reduction_buffer, mbar_ptr, init_val=init_val)
167
+ return cluster_reduce(val, op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val)
159
168
 
160
169
 
161
170
  @cute.jit
@@ -165,6 +174,7 @@ def row_reduce(
165
174
  threads_per_row: cutlass.Constexpr[int],
166
175
  reduction_buffer: Optional[cute.Tensor] = None,
167
176
  mbar_ptr: Optional[cute.Pointer] = None,
177
+ phase: Optional[cutlass.Int32] = None,
168
178
  init_val: cute.Numeric = 0.0,
169
179
  hook_fn: Optional[Callable] = None,
170
180
  ) -> cute.Numeric:
@@ -193,7 +203,7 @@ def row_reduce(
193
203
  ), "mbar_ptr must be provided for cluster reduction"
194
204
  if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
195
205
  val = block_or_cluster_reduce(
196
- val, warp_op, reduction_buffer, mbar_ptr, init_val=init_val
206
+ val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val
197
207
  )
198
208
  return val
199
209
 
@@ -205,6 +215,7 @@ def online_softmax_reduce(
205
215
  reduction_buffer: Optional[cute.Tensor] = None,
206
216
  mbar_ptr: Optional[cute.Pointer] = None,
207
217
  hook_fn: Optional[Callable] = None,
218
+ phase: Optional[cutlass.Int32] = None,
208
219
  return_exp_x: bool = False,
209
220
  ) -> [Float32, Float32, Optional[cute.TensorSSA]]:
210
221
  assert x.dtype == Float32, "x must be of type Float32"
@@ -225,7 +236,7 @@ def online_softmax_reduce(
225
236
  if cutlass.const_expr(hook_fn is not None):
226
237
  hook_fn()
227
238
  if cutlass.const_expr(reduction_buffer is not None):
228
- warps_per_row, cluster_n = reduction_buffer.shape[1]
239
+ rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
229
240
  assert (
230
241
  cluster_n == 1 or mbar_ptr is not None
231
242
  ), "mbar_ptr must be provided for cluster reduction"
@@ -251,6 +262,13 @@ def online_softmax_reduce(
251
262
  max_x = max_x_final
252
263
  else:
253
264
  cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
265
+ if warp_idx == 0:
266
+ with cute.arch.elect_one():
267
+ num_warps = rows_per_block * warps_per_row
268
+ cute.arch.mbarrier_arrive_and_expect_tx(
269
+ mbar_ptr,
270
+ num_warps * cluster_n * reduction_buffer.element_type.width // 8,
271
+ )
254
272
  if lane_idx < cluster_n:
255
273
  store_shared_remote(
256
274
  f32x2_to_i64(max_x, sum_exp_x),
@@ -258,7 +276,7 @@ def online_softmax_reduce(
258
276
  mbar_ptr,
259
277
  peer_cta_rank_in_cluster=lane_idx,
260
278
  )
261
- cute.arch.mbarrier_wait(mbar_ptr, phase=0)
279
+ cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
262
280
  num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
263
281
  max_x_single_warp = cute.make_fragment(num_iter, Float32)
264
282
  max_x_single_warp.fill(-Float32.inf)
@@ -351,7 +369,7 @@ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
351
369
 
352
370
 
353
371
  @cute.jit
354
- def fill_oob(tXsX: cute.Tensor, tXpX: cute.Tensor, fill_value: cute.Numeric) -> None:
372
+ def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Numeric) -> None:
355
373
  """Fill out-of-bounds values in shared memory tensor.
356
374
 
357
375
  Args:
@@ -361,9 +379,12 @@ def fill_oob(tXsX: cute.Tensor, tXpX: cute.Tensor, fill_value: cute.Numeric) ->
361
379
  """
362
380
  tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), 0, 0])
363
381
  tXrX_fill.fill(fill_value)
364
- for rest_v in cutlass.range_constexpr(tXpX.shape[0]):
365
- for rest_k in cutlass.range_constexpr(tXpX.shape[2]):
366
- if not tXpX[rest_v, 0, rest_k]:
382
+ for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]):
383
+ for rest_k in cutlass.range_constexpr(tXsX.shape[2]):
384
+ if cutlass.const_expr(tXpX is not None):
385
+ if not tXpX[rest_v, 0, rest_k]:
386
+ cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
387
+ else:
367
388
  cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
368
389
 
369
390
 
@@ -396,6 +417,9 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
396
417
  def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
397
418
  flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
398
419
  flat_stride = cute.flatten_to_tuple(tensor.stride)
420
+ assert len(flat_coord_i64) == len(
421
+ flat_stride
422
+ ), "Coordinate and stride must have the same length"
399
423
  offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
400
424
  assert isinstance(tensor.iterator, cute.Pointer)
401
425
  # HACK: we assume that applying the offset does not change the pointer alignment
@@ -406,3 +430,19 @@ def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=No
406
430
  assumed_align=tensor.iterator.max_alignment,
407
431
  )
408
432
  return cute.make_tensor(new_ptr, tensor.layout)
433
+
434
+
435
+ @dsl_user_op
436
+ def coord_offset_i64(
437
+ idx: cute.typing.Int, tensor: cute.Tensor, dim: int, *, loc=None, ip=None
438
+ ) -> cute.Tensor:
439
+ offset = cutlass.Int64(idx) * cute.size(tensor.stride[dim])
440
+ assert isinstance(tensor.iterator, cute.Pointer)
441
+ # HACK: we assume that applying the offset does not change the pointer alignment
442
+ new_ptr = cute.make_ptr(
443
+ tensor.element_type,
444
+ tensor.iterator.toint() + offset * tensor.element_type.width // 8,
445
+ tensor.memspace,
446
+ assumed_align=tensor.iterator.max_alignment,
447
+ )
448
+ return cute.make_tensor(new_ptr, tensor.layout)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.6
3
+ Version: 0.1.8
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
6
  Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
@@ -0,0 +1,12 @@
1
+ quack/__init__.py,sha256=tDgX5MF1ttfEyDVFWi47DA8tDooYcBQlkuzvabGUoQI,203
2
+ quack/cross_entropy.py,sha256=VYSAd28GmtnMoKQwLrorvySDtJfRhoqVd-aeM52FmsI,20866
3
+ quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
4
+ quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
5
+ quack/rmsnorm.py,sha256=-qrKqPKk0fUuq0a5-vJmZZ7nQsHgyaqTg0EKhWT44r0,32738
6
+ quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
7
+ quack/utils.py,sha256=aiyzBc9BEwq8s965elfiR331hAaLLBKL9kDHjuls86Q,17791
8
+ quack_kernels-0.1.8.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
9
+ quack_kernels-0.1.8.dist-info/METADATA,sha256=b_2PxFEoVqWJbT2FtuP9FJyF-jpL2Z3q9OHoOEipqo4,289
10
+ quack_kernels-0.1.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
11
+ quack_kernels-0.1.8.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
12
+ quack_kernels-0.1.8.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- quack/__init__.py,sha256=EV_43VfcxQUsFu_Nfq0944ImloZ2T9X94-8T9IQM0I8,203
2
- quack/cross_entropy.py,sha256=bg66wECki5I71SMPIRUa-6-oFJ93aIKpK1jqT__SCBM,19775
3
- quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
4
- quack/reduction_base.py,sha256=fFuGXPR3lDq2yw_m86ujmkni6R51jzNAzy_r9R6C8tA,3563
5
- quack/rmsnorm.py,sha256=rbRP_O-EJYvrvxYPFwfy0yCbG7Qs_DgbzgacxTLvih4,24159
6
- quack/softmax.py,sha256=b-QEQiGjOEPidolE2--K21kwEaLJ-3wQoGIz_BsEcSI,16742
7
- quack/utils.py,sha256=laz_lqeggiIOY_xCs3c3VvCPP8bJmSKBfjFpFR1Al80,15946
8
- quack_kernels-0.1.6.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
9
- quack_kernels-0.1.6.dist-info/METADATA,sha256=NKhSoudW9lNdYryNiMkyjKXYl5n37cuFMbyLv3zr5js,289
10
- quack_kernels-0.1.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
11
- quack_kernels-0.1.6.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
12
- quack_kernels-0.1.6.dist-info/RECORD,,