quack-kernels 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -34
- quack/gemm.py +194 -0
- quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
- quack/gemm_config.py +72 -46
- quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +177 -31
- quack/gemm_sm100.py +729 -506
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +3 -1
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +476 -526
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +23 -16
- quack/topk.py +409 -85
- quack/utils.py +32 -220
- quack/varlen_utils.py +370 -1
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/METADATA +4 -2
- quack_kernels-0.2.4.dist-info/RECORD +44 -0
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.2.dist-info/RECORD +0 -37
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/top_level.txt +0 -0
quack/topk.py
CHANGED
|
@@ -1,55 +1,57 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
|
|
2
2
|
|
|
3
3
|
import math
|
|
4
|
+
from functools import partial
|
|
5
|
+
from typing import Type, Optional
|
|
6
|
+
|
|
4
7
|
import torch
|
|
5
|
-
from typing import Type
|
|
6
8
|
|
|
7
9
|
import cuda.bindings.driver as cuda
|
|
8
10
|
|
|
9
11
|
import cutlass
|
|
10
12
|
import cutlass.cute as cute
|
|
11
|
-
from cutlass
|
|
12
|
-
from cutlass import const_expr
|
|
13
|
+
from cutlass import Int32, Float32, const_expr
|
|
13
14
|
|
|
14
15
|
import quack.utils as utils
|
|
16
|
+
import quack.copy_utils as copy_utils
|
|
17
|
+
from quack.compile_utils import make_fake_tensor as fake_tensor
|
|
18
|
+
from quack.reduction_base import ReductionBase
|
|
19
|
+
from quack.reduce import row_reduce
|
|
15
20
|
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
16
21
|
from quack.sort.bitonic_sort import bitonic_topk
|
|
17
22
|
|
|
18
23
|
|
|
19
24
|
class TopK:
|
|
20
|
-
def __init__(self, dtype: Type[cutlass.Numeric], N: int, k: int):
|
|
25
|
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int, k: int, softmax: bool = False):
|
|
21
26
|
self.dtype = dtype
|
|
22
27
|
self.N = N
|
|
23
28
|
self.vecsize = 128 // dtype.width
|
|
24
29
|
self.k = k
|
|
30
|
+
self.softmax = softmax
|
|
25
31
|
assert N == 2 ** int(math.log2(N)), "N must be a power of 2"
|
|
26
32
|
assert k == 2 ** int(math.log2(k)), "N must be a power of 2"
|
|
27
33
|
assert k <= 128
|
|
28
34
|
assert N <= 4096
|
|
29
35
|
|
|
30
|
-
def
|
|
36
|
+
def _threads_per_row(self):
|
|
31
37
|
# we want num_elems_per_thread >= self.k
|
|
32
38
|
# and each thread can handle at most 64 elements
|
|
33
39
|
N = self.N
|
|
34
40
|
num_threads_per_row = max(min(N // self.k, 32, N // 64), 1)
|
|
35
41
|
return num_threads_per_row
|
|
36
42
|
|
|
37
|
-
def
|
|
43
|
+
def _get_tiled_copy(self):
|
|
38
44
|
N = self.N
|
|
39
45
|
vecsize = self.vecsize
|
|
40
46
|
num_threads = 128 if N <= 16384 else 256
|
|
41
|
-
threads_per_row = self.
|
|
47
|
+
threads_per_row = self._threads_per_row()
|
|
42
48
|
cols_per_block = num_threads // threads_per_row
|
|
43
49
|
num_blocks_N = cute.ceil_div(min(N, 16384) // vecsize, threads_per_row)
|
|
44
50
|
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
stride=(
|
|
48
|
-
(vecsize * cols_per_block, 1),
|
|
49
|
-
(cols_per_block, cols_per_block * vecsize * threads_per_row),
|
|
50
|
-
),
|
|
51
|
+
tiled_copy = copy_utils.tiled_copy_2d(
|
|
52
|
+
self.dtype, threads_per_row, num_threads, num_copy_elems=vecsize
|
|
51
53
|
)
|
|
52
|
-
return tiler_mn,
|
|
54
|
+
return tiled_copy, tiler_mn, threads_per_row
|
|
53
55
|
|
|
54
56
|
@cute.jit
|
|
55
57
|
def __call__(
|
|
@@ -61,10 +63,10 @@ class TopK:
|
|
|
61
63
|
):
|
|
62
64
|
assert mX.element_type == self.dtype
|
|
63
65
|
assert mValues.element_type == self.dtype
|
|
64
|
-
assert mIndices.element_type ==
|
|
65
|
-
tiler_mn,
|
|
66
|
-
num_threads =
|
|
67
|
-
self.kernel(mX, mValues, mIndices,
|
|
66
|
+
assert mIndices.element_type == Int32
|
|
67
|
+
tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy()
|
|
68
|
+
num_threads = tiled_copy.size
|
|
69
|
+
self.kernel(mX, mValues, mIndices, tiler_mn, tiled_copy, threads_per_row).launch(
|
|
68
70
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), 1, 1],
|
|
69
71
|
block=[num_threads, 1, 1],
|
|
70
72
|
stream=stream,
|
|
@@ -76,103 +78,151 @@ class TopK:
|
|
|
76
78
|
mX: cute.Tensor,
|
|
77
79
|
mValues: cute.Tensor,
|
|
78
80
|
mIndices: cute.Tensor,
|
|
79
|
-
tv_layout: cute.Layout,
|
|
80
81
|
tiler_mn: cute.Shape,
|
|
82
|
+
tiled_copy: cute.TiledCopy,
|
|
83
|
+
threads_per_row: cutlass.Constexpr[int],
|
|
81
84
|
):
|
|
82
85
|
tidx, _, _ = cute.arch.thread_idx()
|
|
83
86
|
bidx, _, _ = cute.arch.block_idx()
|
|
87
|
+
tv_layout = tiled_copy.layout_tv_tiled
|
|
84
88
|
|
|
85
89
|
shape = mX.shape
|
|
86
90
|
idX = cute.make_identity_tensor(shape)
|
|
87
91
|
# slice for CTAs
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
cX = cute.local_tile(idX, tiler_mn, (bidx, 0))
|
|
92
|
-
|
|
93
|
-
# declare the atoms which will be used later for memory copy
|
|
94
|
-
copy_atom_load_X = cute.make_copy_atom(
|
|
95
|
-
cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128
|
|
96
|
-
)
|
|
97
|
-
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
98
|
-
tXgX = thr_copy_X.partition_S(gX)
|
|
99
|
-
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
92
|
+
gX, cX = [cute.local_tile(mT, tiler_mn, (bidx, 0)) for mT in (mX, idX)]
|
|
93
|
+
|
|
94
|
+
thr_copy = tiled_copy.get_slice(tidx)
|
|
100
95
|
|
|
101
|
-
|
|
96
|
+
tXgX = thr_copy.partition_S(gX)
|
|
97
|
+
tXcX = thr_copy.partition_S(cX)[(0, None), None, None]
|
|
102
98
|
tXrX = cute.make_fragment_like(tXgX)
|
|
103
99
|
|
|
104
100
|
is_even_N = const_expr(shape[1] == tiler_mn[1])
|
|
105
101
|
tXpX = (
|
|
106
|
-
|
|
102
|
+
None if is_even_N else copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1])
|
|
107
103
|
)
|
|
104
|
+
copy = partial(copy_utils.copy, pred=tXpX)
|
|
105
|
+
|
|
108
106
|
if tXcX[0][0] < shape[0]:
|
|
109
|
-
|
|
110
|
-
tXrX_f32 = cute.make_fragment(tXrX.shape,
|
|
111
|
-
tXrX_f32.store(tXrX.load().to(
|
|
107
|
+
copy(tXgX, tXrX)
|
|
108
|
+
tXrX_f32 = cute.make_fragment(tXrX.shape, Float32)
|
|
109
|
+
tXrX_f32.store(tXrX.load().to(Float32))
|
|
112
110
|
|
|
113
111
|
# Encode the indices into the bottom bits of values.
|
|
114
112
|
log_N = int(math.log2(self.N))
|
|
115
113
|
idx_mask = (1 << log_N) - 1
|
|
116
|
-
vecsize =
|
|
117
|
-
|
|
118
|
-
# Encode indices into the last log_N bits of
|
|
119
|
-
for i in cutlass.range(cute.size(
|
|
114
|
+
vecsize = const_expr(cute.size(tv_layout.shape[1]))
|
|
115
|
+
tXrX_i32 = cute.recast_tensor(tXrX_f32, Int32)
|
|
116
|
+
# Encode indices into the last log_N bits of tXrX_i32
|
|
117
|
+
for i in cutlass.range(cute.size(tXrX_i32), unroll_full=True):
|
|
120
118
|
# tXcX only keeps track of the indices for every @vecsize elements
|
|
121
|
-
col_idx =
|
|
119
|
+
col_idx = Int32(tXcX[i // vecsize][1] + i % vecsize)
|
|
122
120
|
# If positive, invert the bits of the index, so that if there's a tie,
|
|
123
121
|
# indices coming from a earlier column will win.
|
|
124
122
|
encoded_idx = ~col_idx if tXrX_f32[i] >= 0 else col_idx
|
|
125
123
|
# Mask to keep only the last log_N bits of the encoded index
|
|
126
124
|
encoded_idx = encoded_idx & idx_mask
|
|
127
125
|
# Clear the last log_N bits and set them to our encoded index
|
|
128
|
-
|
|
126
|
+
tXrX_i32[i] = (tXrX_i32[i] & ~idx_mask) | encoded_idx
|
|
129
127
|
|
|
130
128
|
# Fill OOB values with -inf for top-k
|
|
131
129
|
if const_expr(not is_even_N):
|
|
132
130
|
utils.fill_oob(tXrX_f32, tXpX, -tXrX_f32.element_type.inf)
|
|
133
131
|
|
|
134
|
-
threads_per_row = tv_layout.shape[0][0]
|
|
135
132
|
topk_vals = bitonic_topk(tXrX_f32, self.k, warp_width=threads_per_row)
|
|
136
133
|
|
|
134
|
+
# Thread 0 in each row contains all the top-k values, so we split those into multiple threads
|
|
135
|
+
vecsize_out = const_expr(min(self.k, vecsize, 128 // mIndices.element_type.width))
|
|
136
|
+
assert self.k % vecsize_out == 0
|
|
137
|
+
nvec_per_thread = const_expr(cute.ceil_div(self.k, vecsize_out * threads_per_row))
|
|
138
|
+
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
|
139
|
+
mask = cute.arch.WARP_SIZE - threads_per_row
|
|
140
|
+
mask_and_clamp = mask << 8 | (cute.arch.WARP_SIZE - 1)
|
|
141
|
+
topk_vals_split = cute.make_fragment((vecsize_out, nvec_per_thread), Float32)
|
|
142
|
+
for i in cutlass.range(cute.ceil_div(self.k, vecsize_out), unroll_full=True):
|
|
143
|
+
should_receive = tidx % threads_per_row == i % threads_per_row
|
|
144
|
+
for v in cutlass.range(vecsize_out, unroll_full=True):
|
|
145
|
+
if const_expr(threads_per_row > 1):
|
|
146
|
+
if i * vecsize_out + v < self.k:
|
|
147
|
+
val = cute.arch.shuffle_sync(
|
|
148
|
+
topk_vals[i * vecsize_out + v], offset=0, mask_and_clamp=mask_and_clamp
|
|
149
|
+
)
|
|
150
|
+
if should_receive:
|
|
151
|
+
topk_vals_split[v, i // threads_per_row] = val
|
|
152
|
+
else:
|
|
153
|
+
topk_vals_split[v, i // threads_per_row] = topk_vals[i * vecsize_out + v]
|
|
154
|
+
|
|
137
155
|
# Extract indices and clean values
|
|
138
|
-
|
|
139
|
-
topk_indices = cute.make_fragment(
|
|
140
|
-
for i in cutlass.range(
|
|
156
|
+
topk_vals_i32 = cute.recast_tensor(topk_vals_split, Int32)
|
|
157
|
+
topk_indices = cute.make_fragment(topk_vals_i32.shape, Int32)
|
|
158
|
+
for i in cutlass.range(cute.size(topk_vals_i32), unroll_full=True):
|
|
141
159
|
# Extract the encoded index from the last log_N bits
|
|
142
|
-
encoded_idx =
|
|
160
|
+
encoded_idx = topk_vals_i32[i] & idx_mask
|
|
143
161
|
# Check if original value was positive by looking at the cleaned value
|
|
144
|
-
|
|
162
|
+
topk_vals_i32[i] = topk_vals_i32[i] & ~idx_mask # Clear last log_N bits
|
|
145
163
|
# If positive, we need to invert the bits back to get original index
|
|
146
164
|
col_idx = ~encoded_idx if topk_vals[i] >= 0 else encoded_idx
|
|
147
|
-
topk_indices[i] =
|
|
165
|
+
topk_indices[i] = Int32(col_idx & idx_mask)
|
|
166
|
+
|
|
167
|
+
# Compute softmax if requested
|
|
168
|
+
if const_expr(self.softmax):
|
|
169
|
+
# Need masking as some elements may be OOB
|
|
170
|
+
for i in cutlass.range(cute.size(topk_vals_split, mode=[1]), unroll_full=True):
|
|
171
|
+
col = i * threads_per_row + tidx % threads_per_row
|
|
172
|
+
if col >= self.k // vecsize_out:
|
|
173
|
+
for v in cutlass.range(vecsize_out, unroll_full=True):
|
|
174
|
+
topk_vals_split[v, i] = -Float32.inf
|
|
175
|
+
# Get max from thread 0 (topk_vals[0] is the max since sorted descending)
|
|
176
|
+
max_val = cute.arch.shuffle_sync(topk_vals[0], offset=0, mask_and_clamp=mask_and_clamp)
|
|
177
|
+
log2_e = math.log2(math.e)
|
|
178
|
+
exp_x = cute.math.exp2(
|
|
179
|
+
topk_vals_split.load() * log2_e - (max_val * log2_e), fastmath=True
|
|
180
|
+
)
|
|
181
|
+
denom = cute.arch.warp_reduction_sum(
|
|
182
|
+
exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
|
|
183
|
+
threads_in_group=threads_per_row,
|
|
184
|
+
)
|
|
185
|
+
topk_vals_split.store(exp_x * cute.arch.rcp_approx(denom))
|
|
148
186
|
|
|
149
187
|
# Convert cleaned values to output type
|
|
150
|
-
topk_vals_out = cute.make_fragment_like(
|
|
151
|
-
topk_vals_out.store(
|
|
188
|
+
topk_vals_out = cute.make_fragment_like(topk_vals_split, mValues.element_type)
|
|
189
|
+
topk_vals_out.store(topk_vals_split.load().to(mValues.element_type))
|
|
152
190
|
|
|
153
191
|
row = tXcX[0][0]
|
|
154
|
-
# Only the 1st thread in this row writes the top-k values and indices
|
|
155
|
-
if row < shape[0] and tXcX[0][1] == 0:
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
192
|
+
# # Only the 1st thread in this row writes the top-k values and indices
|
|
193
|
+
# if row < shape[0] and tXcX[0][1] == 0:
|
|
194
|
+
# # for i in cutlass.range(self.k):
|
|
195
|
+
# # mValues[row, i] = topk_vals_out[i]
|
|
196
|
+
# # mIndices[row, i] = topk_indices[i]
|
|
197
|
+
# # Vectorized write
|
|
198
|
+
# elems_per_store = const_expr(math.gcd(vecsize, self.k))
|
|
199
|
+
# mValues_store = cute.tiled_divide(mValues[row, None], (elems_per_store,))
|
|
200
|
+
# mIndices_store = cute.tiled_divide(mIndices[row, None], (elems_per_store,))
|
|
201
|
+
# topk_vals_out_store = cute.tiled_divide(topk_vals_out, (elems_per_store,))
|
|
202
|
+
# topk_indices_store = cute.tiled_divide(topk_indices, (elems_per_store,))
|
|
203
|
+
# for i in cutlass.range(cute.size(topk_vals_out_store.shape, [1]), unroll_full=True):
|
|
204
|
+
# cute.autovec_copy(topk_vals_out_store[None, i], mValues_store[None, i])
|
|
205
|
+
# cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i])
|
|
206
|
+
if tiler_mn[0] == 0 or row < shape[0]:
|
|
159
207
|
# Vectorized write
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i])
|
|
208
|
+
mValues_store = cute.tiled_divide(mValues[row, None], (vecsize_out,))
|
|
209
|
+
mIndices_store = cute.tiled_divide(mIndices[row, None], (vecsize_out,))
|
|
210
|
+
for i in cutlass.range(cute.size(topk_vals_out.shape, [1]), unroll_full=True):
|
|
211
|
+
col = i * threads_per_row + tidx % threads_per_row
|
|
212
|
+
if col < self.k // vecsize_out:
|
|
213
|
+
cute.autovec_copy(topk_vals_out[None, i], mValues_store[None, col])
|
|
214
|
+
cute.autovec_copy(topk_indices[None, i], mIndices_store[None, col])
|
|
168
215
|
|
|
169
216
|
|
|
170
217
|
@torch.library.custom_op("quack::_topk_fwd", mutates_args={"values", "indices"})
|
|
171
|
-
def _topk_fwd(
|
|
218
|
+
def _topk_fwd(
|
|
219
|
+
x: torch.Tensor, k: int, softmax: bool, values: torch.Tensor, indices: torch.Tensor
|
|
220
|
+
) -> None:
|
|
172
221
|
"""Top-k forward pass.
|
|
173
222
|
Args:
|
|
174
223
|
x: Input tensor of shape (M, N)
|
|
175
224
|
k: Number of top elements to return
|
|
225
|
+
softmax: Whether to apply softmax to the top-k values
|
|
176
226
|
Returns:
|
|
177
227
|
Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k))
|
|
178
228
|
"""
|
|
@@ -182,46 +232,320 @@ def _topk_fwd(x: torch.Tensor, k: int, values: torch.Tensor, indices: torch.Tens
|
|
|
182
232
|
assert k > 0 and k <= x.shape[1], "k must be positive and <= N"
|
|
183
233
|
|
|
184
234
|
N = x.size(1)
|
|
185
|
-
|
|
186
235
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
187
|
-
|
|
188
|
-
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
189
|
-
mode=0, stride_order=(0, 1)
|
|
190
|
-
)
|
|
191
|
-
)
|
|
192
|
-
|
|
193
|
-
x_tensor, values_tensor, indices_tensor = [
|
|
194
|
-
convert_from_dlpack(tensor) for tensor in (x, values, indices)
|
|
195
|
-
]
|
|
196
|
-
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
197
|
-
compile_key = (dtype, N, k)
|
|
236
|
+
compile_key = (dtype, N, k, softmax)
|
|
198
237
|
if compile_key not in _topk_fwd.compile_cache:
|
|
199
|
-
|
|
238
|
+
batch_sym = cute.sym_int()
|
|
239
|
+
div = math.gcd(128 // dtype.width, N)
|
|
240
|
+
x_cute = fake_tensor(dtype, (batch_sym, N), div)
|
|
241
|
+
values_cute = fake_tensor(dtype, (batch_sym, k), div)
|
|
242
|
+
indices_cute = fake_tensor(Int32, (batch_sym, k), div)
|
|
243
|
+
topk_op = TopK(dtype, N, k, softmax=softmax)
|
|
200
244
|
_topk_fwd.compile_cache[compile_key] = cute.compile(
|
|
201
|
-
topk_op,
|
|
245
|
+
topk_op,
|
|
246
|
+
x_cute,
|
|
247
|
+
values_cute,
|
|
248
|
+
indices_cute,
|
|
249
|
+
cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
|
|
250
|
+
options="--enable-tvm-ffi",
|
|
202
251
|
)
|
|
203
|
-
_topk_fwd.compile_cache[compile_key](
|
|
252
|
+
_topk_fwd.compile_cache[compile_key](x, values, indices)
|
|
204
253
|
|
|
205
254
|
|
|
206
255
|
_topk_fwd.compile_cache = {}
|
|
207
256
|
|
|
208
257
|
|
|
209
|
-
def
|
|
258
|
+
def topk_fwd(x: torch.Tensor, k: int, softmax: bool = False):
|
|
210
259
|
"""Top-k operation.
|
|
211
260
|
|
|
212
261
|
Args:
|
|
213
262
|
x: Input tensor of shape (M, N)
|
|
214
263
|
k: Number of top elements to return
|
|
264
|
+
softmax: Whether to apply softmax to the top-k values
|
|
215
265
|
|
|
216
266
|
Returns:
|
|
217
267
|
Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k))
|
|
218
268
|
"""
|
|
219
|
-
|
|
220
269
|
M = x.size(0)
|
|
221
|
-
|
|
222
270
|
values = torch.empty((M, k), dtype=x.dtype, device=x.device)
|
|
223
271
|
indices = torch.empty((M, k), dtype=torch.int32, device=x.device)
|
|
272
|
+
_topk_fwd(x, k, softmax, values, indices)
|
|
273
|
+
return values, indices
|
|
224
274
|
|
|
225
|
-
_topk_fwd(x, k, values, indices)
|
|
226
275
|
|
|
227
|
-
|
|
276
|
+
class TopKBackward(ReductionBase):
|
|
277
|
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int, k: int, softmax: bool = False):
|
|
278
|
+
super().__init__(dtype, N, stage=1, reduction_dtype=Float32)
|
|
279
|
+
self.dtype = dtype
|
|
280
|
+
self.N = N
|
|
281
|
+
self.k = k
|
|
282
|
+
self.softmax = softmax
|
|
283
|
+
assert k <= N
|
|
284
|
+
assert k <= 32768
|
|
285
|
+
|
|
286
|
+
def _num_threads(self):
|
|
287
|
+
return 128 if self.N <= 16384 else 256
|
|
288
|
+
|
|
289
|
+
def _get_tiled_copy(self, N: int, vecsize: Optional[int] = None):
|
|
290
|
+
if vecsize is None:
|
|
291
|
+
vecsize = min(N, 128 // self.dtype.width)
|
|
292
|
+
assert N % vecsize == 0, f"Input N {N} is not divisible by vector size {vecsize}"
|
|
293
|
+
num_threads = self._num_threads()
|
|
294
|
+
threads_per_row = min(N // vecsize, num_threads)
|
|
295
|
+
cols_per_block = num_threads // threads_per_row
|
|
296
|
+
num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row)
|
|
297
|
+
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
|
|
298
|
+
tiled_copy = copy_utils.tiled_copy_2d(
|
|
299
|
+
self.dtype, threads_per_row, num_threads, num_copy_elems=vecsize
|
|
300
|
+
)
|
|
301
|
+
return tiled_copy, tiler_mn, threads_per_row
|
|
302
|
+
|
|
303
|
+
@cute.jit
|
|
304
|
+
def __call__(
|
|
305
|
+
self,
|
|
306
|
+
mdValues: cute.Tensor, # (M, k)
|
|
307
|
+
mValues: Optional[cute.Tensor], # (M, k)
|
|
308
|
+
mIndices: cute.Tensor, # (M, k)
|
|
309
|
+
mdX: cute.Tensor, # (M, N)
|
|
310
|
+
stream: cuda.CUstream,
|
|
311
|
+
):
|
|
312
|
+
assert mdValues.element_type == self.dtype
|
|
313
|
+
if const_expr(mValues is not None):
|
|
314
|
+
assert mValues.element_type == self.dtype
|
|
315
|
+
assert mIndices.element_type == Int32
|
|
316
|
+
self._set_cluster_n()
|
|
317
|
+
largest_dtype_width = const_expr(
|
|
318
|
+
max(
|
|
319
|
+
*(t.element_type.width for t in [mdValues, mValues, mIndices, mdX] if t is not None)
|
|
320
|
+
)
|
|
321
|
+
)
|
|
322
|
+
vecsize = math.gcd(self.N, 128 // largest_dtype_width)
|
|
323
|
+
tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(self.N, vecsize=vecsize)
|
|
324
|
+
num_threads = tiled_copy.size
|
|
325
|
+
self.kernel(
|
|
326
|
+
mdValues,
|
|
327
|
+
mValues,
|
|
328
|
+
mIndices,
|
|
329
|
+
mdX,
|
|
330
|
+
tiler_mn,
|
|
331
|
+
tiled_copy,
|
|
332
|
+
threads_per_row,
|
|
333
|
+
).launch(
|
|
334
|
+
grid=[cute.ceil_div(mdX.shape[0], tiler_mn[0]), 1, 1],
|
|
335
|
+
block=[num_threads, 1, 1],
|
|
336
|
+
stream=stream,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
@cute.kernel
|
|
340
|
+
def kernel(
|
|
341
|
+
self,
|
|
342
|
+
mdValues: cute.Tensor, # (M, k)
|
|
343
|
+
mValues: Optional[cute.Tensor], # (M, k)
|
|
344
|
+
mIndices: cute.Tensor, # (M, k)
|
|
345
|
+
mdX: cute.Tensor, # (M, N)
|
|
346
|
+
tiler_mn: cute.Shape,
|
|
347
|
+
tiled_copy: cute.TiledCopy,
|
|
348
|
+
threads_per_row: cutlass.Constexpr[int],
|
|
349
|
+
):
|
|
350
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
351
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
352
|
+
|
|
353
|
+
tv_layout = tiled_copy.layout_tv_tiled
|
|
354
|
+
shape = mdX.shape
|
|
355
|
+
idX = cute.make_identity_tensor(shape)
|
|
356
|
+
idTopK = cute.make_identity_tensor(mdValues.shape)
|
|
357
|
+
# slice for CTAs
|
|
358
|
+
gdX, cX = [cute.local_tile(mT, tiler_mn, (bidx, 0)) for mT in (mdX, idX)]
|
|
359
|
+
gdVals, gVals, gIdx, cTopK = [
|
|
360
|
+
cute.local_tile(mT, tiler_mn, (bidx, 0)) if mT is not None else None
|
|
361
|
+
for mT in (mdValues, mValues, mIndices, idTopK)
|
|
362
|
+
]
|
|
363
|
+
|
|
364
|
+
# Allocate smem for output gradients
|
|
365
|
+
smem = cutlass.utils.SmemAllocator()
|
|
366
|
+
sdX = smem.allocate_tensor(
|
|
367
|
+
mdX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
368
|
+
)
|
|
369
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
370
|
+
|
|
371
|
+
thr_copy = tiled_copy.get_slice(tidx)
|
|
372
|
+
|
|
373
|
+
tXgdV = thr_copy.partition_S(gdVals)
|
|
374
|
+
tXgV = thr_copy.partition_S(gVals) if const_expr(gVals is not None) else None
|
|
375
|
+
tXgI = thr_copy.partition_S(gIdx)
|
|
376
|
+
tXrdV = cute.make_fragment_like(tXgdV)
|
|
377
|
+
tXrV = cute.make_fragment_like(tXgV) if const_expr(tXgV is not None) else None
|
|
378
|
+
tXrI = cute.make_fragment_like(tXgI)
|
|
379
|
+
tXrdV.fill(tXrdV.element_type.zero)
|
|
380
|
+
if const_expr(mValues is not None):
|
|
381
|
+
tXrV.fill(tXrV.element_type.zero)
|
|
382
|
+
tXrI.fill(0)
|
|
383
|
+
|
|
384
|
+
tXsdX = thr_copy.partition_D(sdX)
|
|
385
|
+
tXgdX = thr_copy.partition_D(gdX)
|
|
386
|
+
tXcX = thr_copy.partition_S(cX)[(0, None), None, None]
|
|
387
|
+
tXrdX = cute.make_fragment_like(tXgdX)
|
|
388
|
+
|
|
389
|
+
is_even_N = const_expr(shape[1] == tiler_mn[1])
|
|
390
|
+
tXpV = copy_utils.predicate_k(thr_copy.partition_S(cTopK), limit=mdValues.shape[1])
|
|
391
|
+
tXpX = (
|
|
392
|
+
None if is_even_N else copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1])
|
|
393
|
+
)
|
|
394
|
+
copy_k = partial(copy_utils.copy, pred=tXpV)
|
|
395
|
+
copy_dx = partial(copy_utils.copy, pred=tXpX)
|
|
396
|
+
|
|
397
|
+
row = tXcX[0][0]
|
|
398
|
+
tile_row_start = Int32(cute.arch.block_idx()[0] * tiler_mn[0])
|
|
399
|
+
|
|
400
|
+
# Zero out smem
|
|
401
|
+
utils.fill_oob(tXsdX, None, fill_value=mdX.element_type.zero)
|
|
402
|
+
|
|
403
|
+
if row < shape[0]:
|
|
404
|
+
copy_k(tXgdV, tXrdV)
|
|
405
|
+
if const_expr(mValues is not None):
|
|
406
|
+
copy_k(tXgV, tXrV)
|
|
407
|
+
copy_k(tXgI, tXrI)
|
|
408
|
+
|
|
409
|
+
cute.arch.barrier()
|
|
410
|
+
|
|
411
|
+
dvals_f32 = tXrdV.load().to(Float32)
|
|
412
|
+
if const_expr(self.softmax):
|
|
413
|
+
vals_f32 = tXrV.load().to(Float32)
|
|
414
|
+
dot = row_reduce(
|
|
415
|
+
dvals_f32 * vals_f32,
|
|
416
|
+
cute.ReductionOp.ADD,
|
|
417
|
+
threads_per_row,
|
|
418
|
+
reduction_buffer[None, None, 0],
|
|
419
|
+
)
|
|
420
|
+
grads = vals_f32 * (dvals_f32 - dot)
|
|
421
|
+
else:
|
|
422
|
+
grads = dvals_f32
|
|
423
|
+
grad_cvt = cute.make_fragment(tXrdV.shape, mdX.element_type)
|
|
424
|
+
grad_cvt.store(grads.to(mdX.element_type))
|
|
425
|
+
|
|
426
|
+
# Scatter values to smem
|
|
427
|
+
if row < shape[0]:
|
|
428
|
+
for rest_v in cutlass.range(tXrdV.shape[0][1], unroll_full=True):
|
|
429
|
+
for n in cutlass.range(tXrdV.shape[2], unroll_full=True):
|
|
430
|
+
if tXpV[rest_v, 0, n]:
|
|
431
|
+
for v in cutlass.range(tXrdV.shape[0][0], unroll_full=True):
|
|
432
|
+
sdX[row - tile_row_start, tXrI[(v, rest_v), 0, n]] = grad_cvt[
|
|
433
|
+
(v, rest_v), 0, n
|
|
434
|
+
]
|
|
435
|
+
cute.arch.barrier()
|
|
436
|
+
|
|
437
|
+
# Read from smem to rmem, then write to gmem
|
|
438
|
+
cute.autovec_copy(tXsdX, tXrdX)
|
|
439
|
+
if row < shape[0]:
|
|
440
|
+
copy_dx(tXrdX, tXgdX)
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
@torch.library.custom_op("quack::_topk_bwd", mutates_args={"dx"})
|
|
444
|
+
def _topk_bwd(
|
|
445
|
+
dvalues: torch.Tensor,
|
|
446
|
+
values: Optional[torch.Tensor],
|
|
447
|
+
indices: torch.Tensor,
|
|
448
|
+
k: int,
|
|
449
|
+
softmax: bool,
|
|
450
|
+
dx: torch.Tensor,
|
|
451
|
+
) -> None:
|
|
452
|
+
"""Top-k backward pass.
|
|
453
|
+
Args:
|
|
454
|
+
dvalues: Upstream gradients tensor of shape (M, k)
|
|
455
|
+
values: Forward top-k values tensor of shape (M, k)
|
|
456
|
+
indices: Indices tensor of shape (M, k) from forward pass
|
|
457
|
+
k: Number of top elements
|
|
458
|
+
softmax: Whether softmax was applied in forward
|
|
459
|
+
dx: Output gradient tensor of shape (M, N)
|
|
460
|
+
"""
|
|
461
|
+
assert dvalues.dim() == 2, "dvalues must be 2D"
|
|
462
|
+
if values is not None:
|
|
463
|
+
assert values.dim() == 2, "values must be 2D"
|
|
464
|
+
assert indices.dim() == 2, "indices must be 2D"
|
|
465
|
+
assert dvalues.is_cuda and indices.is_cuda, "Tensors must be on CUDA device"
|
|
466
|
+
assert dvalues.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
467
|
+
|
|
468
|
+
N = dx.size(1)
|
|
469
|
+
dtype = torch2cute_dtype_map[dvalues.dtype]
|
|
470
|
+
val_dtype = torch2cute_dtype_map[values.dtype] if values is not None else dtype
|
|
471
|
+
dx_dtype = torch2cute_dtype_map[dx.dtype]
|
|
472
|
+
compile_key = (dtype, val_dtype, dx_dtype, N, k, softmax)
|
|
473
|
+
if compile_key not in _topk_bwd.compile_cache:
|
|
474
|
+
batch_sym = cute.sym_int()
|
|
475
|
+
div = math.gcd(128 // dtype.width, N)
|
|
476
|
+
dvalues_cute = fake_tensor(dtype, (batch_sym, k), div)
|
|
477
|
+
values_cute = fake_tensor(val_dtype, (batch_sym, k), div) if values is not None else None
|
|
478
|
+
indices_cute = fake_tensor(Int32, (batch_sym, k), div)
|
|
479
|
+
dx_cute = fake_tensor(dx_dtype, (batch_sym, N), div)
|
|
480
|
+
topk_bwd_op = TopKBackward(dtype, N, k, softmax=softmax)
|
|
481
|
+
_topk_bwd.compile_cache[compile_key] = cute.compile(
|
|
482
|
+
topk_bwd_op,
|
|
483
|
+
dvalues_cute,
|
|
484
|
+
values_cute,
|
|
485
|
+
indices_cute,
|
|
486
|
+
dx_cute,
|
|
487
|
+
cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
|
|
488
|
+
options="--enable-tvm-ffi",
|
|
489
|
+
)
|
|
490
|
+
_topk_bwd.compile_cache[compile_key](dvalues, values, indices, dx)
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
_topk_bwd.compile_cache = {}
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def topk_bwd(
|
|
497
|
+
dvalues: torch.Tensor,
|
|
498
|
+
values: Optional[torch.Tensor],
|
|
499
|
+
indices: torch.Tensor,
|
|
500
|
+
N: int,
|
|
501
|
+
softmax: bool = False,
|
|
502
|
+
) -> torch.Tensor:
|
|
503
|
+
"""Top-k backward pass.
|
|
504
|
+
|
|
505
|
+
Args:
|
|
506
|
+
dvalues: Upstream gradients tensor of shape (M, k)
|
|
507
|
+
values: Forward top-k values tensor of shape (M, k), required if softmax=True
|
|
508
|
+
indices: Indices tensor of shape (M, k) from forward pass
|
|
509
|
+
N: Size of the original input dimension
|
|
510
|
+
softmax: Whether softmax was applied in forward
|
|
511
|
+
|
|
512
|
+
Returns:
|
|
513
|
+
Input gradients tensor of shape (M, N)
|
|
514
|
+
"""
|
|
515
|
+
M, k = dvalues.shape
|
|
516
|
+
dx = torch.zeros((M, N), dtype=dvalues.dtype, device=dvalues.device)
|
|
517
|
+
_topk_bwd(dvalues, values, indices, k, softmax, dx)
|
|
518
|
+
return dx
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
class TopKFunction(torch.autograd.Function):
|
|
522
|
+
@staticmethod
|
|
523
|
+
def forward(ctx, x: torch.Tensor, k: int, softmax: bool = False):
|
|
524
|
+
values, indices = topk_fwd(x, k, softmax=softmax)
|
|
525
|
+
ctx.save_for_backward(values if softmax else None, indices)
|
|
526
|
+
ctx.k = k
|
|
527
|
+
ctx.N = x.shape[1]
|
|
528
|
+
ctx.softmax = softmax
|
|
529
|
+
ctx.mark_non_differentiable(indices)
|
|
530
|
+
ctx.set_materialize_grads(False)
|
|
531
|
+
return values, indices
|
|
532
|
+
|
|
533
|
+
@staticmethod
|
|
534
|
+
def backward(ctx, dvalues: torch.Tensor, dindices_: Optional[torch.Tensor] = None):
|
|
535
|
+
values, indices = ctx.saved_tensors
|
|
536
|
+
dx = topk_bwd(dvalues, values, indices, N=ctx.N, softmax=ctx.softmax)
|
|
537
|
+
return dx, None, None
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
def topk(x: torch.Tensor, k: int, softmax: bool = False):
|
|
541
|
+
"""Top-k operation.
|
|
542
|
+
|
|
543
|
+
Args:
|
|
544
|
+
x: Input tensor of shape (M, N)
|
|
545
|
+
k: Number of top elements to return
|
|
546
|
+
softmax: Whether to apply softmax to the top-k values
|
|
547
|
+
|
|
548
|
+
Returns:
|
|
549
|
+
Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k))
|
|
550
|
+
"""
|
|
551
|
+
return TopKFunction.apply(x, k, softmax)
|