quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/autotuner.py +64 -5
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -35
- quack/gemm.py +194 -0
- quack/gemm_act.py +510 -0
- quack/gemm_config.py +72 -46
- quack/gemm_dact.py +215 -0
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +615 -146
- quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +182 -23
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +508 -624
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +55 -61
- quack/topk.py +409 -85
- quack/utils.py +37 -172
- quack/varlen_utils.py +370 -6
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/gemm_act_sm90.py +0 -368
- quack/gemm_dact_sm90.py +0 -150
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.1.dist-info/RECORD +0 -37
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/autotuner.py
CHANGED
|
@@ -11,7 +11,7 @@ import hashlib
|
|
|
11
11
|
import json
|
|
12
12
|
from pathlib import Path
|
|
13
13
|
from functools import cached_property, partial
|
|
14
|
-
from typing import Dict, Tuple
|
|
14
|
+
from typing import Dict, Tuple, List, Optional, Any
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
17
|
from torch import Tensor
|
|
@@ -53,7 +53,22 @@ def _base32(key):
|
|
|
53
53
|
|
|
54
54
|
|
|
55
55
|
class Autotuner:
|
|
56
|
-
def __init__(
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
fn,
|
|
59
|
+
key,
|
|
60
|
+
configs,
|
|
61
|
+
restore_value=None,
|
|
62
|
+
prune_configs_by: Optional[Dict] = None,
|
|
63
|
+
do_bench=None,
|
|
64
|
+
cache_results=False,
|
|
65
|
+
):
|
|
66
|
+
"""
|
|
67
|
+
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
|
68
|
+
'perf_model': performance model used to predicate running time with different configs, returns running time
|
|
69
|
+
'top_k': number of configs to bench
|
|
70
|
+
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
|
|
71
|
+
"""
|
|
57
72
|
if not configs:
|
|
58
73
|
self.configs = [AutotuneConfig()]
|
|
59
74
|
else:
|
|
@@ -90,6 +105,16 @@ class Autotuner:
|
|
|
90
105
|
else:
|
|
91
106
|
self.post_hook = None
|
|
92
107
|
|
|
108
|
+
self.perf_model = None
|
|
109
|
+
self.configs_top_k = 1.0
|
|
110
|
+
self.early_config_prune = None
|
|
111
|
+
if prune_configs_by:
|
|
112
|
+
self.perf_model = prune_configs_by.get("perf_model", self.perf_model)
|
|
113
|
+
self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k)
|
|
114
|
+
self.early_config_prune = prune_configs_by.get(
|
|
115
|
+
"early_config_prune", self.early_config_prune
|
|
116
|
+
)
|
|
117
|
+
|
|
93
118
|
self.fn = fn
|
|
94
119
|
self._do_bench = do_bench
|
|
95
120
|
|
|
@@ -198,13 +223,14 @@ class Autotuner:
|
|
|
198
223
|
key = tuple(key)
|
|
199
224
|
if key not in self.cache:
|
|
200
225
|
used_cached_result = False
|
|
226
|
+
pruned_configs = self.prune_configs(kwargs)
|
|
201
227
|
|
|
202
228
|
@torch.compiler.disable # Don't want any tracing here
|
|
203
229
|
def benchmark():
|
|
204
230
|
bench_start = time.time()
|
|
205
231
|
timings = {
|
|
206
232
|
config: self._bench(*args, config=config, **kwargs)
|
|
207
|
-
for config in
|
|
233
|
+
for config in pruned_configs
|
|
208
234
|
}
|
|
209
235
|
bench_end = time.time()
|
|
210
236
|
if os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1":
|
|
@@ -215,7 +241,7 @@ class Autotuner:
|
|
|
215
241
|
self.configs_timings = timings
|
|
216
242
|
|
|
217
243
|
if self.cache_results:
|
|
218
|
-
self.check_disk_cache(key,
|
|
244
|
+
self.check_disk_cache(key, pruned_configs, benchmark)
|
|
219
245
|
else:
|
|
220
246
|
benchmark()
|
|
221
247
|
|
|
@@ -239,6 +265,32 @@ class Autotuner:
|
|
|
239
265
|
self.nargs = None
|
|
240
266
|
return ret
|
|
241
267
|
|
|
268
|
+
def prune_configs(self, kwargs: Dict) -> List[Any]:
|
|
269
|
+
pruned_configs = self.configs
|
|
270
|
+
if self.early_config_prune:
|
|
271
|
+
pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
|
|
272
|
+
if self.perf_model:
|
|
273
|
+
top_k = self.configs_top_k
|
|
274
|
+
if isinstance(top_k, float) and top_k <= 1.0:
|
|
275
|
+
top_k = int(len(self.configs) * top_k)
|
|
276
|
+
elif not isinstance(top_k, int):
|
|
277
|
+
# Slice index must be an integer
|
|
278
|
+
raise TypeError(
|
|
279
|
+
"Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int"
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
if len(pruned_configs) > top_k:
|
|
283
|
+
est_timing = {
|
|
284
|
+
config: self.perf_model(
|
|
285
|
+
**self.nargs,
|
|
286
|
+
**kwargs,
|
|
287
|
+
**config.all_kwargs(),
|
|
288
|
+
)
|
|
289
|
+
for config in pruned_configs
|
|
290
|
+
}
|
|
291
|
+
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
|
292
|
+
return pruned_configs
|
|
293
|
+
|
|
242
294
|
|
|
243
295
|
class AutotuneConfig:
|
|
244
296
|
"""
|
|
@@ -272,7 +324,9 @@ class AutotuneConfig:
|
|
|
272
324
|
return self_tuple == other_tuple
|
|
273
325
|
|
|
274
326
|
|
|
275
|
-
def autotune(
|
|
327
|
+
def autotune(
|
|
328
|
+
configs, key=None, prune_configs_by=None, restore_value=None, do_bench=None, cache_results=True
|
|
329
|
+
):
|
|
276
330
|
f"""
|
|
277
331
|
Decorator for auto-tuning a function function.
|
|
278
332
|
|
|
@@ -286,6 +340,10 @@ def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results
|
|
|
286
340
|
:type configs: list[AutotuneConfig]
|
|
287
341
|
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
|
288
342
|
:type key: list[str]
|
|
343
|
+
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
|
344
|
+
'perf_model': performance model used to predicate running time with different configs, returns running time
|
|
345
|
+
'top_k': number of configs to bench
|
|
346
|
+
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
|
|
289
347
|
:param restore_value: a list of argument names whose value will be restored after evaluating any configs.
|
|
290
348
|
:type restore_value: list[str]
|
|
291
349
|
:param do_bench: a benchmark function to measure the time of each run.
|
|
@@ -303,6 +361,7 @@ def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results
|
|
|
303
361
|
key,
|
|
304
362
|
configs,
|
|
305
363
|
restore_value=restore_value,
|
|
364
|
+
prune_configs_by=prune_configs_by,
|
|
306
365
|
do_bench=do_bench,
|
|
307
366
|
cache_results=cache_results,
|
|
308
367
|
)
|
quack/broadcast_utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
import cutlass
|
|
5
|
+
import cutlass.cute as cute
|
|
6
|
+
from cutlass import Float32, const_expr
|
|
7
|
+
|
|
8
|
+
from quack.layout_utils import make_acc_tensor_mn_view
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@cute.jit
|
|
12
|
+
def vec_op(tCrC: cute.Tensor, tCrVec: cute.Tensor, op: Callable, is_colvec: bool) -> None:
|
|
13
|
+
if const_expr(tCrC.element_type != Float32): # Convert to f32
|
|
14
|
+
tCrC_f32 = cute.make_fragment(tCrC.shape, Float32)
|
|
15
|
+
tCrC_f32.store(tCrC.load().to(Float32))
|
|
16
|
+
else:
|
|
17
|
+
tCrC_f32 = tCrC
|
|
18
|
+
# this happens to work for frgA layout too, not just acc layout
|
|
19
|
+
tCrC_f32_mn = make_acc_tensor_mn_view(tCrC_f32)
|
|
20
|
+
if const_expr(is_colvec):
|
|
21
|
+
assert cute.size(tCrC_f32_mn, mode=[0]) == cute.size(tCrVec)
|
|
22
|
+
for r in cutlass.range(cute.size(tCrC_f32_mn, mode=[0]), unroll_full=True):
|
|
23
|
+
tCrC_f32_mn[r, None].store(op(tCrC_f32_mn[r, None].load(), tCrVec[r]))
|
|
24
|
+
else:
|
|
25
|
+
assert cute.size(tCrC_f32_mn, mode=[1]) == cute.size(tCrVec)
|
|
26
|
+
for c in cutlass.range(cute.size(tCrC_f32_mn, mode=[1]), unroll_full=True):
|
|
27
|
+
tCrC_f32_mn[None, c].store(op(tCrC_f32_mn[None, c].load(), tCrVec[c]))
|
|
28
|
+
if const_expr(tCrC.element_type != Float32): # Convert back to original dtype
|
|
29
|
+
tCrC.store(tCrC_f32.load().to(tCrC.element_type))
|
quack/compile_utils.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import cutlass.cute as cute
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def make_fake_tensor(dtype, shape, divisibility=1, leading_dim=-1) -> Optional[cute.Tensor]:
|
|
9
|
+
if leading_dim < 0:
|
|
10
|
+
leading_dim = len(shape) + leading_dim
|
|
11
|
+
if dtype is None:
|
|
12
|
+
return None
|
|
13
|
+
stride = tuple(
|
|
14
|
+
cute.sym_int64(divisibility=divisibility) if i != leading_dim else 1
|
|
15
|
+
for i in range(len(shape))
|
|
16
|
+
)
|
|
17
|
+
return cute.runtime.make_fake_tensor(
|
|
18
|
+
dtype, shape, stride=stride, assumed_align=divisibility * dtype.width // 8
|
|
19
|
+
)
|
quack/copy_utils.py
ADDED
|
@@ -0,0 +1,487 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from typing import Optional, Type, Tuple, Callable
|
|
5
|
+
|
|
6
|
+
import cutlass
|
|
7
|
+
import cutlass.cute as cute
|
|
8
|
+
|
|
9
|
+
from cutlass import Int32, Boolean, const_expr
|
|
10
|
+
from cutlass.cute.nvgpu import cpasync
|
|
11
|
+
from cutlass.cutlass_dsl import dsl_user_op
|
|
12
|
+
import cutlass.pipeline
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dsl_user_op
|
|
16
|
+
def cvt_copy(
|
|
17
|
+
atom: cute.CopyAtom,
|
|
18
|
+
src: cute.Tensor,
|
|
19
|
+
dst: cute.Tensor,
|
|
20
|
+
*,
|
|
21
|
+
pred: Optional[cute.Tensor] = None,
|
|
22
|
+
loc=None,
|
|
23
|
+
ip=None,
|
|
24
|
+
**kwargs,
|
|
25
|
+
) -> None:
|
|
26
|
+
assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem
|
|
27
|
+
if const_expr(src.element_type != dst.element_type):
|
|
28
|
+
src_cvt = cute.make_fragment_like(src, dst.element_type)
|
|
29
|
+
src_cvt.store(src.load().to(dst.element_type))
|
|
30
|
+
src = src_cvt
|
|
31
|
+
cute.copy(atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dsl_user_op
|
|
35
|
+
def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
|
36
|
+
dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip)
|
|
37
|
+
cute.autovec_copy(src, dst, loc=loc, ip=ip)
|
|
38
|
+
return dst
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dsl_user_op
|
|
42
|
+
def load_s2r_retile(
|
|
43
|
+
tiled_copy: cute.TiledCopy,
|
|
44
|
+
src: cute.Tensor,
|
|
45
|
+
dst_shape: cute.Tensor | cute.Shape,
|
|
46
|
+
*,
|
|
47
|
+
loc=None,
|
|
48
|
+
ip=None,
|
|
49
|
+
) -> cute.Tensor:
|
|
50
|
+
# Will also accept dst_shape being a tensor, in which case we write into that tensor
|
|
51
|
+
if const_expr(not isinstance(dst_shape, cute.Tensor)):
|
|
52
|
+
dst = cute.make_fragment(dst_shape, src.element_type, loc=loc, ip=ip)
|
|
53
|
+
else:
|
|
54
|
+
dst = dst_shape
|
|
55
|
+
cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip)
|
|
56
|
+
return dst
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dsl_user_op
|
|
60
|
+
def get_copy_atom(
|
|
61
|
+
dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
|
|
62
|
+
) -> cute.CopyAtom:
|
|
63
|
+
num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
|
|
64
|
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
|
65
|
+
return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dsl_user_op
|
|
69
|
+
def copy(
|
|
70
|
+
src: cute.Tensor,
|
|
71
|
+
dst: cute.Tensor,
|
|
72
|
+
*,
|
|
73
|
+
pred: Optional[cute.Tensor] = None,
|
|
74
|
+
is_async: bool = False,
|
|
75
|
+
loc=None,
|
|
76
|
+
ip=None,
|
|
77
|
+
**kwargs,
|
|
78
|
+
) -> None:
|
|
79
|
+
num_copy_elems = src.shape[0][0]
|
|
80
|
+
copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async)
|
|
81
|
+
cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def tiled_copy_1d(
|
|
85
|
+
dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False
|
|
86
|
+
) -> cute.TiledCopy:
|
|
87
|
+
num_copy_bits = num_copy_elems * dtype.width
|
|
88
|
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
|
89
|
+
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
|
90
|
+
thr_layout = cute.make_layout(num_threads)
|
|
91
|
+
val_layout = cute.make_layout(num_copy_elems)
|
|
92
|
+
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def tiled_copy_2d(
|
|
96
|
+
dtype: Type[cutlass.Numeric],
|
|
97
|
+
threads_per_row: int,
|
|
98
|
+
num_threads: int,
|
|
99
|
+
num_copy_elems: int = 1,
|
|
100
|
+
is_async: bool = False,
|
|
101
|
+
) -> cute.TiledCopy:
|
|
102
|
+
num_copy_bits = num_copy_elems * dtype.width
|
|
103
|
+
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
|
104
|
+
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
|
105
|
+
assert num_threads % threads_per_row == 0
|
|
106
|
+
thr_layout = cute.make_ordered_layout(
|
|
107
|
+
(num_threads // threads_per_row, threads_per_row),
|
|
108
|
+
order=(1, 0),
|
|
109
|
+
)
|
|
110
|
+
val_layout = cute.make_layout((1, num_copy_elems))
|
|
111
|
+
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@cute.jit
|
|
115
|
+
def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor:
|
|
116
|
+
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
|
|
117
|
+
tApA = cute.make_fragment(
|
|
118
|
+
cute.make_layout(
|
|
119
|
+
(cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
|
|
120
|
+
stride=(cute.size(tAcA, mode=[2]), 0, 1),
|
|
121
|
+
),
|
|
122
|
+
Boolean,
|
|
123
|
+
)
|
|
124
|
+
for rest_v in cutlass.range_constexpr(tApA.shape[0]):
|
|
125
|
+
for rest_k in cutlass.range_constexpr(tApA.shape[2]):
|
|
126
|
+
tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
|
|
127
|
+
return tApA
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
# def tiled_copy_2d(
|
|
131
|
+
# dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False
|
|
132
|
+
# ) -> cute.TiledCopy:
|
|
133
|
+
# num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
|
|
134
|
+
# copy_elems = num_copy_bits // dtype.width
|
|
135
|
+
# copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
|
136
|
+
# copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
|
137
|
+
# gmem_threads_per_row = major_mode_size // copy_elems
|
|
138
|
+
# assert num_threads % gmem_threads_per_row == 0
|
|
139
|
+
# thr_layout = cute.make_ordered_layout(
|
|
140
|
+
# (num_threads // gmem_threads_per_row, gmem_threads_per_row),
|
|
141
|
+
# order=(1, 0),
|
|
142
|
+
# )
|
|
143
|
+
# val_layout = cute.make_layout((1, copy_elems))
|
|
144
|
+
# return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def parse_swizzle_from_pointer(ptr: cute.Pointer) -> Tuple[int, int, int]:
|
|
148
|
+
"""Extract swizzle parameters from a pointer's swizzle_type.
|
|
149
|
+
|
|
150
|
+
The swizzle_type string has the form '!cute.swizzle<"S<b,m,s>">' where
|
|
151
|
+
b, m, s are the swizzle parameters (bits, base, shift).
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
A cute.Swizzle object constructed from the extracted parameters
|
|
155
|
+
|
|
156
|
+
Raises:
|
|
157
|
+
ValueError: If the swizzle_type string cannot be parsed
|
|
158
|
+
"""
|
|
159
|
+
# Ideally there should be a better API to get swizzle parameters, but we'll just parse
|
|
160
|
+
# the string here.
|
|
161
|
+
swizzle_str = str(ptr.type.swizzle_type)
|
|
162
|
+
# Extract the inner part "S<b,m,s>"
|
|
163
|
+
match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str)
|
|
164
|
+
if match:
|
|
165
|
+
b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3))
|
|
166
|
+
return b, m, s
|
|
167
|
+
else:
|
|
168
|
+
raise ValueError(f"Could not parse swizzle_type: {swizzle_str}")
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32:
|
|
172
|
+
bit_msk = (1 << b) - 1
|
|
173
|
+
yyy_msk = bit_msk << (m + s)
|
|
174
|
+
return ptr_int ^ ((ptr_int & yyy_msk) >> s)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def swizzle_ptr(ptr: cute.Pointer):
|
|
178
|
+
b, m, s = parse_swizzle_from_pointer(ptr)
|
|
179
|
+
ptr_int = swizzle_int(ptr.toint(), b, m, s)
|
|
180
|
+
return cute.make_ptr(ptr.dtype, ptr_int, ptr.memspace, assumed_align=ptr.alignment)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def as_position_independent_swizzle_tensor(tensor: cute.Tensor) -> cute.Tensor:
|
|
184
|
+
outer = tensor.layout
|
|
185
|
+
width = tensor.element_type.width
|
|
186
|
+
inner = cute.make_swizzle(*parse_swizzle_from_pointer(tensor.iterator))
|
|
187
|
+
# Need to recast the swizzle from byte (e.g. <3, 4, 3> to element units (e.g. <3, 3, 3> for
|
|
188
|
+
# for 16 bits and <3, 2, 3> for 32 bits)
|
|
189
|
+
new_layout = cute.recast_layout(
|
|
190
|
+
width, 8, cute.make_composed_layout(inner, 0, cute.recast_layout(8, width, outer))
|
|
191
|
+
)
|
|
192
|
+
# recast_ptr to remove the pointer swizzle
|
|
193
|
+
return cute.make_tensor(cute.recast_ptr(tensor.iterator, dtype=tensor.element_type), new_layout)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def partition_D_position_independent(
|
|
197
|
+
thr_copy: cute.core.ThrCopy, tensor: cute.Tensor
|
|
198
|
+
) -> cute.Tensor:
|
|
199
|
+
return cute.make_tensor(
|
|
200
|
+
swizzle_ptr(thr_copy.partition_D(tensor).iterator),
|
|
201
|
+
thr_copy.partition_D(as_position_independent_swizzle_tensor(tensor)).layout,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def partition_S_position_independent(
|
|
206
|
+
thr_copy: cute.core.ThrCopy, tensor: cute.Tensor
|
|
207
|
+
) -> cute.Tensor:
|
|
208
|
+
return cute.make_tensor(
|
|
209
|
+
swizzle_ptr(thr_copy.partition_S(tensor).iterator),
|
|
210
|
+
thr_copy.partition_S(as_position_independent_swizzle_tensor(tensor)).layout,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@dsl_user_op
|
|
215
|
+
def sm90_get_smem_load_op(
|
|
216
|
+
layout_c: cutlass.utils.LayoutEnum,
|
|
217
|
+
elem_ty_c: Type[cutlass.Numeric],
|
|
218
|
+
*,
|
|
219
|
+
loc=None,
|
|
220
|
+
ip=None,
|
|
221
|
+
) -> cute.CopyAtom:
|
|
222
|
+
"""
|
|
223
|
+
Selects the largest vectorized smem load atom available subject to constraint of gmem layout.
|
|
224
|
+
|
|
225
|
+
Parameters:
|
|
226
|
+
-----------
|
|
227
|
+
layout_c : LayoutEnum
|
|
228
|
+
The layout enum of the output tensor D.
|
|
229
|
+
|
|
230
|
+
elem_ty_c : Type[Numeric]
|
|
231
|
+
The element type for output tensor D.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
--------
|
|
235
|
+
Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters.
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta):
|
|
239
|
+
raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
|
|
240
|
+
is_m_major = layout_c.is_m_major_c()
|
|
241
|
+
if elem_ty_c.width == 16:
|
|
242
|
+
return cute.make_copy_atom(
|
|
243
|
+
cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip
|
|
244
|
+
)
|
|
245
|
+
else:
|
|
246
|
+
return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def get_smem_store_atom(
|
|
250
|
+
arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
|
|
251
|
+
) -> cute.CopyAtom:
|
|
252
|
+
if const_expr(arch < 90 or element_type.width != 16):
|
|
253
|
+
return cute.make_copy_atom(
|
|
254
|
+
cute.nvgpu.CopyUniversalOp(),
|
|
255
|
+
element_type,
|
|
256
|
+
num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
|
|
257
|
+
)
|
|
258
|
+
else:
|
|
259
|
+
return cute.make_copy_atom(
|
|
260
|
+
cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
|
|
261
|
+
element_type,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def tma_get_copy_fn(
|
|
266
|
+
atom: cute.CopyAtom,
|
|
267
|
+
cta_coord: cute.Coord,
|
|
268
|
+
cta_layout: cute.Layout,
|
|
269
|
+
src_tensor: cute.Tensor,
|
|
270
|
+
dst_tensor: cute.Tensor,
|
|
271
|
+
filter_zeros: bool = False,
|
|
272
|
+
**kwargs,
|
|
273
|
+
) -> Callable:
|
|
274
|
+
src_is_smem = const_expr(
|
|
275
|
+
isinstance(src_tensor.iterator, cute.Pointer)
|
|
276
|
+
and src_tensor.memspace == cute.AddressSpace.smem
|
|
277
|
+
)
|
|
278
|
+
smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor)
|
|
279
|
+
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
|
280
|
+
s, g = cpasync.tma_partition(
|
|
281
|
+
atom,
|
|
282
|
+
cta_coord,
|
|
283
|
+
cta_layout,
|
|
284
|
+
cute.group_modes(smem_tensor, 0, cute.rank(smem_tensor) - 1),
|
|
285
|
+
cute.group_modes(gmem_tensor, 0, cute.rank(gmem_tensor) - 1),
|
|
286
|
+
)
|
|
287
|
+
if const_expr(filter_zeros):
|
|
288
|
+
s = cute.filter_zeros(s)
|
|
289
|
+
g = cute.filter_zeros(g)
|
|
290
|
+
src, dst = (s, g) if src_is_smem else (g, s)
|
|
291
|
+
|
|
292
|
+
def copy_tma(src_idx, dst_idx, **new_kwargs):
|
|
293
|
+
cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs)
|
|
294
|
+
|
|
295
|
+
return copy_tma, s, g
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync):
|
|
299
|
+
def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs):
|
|
300
|
+
copy(
|
|
301
|
+
src_idx=src_idx,
|
|
302
|
+
dst_idx=producer_state.index,
|
|
303
|
+
tma_bar_ptr=pipeline.producer_get_barrier(producer_state),
|
|
304
|
+
**new_kwargs,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
return copy_fn
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@cute.jit
|
|
311
|
+
def gather_m_get_copy_fn(
|
|
312
|
+
thr_copy_A: cute.ThrCopy,
|
|
313
|
+
mA: cute.Tensor, # (whatever, K)
|
|
314
|
+
sA: cute.Tensor, # (tile_M, tile_N, STAGE)
|
|
315
|
+
gsAIdx: cute.Tensor, # (tile_M), either gmem or smem
|
|
316
|
+
limit_m: Int32,
|
|
317
|
+
limit_k: Int32,
|
|
318
|
+
) -> Callable:
|
|
319
|
+
tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
|
|
320
|
+
tAsA = thr_copy_A.partition_D(sA)
|
|
321
|
+
# k-major
|
|
322
|
+
assert tAsA.shape[2] == 1
|
|
323
|
+
tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2)
|
|
324
|
+
|
|
325
|
+
is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
|
|
326
|
+
if const_expr(not is_even_m_smem):
|
|
327
|
+
limit_m = min(limit_m, tile_shape_mk[0])
|
|
328
|
+
elems_per_load = cute.size(tAsA.shape[0][0])
|
|
329
|
+
cA = cute.make_identity_tensor(tile_shape_mk)
|
|
330
|
+
tAcA = thr_copy_A.partition_S(cA)
|
|
331
|
+
t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
|
|
332
|
+
# Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
|
|
333
|
+
# since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
|
|
334
|
+
# This is so that when we do the comparison, t0AcA is known at compile time.
|
|
335
|
+
limit_m = limit_m - tAcA[0][0]
|
|
336
|
+
limit_k = limit_k - tAcA[0][1]
|
|
337
|
+
# Read and cache indices for A
|
|
338
|
+
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
|
339
|
+
cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
|
|
340
|
+
tApA_m = cute.make_fragment(rows_per_thread, Boolean)
|
|
341
|
+
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
|
342
|
+
tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
|
|
343
|
+
m_idx = cute.make_fragment(rows_per_thread, Int32)
|
|
344
|
+
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
|
345
|
+
row_idx = tAcA[0, m, 0][0]
|
|
346
|
+
if tApA_m[m]:
|
|
347
|
+
m_idx[m] = gsAIdx[row_idx]
|
|
348
|
+
else:
|
|
349
|
+
m_idx[m] = 0 # It's ok to load row 0 in the case of OOB
|
|
350
|
+
|
|
351
|
+
mA_k = cute.logical_divide(mA, (None, tile_shape_mk[1]))
|
|
352
|
+
|
|
353
|
+
def copy_fn(src_idx, dst_idx, pred: bool = False):
|
|
354
|
+
tApA_k = None
|
|
355
|
+
if const_expr(pred):
|
|
356
|
+
tApA_k = cute.make_fragment(cols_per_thread, Boolean)
|
|
357
|
+
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
|
358
|
+
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
|
359
|
+
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
|
360
|
+
mA_cur = mA_k[None, (None, src_idx)]
|
|
361
|
+
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
|
362
|
+
# cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,)) would give shape
|
|
363
|
+
# ((elems_per_load), thread_per_row)
|
|
364
|
+
# But we actually want shape ((elems_per_load, 1), thread_per_row) to match tAsA
|
|
365
|
+
# So we append 1s to the last dimension and then do tiled_divide, then slice.
|
|
366
|
+
mA_row = cute.tiled_divide(
|
|
367
|
+
cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1)
|
|
368
|
+
)[None, None, 0]
|
|
369
|
+
if const_expr(is_even_m_smem) or tApA_m[m]:
|
|
370
|
+
# There's only 1 load per row
|
|
371
|
+
assert cute.size(tAcA.shape, mode=[2]) == 1
|
|
372
|
+
ki = tAcA[0, 0, 0][1] // elems_per_load
|
|
373
|
+
cute.copy(thr_copy_A, mA_row[None, ki], tAsA[(None, m), dst_idx], pred=tApA_k)
|
|
374
|
+
|
|
375
|
+
return copy_fn
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
@cute.jit
|
|
379
|
+
def gather_k_get_copy_fn(
|
|
380
|
+
thr_copy_A: cute.ThrCopy,
|
|
381
|
+
mA: cute.Tensor, # (tile_M, whatever)
|
|
382
|
+
sA: cute.Tensor, # (tile_M, tile_N, STAGE)
|
|
383
|
+
gsAIdx: cute.Tensor, # (tile_K, RestK), either gmem or smem
|
|
384
|
+
limit_m: Int32,
|
|
385
|
+
limit_k: Int32,
|
|
386
|
+
) -> Callable:
|
|
387
|
+
gAIdx, sAIdx = None, None
|
|
388
|
+
if const_expr(gsAIdx.memspace == cute.AddressSpace.gmem):
|
|
389
|
+
gAIdx = gsAIdx
|
|
390
|
+
else:
|
|
391
|
+
assert gsAIdx.memspace == cute.AddressSpace.smem
|
|
392
|
+
sAIdx = gsAIdx
|
|
393
|
+
tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1]))
|
|
394
|
+
# (atom_v, CPY_M, 1, STAGE)
|
|
395
|
+
tAsA = thr_copy_A.partition_D(sA)
|
|
396
|
+
# m-major
|
|
397
|
+
tAsA = cute.group_modes(tAsA, 0, 3)
|
|
398
|
+
|
|
399
|
+
is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0
|
|
400
|
+
if const_expr(not is_even_m_smem):
|
|
401
|
+
limit_m = min(limit_m, tile_shape_mk[0])
|
|
402
|
+
elems_per_load = cute.size(tAsA.shape[0][0])
|
|
403
|
+
cA = cute.make_identity_tensor(tile_shape_mk)
|
|
404
|
+
tAcA = thr_copy_A.partition_S(cA)
|
|
405
|
+
t0AcA = thr_copy_A.get_slice(0).partition_S(cA)
|
|
406
|
+
# Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0]
|
|
407
|
+
# since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0].
|
|
408
|
+
# This is so that when we do the comparison, t0AcA is known at compile time.
|
|
409
|
+
limit_m = limit_m - tAcA[0][0]
|
|
410
|
+
limit_k = limit_k - tAcA[0][1]
|
|
411
|
+
# Read and cache indices for A
|
|
412
|
+
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
|
413
|
+
cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
|
|
414
|
+
tApA_m = cute.make_fragment(rows_per_thread, Boolean)
|
|
415
|
+
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
|
416
|
+
tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
|
|
417
|
+
threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
|
|
418
|
+
# This is very convoluted but idk a better way
|
|
419
|
+
# for tile_M=128, flat_divide gives (8, 16, K),
|
|
420
|
+
# then logical_divide gives ((8, 1), (8, 2), K).
|
|
421
|
+
tidx = thr_copy_A.thr_idx
|
|
422
|
+
tAmA = cute.logical_divide(
|
|
423
|
+
cute.flat_divide(mA, (elems_per_load,)), (elems_per_load, threads_per_col)
|
|
424
|
+
)[None, (tidx % threads_per_col, None), None] # ((8, 1), 2, K)
|
|
425
|
+
|
|
426
|
+
def prefetch_from_gmem_fn(src_idx, pred: bool = False) -> Tuple[cute.Tensor, cute.Tensor]:
|
|
427
|
+
# Prefetch mAIdx early, even before smem is free
|
|
428
|
+
tApA_k = None
|
|
429
|
+
if const_expr(pred):
|
|
430
|
+
tApA_k = cute.make_fragment(cols_per_thread, Boolean)
|
|
431
|
+
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
|
432
|
+
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
|
433
|
+
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
|
434
|
+
gAIdx_cur = gAIdx[None, src_idx]
|
|
435
|
+
k_idx = cute.make_fragment(cols_per_thread, Int32)
|
|
436
|
+
for k in cutlass.range(cols_per_thread):
|
|
437
|
+
col_idx = tAcA[0, 0, k][1]
|
|
438
|
+
if const_expr(not pred):
|
|
439
|
+
k_idx[k] = gAIdx_cur[col_idx]
|
|
440
|
+
else:
|
|
441
|
+
if tApA_k[k]:
|
|
442
|
+
k_idx[k] = gAIdx_cur[col_idx]
|
|
443
|
+
else:
|
|
444
|
+
k_idx[k] = -1
|
|
445
|
+
return k_idx, tApA_k
|
|
446
|
+
|
|
447
|
+
def prefetch_from_smem_fn(
|
|
448
|
+
a_prefetch_pipeline, src_idx, dst_idx, a_prefetch_consumer_state, pred: bool = False
|
|
449
|
+
) -> Tuple[cute.Tensor, cute.Tensor]:
|
|
450
|
+
tApA_k = None
|
|
451
|
+
if const_expr(pred):
|
|
452
|
+
tApA_k = cute.make_fragment(cols_per_thread, Boolean)
|
|
453
|
+
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
|
454
|
+
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
|
455
|
+
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
|
456
|
+
a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
|
|
457
|
+
sAIdx_cur = sAIdx[None, dst_idx]
|
|
458
|
+
k_idx = cute.make_fragment(cols_per_thread, Int32)
|
|
459
|
+
for k in cutlass.range(cols_per_thread):
|
|
460
|
+
col_idx = tAcA[0, 0, k][1]
|
|
461
|
+
k_idx[k] = sAIdx_cur[col_idx]
|
|
462
|
+
cute.arch.sync_warp()
|
|
463
|
+
with cute.arch.elect_one():
|
|
464
|
+
a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state)
|
|
465
|
+
return k_idx, tApA_k
|
|
466
|
+
|
|
467
|
+
def copy_fn(
|
|
468
|
+
src_idx, dst_idx, k_idx_tApA_k: Tuple[cute.Tensor, cute.Tensor], pred: bool = False
|
|
469
|
+
):
|
|
470
|
+
k_idx, tApA_k = k_idx_tApA_k
|
|
471
|
+
tApA_k_pred = None
|
|
472
|
+
if const_expr(pred):
|
|
473
|
+
tApA_k_pred = cute.prepend_ones(tApA_k, up_to_rank=2) # (1, cols_per_thread)
|
|
474
|
+
for k in cutlass.range_constexpr(tAcA.shape[2]):
|
|
475
|
+
# copy_A(tAmA[None, None, k_idx[k]], tAsA[(None, None, k), smem_idx], pred=cute.prepend_ones(tApA_m, up_to_rank=2))
|
|
476
|
+
for m in cutlass.range_constexpr(tAcA.shape[1]):
|
|
477
|
+
if tApA_m[m]:
|
|
478
|
+
cute.copy(
|
|
479
|
+
thr_copy_A,
|
|
480
|
+
tAmA[None, m, k_idx[k]],
|
|
481
|
+
tAsA[(None, m, k), dst_idx],
|
|
482
|
+
pred=None if const_expr(tApA_k_pred is None) else tApA_k_pred[None, k],
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
return copy_fn, prefetch_from_gmem_fn if const_expr(
|
|
486
|
+
gAIdx is not None
|
|
487
|
+
) else prefetch_from_smem_fn
|