quack-kernels 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.7"
1
+ __version__ = "0.1.8"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
quack/cross_entropy.py CHANGED
@@ -1,16 +1,16 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
3
  import math
4
- import torch
5
4
  from typing import Optional, Type
6
5
 
7
6
  import cuda.bindings.driver as cuda
8
7
 
9
8
  import cutlass
10
9
  import cutlass.cute as cute
11
- from cutlass.cute.runtime import from_dlpack
12
10
 
13
11
  import quack.utils as utils
12
+ import torch
13
+ from cutlass.cute.runtime import from_dlpack
14
14
  from quack.reduction_base import ReductionBase, torch2cute_dtype_map
15
15
 
16
16
 
@@ -79,7 +79,7 @@ class CrossEntropy(ReductionBase):
79
79
  self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
80
80
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
81
81
  block=[num_threads, 1, 1],
82
- cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
82
+ cluster=([1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None),
83
83
  smem=self._smem_size_in_bytes(tiler_mn, num_warps),
84
84
  stream=stream,
85
85
  )
@@ -111,7 +111,9 @@ class CrossEntropy(ReductionBase):
111
111
 
112
112
  smem = cutlass.utils.SmemAllocator()
113
113
  sX = smem.allocate_tensor(
114
- mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
114
+ mX.element_type,
115
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
116
+ byte_alignment=16,
115
117
  )
116
118
  reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
117
119
 
@@ -166,7 +168,9 @@ class CrossEntropy(ReductionBase):
166
168
  reduction_buffer[None, None, 0],
167
169
  mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
168
170
  init_val=-cutlass.Float32.inf,
169
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
171
+ hook_fn=(
172
+ cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
173
+ ),
170
174
  )
171
175
  if cutlass.const_expr(self.reload_from == "smem"):
172
176
  cute.autovec_copy(tXsX, tXrX)
@@ -191,7 +195,9 @@ class CrossEntropy(ReductionBase):
191
195
  threads_per_row,
192
196
  reduction_buffer[None, None, 0],
193
197
  mbar_ptr,
194
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
198
+ hook_fn=(
199
+ cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
200
+ ),
195
201
  )
196
202
 
197
203
  if (
@@ -225,7 +231,11 @@ def _cross_entropy(
225
231
  assert target.dim() == 1, "Target must be 1D"
226
232
  assert x.shape[0] == target.shape[0], "Batch dimensions must match"
227
233
  assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
228
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
234
+ assert x.dtype in [
235
+ torch.float16,
236
+ torch.bfloat16,
237
+ torch.float32,
238
+ ], "Unsupported input dtype"
229
239
  assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
230
240
  M, N = x.shape
231
241
  device = x.device
@@ -314,13 +324,16 @@ class CrossEntropyBackward:
314
324
  num_threads = cute.size(tv_layout, mode=[0])
315
325
 
316
326
  mDLoss = cute.make_tensor(
317
- mDLoss.iterator, cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,)))
327
+ mDLoss.iterator,
328
+ cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,))),
318
329
  )
319
330
  mTarget = cute.make_tensor(
320
- mTarget.iterator, cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,)))
331
+ mTarget.iterator,
332
+ cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,))),
321
333
  )
322
334
  mLSE = cute.make_tensor(
323
- mLSE.iterator, cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,)))
335
+ mLSE.iterator,
336
+ cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,))),
324
337
  )
325
338
 
