quack-kernels 0.1.5__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. {quack_kernels-0.1.5/quack_kernels.egg-info → quack_kernels-0.1.6}/PKG-INFO +1 -1
  2. {quack_kernels-0.1.5 → quack_kernels-0.1.6}/README.md +16 -5
  3. {quack_kernels-0.1.5 → quack_kernels-0.1.6}/quack/__init__.py +1 -1
  4. {quack_kernels-0.1.5 → quack_kernels-0.1.6}/quack/cross_entropy.py +11 -7
  5. quack_kernels-0.1.6/quack/layernorm.py +351 -0
  6. {quack_kernels-0.1.5 → quack_kernels-0.1.6}/quack/rmsnorm.py +4 -1
  7. {quack_kernels-0.1.5 → quack_kernels-0.1.6}/quack/softmax.py +8 -3
  8. {quack_kernels-0.1.5 → quack_kernels-0.1.6}/quack/utils.py +16 -0
  9. {quack_kernels-0.1.5 → quack_kernels-0.1.6/quack_kernels.egg-info}/PKG-INFO +1 -1
  10. {quack_kernels-0.1.5 → quack_kernels-0.1.6}/quack_kernels.egg-info/SOURCES.txt +2 -0
  11. {quack_kernels-0.1.5 → quack_kernels-0.1.6}/quack_kernels.egg-info/top_level.txt +1 -0
  12. {quack_kernels-0.1.5 → quack_kernels-0.1.6}/tests/test_cross_entropy.py +13 -49
  13. quack_kernels-0.1.6/tests/test_layernorm.py +162 -0
  14. {quack_kernels-0.1.5 → quack_kernels-0.1.6}/tests/test_rmsnorm.py +36 -5
  15. {quack_kernels-0.1.5 → quack_kernels-0.1.6}/tests/test_softmax.py +2 -3
  16. {quack_kernels-0.1.5 → quack_kernels-0.1.6}/LICENSE +0 -0
  17. {quack_kernels-0.1.5 → quack_kernels-0.1.6}/pyproject.toml +0 -0
  18. {quack_kernels-0.1.5 → quack_kernels-0.1.6}/quack/reduction_base.py +0 -0
  19. {quack_kernels-0.1.5 → quack_kernels-0.1.6}/quack_kernels.egg-info/dependency_links.txt +0 -0
  20. {quack_kernels-0.1.5 → quack_kernels-0.1.6}/quack_kernels.egg-info/requires.txt +0 -0
  21. {quack_kernels-0.1.5 → quack_kernels-0.1.6}/setup.cfg +0 -0
  22. {quack_kernels-0.1.5 → quack_kernels-0.1.6}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
6
  Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
@@ -19,6 +19,7 @@ pip install quack-kernels
19
19
  - 🦆 RMSNorm forward
20
20
  - 🦆 Softmax forward + backward
21
21
  - 🦆 Cross entropy forward + backward
22
+ - 🦆 Layernorm forward
22
23
 
23
24
  Upcoming:
24
25
  - 🦆 RMSNorm backward
@@ -30,13 +31,23 @@ Upcoming:
30
31
  from quack import rmsnorm, softmax, cross_entropy
