quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +8 -1
- quack/activation.py +288 -0
- quack/autotuner.py +310 -0
- quack/cross_entropy.py +325 -175
- quack/cute_dsl_utils.py +119 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +1657 -842
- quack/fast_math.py +80 -0
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +69 -0
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +569 -0
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +5 -3
- quack/linear.py +240 -0
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +74 -0
- quack/pipeline.py +151 -0
- quack/reduce.py +241 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +583 -231
- quack/softmax.py +27 -15
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2091 -0
- quack/tensormap_manager.py +115 -0
- quack/tile_scheduler.py +937 -0
- quack/topk.py +227 -0
- quack/utils.py +203 -230
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/METADATA +2 -2
- quack_kernels-0.2.0.dist-info/RECORD +37 -0
- quack_kernels-0.1.10.dist-info/RECORD +0 -13
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/top_level.txt +0 -0
quack/topk.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import torch
|
|
5
|
+
from typing import Type
|
|
6
|
+
|
|
7
|
+
import cuda.bindings.driver as cuda
|
|
8
|
+
|
|
9
|
+
import cutlass
|
|
10
|
+
import cutlass.cute as cute
|
|
11
|
+
from cutlass.cute.runtime import from_dlpack
|
|
12
|
+
from cutlass import const_expr
|
|
13
|
+
|
|
14
|
+
import quack.utils as utils
|
|
15
|
+
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
16
|
+
from quack.sort.bitonic_sort import bitonic_topk
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TopK:
|
|
20
|
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int, k: int):
|
|
21
|
+
self.dtype = dtype
|
|
22
|
+
self.N = N
|
|
23
|
+
self.vecsize = 128 // dtype.width
|
|
24
|
+
self.k = k
|
|
25
|
+
assert N == 2 ** int(math.log2(N)), "N must be a power of 2"
|
|
26
|
+
assert k == 2 ** int(math.log2(k)), "N must be a power of 2"
|
|
27
|
+
assert k <= 128
|
|
28
|
+
assert N <= 4096
|
|
29
|
+
|
|
30
|
+
def _calculate_threads_per_row(self):
|
|
31
|
+
# we want num_elems_per_thread >= self.k
|
|
32
|
+
# and each thread can handle at most 64 elements
|
|
33
|
+
N = self.N
|
|
34
|
+
num_threads_per_row = max(min(N // self.k, 32, N // 64), 1)
|
|
35
|
+
return num_threads_per_row
|
|
36
|
+
|
|
37
|
+
def _get_tv_layout(self):
|
|
38
|
+
N = self.N
|
|
39
|
+
vecsize = self.vecsize
|
|
40
|
+
num_threads = 128 if N <= 16384 else 256
|
|
41
|
+
threads_per_row = self._calculate_threads_per_row()
|
|
42
|
+
cols_per_block = num_threads // threads_per_row
|
|
43
|
+
num_blocks_N = cute.ceil_div(min(N, 16384) // vecsize, threads_per_row)
|
|
44
|
+
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
|
|
45
|
+
tv_layout = cute.make_layout(
|
|
46
|
+
((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
|
|
47
|
+
stride=(
|
|
48
|
+
(vecsize * cols_per_block, 1),
|
|
49
|
+
(cols_per_block, cols_per_block * vecsize * threads_per_row),
|
|
50
|
+
),
|
|
51
|
+
)
|
|
52
|
+
return tiler_mn, tv_layout
|
|
53
|
+
|
|
54
|
+
@cute.jit
|
|
55
|
+
def __call__(
|
|
56
|
+
self,
|
|
57
|
+
mX: cute.Tensor,
|
|
58
|
+
mValues: cute.Tensor,
|
|
59
|
+
mIndices: cute.Tensor,
|
|
60
|
+
stream: cuda.CUstream,
|
|
61
|
+
):
|
|
62
|
+
assert mX.element_type == self.dtype
|
|
63
|
+
assert mValues.element_type == self.dtype
|
|
64
|
+
assert mIndices.element_type == cutlass.Int32
|
|
65
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
66
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
67
|
+
self.kernel(mX, mValues, mIndices, tv_layout, tiler_mn).launch(
|
|
68
|
+
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), 1, 1],
|
|
69
|
+
block=[num_threads, 1, 1],
|
|
70
|
+
stream=stream,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
@cute.kernel
|
|
74
|
+
def kernel(
|
|
75
|
+
self,
|
|
76
|
+
mX: cute.Tensor,
|
|
77
|
+
mValues: cute.Tensor,
|
|
78
|
+
mIndices: cute.Tensor,
|
|
79
|
+
tv_layout: cute.Layout,
|
|
80
|
+
tiler_mn: cute.Shape,
|
|
81
|
+
):
|
|
82
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
83
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
84
|
+
|
|
85
|
+
shape = mX.shape
|
|
86
|
+
idX = cute.make_identity_tensor(shape)
|
|
87
|
+
# slice for CTAs
|
|
88
|
+
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
|
89
|
+
mX = utils.domain_offset_i64((bidx * tiler_mn[0], 0), mX)
|
|
90
|
+
gX = cute.local_tile(mX, tiler_mn, (0, 0))
|
|
91
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, 0))
|
|
92
|
+
|
|
93
|
+
# declare the atoms which will be used later for memory copy
|
|
94
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
95
|
+
cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128
|
|
96
|
+
)
|
|
97
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
98
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
99
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
100
|
+
|
|
101
|
+
# allocate fragments for gmem->rmem
|
|
102
|
+
tXrX = cute.make_fragment_like(tXgX)
|
|
103
|
+
|
|
104
|
+
is_even_N = const_expr(shape[1] == tiler_mn[1])
|
|
105
|
+
tXpX = (
|
|
106
|
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
107
|
+
)
|
|
108
|
+
if tXcX[0][0] < shape[0]:
|
|
109
|
+
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
110
|
+
tXrX_f32 = cute.make_fragment(tXrX.shape, cutlass.Float32)
|
|
111
|
+
tXrX_f32.store(tXrX.load().to(cutlass.Float32))
|
|
112
|
+
|
|
113
|
+
# Encode the indices into the bottom bits of values.
|
|
114
|
+
log_N = int(math.log2(self.N))
|
|
115
|
+
idx_mask = (1 << log_N) - 1
|
|
116
|
+
vecsize = cutlass.const_expr(tv_layout.shape[1][0])
|
|
117
|
+
tXrX_u32 = cute.recast_tensor(tXrX_f32, cutlass.Uint32)
|
|
118
|
+
# Encode indices into the last log_N bits of tXrX_u32
|
|
119
|
+
for i in cutlass.range(cute.size(tXrX_u32), unroll_full=True):
|
|
120
|
+
# tXcX only keeps track of the indices for every @vecsize elements
|
|
121
|
+
col_idx = cutlass.Uint32(tXcX[i // vecsize][1] + i % vecsize)
|
|
122
|
+
# If positive, invert the bits of the index, so that if there's a tie,
|
|
123
|
+
# indices coming from a earlier column will win.
|
|
124
|
+
encoded_idx = ~col_idx if tXrX_f32[i] >= 0 else col_idx
|
|
125
|
+
# Mask to keep only the last log_N bits of the encoded index
|
|
126
|
+
encoded_idx = encoded_idx & idx_mask
|
|
127
|
+
# Clear the last log_N bits and set them to our encoded index
|
|
128
|
+
tXrX_u32[i] = (tXrX_u32[i] & ~idx_mask) | encoded_idx
|
|
129
|
+
|
|
130
|
+
# Fill OOB values with -inf for top-k
|
|
131
|
+
if const_expr(not is_even_N):
|
|
132
|
+
utils.fill_oob(tXrX_f32, tXpX, -tXrX_f32.element_type.inf)
|
|
133
|
+
|
|
134
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
135
|
+
topk_vals = bitonic_topk(tXrX_f32, self.k, warp_width=threads_per_row)
|
|
136
|
+
|
|
137
|
+
# Extract indices and clean values
|
|
138
|
+
topk_vals_u32 = cute.recast_tensor(topk_vals, cutlass.Uint32)
|
|
139
|
+
topk_indices = cute.make_fragment(self.k, cutlass.Int32)
|
|
140
|
+
for i in cutlass.range(self.k):
|
|
141
|
+
# Extract the encoded index from the last log_N bits
|
|
142
|
+
encoded_idx = topk_vals_u32[i] & idx_mask
|
|
143
|
+
# Check if original value was positive by looking at the cleaned value
|
|
144
|
+
topk_vals_u32[i] = topk_vals_u32[i] & ~idx_mask # Clear last log_N bits
|
|
145
|
+
# If positive, we need to invert the bits back to get original index
|
|
146
|
+
col_idx = ~encoded_idx if topk_vals[i] >= 0 else encoded_idx
|
|
147
|
+
topk_indices[i] = cutlass.Int32(col_idx & idx_mask)
|
|
148
|
+
|
|
149
|
+
# Convert cleaned values to output type
|
|
150
|
+
topk_vals_out = cute.make_fragment_like(topk_vals, mValues.element_type)
|
|
151
|
+
topk_vals_out.store(topk_vals.load().to(mValues.element_type))
|
|
152
|
+
|
|
153
|
+
row = tXcX[0][0]
|
|
154
|
+
# Only the 1st thread in this row writes the top-k values and indices
|
|
155
|
+
if row < shape[0] and tXcX[0][1] == 0:
|
|
156
|
+
# for i in cutlass.range(self.k):
|
|
157
|
+
# mValues[row, i] = topk_vals_out[i]
|
|
158
|
+
# mIndices[row, i] = topk_indices[i]
|
|
159
|
+
# Vectorized write
|
|
160
|
+
elems_per_store = const_expr(math.gcd(vecsize, self.k))
|
|
161
|
+
mValues_store = cute.tiled_divide(mValues[row, None], (elems_per_store,))
|
|
162
|
+
mIndices_store = cute.tiled_divide(mIndices[row, None], (elems_per_store,))
|
|
163
|
+
topk_vals_out_store = cute.tiled_divide(topk_vals_out, (elems_per_store,))
|
|
164
|
+
topk_indices_store = cute.tiled_divide(topk_indices, (elems_per_store,))
|
|
165
|
+
for i in cutlass.range(cute.size(topk_vals_out_store.shape, [1]), unroll_full=True):
|
|
166
|
+
cute.autovec_copy(topk_vals_out_store[None, i], mValues_store[None, i])
|
|
167
|
+
cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i])
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@torch.library.custom_op("quack::_topk_fwd", mutates_args={"values", "indices"})
|
|
171
|
+
def _topk_fwd(x: torch.Tensor, k: int, values: torch.Tensor, indices: torch.Tensor) -> None:
|
|
172
|
+
"""Top-k forward pass.
|
|
173
|
+
Args:
|
|
174
|
+
x: Input tensor of shape (M, N)
|
|
175
|
+
k: Number of top elements to return
|
|
176
|
+
Returns:
|
|
177
|
+
Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k))
|
|
178
|
+
"""
|
|
179
|
+
assert x.dim() == 2, "Input must be 2D"
|
|
180
|
+
assert x.is_cuda, "Tensor must be on CUDA device"
|
|
181
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
182
|
+
assert k > 0 and k <= x.shape[1], "k must be positive and <= N"
|
|
183
|
+
|
|
184
|
+
N = x.size(1)
|
|
185
|
+
|
|
186
|
+
dtype = torch2cute_dtype_map[x.dtype]
|
|
187
|
+
convert_from_dlpack = lambda tensor: (
|
|
188
|
+
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
189
|
+
mode=0, stride_order=(0, 1)
|
|
190
|
+
)
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
x_tensor, values_tensor, indices_tensor = [
|
|
194
|
+
convert_from_dlpack(tensor) for tensor in (x, values, indices)
|
|
195
|
+
]
|
|
196
|
+
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
197
|
+
compile_key = (dtype, N, k)
|
|
198
|
+
if compile_key not in _topk_fwd.compile_cache:
|
|
199
|
+
topk_op = TopK(dtype, N, k)
|
|
200
|
+
_topk_fwd.compile_cache[compile_key] = cute.compile(
|
|
201
|
+
topk_op, x_tensor, values_tensor, indices_tensor, current_stream
|
|
202
|
+
)
|
|
203
|
+
_topk_fwd.compile_cache[compile_key](x_tensor, values_tensor, indices_tensor, current_stream)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
_topk_fwd.compile_cache = {}
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def topk(x: torch.Tensor, k: int):
|
|
210
|
+
"""Top-k operation.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
x: Input tensor of shape (M, N)
|
|
214
|
+
k: Number of top elements to return
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k))
|
|
218
|
+
"""
|
|
219
|
+
|
|
220
|
+
M = x.size(0)
|
|
221
|
+
|
|
222
|
+
values = torch.empty((M, k), dtype=x.dtype, device=x.device)
|
|
223
|
+
indices = torch.empty((M, k), dtype=torch.int32, device=x.device)
|
|
224
|
+
|
|
225
|
+
_topk_fwd(x, k, values, indices)
|
|
226
|
+
|
|
227
|
+
return values, indices
|