326
339
  smem_size = cute.size_in_bytes(
@@ -364,7 +377,9 @@ class CrossEntropyBackward:
364
377
 
365
378
  smem = cutlass.utils.SmemAllocator()
366
379
  sX = smem.allocate_tensor(
367
- mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
380
+ mX.element_type,
381
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
382
+ byte_alignment=16,
368
383
  )
369
384
 
370
385
  idX = cute.make_identity_tensor(shape)
@@ -474,7 +489,11 @@ def _cross_entropy_backward(
474
489
  assert (
475
490
  x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
476
491
  ), "Tensors must be on CUDA device"
477
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
492
+ assert x.dtype in [
493
+ torch.float16,
494
+ torch.bfloat16,
495
+ torch.float32,
496
+ ], "Unsupported input dtype"
478
497
  assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
479
498
 
480
499
  M, N = x.shape
@@ -532,15 +551,37 @@ class CrossEntropyFunction(torch.autograd.Function):
532
551
 
533
552
 
534
553
  def cross_entropy(
535
- x: torch.Tensor, target: torch.Tensor, inplace_backward: bool = False
554
+ x: torch.Tensor,
555
+ target: torch.Tensor,
556
+ inplace_backward: bool = True,
557
+ reduction: str = "none",
536
558
  ) -> torch.Tensor:
537
559
  """Cross entropy loss with automatic differentiation support.
538
560
 
539
561
  Args:
540
562
  x: Input logits tensor of shape (M, N)
541
563
  target: Target class indices tensor of shape (M,)
564
+ inplace_backward: Whether to perform backward pass in-place
565
+ reduction: Specifies the reduction to apply to the output:
566
+ 'none': no reduction will be applied (default)
567
+ 'mean': the sum of the output will be divided by the number of elements
568
+ 'sum': the output will be summed
542
569
 
543
570
  Returns:
544
- Cross entropy loss tensor of shape (M,)
571
+ Cross entropy loss tensor:
572
+ - If reduction='none': tensor of shape (M,) with per-example losses
573
+ - If reduction='mean': scalar tensor with mean loss
574
+ - If reduction='sum': scalar tensor with sum of losses
545
575
  """
546
- return CrossEntropyFunction.apply(x, target, inplace_backward)
576
+ loss = CrossEntropyFunction.apply(x, target, inplace_backward)
577
+
578
+ if reduction == "mean":
579
+ return loss.mean()
580
+ elif reduction == "sum":
581
+ return loss.sum()
582
+ elif reduction == "none":
583
+ return loss
584
+ else:
585
+ raise ValueError(
586
+ f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', or 'sum'"
587
+ )
quack/rmsnorm.py CHANGED
@@ -1,14 +1,16 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
- import torch
4
3
  from typing import Optional
5
4
 
6
5
  import cuda.bindings.driver as cuda
7
6
 
8
7
  import cutlass
9
8
  import cutlass.cute as cute
9
+ from cutlass import Float32, Int32
10
10
  from cutlass.cute.runtime import from_dlpack
11
+
11
12
  import quack.utils as utils
13
+ import torch
12
14
  from quack.reduction_base import ReductionBase, torch2cute_dtype_map
13
15
 
14
16
 
@@ -19,41 +21,55 @@ class RMSNorm(ReductionBase):
19
21
  self.delay_w_load = False
20
22
 
21
23
  def _calculate_threads_per_row(self):
24
+ """Calculate the number of threads per row for the RMSNorm kernel."""
22
25
  N = self.N
23
- return (
24
- 8
25
- if N <= 64
26
- else (
27
- 16
28
- if N <= 128
29
- else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
30
- )
31
- )
26
+ if N <= 64:
27
+ return 8
28
+ elif N <= 128:
29
+ return 16
30
+ elif N <= 3072:
31
+ return 32
32
+ elif N <= 6144:
33
+ return 64
34
+ elif N <= 16384:
35
+ return 128
36
+ else:
37
+ return 256
32
38
 
33
39
  def _set_cluster_n(self):
40
+ """
41
+ Set the number of clusters for the RMSNorm kernel.
42
+ Stored in self.cluster_n.
43
+ """
34
44
  N = self.N
45
+
35
46
  # cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
36
47
  # Similarly cluster_n = 8 is faster for N=128k
37
48
  if cutlass.const_expr(self.dtype.width == 16):
38
- cluster_n = (
39
- 1
40
- if N <= 16 * 1024
41
- else (
42
- 2
43
- if N <= 32 * 1024
44
- else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
45
- )
46
- )
47
- else: # fp32
48
- cluster_n = (
49
- 1
50
- if N <= 32 * 1024
51
- else (
52
- 2
53
- if N <= 64 * 1024
54
- else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
55
- )
56
- )
49
+ # 16-bit types (fp16, bf16)
50
+ if N <= 16 * 1024:
51
+ cluster_n = 1
52
+ elif N <= 32 * 1024:
53
+ cluster_n = 2
54
+ elif N <= 64 * 1024:
55
+ cluster_n = 4
56
+ elif N <= 128 * 1024:
57
+ cluster_n = 8
58
+ else:
59
+ cluster_n = 16
60
+ else:
61
+ # 32-bit types (fp32)
62
+ if N <= 32 * 1024:
63
+ cluster_n = 1
64
+ elif N <= 64 * 1024:
65
+ cluster_n = 2
66
+ elif N <= 128 * 1024:
67
+ cluster_n = 4
68
+ elif N <= 256 * 1024:
69
+ cluster_n = 8
70
+ else:
71
+ cluster_n = 16
72
+
57
73
  self.cluster_n = cluster_n
58
74
 
59
75
  @cute.jit
@@ -64,8 +80,17 @@ class RMSNorm(ReductionBase):
64
80
  mO: cute.Tensor,
65
81
  mRstd: Optional[cute.Tensor],
66
82
  stream: cuda.CUstream,
67
- eps: cutlass.Float32 = 1e-6,
83
+ eps: Float32 = 1e-6,
68
84
  ):
85
+ semistatic_shape = (*mX.shape[:-1], self.N) # Set last dimension to be statically N
86
+ new_stride = lambda t: (
87
+ cute.assume(t.stride[0], divby=128 // t.element_type.width),
88
+ t.stride[1],
89
+ )
90
+ mX, mO = [
91
+ cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
92
+ for t in (mX, mO)
93
+ ]
69
94
  assert mX.element_type == self.dtype
70
95
  assert mO.element_type == self.dtype
71
96
  self._set_cluster_n()
@@ -82,7 +107,7 @@ class RMSNorm(ReductionBase):
82
107
  self.kernel(mX, mW, mO, mRstd, eps, tv_layout, tiler_mn, self.reload_from).launch(
83
108
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
84
109
  block=[num_threads, 1, 1],
85
- cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
110
+ cluster=([1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None),
86
111
  smem=self._smem_size_in_bytes(tiler_mn, num_warps),
87
112
  stream=stream,
88
113
  )
@@ -109,7 +134,9 @@ class RMSNorm(ReductionBase):
109
134
 
110
135
  smem = cutlass.utils.SmemAllocator()
111
136
  sX = smem.allocate_tensor(
112
- mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
137
+ mX.element_type,
138
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
139
+ byte_alignment=16,
113
140
  )
114
141
  reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
115
142
 
@@ -184,7 +211,7 @@ class RMSNorm(ReductionBase):
184
211
  reduction_buffer[None, None, 0],
185
212
  mbar_ptr,
186
213
  init_val=0.0,
187
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
214
+ hook_fn=(cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None),
188
215
  )
189
216
  rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
190
217
  if cutlass.const_expr(mRstd is not None):
@@ -232,25 +259,36 @@ def _rmsnorm_fwd(
232
259
  assert weight.dim() == 1, "Weight must be 1D"
233
260
  assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
234
261
  assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
235
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
236
- assert weight.dtype == torch.float32, "Weight must be float32"
262
+ assert x.dtype in [
263
+ torch.float16,
264
+ torch.bfloat16,
265
+ torch.float32,
266
+ ], "Unsupported dtype"
267
+
268
+ assert weight.dtype in [
269
+ torch.float32,
270
+ torch.bfloat16,
271
+ torch.float16,
272
+ ], "Weight must be float32, float16 or bfloat16"
273
+
237
274
  M, N = x.shape
238
275
  device = x.device
239
276
  out = torch.empty_like(x)
240
277
  rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
241
278
  dtype = torch2cute_dtype_map[x.dtype]
279
+ # convert_from_dlpack = lambda x: (
280
+ # from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
281
+ # mode=0, divisibility=128 // dtype.width
282
+ # )
283
+ # )
242
284
  convert_from_dlpack = lambda x: (
243
- from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
244
- mode=0, stride_order=(0, 1)
245
- )
285
+ from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
246
286
  )
247
- x_tensor, out_tensor = [
248
- # utils.convert_from_dlpack(t, leading_dim=t.ndim - 1, divisibility=128 // dtype.width)
249
- convert_from_dlpack(t)
250
- for t in (x, out)
251
- ]
287
+ x_tensor, out_tensor = [convert_from_dlpack(t) for t in (x, out)]
288
+ # handle weight divisibility based on weight dtype
289
+ weight_dtype = torch2cute_dtype_map[weight.dtype]
252
290
  weight_tensor = utils.convert_from_dlpack(
253
- weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
291
+ weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
254
292
  )
255
293
  rstd_tensor = (
256
294
  from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
@@ -258,7 +296,7 @@ def _rmsnorm_fwd(
258
296
  else None
259
297
  )
260
298
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
261
- compile_key = (dtype, N, rstd is not None)
299
+ compile_key = (dtype, N, rstd is not None, weight.dtype)
262
300
  if compile_key not in _rmsnorm_fwd.compile_cache:
263
301
  rmsnorm_op = RMSNorm(dtype, N)
264
302
  _rmsnorm_fwd.compile_cache[compile_key] = cute.compile(
@@ -301,7 +339,8 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
301
339
  class RMSNormBackward(ReductionBase):
302
340
  def __init__(self, dtype: cutlass.Numeric, N: int):
303
341
  # 2 stages for double buffering when computing mean of x_hat * wdy
304
- super().__init__(dtype, N, stage=2, reduction_dtype=cutlass.Float32)
342
+ super().__init__(dtype, N, stage=2, reduction_dtype=Float32)
343
+ self.reload_wdy = None if N <= 16 * 1024 else "smem"
305
344
  if self.N > 128 * 1024 and self.dtype.width >= 32:
306
345
  # Not enough smem
307
346
  raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits")
@@ -348,9 +387,18 @@ class RMSNormBackward(ReductionBase):
348
387
  mRstd: cute.Tensor,
349
388
  mdX: cute.Tensor,
350
389
  mdW: cute.Tensor,
351
- sm_count: cutlass.Int32,
390
+ sm_count: Int32,
352
391
  stream: cuda.CUstream,
353
392
  ):
393
+ semistatic_shape = (*mX.shape[:-1], self.N) # Set last dimension to be statically N
394
+ new_stride = lambda t: (
395
+ cute.assume(t.stride[0], divby=128 // t.element_type.width),
396
+ t.stride[1],
397
+ )
398
+ mX, mdOut, mdX = [
399
+ cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
400
+ for t in (mX, mdOut, mdX)
401
+ ]
354
402
  self._set_cluster_n()
355
403
  tiler_mn, tv_layout = self._get_tv_layout()
356
404
  num_threads = cute.size(tv_layout, mode=[0])
@@ -460,7 +508,8 @@ class RMSNormBackward(ReductionBase):
460
508
 
461
509
  gdW = cute.local_tile(mdW, (1, tiler_mn[1]), (bidx_start, cluster_y))
462
510
  tdWgdW = thr_copy_dW.partition_D(gdW)
463
- tdWrdW = cute.make_fragment_like(tdWgdW, cutlass.Float32)
511
+ # Always compute partial weight gradients in fp32
512
+ tdWrdW = cute.make_fragment_like(tdWgdW, Float32)
464
513
  tXrdW = thr_copy_X.retile(tdWrdW)
465
514
 
466
515
  gX = cute.local_tile(mX, tiler_mn, (None, cluster_y))
@@ -513,9 +562,9 @@ class RMSNormBackward(ReductionBase):
513
562
 
514
563
  threads_per_row = tv_layout.shape[0][0]
515
564
  tXrdW.fill(0.0)
516
- stage = cutlass.Int32(0)
517
- producer_phase = cutlass.Int32(1)
518
- consumer_phase = cutlass.Int32(0)
565
+ stage = Int32(0)
566
+ producer_phase = Int32(1)
567
+ consumer_phase = Int32(0)
519
568
  for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
520
569
  row = tXcX[None, None, None, bidx][0][0]
521
570
  rstd = cutlass.Float.zero
@@ -538,10 +587,14 @@ class RMSNormBackward(ReductionBase):
538
587
  )
539
588
  elif tiler_mn[0] > 1:
540
589
  utils.fill_oob(
541
- tXsX[None, None, None, stage ^ 1], None, fill_value=mX.element_type.zero
590
+ tXsX[None, None, None, stage ^ 1],
591
+ None,
592
+ fill_value=mX.element_type.zero,
542
593
  )
543
594
  utils.fill_oob(
544
- tXsdOut[None, None, None, stage ^ 1], None, fill_value=mdOut.element_type.zero
595
+ tXsdOut[None, None, None, stage ^ 1],
596
+ None,
597
+ fill_value=mdOut.element_type.zero,
545
598
  )
546
599
  cute.arch.cp_async_commit_group()
547
600
  if row < M or tiler_mn[0] == 1:
@@ -561,12 +614,13 @@ class RMSNormBackward(ReductionBase):
561
614
  cute.ReductionOp.ADD,
562
615
  threads_per_row,
563
616
  reduction_buffer[None, None, stage],
564
- mbar_full_ptr + stage if cutlass.const_expr(self.cluster_n > 1) else None,
617
+ (mbar_full_ptr + stage if cutlass.const_expr(self.cluster_n > 1) else None),
565
618
  phase=consumer_phase,
566
619
  init_val=0.0,
567
620
  )
568
621
  / shape[1]
569
622
  )
623
+
570
624
  if cutlass.const_expr(self.cluster_n > 1):
571
625
  # It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
572
626
  # Requires adjusting the thread_count when initializing the mbar
@@ -576,12 +630,20 @@ class RMSNormBackward(ReductionBase):
576
630
  cute.arch.mbarrier_arrive(
577
631
  mbar_empty_ptr + stage, peer_cta_rank_in_cluster=lane_idx
578
632
  )
633
+
634
+ if cutlass.const_expr(self.reload_wdy == "smem"):
635
+ cute.autovec_copy(tXsdOut[None, None, None, stage], tXrdOut)
636
+ dout = tXrdOut.load().to(cute.Float32)
637
+ wdy = dout * weight
638
+
579
639
  dx = (wdy - x_hat * mean_xhat_wdy) * rstd
580
640
  tXrdX.store(dx.to(tXrdOut.element_type))
581
641
  if row < M or tiler_mn[0] == 1:
582
642
  tXgdX_cur = utils.coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
583
643
  cute.copy(copy_atom_store_dX, tXrdX, tXgdX_cur, pred=tXpX)
644
+ # Accumulate weight gradients in fp32
584
645
  tXrdW.store(tXrdW.load() + dout * x_hat)
646
+
585
647
  stage ^= 1
586
648
  if stage == 0:
587
649
  consumer_phase ^= 1
@@ -609,7 +671,9 @@ class RMSNormBackward(ReductionBase):
609
671
  cute.autovec_copy(tXsdW_other, tXrdW_other)
610
672
  tXrdW.store(tXrdW.load() + tXrdW_other.load())
611
673
  cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
674
+
612
675
  else:
676
+ # dw is already in fp32, so we can directly copy to global memory
613
677
  cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
614
678
 
615
679
 
@@ -634,8 +698,17 @@ def _rmsnorm_backward(
634
698
  assert weight.dim() == 1, "Weight must be 1D"
635
699
  assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
636
700
  assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
637
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
638
- assert weight.dtype == torch.float32, "Weight must be float32"
701
+ assert x.dtype in [
702
+ torch.float16,
703
+ torch.bfloat16,
704
+ torch.float32,
705
+ ], "Unsupported dtype"
706
+
707
+ assert weight.dtype in [
708
+ torch.float32,
709
+ torch.bfloat16,
710
+ torch.float16,
711
+ ], "Weight must be float32, float16 or bfloat16"
639
712
 
640
713
  M, N = x.shape
641
714
  dx = torch.empty_like(x)
@@ -654,28 +727,29 @@ def _rmsnorm_backward(
654
727
  sm_count = (
655
728
  sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
656
729
  )
657
- dw_partial = torch.empty(sm_count, N, device=device, dtype=weight.dtype)
730
+
731
+ # Always store partial gradients in fp32 for numerical accuracy
732
+ dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
658
733
 
659
734
  dtype = torch2cute_dtype_map[x.dtype]
660
735
 
661
- convert_from_dlpack = lambda tensor: (
662
- from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
663
- mode=0, stride_order=(0, 1)
664
- )
736
+ convert_from_dlpack = lambda x: (
737
+ from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
665
738
  )
666
-
667
739
  x_tensor, dout_tensor, dx_tensor = [convert_from_dlpack(tensor) for tensor in (x, dout, dx)]
668
740
 
741
+ # Handle weight div based on weight dtype
742
+ weight_dtype = torch2cute_dtype_map[weight.dtype]
669
743
  weight_tensor = utils.convert_from_dlpack(
670
- weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
744
+ weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
671
745
  )
672
746
 
673
- dw_partial_tensor = convert_from_dlpack(dw_partial)
747
+ dw_partial_tensor = from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
674
748
  rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
675
749
 
676
750
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
677
751
 
678
- compile_key = (dtype, N)
752
+ compile_key = (dtype, N, weight.dtype)
679
753
  if compile_key not in _rmsnorm_backward.compile_cache:
680
754
  rmsnorm_backward_op = RMSNormBackward(dtype, N)
681
755
  _rmsnorm_backward.compile_cache[compile_key] = cute.compile(
@@ -700,7 +774,7 @@ def _rmsnorm_backward(
700
774
  sm_count,
701
775
  current_stream,
702
776
  )
703
-
777
+ # we have summed the partial gradients in fp32, now we convert back to the weight dtype
704
778
  dw = dw_partial.sum(dim=0).to(weight.dtype)
705
779
  return dx, dw
706
780
 
@@ -711,16 +785,29 @@ _rmsnorm_backward.compile_cache = {}
711
785
  class RMSNormFunction(torch.autograd.Function):
712
786
  @staticmethod
713
787
  def forward(ctx, x, weight, eps):
788
+ x_shape_start = x.shape
789
+
790
+ # Flatten input
791
+ x = x.view(-1, x.shape[-1])
792
+
714
793
  out, rstd = _rmsnorm_fwd(x, weight, eps, return_rstd=True)
715
794
  ctx.save_for_backward(x, weight, rstd)
716
795
  ctx.eps = eps
717
- return out
796
+ ctx.x_shape_start = x_shape_start
797
+
798
+ return out.reshape(x_shape_start)
718
799
 
719
800
  @staticmethod
720
801
  def backward(ctx, dout):
721
802
  x, weight, rstd = ctx.saved_tensors
803
+ x_shape_start = ctx.x_shape_start
804
+ # Reshape dout to match the flattened shape used in forward
805
+ dout = dout.view(-1, dout.shape[-1])
722
806
  dx, dw = _rmsnorm_backward(x, weight, dout, rstd)
723
- # dw is returned for weight gradient, None for eps gradient
807
+ dx = dx.view(x_shape_start)
808
+ # dx is returned for input gradient,
809
+ # dw is returned for weight gradient,
810
+ # None for eps gradient
724
811
  return dx, dw, None
725
812
 
726
813
 
@@ -736,3 +823,39 @@ def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.T
736
823
  Normalized output tensor of same shape as x
737
824
  """
738
825
  return RMSNormFunction.apply(x, weight, eps)
826
+
827
+
828
+ class QuackRMSNorm(torch.nn.Module):
829
+ """RMSNorm module that behaves like torch.nn.RMSNorm.
830
+
831
+ This class provides a drop-in replacement for torch.nn.RMSNorm that uses
832
+ the quack.rmsnorm implementation under the hood.
833
+
834
+ Args:
835
+ dim (int): The dimension to normalize over
836
+ eps (float, optional): A small constant for numerical stability. Default: 1e-6
837
+
838
+ Attributes:
839
+ weight (torch.nn.Parameter): The learnable weight parameter
840
+ eps (float): A small constant for numerical stability
841
+ """
842
+
843
+ def __init__(self, dim: int, eps: float = 1e-6):
844
+ super().__init__()
845
+ self.weight = torch.nn.Parameter(torch.ones(dim))
846
+ self.eps = eps
847
+
848
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
849
+ """Apply RMSNorm to the input tensor.
850
+
851
+ Args:
852
+ x (torch.Tensor): Input tensor
853
+
854
+ Returns:
855
+ torch.Tensor: Normalized tensor
856
+ """
857
+ return rmsnorm(x, self.weight, self.eps)
858
+
859
+ def reset_parameters(self):
860
+ """Reset the weight parameter to ones."""
861
+ torch.nn.init.ones_(self.weight)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.7
3
+ Version: 0.1.8
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
6
  Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
@@ -0,0 +1,12 @@
1
+ quack/__init__.py,sha256=tDgX5MF1ttfEyDVFWi47DA8tDooYcBQlkuzvabGUoQI,203
2
+ quack/cross_entropy.py,sha256=VYSAd28GmtnMoKQwLrorvySDtJfRhoqVd-aeM52FmsI,20866
3
+ quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
4
+ quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
5
+ quack/rmsnorm.py,sha256=-qrKqPKk0fUuq0a5-vJmZZ7nQsHgyaqTg0EKhWT44r0,32738
6
+ quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
7
+ quack/utils.py,sha256=aiyzBc9BEwq8s965elfiR331hAaLLBKL9kDHjuls86Q,17791
8
+ quack_kernels-0.1.8.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
9
+ quack_kernels-0.1.8.dist-info/METADATA,sha256=b_2PxFEoVqWJbT2FtuP9FJyF-jpL2Z3q9OHoOEipqo4,289
10
+ quack_kernels-0.1.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
11
+ quack_kernels-0.1.8.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
12
+ quack_kernels-0.1.8.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- quack/__init__.py,sha256=R9cZd_vslI5oZjjS-ojfWAd9tCZAqsLUiFVqEbUaGnw,203
2
- quack/cross_entropy.py,sha256=bg66wECki5I71SMPIRUa-6-oFJ93aIKpK1jqT__SCBM,19775
3
- quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
4
- quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
5
- quack/rmsnorm.py,sha256=3jiwWhVmaG0n5vuUnGGrpg3StAB4lnzziNF97QVMLGQ,28870
6
- quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
7
- quack/utils.py,sha256=aiyzBc9BEwq8s965elfiR331hAaLLBKL9kDHjuls86Q,17791
8
- quack_kernels-0.1.7.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
9
- quack_kernels-0.1.7.dist-info/METADATA,sha256=9RlqUmX3-7BI2aZk88r84B8o2FzZkQgkfV1UxwN8GlE,289
10
- quack_kernels-0.1.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
11
- quack_kernels-0.1.7.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
12
- quack_kernels-0.1.7.dist-info/RECORD,,