quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +8 -1
- quack/activation.py +288 -0
- quack/autotuner.py +310 -0
- quack/cross_entropy.py +325 -175
- quack/cute_dsl_utils.py +119 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +1657 -842
- quack/fast_math.py +80 -0
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +69 -0
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +569 -0
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +5 -3
- quack/linear.py +240 -0
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +74 -0
- quack/pipeline.py +151 -0
- quack/reduce.py +241 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +583 -231
- quack/softmax.py +27 -15
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2091 -0
- quack/tensormap_manager.py +115 -0
- quack/tile_scheduler.py +937 -0
- quack/topk.py +227 -0
- quack/utils.py +203 -230
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/METADATA +2 -2
- quack_kernels-0.2.0.dist-info/RECORD +37 -0
- quack_kernels-0.1.10.dist-info/RECORD +0 -13
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
|
|
2
|
+
"""
|
|
3
|
+
Optimal sorting networks generated from: https://bertdobbelaere.github.io/sorting_networks.html
|
|
4
|
+
|
|
5
|
+
This file was auto-generated by quack/sort/generate_sorting_networks.py. Do not edit it directly.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
# fmt: off
|
|
9
|
+
# ruff: noqa
|
|
10
|
+
# isort: skip_file
|
|
11
|
+
|
|
12
|
+
import cutlass
|
|
13
|
+
import cutlass.cute as cute
|
|
14
|
+
|
|
15
|
+
from quack.sort.utils import compare_and_swap
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
networks = {
|
|
19
|
+
# Size 2: 1 CEs, depth 1
|
|
20
|
+
2: [[(0, 1)]],
|
|
21
|
+
|
|
22
|
+
# Size 4: 5 CEs, depth 3
|
|
23
|
+
4: [
|
|
24
|
+
[(0, 2), (1, 3)],
|
|
25
|
+
[(0, 1), (2, 3)],
|
|
26
|
+
[(1, 2)]
|
|
27
|
+
],
|
|
28
|
+
|
|
29
|
+
# Size 8: 19 CEs, depth 6
|
|
30
|
+
8: [
|
|
31
|
+
[(0, 2), (1, 3), (4, 6), (5, 7)],
|
|
32
|
+
[(0, 4), (1, 5), (2, 6), (3, 7)],
|
|
33
|
+
[(0, 1), (2, 3), (4, 5), (6, 7)],
|
|
34
|
+
[(2, 4), (3, 5)],
|
|
35
|
+
[(1, 4), (3, 6)],
|
|
36
|
+
[(1, 2), (3, 4), (5, 6)]
|
|
37
|
+
],
|
|
38
|
+
|
|
39
|
+
# Size 16: 60 CEs, depth 10
|
|
40
|
+
16: [
|
|
41
|
+
[(0, 13), (1, 12), (2, 15), (3, 14), (4, 8), (5, 6), (7, 11), (9, 10)],
|
|
42
|
+
[(0, 5), (1, 7), (2, 9), (3, 4), (6, 13), (8, 14), (10, 15), (11, 12)],
|
|
43
|
+
[(0, 1), (2, 3), (4, 5), (6, 8), (7, 9), (10, 11), (12, 13), (14, 15)],
|
|
44
|
+
[(0, 2), (1, 3), (4, 10), (5, 11), (6, 7), (8, 9), (12, 14), (13, 15)],
|
|
45
|
+
[(1, 2), (3, 12), (4, 6), (5, 7), (8, 10), (9, 11), (13, 14)],
|
|
46
|
+
[(1, 4), (2, 6), (5, 8), (7, 10), (9, 13), (11, 14)],
|
|
47
|
+
[(2, 4), (3, 6), (9, 12), (11, 13)],
|
|
48
|
+
[(3, 5), (6, 8), (7, 9), (10, 12)],
|
|
49
|
+
[(3, 4), (5, 6), (7, 8), (9, 10), (11, 12)],
|
|
50
|
+
[(6, 7), (8, 9)]
|
|
51
|
+
],
|
|
52
|
+
|
|
53
|
+
# Size 32: 185 CEs, depth 14
|
|
54
|
+
32: [
|
|
55
|
+
[(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19), (20, 21), (22, 23), (24, 25), (26, 27), (28, 29), (30, 31)],
|
|
56
|
+
[(0, 2), (1, 3), (4, 6), (5, 7), (8, 10), (9, 11), (12, 14), (13, 15), (16, 18), (17, 19), (20, 22), (21, 23), (24, 26), (25, 27), (28, 30), (29, 31)],
|
|
57
|
+
[(0, 4), (1, 5), (2, 6), (3, 7), (8, 12), (9, 13), (10, 14), (11, 15), (16, 20), (17, 21), (18, 22), (19, 23), (24, 28), (25, 29), (26, 30), (27, 31)],
|
|
58
|
+
[(0, 8), (1, 9), (2, 10), (3, 11), (4, 12), (5, 13), (6, 14), (7, 15), (16, 24), (17, 25), (18, 26), (19, 27), (20, 28), (21, 29), (22, 30), (23, 31)],
|
|
59
|
+
[(0, 16), (1, 8), (2, 4), (3, 12), (5, 10), (6, 9), (7, 14), (11, 13), (15, 31), (17, 24), (18, 20), (19, 28), (21, 26), (22, 25), (23, 30), (27, 29)],
|
|
60
|
+
[(1, 2), (3, 5), (4, 8), (6, 22), (7, 11), (9, 25), (10, 12), (13, 14), (17, 18), (19, 21), (20, 24), (23, 27), (26, 28), (29, 30)],
|
|
61
|
+
[(1, 17), (2, 18), (3, 19), (4, 20), (5, 10), (7, 23), (8, 24), (11, 27), (12, 28), (13, 29), (14, 30), (21, 26)],
|
|
62
|
+
[(3, 17), (4, 16), (5, 21), (6, 18), (7, 9), (8, 20), (10, 26), (11, 23), (13, 25), (14, 28), (15, 27), (22, 24)],
|
|
63
|
+
[(1, 4), (3, 8), (5, 16), (7, 17), (9, 21), (10, 22), (11, 19), (12, 20), (14, 24), (15, 26), (23, 28), (27, 30)],
|
|
64
|
+
[(2, 5), (7, 8), (9, 18), (11, 17), (12, 16), (13, 22), (14, 20), (15, 19), (23, 24), (26, 29)],
|
|
65
|
+
[(2, 4), (6, 12), (9, 16), (10, 11), (13, 17), (14, 18), (15, 22), (19, 25), (20, 21), (27, 29)],
|
|
66
|
+
[(5, 6), (8, 12), (9, 10), (11, 13), (14, 16), (15, 17), (18, 20), (19, 23), (21, 22), (25, 26)],
|
|
67
|
+
[(3, 5), (6, 7), (8, 9), (10, 12), (11, 14), (13, 16), (15, 18), (17, 20), (19, 21), (22, 23), (24, 25), (26, 28)],
|
|
68
|
+
[(3, 4), (5, 6), (7, 8), (9, 10), (11, 12), (13, 14), (15, 16), (17, 18), (19, 20), (21, 22), (23, 24), (25, 26), (27, 28)]
|
|
69
|
+
],
|
|
70
|
+
|
|
71
|
+
# Size 64: 521 CEs, depth 21
|
|
72
|
+
64: [
|
|
73
|
+
[(0, 2), (1, 3), (4, 6), (5, 7), (8, 10), (9, 11), (12, 14), (13, 15), (16, 18), (17, 19), (20, 22), (21, 23), (24, 26), (25, 27), (28, 30), (29, 31), (32, 34), (33, 35), (36, 38), (37, 39), (40, 42), (41, 43), (44, 46), (45, 47), (48, 50), (49, 51), (52, 54), (53, 55), (56, 58), (57, 59), (60, 62), (61, 63)],
|
|
74
|
+
[(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19), (20, 21), (22, 23), (24, 25), (26, 27), (28, 29), (30, 31), (32, 33), (34, 35), (36, 37), (38, 39), (40, 41), (42, 43), (44, 45), (46, 47), (48, 49), (50, 51), (52, 53), (54, 55), (56, 57), (58, 59), (60, 61), (62, 63)],
|
|
75
|
+
[(0, 52), (1, 2), (3, 55), (4, 48), (5, 6), (7, 51), (8, 60), (9, 10), (11, 63), (12, 56), (13, 14), (15, 59), (16, 32), (17, 18), (19, 35), (20, 24), (21, 22), (23, 27), (25, 26), (28, 44), (29, 30), (31, 47), (33, 34), (36, 40), (37, 38), (39, 43), (41, 42), (45, 46), (49, 50), (53, 54), (57, 58), (61, 62)],
|
|
76
|
+
[(0, 20), (1, 53), (2, 54), (3, 23), (4, 28), (5, 49), (6, 50), (7, 31), (8, 36), (9, 61), (10, 62), (11, 39), (12, 16), (13, 57), (14, 58), (15, 19), (17, 33), (18, 34), (21, 25), (22, 26), (24, 52), (27, 55), (29, 45), (30, 46), (32, 56), (35, 59), (37, 41), (38, 42), (40, 60), (43, 63), (44, 48), (47, 51)],
|
|
77
|
+
[(0, 4), (1, 21), (2, 22), (3, 7), (5, 29), (6, 30), (8, 12), (9, 37), (10, 38), (11, 15), (13, 17), (14, 18), (16, 20), (19, 23), (24, 32), (25, 53), (26, 54), (27, 35), (28, 36), (31, 39), (33, 57), (34, 58), (40, 44), (41, 61), (42, 62), (43, 47), (45, 49), (46, 50), (48, 52), (51, 55), (56, 60), (59, 63)],
|
|
78
|
+
[(0, 8), (1, 5), (2, 6), (3, 11), (4, 12), (7, 15), (9, 13), (10, 14), (16, 40), (17, 21), (18, 22), (19, 43), (20, 44), (23, 47), (24, 28), (25, 33), (26, 34), (27, 31), (29, 37), (30, 38), (32, 36), (35, 39), (41, 45), (42, 46), (48, 56), (49, 53), (50, 54), (51, 59), (52, 60), (55, 63), (57, 61), (58, 62)],
|
|
79
|
+
[(1, 9), (2, 10), (4, 8), (5, 13), (6, 14), (7, 11), (12, 48), (15, 51), (16, 24), (17, 41), (18, 42), (19, 27), (20, 28), (21, 45), (22, 46), (23, 31), (25, 29), (26, 30), (32, 40), (33, 37), (34, 38), (35, 43), (36, 44), (39, 47), (49, 57), (50, 58), (52, 56), (53, 61), (54, 62), (55, 59)],
|
|
80
|
+
[(4, 16), (5, 9), (6, 10), (7, 19), (8, 24), (11, 27), (13, 49), (14, 50), (17, 25), (18, 26), (20, 32), (21, 29), (22, 30), (23, 35), (28, 40), (31, 43), (33, 41), (34, 42), (36, 52), (37, 45), (38, 46), (39, 55), (44, 56), (47, 59), (53, 57), (54, 58)],
|
|
81
|
+
[(1, 4), (5, 17), (6, 18), (8, 16), (9, 25), (10, 26), (11, 19), (12, 24), (15, 27), (21, 33), (22, 34), (29, 41), (30, 42), (36, 48), (37, 53), (38, 54), (39, 51), (44, 52), (45, 57), (46, 58), (47, 55), (59, 62)],
|
|
82
|
+
[(2, 8), (9, 17), (10, 18), (12, 20), (13, 25), (14, 26), (15, 23), (24, 32), (27, 35), (28, 36), (31, 39), (37, 49), (38, 50), (40, 48), (43, 51), (45, 53), (46, 54), (55, 61)],
|
|
83
|
+
[(2, 4), (12, 16), (13, 21), (14, 22), (15, 19), (20, 24), (23, 27), (25, 33), (26, 34), (28, 32), (29, 37), (30, 38), (31, 35), (36, 40), (39, 43), (41, 49), (42, 50), (44, 48), (47, 51), (59, 61)],
|
|
84
|
+
[(4, 16), (5, 20), (10, 40), (13, 17), (14, 18), (21, 25), (22, 26), (23, 53), (24, 28), (27, 31), (29, 33), (30, 34), (32, 36), (35, 39), (37, 41), (38, 42), (43, 58), (45, 49), (46, 50), (47, 59)],
|
|
85
|
+
[(3, 17), (6, 36), (7, 21), (8, 32), (9, 24), (11, 41), (13, 28), (14, 44), (15, 45), (18, 48), (19, 49), (22, 52), (25, 29), (26, 30), (27, 57), (31, 55), (33, 37), (34, 38), (35, 50), (39, 54), (42, 56), (46, 60)],
|
|
86
|
+
[(6, 20), (8, 16), (10, 24), (11, 25), (14, 28), (15, 29), (17, 33), (18, 32), (21, 37), (22, 36), (26, 42), (27, 41), (30, 46), (31, 45), (34, 48), (35, 49), (38, 52), (39, 53), (43, 57), (47, 55)],
|
|
87
|
+
[(3, 18), (5, 8), (6, 12), (7, 22), (15, 21), (17, 32), (19, 33), (23, 37), (26, 40), (30, 44), (31, 46), (41, 56), (42, 48), (45, 60), (51, 57), (55, 58)],
|
|
88
|
+
[(3, 16), (7, 20), (11, 26), (18, 24), (19, 25), (22, 28), (23, 29), (27, 33), (30, 36), (34, 40), (35, 41), (37, 52), (38, 44), (39, 45), (43, 56), (47, 60)],
|
|
89
|
+
[(3, 9), (7, 13), (10, 16), (11, 17), (14, 20), (15, 30), (19, 34), (21, 36), (23, 38), (25, 40), (26, 32), (27, 42), (29, 44), (31, 37), (33, 48), (43, 49), (46, 52), (47, 53), (50, 56), (54, 60)],
|
|
90
|
+
[(3, 8), (7, 10), (9, 12), (11, 18), (13, 14), (15, 24), (17, 22), (19, 28), (21, 26), (23, 25), (27, 34), (29, 36), (30, 32), (31, 33), (35, 44), (37, 42), (38, 40), (39, 48), (41, 46), (45, 52), (49, 50), (51, 54), (53, 56), (55, 60)],
|
|
91
|
+
[(3, 6), (7, 12), (11, 16), (15, 17), (18, 20), (19, 24), (21, 22), (23, 30), (25, 32), (26, 28), (27, 29), (31, 38), (33, 40), (34, 36), (35, 37), (39, 44), (41, 42), (43, 45), (46, 48), (47, 52), (51, 56), (57, 60)],
|
|
92
|
+
[(3, 5), (6, 8), (7, 9), (10, 12), (11, 13), (14, 16), (15, 18), (17, 20), (19, 21), (22, 24), (23, 26), (25, 28), (27, 30), (29, 32), (31, 34), (33, 36), (35, 38), (37, 40), (39, 41), (42, 44), (43, 46), (45, 48), (47, 49), (50, 52), (51, 53), (54, 56), (55, 57), (58, 60)],
|
|
93
|
+
[(3, 4), (7, 8), (11, 12), (13, 14), (15, 16), (17, 18), (19, 20), (21, 22), (23, 24), (25, 26), (27, 28), (29, 30), (31, 32), (33, 34), (35, 36), (37, 38), (39, 40), (41, 42), (43, 44), (45, 46), (47, 48), (49, 50), (51, 52), (55, 56), (59, 60)]
|
|
94
|
+
],
|
|
95
|
+
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@cute.jit
|
|
100
|
+
def optimal_sort(
|
|
101
|
+
arr: cute.Tensor,
|
|
102
|
+
n: cutlass.Constexpr[int],
|
|
103
|
+
start: cutlass.Constexpr[int] = 0,
|
|
104
|
+
ascending: cutlass.Constexpr[bool] = True
|
|
105
|
+
) -> None:
|
|
106
|
+
"""
|
|
107
|
+
Optimal sorting network dispatcher.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
arr: Array to sort
|
|
111
|
+
n: Size of array (must be power of 2 and available in networks)
|
|
112
|
+
start: Starting index (default 0)
|
|
113
|
+
ascending: Sort in ascending order (default True)
|
|
114
|
+
|
|
115
|
+
Source: https://bertdobbelaere.github.io/sorting_networks.html
|
|
116
|
+
"""
|
|
117
|
+
assert n in networks
|
|
118
|
+
for level in networks[n]:
|
|
119
|
+
for i, j in level:
|
|
120
|
+
compare_and_swap(arr, start + i, start + j, ascending)
|
quack/sort/utils.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import cutlass
|
|
2
|
+
import cutlass.cute as cute
|
|
3
|
+
|
|
4
|
+
import quack.utils as utils
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@cute.jit
|
|
8
|
+
def compare_and_swap(
|
|
9
|
+
arr: cute.Tensor, i: int, j: int, ascending: bool = True, use_selection: bool = False
|
|
10
|
+
) -> None:
|
|
11
|
+
"""Compare and swap elements at indices i and j in ascending or descending order."""
|
|
12
|
+
if cutlass.const_expr(use_selection):
|
|
13
|
+
a, b = arr[i], arr[j]
|
|
14
|
+
if (a > b) ^ (not ascending):
|
|
15
|
+
arr[i] = b
|
|
16
|
+
arr[j] = a
|
|
17
|
+
# if cutlass.const_expr(ascending):
|
|
18
|
+
# if a > b:
|
|
19
|
+
# arr[i] = b
|
|
20
|
+
# arr[j] = a
|
|
21
|
+
# else:
|
|
22
|
+
# if a < b:
|
|
23
|
+
# arr[i] = b
|
|
24
|
+
# arr[j] = a
|
|
25
|
+
else:
|
|
26
|
+
min_fn = min if cutlass.const_expr(arr.element_type != cutlass.Float32) else utils.fmin
|
|
27
|
+
max_fn = max if cutlass.const_expr(arr.element_type != cutlass.Float32) else cute.arch.fmax
|
|
28
|
+
if cutlass.const_expr(ascending):
|
|
29
|
+
arr[i], arr[j] = min_fn(arr[i], arr[j]), max_fn(arr[i], arr[j])
|
|
30
|
+
else:
|
|
31
|
+
arr[i], arr[j] = max_fn(arr[i], arr[j]), min_fn(arr[i], arr[j])
|