31
32
  ```
32
33
 
33
- ## Caveats 🦆⚠️
34
+ ## Documentations
34
35
 
35
- **Tensor Size Limitation**: We currently only support tensors ≤ 4GB due to CuTe-DSL using int32 for indexing.
36
+ [2025-07-10] We have a comprehensive
37
+ [blogpost](media/2025-07-10-membound-sol.md) on how to get memory-bound kernels
38
+ to speed-of-light, right in the comfort of Python thanks to the [CuTe-DSL](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.html).
36
39
 
37
- 🦆 **Workaround**: For larger tensors, split your input tensors into chunks of
38
- size ≤ 4GB each. We will implement this automatic chunking in the pytorch part
39
- of the code in the near future, but if you need it in the meantime, we welcome contributions!
40
+ ## Performance
41
+
42
+ <div align="center">
43
+ <figure>
44
+ <img
45
+ src="media/bf16_kernel_benchmarks_single_row.svg"
46
+ >
47
+ </figure>
48
+ </div>
49
+
50
+ See our [blogpost](media/2025-07-10-membound-sol.md) for the details.
40
51
 
41
52
  ## Development
42
53
 
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.5"
1
+ __version__ = "0.1.6"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
@@ -104,7 +104,10 @@ class CrossEntropy(ReductionBase):
104
104
  shape: cute.Shape = mX.shape
105
105
  idX = cute.make_identity_tensor(shape)
106
106
  # slice for CTAs
107
- gX, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, idX)]
107
+ # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
108
+ mX_off = utils.domain_offset_i64((bidx * tiler_mn[0], 0), mX)
109
+ gX = cute.local_tile(mX_off, tiler_mn, (0, cluster_y))
110
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
108
111
 
109
112
  smem = cutlass.utils.SmemAllocator()
110
113
  sX = smem.allocate_tensor(
@@ -150,7 +153,9 @@ class CrossEntropy(ReductionBase):
150
153
 
151
154
  target_logit = cute.Float32.zero
152
155
  if row < shape[0] and tXcX[0][1] == 0:
153
- target_logit = cute.Float32(mX[row, target])
156
+ # Use Int64 for indexing to deal with large tensors
157
+ mX_off = utils.domain_offset_i64((row, 0), mX)
158
+ target_logit = cute.Float32(mX_off[0, target])
154
159
 
155
160
  threads_per_row = tv_layout.shape[0][0]
156
161
  if cutlass.const_expr(not self.online_softmax):
@@ -363,11 +368,10 @@ class CrossEntropyBackward:
363
368
  )
364
369
 
365
370
  idX = cute.make_identity_tensor(shape)
366
-
367
- gX, gdX, cX, gTarget, gDLoss, gLse = [
368
- cute.local_tile(mT, tiler_mn, (bidx, bidy))
369
- for mT in (mX, mdX, idX, mTarget, mDLoss, mLSE)
370
- ]
371
+ # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
372
+ mX, mdX = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mdX)]
373
+ gX, gdX = [cute.local_tile(mT, tiler_mn, (0, bidy)) for mT in (mX, mdX)]
374
+ cX = cute.local_tile(idX, tiler_mn, (bidx, bidy))
371
375
 
372
376
  copy_atom_load_X = cute.make_copy_atom(
373
377
  cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128
@@ -0,0 +1,351 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+
4
+ import torch
5
+ from typing import Optional
6
+
7
+ import cuda.bindings.driver as cuda
8
+
9
+ import cutlass
10
+ import cutlass.cute as cute
11
+ from cutlass.cute.runtime import from_dlpack
12
+ import quack.utils as utils
13
+ from quack.reduction_base import ReductionBase, torch2cute_dtype_map
14
+
15
+
16
+ class LayerNorm(ReductionBase):
17
+ def __init__(self, dtype: cutlass.Numeric, N: int):
18
+ super().__init__(dtype, N, stage=2) # 2 stages for mean and var
19
+ self.reload_from = None if N <= 16384 else "smem"
20
+ self.delay_w_load = False
21
+
22
+ def _calculate_threads_per_row(self):
23
+ N = self.N
24
+ return (
25
+ 8
26
+ if N <= 64
27
+ else (
28
+ 16
29
+ if N <= 128
30
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
31
+ )
32
+ )
33
+
34
+ def _set_cluster_n(self):
35
+ N = self.N
36
+ # cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
37
+ # Similarly cluster_n = 8 is faster for N=128k
38
+ if cutlass.const_expr(self.dtype.width == 16):
39
+ cluster_n = (
40
+ 1
41
+ if N <= 16 * 1024
42
+ else (
43
+ 2
44
+ if N <= 32 * 1024
45
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
46
+ )
47
+ )
48
+ else: # fp32
49
+ cluster_n = (
50
+ 1
51
+ if N <= 32 * 1024
52
+ else (
53
+ 2
54
+ if N <= 64 * 1024
55
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
56
+ )
57
+ )
58
+ self.cluster_n = cluster_n
59
+
60
+ @cute.jit
61
+ def __call__(
62
+ self,
63
+ mX: cute.Tensor,
64
+ mW: cute.Tensor,
65
+ mO: cute.Tensor,
66
+ mRstd: Optional[cute.Tensor],
67
+ mMean: Optional[cute.Tensor],
68
+ stream: cuda.CUstream,
69
+ eps: cutlass.Float32 = 1e-6,
70
+ ):
71
+ assert mX.element_type == self.dtype
72
+ assert mO.element_type == self.dtype
73
+ self._set_cluster_n()
74
+ tiler_mn, tv_layout = self._get_tv_layout()
75
+ num_threads = cute.size(tv_layout, mode=[0])
76
+ num_warps = num_threads // cute.arch.WARP_SIZE
77
+ mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
78
+ mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
79
+ if cutlass.const_expr(mRstd is not None):
80
+ mRstd_expanded_layout = cute.append(
81
+ mRstd.layout, cute.make_layout((self.N,), stride=(0,))
82
+ )
83
+ mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
84
+ if cutlass.const_expr(mMean is not None):
85
+ mMean_expanded_layout = cute.append(
86
+ mMean.layout, cute.make_layout((self.N,), stride=(0,))
87
+ )
88
+ mMean = cute.make_tensor(mMean.iterator, mMean_expanded_layout)
89
+ self.kernel(mX, mW, mO, mRstd, mMean, eps, tv_layout, tiler_mn, self.reload_from).launch(
90
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
91
+ block=[num_threads, 1, 1],
92
+ cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
93
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
94
+ stream=stream,
95
+ )
96
+
97
+ @cute.kernel
98
+ def kernel(
99
+ self,
100
+ mX: cute.Tensor,
101
+ mW: cute.Tensor,
102
+ mO: cute.Tensor,
103
+ mRstd: Optional[cute.Tensor],
104
+ mMean: Optional[cute.Tensor],
105
+ eps: cute.Float32,
106
+ tv_layout: cute.Layout,
107
+ tiler_mn: cute.Shape,
108
+ reload_from: cutlass.Constexpr = None,
109
+ delay_w_load: cutlass.Constexpr = False,
110
+ ):
111
+ tidx, _, _ = cute.arch.thread_idx()
112
+ bidx, _, _ = cute.arch.block_idx()
113
+ if cutlass.const_expr(self.cluster_n > 1):
114
+ cluster_y = cute.arch.block_idx()[1]
115
+ else:
116
+ cluster_y = cutlass.const_expr(0)
117
+
118
+ smem = cutlass.utils.SmemAllocator()
119
+ sX = smem.allocate_tensor(
120
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
121
+ )
122
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
123
+
124
+ shape = mX.shape
125
+ idX = cute.make_identity_tensor(shape)
126
+ # slice for CTAs
127
+ # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
128
+ mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
129
+ gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
130
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
131
+ gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
132
+ gRstd = (
133
+ cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
134
+ if cutlass.const_expr(mRstd is not None)
135
+ else None
136
+ )
137
+ gMean = (
138
+ cute.local_tile(mMean, tiler_mn, (bidx, cluster_y))
139
+ if cutlass.const_expr(mMean is not None)
140
+ else None
141
+ )
142
+
143
+ # declare the atoms which will be used later for memory copy
144
+ copy_atom_load_X = cute.make_copy_atom(
145
+ cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
146
+ )
147
+ copy_atom_load_X_async = cute.make_copy_atom(
148
+ cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
149
+ )
150
+ copy_atom_load_W = cute.make_copy_atom(
151
+ cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
152
+ )
153
+ copy_atom_store_O = cute.make_copy_atom(
154
+ cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=128
155
+ )
156
+
157
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
158
+ tidx
159
+ )
160
+ thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
161
+ thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
162
+
163
+ tWgW = thr_copy_W.partition_S(gW)
164
+ tXgX = thr_copy_X.partition_S(gX)
165
+ tXsX = thr_copy_X.partition_D(sX)
166
+ tXgO = thr_copy_O.partition_D(gO)
167
+ tXrRstd = thr_copy_O.partition_D(gRstd) if cutlass.const_expr(mRstd is not None) else None
168
+ tXrMean = thr_copy_O.partition_D(gMean) if cutlass.const_expr(mMean is not None) else None
169
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
170
+
171
+ # allocate fragments for gmem->rmem
172
+ tWrW = cute.make_fragment_like(tWgW)
173
+ tXrW = thr_copy_X.retile(tWrW)
174
+ tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
175
+
176
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
177
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
178
+
179
+ tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
180
+ row = tXcX[0][0]
181
+ if row < shape[0]:
182
+ cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
183
+ cute.arch.cp_async_commit_group()
184
+
185
+ tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
186
+ if cutlass.const_expr(not delay_w_load):
187
+ cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
188
+
189
+ cute.arch.cp_async_wait_group(0)
190
+ cute.autovec_copy(tXsX, tXrX)
191
+ x = tXrX.load().to(cute.Float32)
192
+ threads_per_row = tv_layout.shape[0][0]
193
+ sum_x = utils.row_reduce(
194
+ x,
195
+ cute.ReductionOp.ADD,
196
+ threads_per_row,
197
+ reduction_buffer[None, None, 0],
198
+ mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
199
+ init_val=0.0,
200
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
201
+ )
202
+ mean = sum_x / shape[1]
203
+ if cutlass.const_expr(reload_from == "smem"):
204
+ cute.autovec_copy(tXsX, tXrX)
205
+ x = tXrX.load().to(cute.Float32)
206
+ elif cutlass.const_expr(reload_from == "gmem"):
207
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
208
+ x = tXrX.load().to(cute.Float32)
209
+
210
+ sum_sq_x_sub_mean = utils.row_reduce(
211
+ (x - mean) * (x - mean),
212
+ cute.ReductionOp.ADD,
213
+ threads_per_row,
214
+ reduction_buffer[None, None, 1],
215
+ mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
216
+ init_val=0.0,
217
+ )
218
+ rstd = utils.rsqrt(sum_sq_x_sub_mean / shape[1] + eps)
219
+ if cutlass.const_expr(mRstd is not None):
220
+ # Only the thread corresponding to column 0 writes out the rstd to gmem
221
+ if (
222
+ tXcX[0][1] == 0
223
+ and row < shape[0]
224
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
225
+ ):
226
+ tXrRstd[0] = rstd
227
+ if cutlass.const_expr(mMean is not None):
228
+ # Only the thread corresponding to column 0 writes out the mean to gmem
229
+ if (
230
+ tXcX[0][1] == 0
231
+ and row < shape[0]
232
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
233
+ ):
234
+ tXrMean[0] = mean
235
+ if cutlass.const_expr(delay_w_load):
236
+ cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
237
+ if cutlass.const_expr(reload_from == "smem"):
238
+ cute.autovec_copy(tXsX, tXrX)
239
+ x = tXrX.load().to(cute.Float32)
240
+ elif cutlass.const_expr(reload_from == "gmem"):
241
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
242
+ x = tXrX.load().to(cute.Float32)
243
+ x_hat = (x - mean) * rstd
244
+ w = tXrW.load().to(cute.Float32)
245
+ y = x_hat * w
246
+ tXrO.store(y.to(tXrO.element_type))
247
+ tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
248
+ if row < shape[0]:
249
+ cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
250
+
251
+
252
+ def layernorm(
253
+ x: torch.Tensor,
254
+ weight: torch.Tensor,
255
+ eps: float = 1e-6,
256
+ return_rstd: bool = False,
257
+ return_mean: bool = False,
258
+ ) -> torch.Tensor:
259
+ """LayerNorm forward pass.
260
+
261
+ Args:
262
+ x: Input tensor of shape (M, N)
263
+ weight: Weight tensor of shape (N,)
264
+ eps: Small value for numerical stability
265
+ return_rstd: Whether to return the reciprocal standard deviation
266
+ return_mean: Whether to return the mean
267
+
268
+ Returns:
269
+ Normalized output tensor of same shape as x
270
+ If return_rstd is True, also returns rstd tensor of shape (M,)
271
+ If return_mean is True, also returns mean tensor of shape (M,)
272
+ """
273
+ assert x.dim() == 2, "Input must be 2D"
274
+ assert weight.dim() == 1, "Weight must be 1D"
275
+ assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
276
+ assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
277
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
278
+ assert weight.dtype == torch.float32, "Weight must be float32"
279
+ M, N = x.shape
280
+ device = x.device
281
+ out = torch.empty_like(x)
282
+ rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
283
+ mean = torch.empty(M, device=device, dtype=torch.float32) if return_mean else None
284
+ dtype = torch2cute_dtype_map[x.dtype]
285
+ convert_from_dlpack = lambda x: (
286
+ from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
287
+ mode=0, stride_order=(0, 1)
288
+ )
289
+ )
290
+ x_tensor, out_tensor = [
291
+ # utils.convert_from_dlpack(t, leading_dim=t.ndim - 1, divisibility=128 // dtype.width)
292
+ convert_from_dlpack(t)
293
+ for t in (x, out)
294
+ ]
295
+ weight_tensor = utils.convert_from_dlpack(
296
+ weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
297
+ )
298
+ rstd_tensor = (
299
+ from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
300
+ if rstd is not None
301
+ else None
302
+ )
303
+ mean_tensor = (
304
+ from_dlpack(mean.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
305
+ if mean is not None
306
+ else None
307
+ )
308
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
309
+ compile_key = (dtype, N, rstd is not None, mean is not None)
310
+ if compile_key not in layernorm.compile_cache:
311
+ rmsnorm_op = LayerNorm(dtype, N)
312
+ layernorm.compile_cache[compile_key] = cute.compile(
313
+ rmsnorm_op,
314
+ x_tensor,
315
+ weight_tensor,
316
+ out_tensor,
317
+ rstd_tensor,
318
+ mean_tensor,
319
+ current_stream,
320
+ )
321
+ layernorm.compile_cache[compile_key](
322
+ x_tensor, weight_tensor, out_tensor, rstd_tensor, mean_tensor, current_stream, eps
323
+ )
324
+ return (
325
+ (out, rstd, mean)
326
+ if return_mean and return_rstd
327
+ else (
328
+ (out, rstd)
329
+ if return_rstd and not return_mean
330
+ else ((out, mean) if return_mean and not return_rstd else (out))
331
+ )
332
+ )
333
+
334
+
335
+ layernorm.compile_cache = {}
336
+
337
+
338
+ def layernorm_ref(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6):
339
+ x_f32 = x.float()
340
+ return torch.nn.functional.layer_norm(x_f32, w.shape, w, None, eps).to(x.dtype)
341
+
342
+
343
+ def rstd_ref(x: torch.Tensor, eps: float = 1e-6):
344
+ x_f32 = x.float()
345
+ mean = x_f32.mean(dim=-1, keepdim=True)
346
+ var = ((x_f32 - mean) ** 2).mean(dim=-1)
347
+ return 1.0 / torch.sqrt(var + eps)
348
+
349
+
350
+ def mean_ref(x: torch.Tensor) -> torch.Tensor:
351
+ return x.float().mean(dim=-1)
@@ -117,7 +117,10 @@ class RMSNorm(ReductionBase):
117
117
  shape = mX.shape
118
118
  idX = cute.make_identity_tensor(shape)
119
119
  # slice for CTAs
120
- gX, gO, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)]
120
+ # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
121
+ mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
122
+ gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
123
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
121
124
  gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
122
125
  gRstd = (
123
126
  cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
@@ -98,7 +98,10 @@ class Softmax(ReductionBase):
98
98
  shape = mX.shape
99
99
  idX = cute.make_identity_tensor(shape)
100
100
  # slice for CTAs
101
- gX, gO, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)]
101
+ # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
102
+ mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
103
+ gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
104
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
102
105
 
103
106
  smem = cutlass.utils.SmemAllocator()
104
107
  sX = smem.allocate_tensor(
@@ -312,9 +315,11 @@ class SoftmaxBackward(ReductionBase):
312
315
  shape = mdY.shape
313
316
  idX = cute.make_identity_tensor(shape)
314
317
  # slice for CTAs
315
- gdY, gY, gdX, cX = [
316
- cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mdY, mY, mdX, idX)
318
+ mdY, mY, mdX = [
319
+ utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mdY, mY, mdX)
317
320
  ]
321
+ gdY, gY, gdX = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mdY, mY, mdX)]
322
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
318
323
 
319
324
  smem = cutlass.utils.SmemAllocator()
320
325
  sdY = smem.allocate_tensor(
@@ -390,3 +390,19 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
390
390
  vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip)
391
391
  )
392
392
  return res0, res1
393
+
394
+
395
+ @dsl_user_op
396
+ def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
397
+ flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
398
+ flat_stride = cute.flatten_to_tuple(tensor.stride)
399
+ offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
400
+ assert isinstance(tensor.iterator, cute.Pointer)
401
+ # HACK: we assume that applying the offset does not change the pointer alignment
402
+ new_ptr = cute.make_ptr(
403
+ tensor.element_type,
404
+ tensor.iterator.toint() + offset * tensor.element_type.width // 8,
405
+ tensor.memspace,
406
+ assumed_align=tensor.iterator.max_alignment,
407
+ )
408
+ return cute.make_tensor(new_ptr, tensor.layout)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
6
  Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
@@ -4,6 +4,7 @@ pyproject.toml
4
4
  setup.py
5
5
  quack/__init__.py
6
6
  quack/cross_entropy.py
7
+ quack/layernorm.py
7
8
  quack/reduction_base.py
8
9
  quack/rmsnorm.py
9
10
  quack/softmax.py
@@ -14,5 +15,6 @@ quack_kernels.egg-info/dependency_links.txt
14
15
  quack_kernels.egg-info/requires.txt
15
16
  quack_kernels.egg-info/top_level.txt
16
17
  tests/test_cross_entropy.py
18
+ tests/test_layernorm.py
17
19
  tests/test_rmsnorm.py
18
20
  tests/test_softmax.py
@@ -16,18 +16,19 @@ import cutlass
16
16
  )
17
17
  @pytest.mark.parametrize("M", [1, 77, 289])
18
18
  # @pytest.mark.parametrize("M", [1])
19
- def test_cross_entropy_forward(M, N, input_dtype):
19
+ def test_cross_entropy(M, N, input_dtype):
20
20
  """Test Cross Entropy forward pass against reference implementation."""
21
21
  device = "cuda"
22
- atol, rtol = 1e-5, 1e-5
22
+ atol, rtol = 5e-5, 1e-5
23
23
  torch.random.manual_seed(0)
24
+ cutlass.cuda.initialize_cuda_context()
24
25
  # Create input tensors (scale down to avoid overflow)
25
- x = 0.1 * torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=False)
26
+ x = (0.1 * torch.randn(M, N, device=device, dtype=input_dtype)).requires_grad_()
26
27
  target = torch.randint(0, N, (M,), device=device, dtype=torch.int64)
27
- x_ref = x.detach().clone()
28
+ x_ref = x.detach().clone().requires_grad_()
28
29
  target_ref = target.detach().clone()
29
30
  # Forward pass
30
- loss = _cross_entropy(x, target)
31
+ loss = cross_entropy(x, target)
31
32
  loss_ref = F.cross_entropy(x_ref.float(), target_ref, reduction='none')
32
33
  # Check output shape and dtype
33
34
  assert loss.shape == (M,)
@@ -40,6 +41,13 @@ def test_cross_entropy_forward(M, N, input_dtype):
40
41
  # Check that loss is reasonable (not inf or nan)
41
42
  assert not torch.isnan(loss).any()
42
43
  assert not torch.isinf(loss).any()
44
+ # Test backward pass
45
+ dloss = torch.randn_like(loss)
46
+ torch.cuda.synchronize()
47
+ dx_ref, = torch.autograd.grad(loss_ref, x_ref, grad_outputs=dloss)
48
+ dx, = torch.autograd.grad(loss, x, grad_outputs=dloss)
49
+ assert dx.shape == x.shape
50
+ torch.testing.assert_close(dx, dx_ref.to(input_dtype), atol=atol, rtol=rtol)
43
51
 
44
52
 
45
53
  @pytest.mark.parametrize("input_dtype", [torch.float16, torch.float32])
@@ -99,47 +107,3 @@ def test_cross_entropy_edge_targets():
99
107
  loss_last = _cross_entropy(x, target_last)
100
108
  loss_ref_last = F.cross_entropy(x, target_last, reduction='none')
101
109
  torch.testing.assert_close(loss_last, loss_ref_last, atol=1e-4, rtol=1e-4)
102
-
103
-
104
-
105
-
106
-
107
- @pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16, torch.float32])
108
- @pytest.mark.parametrize(
109
- "N",
110
- [192, 256, 512, 760, 1024, 1128, 2048, 4096, 8192, 16384, 32768, 65536, 128256, 131072, 256128, 262144] # A representative subset to keep compile time reasonable
111
- )
112
- @pytest.mark.parametrize("M", [1, 37, 77])
113
- def test_cross_entropy_autograd_backward(M, N, input_dtype):
114
- device = "cuda"
115
-
116
- if input_dtype == torch.bfloat16:
117
- atol = 1e-3
118
- rtol = 1e-3
119
- else:
120
- atol = 1e-5
121
- rtol = 1e-5
122
-
123
- torch.random.manual_seed(0)
124
-
125
- x = 0.1 * torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=True)
126
- target = torch.randint(0, N, (M,), device=device, dtype=torch.int64)
127
-
128
- x_ref = x.detach().clone().requires_grad_(True)
129
- target_ref = target.detach().clone()
130
-
131
- cutlass.cuda.initialize_cuda_context()
132
-
133
- loss = cross_entropy(x, target) # our autograd-enabled op
134
- loss_ref = F.cross_entropy(x_ref.float(), target_ref, reduction='none')
135
-
136
- torch.testing.assert_close(loss, loss_ref, atol=atol, rtol=rtol)
137
-
138
- dloss = torch.randn_like(loss)
139
-
140
- dx_ref, = torch.autograd.grad(loss_ref, x_ref, grad_outputs=dloss)
141
-
142
- dx, = torch.autograd.grad(loss, x, grad_outputs=dloss)
143
-
144
- assert dx.shape == x.shape
145
- torch.testing.assert_close(dx, dx_ref.to(input_dtype), atol=atol, rtol=rtol)
@@ -0,0 +1,162 @@
1
+ # tests/test_layernorm.py
2
+
3
+ import pytest
4
+ import torch
5
+
6
+ from quack.layernorm import layernorm, layernorm_ref, rstd_ref, mean_ref
7
+
8
+
9
+ @pytest.mark.parametrize("eps", [1e-5, 1e-6])
10
+ @pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16, torch.float32])
11
+ @pytest.mark.parametrize("M", [1, 37, 199])
12
+ @pytest.mark.parametrize(
13
+ "N", [256, 512, 760, 1024, 1128, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144]
14
+ ) # , 32768])
15
+ def test_layernorm_forward(M, N, input_dtype, eps):
16
+ """Test LayerNorm forward pass against reference implementation."""
17
+ device = "cuda"
18
+
19
+ # tolerance depends on precision
20
+ if input_dtype == torch.bfloat16:
21
+ atol = 1e-2
22
+ rtol = 1e-2
23
+ elif input_dtype == torch.float16:
24
+ atol = 1e-3
25
+ rtol = 1e-3
26
+ else:
27
+ atol = 1e-4
28
+ rtol = 1e-4
29
+
30
+ torch.random.manual_seed(0)
31
+ x = torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=True)
32
+ weight = torch.randn(N, device=device, dtype=torch.float32, requires_grad=True)
33
+
34
+ # pure‐PyTorch refs
35
+ x_ref = x.detach().clone().requires_grad_()
36
+ weight_ref = weight.detach().clone().requires_grad_()
37
+
38
+ out, rstd, mean = layernorm(x, weight, eps=eps, return_rstd=True, return_mean=True)
39
+ out_ref = layernorm_ref(x_ref, weight_ref, eps=eps)
40
+ rstd_ref_val = rstd_ref(x_ref, eps=eps)
41
+ mean_ref_val = mean_ref(x_ref)
42
+
43
+ # shapes & dtypes
44
+ assert out.shape == x.shape
45
+ assert out.dtype == input_dtype
46
+ assert rstd.shape == (M,) and rstd.dtype == torch.float32
47
+ assert mean.shape == (M,) and mean.dtype == torch.float32
48
+
49
+ # numeric check
50
+ torch.testing.assert_close(out, out_ref, atol=atol, rtol=rtol)
51
+ torch.testing.assert_close(rstd, rstd_ref_val, atol=6e-4, rtol=6e-4)
52
+ torch.testing.assert_close(mean, mean_ref_val, atol=6e-4, rtol=6e-4)
53
+
54
+
55
+ @pytest.mark.parametrize("return_rstd", [True, False])
56
+ @pytest.mark.parametrize("return_mean", [True, False])
57
+ def test_layernormnorm_return_rstd_option(return_rstd, return_mean):
58
+ """Test that return_rstd option works correctly."""
59
+ device = "cuda"
60
+ M, N = 32, 1024
61
+ eps = 1e-6
62
+
63
+ x = torch.randn(M, N, device=device, dtype=torch.float16)
64
+ weight = torch.randn(N, device=device, dtype=torch.float32)
65
+
66
+ if return_rstd and return_mean:
67
+ out, rstd, mean = layernorm(x, weight, eps=eps, return_rstd=True, return_mean=True)
68
+ assert out.shape == (M, N)
69
+ assert rstd.shape == (M,)
70
+ assert rstd.dtype == torch.float32
71
+ assert mean.shape == (M,)
72
+ assert mean.dtype == torch.float32
73
+ elif return_rstd and not return_mean:
74
+ out, rstd = layernorm(x, weight, eps=eps, return_rstd=True, return_mean=False)
75
+ assert out.shape == (M, N)
76
+ assert rstd.shape == (M,)
77
+ assert rstd.dtype == torch.float32
78
+ elif not return_rstd and return_mean:
79
+ out, mean = layernorm(x, weight, eps=eps, return_rstd=False, return_mean=True)
80
+ assert out.shape == (M, N)
81
+ assert mean.shape == (M,)
82
+ assert mean.dtype == torch.float32
83
+ else:
84
+ out = layernorm(x, weight, eps=eps, return_rstd=False, return_mean=False)
85
+ assert out.shape == (M, N)
86
+ assert isinstance(out, torch.Tensor)
87
+
88
+
89
+ def test_layernorm_input_validation():
90
+ """Test input validation and error handling."""
91
+ device = "cuda"
92
+
93
+ # Test 3D input (should fail)
94
+ x_3d = torch.randn(2, 32, 1024, device=device, dtype=torch.float16)
95
+ weight = torch.randn(1024, device=device, dtype=torch.float32)
96
+
97
+ with pytest.raises(AssertionError, match="Input must be 2D"):
98
+ layernorm(x_3d, weight)
99
+
100
+ # Test weight dimension mismatch
101
+ x = torch.randn(32, 1024, device=device, dtype=torch.float16)
102
+ weight_wrong = torch.randn(512, device=device, dtype=torch.float32)
103
+
104
+ with pytest.raises(AssertionError, match="Last dimension of input must match weight dimension"):
105
+ layernorm(x, weight_wrong)
106
+
107
+ # Test CPU tensors (should fail)
108
+ x_cpu = torch.randn(32, 1024, dtype=torch.float16)
109
+ weight_cpu = torch.randn(1024, dtype=torch.float32)
110
+
111
+ with pytest.raises(AssertionError, match="Tensors must be on CUDA device"):
112
+ layernorm(x_cpu, weight_cpu)
113
+
114
+ # Test unsupported dtype
115
+ x = torch.randn(32, 1024, device=device, dtype=torch.float64)
116
+ weight = torch.randn(1024, device=device, dtype=torch.float32)
117
+
118
+ with pytest.raises(AssertionError, match="Unsupported dtype"):
119
+ layernorm(x, weight)
120
+
121
+ # Test wrong weight dtype
122
+ x = torch.randn(32, 1024, device=device, dtype=torch.float16)
123
+ weight_wrong_dtype = torch.randn(1024, device=device, dtype=torch.float16)
124
+
125
+ with pytest.raises(AssertionError, match="Weight must be float32"):
126
+ layernorm(x, weight_wrong_dtype)
127
+
128
+
129
+ def test_layernorm_compile_cache():
130
+ """Test that compile cache works correctly for repeated calls."""
131
+ device = "cuda"
132
+ M, N = 32, 1024
133
+ eps = 1e-6
134
+
135
+ # Clear cache
136
+ layernorm.compile_cache.clear()
137
+ assert len(layernorm.compile_cache) == 0
138
+
139
+ x1 = torch.randn(M, N, device=device, dtype=torch.float16)
140
+ weight1 = torch.randn(N, device=device, dtype=torch.float32)
141
+
142
+ # First call should compile
143
+ out1 = layernorm(x1, weight1, eps=eps)
144
+ assert len(layernorm.compile_cache) == 1
145
+
146
+ # Same shape should reuse cache
147
+ x2 = torch.randn(M, N, device=device, dtype=torch.float16)
148
+ weight2 = torch.randn(N, device=device, dtype=torch.float32)
149
+ out2 = layernorm(x2, weight2, eps=eps)
150
+ assert len(layernorm.compile_cache) == 1
151
+
152
+ # Different shape should create new cache entry
153
+ x3 = torch.randn(M, N * 2, device=device, dtype=torch.float16)
154
+ weight3 = torch.randn(N * 2, device=device, dtype=torch.float32)
155
+ out3 = layernorm(x3, weight3, eps=eps)
156
+ assert len(layernorm.compile_cache) == 2
157
+
158
+ # Different dtype should create new cache entry
159
+ x4 = torch.randn(M, N, device=device, dtype=torch.float32)
160
+ weight4 = torch.randn(N, device=device, dtype=torch.float32)
161
+ out4 = layernorm(x4, weight4, eps=eps)
162
+ assert len(layernorm.compile_cache) == 3
@@ -32,19 +32,17 @@ def test_rmsnorm_forward(M, N, input_dtype, eps):
32
32
  weight = torch.randn(N, device=device, dtype=torch.float32, requires_grad=True)
33
33
  x_ref = x.detach().clone().requires_grad_()
34
34
  weight_ref = weight.detach().clone().requires_grad_()
35
- out, rstd = rmsnorm(x, weight, eps=eps, return_rstd=True)
35
+ out = rmsnorm(x, weight, eps=eps)
36
36
  out_ref = rmsnorm_ref(x_ref, weight_ref, eps=eps)
37
- rstd_ref_val = rstd_ref(x_ref, eps=eps)
37
+ # rstd_ref_val = rstd_ref(x_ref, eps=eps)
38
38
 
39
39
  # Check output shape and dtype
40
40
  assert out.shape == x.shape
41
41
  assert out.dtype == input_dtype
42
- assert rstd.shape == (M,)
43
- assert rstd.dtype == torch.float32
44
42
 
45
43
  # Check accuracy
46
44
  torch.testing.assert_close(out, out_ref, atol=atol, rtol=1e-3)
47
- torch.testing.assert_close(rstd, rstd_ref_val, atol=atol, rtol=1e-3)
45
+ # torch.testing.assert_close(rstd, rstd_ref_val, atol=atol, rtol=1e-3)
48
46
 
49
47
 
50
48
  # @pytest.mark.parametrize("eps", [1e-5, 1e-6])
@@ -88,6 +86,39 @@ def test_rmsnorm_forward(M, N, input_dtype, eps):
88
86
  # torch.testing.assert_close(weight.grad, weight_ref.grad, atol=atol, rtol=1e-3)
89
87
 
90
88
 
89
+ @pytest.mark.parametrize("eps", [1e-5])
90
+ @pytest.mark.parametrize("input_dtype", [torch.bfloat16])
91
+ @pytest.mark.parametrize(
92
+ "N",
93
+ [131072, 262144]
94
+ # [262144]
95
+ )
96
+ @pytest.mark.parametrize("M", [32 * 1024])
97
+ def test_rmsnorm_large_tensor(M, N, input_dtype, eps):
98
+ """Test RMSNorm forward pass against reference implementation."""
99
+ device = "cuda"
100
+ # Set tolerance based on dtype
101
+ if input_dtype == torch.bfloat16:
102
+ atol = 1e-1
103
+ elif input_dtype == torch.float16:
104
+ atol = 1e-2
105
+ else:
106
+ atol = 1e-4
107
+ torch.random.manual_seed(0)
108
+ torch.cuda.empty_cache()
109
+ x = torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=False)
110
+ weight = torch.randn(N, device=device, dtype=torch.float32, requires_grad=False)
111
+ out = rmsnorm(x, weight, eps=eps)
112
+ # Need to compile, otherwise it OOMs
113
+ rmsnorm_compiled = torch.compile(rmsnorm_ref)
114
+ # Run once with smaller input to avoid OOMs
115
+ rmsnorm_compiled(x[:32], weight, eps=eps)
116
+ out_ref = rmsnorm_compiled(x, weight, eps=eps)
117
+ # Need to chunk, otherwise it OOMs
118
+ assert all((out_c - out_ref_c).abs().max() < atol
119
+ for out_c, out_ref_c in zip(out.chunk(16), out_ref.chunk(16)))
120
+
121
+
91
122
  @pytest.mark.parametrize("return_rstd", [True, False])
92
123
  def test_rmsnorm_return_rstd_option(return_rstd):
93
124
  """Test that return_rstd option works correctly."""
@@ -34,7 +34,7 @@ def test_softmax(M, N, input_dtype):
34
34
 
35
35
  torch.random.manual_seed(0)
36
36
  # Create input tensors (scale down to avoid overflow in softmax)
37
- x = 0.1 * torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=True)
37
+ x = (0.1 * torch.randn(M, N, device=device, dtype=input_dtype)).requires_grad_()
38
38
  x_ref = x.detach().clone().requires_grad_(True)
39
39
 
40
40
  # Forward pass
@@ -58,13 +58,12 @@ def test_softmax(M, N, input_dtype):
58
58
 
59
59
  # Test backward pass
60
60
  dy = torch.randn_like(out)
61
+ torch.cuda.synchronize() # without sync, torch.autograd gets wrong results
61
62
  dx_ref, = torch.autograd.grad(out_ref, x_ref, grad_outputs=dy)
62
63
  # Call our implementation later, otherwise getting CUDA_ERROR_INVALID_CONTEXT
63
64
  dx, = torch.autograd.grad(out, x, grad_outputs=dy)
64
- # Check output shape and dtype
65
65
  assert dx.shape == dy.shape
66
66
  assert dx.dtype == input_dtype
67
- # Check accuracy against reference
68
67
  torch.testing.assert_close(dx, dx_ref, atol=atol, rtol=rtol)
69
68
 
70
69
 
File without changes
File without changes
File without changes