quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,120 @@
1
+ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
2
+ """
3
+ Optimal sorting networks generated from: https://bertdobbelaere.github.io/sorting_networks.html
4
+
5
+ This file was auto-generated by quack/sort/generate_sorting_networks.py. Do not edit it directly.
6
+ """
7
+
8
+ # fmt: off
9
+ # ruff: noqa
10
+ # isort: skip_file
11
+
12
+ import cutlass
13
+ import cutlass.cute as cute
14
+
15
+ from quack.sort.utils import compare_and_swap
16
+
17
+
18
+ networks = {
19
+ # Size 2: 1 CEs, depth 1
20
+ 2: [[(0, 1)]],
21
+
22
+ # Size 4: 5 CEs, depth 3
23
+ 4: [
24
+ [(0, 2), (1, 3)],
25
+ [(0, 1), (2, 3)],
26
+ [(1, 2)]
27
+ ],
28
+
29
+ # Size 8: 19 CEs, depth 6
30
+ 8: [
31
+ [(0, 2), (1, 3), (4, 6), (5, 7)],
32
+ [(0, 4), (1, 5), (2, 6), (3, 7)],
33
+ [(0, 1), (2, 3), (4, 5), (6, 7)],
34
+ [(2, 4), (3, 5)],
35
+ [(1, 4), (3, 6)],
36
+ [(1, 2), (3, 4), (5, 6)]
37
+ ],
38
+
39
+ # Size 16: 60 CEs, depth 10
40
+ 16: [
41
+ [(0, 13), (1, 12), (2, 15), (3, 14), (4, 8), (5, 6), (7, 11), (9, 10)],
42
+ [(0, 5), (1, 7), (2, 9), (3, 4), (6, 13), (8, 14), (10, 15), (11, 12)],
43
+ [(0, 1), (2, 3), (4, 5), (6, 8), (7, 9), (10, 11), (12, 13), (14, 15)],
44
+ [(0, 2), (1, 3), (4, 10), (5, 11), (6, 7), (8, 9), (12, 14), (13, 15)],
45
+ [(1, 2), (3, 12), (4, 6), (5, 7), (8, 10), (9, 11), (13, 14)],
46
+ [(1, 4), (2, 6), (5, 8), (7, 10), (9, 13), (11, 14)],
47
+ [(2, 4), (3, 6), (9, 12), (11, 13)],
48
+ [(3, 5), (6, 8), (7, 9), (10, 12)],
49
+ [(3, 4), (5, 6), (7, 8), (9, 10), (11, 12)],
50
+ [(6, 7), (8, 9)]
51
+ ],
52
+
53
+ # Size 32: 185 CEs, depth 14
54
+ 32: [
55
+ [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19), (20, 21), (22, 23), (24, 25), (26, 27), (28, 29), (30, 31)],
56
+ [(0, 2), (1, 3), (4, 6), (5, 7), (8, 10), (9, 11), (12, 14), (13, 15), (16, 18), (17, 19), (20, 22), (21, 23), (24, 26), (25, 27), (28, 30), (29, 31)],
57
+ [(0, 4), (1, 5), (2, 6), (3, 7), (8, 12), (9, 13), (10, 14), (11, 15), (16, 20), (17, 21), (18, 22), (19, 23), (24, 28), (25, 29), (26, 30), (27, 31)],
58
+ [(0, 8), (1, 9), (2, 10), (3, 11), (4, 12), (5, 13), (6, 14), (7, 15), (16, 24), (17, 25), (18, 26), (19, 27), (20, 28), (21, 29), (22, 30), (23, 31)],
59
+ [(0, 16), (1, 8), (2, 4), (3, 12), (5, 10), (6, 9), (7, 14), (11, 13), (15, 31), (17, 24), (18, 20), (19, 28), (21, 26), (22, 25), (23, 30), (27, 29)],
60
+ [(1, 2), (3, 5), (4, 8), (6, 22), (7, 11), (9, 25), (10, 12), (13, 14), (17, 18), (19, 21), (20, 24), (23, 27), (26, 28), (29, 30)],
61
+ [(1, 17), (2, 18), (3, 19), (4, 20), (5, 10), (7, 23), (8, 24), (11, 27), (12, 28), (13, 29), (14, 30), (21, 26)],
62
+ [(3, 17), (4, 16), (5, 21), (6, 18), (7, 9), (8, 20), (10, 26), (11, 23), (13, 25), (14, 28), (15, 27), (22, 24)],
63
+ [(1, 4), (3, 8), (5, 16), (7, 17), (9, 21), (10, 22), (11, 19), (12, 20), (14, 24), (15, 26), (23, 28), (27, 30)],
64
+ [(2, 5), (7, 8), (9, 18), (11, 17), (12, 16), (13, 22), (14, 20), (15, 19), (23, 24), (26, 29)],
65
+ [(2, 4), (6, 12), (9, 16), (10, 11), (13, 17), (14, 18), (15, 22), (19, 25), (20, 21), (27, 29)],
66
+ [(5, 6), (8, 12), (9, 10), (11, 13), (14, 16), (15, 17), (18, 20), (19, 23), (21, 22), (25, 26)],
67
+ [(3, 5), (6, 7), (8, 9), (10, 12), (11, 14), (13, 16), (15, 18), (17, 20), (19, 21), (22, 23), (24, 25), (26, 28)],
68
+ [(3, 4), (5, 6), (7, 8), (9, 10), (11, 12), (13, 14), (15, 16), (17, 18), (19, 20), (21, 22), (23, 24), (25, 26), (27, 28)]
69
+ ],
70
+
71
+ # Size 64: 521 CEs, depth 21
72
+ 64: [
73
+ [(0, 2), (1, 3), (4, 6), (5, 7), (8, 10), (9, 11), (12, 14), (13, 15), (16, 18), (17, 19), (20, 22), (21, 23), (24, 26), (25, 27), (28, 30), (29, 31), (32, 34), (33, 35), (36, 38), (37, 39), (40, 42), (41, 43), (44, 46), (45, 47), (48, 50), (49, 51), (52, 54), (53, 55), (56, 58), (57, 59), (60, 62), (61, 63)],
74
+ [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19), (20, 21), (22, 23), (24, 25), (26, 27), (28, 29), (30, 31), (32, 33), (34, 35), (36, 37), (38, 39), (40, 41), (42, 43), (44, 45), (46, 47), (48, 49), (50, 51), (52, 53), (54, 55), (56, 57), (58, 59), (60, 61), (62, 63)],
75
+ [(0, 52), (1, 2), (3, 55), (4, 48), (5, 6), (7, 51), (8, 60), (9, 10), (11, 63), (12, 56), (13, 14), (15, 59), (16, 32), (17, 18), (19, 35), (20, 24), (21, 22), (23, 27), (25, 26), (28, 44), (29, 30), (31, 47), (33, 34), (36, 40), (37, 38), (39, 43), (41, 42), (45, 46), (49, 50), (53, 54), (57, 58), (61, 62)],
76
+ [(0, 20), (1, 53), (2, 54), (3, 23), (4, 28), (5, 49), (6, 50), (7, 31), (8, 36), (9, 61), (10, 62), (11, 39), (12, 16), (13, 57), (14, 58), (15, 19), (17, 33), (18, 34), (21, 25), (22, 26), (24, 52), (27, 55), (29, 45), (30, 46), (32, 56), (35, 59), (37, 41), (38, 42), (40, 60), (43, 63), (44, 48), (47, 51)],
77
+ [(0, 4), (1, 21), (2, 22), (3, 7), (5, 29), (6, 30), (8, 12), (9, 37), (10, 38), (11, 15), (13, 17), (14, 18), (16, 20), (19, 23), (24, 32), (25, 53), (26, 54), (27, 35), (28, 36), (31, 39), (33, 57), (34, 58), (40, 44), (41, 61), (42, 62), (43, 47), (45, 49), (46, 50), (48, 52), (51, 55), (56, 60), (59, 63)],
78
+ [(0, 8), (1, 5), (2, 6), (3, 11), (4, 12), (7, 15), (9, 13), (10, 14), (16, 40), (17, 21), (18, 22), (19, 43), (20, 44), (23, 47), (24, 28), (25, 33), (26, 34), (27, 31), (29, 37), (30, 38), (32, 36), (35, 39), (41, 45), (42, 46), (48, 56), (49, 53), (50, 54), (51, 59), (52, 60), (55, 63), (57, 61), (58, 62)],
79
+ [(1, 9), (2, 10), (4, 8), (5, 13), (6, 14), (7, 11), (12, 48), (15, 51), (16, 24), (17, 41), (18, 42), (19, 27), (20, 28), (21, 45), (22, 46), (23, 31), (25, 29), (26, 30), (32, 40), (33, 37), (34, 38), (35, 43), (36, 44), (39, 47), (49, 57), (50, 58), (52, 56), (53, 61), (54, 62), (55, 59)],
80
+ [(4, 16), (5, 9), (6, 10), (7, 19), (8, 24), (11, 27), (13, 49), (14, 50), (17, 25), (18, 26), (20, 32), (21, 29), (22, 30), (23, 35), (28, 40), (31, 43), (33, 41), (34, 42), (36, 52), (37, 45), (38, 46), (39, 55), (44, 56), (47, 59), (53, 57), (54, 58)],
81
+ [(1, 4), (5, 17), (6, 18), (8, 16), (9, 25), (10, 26), (11, 19), (12, 24), (15, 27), (21, 33), (22, 34), (29, 41), (30, 42), (36, 48), (37, 53), (38, 54), (39, 51), (44, 52), (45, 57), (46, 58), (47, 55), (59, 62)],
82
+ [(2, 8), (9, 17), (10, 18), (12, 20), (13, 25), (14, 26), (15, 23), (24, 32), (27, 35), (28, 36), (31, 39), (37, 49), (38, 50), (40, 48), (43, 51), (45, 53), (46, 54), (55, 61)],
83
+ [(2, 4), (12, 16), (13, 21), (14, 22), (15, 19), (20, 24), (23, 27), (25, 33), (26, 34), (28, 32), (29, 37), (30, 38), (31, 35), (36, 40), (39, 43), (41, 49), (42, 50), (44, 48), (47, 51), (59, 61)],
84
+ [(4, 16), (5, 20), (10, 40), (13, 17), (14, 18), (21, 25), (22, 26), (23, 53), (24, 28), (27, 31), (29, 33), (30, 34), (32, 36), (35, 39), (37, 41), (38, 42), (43, 58), (45, 49), (46, 50), (47, 59)],
85
+ [(3, 17), (6, 36), (7, 21), (8, 32), (9, 24), (11, 41), (13, 28), (14, 44), (15, 45), (18, 48), (19, 49), (22, 52), (25, 29), (26, 30), (27, 57), (31, 55), (33, 37), (34, 38), (35, 50), (39, 54), (42, 56), (46, 60)],
86
+ [(6, 20), (8, 16), (10, 24), (11, 25), (14, 28), (15, 29), (17, 33), (18, 32), (21, 37), (22, 36), (26, 42), (27, 41), (30, 46), (31, 45), (34, 48), (35, 49), (38, 52), (39, 53), (43, 57), (47, 55)],
87
+ [(3, 18), (5, 8), (6, 12), (7, 22), (15, 21), (17, 32), (19, 33), (23, 37), (26, 40), (30, 44), (31, 46), (41, 56), (42, 48), (45, 60), (51, 57), (55, 58)],
88
+ [(3, 16), (7, 20), (11, 26), (18, 24), (19, 25), (22, 28), (23, 29), (27, 33), (30, 36), (34, 40), (35, 41), (37, 52), (38, 44), (39, 45), (43, 56), (47, 60)],
89
+ [(3, 9), (7, 13), (10, 16), (11, 17), (14, 20), (15, 30), (19, 34), (21, 36), (23, 38), (25, 40), (26, 32), (27, 42), (29, 44), (31, 37), (33, 48), (43, 49), (46, 52), (47, 53), (50, 56), (54, 60)],
90
+ [(3, 8), (7, 10), (9, 12), (11, 18), (13, 14), (15, 24), (17, 22), (19, 28), (21, 26), (23, 25), (27, 34), (29, 36), (30, 32), (31, 33), (35, 44), (37, 42), (38, 40), (39, 48), (41, 46), (45, 52), (49, 50), (51, 54), (53, 56), (55, 60)],
91
+ [(3, 6), (7, 12), (11, 16), (15, 17), (18, 20), (19, 24), (21, 22), (23, 30), (25, 32), (26, 28), (27, 29), (31, 38), (33, 40), (34, 36), (35, 37), (39, 44), (41, 42), (43, 45), (46, 48), (47, 52), (51, 56), (57, 60)],
92
+ [(3, 5), (6, 8), (7, 9), (10, 12), (11, 13), (14, 16), (15, 18), (17, 20), (19, 21), (22, 24), (23, 26), (25, 28), (27, 30), (29, 32), (31, 34), (33, 36), (35, 38), (37, 40), (39, 41), (42, 44), (43, 46), (45, 48), (47, 49), (50, 52), (51, 53), (54, 56), (55, 57), (58, 60)],
93
+ [(3, 4), (7, 8), (11, 12), (13, 14), (15, 16), (17, 18), (19, 20), (21, 22), (23, 24), (25, 26), (27, 28), (29, 30), (31, 32), (33, 34), (35, 36), (37, 38), (39, 40), (41, 42), (43, 44), (45, 46), (47, 48), (49, 50), (51, 52), (55, 56), (59, 60)]
94
+ ],
95
+
96
+ }
97
+
98
+
99
+ @cute.jit
100
+ def optimal_sort(
101
+ arr: cute.Tensor,
102
+ n: cutlass.Constexpr[int],
103
+ start: cutlass.Constexpr[int] = 0,
104
+ ascending: cutlass.Constexpr[bool] = True
105
+ ) -> None:
106
+ """
107
+ Optimal sorting network dispatcher.
108
+
109
+ Args:
110
+ arr: Array to sort
111
+ n: Size of array (must be power of 2 and available in networks)
112
+ start: Starting index (default 0)
113
+ ascending: Sort in ascending order (default True)
114
+
115
+ Source: https://bertdobbelaere.github.io/sorting_networks.html
116
+ """
117
+ assert n in networks
118
+ for level in networks[n]:
119
+ for i, j in level:
120
+ compare_and_swap(arr, start + i, start + j, ascending)
quack/sort/utils.py ADDED
@@ -0,0 +1,31 @@
1
+ import cutlass
2
+ import cutlass.cute as cute
3
+
4
+ import quack.utils as utils
5
+
6
+
7
+ @cute.jit
8
+ def compare_and_swap(
9
+ arr: cute.Tensor, i: int, j: int, ascending: bool = True, use_selection: bool = False
10
+ ) -> None:
11
+ """Compare and swap elements at indices i and j in ascending or descending order."""
12
+ if cutlass.const_expr(use_selection):
13
+ a, b = arr[i], arr[j]
14
+ if (a > b) ^ (not ascending):
15
+ arr[i] = b
16
+ arr[j] = a
17
+ # if cutlass.const_expr(ascending):
18
+ # if a > b:
19
+ # arr[i] = b
20
+ # arr[j] = a
21
+ # else:
22
+ # if a < b:
23
+ # arr[i] = b
24
+ # arr[j] = a
25
+ else:
26
+ min_fn = min if cutlass.const_expr(arr.element_type != cutlass.Float32) else utils.fmin
27
+ max_fn = max if cutlass.const_expr(arr.element_type != cutlass.Float32) else cute.arch.fmax
28
+ if cutlass.const_expr(ascending):
29
+ arr[i], arr[j] = min_fn(arr[i], arr[j]), max_fn(arr[i], arr[j])
30
+ else:
31
+ arr[i], arr[j] = max_fn(arr[i], arr[j]), min_fn(arr[i], arr[j])