quack-kernels 0.1.7__tar.gz → 0.1.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. {quack_kernels-0.1.7/quack_kernels.egg-info → quack_kernels-0.1.8}/PKG-INFO +1 -1
  2. {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack/__init__.py +1 -1
  3. {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack/cross_entropy.py +56 -15
  4. {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack/rmsnorm.py +191 -68
  5. {quack_kernels-0.1.7 → quack_kernels-0.1.8/quack_kernels.egg-info}/PKG-INFO +1 -1
  6. {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack_kernels.egg-info/top_level.txt +1 -0
  7. quack_kernels-0.1.8/tests/test_rmsnorm.py +392 -0
  8. quack_kernels-0.1.7/tests/test_rmsnorm.py +0 -183
  9. {quack_kernels-0.1.7 → quack_kernels-0.1.8}/LICENSE +0 -0
  10. {quack_kernels-0.1.7 → quack_kernels-0.1.8}/README.md +0 -0
  11. {quack_kernels-0.1.7 → quack_kernels-0.1.8}/pyproject.toml +0 -0
  12. {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack/layernorm.py +0 -0
  13. {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack/reduction_base.py +0 -0
  14. {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack/softmax.py +0 -0
  15. {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack/utils.py +0 -0
  16. {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack_kernels.egg-info/SOURCES.txt +0 -0
  17. {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack_kernels.egg-info/dependency_links.txt +0 -0
  18. {quack_kernels-0.1.7 → quack_kernels-0.1.8}/quack_kernels.egg-info/requires.txt +0 -0
  19. {quack_kernels-0.1.7 → quack_kernels-0.1.8}/setup.cfg +0 -0
  20. {quack_kernels-0.1.7 → quack_kernels-0.1.8}/setup.py +0 -0
  21. {quack_kernels-0.1.7 → quack_kernels-0.1.8}/tests/test_cross_entropy.py +0 -0
  22. {quack_kernels-0.1.7 → quack_kernels-0.1.8}/tests/test_layernorm.py +0 -0
  23. {quack_kernels-0.1.7 → quack_kernels-0.1.8}/tests/test_softmax.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.7
3
+ Version: 0.1.8
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
6
  Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.7"
1
+ __version__ = "0.1.8"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
@@ -1,16 +1,16 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
3
  import math
4
- import torch
5
4
  from typing import Optional, Type
6
5
 
7
6
  import cuda.bindings.driver as cuda
8
7
 
9
8
  import cutlass
10
9
  import cutlass.cute as cute
11
- from cutlass.cute.runtime import from_dlpack
12
10
 
13
11
  import quack.utils as utils
12
+ import torch
13
+ from cutlass.cute.runtime import from_dlpack
14
14
  from quack.reduction_base import ReductionBase, torch2cute_dtype_map
15
15
 
16
16
 
@@ -79,7 +79,7 @@ class CrossEntropy(ReductionBase):
79
79
  self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
80
80
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
81
81
  block=[num_threads, 1, 1],
82
- cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
82
+ cluster=([1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None),
83
83
  smem=self._smem_size_in_bytes(tiler_mn, num_warps),
84
84
  stream=stream,
85
85
  )
@@ -111,7 +111,9 @@ class CrossEntropy(ReductionBase):
111
111
 
112
112
  smem = cutlass.utils.SmemAllocator()
113
113
  sX = smem.allocate_tensor(
114
- mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
114
+ mX.element_type,
115
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
116
+ byte_alignment=16,
115
117
  )
116
118
  reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
117
119
 
@@ -166,7 +168,9 @@ class CrossEntropy(ReductionBase):
166
168
  reduction_buffer[None, None, 0],
167
169
  mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
168
170
  init_val=-cutlass.Float32.inf,
169
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
171
+ hook_fn=(
172
+ cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
173
+ ),
170
174
  )
171
175
  if cutlass.const_expr(self.reload_from == "smem"):
172
176
  cute.autovec_copy(tXsX, tXrX)
@@ -191,7 +195,9 @@ class CrossEntropy(ReductionBase):
191
195
  threads_per_row,
192
196
  reduction_buffer[None, None, 0],
193
197
  mbar_ptr,
194
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
198
+ hook_fn=(
199
+ cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
200
+ ),
195
201
  )
196
202
 
197
203
  if (
@@ -225,7 +231,11 @@ def _cross_entropy(
225
231
  assert target.dim() == 1, "Target must be 1D"
226
232
  assert x.shape[0] == target.shape[0], "Batch dimensions must match"
227
233
  assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
228
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
234
+ assert x.dtype in [
235
+ torch.float16,
236
+ torch.bfloat16,
237
+ torch.float32,
238
+ ], "Unsupported input dtype"
229
239
  assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
230
240
  M, N = x.shape
231
241
  device = x.device
@@ -314,13 +324,16 @@ class CrossEntropyBackward:
314
324
  num_threads = cute.size(tv_layout, mode=[0])
315
325
 
316
326
  mDLoss = cute.make_tensor(
317
- mDLoss.iterator, cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,)))
327
+ mDLoss.iterator,
328
+ cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,))),
318
329
  )
319
330
  mTarget = cute.make_tensor(
320
- mTarget.iterator, cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,)))
331
+ mTarget.iterator,
332
+ cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,))),
321
333
  )
322
334
  mLSE = cute.make_tensor(
323
- mLSE.iterator, cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,)))
335
+ mLSE.iterator,
336
+ cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,))),
324
337
  )
325
338
 
