quack-kernels 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.7"
1
+ __version__ = "0.1.9"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
quack/cross_entropy.py CHANGED
@@ -1,16 +1,16 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
3
  import math
4
- import torch
5
4
  from typing import Optional, Type
6
5
 
7
6
  import cuda.bindings.driver as cuda
8
7
 
9
8
  import cutlass
10
9
  import cutlass.cute as cute
11
- from cutlass.cute.runtime import from_dlpack
12
10
 
13
11
  import quack.utils as utils
12
+ import torch
13
+ from cutlass.cute.runtime import from_dlpack
14
14
  from quack.reduction_base import ReductionBase, torch2cute_dtype_map
15
15
 
16
16
 
@@ -79,7 +79,7 @@ class CrossEntropy(ReductionBase):
79
79
  self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
80
80
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
81
81
  block=[num_threads, 1, 1],
82
- cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
82
+ cluster=([1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None),
83
83
  smem=self._smem_size_in_bytes(tiler_mn, num_warps),
84
84
  stream=stream,
85
85
  )
@@ -111,7 +111,9 @@ class CrossEntropy(ReductionBase):
111
111
 
112
112
  smem = cutlass.utils.SmemAllocator()
113
113
  sX = smem.allocate_tensor(
114
- mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
114
+ mX.element_type,
115
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
116
+ byte_alignment=16,
115
117
  )
116
118
  reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
117
119
 
@@ -166,7 +168,9 @@ class CrossEntropy(ReductionBase):
166
168
  reduction_buffer[None, None, 0],
167
169
  mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
168
170
  init_val=-cutlass.Float32.inf,
169
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
171
+ hook_fn=(
172
+ cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
173
+ ),
170
174
  )
171
175
  if cutlass.const_expr(self.reload_from == "smem"):
172
176
  cute.autovec_copy(tXsX, tXrX)
@@ -191,7 +195,9 @@ class CrossEntropy(ReductionBase):
191
195
  threads_per_row,
192
196
  reduction_buffer[None, None, 0],
193
197
  mbar_ptr,
194
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
198
+ hook_fn=(
199
+ cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
200
+ ),
195
201
  )
196
202
 
197
203
  if (
@@ -225,7 +231,11 @@ def _cross_entropy(
225
231
  assert target.dim() == 1, "Target must be 1D"
226
232
  assert x.shape[0] == target.shape[0], "Batch dimensions must match"
227
233
  assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
228
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
234
+ assert x.dtype in [
235
+ torch.float16,
236
+ torch.bfloat16,
237
+ torch.float32,
238
+ ], "Unsupported input dtype"
229
239
  assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
230
240
  M, N = x.shape
231
241
  device = x.device
@@ -314,13 +324,16 @@ class CrossEntropyBackward:
314
324
  num_threads = cute.size(tv_layout, mode=[0])
315
325
 
316
326
  mDLoss = cute.make_tensor(
317
- mDLoss.iterator, cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,)))
327
+ mDLoss.iterator,
328
+ cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,))),
318
329
  )
319
330
  mTarget = cute.make_tensor(
320
- mTarget.iterator, cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,)))
331
+ mTarget.iterator,
332
+ cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,))),
321
333
  )
322
334
  mLSE = cute.make_tensor(
323
- mLSE.iterator, cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,)))
335
+ mLSE.iterator,
336
+ cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,))),
324
337
  )
325
338
 
