quack-kernels 0.1.2__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. quack_kernels-0.1.4/PKG-INFO +11 -0
  2. {quack_kernels-0.1.2 → quack_kernels-0.1.4}/README.md +14 -1
  3. quack_kernels-0.1.4/pyproject.toml +33 -0
  4. {quack_kernels-0.1.2 → quack_kernels-0.1.4}/quack/__init__.py +7 -1
  5. quack_kernels-0.1.4/quack/cross_entropy.py +255 -0
  6. quack_kernels-0.1.4/quack/reduction_base.py +98 -0
  7. quack_kernels-0.1.4/quack/rmsnorm.py +285 -0
  8. quack_kernels-0.1.4/quack/softmax.py +456 -0
  9. quack_kernels-0.1.4/quack/utils.py +407 -0
  10. quack_kernels-0.1.4/quack_kernels.egg-info/PKG-INFO +11 -0
  11. {quack_kernels-0.1.2 → quack_kernels-0.1.4}/quack_kernels.egg-info/SOURCES.txt +1 -0
  12. quack_kernels-0.1.4/quack_kernels.egg-info/requires.txt +6 -0
  13. {quack_kernels-0.1.2 → quack_kernels-0.1.4}/tests/test_cross_entropy.py +3 -3
  14. {quack_kernels-0.1.2 → quack_kernels-0.1.4}/tests/test_softmax.py +75 -16
  15. quack_kernels-0.1.2/PKG-INFO +0 -8
  16. quack_kernels-0.1.2/pyproject.toml +0 -18
  17. quack_kernels-0.1.2/quack/cross_entropy.py +0 -221
  18. quack_kernels-0.1.2/quack/rmsnorm.py +0 -254
  19. quack_kernels-0.1.2/quack/softmax.py +0 -195
  20. quack_kernels-0.1.2/quack/utils.py +0 -246
  21. quack_kernels-0.1.2/quack_kernels.egg-info/PKG-INFO +0 -8
  22. quack_kernels-0.1.2/quack_kernels.egg-info/requires.txt +0 -2
  23. {quack_kernels-0.1.2 → quack_kernels-0.1.4}/LICENSE +0 -0
  24. {quack_kernels-0.1.2 → quack_kernels-0.1.4}/quack_kernels.egg-info/dependency_links.txt +0 -0
  25. {quack_kernels-0.1.2 → quack_kernels-0.1.4}/quack_kernels.egg-info/top_level.txt +0 -0
  26. {quack_kernels-0.1.2 → quack_kernels-0.1.4}/setup.cfg +0 -0
  27. {quack_kernels-0.1.2 → quack_kernels-0.1.4}/setup.py +0 -0
  28. {quack_kernels-0.1.2 → quack_kernels-0.1.4}/tests/test_rmsnorm.py +0 -0
@@ -0,0 +1,11 @@
1
+ Metadata-Version: 2.4
2
+ Name: quack-kernels
3
+ Version: 0.1.4
4
+ Requires-Python: >=3.9
5
+ License-File: LICENSE
6
+ Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
7
+ Requires-Dist: torch
8
+ Provides-Extra: dev
9
+ Requires-Dist: pre-commit; extra == "dev"
10
+ Requires-Dist: ruff; extra == "dev"
11
+ Dynamic: license-file
@@ -17,12 +17,25 @@ pip install quack-kernels
17
17
  ## Kernels 🐥
18
18
 
19
19
  - 🦆 RMSNorm forward
20
- - 🦆 Softmax forward
20
+ - 🦆 Softmax forward and backward
21
21
  - 🦆 Cross entropy forward
22
22
 
23
+ Upcoming:
24
+ - 🦆 Cross entropy backward
25
+ - 🦆 RMSNorm backward
26
+ - 🦆 Rotary forward + backward
23
27
 
24
28
  ## Usage
25
29
 
26
30
  ```
27
31
  from quack import rmsnorm, softmax, cross_entropy
28
32
  ```