326
339
  smem_size = cute.size_in_bytes(
@@ -364,7 +377,9 @@ class CrossEntropyBackward:
364
377
 
365
378
  smem = cutlass.utils.SmemAllocator()
366
379
  sX = smem.allocate_tensor(
367
- mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
380
+ mX.element_type,
381
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
382
+ byte_alignment=16,
368
383
  )
369
384
 
370
385
  idX = cute.make_identity_tensor(shape)
@@ -474,7 +489,11 @@ def _cross_entropy_backward(
474
489
  assert (
475
490
  x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
476
491
  ), "Tensors must be on CUDA device"
477
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
492
+ assert x.dtype in [
493
+ torch.float16,
494
+ torch.bfloat16,
495
+ torch.float32,
496
+ ], "Unsupported input dtype"
478
497
  assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
479
498
 
480
499
  M, N = x.shape
@@ -532,15 +551,37 @@ class CrossEntropyFunction(torch.autograd.Function):
532
551
 
533
552
 
534
553
  def cross_entropy(
535
- x: torch.Tensor, target: torch.Tensor, inplace_backward: bool = False
554
+ x: torch.Tensor,
555
+ target: torch.Tensor,
556
+ inplace_backward: bool = True,
557
+ reduction: str = "none",
536
558
  ) -> torch.Tensor:
537
559
  """Cross entropy loss with automatic differentiation support.
538
560
 
539
561
  Args:
540
562
  x: Input logits tensor of shape (M, N)
541
563
  target: Target class indices tensor of shape (M,)
564
+ inplace_backward: Whether to perform backward pass in-place
565
+ reduction: Specifies the reduction to apply to the output:
566
+ 'none': no reduction will be applied (default)
567
+ 'mean': the sum of the output will be divided by the number of elements
568
+ 'sum': the output will be summed
542
569
 
543
570
  Returns:
544
- Cross entropy loss tensor of shape (M,)
571
+ Cross entropy loss tensor:
572
+ - If reduction='none': tensor of shape (M,) with per-example losses
573
+ - If reduction='mean': scalar tensor with mean loss
574
+ - If reduction='sum': scalar tensor with sum of losses
545
575
  """
546
- return CrossEntropyFunction.apply(x, target, inplace_backward)
576
+ loss = CrossEntropyFunction.apply(x, target, inplace_backward)
577
+
578
+ if reduction == "mean":
579
+ return loss.mean()
580
+ elif reduction == "sum":
581
+ return loss.sum()
582
+ elif reduction == "none":
583
+ return loss
584
+ else:
585
+ raise ValueError(
586
+ f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', or 'sum'"
587
+ )
@@ -1,14 +1,16 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
- import torch
4
3
  from typing import Optional
5
4
 
6
5
  import cuda.bindings.driver as cuda
7
6
 
8
7
  import cutlass
9
8
  import cutlass.cute as cute
9
+ from cutlass import Float32, Int32
10
10
  from cutlass.cute.runtime import from_dlpack
11
+
11
12
  import quack.utils as utils
13
+ import torch
12
14
  from quack.reduction_base import ReductionBase, torch2cute_dtype_map
13
15
 
14
16
 
@@ -19,41 +21,55 @@ class RMSNorm(ReductionBase):
19
21
  self.delay_w_load = False
20
22
 
21
23
  def _calculate_threads_per_row(self):
24
+ """Calculate the number of threads per row for the RMSNorm kernel."""
22
25
  N = self.N
23
- return (
24
- 8
25
- if N <= 64
26
- else (
27
- 16
28
- if N <= 128
29
- else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
30
- )
31
- )
26
+ if N <= 64:
27
+ return 8
28
+ elif N <= 128:
29
+ return 16
30
+ elif N <= 3072:
31
+ return 32
32
+ elif N <= 6144:
33
+ return 64
34
+ elif N <= 16384:
35
+ return 128
36
+ else:
37
+ return 256
32
38
 
33
39
  def _set_cluster_n(self):
40
+ """
41
+ Set the number of clusters for the RMSNorm kernel.
42
+ Stored in self.cluster_n.
43
+ """
34
44
  N = self.N
45
+
35
46
  # cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
36
47
  # Similarly cluster_n = 8 is faster for N=128k
37
48
  if cutlass.const_expr(self.dtype.width == 16):
38
- cluster_n = (
39
- 1
40
- if N <= 16 * 1024
41
- else (
42
- 2
43
- if N <= 32 * 1024
44
- else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
45
- )
46
- )
47
- else: # fp32
48
- cluster_n = (
49
- 1
50
- if N <= 32 * 1024
51
- else (
52
- 2
53
- if N <= 64 * 1024
54
- else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
55
- )
56
- )
49
+ # 16-bit types (fp16, bf16)
50
+ if N <= 16 * 1024:
51
+ cluster_n = 1
52
+ elif N <= 32 * 1024:
53
+ cluster_n = 2
54
+ elif N <= 64 * 1024:
55
+ cluster_n = 4
56
+ elif N <= 128 * 1024:
57
+ cluster_n = 8
58
+ else:
59
+ cluster_n = 16
60
+ else:
61
+ # 32-bit types (fp32)
62
+ if N <= 32 * 1024:
63
+ cluster_n = 1
64
+ elif N <= 64 * 1024:
65
+ cluster_n = 2
66
+ elif N <= 128 * 1024:
67
+ cluster_n = 4
68
+ elif N <= 256 * 1024:
69
+ cluster_n = 8
70
+ else:
71
+ cluster_n = 16
72
+
57
73
  self.cluster_n = cluster_n
58
74
 
59
75
  @cute.jit
@@ -64,8 +80,17 @@ class RMSNorm(ReductionBase):
64
80
  mO: cute.Tensor,
65
81
  mRstd: Optional[cute.Tensor],
66
82
  stream: cuda.CUstream,
67
- eps: cutlass.Float32 = 1e-6,
83
+ eps: Float32 = 1e-6,
68
84
  ):
85
+ semistatic_shape = (*mX.shape[:-1], self.N) # Set last dimension to be statically N
86
+ new_stride = lambda t: (
87
+ cute.assume(t.stride[0], divby=128 // t.element_type.width),
88
+ t.stride[1],
89
+ )
90
+ mX, mO = [
91
+ cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
92
+ for t in (mX, mO)
93
+ ]
69
94
  assert mX.element_type == self.dtype
70
95
  assert mO.element_type == self.dtype
71
96
  self._set_cluster_n()
@@ -82,7 +107,7 @@ class RMSNorm(ReductionBase):
82
107
  self.kernel(mX, mW, mO, mRstd, eps, tv_layout, tiler_mn, self.reload_from).launch(
83
108
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
84
109
  block=[num_threads, 1, 1],
85
- cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
110
+ cluster=([1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None),
86
111
  smem=self._smem_size_in_bytes(tiler_mn, num_warps),
87
112
  stream=stream,
88
113
  )
@@ -109,7 +134,9 @@ class RMSNorm(ReductionBase):
109
134
 
110
135
  smem = cutlass.utils.SmemAllocator()
111
136
  sX = smem.allocate_tensor(
112
- mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
137
+ mX.element_type,
138
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
139
+ byte_alignment=16,
113
140
  )
114
141
  reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
115
142
 
@@ -184,7 +211,7 @@ class RMSNorm(ReductionBase):
184
211
  reduction_buffer[None, None, 0],
185
212
  mbar_ptr,
186
213
  init_val=0.0,
187
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
214
+ hook_fn=(cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None),
188
215
  )
189
216
  rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
190
217
  if cutlass.const_expr(mRstd is not None):
@@ -232,25 +259,36 @@ def _rmsnorm_fwd(
232
259
  assert weight.dim() == 1, "Weight must be 1D"
233
260
  assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
234
261
  assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
235
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
236
- assert weight.dtype == torch.float32, "Weight must be float32"
262
+ assert x.dtype in [
263
+ torch.float16,
264
+ torch.bfloat16,
265
+ torch.float32,
266
+ ], "Unsupported dtype"
267
+
268
+ assert weight.dtype in [
269
+ torch.float32,
270
+ torch.bfloat16,
271
+ torch.float16,
272
+ ], "Weight must be float32, float16 or bfloat16"
273
+
237
274
  M, N = x.shape
238
275
  device = x.device
239
276
  out = torch.empty_like(x)
240
277
  rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
241
278
  dtype = torch2cute_dtype_map[x.dtype]
279
+ # convert_from_dlpack = lambda x: (
280
+ # from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
281
+ # mode=0, divisibility=128 // dtype.width
282
+ # )
283
+ # )
242
284
  convert_from_dlpack = lambda x: (
243
- from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
244
- mode=0, stride_order=(0, 1)
245
- )
285
+ from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
246
286
  )
247
- x_tensor, out_tensor = [
248
- # utils.convert_from_dlpack(t, leading_dim=t.ndim - 1, divisibility=128 // dtype.width)
249
- convert_from_dlpack(t)
250
- for t in (x, out)
251
- ]
287
+ x_tensor, out_tensor = [convert_from_dlpack(t) for t in (x, out)]
288
+ # handle weight divisibility based on weight dtype
289
+ weight_dtype = torch2cute_dtype_map[weight.dtype]
252
290
  weight_tensor = utils.convert_from_dlpack(
253
- weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
291
+ weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
254
292
  )
255
293
  rstd_tensor = (
256
294
  from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
@@ -258,7 +296,7 @@ def _rmsnorm_fwd(
258
296
  else None
259
297
  )
260
298
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
261
- compile_key = (dtype, N, rstd is not None)
299
+ compile_key = (dtype, N, rstd is not None, weight.dtype)
262
300
  if compile_key not in _rmsnorm_fwd.compile_cache:
263
301
  rmsnorm_op = RMSNorm(dtype, N)
264
302
  _rmsnorm_fwd.compile_cache[compile_key] = cute.compile(
@@ -301,7 +339,8 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
301
339
  class RMSNormBackward(ReductionBase):
302
340
  def __init__(self, dtype: cutlass.Numeric, N: int):
303
341
  # 2 stages for double buffering when computing mean of x_hat * wdy
304
- super().__init__(dtype, N, stage=2, reduction_dtype=cutlass.Float32)
342
+ super().__init__(dtype, N, stage=2, reduction_dtype=Float32)
343
+ self.reload_wdy = None if N <= 16 * 1024 else "smem"
305
344
  if self.N > 128 * 1024 and self.dtype.width >= 32:
306
345
  # Not enough smem
307
346
  raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits")
@@ -348,9 +387,18 @@ class RMSNormBackward(ReductionBase):
348
387
  mRstd: cute.Tensor,
349
388
  mdX: cute.Tensor,
350
389
  mdW: cute.Tensor,
351
- sm_count: cutlass.Int32,
390
+ sm_count: Int32,
352
391
  stream: cuda.CUstream,
353
392
  ):
393
+ semistatic_shape = (*mX.shape[:-1], self.N) # Set last dimension to be statically N
394
+ new_stride = lambda t: (
395
+ cute.assume(t.stride[0], divby=128 // t.element_type.width),
396
+ t.stride[1],
397
+ )
398
+ mX, mdOut, mdX = [
399
+ cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
400
+ for t in (mX, mdOut, mdX)
401
+ ]
354
402
  self._set_cluster_n()
355
403
  tiler_mn, tv_layout = self._get_tv_layout()
356
404
  num_threads = cute.size(tv_layout, mode=[0])
@@ -460,7 +508,8 @@ class RMSNormBackward(ReductionBase):
460
508
 
461
509
  gdW = cute.local_tile(mdW, (1, tiler_mn[1]), (bidx_start, cluster_y))
462
510
  tdWgdW = thr_copy_dW.partition_D(gdW)
463
- tdWrdW = cute.make_fragment_like(tdWgdW, cutlass.Float32)
511
+ # Always compute partial weight gradients in fp32
512
+ tdWrdW = cute.make_fragment_like(tdWgdW, Float32)
464
513
  tXrdW = thr_copy_X.retile(tdWrdW)
465
514
 
466
515
  gX = cute.local_tile(mX, tiler_mn, (None, cluster_y))
@@ -513,9 +562,9 @@ class RMSNormBackward(ReductionBase):
513
562
 
514
563
  threads_per_row = tv_layout.shape[0][0]
515
564
  tXrdW.fill(0.0)
516
- stage = cutlass.Int32(0)
517
- producer_phase = cutlass.Int32(1)
518
- consumer_phase = cutlass.Int32(0)
565
+ stage = Int32(0)
566
+ producer_phase = Int32(1)
567
+ consumer_phase = Int32(0)
519
568
  for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
520
569
  row = tXcX[None, None, None, bidx][0][0]
521
570
  rstd = cutlass.Float.zero
@@ -538,10 +587,14 @@ class RMSNormBackward(ReductionBase):
538
587
  )
539
588
  elif tiler_mn[0] > 1:
540
589
  utils.fill_oob(
541
- tXsX[None, None, None, stage ^ 1], None, fill_value=mX.element_type.zero
590
+ tXsX[None, None, None, stage ^ 1],
591
+ None,
592
+ fill_value=mX.element_type.zero,
542
593
  )
543
594
  utils.fill_oob(
544
- tXsdOut[None, None, None, stage ^ 1], None, fill_value=mdOut.element_type.zero
595
+ tXsdOut[None, None, None, stage ^ 1],
596
+ None,
597
+ fill_value=mdOut.element_type.zero,
545
598
  )
546
599
  cute.arch.cp_async_commit_group()
547
600
  if row < M or tiler_mn[0] == 1:
@@ -561,12 +614,13 @@ class RMSNormBackward(ReductionBase):
561
614
  cute.ReductionOp.ADD,
562
615
  threads_per_row,
563
616
  reduction_buffer[None, None, stage],
564
- mbar_full_ptr + stage if cutlass.const_expr(self.cluster_n > 1) else None,
617
+ (mbar_full_ptr + stage if cutlass.const_expr(self.cluster_n > 1) else None),
565
618
  phase=consumer_phase,
566
619
  init_val=0.0,
567
620
  )
568
621
  / shape[1]
569
622
  )
623
+
570
624
  if cutlass.const_expr(self.cluster_n > 1):
571
625
  # It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
572
626
  # Requires adjusting the thread_count when initializing the mbar
@@ -576,12 +630,20 @@ class RMSNormBackward(ReductionBase):
576
630
  cute.arch.mbarrier_arrive(
577
631
  mbar_empty_ptr + stage, peer_cta_rank_in_cluster=lane_idx
578
632
  )
633
+
634
+ if cutlass.const_expr(self.reload_wdy == "smem"):
635
+ cute.autovec_copy(tXsdOut[None, None, None, stage], tXrdOut)
636
+ dout = tXrdOut.load().to(cute.Float32)
637
+ wdy = dout * weight
638
+
579
639
  dx = (wdy - x_hat * mean_xhat_wdy) * rstd
580
640
  tXrdX.store(dx.to(tXrdOut.element_type))
581
641
  if row < M or tiler_mn[0] == 1:
582
642
  tXgdX_cur = utils.coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
583
643
  cute.copy(copy_atom_store_dX, tXrdX, tXgdX_cur, pred=tXpX)
644
+ # Accumulate weight gradients in fp32
584
645
  tXrdW.store(tXrdW.load() + dout * x_hat)
646
+
585
647
  stage ^= 1
586
648
  if stage == 0:
587
649
  consumer_phase ^= 1
@@ -609,7 +671,9 @@ class RMSNormBackward(ReductionBase):
609
671
  cute.autovec_copy(tXsdW_other, tXrdW_other)
610
672
  tXrdW.store(tXrdW.load() + tXrdW_other.load())
611
673
  cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
674
+
612
675
  else:
676
+ # dw is already in fp32, so we can directly copy to global memory
613
677
  cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
614
678
 
615
679
 
@@ -634,8 +698,17 @@ def _rmsnorm_backward(
634
698
  assert weight.dim() == 1, "Weight must be 1D"
635
699
  assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
636
700
  assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
637
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
638
- assert weight.dtype == torch.float32, "Weight must be float32"
701
+ assert x.dtype in [
702
+ torch.float16,
703
+ torch.bfloat16,
704
+ torch.float32,
705
+ ], "Unsupported dtype"
706
+
707
+ assert weight.dtype in [
708
+ torch.float32,
709
+ torch.bfloat16,
710
+ torch.float16,
711
+ ], "Weight must be float32, float16 or bfloat16"
639
712
 
640
713
  M, N = x.shape
641
714
  dx = torch.empty_like(x)
@@ -654,28 +727,29 @@ def _rmsnorm_backward(
654
727
  sm_count = (
655
728
  sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
656
729
  )
657
- dw_partial = torch.empty(sm_count, N, device=device, dtype=weight.dtype)
730
+
731
+ # Always store partial gradients in fp32 for numerical accuracy
732
+ dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
658
733
 
659
734
  dtype = torch2cute_dtype_map[x.dtype]
660
735
 
661
- convert_from_dlpack = lambda tensor: (
662
- from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
663
- mode=0, stride_order=(0, 1)
664
- )
736
+ convert_from_dlpack = lambda x: (
737
+ from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
665
738
  )
666
-
667
739
  x_tensor, dout_tensor, dx_tensor = [convert_from_dlpack(tensor) for tensor in (x, dout, dx)]
668
740
 
741
+ # Handle weight div based on weight dtype
742
+ weight_dtype = torch2cute_dtype_map[weight.dtype]
669
743
  weight_tensor = utils.convert_from_dlpack(
670
- weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
744
+ weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
671
745
  )
672
746
 
673
- dw_partial_tensor = convert_from_dlpack(dw_partial)
747
+ dw_partial_tensor = from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
674
748
  rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
675
749
 
676
750
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
677
751
 
678
- compile_key = (dtype, N)
752
+ compile_key = (dtype, N, weight.dtype)
679
753
  if compile_key not in _rmsnorm_backward.compile_cache:
680
754
  rmsnorm_backward_op = RMSNormBackward(dtype, N)
681
755
  _rmsnorm_backward.compile_cache[compile_key] = cute.compile(
@@ -700,7 +774,7 @@ def _rmsnorm_backward(
700
774
  sm_count,
701
775
  current_stream,
702
776
  )
703
-
777
+ # we have summed the partial gradients in fp32, now we convert back to the weight dtype
704
778
  dw = dw_partial.sum(dim=0).to(weight.dtype)
705
779
  return dx, dw
706
780
 
@@ -711,16 +785,29 @@ _rmsnorm_backward.compile_cache = {}
711
785
  class RMSNormFunction(torch.autograd.Function):
712
786
  @staticmethod
713
787
  def forward(ctx, x, weight, eps):
788
+ x_shape_start = x.shape
789
+
790
+ # Flatten input
791
+ x = x.view(-1, x.shape[-1])
792
+
714
793
  out, rstd = _rmsnorm_fwd(x, weight, eps, return_rstd=True)
715
794
  ctx.save_for_backward(x, weight, rstd)
716
795
  ctx.eps = eps
717
- return out
796
+ ctx.x_shape_start = x_shape_start
797
+
798
+ return out.reshape(x_shape_start)
718
799
 
719
800
  @staticmethod
720
801
  def backward(ctx, dout):
721
802
  x, weight, rstd = ctx.saved_tensors
803
+ x_shape_start = ctx.x_shape_start
804
+ # Reshape dout to match the flattened shape used in forward
805
+ dout = dout.view(-1, dout.shape[-1])
722
806
  dx, dw = _rmsnorm_backward(x, weight, dout, rstd)
723
- # dw is returned for weight gradient, None for eps gradient
807
+ dx = dx.view(x_shape_start)
808
+ # dx is returned for input gradient,
809
+ # dw is returned for weight gradient,
810
+ # None for eps gradient
724
811
  return dx, dw, None
725
812
 
726
813
 
@@ -736,3 +823,39 @@ def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.T
736
823
  Normalized output tensor of same shape as x
737
824
  """
738
825
  return RMSNormFunction.apply(x, weight, eps)
826
+
827
+
828
+ class QuackRMSNorm(torch.nn.Module):
829
+ """RMSNorm module that behaves like torch.nn.RMSNorm.
830
+
831
+ This class provides a drop-in replacement for torch.nn.RMSNorm that uses
832
+ the quack.rmsnorm implementation under the hood.
833
+
834
+ Args:
835
+ dim (int): The dimension to normalize over
836
+ eps (float, optional): A small constant for numerical stability. Default: 1e-6
837
+
838
+ Attributes:
839
+ weight (torch.nn.Parameter): The learnable weight parameter
840
+ eps (float): A small constant for numerical stability
841
+ """
842
+
843
+ def __init__(self, dim: int, eps: float = 1e-6):
844
+ super().__init__()
845
+ self.weight = torch.nn.Parameter(torch.ones(dim))
846
+ self.eps = eps
847
+
848
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
849
+ """Apply RMSNorm to the input tensor.
850
+
851
+ Args:
852
+ x (torch.Tensor): Input tensor
853
+
854
+ Returns:
855
+ torch.Tensor: Normalized tensor
856
+ """
857
+ return rmsnorm(x, self.weight, self.eps)
858
+
859
+ def reset_parameters(self):
860
+ """Reset the weight parameter to ones."""
861
+ torch.nn.init.ones_(self.weight)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.7
3
+ Version: 0.1.8
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
6
  Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
@@ -1,3 +1,4 @@
1
+ benchmarks
1
2
  dist
2
3
  media
3
4
  quack
@@ -0,0 +1,392 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ import pytest
4
+ import torch
5
+
6
+ from quack.rmsnorm import rmsnorm, rmsnorm_ref, rstd_ref, _rmsnorm_fwd
7
+
8
+ @pytest.mark.parametrize("eps", [1e-5, 1e-6])
9
+ # @pytest.mark.parametrize("eps", [1e-5])
10
+ @pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16, torch.float32])
11
+ # @pytest.mark.parametrize("input_dtype", [torch.float16])
12
+ @pytest.mark.parametrize(
13
+ "N",
14
+ [
15
+ 192,
16
+ 256,
17
+ 512,
18
+ 760,
19
+ 1024,
20
+ 1128,
21
+ 2048,
22
+ 4096,
23
+ 8192,
24
+ 16384,
25
+ 32768,
26
+ 65536,
27
+ 131072,
28
+ 262144,
29
+ ],
30
+ # [262144]
31
+ )
32
+ @pytest.mark.parametrize("M", [1, 37, 199, 8 * 1024])
33
+ # @pytest.mark.parametrize("M", [1])
34
+ def test_rmsnorm_forward_backward(M, N, input_dtype, eps):
35
+ """Test RMSNorm forward pass against reference implementation."""
36
+ if N >= 256 * 1024 and input_dtype == torch.float32 and M >= 8 * 1024:
37
+ pytest.skip("Skipping large tensor test for float32 to avoid OOM")
38
+ device = "cuda"
39
+ # Set tolerance based on dtype
40
+ if input_dtype == torch.bfloat16:
41
+ atol = 1e-1
42
+ elif input_dtype == torch.float16:
43
+ atol = 1e-2
44
+ else:
45
+ atol = 1e-4
46
+ torch.random.manual_seed(0)
47
+ x = torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=True)
48
+ weight = torch.randn(N, device=device, dtype=torch.float32, requires_grad=True)
49
+ x_ref = x.detach().clone().requires_grad_()
50
+ weight_ref = weight.detach().clone().requires_grad_()
51
+ out = rmsnorm(x, weight, eps=eps)
52
+ out_ref = rmsnorm_ref(x_ref, weight_ref, eps=eps)
53
+ # rstd_ref_val = rstd_ref(x_ref, eps=eps)
54
+ assert out.shape == x.shape
55
+ assert out.dtype == input_dtype
56
+ torch.testing.assert_close(out, out_ref, atol=atol, rtol=1e-3)
57
+ # torch.testing.assert_close(rstd, rstd_ref_val, atol=atol, rtol=1e-3)
58
+ # Backward pass
59
+ if N > 128 * 1024 and input_dtype == torch.float32:
60
+ # Skip backward pass for due to not enough smem
61
+ return
62
+ grad_out = torch.randn_like(out)
63
+ torch.cuda.synchronize()
64
+ out_ref.backward(grad_out)
65
+ out.backward(grad_out)
66
+ torch.testing.assert_close(x.grad, x_ref.grad, atol=atol, rtol=1e-3)
67
+ torch.testing.assert_close(weight.grad, weight_ref.grad, atol=atol, rtol=1e-3)
68
+
69
+
70
+ def test_rmsnorm_strided_tensor():
71
+ """Test RMSNorm with strided tensor input where shape is (8, 4096, 512) and stride is (sth, 576, 1)."""
72
+ device = "cuda"
73
+ dtype = torch.bfloat16
74
+ atol = 1e-1
75
+ eps = 1e-5
76
+ # Create a larger tensor with 576 features
77
+ full_tensor = torch.randn(8, 4096, 576, device=device, dtype=dtype)
78
+ # Take a slice of the top 512 dimensions - this creates a strided view
79
+ x = full_tensor[:, :, :512].detach().requires_grad_()
80
+ # Create weight tensor
81
+ weight = torch.randn(512, device=device, dtype=torch.float32, requires_grad=True)
82
+ # Reference implementation
83
+ x_ref = x.detach().clone().requires_grad_()
84
+ weight_ref = weight.detach().clone().requires_grad_()
85
+ out = rmsnorm(x, weight, eps=eps)
86
+ out_ref = rmsnorm_ref(x_ref, weight_ref, eps=eps)
87
+ assert out.shape == x.shape
88
+ torch.testing.assert_close(out, out_ref, atol=atol, rtol=1e-3)
89
+ grad_out = torch.randn_like(out)
90
+ torch.cuda.synchronize()
91
+ out_ref.backward(grad_out)
92
+ out.backward(grad_out)
93
+ torch.testing.assert_close(x.grad, x_ref.grad, atol=atol, rtol=1e-3)
94
+ torch.testing.assert_close(weight.grad, weight_ref.grad, atol=atol, rtol=1e-3)
95
+
96
+
97
+ @pytest.mark.parametrize("eps", [1e-5])
98
+ @pytest.mark.parametrize("input_dtype", [torch.bfloat16])
99
+ @pytest.mark.parametrize(
100
+ "N",
101
+ [131072, 262144],
102
+ # [262144]
103
+ )
104
+ @pytest.mark.parametrize("M", [32 * 1024])
105
+ def test_rmsnorm_large_tensor(M, N, input_dtype, eps):
106
+ """Test RMSNorm forward pass against reference implementation."""
107
+ device = "cuda"
108
+ # Set tolerance based on dtype
109
+ if input_dtype == torch.bfloat16:
110
+ atol = 1e-1
111
+ elif input_dtype == torch.float16:
112
+ atol = 1e-2
113
+ else:
114
+ atol = 1e-4
115
+ torch.random.manual_seed(0)
116
+ torch.cuda.empty_cache()
117
+ x = torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=False)
118
+ weight = torch.randn(N, device=device, dtype=torch.float32, requires_grad=False)
119
+ out = rmsnorm(x, weight, eps=eps)
120
+ # Need to compile, otherwise it OOMs
121
+ rmsnorm_compiled = torch.compile(rmsnorm_ref)
122
+ # Run once with smaller input to avoid OOMs
123
+ rmsnorm_compiled(x[:32], weight, eps=eps)
124
+ out_ref = rmsnorm_compiled(x, weight, eps=eps)
125
+ # Need to chunk, otherwise it OOMs
126
+ assert all(
127
+ (out_c - out_ref_c).abs().max() < atol
128
+ for out_c, out_ref_c in zip(out.chunk(16), out_ref.chunk(16))
129
+ )
130
+
131
+
132
+ @pytest.mark.parametrize("return_rstd", [True, False])
133
+ def test_rmsnorm_return_rstd_option(return_rstd):
134
+ """Test that return_rstd option works correctly."""
135
+ device = "cuda"
136
+ M, N = 32, 1024
137
+ eps = 1e-6
138
+
139
+ x = torch.randn(M, N, device=device, dtype=torch.float16)
140
+ weight = torch.randn(N, device=device, dtype=torch.float32)
141
+
142
+ if return_rstd:
143
+ out, rstd = _rmsnorm_fwd(x, weight, eps=eps, return_rstd=True)
144
+ assert out.shape == (M, N)
145
+ assert rstd.shape == (M,)
146
+ assert rstd.dtype == torch.float32
147
+ else:
148
+ out = _rmsnorm_fwd(x, weight, eps=eps, return_rstd=False)
149
+ assert out.shape == (M, N)
150
+ assert isinstance(out, torch.Tensor)
151
+
152
+
153
+ def test_rmsnorm_input_validation():
154
+ """Test input validation and error handling."""
155
+ device = "cuda"
156
+
157
+ # Test 3D input (should now work since rmsnorm was updated to accept 3D inputs)
158
+ x_3d = torch.randn(2, 32, 1024, device=device, dtype=torch.float16)
159
+ weight = torch.randn(1024, device=device, dtype=torch.float32)
160
+
161
+ # This should not raise an exception now
162
+ out = rmsnorm(x_3d, weight)
163
+ # Verify output shape matches input shape
164
+ assert out.shape == x_3d.shape
165
+ # Verify output dtype matches input dtype
166
+ assert out.dtype == x_3d.dtype
167
+
168
+ # Test weight dimension mismatch
169
+ x = torch.randn(32, 1024, device=device, dtype=torch.float16)
170
+ weight_wrong = torch.randn(512, device=device, dtype=torch.float32)
171
+
172
+ with pytest.raises(AssertionError, match="Last dimension of input must match weight dimension"):
173
+ rmsnorm(x, weight_wrong)
174
+
175
+ # Test CPU tensors (should fail)
176
+ x_cpu = torch.randn(32, 1024, dtype=torch.float16)
177
+ weight_cpu = torch.randn(1024, dtype=torch.float32)
178
+
179
+ with pytest.raises(AssertionError, match="Tensors must be on CUDA device"):
180
+ rmsnorm(x_cpu, weight_cpu)
181
+
182
+ # Test unsupported dtype
183
+ x = torch.randn(32, 1024, device=device, dtype=torch.float64)
184
+ weight = torch.randn(1024, device=device, dtype=torch.float32)
185
+
186
+ with pytest.raises(AssertionError, match="Unsupported dtype"):
187
+ rmsnorm(x, weight)
188
+
189
+ # Test wrong weight dtype
190
+ x = torch.randn(32, 1024, device=device, dtype=torch.float16)
191
+ weight_wrong_dtype = torch.randn(1024, device=device, dtype=torch.float64)
192
+
193
+ with pytest.raises(AssertionError, match="Weight must be float32, float16 or bfloat16"):
194
+ rmsnorm(x, weight_wrong_dtype)
195
+
196
+
197
+ def test_rmsnorm_bf16_weights():
198
+ """Test that bfloat16 weights work correctly with rmsnorm."""
199
+ device = "cuda"
200
+ M, N = 32, 1024
201
+ eps = 1e-6
202
+
203
+ # Test with bfloat16 input and weights
204
+ x = torch.randn(M, N, device=device, dtype=torch.bfloat16)
205
+ weight_bf16 = torch.randn(N, device=device, dtype=torch.bfloat16)
206
+
207
+ # Run rmsnorm with bfloat16 weights
208
+ out_bf16 = rmsnorm(x, weight_bf16, eps=eps)
209
+
210
+ # Verify output shape and dtype
211
+ assert out_bf16.shape == x.shape
212
+ assert out_bf16.dtype == torch.bfloat16
213
+
214
+ # Convert to float32 for reference comparison
215
+ x_fp32 = x.to(torch.float32)
216
+ weight_fp32 = weight_bf16.to(torch.float32)
217
+
218
+ # Run reference implementation with float32
219
+ out_ref = rmsnorm_ref(x_fp32, weight_fp32, eps=eps).to(torch.bfloat16)
220
+
221
+ # Verify output values match reference implementation
222
+ torch.testing.assert_close(out_bf16, out_ref, atol=1e-1, rtol=1e-2)
223
+
224
+
225
+ def test_rmsnorm_bf16_weights_backward():
226
+ """Test that bfloat16 weights work correctly with rmsnorm backward pass."""
227
+ device = "cuda"
228
+ M, N = 32, 1024
229
+ eps = 1e-6
230
+ atol = 1e-1 # Higher tolerance for bfloat16
231
+
232
+ # Create tensors with gradients
233
+ x = torch.randn(M, N, device=device, dtype=torch.bfloat16, requires_grad=True)
234
+ weight_bf16 = torch.randn(N, device=device, dtype=torch.bfloat16, requires_grad=True)
235
+
236
+ # Create reference tensors with float32 weights for comparison
237
+ x_ref = x.detach().clone().requires_grad_()
238
+ weight_fp32 = weight_bf16.to(torch.float32).detach().requires_grad_()
239
+
240
+ # Forward pass
241
+ out_bf16 = rmsnorm(x, weight_bf16, eps=eps)
242
+ out_ref = rmsnorm(x_ref, weight_fp32, eps=eps)
243
+
244
+ # Create gradient for backward pass
245
+ grad_out = torch.randn_like(out_bf16)
246
+ grad_out_ref = grad_out.clone()
247
+
248
+ # Backward pass
249
+ torch.cuda.synchronize()
250
+ out_bf16.backward(grad_out)
251
+ out_ref.backward(grad_out_ref)
252
+
253
+ # Verify gradients
254
+ torch.testing.assert_close(x.grad, x_ref.grad, atol=atol, rtol=1e-2)
255
+ torch.testing.assert_close(
256
+ weight_bf16.grad, weight_fp32.grad.to(torch.bfloat16), atol=atol, rtol=1e-2
257
+ )
258
+
259
+ # Test with mixed precision: bfloat16 input and float32 weights
260
+ x = torch.randn(M, N, device=device, dtype=torch.bfloat16, requires_grad=True)
261
+ weight_fp32 = torch.randn(N, device=device, dtype=torch.float32, requires_grad=True)
262
+
263
+ # Forward pass
264
+ out_mixed = rmsnorm(x, weight_fp32, eps=eps)
265
+
266
+ # Create gradient for backward pass
267
+ grad_out = torch.randn_like(out_mixed)
268
+
269
+ # Backward pass
270
+ torch.cuda.synchronize()
271
+ out_mixed.backward(grad_out)
272
+
273
+ # Just verify that backward pass completes without errors
274
+ assert x.grad is not None
275
+ assert weight_fp32.grad is not None
276
+
277
+
278
+ def test_rmsnorm_fp16_weights():
279
+ """Test that float16 weights work correctly with rmsnorm."""
280
+ device = "cuda"
281
+ M, N = 32, 1024
282
+ eps = 1e-6
283
+
284
+ # Test with float16 input and weights
285
+ x = torch.randn(M, N, device=device, dtype=torch.float16)
286
+ weight_fp16 = torch.randn(N, device=device, dtype=torch.float16)
287
+
288
+ # Run rmsnorm with float16 weights
289
+ out_fp16 = rmsnorm(x, weight_fp16, eps=eps)
290
+
291
+ # Verify output shape and dtype
292
+ assert out_fp16.shape == x.shape
293
+ assert out_fp16.dtype == torch.float16
294
+
295
+ # Convert to float32 for reference comparison
296
+ x_fp32 = x.to(torch.float32)
297
+ weight_fp32 = weight_fp16.to(torch.float32)
298
+
299
+ # Run reference implementation with float32
300
+ out_ref = rmsnorm_ref(x_fp32, weight_fp32, eps=eps).to(torch.float16)
301
+
302
+ # Verify output values match reference implementation
303
+ torch.testing.assert_close(out_fp16, out_ref, atol=1e-2, rtol=1e-2)
304
+
305
+
306
+ def test_rmsnorm_fp16_weights_backward():
307
+ """Test that float16 weights work correctly with rmsnorm backward pass."""
308
+ device = "cuda"
309
+ M, N = 32, 1024
310
+ eps = 1e-6
311
+ atol = 1e-2 # Tolerance for float16
312
+
313
+ # Create tensors with gradients
314
+ x = torch.randn(M, N, device=device, dtype=torch.float16, requires_grad=True)
315
+ weight_fp16 = torch.randn(N, device=device, dtype=torch.float16, requires_grad=True)
316
+
317
+ # Create reference tensors with float32 weights for comparison
318
+ x_ref = x.detach().clone().requires_grad_()
319
+ weight_fp32 = weight_fp16.to(torch.float32).detach().requires_grad_()
320
+
321
+ # Forward pass
322
+ out_fp16 = rmsnorm(x, weight_fp16, eps=eps)
323
+ out_ref = rmsnorm(x_ref, weight_fp32, eps=eps)
324
+
325
+ # Create gradient for backward pass
326
+ grad_out = torch.randn_like(out_fp16)
327
+ grad_out_ref = grad_out.clone()
328
+
329
+ # Backward pass
330
+ torch.cuda.synchronize()
331
+ out_fp16.backward(grad_out)
332
+ out_ref.backward(grad_out_ref)
333
+
334
+ # Verify gradients
335
+ torch.testing.assert_close(x.grad, x_ref.grad, atol=atol, rtol=1e-2)
336
+ torch.testing.assert_close(
337
+ weight_fp16.grad, weight_fp32.grad.to(torch.float16), atol=atol, rtol=1e-2
338
+ )
339
+
340
+ # Test with mixed precision: float16 input and float32 weights
341
+ x = torch.randn(M, N, device=device, dtype=torch.float16, requires_grad=True)
342
+ weight_fp32 = torch.randn(N, device=device, dtype=torch.float32, requires_grad=True)
343
+
344
+ # Forward pass
345
+ out_mixed = rmsnorm(x, weight_fp32, eps=eps)
346
+
347
+ # Create gradient for backward pass
348
+ grad_out = torch.randn_like(out_mixed)
349
+
350
+ # Backward pass
351
+ torch.cuda.synchronize()
352
+ out_mixed.backward(grad_out)
353
+
354
+ # Just verify that backward pass completes without errors
355
+ assert x.grad is not None
356
+ assert weight_fp32.grad is not None
357
+
358
+
359
+ def test_rmsnorm_compile_cache():
360
+ """Test that compile cache works correctly for repeated calls."""
361
+ device = "cuda"
362
+ M, N = 32, 1024
363
+ eps = 1e-6
364
+
365
+ # Clear cache
366
+ _rmsnorm_fwd.compile_cache.clear()
367
+ assert len(_rmsnorm_fwd.compile_cache) == 0
368
+
369
+ x1 = torch.randn(M, N, device=device, dtype=torch.float16)
370
+ weight1 = torch.randn(N, device=device, dtype=torch.float32)
371
+
372
+ # First call should compile
373
+ out1 = _rmsnorm_fwd(x1, weight1, eps=eps)
374
+ assert len(_rmsnorm_fwd.compile_cache) == 1
375
+
376
+ # Same shape should reuse cache
377
+ x2 = torch.randn(M, N, device=device, dtype=torch.float16)
378
+ weight2 = torch.randn(N, device=device, dtype=torch.float32)
379
+ out2 = _rmsnorm_fwd(x2, weight2, eps=eps)
380
+ assert len(_rmsnorm_fwd.compile_cache) == 1
381
+
382
+ # Different shape should create new cache entry
383
+ x3 = torch.randn(M, N * 2, device=device, dtype=torch.float16)
384
+ weight3 = torch.randn(N * 2, device=device, dtype=torch.float32)
385
+ out3 = _rmsnorm_fwd(x3, weight3, eps=eps)
386
+ assert len(_rmsnorm_fwd.compile_cache) == 2
387
+
388
+ # Different dtype should create new cache entry
389
+ x4 = torch.randn(M, N, device=device, dtype=torch.float32)
390
+ weight4 = torch.randn(N, device=device, dtype=torch.float32)
391
+ out4 = _rmsnorm_fwd(x4, weight4, eps=eps)
392
+ assert len(_rmsnorm_fwd.compile_cache) == 3
@@ -1,183 +0,0 @@
1
- # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
-
3
- import pytest
4
- import torch
5
-
6
- from quack.rmsnorm import rmsnorm, rmsnorm_ref, rstd_ref
7
-
8
-
9
- @pytest.mark.parametrize("eps", [1e-5, 1e-6])
10
- # @pytest.mark.parametrize("eps", [1e-5])
11
- @pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16, torch.float32])
12
- # @pytest.mark.parametrize("input_dtype", [torch.float16])
13
- @pytest.mark.parametrize(
14
- "N",
15
- [192, 256, 512, 760, 1024, 1128, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144]
16
- # [262144]
17
- )
18
- @pytest.mark.parametrize("M", [1, 37, 199, 8 * 1024])
19
- # @pytest.mark.parametrize("M", [1])
20
- def test_rmsnorm_forward_backward(M, N, input_dtype, eps):
21
- """Test RMSNorm forward pass against reference implementation."""
22
- if N >= 256 * 1024 and input_dtype == torch.float32 and M >= 8 * 1024:
23
- pytest.skip("Skipping large tensor test for float32 to avoid OOM")
24
- device = "cuda"
25
- # Set tolerance based on dtype
26
- if input_dtype == torch.bfloat16:
27
- atol = 1e-1
28
- elif input_dtype == torch.float16:
29
- atol = 1e-2
30
- else:
31
- atol = 1e-4
32
- torch.random.manual_seed(0)
33
- x = torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=True)
34
- weight = torch.randn(N, device=device, dtype=torch.float32, requires_grad=True)
35
- x_ref = x.detach().clone().requires_grad_()
36
- weight_ref = weight.detach().clone().requires_grad_()
37
- out = rmsnorm(x, weight, eps=eps)
38
- out_ref = rmsnorm_ref(x_ref, weight_ref, eps=eps)
39
- # rstd_ref_val = rstd_ref(x_ref, eps=eps)
40
- assert out.shape == x.shape
41
- assert out.dtype == input_dtype
42
- torch.testing.assert_close(out, out_ref, atol=atol, rtol=1e-3)
43
- # torch.testing.assert_close(rstd, rstd_ref_val, atol=atol, rtol=1e-3)
44
- # Backward pass
45
- if N > 128 * 1024 and input_dtype == torch.float32:
46
- # Skip backward pass for due to not enough smem
47
- return
48
- grad_out = torch.randn_like(out)
49
- torch.cuda.synchronize()
50
- out_ref.backward(grad_out)
51
- out.backward(grad_out)
52
- torch.testing.assert_close(x.grad, x_ref.grad, atol=atol, rtol=1e-3)
53
- torch.testing.assert_close(weight.grad, weight_ref.grad, atol=atol, rtol=1e-3)
54
-
55
-
56
- @pytest.mark.parametrize("eps", [1e-5])
57
- @pytest.mark.parametrize("input_dtype", [torch.bfloat16])
58
- @pytest.mark.parametrize(
59
- "N",
60
- [131072, 262144]
61
- # [262144]
62
- )
63
- @pytest.mark.parametrize("M", [32 * 1024])
64
- def test_rmsnorm_large_tensor(M, N, input_dtype, eps):
65
- """Test RMSNorm forward pass against reference implementation."""
66
- device = "cuda"
67
- # Set tolerance based on dtype
68
- if input_dtype == torch.bfloat16:
69
- atol = 1e-1
70
- elif input_dtype == torch.float16:
71
- atol = 1e-2
72
- else:
73
- atol = 1e-4
74
- torch.random.manual_seed(0)
75
- torch.cuda.empty_cache()
76
- x = torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=False)
77
- weight = torch.randn(N, device=device, dtype=torch.float32, requires_grad=False)
78
- out = rmsnorm(x, weight, eps=eps)
79
- # Need to compile, otherwise it OOMs
80
- rmsnorm_compiled = torch.compile(rmsnorm_ref)
81
- # Run once with smaller input to avoid OOMs
82
- rmsnorm_compiled(x[:32], weight, eps=eps)
83
- out_ref = rmsnorm_compiled(x, weight, eps=eps)
84
- # Need to chunk, otherwise it OOMs
85
- assert all((out_c - out_ref_c).abs().max() < atol
86
- for out_c, out_ref_c in zip(out.chunk(16), out_ref.chunk(16)))
87
-
88
-
89
- @pytest.mark.parametrize("return_rstd", [True, False])
90
- def test_rmsnorm_return_rstd_option(return_rstd):
91
- """Test that return_rstd option works correctly."""
92
- device = "cuda"
93
- M, N = 32, 1024
94
- eps = 1e-6
95
-
96
- x = torch.randn(M, N, device=device, dtype=torch.float16)
97
- weight = torch.randn(N, device=device, dtype=torch.float32)
98
-
99
- if return_rstd:
100
- out, rstd = rmsnorm(x, weight, eps=eps, return_rstd=True)
101
- assert out.shape == (M, N)
102
- assert rstd.shape == (M,)
103
- assert rstd.dtype == torch.float32
104
- else:
105
- out = rmsnorm(x, weight, eps=eps, return_rstd=False)
106
- assert out.shape == (M, N)
107
- assert isinstance(out, torch.Tensor)
108
-
109
-
110
- def test_rmsnorm_input_validation():
111
- """Test input validation and error handling."""
112
- device = "cuda"
113
-
114
- # Test 3D input (should fail)
115
- x_3d = torch.randn(2, 32, 1024, device=device, dtype=torch.float16)
116
- weight = torch.randn(1024, device=device, dtype=torch.float32)
117
-
118
- with pytest.raises(AssertionError, match="Input must be 2D"):
119
- rmsnorm(x_3d, weight)
120
-
121
- # Test weight dimension mismatch
122
- x = torch.randn(32, 1024, device=device, dtype=torch.float16)
123
- weight_wrong = torch.randn(512, device=device, dtype=torch.float32)
124
-
125
- with pytest.raises(AssertionError, match="Last dimension of input must match weight dimension"):
126
- rmsnorm(x, weight_wrong)
127
-
128
- # Test CPU tensors (should fail)
129
- x_cpu = torch.randn(32, 1024, dtype=torch.float16)
130
- weight_cpu = torch.randn(1024, dtype=torch.float32)
131
-
132
- with pytest.raises(AssertionError, match="Tensors must be on CUDA device"):
133
- rmsnorm(x_cpu, weight_cpu)
134
-
135
- # Test unsupported dtype
136
- x = torch.randn(32, 1024, device=device, dtype=torch.float64)
137
- weight = torch.randn(1024, device=device, dtype=torch.float32)
138
-
139
- with pytest.raises(AssertionError, match="Unsupported dtype"):
140
- rmsnorm(x, weight)
141
-
142
- # Test wrong weight dtype
143
- x = torch.randn(32, 1024, device=device, dtype=torch.float16)
144
- weight_wrong_dtype = torch.randn(1024, device=device, dtype=torch.float16)
145
-
146
- with pytest.raises(AssertionError, match="Weight must be float32"):
147
- rmsnorm(x, weight_wrong_dtype)
148
-
149
-
150
- def test_rmsnorm_compile_cache():
151
- """Test that compile cache works correctly for repeated calls."""
152
- device = "cuda"
153
- M, N = 32, 1024
154
- eps = 1e-6
155
-
156
- # Clear cache
157
- rmsnorm.compile_cache.clear()
158
- assert len(rmsnorm.compile_cache) == 0
159
-
160
- x1 = torch.randn(M, N, device=device, dtype=torch.float16)
161
- weight1 = torch.randn(N, device=device, dtype=torch.float32)
162
-
163
- # First call should compile
164
- out1 = rmsnorm(x1, weight1, eps=eps)
165
- assert len(rmsnorm.compile_cache) == 1
166
-
167
- # Same shape should reuse cache
168
- x2 = torch.randn(M, N, device=device, dtype=torch.float16)
169
- weight2 = torch.randn(N, device=device, dtype=torch.float32)
170
- out2 = rmsnorm(x2, weight2, eps=eps)
171
- assert len(rmsnorm.compile_cache) == 1
172
-
173
- # Different shape should create new cache entry
174
- x3 = torch.randn(M, N * 2, device=device, dtype=torch.float16)
175
- weight3 = torch.randn(N * 2, device=device, dtype=torch.float32)
176
- out3 = rmsnorm(x3, weight3, eps=eps)
177
- assert len(rmsnorm.compile_cache) == 2
178
-
179
- # Different dtype should create new cache entry
180
- x4 = torch.randn(M, N, device=device, dtype=torch.float32)
181
- weight4 = torch.randn(N, device=device, dtype=torch.float32)
182
- out4 = rmsnorm(x4, weight4, eps=eps)
183
- assert len(rmsnorm.compile_cache) == 3
File without changes
File without changes
File without changes
File without changes