quack-kernels 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/autotuner.py +64 -5
- quack/cute_dsl_utils.py +6 -7
- quack/dense_gemm_sm90.py +582 -287
- quack/gemm_act_sm90.py +70 -29
- quack/gemm_dact_sm90.py +43 -10
- quack/gemm_interface.py +453 -130
- quack/{dense_gemm_sm100.py → gemm_sm100.py} +443 -419
- quack/gemm_wrapper_utils.py +179 -22
- quack/rmsnorm.py +83 -149
- quack/tile_scheduler.py +34 -47
- quack/utils.py +61 -8
- quack/varlen_utils.py +1 -6
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.2.dist-info}/METADATA +2 -2
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.2.dist-info}/RECORD +18 -18
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.2.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.2.dist-info}/top_level.txt +0 -0
quack/utils.py
CHANGED
|
@@ -6,7 +6,7 @@ from typing import Optional, Tuple, Type, Union
|
|
|
6
6
|
import cutlass
|
|
7
7
|
import cutlass.cute as cute
|
|
8
8
|
|
|
9
|
-
from cutlass import Float32, Int32
|
|
9
|
+
from cutlass import Float32, Int32, const_expr
|
|
10
10
|
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
11
11
|
from cutlass._mlir.dialects import llvm, nvvm, vector
|
|
12
12
|
from cutlass.cute.runtime import from_dlpack
|
|
@@ -22,6 +22,59 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te
|
|
|
22
22
|
)
|
|
23
23
|
|
|
24
24
|
|
|
25
|
+
def transpose_view(a: cute.Tensor) -> cute.Tensor:
|
|
26
|
+
"""Transpose the first two dimensions of a tensor on smem."""
|
|
27
|
+
shape = (a.shape[1], a.shape[0], *a.shape[2:])
|
|
28
|
+
order = (1, 0, *range(2, cute.rank(a)))
|
|
29
|
+
return cute.composition(a, cute.make_ordered_layout(shape, order=order))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor:
|
|
33
|
+
return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dsl_user_op
|
|
37
|
+
def get_copy_atom(
|
|
38
|
+
dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None
|
|
39
|
+
) -> cute.CopyAtom:
|
|
40
|
+
num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
|
|
41
|
+
copy_op = cute.nvgpu.cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
|
42
|
+
return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dsl_user_op
|
|
46
|
+
def copy(
|
|
47
|
+
src: cute.Tensor,
|
|
48
|
+
dst: cute.Tensor,
|
|
49
|
+
*,
|
|
50
|
+
pred: Optional[cute.Tensor] = None,
|
|
51
|
+
num_copy_elems: int = 1,
|
|
52
|
+
is_async: bool = False,
|
|
53
|
+
loc=None,
|
|
54
|
+
ip=None,
|
|
55
|
+
**kwargs,
|
|
56
|
+
) -> None:
|
|
57
|
+
copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async)
|
|
58
|
+
cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def tiled_copy_2d(
|
|
62
|
+
dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = True
|
|
63
|
+
) -> cute.TiledCopy:
|
|
64
|
+
num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width
|
|
65
|
+
copy_elems = num_copy_bits // dtype.width
|
|
66
|
+
copy_op = cute.nvgpu.cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
|
|
67
|
+
copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
|
|
68
|
+
gmem_threads_per_row = major_mode_size // copy_elems
|
|
69
|
+
assert num_threads % gmem_threads_per_row == 0
|
|
70
|
+
thr_layout = cute.make_ordered_layout(
|
|
71
|
+
(num_threads // gmem_threads_per_row, gmem_threads_per_row),
|
|
72
|
+
order=(1, 0),
|
|
73
|
+
)
|
|
74
|
+
val_layout = cute.make_layout((1, copy_elems))
|
|
75
|
+
return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
|
|
76
|
+
|
|
77
|
+
|
|
25
78
|
@dsl_user_op
|
|
26
79
|
def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
|
|
27
80
|
return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
|
|
@@ -29,7 +82,7 @@ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cut
|
|
|
29
82
|
|
|
30
83
|
@cute.jit
|
|
31
84
|
def load_scalar_or_pointer(x: Float32 | cute.Pointer) -> Float32:
|
|
32
|
-
if
|
|
85
|
+
if const_expr(isinstance(x, cute.Pointer)):
|
|
33
86
|
return Float32(cute.make_tensor(x, cute.make_layout(1))[0])
|
|
34
87
|
else:
|
|
35
88
|
assert isinstance(x, Float32)
|
|
@@ -71,7 +124,7 @@ def store_shared_remote(
|
|
|
71
124
|
remote_mbar_ptr_i32 = set_block_rank(
|
|
72
125
|
mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
|
73
126
|
).ir_value()
|
|
74
|
-
if
|
|
127
|
+
if const_expr(isinstance(val, float)):
|
|
75
128
|
val = Float32(val)
|
|
76
129
|
assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64"
|
|
77
130
|
suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)]
|
|
@@ -196,7 +249,7 @@ def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Nu
|
|
|
196
249
|
tXrX_fill.fill(fill_value)
|
|
197
250
|
for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]):
|
|
198
251
|
for rest_k in cutlass.range_constexpr(tXsX.shape[2]):
|
|
199
|
-
if
|
|
252
|
+
if const_expr(tXpX is not None):
|
|
200
253
|
if not tXpX[rest_v, 0, rest_k]:
|
|
201
254
|
cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
|
|
202
255
|
else:
|
|
@@ -232,9 +285,9 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
|
|
|
232
285
|
def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
|
233
286
|
flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
|
|
234
287
|
flat_stride = cute.flatten_to_tuple(tensor.stride)
|
|
235
|
-
assert len(flat_coord_i64) == len(
|
|
236
|
-
|
|
237
|
-
)
|
|
288
|
+
assert len(flat_coord_i64) == len(flat_stride), (
|
|
289
|
+
"Coordinate and stride must have the same length"
|
|
290
|
+
)
|
|
238
291
|
offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
|
|
239
292
|
assert isinstance(tensor.iterator, cute.Pointer)
|
|
240
293
|
# HACK: we assume that applying the offset does not change the pointer alignment
|
|
@@ -265,7 +318,7 @@ def coord_offset_i64(
|
|
|
265
318
|
|
|
266
319
|
@cute.jit
|
|
267
320
|
def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32:
|
|
268
|
-
if
|
|
321
|
+
if const_expr(lane is None):
|
|
269
322
|
lane = cute.arch.lane_idx()
|
|
270
323
|
for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
|
|
271
324
|
offset = 1 << i
|
quack/varlen_utils.py
CHANGED
|
@@ -14,9 +14,4 @@ class VarlenArguments(ArgumentsBase):
|
|
|
14
14
|
mCuSeqlensM: Optional[cute.Tensor] = None
|
|
15
15
|
mCuSeqlensK: Optional[cute.Tensor] = None
|
|
16
16
|
mTensormaps: Optional[cute.Tensor] = None
|
|
17
|
-
|
|
18
|
-
def __post_init__(self):
|
|
19
|
-
if self.mCuSeqlensM is not None or self.mCuSeqlensK is not None:
|
|
20
|
-
assert (
|
|
21
|
-
self.mTensormaps is not None
|
|
22
|
-
), "mTensormaps must be provided if mCuSeqlensM or mCuSeqlensK is provided"
|
|
17
|
+
mAIdx: Optional[cute.Tensor] = None
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: quack-kernels
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Requires-Python: >=3.10
|
|
5
5
|
License-File: LICENSE
|
|
6
|
-
Requires-Dist: nvidia-cutlass-dsl==4.2.
|
|
6
|
+
Requires-Dist: nvidia-cutlass-dsl==4.2.1
|
|
7
7
|
Requires-Dist: torch
|
|
8
8
|
Provides-Extra: dev
|
|
9
9
|
Requires-Dist: pre-commit; extra == "dev"
|
|
@@ -1,16 +1,16 @@
|
|
|
1
|
-
quack/__init__.py,sha256=
|
|
1
|
+
quack/__init__.py,sha256=sJum67V7jEQPUDWz4FKJ5Sk7MqmBtbMXjZPVboQnDdE,364
|
|
2
2
|
quack/activation.py,sha256=SzQDUCB-kccqsy1aYUrHYJ2cGxKMXxxqpjJaJoqBYaE,10017
|
|
3
|
-
quack/autotuner.py,sha256=
|
|
3
|
+
quack/autotuner.py,sha256=atw0ntedi22RPwSdjWOoge4S56S8VFvRocJQcYhpAlo,13454
|
|
4
4
|
quack/cross_entropy.py,sha256=TE8j21c-7E4cInKtFjcKsgKXNhKCRFkNfhCJpgpasj8,28409
|
|
5
|
-
quack/cute_dsl_utils.py,sha256=
|
|
6
|
-
quack/
|
|
7
|
-
quack/dense_gemm_sm90.py,sha256=TjnjHnjhAwWH5YQWsFlADq07xSxtsprkw_p2Cy0yw7I,100407
|
|
5
|
+
quack/cute_dsl_utils.py,sha256=d8xLD17a9EsSQgmgWDO8rUWWCTRM8e1kDq1wzilaYC8,4563
|
|
6
|
+
quack/dense_gemm_sm90.py,sha256=LvcR178zzzWClkEerhIx940Sg-AF_BpQdnjqC8s9W1o,113832
|
|
8
7
|
quack/fast_math.py,sha256=E1XUqfUt0_n9BPZNggF-UDzZ6anso9bYUrwqafemWvQ,2297
|
|
9
|
-
quack/gemm_act_sm90.py,sha256=
|
|
8
|
+
quack/gemm_act_sm90.py,sha256=yJEkwCtKjldxzJYq78CpCV6fxoqoZJSpd7KvnglHqfo,16206
|
|
10
9
|
quack/gemm_config.py,sha256=gbYjPFeyT5wAhVwFQroRHlHoMKEJqAWX9P8wWy04l8Q,2258
|
|
11
|
-
quack/gemm_dact_sm90.py,sha256=
|
|
12
|
-
quack/gemm_interface.py,sha256=
|
|
13
|
-
quack/
|
|
10
|
+
quack/gemm_dact_sm90.py,sha256=QOACq-v9XHfY6p5frKzYCvkCbqGDq69beYcfCfl-5Kc,6458
|
|
11
|
+
quack/gemm_interface.py,sha256=qEbQRsvTrwKdLLlGVCMH76diMCKOsA6GqsC0PaepLow,39636
|
|
12
|
+
quack/gemm_sm100.py,sha256=T-2BUrUBXROxQ9Iz-6pB5T8j9go29Vlw4ZCJQ_oM7yg,110396
|
|
13
|
+
quack/gemm_wrapper_utils.py,sha256=oDCXngJuH-qbDI9DJuXkDHUogXleWZrF1mRpI1DAcI8,12687
|
|
14
14
|
quack/layernorm.py,sha256=AOe95-YqhFPw96x8pJq7FfBe26ROX9ZTvH025lM1ILs,13579
|
|
15
15
|
quack/linear.py,sha256=SrhRiAFjC7ONIMVmiNu-kSPLHNUyaCXt59a1f_5nNXo,9383
|
|
16
16
|
quack/linear_cross_entropy.py,sha256=Zhy_gdMsKHOie-jntBaqIuiDJtkiq6qEBwnyuWwIRw4,10092
|
|
@@ -18,20 +18,20 @@ quack/mlp.py,sha256=YjdwQRwEePA9KyidFXp5H1-lxiJc8dZ41vl8Fv8pgss,2259
|
|
|
18
18
|
quack/pipeline.py,sha256=DyCwZX8WvoUBFcMBz7CeYm9VUM31haEGgBhAzmxu8cE,5519
|
|
19
19
|
quack/reduce.py,sha256=0hRFMFfn6xC5QLk32Qmgc17XVkQ1yKC-3TfksccSBaU,10341
|
|
20
20
|
quack/reduction_base.py,sha256=CT-t_j7z8H1ByD9FkQYDRik_-THMDFv9QoXHmr9Xx9E,3636
|
|
21
|
-
quack/rmsnorm.py,sha256=
|
|
21
|
+
quack/rmsnorm.py,sha256=Ak3EL-qzwgaKGZl7O2upiR3FC93776Cgse_B5PZhTu0,45643
|
|
22
22
|
quack/softmax.py,sha256=WFWtgc40iLPFBpdStBBTC9803Npnv9rZjOzb_nK-RDs,17110
|
|
23
23
|
quack/symmetric_dense_gemm_sm90.py,sha256=2UXooIpClT2izdyGis1XaIgYYlLj-7MrcOMg2yR7YCk,88694
|
|
24
24
|
quack/tensormap_manager.py,sha256=Ts3Mxp0_es2RNA0ffvUjWMXN79lsfWEBZ0DQYhtbcnw,5338
|
|
25
|
-
quack/tile_scheduler.py,sha256=
|
|
25
|
+
quack/tile_scheduler.py,sha256=5lcprf3VIXWCNusWHBCveHpCWRzQ0nzcIMhaQbXher8,41727
|
|
26
26
|
quack/topk.py,sha256=RQl-23lIicQ9ry9Njur8i0JGem_WbO_Gchr6jy8EtVM,9185
|
|
27
|
-
quack/utils.py,sha256=
|
|
28
|
-
quack/varlen_utils.py,sha256=
|
|
27
|
+
quack/utils.py,sha256=DVMSbMngPBnIRrHuGDXKqVueiNv9DFCfGv076hxzJms,14747
|
|
28
|
+
quack/varlen_utils.py,sha256=GwXc8tO6BrYoYszhSeJ0u_KmreJAEodP1EAizLS-jaA,464
|
|
29
29
|
quack/sort/bitonic_sort.py,sha256=8t0SG1a6iEpYIlY8YM_AWvm4aN-4AA4vEzdBuJMJm9g,4768
|
|
30
30
|
quack/sort/generate_sorting_networks.py,sha256=vkJBOjTVEinQkWT4OtFqOWxFVdTIPoNAQocneKc9-rM,14477
|
|
31
31
|
quack/sort/sorting_networks.py,sha256=l_26zi3gXD_z-tnm2eAczRrmE-mbaz00KmqH6ONivL8,9686
|
|
32
32
|
quack/sort/utils.py,sha256=Mkr-l97RMAV-ZoNrwuzA1U3KO0Wjr38CV9Jm7ScyZoI,1090
|
|
33
|
-
quack_kernels-0.2.
|
|
34
|
-
quack_kernels-0.2.
|
|
35
|
-
quack_kernels-0.2.
|
|
36
|
-
quack_kernels-0.2.
|
|
37
|
-
quack_kernels-0.2.
|
|
33
|
+
quack_kernels-0.2.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
34
|
+
quack_kernels-0.2.2.dist-info/METADATA,sha256=ZZofR2edTztufmX_0ExiJ7CpFsT80koJf-pRRUm3ssg,285
|
|
35
|
+
quack_kernels-0.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
36
|
+
quack_kernels-0.2.2.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
37
|
+
quack_kernels-0.2.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|