33
+
34
+ ## Development
35
+
36
+ To set up the development environment:
37
+
38
+ ```bash
39
+ pip install -e '.[dev]'
40
+ pre-commit install
41
+ ```
@@ -0,0 +1,33 @@
1
+ [build-system]
2
+ requires = ["setuptools"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "quack-kernels"
7
+ dynamic = ["version"]
8
+ requires-python = ">=3.9"
9
+ dependencies = [
10
+ "nvidia-cutlass-dsl==4.1.0.dev0",
11
+ "torch",
12
+ ]
13
+
14
+ [project.optional-dependencies]
15
+ dev = [
16
+ "pre-commit",
17
+ "ruff",
18
+ ]
19
+
20
+ [tool.setuptools.packages.find]
21
+ exclude = ["tests", "benchmarks"]
22
+
23
+ [tool.setuptools.dynamic]
24
+ version = {attr = "quack.__version__"}
25
+
26
+ [tool.ruff]
27
+ line-length = 100
28
+
29
+ [tool.ruff.lint]
30
+ ignore = [
31
+ "E731", # do not assign a lambda expression, use a def
32
+ "F841", # local variable is assigned to but never used
33
+ ]
@@ -1,5 +1,11 @@
1
- __version__ = "0.1.2"
1
+ __version__ = "0.1.4"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
5
5
  from quack.cross_entropy import cross_entropy
6
+
7
+ __all__ = [
8
+ "rmsnorm",
9
+ "softmax",
10
+ "cross_entropy",
11
+ ]
@@ -0,0 +1,255 @@
1
+ import math
2
+ import torch
3
+ from typing import Optional, Type
4
+
5
+ import cuda.bindings.driver as cuda
6
+
7
+ import cutlass
8
+ import cutlass.cute as cute
9
+ from cutlass.cute.runtime import from_dlpack
10
+
11
+ import quack.utils as utils
12
+ from quack.reduction_base import ReductionBase, torch2cute_dtype_map
13
+
14
+
15
+ class CrossEntropy(ReductionBase):
16
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int, online_softmax: bool = True):
17
+ # 2 stages: 1 for max, 1 for sum
18
+ super().__init__(
19
+ dtype,
20
+ N,
21
+ stage=2 if not online_softmax else 1,
22
+ reduction_dtype=cutlass.Float32 if not online_softmax else cutlass.Int64,
23
+ )
24
+ self.online_softmax = online_softmax
25
+ self.reload_from = None if N <= 16384 or online_softmax else "smem"
26
+
27
+ def _calculate_threads_per_row(self):
28
+ N = self.N
29
+ return (
30
+ 8
31
+ if N <= 64
32
+ else (
33
+ 16
34
+ if N <= 128
35
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
36
+ )
37
+ )
38
+
39
+ def _set_cluster_n(self):
40
+ N = self.N
41
+ if cutlass.const_expr(self.dtype.width == 16):
42
+ cluster_n = (
43
+ 1
44
+ if N <= 16 * 1024
45
+ else (
46
+ 2
47
+ if N <= 32 * 1024
48
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
49
+ )
50
+ )
51
+ else: # fp32
52
+ cluster_n = (
53
+ 1
54
+ if N <= 16 * 1024
55
+ else (
56
+ 2
57
+ if N <= 64 * 1024
58
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
59
+ )
60
+ )
61
+ self.cluster_n = cluster_n
62
+
63
+ @cute.jit
64
+ def __call__(
65
+ self,
66
+ mX: cute.Tensor,
67
+ mTarget: cute.Tensor,
68
+ mLoss: cute.Tensor,
69
+ mLSE: Optional[cute.Tensor],
70
+ stream: cuda.CUstream,
71
+ ):
72
+ assert mX.element_type == self.dtype
73
+ self._set_cluster_n()
74
+ tiler_mn, tv_layout = self._get_tv_layout()
75
+ num_threads = cute.size(tv_layout, mode=[0])
76
+ num_warps = num_threads // cute.arch.WARP_SIZE
77
+ self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
78
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
79
+ block=[num_threads, 1, 1],
80
+ cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
81
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
82
+ stream=stream,
83
+ )
84
+
85
+ @cute.kernel
86
+ def kernel(
87
+ self,
88
+ mX: cute.Tensor, # (M, N)
89
+ mTarget: cute.Tensor, # (M,)
90
+ mLoss: cute.Tensor, # (M,)
91
+ mLSE: Optional[cute.Tensor], # (M,)
92
+ tv_layout: cute.Layout,
93
+ tiler_mn: cute.Shape,
94
+ ):
95
+ tidx, _, _ = cute.arch.thread_idx()
96
+ bidx, _, _ = cute.arch.block_idx()
97
+ if cutlass.const_expr(self.cluster_n > 1):
98
+ cluster_y = cute.arch.block_idx()[1]
99
+ else:
100
+ cluster_y = cutlass.const_expr(0)
101
+
102
+ shape: cute.Shape = mX.shape
103
+ idX = cute.make_identity_tensor(shape)
104
+ # slice for CTAs
105
+ gX, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, idX)]
106
+
107
+ smem = cutlass.utils.SmemAllocator()
108
+ sX = smem.allocate_tensor(
109
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
110
+ )
111
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
112
+
113
+ # declare the atoms which will be used later for memory copy
114
+ copy_atom_load_X = cute.make_copy_atom(
115
+ cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
116
+ )
117
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
118
+
119
+ #### Thread View
120
+ tXgX = thr_copy_X.partition_S(gX)
121
+ tXsX = thr_copy_X.partition_D(sX)
122
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
123
+ tXrX = cute.make_fragment_like(tXgX)
124
+
125
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
126
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
127
+
128
+ row = tXcX[0][0]
129
+ target = cute.Int32.zero
130
+ if row < shape[0] and tXcX[0][1] == 0:
131
+ target = cute.Int32(mTarget[row])
132
+
133
+ is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
134
+ tXpX = (
135
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
136
+ if cutlass.const_expr(not is_even_N)
137
+ else None
138
+ )
139
+ if row < shape[0]:
140
+ cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
141
+ cute.arch.cp_async_commit_group()
142
+ cute.arch.cp_async_wait_group(0)
143
+ # Fill OOB values with -inf
144
+ if cutlass.const_expr(not is_even_N):
145
+ utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
146
+ cute.autovec_copy(tXsX, tXrX)
147
+ x = tXrX.load().to(cute.Float32)
148
+
149
+ target_logit = cute.Float32.zero
150
+ if row < shape[0] and tXcX[0][1] == 0:
151
+ target_logit = cute.Float32(mX[row, target])
152
+
153
+ threads_per_row = tv_layout.shape[0][0]
154
+ if cutlass.const_expr(not self.online_softmax):
155
+ max_x = utils.row_reduce(
156
+ x,
157
+ cute.ReductionOp.MAX,
158
+ threads_per_row,
159
+ reduction_buffer[None, None, 0],
160
+ mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
161
+ init_val=-cutlass.Float32.inf,
162
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
163
+ )
164
+ if cutlass.const_expr(self.reload_from == "smem"):
165
+ cute.autovec_copy(tXsX, tXrX)
166
+ x = tXrX.load().to(cute.Float32)
167
+ log2_e = math.log2(math.e)
168
+ # exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
169
+ # a bit faster, probably because it's calling ex2.approx.ftz instead of ex2.approx?
170
+ # exp_x = utils.exp2f((x - max_x) * log2_e)
171
+ # This would use ffma instead of fadd then fmul
172
+ exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
173
+ denom = utils.row_reduce(
174
+ exp_x,
175
+ cute.ReductionOp.ADD,
176
+ threads_per_row,
177
+ reduction_buffer[None, None, 1],
178
+ mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
179
+ init_val=0.0,
180
+ )
181
+ else:
182
+ max_x, denom, _ = utils.online_softmax_reduce(
183
+ x,
184
+ threads_per_row,
185
+ reduction_buffer[None, None, 0],
186
+ mbar_ptr,
187
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
188
+ )
189
+
190
+ if (
191
+ tXcX[0][1] == 0
192
+ and row < shape[0]
193
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
194
+ ):
195
+ ln_2 = math.log(2.0)
196
+ lse = max_x + utils.log2f(denom) * ln_2
197
+ loss_val = lse - target_logit
198
+ mLoss[row] = loss_val.to(mLoss.element_type)
199
+ if cutlass.const_expr(mLSE is not None):
200
+ mLSE[row] = lse
201
+
202
+
203
+ def cross_entropy(
204
+ x: torch.Tensor,
205
+ target: torch.Tensor,
206
+ return_lse: bool = False,
207
+ ) -> torch.Tensor:
208
+ """Cross entropy forward pass.
209
+
210
+ Args:
211
+ x: Input logits tensor of shape (M, N)
212
+ target: Target class indices tensor of shape (M,)
213
+
214
+ Returns:
215
+ Cross entropy loss tensor of shape (M,)
216
+ """
217
+ assert x.dim() == 2, "Input must be 2D"
218
+ assert target.dim() == 1, "Target must be 1D"
219
+ assert x.shape[0] == target.shape[0], "Batch dimensions must match"
220
+ assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
221
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
222
+ assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
223
+ M, N = x.shape
224
+ device = x.device
225
+ loss = torch.empty(M, device=device, dtype=torch.float32)
226
+ lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
227
+ dtype = torch2cute_dtype_map[x.dtype]
228
+ convert_from_dlpack = lambda tensor: (
229
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
230
+ mode=0, stride_order=(0, 1)
231
+ )
232
+ )
233
+ x_tensor = convert_from_dlpack(x)
234
+ loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
235
+ lse_tensor = (
236
+ from_dlpack(lse.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
237
+ if lse is not None
238
+ else None
239
+ )
240
+ target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_compact_shape_dynamic(mode=0)
241
+ stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
242
+
243
+ compile_key = (dtype, N, lse is not None)
244
+ if compile_key not in cross_entropy.compile_cache:
245
+ cross_entropy_op = CrossEntropy(dtype, N)
246
+ cross_entropy.compile_cache[compile_key] = cute.compile(
247
+ cross_entropy_op, x_tensor, target_tensor, loss_tensor, lse_tensor, stream
248
+ )
249
+ cross_entropy.compile_cache[compile_key](
250
+ x_tensor, target_tensor, loss_tensor, lse_tensor, stream
251
+ )
252
+ return loss if not return_lse else (loss, lse)
253
+
254
+
255
+ cross_entropy.compile_cache = {}
@@ -0,0 +1,98 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ import torch
4
+ from typing import Type, Tuple, Optional
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+
9
+ import quack.utils as utils
10
+
11
+
12
+ torch2cute_dtype_map = {
13
+ torch.float16: cutlass.Float16,
14
+ torch.bfloat16: cutlass.BFloat16,
15
+ torch.float32: cutlass.Float32,
16
+ }
17
+
18
+
19
+ class ReductionBase:
20
+ def __init__(
21
+ self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=cutlass.Float32
22
+ ):
23
+ self.dtype = dtype
24
+ self.N = N
25
+ self.stage = stage
26
+ self.reduction_dtype = reduction_dtype
27
+
28
+ def _calculate_threads_per_row(self):
29
+ raise NotImplementedError()
30
+
31
+ def _set_cluster_n(self):
32
+ self.cluster_n = 1
33
+
34
+ def _get_num_threads(self):
35
+ return 128 if self.N <= 16384 else 256
36
+
37
+ def _get_tv_layout(self):
38
+ copy_bits = 128
39
+ vecsize = copy_bits // self.dtype.width
40
+ assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
41
+ num_threads = self._get_num_threads()
42
+ num_warps = num_threads // cute.arch.WARP_SIZE
43
+ assert num_threads % cute.arch.WARP_SIZE == 0
44
+
45
+ threads_per_row = self._calculate_threads_per_row()
46
+ num_blocks_N = cute.ceil_div(self.N // vecsize, threads_per_row * self.cluster_n)
47
+ cols_per_block = num_threads // threads_per_row
48
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
49
+ tv_layout = cute.make_layout(
50
+ ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
51
+ stride=(
52
+ (vecsize * cols_per_block, 1),
53
+ (cols_per_block, cols_per_block * vecsize * threads_per_row),
54
+ ),
55
+ )
56
+ return tiler_mn, tv_layout
57
+
58
+ def _smem_size_in_bytes(self, tiler_mn, num_warps):
59
+ return (
60
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
61
+ + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
62
+ + self.stage * (cutlass.Int64.width // 8)
63
+ )
64
+
65
+ def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int):
66
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
67
+ warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
68
+ return cute.make_ordered_layout(
69
+ (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
70
+ order=(1, 0, 2),
71
+ )
72
+
73
+ def _allocate_reduction_buffer_and_mbar(
74
+ self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout
75
+ ) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
76
+ reduction_buffer = smem.allocate_tensor(
77
+ self.reduction_dtype,
78
+ self._get_reduction_buffer_layout(tv_layout, self.cluster_n),
79
+ byte_alignment=4,
80
+ )
81
+ if cutlass.const_expr(self.cluster_n > 1):
82
+ mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=self.stage)
83
+ else:
84
+ mbar_ptr = None
85
+ return reduction_buffer, mbar_ptr
86
+
87
+ @cute.jit
88
+ def _initialize_cluster(self, tidx: cutlass.Int32, mbar_ptr: cute.Pointer, num_warps: int):
89
+ if cutlass.const_expr(self.cluster_n > 1):
90
+ if tidx < self.stage:
91
+ cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
92
+ cute.arch.mbarrier_init_fence()
93
+ if tidx < self.stage:
94
+ cute.arch.mbarrier_arrive_and_expect_tx(
95
+ mbar_ptr + tidx, num_warps * self.cluster_n * self.reduction_dtype.width // 8
96
+ )
97
+ # Cluster arrive after barrier init
98
+ cute.arch.cluster_arrive_relaxed()