quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/topk.py ADDED
@@ -0,0 +1,227 @@
1
+ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
2
+
3
+ import math
4
+ import torch
5
+ from typing import Type
6
+
7
+ import cuda.bindings.driver as cuda
8
+
9
+ import cutlass
10
+ import cutlass.cute as cute
11
+ from cutlass.cute.runtime import from_dlpack
12
+ from cutlass import const_expr
13
+
14
+ import quack.utils as utils
15
+ from quack.cute_dsl_utils import torch2cute_dtype_map
16
+ from quack.sort.bitonic_sort import bitonic_topk
17
+
18
+
19
+ class TopK:
20
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int, k: int):
21
+ self.dtype = dtype
22
+ self.N = N
23
+ self.vecsize = 128 // dtype.width
24
+ self.k = k
25
+ assert N == 2 ** int(math.log2(N)), "N must be a power of 2"
26
+ assert k == 2 ** int(math.log2(k)), "N must be a power of 2"
27
+ assert k <= 128
28
+ assert N <= 4096
29
+
30
+ def _calculate_threads_per_row(self):
31
+ # we want num_elems_per_thread >= self.k
32
+ # and each thread can handle at most 64 elements
33
+ N = self.N
34
+ num_threads_per_row = max(min(N // self.k, 32, N // 64), 1)
35
+ return num_threads_per_row
36
+
37
+ def _get_tv_layout(self):
38
+ N = self.N
39
+ vecsize = self.vecsize
40
+ num_threads = 128 if N <= 16384 else 256
41
+ threads_per_row = self._calculate_threads_per_row()
42
+ cols_per_block = num_threads // threads_per_row
43
+ num_blocks_N = cute.ceil_div(min(N, 16384) // vecsize, threads_per_row)
44
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
45
+ tv_layout = cute.make_layout(
46
+ ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
47
+ stride=(
48
+ (vecsize * cols_per_block, 1),
49
+ (cols_per_block, cols_per_block * vecsize * threads_per_row),
50
+ ),
51
+ )
52
+ return tiler_mn, tv_layout
53
+
54
+ @cute.jit
55
+ def __call__(
56
+ self,
57
+ mX: cute.Tensor,
58
+ mValues: cute.Tensor,
59
+ mIndices: cute.Tensor,
60
+ stream: cuda.CUstream,
61
+ ):
62
+ assert mX.element_type == self.dtype
63
+ assert mValues.element_type == self.dtype
64
+ assert mIndices.element_type == cutlass.Int32
65
+ tiler_mn, tv_layout = self._get_tv_layout()
66
+ num_threads = cute.size(tv_layout, mode=[0])
67
+ self.kernel(mX, mValues, mIndices, tv_layout, tiler_mn).launch(
68
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), 1, 1],
69
+ block=[num_threads, 1, 1],
70
+ stream=stream,
71
+ )
72
+
73
+ @cute.kernel
74
+ def kernel(
75
+ self,
76
+ mX: cute.Tensor,
77
+ mValues: cute.Tensor,
78
+ mIndices: cute.Tensor,
79
+ tv_layout: cute.Layout,
80
+ tiler_mn: cute.Shape,
81
+ ):
82
+ tidx, _, _ = cute.arch.thread_idx()
83
+ bidx, _, _ = cute.arch.block_idx()
84
+
85
+ shape = mX.shape
86
+ idX = cute.make_identity_tensor(shape)
87
+ # slice for CTAs
88
+ # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
89
+ mX = utils.domain_offset_i64((bidx * tiler_mn[0], 0), mX)
90
+ gX = cute.local_tile(mX, tiler_mn, (0, 0))
91
+ cX = cute.local_tile(idX, tiler_mn, (bidx, 0))
92
+
93
+ # declare the atoms which will be used later for memory copy
94
+ copy_atom_load_X = cute.make_copy_atom(
95
+ cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128
96
+ )
97
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
98
+ tXgX = thr_copy_X.partition_S(gX)
99
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
100
+
101
+ # allocate fragments for gmem->rmem
102
+ tXrX = cute.make_fragment_like(tXgX)
103
+
104
+ is_even_N = const_expr(shape[1] == tiler_mn[1])
105
+ tXpX = (
106
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
107
+ )
108
+ if tXcX[0][0] < shape[0]:
109
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
110
+ tXrX_f32 = cute.make_fragment(tXrX.shape, cutlass.Float32)
111
+ tXrX_f32.store(tXrX.load().to(cutlass.Float32))
112
+
113
+ # Encode the indices into the bottom bits of values.
114
+ log_N = int(math.log2(self.N))
115
+ idx_mask = (1 << log_N) - 1
116
+ vecsize = cutlass.const_expr(tv_layout.shape[1][0])
117
+ tXrX_u32 = cute.recast_tensor(tXrX_f32, cutlass.Uint32)
118
+ # Encode indices into the last log_N bits of tXrX_u32
119
+ for i in cutlass.range(cute.size(tXrX_u32), unroll_full=True):
120
+ # tXcX only keeps track of the indices for every @vecsize elements
121
+ col_idx = cutlass.Uint32(tXcX[i // vecsize][1] + i % vecsize)
122
+ # If positive, invert the bits of the index, so that if there's a tie,
123
+ # indices coming from a earlier column will win.
124
+ encoded_idx = ~col_idx if tXrX_f32[i] >= 0 else col_idx
125
+ # Mask to keep only the last log_N bits of the encoded index
126
+ encoded_idx = encoded_idx & idx_mask
127
+ # Clear the last log_N bits and set them to our encoded index
128
+ tXrX_u32[i] = (tXrX_u32[i] & ~idx_mask) | encoded_idx
129
+
130
+ # Fill OOB values with -inf for top-k
131
+ if const_expr(not is_even_N):
132
+ utils.fill_oob(tXrX_f32, tXpX, -tXrX_f32.element_type.inf)
133
+
134
+ threads_per_row = tv_layout.shape[0][0]
135
+ topk_vals = bitonic_topk(tXrX_f32, self.k, warp_width=threads_per_row)
136
+
137
+ # Extract indices and clean values
138
+ topk_vals_u32 = cute.recast_tensor(topk_vals, cutlass.Uint32)
139
+ topk_indices = cute.make_fragment(self.k, cutlass.Int32)
140
+ for i in cutlass.range(self.k):
141
+ # Extract the encoded index from the last log_N bits
142
+ encoded_idx = topk_vals_u32[i] & idx_mask
143
+ # Check if original value was positive by looking at the cleaned value
144
+ topk_vals_u32[i] = topk_vals_u32[i] & ~idx_mask # Clear last log_N bits
145
+ # If positive, we need to invert the bits back to get original index
146
+ col_idx = ~encoded_idx if topk_vals[i] >= 0 else encoded_idx
147
+ topk_indices[i] = cutlass.Int32(col_idx & idx_mask)
148
+
149
+ # Convert cleaned values to output type
150
+ topk_vals_out = cute.make_fragment_like(topk_vals, mValues.element_type)
151
+ topk_vals_out.store(topk_vals.load().to(mValues.element_type))
152
+
153
+ row = tXcX[0][0]
154
+ # Only the 1st thread in this row writes the top-k values and indices
155
+ if row < shape[0] and tXcX[0][1] == 0:
156
+ # for i in cutlass.range(self.k):
157
+ # mValues[row, i] = topk_vals_out[i]
158
+ # mIndices[row, i] = topk_indices[i]
159
+ # Vectorized write
160
+ elems_per_store = const_expr(math.gcd(vecsize, self.k))
161
+ mValues_store = cute.tiled_divide(mValues[row, None], (elems_per_store,))
162
+ mIndices_store = cute.tiled_divide(mIndices[row, None], (elems_per_store,))
163
+ topk_vals_out_store = cute.tiled_divide(topk_vals_out, (elems_per_store,))
164
+ topk_indices_store = cute.tiled_divide(topk_indices, (elems_per_store,))
165
+ for i in cutlass.range(cute.size(topk_vals_out_store.shape, [1]), unroll_full=True):
166
+ cute.autovec_copy(topk_vals_out_store[None, i], mValues_store[None, i])
167
+ cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i])
168
+
169
+
170
+ @torch.library.custom_op("quack::_topk_fwd", mutates_args={"values", "indices"})
171
+ def _topk_fwd(x: torch.Tensor, k: int, values: torch.Tensor, indices: torch.Tensor) -> None:
172
+ """Top-k forward pass.
173
+ Args:
174
+ x: Input tensor of shape (M, N)
175
+ k: Number of top elements to return
176
+ Returns:
177
+ Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k))
178
+ """
179
+ assert x.dim() == 2, "Input must be 2D"
180
+ assert x.is_cuda, "Tensor must be on CUDA device"
181
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
182
+ assert k > 0 and k <= x.shape[1], "k must be positive and <= N"
183
+
184
+ N = x.size(1)
185
+
186
+ dtype = torch2cute_dtype_map[x.dtype]
187
+ convert_from_dlpack = lambda tensor: (
188
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
189
+ mode=0, stride_order=(0, 1)
190
+ )
191
+ )
192
+
193
+ x_tensor, values_tensor, indices_tensor = [
194
+ convert_from_dlpack(tensor) for tensor in (x, values, indices)
195
+ ]
196
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
197
+ compile_key = (dtype, N, k)
198
+ if compile_key not in _topk_fwd.compile_cache:
199
+ topk_op = TopK(dtype, N, k)
200
+ _topk_fwd.compile_cache[compile_key] = cute.compile(
201
+ topk_op, x_tensor, values_tensor, indices_tensor, current_stream
202
+ )
203
+ _topk_fwd.compile_cache[compile_key](x_tensor, values_tensor, indices_tensor, current_stream)
204
+
205
+
206
+ _topk_fwd.compile_cache = {}
207
+
208
+
209
+ def topk(x: torch.Tensor, k: int):
210
+ """Top-k operation.
211
+
212
+ Args:
213
+ x: Input tensor of shape (M, N)
214
+ k: Number of top elements to return
215
+
216
+ Returns:
217
+ Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k))
218
+ """
219
+
220
+ M = x.size(0)
221
+
222
+ values = torch.empty((M, k), dtype=x.dtype, device=x.device)
223
+ indices = torch.empty((M, k), dtype=torch.int32, device=x.device)
224
+
225
+ _topk_fwd(x, k, values, indices)
226
+
227
+ return values, indices