326
339
  smem_size = cute.size_in_bytes(
@@ -364,7 +377,9 @@ class CrossEntropyBackward:
364
377
 
365
378
  smem = cutlass.utils.SmemAllocator()
366
379
  sX = smem.allocate_tensor(
367
- mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
380
+ mX.element_type,
381
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
382
+ byte_alignment=16,
368
383
  )
369
384
 
370
385
  idX = cute.make_identity_tensor(shape)
@@ -474,7 +489,11 @@ def _cross_entropy_backward(
474
489
  assert (
475
490
  x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
476
491
  ), "Tensors must be on CUDA device"
477
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
492
+ assert x.dtype in [
493
+ torch.float16,
494
+ torch.bfloat16,
495
+ torch.float32,
496
+ ], "Unsupported input dtype"
478
497
  assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
479
498
 
480
499
  M, N = x.shape
@@ -532,15 +551,37 @@ class CrossEntropyFunction(torch.autograd.Function):
532
551
 
533
552
 
534
553
  def cross_entropy(
535
- x: torch.Tensor, target: torch.Tensor, inplace_backward: bool = False
554
+ x: torch.Tensor,
555
+ target: torch.Tensor,
556
+ inplace_backward: bool = True,
557
+ reduction: str = "none",
536
558
  ) -> torch.Tensor:
537
559
  """Cross entropy loss with automatic differentiation support.
538
560
 
539
561
  Args:
540
562
  x: Input logits tensor of shape (M, N)
541
563
  target: Target class indices tensor of shape (M,)
564
+ inplace_backward: Whether to perform backward pass in-place
565
+ reduction: Specifies the reduction to apply to the output:
566
+ 'none': no reduction will be applied (default)
567
+ 'mean': the sum of the output will be divided by the number of elements
568
+ 'sum': the output will be summed
542
569
 
543
570
  Returns:
544
- Cross entropy loss tensor of shape (M,)
571
+ Cross entropy loss tensor:
572
+ - If reduction='none': tensor of shape (M,) with per-example losses
573
+ - If reduction='mean': scalar tensor with mean loss
574
+ - If reduction='sum': scalar tensor with sum of losses
545
575
  """
546
- return CrossEntropyFunction.apply(x, target, inplace_backward)
576
+ loss = CrossEntropyFunction.apply(x, target, inplace_backward)
577
+
578
+ if reduction == "mean":
579
+ return loss.mean()
580
+ elif reduction == "sum":
581
+ return loss.sum()
582
+ elif reduction == "none":
583
+ return loss
584
+ else:
585
+ raise ValueError(
586
+ f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', or 'sum'"
587
+ )
quack/rmsnorm.py CHANGED
@@ -1,14 +1,16 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
- import torch
4
3
  from typing import Optional
5
4
 
6
5
  import cuda.bindings.driver as cuda
7
6
 
8
7
  import cutlass
9
8
  import cutlass.cute as cute
9
+ from cutlass import Float32, Int32
10
10
  from cutlass.cute.runtime import from_dlpack
11
+
11
12
  import quack.utils as utils
13
+ import torch
12
14
  from quack.reduction_base import ReductionBase, torch2cute_dtype_map
13
15
 
14
16
 
@@ -19,41 +21,55 @@ class RMSNorm(ReductionBase):
19
21
  self.delay_w_load = False
20
22
 
21
23
  def _calculate_threads_per_row(self):
24
+ """Calculate the number of threads per row for the RMSNorm kernel."""
22
25
  N = self.N
23
- return (
24
- 8
25
- if N <= 64
26
- else (
27
- 16
28
- if N <= 128
29
- else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
30
- )
31
- )
26
+ if N <= 64:
27
+ return 8
28
+ elif N <= 128:
29
+ return 16
30
+ elif N <= 3072:
31
+ return 32
32
+ elif N <= 6144:
33
+ return 64
34
+ elif N <= 16384:
35
+ return 128
36
+ else:
37
+ return 256
32
38
 
33
39
  def _set_cluster_n(self):
40
+ """
41
+ Set the number of clusters for the RMSNorm kernel.
42
+ Stored in self.cluster_n.
43
+ """
34
44
  N = self.N
45
+
35
46
  # cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
36
47
  # Similarly cluster_n = 8 is faster for N=128k
37
48
  if cutlass.const_expr(self.dtype.width == 16):
38
- cluster_n = (
39
- 1
40
- if N <= 16 * 1024
41
- else (
42
- 2
43
- if N <= 32 * 1024
44
- else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
45
- )
46
- )
47
- else: # fp32
48
- cluster_n = (
49
- 1
50
- if N <= 32 * 1024
51
- else (
52
- 2
53
- if N <= 64 * 1024
54
- else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
55
- )
56
- )
49
+ # 16-bit types (fp16, bf16)
50
+ if N <= 16 * 1024:
51
+ cluster_n = 1
52
+ elif N <= 32 * 1024:
53
+ cluster_n = 2
54
+ elif N <= 64 * 1024:
55
+ cluster_n = 4
56
+ elif N <= 128 * 1024:
57
+ cluster_n = 8
58
+ else:
59
+ cluster_n = 16
60
+ else:
61
+ # 32-bit types (fp32)
62
+ if N <= 32 * 1024:
63
+ cluster_n = 1
64
+ elif N <= 64 * 1024:
65
+ cluster_n = 2
66
+ elif N <= 128 * 1024:
67
+ cluster_n = 4
68
+ elif N <= 256 * 1024:
69
+ cluster_n = 8
70
+ else:
71
+ cluster_n = 16
72
+
57
73
  self.cluster_n = cluster_n
58
74
 
59
75
  @cute.jit
@@ -64,8 +80,17 @@ class RMSNorm(ReductionBase):
64
80
  mO: cute.Tensor,
65
81
  mRstd: Optional[cute.Tensor],
66
82
  stream: cuda.CUstream,
67
- eps: cutlass.Float32 = 1e-6,
83
+ eps: Float32 = 1e-6,
68
84
  ):
85
+ semistatic_shape = (*mX.shape[:-1], self.N) # Set last dimension to be statically N
86
+ new_stride = lambda t: (
87
+ cute.assume(t.stride[0], divby=128 // t.element_type.width),
88
+ t.stride[1],
89
+ )
90
+ mX, mO = [
91
+ cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
92
+ for t in (mX, mO)
93
+ ]
69
94
  assert mX.element_type == self.dtype
70
95
  assert mO.element_type == self.dtype
71
96
  self._set_cluster_n()
@@ -82,7 +107,7 @@ class RMSNorm(ReductionBase):
82
107
  self.kernel(mX, mW, mO, mRstd, eps, tv_layout, tiler_mn, self.reload_from).launch(
83
108
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
84
109
  block=[num_threads, 1, 1],
85
- cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
110
+ cluster=([1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None),
86
111
  smem=self._smem_size_in_bytes(tiler_mn, num_warps),
87
112
  stream=stream,
88
113
  )
@@ -109,7 +134,9 @@ class RMSNorm(ReductionBase):
109
134
 
110
135
  smem = cutlass.utils.SmemAllocator()
111
136
  sX = smem.allocate_tensor(
112
- mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
137
+ mX.element_type,
138
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
139
+ byte_alignment=16,
113
140
  )
114
141
  reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
115
142
 
@@ -134,30 +161,33 @@ class RMSNorm(ReductionBase):
134
161
  copy_atom_load_X_async = cute.make_copy_atom(
135
162
  cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
136
163
  )
164
+ num_bits_per_copy_W = cutlass.const_expr(
165
+ min(128, 128 // mX.element_type.width * mW.element_type.width)
166
+ )
137
167
  copy_atom_load_W = cute.make_copy_atom(
138
- cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
168
+ cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_bits_per_copy_W
169
+ )
170
+ num_bits_per_copy_O = cutlass.const_expr(
171
+ min(128, 128 // mX.element_type.width * mO.element_type.width)
139
172
  )
140
173
  copy_atom_store_O = cute.make_copy_atom(
141
- cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=128
174
+ cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=num_bits_per_copy_O
142
175
  )
143
176
 
144
177
  thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
145
178
  tidx
146
179
  )
147
- thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
148
- thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
149
180
 
150
- tWgW = thr_copy_W.partition_S(gW)
181
+ tXgW = thr_copy_X.partition_S(gW)
151
182
  tXgX = thr_copy_X.partition_S(gX)
152
183
  tXsX = thr_copy_X.partition_D(sX)
153
- tXgO = thr_copy_O.partition_D(gO)
154
- tXrRstd = thr_copy_O.partition_D(gRstd) if cutlass.const_expr(mRstd is not None) else None
184
+ tXgO = thr_copy_X.partition_D(gO)
185
+ tXrRstd = thr_copy_X.partition_D(gRstd) if cutlass.const_expr(mRstd is not None) else None
155
186
  tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
156
187
 
157
188
  # allocate fragments for gmem->rmem
158
- tWrW = cute.make_fragment_like(tWgW)
159
- tWrW.fill(0.0)
160
- tXrW = thr_copy_X.retile(tWrW)
189
+ tXrW = cute.make_fragment_like(tXgW)
190
+ tXrW.fill(0.0)
161
191
  tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
162
192
 
163
193
  num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
@@ -169,9 +199,9 @@ class RMSNorm(ReductionBase):
169
199
  cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
170
200
  cute.arch.cp_async_commit_group()
171
201
 
172
- tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
202
+ tXpW = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
173
203
  if cutlass.const_expr(not delay_w_load):
174
- cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
204
+ cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpW)
175
205
 
176
206
  cute.arch.cp_async_wait_group(0)
177
207
  cute.autovec_copy(tXsX, tXrX)
@@ -184,7 +214,7 @@ class RMSNorm(ReductionBase):
184
214
  reduction_buffer[None, None, 0],
185
215
  mbar_ptr,
186
216
  init_val=0.0,
187
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
217
+ hook_fn=(cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None),
188
218
  )
189
219
  rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
190
220
  if cutlass.const_expr(mRstd is not None):
@@ -196,7 +226,7 @@ class RMSNorm(ReductionBase):
196
226
  ):
197
227
  tXrRstd[0] = rstd
198
228
  if cutlass.const_expr(delay_w_load):
199
- cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
229
+ cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpW)
200
230
  if cutlass.const_expr(reload_from == "smem"):
