quack-kernels 0.1.10__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,326 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate optimized sorting network code from the optimal sorting network data.
4
+ Based on data from: https://bertdobbelaere.github.io/sorting_networks.html
5
+
6
+ This script generates CUTE DSL functions for optimal sorting networks of various sizes.
7
+ """
8
+
9
+ import argparse
10
+ import os
11
+ import re
12
+ from typing import List, Tuple, Dict
13
+
14
+ # Network strings from bertdobbelaere.github.io/sorting_networks.html
15
+ # Copy-paste network strings here, then run initialize_networks() to parse them
16
+ NETWORK_STRINGS = {
17
+ # Size 2: 1 CE, depth 1
18
+ 2: """
19
+ [(0,1)]
20
+ """,
21
+ # Size 4: 5 CEs, depth 3
22
+ 4: """
23
+ [(0,2),(1,3)]
24
+ [(0,1),(2,3)]
25
+ [(1,2)]
26
+ """,
27
+ # Size 8: 19 CEs, depth 6
28
+ 8: """
29
+ [(0,2),(1,3),(4,6),(5,7)]
30
+ [(0,4),(1,5),(2,6),(3,7)]
31
+ [(0,1),(2,3),(4,5),(6,7)]
32
+ [(2,4),(3,5)]
33
+ [(1,4),(3,6)]
34
+ [(1,2),(3,4),(5,6)]
35
+ """,
36
+ # Size 16: 60 CEs, depth 10
37
+ 16: """
38
+ [(0,13),(1,12),(2,15),(3,14),(4,8),(5,6),(7,11),(9,10)]
39
+ [(0,5),(1,7),(2,9),(3,4),(6,13),(8,14),(10,15),(11,12)]
40
+ [(0,1),(2,3),(4,5),(6,8),(7,9),(10,11),(12,13),(14,15)]
41
+ [(0,2),(1,3),(4,10),(5,11),(6,7),(8,9),(12,14),(13,15)]
42
+ [(1,2),(3,12),(4,6),(5,7),(8,10),(9,11),(13,14)]
43
+ [(1,4),(2,6),(5,8),(7,10),(9,13),(11,14)]
44
+ [(2,4),(3,6),(9,12),(11,13)]
45
+ [(3,5),(6,8),(7,9),(10,12)]
46
+ [(3,4),(5,6),(7,8),(9,10),(11,12)]
47
+ [(6,7),(8,9)]
48
+ """,
49
+ # Size 32: 185 CEs, depth 14
50
+ 32: """
51
+ [(0,1),(2,3),(4,5),(6,7),(8,9),(10,11),(12,13),(14,15),(16,17),(18,19),(20,21),(22,23),(24,25),(26,27),(28,29),(30,31)]
52
+ [(0,2),(1,3),(4,6),(5,7),(8,10),(9,11),(12,14),(13,15),(16,18),(17,19),(20,22),(21,23),(24,26),(25,27),(28,30),(29,31)]
53
+ [(0,4),(1,5),(2,6),(3,7),(8,12),(9,13),(10,14),(11,15),(16,20),(17,21),(18,22),(19,23),(24,28),(25,29),(26,30),(27,31)]
54
+ [(0,8),(1,9),(2,10),(3,11),(4,12),(5,13),(6,14),(7,15),(16,24),(17,25),(18,26),(19,27),(20,28),(21,29),(22,30),(23,31)]
55
+ [(0,16),(1,8),(2,4),(3,12),(5,10),(6,9),(7,14),(11,13),(15,31),(17,24),(18,20),(19,28),(21,26),(22,25),(23,30),(27,29)]
56
+ [(1,2),(3,5),(4,8),(6,22),(7,11),(9,25),(10,12),(13,14),(17,18),(19,21),(20,24),(23,27),(26,28),(29,30)]
57
+ [(1,17),(2,18),(3,19),(4,20),(5,10),(7,23),(8,24),(11,27),(12,28),(13,29),(14,30),(21,26)]
58
+ [(3,17),(4,16),(5,21),(6,18),(7,9),(8,20),(10,26),(11,23),(13,25),(14,28),(15,27),(22,24)]
59
+ [(1,4),(3,8),(5,16),(7,17),(9,21),(10,22),(11,19),(12,20),(14,24),(15,26),(23,28),(27,30)]
60
+ [(2,5),(7,8),(9,18),(11,17),(12,16),(13,22),(14,20),(15,19),(23,24),(26,29)]
61
+ [(2,4),(6,12),(9,16),(10,11),(13,17),(14,18),(15,22),(19,25),(20,21),(27,29)]
62
+ [(5,6),(8,12),(9,10),(11,13),(14,16),(15,17),(18,20),(19,23),(21,22),(25,26)]
63
+ [(3,5),(6,7),(8,9),(10,12),(11,14),(13,16),(15,18),(17,20),(19,21),(22,23),(24,25),(26,28)]
64
+ [(3,4),(5,6),(7,8),(9,10),(11,12),(13,14),(15,16),(17,18),(19,20),(21,22),(23,24),(25,26),(27,28)]
65
+ """,
66
+ # Size 64: 512 CEs, depth 21
67
+ 64: """
68
+ [(0,2),(1,3),(4,6),(5,7),(8,10),(9,11),(12,14),(13,15),(16,18),(17,19),(20,22),(21,23),(24,26),(25,27),(28,30),(29,31),(32,34),(33,35),(36,38),(37,39),(40,42),(41,43),(44,46),(45,47),(48,50),(49,51),(52,54),(53,55),(56,58),(57,59),(60,62),(61,63)]
69
+ [(0,1),(2,3),(4,5),(6,7),(8,9),(10,11),(12,13),(14,15),(16,17),(18,19),(20,21),(22,23),(24,25),(26,27),(28,29),(30,31),(32,33),(34,35),(36,37),(38,39),(40,41),(42,43),(44,45),(46,47),(48,49),(50,51),(52,53),(54,55),(56,57),(58,59),(60,61),(62,63)]
70
+ [(0,52),(1,2),(3,55),(4,48),(5,6),(7,51),(8,60),(9,10),(11,63),(12,56),(13,14),(15,59),(16,32),(17,18),(19,35),(20,24),(21,22),(23,27),(25,26),(28,44),(29,30),(31,47),(33,34),(36,40),(37,38),(39,43),(41,42),(45,46),(49,50),(53,54),(57,58),(61,62)]
71
+ [(0,20),(1,53),(2,54),(3,23),(4,28),(5,49),(6,50),(7,31),(8,36),(9,61),(10,62),(11,39),(12,16),(13,57),(14,58),(15,19),(17,33),(18,34),(21,25),(22,26),(24,52),(27,55),(29,45),(30,46),(32,56),(35,59),(37,41),(38,42),(40,60),(43,63),(44,48),(47,51)]
72
+ [(0,4),(1,21),(2,22),(3,7),(5,29),(6,30),(8,12),(9,37),(10,38),(11,15),(13,17),(14,18),(16,20),(19,23),(24,32),(25,53),(26,54),(27,35),(28,36),(31,39),(33,57),(34,58),(40,44),(41,61),(42,62),(43,47),(45,49),(46,50),(48,52),(51,55),(56,60),(59,63)]
73
+ [(0,8),(1,5),(2,6),(3,11),(4,12),(7,15),(9,13),(10,14),(16,40),(17,21),(18,22),(19,43),(20,44),(23,47),(24,28),(25,33),(26,34),(27,31),(29,37),(30,38),(32,36),(35,39),(41,45),(42,46),(48,56),(49,53),(50,54),(51,59),(52,60),(55,63),(57,61),(58,62)]
74
+ [(1,9),(2,10),(4,8),(5,13),(6,14),(7,11),(12,48),(15,51),(16,24),(17,41),(18,42),(19,27),(20,28),(21,45),(22,46),(23,31),(25,29),(26,30),(32,40),(33,37),(34,38),(35,43),(36,44),(39,47),(49,57),(50,58),(52,56),(53,61),(54,62),(55,59)]
75
+ [(4,16),(5,9),(6,10),(7,19),(8,24),(11,27),(13,49),(14,50),(17,25),(18,26),(20,32),(21,29),(22,30),(23,35),(28,40),(31,43),(33,41),(34,42),(36,52),(37,45),(38,46),(39,55),(44,56),(47,59),(53,57),(54,58)]
76
+ [(1,4),(5,17),(6,18),(8,16),(9,25),(10,26),(11,19),(12,24),(15,27),(21,33),(22,34),(29,41),(30,42),(36,48),(37,53),(38,54),(39,51),(44,52),(45,57),(46,58),(47,55),(59,62)]
77
+ [(2,8),(9,17),(10,18),(12,20),(13,25),(14,26),(15,23),(24,32),(27,35),(28,36),(31,39),(37,49),(38,50),(40,48),(43,51),(45,53),(46,54),(55,61)]
78
+ [(2,4),(12,16),(13,21),(14,22),(15,19),(20,24),(23,27),(25,33),(26,34),(28,32),(29,37),(30,38),(31,35),(36,40),(39,43),(41,49),(42,50),(44,48),(47,51),(59,61)]
79
+ [(4,16),(5,20),(10,40),(13,17),(14,18),(21,25),(22,26),(23,53),(24,28),(27,31),(29,33),(30,34),(32,36),(35,39),(37,41),(38,42),(43,58),(45,49),(46,50),(47,59)]
80
+ [(3,17),(6,36),(7,21),(8,32),(9,24),(11,41),(13,28),(14,44),(15,45),(18,48),(19,49),(22,52),(25,29),(26,30),(27,57),(31,55),(33,37),(34,38),(35,50),(39,54),(42,56),(46,60)]
81
+ [(6,20),(8,16),(10,24),(11,25),(14,28),(15,29),(17,33),(18,32),(21,37),(22,36),(26,42),(27,41),(30,46),(31,45),(34,48),(35,49),(38,52),(39,53),(43,57),(47,55)]
82
+ [(3,18),(5,8),(6,12),(7,22),(15,21),(17,32),(19,33),(23,37),(26,40),(30,44),(31,46),(41,56),(42,48),(45,60),(51,57),(55,58)]
83
+ [(3,16),(7,20),(11,26),(18,24),(19,25),(22,28),(23,29),(27,33),(30,36),(34,40),(35,41),(37,52),(38,44),(39,45),(43,56),(47,60)]
84
+ [(3,9),(7,13),(10,16),(11,17),(14,20),(15,30),(19,34),(21,36),(23,38),(25,40),(26,32),(27,42),(29,44),(31,37),(33,48),(43,49),(46,52),(47,53),(50,56),(54,60)]
85
+ [(3,8),(7,10),(9,12),(11,18),(13,14),(15,24),(17,22),(19,28),(21,26),(23,25),(27,34),(29,36),(30,32),(31,33),(35,44),(37,42),(38,40),(39,48),(41,46),(45,52),(49,50),(51,54),(53,56),(55,60)]
86
+ [(3,6),(7,12),(11,16),(15,17),(18,20),(19,24),(21,22),(23,30),(25,32),(26,28),(27,29),(31,38),(33,40),(34,36),(35,37),(39,44),(41,42),(43,45),(46,48),(47,52),(51,56),(57,60)]
87
+ [(3,5),(6,8),(7,9),(10,12),(11,13),(14,16),(15,18),(17,20),(19,21),(22,24),(23,26),(25,28),(27,30),(29,32),(31,34),(33,36),(35,38),(37,40),(39,41),(42,44),(43,46),(45,48),(47,49),(50,52),(51,53),(54,56),(55,57),(58,60)]
88
+ [(3,4),(7,8),(11,12),(13,14),(15,16),(17,18),(19,20),(21,22),(23,24),(25,26),(27,28),(29,30),(31,32),(33,34),(35,36),(37,38),(39,40),(41,42),(43,44),(45,46),(47,48),(49,50),(51,52),(55,56),(59,60)]
89
+ """,
90
+ }
91
+
92
+ # This will be populated by initialize_networks()
93
+ OPTIMAL_NETWORKS: Dict[int, Tuple[int, int, List[List[Tuple[int, int]]]]] = {}
94
+
95
+
96
+ def parse_network_string(network_str: str) -> List[List[Tuple[int, int]]]:
97
+ """
98
+ Parse a sorting network string from bertdobbelaere.github.io format.
99
+
100
+ Examples:
101
+ Input: "[(0,2),(1,3)], [(0,1),(2,3)], [(1,2)]"
102
+ Output: [[(0, 2), (1, 3)], [(0, 1), (2, 3)], [(1, 2)]]
103
+
104
+ Input: "[(0,1)], [(1,2)], [(0,1)]"
105
+ Output: [[(0, 1)], [(1, 2)], [(0, 1)]]
106
+ """
107
+ # Remove whitespace and split by '], ['
108
+ network_str = network_str.strip()
109
+ if not network_str:
110
+ return []
111
+
112
+ # Split into layer strings
113
+ layer_pattern = r"\[((?:\(\d+,\d+\)(?:,\(\d+,\d+\))*)?)\]"
114
+ layers = []
115
+
116
+ for match in re.finditer(layer_pattern, network_str):
117
+ layer_str = match.group(1)
118
+ if not layer_str.strip():
119
+ layers.append([])
120
+ continue
121
+
122
+ # Parse comparisons in this layer: (i,j), (k,l), ...
123
+ comparisons = []
124
+ comp_pattern = r"\((\d+),(\d+)\)"
125
+
126
+ for comp_match in re.finditer(comp_pattern, layer_str):
127
+ i, j = int(comp_match.group(1)), int(comp_match.group(2))
128
+ comparisons.append((i, j))
129
+
130
+ layers.append(comparisons)
131
+
132
+ return layers
133
+
134
+
135
+ def calculate_network_stats(layers: List[List[Tuple[int, int]]]) -> Tuple[int, int, int]:
136
+ """Calculate depth, total comparisons, and max index from network layers."""
137
+ depth = len(layers)
138
+ total_comparisons = sum(len(layer) for layer in layers)
139
+
140
+ # Find maximum index to determine network size
141
+ max_index = 0
142
+ for layer in layers:
143
+ for i, j in layer:
144
+ max_index = max(max_index, i, j)
145
+
146
+ network_size = max_index + 1 # Since indices are 0-based
147
+ return depth, total_comparisons, network_size
148
+
149
+
150
+ def add_network_from_string(size: int, network_str: str, description: str = ""):
151
+ """
152
+ Add a network from a string representation to the OPTIMAL_NETWORKS dictionary.
153
+
154
+ Args:
155
+ size: Size of the network (number of elements)
156
+ network_str: Network string in bertdobbelaere.github.io format
157
+ description: Optional description for debugging
158
+ """
159
+ try:
160
+ layers = parse_network_string(network_str)
161
+ depth, comparisons, detected_size = calculate_network_stats(layers)
162
+
163
+ if detected_size != size:
164
+ print(f"Warning: Network size mismatch! Expected {size}, detected {detected_size}")
165
+ print(f"Network string: {network_str[:100]}...")
166
+ return False
167
+
168
+ OPTIMAL_NETWORKS[size] = (depth, comparisons, layers)
169
+
170
+ if description:
171
+ print(f"Added network for size {size}: {description}")
172
+ print(f" Depth: {depth}, Comparisons: {comparisons}")
173
+ return True
174
+
175
+ except Exception as e:
176
+ print(f"Error parsing network for size {size}: {e}")
177
+ print(f"Network string: {network_str[:100]}...")
178
+ return False
179
+
180
+
181
+ def generate_networks_dict(
182
+ networks_data: Dict[int, Tuple[int, int, List[List[Tuple[int, int]]]]]
183
+ ) -> str:
184
+ """Generate the global networks dictionary."""
185
+ lines = ["networks = {"]
186
+
187
+ for size, (depth, num_comparisons, layers) in sorted(networks_data.items()):
188
+ # Format the network with proper indentation and newlines
189
+ network_lines = []
190
+ for i, layer in enumerate(layers):
191
+ if i == 0:
192
+ network_lines.append(f" {layer}")
193
+ else:
194
+ network_lines.append(f",\n {layer}")
195
+
196
+ if len(layers) == 1:
197
+ network_str = f"[{network_lines[0].strip()}]"
198
+ else:
199
+ network_str = "[\n" + "".join(network_lines) + "\n ]"
200
+
201
+ lines.append(f" # Size {size}: {num_comparisons} CEs, depth {depth}")
202
+ lines.append(f" {size}: {network_str},")
203
+ lines.append("")
204
+
205
+ lines.append("}")
206
+ return "\n".join(lines)
207
+
208
+
209
+ def generate_optimal_sort_function() -> str:
210
+ """Generate the single optimal_sort function that looks up networks by size."""
211
+ return """@cute.jit
212
+ def optimal_sort(
213
+ arr: cute.Tensor,
214
+ n: cutlass.Constexpr[int],
215
+ start: cutlass.Constexpr[int] = 0,
216
+ ascending: cutlass.Constexpr[bool] = True
217
+ ) -> None:
218
+ \"\"\"
219
+ Optimal sorting network dispatcher.
220
+
221
+ Args:
222
+ arr: Array to sort
223
+ n: Size of array (must be power of 2 and available in networks)
224
+ start: Starting index (default 0)
225
+ ascending: Sort in ascending order (default True)
226
+
227
+ Source: https://bertdobbelaere.github.io/sorting_networks.html
228
+ \"\"\"
229
+ assert n in networks
230
+ for level in networks[n]:
231
+ for i, j in level:
232
+ compare_and_swap(arr, start + i, start + j, ascending)
233
+ """
234
+
235
+
236
+ def generate_sorting_networks_file(max_size: int = 64):
237
+ """Generate a complete sorting networks file with optimal networks up to max_size."""
238
+
239
+ output_file = os.path.join(os.path.dirname(__file__), "sorting_networks.py")
240
+
241
+ # Header
242
+ header = '''# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
243
+ """
244
+ Optimal sorting networks generated from: https://bertdobbelaere.github.io/sorting_networks.html
245
+
246
+ This file was auto-generated by quack/sort/generate_sorting_networks.py. Do not edit it directly.
247
+ """
248
+
249
+ # fmt: off
250
+ # ruff: noqa
251
+ # isort: skip_file
252
+
253
+ import cutlass
254
+ import cutlass.cute as cute
255
+
256
+ from quack.sort.utils import compare_and_swap
257
+
258
+
259
+ '''
260
+
261
+ # Generate networks dictionary and optimal_sort function
262
+ sizes = [n for n in range(2, max_size + 1) if n in OPTIMAL_NETWORKS]
263
+ networks_dict = generate_networks_dict(OPTIMAL_NETWORKS)
264
+ optimal_sort_func = generate_optimal_sort_function()
265
+
266
+ # Combine everything
267
+ content = header + networks_dict + "\n\n\n" + optimal_sort_func
268
+
269
+ with open(output_file, "w") as f:
270
+ f.write(content)
271
+
272
+ print(f"Generated optimal sorting networks for sizes {sizes}")
273
+ print(f"Output written to: {output_file}")
274
+ return sizes
275
+
276
+
277
+ def initialize_networks():
278
+ """Initialize the OPTIMAL_NETWORKS dictionary by parsing NETWORK_STRINGS."""
279
+ global OPTIMAL_NETWORKS
280
+ OPTIMAL_NETWORKS.clear()
281
+
282
+ for size, network_str in NETWORK_STRINGS.items():
283
+ success = add_network_from_string(size, network_str, f"Size {size} optimal network")
284
+ if not success:
285
+ print(f"Warning: Failed to parse network for size {size}")
286
+
287
+
288
+ def main():
289
+ parser = argparse.ArgumentParser(
290
+ description="Generate optimal sorting network code from bertdobbelaere.github.io data"
291
+ )
292
+ parser.add_argument(
293
+ "--max-size",
294
+ "-m",
295
+ type=int,
296
+ default=64,
297
+ help="Maximum sorting network size to generate (default: 32)",
298
+ )
299
+ parser.add_argument(
300
+ "--stats", "-s", action="store_true", help="Print statistics about the optimal networks"
301
+ )
302
+
303
+ args = parser.parse_args()
304
+
305
+ # Initialize networks from strings
306
+ initialize_networks()
307
+
308
+ if args.stats:
309
+ print("Optimal Sorting Network Statistics:")
310
+ print("Size\tDepth\tComparisons\tLayers")
311
+ print("-" * 35)
312
+ for n in sorted(OPTIMAL_NETWORKS.keys()):
313
+ if n <= args.max_size:
314
+ depth, comparisons, layers = OPTIMAL_NETWORKS[n]
315
+ print(f"{n}\t{depth}\t{comparisons}\t\t{len(layers)}")
316
+
317
+ # Generate the sorting networks file
318
+ sizes = generate_sorting_networks_file(args.max_size)
319
+
320
+ print(f"\nGenerated optimal sorting networks for {len(sizes)} sizes")
321
+ print(f"Total networks: {len(sizes)}")
322
+ print(f"Max network size: {max(sizes)}")
323
+
324
+
325
+ if __name__ == "__main__":
326
+ main()
@@ -0,0 +1,120 @@
1
+ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
2
+ """
3
+ Optimal sorting networks generated from: https://bertdobbelaere.github.io/sorting_networks.html
4
+
5
+ This file was auto-generated by quack/sort/generate_sorting_networks.py. Do not edit it directly.
6
+ """
7
+
8
+ # fmt: off
9
+ # ruff: noqa
10
+ # isort: skip_file
11
+
12
+ import cutlass
13
+ import cutlass.cute as cute
14
+
15
+ from quack.sort.utils import compare_and_swap
16
+
17
+
18
+ networks = {
19
+ # Size 2: 1 CEs, depth 1
20
+ 2: [[(0, 1)]],
21
+
22
+ # Size 4: 5 CEs, depth 3
23
+ 4: [
24
+ [(0, 2), (1, 3)],
25
+ [(0, 1), (2, 3)],
26
+ [(1, 2)]
27
+ ],
28
+
29
+ # Size 8: 19 CEs, depth 6
30
+ 8: [
31
+ [(0, 2), (1, 3), (4, 6), (5, 7)],
32
+ [(0, 4), (1, 5), (2, 6), (3, 7)],
33
+ [(0, 1), (2, 3), (4, 5), (6, 7)],
34
+ [(2, 4), (3, 5)],
35
+ [(1, 4), (3, 6)],
36
+ [(1, 2), (3, 4), (5, 6)]
37
+ ],
38
+
39
+ # Size 16: 60 CEs, depth 10
40
+ 16: [
41
+ [(0, 13), (1, 12), (2, 15), (3, 14), (4, 8), (5, 6), (7, 11), (9, 10)],
42
+ [(0, 5), (1, 7), (2, 9), (3, 4), (6, 13), (8, 14), (10, 15), (11, 12)],
43
+ [(0, 1), (2, 3), (4, 5), (6, 8), (7, 9), (10, 11), (12, 13), (14, 15)],
44
+ [(0, 2), (1, 3), (4, 10), (5, 11), (6, 7), (8, 9), (12, 14), (13, 15)],
45
+ [(1, 2), (3, 12), (4, 6), (5, 7), (8, 10), (9, 11), (13, 14)],
46
+ [(1, 4), (2, 6), (5, 8), (7, 10), (9, 13), (11, 14)],
47
+ [(2, 4), (3, 6), (9, 12), (11, 13)],
48
+ [(3, 5), (6, 8), (7, 9), (10, 12)],
49
+ [(3, 4), (5, 6), (7, 8), (9, 10), (11, 12)],
50
+ [(6, 7), (8, 9)]
51
+ ],
52
+
53
+ # Size 32: 185 CEs, depth 14
54
+ 32: [
55
+ [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19), (20, 21), (22, 23), (24, 25), (26, 27), (28, 29), (30, 31)],
56
+ [(0, 2), (1, 3), (4, 6), (5, 7), (8, 10), (9, 11), (12, 14), (13, 15), (16, 18), (17, 19), (20, 22), (21, 23), (24, 26), (25, 27), (28, 30), (29, 31)],
57
+ [(0, 4), (1, 5), (2, 6), (3, 7), (8, 12), (9, 13), (10, 14), (11, 15), (16, 20), (17, 21), (18, 22), (19, 23), (24, 28), (25, 29), (26, 30), (27, 31)],
58
+ [(0, 8), (1, 9), (2, 10), (3, 11), (4, 12), (5, 13), (6, 14), (7, 15), (16, 24), (17, 25), (18, 26), (19, 27), (20, 28), (21, 29), (22, 30), (23, 31)],
59
+ [(0, 16), (1, 8), (2, 4), (3, 12), (5, 10), (6, 9), (7, 14), (11, 13), (15, 31), (17, 24), (18, 20), (19, 28), (21, 26), (22, 25), (23, 30), (27, 29)],
60
+ [(1, 2), (3, 5), (4, 8), (6, 22), (7, 11), (9, 25), (10, 12), (13, 14), (17, 18), (19, 21), (20, 24), (23, 27), (26, 28), (29, 30)],
61
+ [(1, 17), (2, 18), (3, 19), (4, 20), (5, 10), (7, 23), (8, 24), (11, 27), (12, 28), (13, 29), (14, 30), (21, 26)],
62
+ [(3, 17), (4, 16), (5, 21), (6, 18), (7, 9), (8, 20), (10, 26), (11, 23), (13, 25), (14, 28), (15, 27), (22, 24)],
63
+ [(1, 4), (3, 8), (5, 16), (7, 17), (9, 21), (10, 22), (11, 19), (12, 20), (14, 24), (15, 26), (23, 28), (27, 30)],
64
+ [(2, 5), (7, 8), (9, 18), (11, 17), (12, 16), (13, 22), (14, 20), (15, 19), (23, 24), (26, 29)],
65
+ [(2, 4), (6, 12), (9, 16), (10, 11), (13, 17), (14, 18), (15, 22), (19, 25), (20, 21), (27, 29)],
66
+ [(5, 6), (8, 12), (9, 10), (11, 13), (14, 16), (15, 17), (18, 20), (19, 23), (21, 22), (25, 26)],
67
+ [(3, 5), (6, 7), (8, 9), (10, 12), (11, 14), (13, 16), (15, 18), (17, 20), (19, 21), (22, 23), (24, 25), (26, 28)],
68
+ [(3, 4), (5, 6), (7, 8), (9, 10), (11, 12), (13, 14), (15, 16), (17, 18), (19, 20), (21, 22), (23, 24), (25, 26), (27, 28)]
69
+ ],
70
+
71
+ # Size 64: 521 CEs, depth 21
72
+ 64: [
73
+ [(0, 2), (1, 3), (4, 6), (5, 7), (8, 10), (9, 11), (12, 14), (13, 15), (16, 18), (17, 19), (20, 22), (21, 23), (24, 26), (25, 27), (28, 30), (29, 31), (32, 34), (33, 35), (36, 38), (37, 39), (40, 42), (41, 43), (44, 46), (45, 47), (48, 50), (49, 51), (52, 54), (53, 55), (56, 58), (57, 59), (60, 62), (61, 63)],
74
+ [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19), (20, 21), (22, 23), (24, 25), (26, 27), (28, 29), (30, 31), (32, 33), (34, 35), (36, 37), (38, 39), (40, 41), (42, 43), (44, 45), (46, 47), (48, 49), (50, 51), (52, 53), (54, 55), (56, 57), (58, 59), (60, 61), (62, 63)],
75
+ [(0, 52), (1, 2), (3, 55), (4, 48), (5, 6), (7, 51), (8, 60), (9, 10), (11, 63), (12, 56), (13, 14), (15, 59), (16, 32), (17, 18), (19, 35), (20, 24), (21, 22), (23, 27), (25, 26), (28, 44), (29, 30), (31, 47), (33, 34), (36, 40), (37, 38), (39, 43), (41, 42), (45, 46), (49, 50), (53, 54), (57, 58), (61, 62)],
76
+ [(0, 20), (1, 53), (2, 54), (3, 23), (4, 28), (5, 49), (6, 50), (7, 31), (8, 36), (9, 61), (10, 62), (11, 39), (12, 16), (13, 57), (14, 58), (15, 19), (17, 33), (18, 34), (21, 25), (22, 26), (24, 52), (27, 55), (29, 45), (30, 46), (32, 56), (35, 59), (37, 41), (38, 42), (40, 60), (43, 63), (44, 48), (47, 51)],
77
+ [(0, 4), (1, 21), (2, 22), (3, 7), (5, 29), (6, 30), (8, 12), (9, 37), (10, 38), (11, 15), (13, 17), (14, 18), (16, 20), (19, 23), (24, 32), (25, 53), (26, 54), (27, 35), (28, 36), (31, 39), (33, 57), (34, 58), (40, 44), (41, 61), (42, 62), (43, 47), (45, 49), (46, 50), (48, 52), (51, 55), (56, 60), (59, 63)],
78
+ [(0, 8), (1, 5), (2, 6), (3, 11), (4, 12), (7, 15), (9, 13), (10, 14), (16, 40), (17, 21), (18, 22), (19, 43), (20, 44), (23, 47), (24, 28), (25, 33), (26, 34), (27, 31), (29, 37), (30, 38), (32, 36), (35, 39), (41, 45), (42, 46), (48, 56), (49, 53), (50, 54), (51, 59), (52, 60), (55, 63), (57, 61), (58, 62)],
79
+ [(1, 9), (2, 10), (4, 8), (5, 13), (6, 14), (7, 11), (12, 48), (15, 51), (16, 24), (17, 41), (18, 42), (19, 27), (20, 28), (21, 45), (22, 46), (23, 31), (25, 29), (26, 30), (32, 40), (33, 37), (34, 38), (35, 43), (36, 44), (39, 47), (49, 57), (50, 58), (52, 56), (53, 61), (54, 62), (55, 59)],
80
+ [(4, 16), (5, 9), (6, 10), (7, 19), (8, 24), (11, 27), (13, 49), (14, 50), (17, 25), (18, 26), (20, 32), (21, 29), (22, 30), (23, 35), (28, 40), (31, 43), (33, 41), (34, 42), (36, 52), (37, 45), (38, 46), (39, 55), (44, 56), (47, 59), (53, 57), (54, 58)],
81
+ [(1, 4), (5, 17), (6, 18), (8, 16), (9, 25), (10, 26), (11, 19), (12, 24), (15, 27), (21, 33), (22, 34), (29, 41), (30, 42), (36, 48), (37, 53), (38, 54), (39, 51), (44, 52), (45, 57), (46, 58), (47, 55), (59, 62)],
82
+ [(2, 8), (9, 17), (10, 18), (12, 20), (13, 25), (14, 26), (15, 23), (24, 32), (27, 35), (28, 36), (31, 39), (37, 49), (38, 50), (40, 48), (43, 51), (45, 53), (46, 54), (55, 61)],
83
+ [(2, 4), (12, 16), (13, 21), (14, 22), (15, 19), (20, 24), (23, 27), (25, 33), (26, 34), (28, 32), (29, 37), (30, 38), (31, 35), (36, 40), (39, 43), (41, 49), (42, 50), (44, 48), (47, 51), (59, 61)],
84
+ [(4, 16), (5, 20), (10, 40), (13, 17), (14, 18), (21, 25), (22, 26), (23, 53), (24, 28), (27, 31), (29, 33), (30, 34), (32, 36), (35, 39), (37, 41), (38, 42), (43, 58), (45, 49), (46, 50), (47, 59)],
85
+ [(3, 17), (6, 36), (7, 21), (8, 32), (9, 24), (11, 41), (13, 28), (14, 44), (15, 45), (18, 48), (19, 49), (22, 52), (25, 29), (26, 30), (27, 57), (31, 55), (33, 37), (34, 38), (35, 50), (39, 54), (42, 56), (46, 60)],
86
+ [(6, 20), (8, 16), (10, 24), (11, 25), (14, 28), (15, 29), (17, 33), (18, 32), (21, 37), (22, 36), (26, 42), (27, 41), (30, 46), (31, 45), (34, 48), (35, 49), (38, 52), (39, 53), (43, 57), (47, 55)],
87
+ [(3, 18), (5, 8), (6, 12), (7, 22), (15, 21), (17, 32), (19, 33), (23, 37), (26, 40), (30, 44), (31, 46), (41, 56), (42, 48), (45, 60), (51, 57), (55, 58)],
88
+ [(3, 16), (7, 20), (11, 26), (18, 24), (19, 25), (22, 28), (23, 29), (27, 33), (30, 36), (34, 40), (35, 41), (37, 52), (38, 44), (39, 45), (43, 56), (47, 60)],
89
+ [(3, 9), (7, 13), (10, 16), (11, 17), (14, 20), (15, 30), (19, 34), (21, 36), (23, 38), (25, 40), (26, 32), (27, 42), (29, 44), (31, 37), (33, 48), (43, 49), (46, 52), (47, 53), (50, 56), (54, 60)],
90
+ [(3, 8), (7, 10), (9, 12), (11, 18), (13, 14), (15, 24), (17, 22), (19, 28), (21, 26), (23, 25), (27, 34), (29, 36), (30, 32), (31, 33), (35, 44), (37, 42), (38, 40), (39, 48), (41, 46), (45, 52), (49, 50), (51, 54), (53, 56), (55, 60)],
91
+ [(3, 6), (7, 12), (11, 16), (15, 17), (18, 20), (19, 24), (21, 22), (23, 30), (25, 32), (26, 28), (27, 29), (31, 38), (33, 40), (34, 36), (35, 37), (39, 44), (41, 42), (43, 45), (46, 48), (47, 52), (51, 56), (57, 60)],
92
+ [(3, 5), (6, 8), (7, 9), (10, 12), (11, 13), (14, 16), (15, 18), (17, 20), (19, 21), (22, 24), (23, 26), (25, 28), (27, 30), (29, 32), (31, 34), (33, 36), (35, 38), (37, 40), (39, 41), (42, 44), (43, 46), (45, 48), (47, 49), (50, 52), (51, 53), (54, 56), (55, 57), (58, 60)],
93
+ [(3, 4), (7, 8), (11, 12), (13, 14), (15, 16), (17, 18), (19, 20), (21, 22), (23, 24), (25, 26), (27, 28), (29, 30), (31, 32), (33, 34), (35, 36), (37, 38), (39, 40), (41, 42), (43, 44), (45, 46), (47, 48), (49, 50), (51, 52), (55, 56), (59, 60)]
94
+ ],
95
+
96
+ }
97
+
98
+
99
+ @cute.jit
100
+ def optimal_sort(
101
+ arr: cute.Tensor,
102
+ n: cutlass.Constexpr[int],
103
+ start: cutlass.Constexpr[int] = 0,
104
+ ascending: cutlass.Constexpr[bool] = True
105
+ ) -> None:
106
+ """
107
+ Optimal sorting network dispatcher.
108
+
109
+ Args:
110
+ arr: Array to sort
111
+ n: Size of array (must be power of 2 and available in networks)
112
+ start: Starting index (default 0)
113
+ ascending: Sort in ascending order (default True)
114
+
115
+ Source: https://bertdobbelaere.github.io/sorting_networks.html
116
+ """
117
+ assert n in networks
118
+ for level in networks[n]:
119
+ for i, j in level:
120
+ compare_and_swap(arr, start + i, start + j, ascending)
quack/sort/utils.py ADDED
@@ -0,0 +1,31 @@
1
+ import cutlass
2
+ import cutlass.cute as cute
3
+
4
+ import quack.utils as utils
5
+
6
+
7
+ @cute.jit
8
+ def compare_and_swap(
9
+ arr: cute.Tensor, i: int, j: int, ascending: bool = True, use_selection: bool = False
10
+ ) -> None:
11
+ """Compare and swap elements at indices i and j in ascending or descending order."""
12
+ if cutlass.const_expr(use_selection):
13
+ a, b = arr[i], arr[j]
14
+ if (a > b) ^ (not ascending):
15
+ arr[i] = b
16
+ arr[j] = a
17
+ # if cutlass.const_expr(ascending):
18
+ # if a > b:
19
+ # arr[i] = b
20
+ # arr[j] = a
21
+ # else:
22
+ # if a < b:
23
+ # arr[i] = b
24
+ # arr[j] = a
25
+ else:
26
+ min_fn = min if cutlass.const_expr(arr.element_type != cutlass.Float32) else utils.fmin
27
+ max_fn = max if cutlass.const_expr(arr.element_type != cutlass.Float32) else cute.arch.fmax
28
+ if cutlass.const_expr(ascending):
29
+ arr[i], arr[j] = min_fn(arr[i], arr[j]), max_fn(arr[i], arr[j])
30
+ else:
31
+ arr[i], arr[j] = max_fn(arr[i], arr[j]), min_fn(arr[i], arr[j])