quack-kernels 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +4 -1
- quack/autotuner.py +309 -0
- quack/cross_entropy.py +2 -5
- quack/cute_dsl_utils.py +40 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +2474 -0
- quack/fast_math.py +97 -0
- quack/gemm_config.py +61 -0
- quack/gemm_interface.py +321 -0
- quack/linear.py +176 -0
- quack/lse.py +62 -0
- quack/mlp.py +204 -0
- quack/pipeline.py +166 -0
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2088 -0
- quack/tensormap_manager.py +114 -0
- quack/tile_scheduler.py +935 -0
- quack/topk.py +221 -0
- quack/utils.py +237 -19
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/METADATA +3 -3
- quack_kernels-0.1.11.dist-info/RECORD +31 -0
- quack_kernels-0.1.9.dist-info/RECORD +0 -12
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/top_level.txt +0 -0
quack/topk.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import torch
|
|
5
|
+
from typing import Type
|
|
6
|
+
|
|
7
|
+
import cuda.bindings.driver as cuda
|
|
8
|
+
|
|
9
|
+
import cutlass
|
|
10
|
+
import cutlass.cute as cute
|
|
11
|
+
from cutlass.cute.runtime import from_dlpack
|
|
12
|
+
from cutlass import const_expr
|
|
13
|
+
|
|
14
|
+
import quack.utils as utils
|
|
15
|
+
from quack.reduction_base import torch2cute_dtype_map
|
|
16
|
+
from quack.sort.bitonic_sort import bitonic_topk
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TopK:
|
|
20
|
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int, k: int):
|
|
21
|
+
self.dtype = dtype
|
|
22
|
+
self.N = N
|
|
23
|
+
self.vecsize = 128 // dtype.width
|
|
24
|
+
self.k = k
|
|
25
|
+
assert N == 2 ** int(math.log2(N)), "N must be a power of 2"
|
|
26
|
+
assert k == 2 ** int(math.log2(k)), "N must be a power of 2"
|
|
27
|
+
assert k <= 128
|
|
28
|
+
assert N <= 4096
|
|
29
|
+
|
|
30
|
+
def _calculate_threads_per_row(self):
|
|
31
|
+
# we want num_elems_per_thread >= self.k
|
|
32
|
+
# and each thread can handle at most 64 elements
|
|
33
|
+
N = self.N
|
|
34
|
+
num_threads_per_row = max(min(N // self.k, 32, N // 64), 1)
|
|
35
|
+
return num_threads_per_row
|
|
36
|
+
|
|
37
|
+
def _get_tv_layout(self):
|
|
38
|
+
N = self.N
|
|
39
|
+
vecsize = self.vecsize
|
|
40
|
+
num_threads = 128 if N <= 16384 else 256
|
|
41
|
+
threads_per_row = self._calculate_threads_per_row()
|
|
42
|
+
cols_per_block = num_threads // threads_per_row
|
|
43
|
+
num_blocks_N = cute.ceil_div(min(N, 16384) // vecsize, threads_per_row)
|
|
44
|
+
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
|
|
45
|
+
tv_layout = cute.make_layout(
|
|
46
|
+
((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
|
|
47
|
+
stride=(
|
|
48
|
+
(vecsize * cols_per_block, 1),
|
|
49
|
+
(cols_per_block, cols_per_block * vecsize * threads_per_row),
|
|
50
|
+
),
|
|
51
|
+
)
|
|
52
|
+
return tiler_mn, tv_layout
|
|
53
|
+
|
|
54
|
+
@cute.jit
|
|
55
|
+
def __call__(
|
|
56
|
+
self,
|
|
57
|
+
mX: cute.Tensor,
|
|
58
|
+
mValues: cute.Tensor,
|
|
59
|
+
mIndices: cute.Tensor,
|
|
60
|
+
stream: cuda.CUstream,
|
|
61
|
+
):
|
|
62
|
+
assert mX.element_type == self.dtype
|
|
63
|
+
assert mValues.element_type == self.dtype
|
|
64
|
+
assert mIndices.element_type == cutlass.Int32
|
|
65
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
66
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
67
|
+
self.kernel(mX, mValues, mIndices, tv_layout, tiler_mn).launch(
|
|
68
|
+
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), 1, 1],
|
|
69
|
+
block=[num_threads, 1, 1],
|
|
70
|
+
stream=stream,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
@cute.kernel
|
|
74
|
+
def kernel(
|
|
75
|
+
self,
|
|
76
|
+
mX: cute.Tensor,
|
|
77
|
+
mValues: cute.Tensor,
|
|
78
|
+
mIndices: cute.Tensor,
|
|
79
|
+
tv_layout: cute.Layout,
|
|
80
|
+
tiler_mn: cute.Shape,
|
|
81
|
+
):
|
|
82
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
83
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
84
|
+
|
|
85
|
+
shape = mX.shape
|
|
86
|
+
idX = cute.make_identity_tensor(shape)
|
|
87
|
+
# slice for CTAs
|
|
88
|
+
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
|
89
|
+
mX = utils.domain_offset_i64((bidx * tiler_mn[0], 0), mX)
|
|
90
|
+
gX = cute.local_tile(mX, tiler_mn, (0, 0))
|
|
91
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, 0))
|
|
92
|
+
|
|
93
|
+
# declare the atoms which will be used later for memory copy
|
|
94
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
95
|
+
cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128
|
|
96
|
+
)
|
|
97
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
98
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
99
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
100
|
+
|
|
101
|
+
# allocate fragments for gmem->rmem
|
|
102
|
+
tXrX = cute.make_fragment_like(tXgX)
|
|
103
|
+
|
|
104
|
+
is_even_N = const_expr(shape[1] == tiler_mn[1])
|
|
105
|
+
tXpX = (
|
|
106
|
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
107
|
+
)
|
|
108
|
+
if tXcX[0][0] < shape[0]:
|
|
109
|
+
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
110
|
+
tXrX_f32 = cute.make_fragment(tXrX.shape, cutlass.Float32)
|
|
111
|
+
tXrX_f32.store(tXrX.load().to(cutlass.Float32))
|
|
112
|
+
|
|
113
|
+
# Encode the indices into the bottom bits of values.
|
|
114
|
+
log_N = int(math.log2(self.N))
|
|
115
|
+
idx_mask = (1 << log_N) - 1
|
|
116
|
+
vecsize = cutlass.const_expr(tv_layout.shape[1][0])
|
|
117
|
+
tXrX_u32 = cute.recast_tensor(tXrX_f32, cutlass.Uint32)
|
|
118
|
+
# Encode indices into the last log_N bits of tXrX_u32
|
|
119
|
+
for i in cutlass.range(cute.size(tXrX_u32), unroll_full=True):
|
|
120
|
+
# tXcX only keeps track of the indices for every @vecsize elements
|
|
121
|
+
col_idx = cutlass.Uint32(tXcX[i // vecsize][1] + i % vecsize)
|
|
122
|
+
# If positive, invert the bits of the index, so that if there's a tie,
|
|
123
|
+
# indices coming from a earlier column will win.
|
|
124
|
+
encoded_idx = ~col_idx if tXrX_f32[i] >= 0 else col_idx
|
|
125
|
+
# Mask to keep only the last log_N bits of the encoded index
|
|
126
|
+
encoded_idx = encoded_idx & idx_mask
|
|
127
|
+
# Clear the last log_N bits and set them to our encoded index
|
|
128
|
+
tXrX_u32[i] = (tXrX_u32[i] & ~idx_mask) | encoded_idx
|
|
129
|
+
|
|
130
|
+
# Fill OOB values with -inf for top-k
|
|
131
|
+
if const_expr(not is_even_N):
|
|
132
|
+
utils.fill_oob(tXrX_f32, tXpX, -tXrX_f32.element_type.inf)
|
|
133
|
+
|
|
134
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
135
|
+
topk_vals = bitonic_topk(tXrX_f32, self.k, warp_width=threads_per_row)
|
|
136
|
+
# Extract indices and clean values
|
|
137
|
+
topk_vals_u32 = cute.recast_tensor(topk_vals, cutlass.Uint32)
|
|
138
|
+
topk_indices = cute.make_fragment(self.k, cutlass.Int32)
|
|
139
|
+
for i in cutlass.range(self.k):
|
|
140
|
+
# Extract the encoded index from the last log_N bits
|
|
141
|
+
encoded_idx = topk_vals_u32[i] & idx_mask
|
|
142
|
+
# Check if original value was positive by looking at the cleaned value
|
|
143
|
+
topk_vals_u32[i] = topk_vals_u32[i] & ~idx_mask # Clear last log_N bits
|
|
144
|
+
# If positive, we need to invert the bits back to get original index
|
|
145
|
+
col_idx = ~encoded_idx if topk_vals[i] >= 0 else encoded_idx
|
|
146
|
+
topk_indices[i] = cutlass.Int32(col_idx & idx_mask)
|
|
147
|
+
|
|
148
|
+
# Convert cleaned values to output type
|
|
149
|
+
topk_vals_out = cute.make_fragment_like(topk_vals, mValues.element_type)
|
|
150
|
+
topk_vals_out.store(topk_vals.load().to(mValues.element_type))
|
|
151
|
+
|
|
152
|
+
row = tXcX[0][0]
|
|
153
|
+
# Only the 1st thread in this row writes the top-k values and indices
|
|
154
|
+
if row < shape[0] and tXcX[0][1] == 0:
|
|
155
|
+
# for i in cutlass.range(self.k):
|
|
156
|
+
# mValues[row, i] = topk_vals_out[i]
|
|
157
|
+
# mIndices[row, i] = topk_indices[i]
|
|
158
|
+
# Vectorized write
|
|
159
|
+
elems_per_store = const_expr(math.gcd(vecsize, self.k))
|
|
160
|
+
mValues_store = cute.tiled_divide(mValues[row, None], (elems_per_store,))
|
|
161
|
+
mIndices_store = cute.tiled_divide(mIndices[row, None], (elems_per_store,))
|
|
162
|
+
topk_vals_out_store = cute.tiled_divide(topk_vals_out, (elems_per_store,))
|
|
163
|
+
topk_indices_store = cute.tiled_divide(topk_indices, (elems_per_store,))
|
|
164
|
+
for i in cutlass.range(cute.size(topk_vals_out_store.shape, [1]), unroll_full=True):
|
|
165
|
+
cute.autovec_copy(topk_vals_out_store[None, i], mValues_store[None, i])
|
|
166
|
+
cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i])
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _topk_fwd(x: torch.Tensor, k: int):
|
|
170
|
+
"""Top-k forward pass.
|
|
171
|
+
Args:
|
|
172
|
+
x: Input tensor of shape (M, N)
|
|
173
|
+
k: Number of top elements to return
|
|
174
|
+
Returns:
|
|
175
|
+
Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k))
|
|
176
|
+
"""
|
|
177
|
+
assert x.dim() == 2, "Input must be 2D"
|
|
178
|
+
assert x.is_cuda, "Tensor must be on CUDA device"
|
|
179
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
180
|
+
assert k > 0 and k <= x.shape[1], "k must be positive and <= N"
|
|
181
|
+
|
|
182
|
+
M, N = x.shape
|
|
183
|
+
values = torch.empty((M, k), dtype=x.dtype, device=x.device)
|
|
184
|
+
indices = torch.empty((M, k), dtype=torch.int32, device=x.device)
|
|
185
|
+
|
|
186
|
+
dtype = torch2cute_dtype_map[x.dtype]
|
|
187
|
+
convert_from_dlpack = lambda tensor: (
|
|
188
|
+
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
189
|
+
mode=0, stride_order=(0, 1)
|
|
190
|
+
)
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
x_tensor, values_tensor, indices_tensor = [
|
|
194
|
+
convert_from_dlpack(tensor) for tensor in (x, values, indices)
|
|
195
|
+
]
|
|
196
|
+
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
197
|
+
compile_key = (dtype, N, k)
|
|
198
|
+
if compile_key not in _topk_fwd.compile_cache:
|
|
199
|
+
topk_op = TopK(dtype, N, k)
|
|
200
|
+
_topk_fwd.compile_cache[compile_key] = cute.compile(
|
|
201
|
+
topk_op, x_tensor, values_tensor, indices_tensor, current_stream
|
|
202
|
+
)
|
|
203
|
+
_topk_fwd.compile_cache[compile_key](x_tensor, values_tensor, indices_tensor, current_stream)
|
|
204
|
+
|
|
205
|
+
return values, indices
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
_topk_fwd.compile_cache = {}
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def topk(x: torch.Tensor, k: int):
|
|
212
|
+
"""Top-k operation.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
x: Input tensor of shape (M, N)
|
|
216
|
+
k: Number of top elements to return
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k))
|
|
220
|
+
"""
|
|
221
|
+
return _topk_fwd(x, k)
|
quack/utils.py
CHANGED
|
@@ -2,14 +2,14 @@
|
|
|
2
2
|
|
|
3
3
|
import operator
|
|
4
4
|
import math
|
|
5
|
-
from typing import Callable, Optional, Tuple
|
|
5
|
+
from typing import Callable, Optional, Tuple, Type, Union
|
|
6
6
|
|
|
7
7
|
import cutlass
|
|
8
8
|
import cutlass.cute as cute
|
|
9
9
|
|
|
10
|
-
from cutlass import Float32
|
|
10
|
+
from cutlass import Float32, Int32
|
|
11
11
|
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
12
|
-
from cutlass._mlir.dialects import llvm, vector
|
|
12
|
+
from cutlass._mlir.dialects import llvm, nvvm, vector
|
|
13
13
|
from cutlass.cute.runtime import from_dlpack
|
|
14
14
|
|
|
15
15
|
|
|
@@ -100,13 +100,14 @@ def store_shared_remote(
|
|
|
100
100
|
).ir_value()
|
|
101
101
|
if cutlass.const_expr(isinstance(val, float)):
|
|
102
102
|
val = Float32(val)
|
|
103
|
-
assert isinstance(val, (Float32, cutlass.Int64)), "val must be Float32 or Int64"
|
|
104
|
-
suffix = "f32"
|
|
103
|
+
assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64"
|
|
104
|
+
suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)]
|
|
105
|
+
constraint = {Float32: "f", Int32: "r", cutlass.Int64: "l"}[type(val)]
|
|
105
106
|
llvm.inline_asm(
|
|
106
107
|
None,
|
|
107
108
|
[remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
|
|
108
109
|
f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];",
|
|
109
|
-
f"r,{
|
|
110
|
+
f"r,{constraint},r",
|
|
110
111
|
has_side_effects=True,
|
|
111
112
|
is_align_stack=False,
|
|
112
113
|
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
@@ -198,9 +199,9 @@ def row_reduce(
|
|
|
198
199
|
hook_fn()
|
|
199
200
|
if cutlass.const_expr(reduction_buffer is not None):
|
|
200
201
|
warps_per_row, cluster_n = reduction_buffer.shape[1]
|
|
201
|
-
assert (
|
|
202
|
-
|
|
203
|
-
)
|
|
202
|
+
assert cluster_n == 1 or mbar_ptr is not None, (
|
|
203
|
+
"mbar_ptr must be provided for cluster reduction"
|
|
204
|
+
)
|
|
204
205
|
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
205
206
|
val = block_or_cluster_reduce(
|
|
206
207
|
val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val
|
|
@@ -237,13 +238,13 @@ def online_softmax_reduce(
|
|
|
237
238
|
hook_fn()
|
|
238
239
|
if cutlass.const_expr(reduction_buffer is not None):
|
|
239
240
|
rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
|
|
240
|
-
assert (
|
|
241
|
-
|
|
242
|
-
)
|
|
241
|
+
assert cluster_n == 1 or mbar_ptr is not None, (
|
|
242
|
+
"mbar_ptr must be provided for cluster reduction"
|
|
243
|
+
)
|
|
243
244
|
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
244
|
-
assert (
|
|
245
|
-
reduction_buffer
|
|
246
|
-
)
|
|
245
|
+
assert reduction_buffer.element_type == cutlass.Int64, (
|
|
246
|
+
"reduction_buffer must be of type cute.Int64"
|
|
247
|
+
)
|
|
247
248
|
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
248
249
|
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
249
250
|
if cutlass.const_expr(mbar_ptr is None):
|
|
@@ -304,6 +305,19 @@ def online_softmax_reduce(
|
|
|
304
305
|
return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
|
|
305
306
|
|
|
306
307
|
|
|
308
|
+
@dsl_user_op
|
|
309
|
+
def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None) -> Float32:
|
|
310
|
+
return Float32(
|
|
311
|
+
nvvm.fmin(
|
|
312
|
+
T.f32(),
|
|
313
|
+
Float32(a).ir_value(loc=loc, ip=ip),
|
|
314
|
+
Float32(b).ir_value(loc=loc, ip=ip),
|
|
315
|
+
loc=loc,
|
|
316
|
+
ip=ip,
|
|
317
|
+
)
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
|
|
307
321
|
@cute.jit
|
|
308
322
|
def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
|
|
309
323
|
"""exp2f calculation for both vector and scalar.
|
|
@@ -315,7 +329,7 @@ def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
|
|
|
315
329
|
if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
|
|
316
330
|
res = cute.make_fragment(x.shape, Float32)
|
|
317
331
|
res.store(x)
|
|
318
|
-
for i in cutlass.
|
|
332
|
+
for i in cutlass.range(cute.size(x.shape), unroll_full=True):
|
|
319
333
|
res[i] = cute.arch.exp2(res[i])
|
|
320
334
|
return res.load()
|
|
321
335
|
else:
|
|
@@ -337,6 +351,21 @@ def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
|
337
351
|
)
|
|
338
352
|
|
|
339
353
|
|
|
354
|
+
@dsl_user_op
|
|
355
|
+
def sqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
356
|
+
return Float32(
|
|
357
|
+
llvm.inline_asm(
|
|
358
|
+
T.f32(),
|
|
359
|
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
360
|
+
"sqrt.approx.ftz.f32 $0, $1;",
|
|
361
|
+
"=f,f",
|
|
362
|
+
has_side_effects=False,
|
|
363
|
+
is_align_stack=False,
|
|
364
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
365
|
+
)
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
|
|
340
369
|
@dsl_user_op
|
|
341
370
|
def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
342
371
|
return Float32(
|
|
@@ -352,6 +381,98 @@ def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
|
352
381
|
)
|
|
353
382
|
|
|
354
383
|
|
|
384
|
+
@dsl_user_op
|
|
385
|
+
def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
386
|
+
return Float32(
|
|
387
|
+
llvm.inline_asm(
|
|
388
|
+
T.f32(),
|
|
389
|
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
390
|
+
"tanh.approx.f32 $0, $1;",
|
|
391
|
+
"=f,f",
|
|
392
|
+
has_side_effects=False,
|
|
393
|
+
is_align_stack=False,
|
|
394
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
395
|
+
)
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
@dsl_user_op
|
|
400
|
+
def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32:
|
|
401
|
+
return Int32(
|
|
402
|
+
llvm.inline_asm(
|
|
403
|
+
T.i32(),
|
|
404
|
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
405
|
+
"cvt.rpi.ftz.s32.f32 $0, $1;",
|
|
406
|
+
"=r,f",
|
|
407
|
+
has_side_effects=False,
|
|
408
|
+
is_align_stack=False,
|
|
409
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
410
|
+
)
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
@dsl_user_op
|
|
415
|
+
def silu(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
416
|
+
"""
|
|
417
|
+
silu(a) = a * sigmoid(a) = a * (1 + tanh(a / 2)) / 2 = (0.5 * a) * tanh(0.5 * a) + (0.5 * a)
|
|
418
|
+
This compiles down to 3 SASS instructions: FMUL to get 0.5 * a, MUFU.TANH, and FFMA.
|
|
419
|
+
"""
|
|
420
|
+
a_half = 0.5 * a
|
|
421
|
+
return a_half * tanh(a_half) + a_half
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
@dsl_user_op
|
|
425
|
+
def prmt(a: int | Int32, b: int | Int32, c: int | Int32, *, loc=None, ip=None) -> Int32:
|
|
426
|
+
return Int32(
|
|
427
|
+
llvm.inline_asm(
|
|
428
|
+
T.i32(),
|
|
429
|
+
[
|
|
430
|
+
Int32(a).ir_value(loc=loc, ip=ip),
|
|
431
|
+
Int32(b).ir_value(loc=loc, ip=ip),
|
|
432
|
+
Int32(c).ir_value(loc=loc, ip=ip),
|
|
433
|
+
],
|
|
434
|
+
"prmt.b32 $0, $1, $2, $3;",
|
|
435
|
+
"=r,r,r,r",
|
|
436
|
+
has_side_effects=False,
|
|
437
|
+
is_align_stack=False,
|
|
438
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
439
|
+
)
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
@cute.jit
|
|
444
|
+
def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
|
|
445
|
+
assert t.element_type.width == 16
|
|
446
|
+
assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation"
|
|
447
|
+
t_u32 = cute.recast_tensor(t, Int32)
|
|
448
|
+
|
|
449
|
+
quad_idx = cute.arch.lane_idx() % 4
|
|
450
|
+
lane_03 = quad_idx == 0 or quad_idx == 3
|
|
451
|
+
selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054)
|
|
452
|
+
selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276)
|
|
453
|
+
# upper_map = [0, 3, 1, 2]
|
|
454
|
+
# lower_map = [1, 2, 0, 3]
|
|
455
|
+
# upper_idx = upper_map[quad_idx]
|
|
456
|
+
# indexing isn't supported so we have to do arithmetic
|
|
457
|
+
upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2
|
|
458
|
+
lower_idx = upper_idx ^ 1
|
|
459
|
+
|
|
460
|
+
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
|
461
|
+
width = 4
|
|
462
|
+
mask = cute.arch.WARP_SIZE - width
|
|
463
|
+
clamp = cute.arch.WARP_SIZE - 1
|
|
464
|
+
mask_and_clamp = mask << 8 | clamp
|
|
465
|
+
|
|
466
|
+
for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True):
|
|
467
|
+
upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1]
|
|
468
|
+
upper0 = upper if lane_03 else lower
|
|
469
|
+
lower0 = lower if lane_03 else upper
|
|
470
|
+
upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
|
|
471
|
+
lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
|
|
472
|
+
t_u32[i * 2 + 0] = prmt(upper0, lower0, selector_upper)
|
|
473
|
+
t_u32[i * 2 + 1] = prmt(upper0, lower0, selector_lower)
|
|
474
|
+
|
|
475
|
+
|
|
355
476
|
@cute.jit
|
|
356
477
|
def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
|
|
357
478
|
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
|
|
@@ -417,9 +538,9 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
|
|
|
417
538
|
def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
|
418
539
|
flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
|
|
419
540
|
flat_stride = cute.flatten_to_tuple(tensor.stride)
|
|
420
|
-
assert len(flat_coord_i64) == len(
|
|
421
|
-
|
|
422
|
-
)
|
|
541
|
+
assert len(flat_coord_i64) == len(flat_stride), (
|
|
542
|
+
"Coordinate and stride must have the same length"
|
|
543
|
+
)
|
|
423
544
|
offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
|
|
424
545
|
assert isinstance(tensor.iterator, cute.Pointer)
|
|
425
546
|
# HACK: we assume that applying the offset does not change the pointer alignment
|
|
@@ -446,3 +567,100 @@ def coord_offset_i64(
|
|
|
446
567
|
assumed_align=tensor.iterator.max_alignment,
|
|
447
568
|
)
|
|
448
569
|
return cute.make_tensor(new_ptr, tensor.layout)
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
@cute.jit
|
|
573
|
+
def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32:
|
|
574
|
+
if cutlass.const_expr(lane is None):
|
|
575
|
+
lane = cute.arch.lane_idx()
|
|
576
|
+
for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
|
|
577
|
+
offset = 1 << i
|
|
578
|
+
# Very important that we set mask_and_clamp to 0
|
|
579
|
+
partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0)
|
|
580
|
+
if lane >= offset:
|
|
581
|
+
val += partial_sum
|
|
582
|
+
return val
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout:
|
|
586
|
+
"""
|
|
587
|
+
For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
|
|
588
|
+
For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
|
|
589
|
+
"""
|
|
590
|
+
acc_layout_col_major = cute.make_layout(acc_layout.shape)
|
|
591
|
+
acc_layout_mn = cute.make_layout(
|
|
592
|
+
(
|
|
593
|
+
(acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
|
|
594
|
+
(
|
|
595
|
+
acc_layout_col_major.shape[0][0],
|
|
596
|
+
*acc_layout_col_major.shape[0][2:],
|
|
597
|
+
acc_layout_col_major.shape[2],
|
|
598
|
+
), # MMA_N
|
|
599
|
+
*acc_layout_col_major.shape[3:],
|
|
600
|
+
),
|
|
601
|
+
stride=(
|
|
602
|
+
(acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
|
|
603
|
+
(
|
|
604
|
+
acc_layout_col_major.stride[0][0],
|
|
605
|
+
*acc_layout_col_major.stride[0][2:],
|
|
606
|
+
acc_layout_col_major.stride[2],
|
|
607
|
+
), # MMA_N
|
|
608
|
+
*acc_layout_col_major.stride[3:],
|
|
609
|
+
),
|
|
610
|
+
)
|
|
611
|
+
return cute.composition(acc_layout, acc_layout_mn)
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor:
|
|
615
|
+
return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
@dsl_user_op
|
|
619
|
+
def sm90_get_smem_load_op(
|
|
620
|
+
layout_c: cutlass.utils.LayoutEnum,
|
|
621
|
+
elem_ty_c: Type[cutlass.Numeric],
|
|
622
|
+
*,
|
|
623
|
+
loc=None,
|
|
624
|
+
ip=None,
|
|
625
|
+
) -> cute.CopyAtom:
|
|
626
|
+
"""
|
|
627
|
+
Selects the largest vectorized smem load atom available subject to constraint of gmem layout.
|
|
628
|
+
|
|
629
|
+
Parameters:
|
|
630
|
+
-----------
|
|
631
|
+
layout_c : LayoutEnum
|
|
632
|
+
The layout enum of the output tensor D.
|
|
633
|
+
|
|
634
|
+
elem_ty_c : Type[Numeric]
|
|
635
|
+
The element type for output tensor D.
|
|
636
|
+
|
|
637
|
+
Returns:
|
|
638
|
+
--------
|
|
639
|
+
Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters.
|
|
640
|
+
"""
|
|
641
|
+
|
|
642
|
+
if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta):
|
|
643
|
+
raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
|
|
644
|
+
is_m_major = layout_c.is_m_major_c()
|
|
645
|
+
if elem_ty_c.width == 16:
|
|
646
|
+
return cute.make_copy_atom(
|
|
647
|
+
cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip
|
|
648
|
+
)
|
|
649
|
+
else:
|
|
650
|
+
return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
@dsl_user_op
|
|
654
|
+
def atomic_add_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32:
|
|
655
|
+
return nvvm.atomicrmw(
|
|
656
|
+
res=T.i32(), op=nvvm.AtomicOpKind.ADD, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value()
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
|
|
660
|
+
@dsl_user_op
|
|
661
|
+
def atomic_inc_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32:
|
|
662
|
+
return nvvm.atomicrmw(
|
|
663
|
+
res=T.i32(), op=nvvm.AtomicOpKind.INC, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value()
|
|
664
|
+
)
|
|
665
|
+
|
|
666
|
+
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: quack-kernels
|
|
3
|
-
Version: 0.1.
|
|
4
|
-
Requires-Python: >=3.
|
|
3
|
+
Version: 0.1.11
|
|
4
|
+
Requires-Python: >=3.12
|
|
5
5
|
License-File: LICENSE
|
|
6
|
-
Requires-Dist: nvidia-cutlass-dsl==4.1.0
|
|
6
|
+
Requires-Dist: nvidia-cutlass-dsl==4.1.0
|
|
7
7
|
Requires-Dist: torch
|
|
8
8
|
Provides-Extra: dev
|
|
9
9
|
Requires-Dist: pre-commit; extra == "dev"
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
quack/__init__.py,sha256=AD0T-rBhSfKXpwZ6E4JIPiugvlFaAePjl-3pUhWOlPE,292
|
|
2
|
+
quack/autotuner.py,sha256=aF9-Cw47gaX7_LZvyVbLsj6Z2AWi4UZ-0Qwjy06Xd5I,10733
|
|
3
|
+
quack/cross_entropy.py,sha256=xsg2bXZ4wNvusBARhN4PwAzm5PbejEcfwj71nR7bzuE,20852
|
|
4
|
+
quack/cute_dsl_utils.py,sha256=LkNyFEKwYrgp-tLt_775EZWuBR3v7G80El3UAObHY2U,1292
|
|
5
|
+
quack/dense_gemm_sm100.py,sha256=W_j8BO-ilb1YUYFuclo7_itfPIRTkjPV_ittWgQy8t4,109937
|
|
6
|
+
quack/dense_gemm_sm90.py,sha256=Dff0GbIv92uTjrtsUE1GjVKCtwSf6_5KZbrqYZm-ZMY,110418
|
|
7
|
+
quack/fast_math.py,sha256=XqXVvKLSxXC3c9tIGLvKVRWdPsmjAa_O4C0plmsfZ0w,3106
|
|
8
|
+
quack/gemm_config.py,sha256=Gz4dkHH1Uwg9IdW-x5W_5tjdaFHBfxq4bn7hJx_xu5s,1789
|
|
9
|
+
quack/gemm_interface.py,sha256=XHgxo08d8LIu6dTlQKBOBJtjCegUB5uLh4k9hC-5mvY,9525
|
|
10
|
+
quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
|
|
11
|
+
quack/linear.py,sha256=Wd0KeXWvWjbkKrgW4Av1ud2v_mbhzf1RvubF7BYhcw4,6425
|
|
12
|
+
quack/lse.py,sha256=aANOleIYREyrkUQM9cfJ9Gt63eawMb2KVd7YAGWNoZU,2092
|
|
13
|
+
quack/mlp.py,sha256=D9V7aIfvoBMzhKwN8ZE6GlSOmwFJe_JGqgOvQprU0OQ,8224
|
|
14
|
+
quack/pipeline.py,sha256=SwvRZAR4RqYH60wAFC3OTu5DisN1XDMv5umQF4czJW4,5867
|
|
15
|
+
quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
|
|
16
|
+
quack/rmsnorm.py,sha256=bJEHqc8ila-LTGco-tNNCUyFBjJ2UdXeoMplYNJPXFI,32740
|
|
17
|
+
quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
|
|
18
|
+
quack/symmetric_dense_gemm_sm90.py,sha256=t-6eLasZwyu1NW4HpnvVBBPOvfqUzOg8VHe9sJQYdmg,88637
|
|
19
|
+
quack/tensormap_manager.py,sha256=pzBNwLCB8kV_yp8X8_BoDdtbwWeht2jrgRhyyfVIcMI,5261
|
|
20
|
+
quack/tile_scheduler.py,sha256=mImjD2LuIVchM6USJoJY4-CSG54jGuwyLIvFG6LTP9Y,42205
|
|
21
|
+
quack/topk.py,sha256=1pObblNJnxKLaE_T3qGvaMnUua0dqG2en9OU5PSp71s,9020
|
|
22
|
+
quack/utils.py,sha256=4ViEFgHecaX5wcYpO6XzTCzdnuZv2rniUJAJH5Ta0bA,24981
|
|
23
|
+
quack/sort/bitonic_sort.py,sha256=8t0SG1a6iEpYIlY8YM_AWvm4aN-4AA4vEzdBuJMJm9g,4768
|
|
24
|
+
quack/sort/generate_sorting_networks.py,sha256=vkJBOjTVEinQkWT4OtFqOWxFVdTIPoNAQocneKc9-rM,14477
|
|
25
|
+
quack/sort/sorting_networks.py,sha256=l_26zi3gXD_z-tnm2eAczRrmE-mbaz00KmqH6ONivL8,9686
|
|
26
|
+
quack/sort/utils.py,sha256=Mkr-l97RMAV-ZoNrwuzA1U3KO0Wjr38CV9Jm7ScyZoI,1090
|
|
27
|
+
quack_kernels-0.1.11.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
28
|
+
quack_kernels-0.1.11.dist-info/METADATA,sha256=WTYlk9lmhr4Jkin71stp3h-NrBdme-8OrBc7lAf4vSw,286
|
|
29
|
+
quack_kernels-0.1.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
30
|
+
quack_kernels-0.1.11.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
31
|
+
quack_kernels-0.1.11.dist-info/RECORD,,
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
quack/__init__.py,sha256=CT76CeRNh5bzQ9f13yVuRz9Sj7V3MvwzHH4fB1iQIf0,203
|
|
2
|
-
quack/cross_entropy.py,sha256=VYSAd28GmtnMoKQwLrorvySDtJfRhoqVd-aeM52FmsI,20866
|
|
3
|
-
quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
|
|
4
|
-
quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
|
|
5
|
-
quack/rmsnorm.py,sha256=bJEHqc8ila-LTGco-tNNCUyFBjJ2UdXeoMplYNJPXFI,32740
|
|
6
|
-
quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
|
|
7
|
-
quack/utils.py,sha256=aiyzBc9BEwq8s965elfiR331hAaLLBKL9kDHjuls86Q,17791
|
|
8
|
-
quack_kernels-0.1.9.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
9
|
-
quack_kernels-0.1.9.dist-info/METADATA,sha256=vOnpbShNHRiUXKAnOUxzfRM7zkpW3RmjW4hIgvYda08,289
|
|
10
|
-
quack_kernels-0.1.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
11
|
-
quack_kernels-0.1.9.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
12
|
-
quack_kernels-0.1.9.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|