201
231
  cute.autovec_copy(tXsX, tXrX)
202
232
  x = tXrX.load().to(cute.Float32)
@@ -207,9 +237,9 @@ class RMSNorm(ReductionBase):
207
237
  w = tXrW.load().to(cute.Float32)
208
238
  y = x_hat * w
209
239
  tXrO.store(y.to(tXrO.element_type))
210
- tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
240
+ tXpO = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
211
241
  if row < shape[0]:
212
- cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
242
+ cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tXpO)
213
243
 
214
244
 
215
245
  def _rmsnorm_fwd(
@@ -232,25 +262,36 @@ def _rmsnorm_fwd(
232
262
  assert weight.dim() == 1, "Weight must be 1D"
233
263
  assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
234
264
  assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
235
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
236
- assert weight.dtype == torch.float32, "Weight must be float32"
265
+ assert x.dtype in [
266
+ torch.float16,
267
+ torch.bfloat16,
268
+ torch.float32,
269
+ ], "Unsupported dtype"
270
+
271
+ assert weight.dtype in [
272
+ torch.float32,
273
+ torch.bfloat16,
274
+ torch.float16,
275
+ ], "Weight must be float32, float16 or bfloat16"
276
+
237
277
  M, N = x.shape
238
278
  device = x.device
239
279
  out = torch.empty_like(x)
240
280
  rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
241
281
  dtype = torch2cute_dtype_map[x.dtype]
282
+ # convert_from_dlpack = lambda x: (
283
+ # from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
284
+ # mode=0, divisibility=128 // dtype.width
285
+ # )
286
+ # )
242
287
  convert_from_dlpack = lambda x: (
243
- from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
244
- mode=0, stride_order=(0, 1)
245
- )
288
+ from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
246
289
  )
247
- x_tensor, out_tensor = [
248
- # utils.convert_from_dlpack(t, leading_dim=t.ndim - 1, divisibility=128 // dtype.width)
249
- convert_from_dlpack(t)
250
- for t in (x, out)
251
- ]
290
+ x_tensor, out_tensor = [convert_from_dlpack(t) for t in (x, out)]
291
+ # handle weight divisibility based on weight dtype
292
+ weight_dtype = torch2cute_dtype_map[weight.dtype]
252
293
  weight_tensor = utils.convert_from_dlpack(
253
- weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
294
+ weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
254
295
  )
255
296
  rstd_tensor = (
256
297
  from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
@@ -258,7 +299,7 @@ def _rmsnorm_fwd(
258
299
  else None
259
300
  )
260
301
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
261
- compile_key = (dtype, N, rstd is not None)
302
+ compile_key = (dtype, N, rstd is not None, weight.dtype)
262
303
  if compile_key not in _rmsnorm_fwd.compile_cache:
263
304
  rmsnorm_op = RMSNorm(dtype, N)
264
305
  _rmsnorm_fwd.compile_cache[compile_key] = cute.compile(
@@ -301,7 +342,8 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
301
342
  class RMSNormBackward(ReductionBase):
302
343
  def __init__(self, dtype: cutlass.Numeric, N: int):
303
344
  # 2 stages for double buffering when computing mean of x_hat * wdy
304
- super().__init__(dtype, N, stage=2, reduction_dtype=cutlass.Float32)
345
+ super().__init__(dtype, N, stage=2, reduction_dtype=Float32)
346
+ self.reload_wdy = None if N <= 16 * 1024 else "smem"
305
347
  if self.N > 128 * 1024 and self.dtype.width >= 32:
306
348
  # Not enough smem
307
349
  raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits")
@@ -348,9 +390,18 @@ class RMSNormBackward(ReductionBase):
348
390
  mRstd: cute.Tensor,
349
391
  mdX: cute.Tensor,
350
392
  mdW: cute.Tensor,
351
- sm_count: cutlass.Int32,
393
+ sm_count: Int32,
352
394
  stream: cuda.CUstream,
353
395
  ):
396
+ semistatic_shape = (*mX.shape[:-1], self.N) # Set last dimension to be statically N
397
+ new_stride = lambda t: (
398
+ cute.assume(t.stride[0], divby=128 // t.element_type.width),
399
+ t.stride[1],
400
+ )
401
+ mX, mdOut, mdX = [
402
+ cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
403
+ for t in (mX, mdOut, mdX)
404
+ ]
354
405
  self._set_cluster_n()
355
406
  tiler_mn, tv_layout = self._get_tv_layout()
356
407
  num_threads = cute.size(tv_layout, mode=[0])
@@ -412,39 +463,41 @@ class RMSNormBackward(ReductionBase):
412
463
  copy_atom_load_X_async = cute.make_copy_atom(
413
464
  cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
414
465
  )
466
+ num_bits_per_copy_W = cutlass.const_expr(
467
+ min(128, 128 // mX.element_type.width * mW.element_type.width)
468
+ )
415
469
  copy_atom_load_W = cute.make_copy_atom(
416
- cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
470
+ cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_bits_per_copy_W
471
+ )
472
+ num_bits_per_copy_dX = cutlass.const_expr(
473
+ min(128, 128 // mX.element_type.width * mdX.element_type.width)
417
474
  )
418
475
  copy_atom_store_dX = cute.make_copy_atom(
419
- cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=128
476
+ cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_bits_per_copy_dX
477
+ )
478
+ num_bits_per_copy_dW = cutlass.const_expr(
479
+ min(128, 128 // mX.element_type.width * mdW.element_type.width)
420
480
  )
421
481
  copy_atom_store_dW = cute.make_copy_atom(
422
- cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=128
482
+ cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=num_bits_per_copy_dW
423
483
  )
424
484
 
425
485
  thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
426
- thr_copy_X_async = cute.make_tiled_copy(
427
- copy_atom_load_X_async, tv_layout, tiler_mn
428
- ).get_slice(tidx)
429
- thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
430
- thr_copy_dW = cute.make_tiled_copy(copy_atom_store_dW, tv_layout, tiler_mn).get_slice(tidx)
431
- thr_store_dX = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
432
486
 
433
487
  gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
434
- tWgW = thr_copy_W.partition_S(gW)
435
- tWrW = cute.make_fragment_like(tWgW)
488
+ tXgW = thr_copy_X.partition_S(gW)
489
+ tXrW = cute.make_fragment_like(tXgW)
436
490
  # Need this, otherwise rW can have arbitrary values that changes the reduction
437
491
  if not is_even_N:
438
- tWrW.fill(0.0)
439
- tXrW = thr_copy_X.retile(tWrW)
492
+ tXrW.fill(0.0)
440
493
 
441
494
  gW_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
442
- tWpW = (
443
- utils.predicate_k(thr_copy_W.partition_S(gW_coord), limit=shape[1])
495
+ tXpW = (
496
+ utils.predicate_k(thr_copy_X.partition_S(gW_coord), limit=shape[1])
444
497
  if not is_even_N
445
498
  else None
446
499
  )
447
- cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
500
+ cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpW)
448
501
  weight = tXrW.load().to(cute.Float32)
449
502
 
450
503
  num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
@@ -452,16 +505,16 @@ class RMSNormBackward(ReductionBase):
452
505
  self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True)
453
506
 
454
507
  dw_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
455
- tdWpdW = (
456
- utils.predicate_k(thr_copy_dW.partition_S(dw_coord), limit=shape[1])
508
+ tXpdW = (
509
+ utils.predicate_k(thr_copy_X.partition_S(dw_coord), limit=shape[1])
457
510
  if not is_even_N
458
511
  else None
459
512
  )
460
513
 
461
514
  gdW = cute.local_tile(mdW, (1, tiler_mn[1]), (bidx_start, cluster_y))
462
- tdWgdW = thr_copy_dW.partition_D(gdW)
463
- tdWrdW = cute.make_fragment_like(tdWgdW, cutlass.Float32)
464
- tXrdW = thr_copy_X.retile(tdWrdW)
515
+ tXgdW = thr_copy_X.partition_S(gdW)
516
+ # Always compute partial weight gradients in fp32
517
+ tXrdW = cute.make_fragment_like(tXgdW, Float32)
465
518
 
466
519
  gX = cute.local_tile(mX, tiler_mn, (None, cluster_y))
467
520
  gdOut = cute.local_tile(mdOut, tiler_mn, (None, cluster_y))
@@ -471,7 +524,7 @@ class RMSNormBackward(ReductionBase):
471
524
  tXsX = thr_copy_X.partition_D(sX)
472
525
  tXgdOut = thr_copy_X.partition_S(gdOut)
473
526
  tXsdOut = thr_copy_X.partition_D(sdOut)
474
- tXgdX = thr_store_dX.partition_D(gdX)
527
+ tXgdX = thr_copy_X.partition_D(gdX)
475
528
  tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None]
476
529
  # This doesn't change across iterations
477
530
  tXpX = (
@@ -513,9 +566,9 @@ class RMSNormBackward(ReductionBase):
513
566
 
514
567
  threads_per_row = tv_layout.shape[0][0]
515
568
  tXrdW.fill(0.0)
516
- stage = cutlass.Int32(0)
517
- producer_phase = cutlass.Int32(1)
518
- consumer_phase = cutlass.Int32(0)
569
+ stage = Int32(0)
570
+ producer_phase = Int32(1)
571
+ consumer_phase = Int32(0)
519
572
  for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
520
573
  row = tXcX[None, None, None, bidx][0][0]
521
574
  rstd = cutlass.Float.zero
@@ -538,10 +591,14 @@ class RMSNormBackward(ReductionBase):
538
591
  )
539
592
  elif tiler_mn[0] > 1:
540
593
  utils.fill_oob(
541
- tXsX[None, None, None, stage ^ 1], None, fill_value=mX.element_type.zero
594
+ tXsX[None, None, None, stage ^ 1],
595
+ None,
596
+ fill_value=mX.element_type.zero,
542
597
  )
543
598
  utils.fill_oob(
544
- tXsdOut[None, None, None, stage ^ 1], None, fill_value=mdOut.element_type.zero
599
+ tXsdOut[None, None, None, stage ^ 1],
600
+ None,
601
+ fill_value=mdOut.element_type.zero,
545
602
  )
546
603
  cute.arch.cp_async_commit_group()
547
604
  if row < M or tiler_mn[0] == 1:
@@ -561,12 +618,13 @@ class RMSNormBackward(ReductionBase):
561
618
  cute.ReductionOp.ADD,
562
619
  threads_per_row,
563
620
  reduction_buffer[None, None, stage],
564
- mbar_full_ptr + stage if cutlass.const_expr(self.cluster_n > 1) else None,
621
+ (mbar_full_ptr + stage if cutlass.const_expr(self.cluster_n > 1) else None),
565
622
  phase=consumer_phase,
566
623
  init_val=0.0,
567
624
  )
568
625
  / shape[1]
569
626
  )
627
+
570
628
  if cutlass.const_expr(self.cluster_n > 1):
571
629
  # It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
572
630
  # Requires adjusting the thread_count when initializing the mbar
@@ -576,12 +634,20 @@ class RMSNormBackward(ReductionBase):
576
634
  cute.arch.mbarrier_arrive(
577
635
  mbar_empty_ptr + stage, peer_cta_rank_in_cluster=lane_idx
578
636
  )
637
+
638
+ if cutlass.const_expr(self.reload_wdy == "smem"):
639
+ cute.autovec_copy(tXsdOut[None, None, None, stage], tXrdOut)
640
+ dout = tXrdOut.load().to(cute.Float32)
641
+ wdy = dout * weight
642
+
579
643
  dx = (wdy - x_hat * mean_xhat_wdy) * rstd
580
644
  tXrdX.store(dx.to(tXrdOut.element_type))
581
645
  if row < M or tiler_mn[0] == 1:
582
646
  tXgdX_cur = utils.coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
583
647
  cute.copy(copy_atom_store_dX, tXrdX, tXgdX_cur, pred=tXpX)
648
+ # Accumulate weight gradients in fp32
584
649
  tXrdW.store(tXrdW.load() + dout * x_hat)
650
+
585
651
  stage ^= 1
586
652
  if stage == 0:
587
653
  consumer_phase ^= 1
@@ -608,9 +674,10 @@ class RMSNormBackward(ReductionBase):
608
674
  tXsdW_other = cute.make_tensor(tXsdW.iterator + i * sdW.stride[0], tXsdW.layout)
609
675
  cute.autovec_copy(tXsdW_other, tXrdW_other)
610
676
  tXrdW.store(tXrdW.load() + tXrdW_other.load())
611
- cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
677
+ cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
612
678
  else:
613
- cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
679
+ # dw is already in fp32, so we can directly copy to global memory
680
+ cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
614
681
 
615
682
 
616
683
  def _rmsnorm_backward(
@@ -634,8 +701,17 @@ def _rmsnorm_backward(
634
701
  assert weight.dim() == 1, "Weight must be 1D"
635
702
  assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
636
703
  assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
637
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
638
- assert weight.dtype == torch.float32, "Weight must be float32"
704
+ assert x.dtype in [
705
+ torch.float16,
706
+ torch.bfloat16,
707
+ torch.float32,
708
+ ], "Unsupported dtype"
709
+
710
+ assert weight.dtype in [
711
+ torch.float32,
712
+ torch.bfloat16,
713
+ torch.float16,
714
+ ], "Weight must be float32, float16 or bfloat16"
639
715
 
640
716
  M, N = x.shape
641
717
  dx = torch.empty_like(x)
@@ -654,28 +730,29 @@ def _rmsnorm_backward(
654
730
  sm_count = (
655
731
  sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
656
732
  )
657
- dw_partial = torch.empty(sm_count, N, device=device, dtype=weight.dtype)
733
+
734
+ # Always store partial gradients in fp32 for numerical accuracy
735
+ dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
658
736
 
659
737
  dtype = torch2cute_dtype_map[x.dtype]
660
738
 
661
- convert_from_dlpack = lambda tensor: (
662
- from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
663
- mode=0, stride_order=(0, 1)
664
- )
739
+ convert_from_dlpack = lambda x: (
740
+ from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
665
741
  )
666
-
667
742
  x_tensor, dout_tensor, dx_tensor = [convert_from_dlpack(tensor) for tensor in (x, dout, dx)]
668
743
 
744
+ # Handle weight div based on weight dtype
745
+ weight_dtype = torch2cute_dtype_map[weight.dtype]
669
746
  weight_tensor = utils.convert_from_dlpack(
670
- weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
747
+ weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
671
748
  )
672
749
 
673
- dw_partial_tensor = convert_from_dlpack(dw_partial)
750
+ dw_partial_tensor = from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
674
751
  rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
675
752
 
676
753
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
677
754
 
678
- compile_key = (dtype, N)
755
+ compile_key = (dtype, N, weight.dtype)
679
756
  if compile_key not in _rmsnorm_backward.compile_cache:
680
757
  rmsnorm_backward_op = RMSNormBackward(dtype, N)
681
758
  _rmsnorm_backward.compile_cache[compile_key] = cute.compile(
@@ -700,7 +777,7 @@ def _rmsnorm_backward(
700
777
  sm_count,
701
778
  current_stream,
702
779
  )
703
-
780
+ # we have summed the partial gradients in fp32, now we convert back to the weight dtype
704
781
  dw = dw_partial.sum(dim=0).to(weight.dtype)
705
782
  return dx, dw
706
783
 
@@ -711,16 +788,29 @@ _rmsnorm_backward.compile_cache = {}
711
788
  class RMSNormFunction(torch.autograd.Function):
712
789
  @staticmethod
713
790
  def forward(ctx, x, weight, eps):
791
+ x_shape_start = x.shape
792
+
793
+ # Flatten input
794
+ x = x.view(-1, x.shape[-1])
795
+
714
796
  out, rstd = _rmsnorm_fwd(x, weight, eps, return_rstd=True)
715
797
  ctx.save_for_backward(x, weight, rstd)
716
798
  ctx.eps = eps
717
- return out
799
+ ctx.x_shape_start = x_shape_start
800
+
801
+ return out.reshape(x_shape_start)
718
802
 
719
803
  @staticmethod
720
804
  def backward(ctx, dout):
721
805
  x, weight, rstd = ctx.saved_tensors
806
+ x_shape_start = ctx.x_shape_start
807
+ # Reshape dout to match the flattened shape used in forward
808
+ dout = dout.view(-1, dout.shape[-1])
722
809
  dx, dw = _rmsnorm_backward(x, weight, dout, rstd)
723
- # dw is returned for weight gradient, None for eps gradient
810
+ dx = dx.view(x_shape_start)
811
+ # dx is returned for input gradient,
812
+ # dw is returned for weight gradient,
813
+ # None for eps gradient
724
814
  return dx, dw, None
725
815
 
726
816
 
@@ -736,3 +826,39 @@ def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.T
736
826
  Normalized output tensor of same shape as x
737
827
  """
738
828
  return RMSNormFunction.apply(x, weight, eps)
829
+
830
+
831
+ class QuackRMSNorm(torch.nn.Module):
832
+ """RMSNorm module that behaves like torch.nn.RMSNorm.
833
+
834
+ This class provides a drop-in replacement for torch.nn.RMSNorm that uses
835
+ the quack.rmsnorm implementation under the hood.
836
+
837
+ Args:
838
+ dim (int): The dimension to normalize over
839
+ eps (float, optional): A small constant for numerical stability. Default: 1e-6
840
+
841
+ Attributes:
842
+ weight (torch.nn.Parameter): The learnable weight parameter
843
+ eps (float): A small constant for numerical stability
844
+ """
845
+
846
+ def __init__(self, dim: int, eps: float = 1e-6):
847
+ super().__init__()
848
+ self.weight = torch.nn.Parameter(torch.ones(dim))
849
+ self.eps = eps
850
+
851
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
852
+ """Apply RMSNorm to the input tensor.
853
+
854
+ Args:
855
+ x (torch.Tensor): Input tensor
856
+
857
+ Returns:
858
+ torch.Tensor: Normalized tensor
859
+ """
860
+ return rmsnorm(x, self.weight, self.eps)
861
+
862
+ def reset_parameters(self):
863
+ """Reset the weight parameter to ones."""
864
+ torch.nn.init.ones_(self.weight)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.7
3
+ Version: 0.1.9
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
6
  Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
@@ -0,0 +1,12 @@
1
+ quack/__init__.py,sha256=CT76CeRNh5bzQ9f13yVuRz9Sj7V3MvwzHH4fB1iQIf0,203
2
+ quack/cross_entropy.py,sha256=VYSAd28GmtnMoKQwLrorvySDtJfRhoqVd-aeM52FmsI,20866
3
+ quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
4
+ quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
5
+ quack/rmsnorm.py,sha256=bJEHqc8ila-LTGco-tNNCUyFBjJ2UdXeoMplYNJPXFI,32740
6
+ quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
7
+ quack/utils.py,sha256=aiyzBc9BEwq8s965elfiR331hAaLLBKL9kDHjuls86Q,17791
8
+ quack_kernels-0.1.9.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
9
+ quack_kernels-0.1.9.dist-info/METADATA,sha256=vOnpbShNHRiUXKAnOUxzfRM7zkpW3RmjW4hIgvYda08,289
10
+ quack_kernels-0.1.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
11
+ quack_kernels-0.1.9.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
12
+ quack_kernels-0.1.9.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- quack/__init__.py,sha256=R9cZd_vslI5oZjjS-ojfWAd9tCZAqsLUiFVqEbUaGnw,203
2
- quack/cross_entropy.py,sha256=bg66wECki5I71SMPIRUa-6-oFJ93aIKpK1jqT__SCBM,19775
3
- quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
4
- quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
5
- quack/rmsnorm.py,sha256=3jiwWhVmaG0n5vuUnGGrpg3StAB4lnzziNF97QVMLGQ,28870
6
- quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
7
- quack/utils.py,sha256=aiyzBc9BEwq8s965elfiR331hAaLLBKL9kDHjuls86Q,17791
8
- quack_kernels-0.1.7.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
9
- quack_kernels-0.1.7.dist-info/METADATA,sha256=9RlqUmX3-7BI2aZk88r84B8o2FzZkQgkfV1UxwN8GlE,289
10
- quack_kernels-0.1.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
11
- quack_kernels-0.1.7.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
12
- quack_kernels-0.1.7.dist-info/RECORD,,