quack-kernels 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +4 -1
- quack/autotuner.py +309 -0
- quack/cross_entropy.py +2 -5
- quack/cute_dsl_utils.py +40 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +2474 -0
- quack/fast_math.py +97 -0
- quack/gemm_config.py +61 -0
- quack/gemm_interface.py +321 -0
- quack/linear.py +176 -0
- quack/lse.py +62 -0
- quack/mlp.py +204 -0
- quack/pipeline.py +166 -0
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2088 -0
- quack/tensormap_manager.py +114 -0
- quack/tile_scheduler.py +935 -0
- quack/topk.py +221 -0
- quack/utils.py +237 -19
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/METADATA +3 -3
- quack_kernels-0.1.11.dist-info/RECORD +31 -0
- quack_kernels-0.1.9.dist-info/RECORD +0 -12
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Generate optimized sorting network code from the optimal sorting network data.
|
|
4
|
+
Based on data from: https://bertdobbelaere.github.io/sorting_networks.html
|
|
5
|
+
|
|
6
|
+
This script generates CUTE DSL functions for optimal sorting networks of various sizes.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import argparse
|
|
10
|
+
import os
|
|
11
|
+
import re
|
|
12
|
+
from typing import List, Tuple, Dict
|
|
13
|
+
|
|
14
|
+
# Network strings from bertdobbelaere.github.io/sorting_networks.html
|
|
15
|
+
# Copy-paste network strings here, then run initialize_networks() to parse them
|
|
16
|
+
NETWORK_STRINGS = {
|
|
17
|
+
# Size 2: 1 CE, depth 1
|
|
18
|
+
2: """
|
|
19
|
+
[(0,1)]
|
|
20
|
+
""",
|
|
21
|
+
# Size 4: 5 CEs, depth 3
|
|
22
|
+
4: """
|
|
23
|
+
[(0,2),(1,3)]
|
|
24
|
+
[(0,1),(2,3)]
|
|
25
|
+
[(1,2)]
|
|
26
|
+
""",
|
|
27
|
+
# Size 8: 19 CEs, depth 6
|
|
28
|
+
8: """
|
|
29
|
+
[(0,2),(1,3),(4,6),(5,7)]
|
|
30
|
+
[(0,4),(1,5),(2,6),(3,7)]
|
|
31
|
+
[(0,1),(2,3),(4,5),(6,7)]
|
|
32
|
+
[(2,4),(3,5)]
|
|
33
|
+
[(1,4),(3,6)]
|
|
34
|
+
[(1,2),(3,4),(5,6)]
|
|
35
|
+
""",
|
|
36
|
+
# Size 16: 60 CEs, depth 10
|
|
37
|
+
16: """
|
|
38
|
+
[(0,13),(1,12),(2,15),(3,14),(4,8),(5,6),(7,11),(9,10)]
|
|
39
|
+
[(0,5),(1,7),(2,9),(3,4),(6,13),(8,14),(10,15),(11,12)]
|
|
40
|
+
[(0,1),(2,3),(4,5),(6,8),(7,9),(10,11),(12,13),(14,15)]
|
|
41
|
+
[(0,2),(1,3),(4,10),(5,11),(6,7),(8,9),(12,14),(13,15)]
|
|
42
|
+
[(1,2),(3,12),(4,6),(5,7),(8,10),(9,11),(13,14)]
|
|
43
|
+
[(1,4),(2,6),(5,8),(7,10),(9,13),(11,14)]
|
|
44
|
+
[(2,4),(3,6),(9,12),(11,13)]
|
|
45
|
+
[(3,5),(6,8),(7,9),(10,12)]
|
|
46
|
+
[(3,4),(5,6),(7,8),(9,10),(11,12)]
|
|
47
|
+
[(6,7),(8,9)]
|
|
48
|
+
""",
|
|
49
|
+
# Size 32: 185 CEs, depth 14
|
|
50
|
+
32: """
|
|
51
|
+
[(0,1),(2,3),(4,5),(6,7),(8,9),(10,11),(12,13),(14,15),(16,17),(18,19),(20,21),(22,23),(24,25),(26,27),(28,29),(30,31)]
|
|
52
|
+
[(0,2),(1,3),(4,6),(5,7),(8,10),(9,11),(12,14),(13,15),(16,18),(17,19),(20,22),(21,23),(24,26),(25,27),(28,30),(29,31)]
|
|
53
|
+
[(0,4),(1,5),(2,6),(3,7),(8,12),(9,13),(10,14),(11,15),(16,20),(17,21),(18,22),(19,23),(24,28),(25,29),(26,30),(27,31)]
|
|
54
|
+
[(0,8),(1,9),(2,10),(3,11),(4,12),(5,13),(6,14),(7,15),(16,24),(17,25),(18,26),(19,27),(20,28),(21,29),(22,30),(23,31)]
|
|
55
|
+
[(0,16),(1,8),(2,4),(3,12),(5,10),(6,9),(7,14),(11,13),(15,31),(17,24),(18,20),(19,28),(21,26),(22,25),(23,30),(27,29)]
|
|
56
|
+
[(1,2),(3,5),(4,8),(6,22),(7,11),(9,25),(10,12),(13,14),(17,18),(19,21),(20,24),(23,27),(26,28),(29,30)]
|
|
57
|
+
[(1,17),(2,18),(3,19),(4,20),(5,10),(7,23),(8,24),(11,27),(12,28),(13,29),(14,30),(21,26)]
|
|
58
|
+
[(3,17),(4,16),(5,21),(6,18),(7,9),(8,20),(10,26),(11,23),(13,25),(14,28),(15,27),(22,24)]
|
|
59
|
+
[(1,4),(3,8),(5,16),(7,17),(9,21),(10,22),(11,19),(12,20),(14,24),(15,26),(23,28),(27,30)]
|
|
60
|
+
[(2,5),(7,8),(9,18),(11,17),(12,16),(13,22),(14,20),(15,19),(23,24),(26,29)]
|
|
61
|
+
[(2,4),(6,12),(9,16),(10,11),(13,17),(14,18),(15,22),(19,25),(20,21),(27,29)]
|
|
62
|
+
[(5,6),(8,12),(9,10),(11,13),(14,16),(15,17),(18,20),(19,23),(21,22),(25,26)]
|
|
63
|
+
[(3,5),(6,7),(8,9),(10,12),(11,14),(13,16),(15,18),(17,20),(19,21),(22,23),(24,25),(26,28)]
|
|
64
|
+
[(3,4),(5,6),(7,8),(9,10),(11,12),(13,14),(15,16),(17,18),(19,20),(21,22),(23,24),(25,26),(27,28)]
|
|
65
|
+
""",
|
|
66
|
+
# Size 64: 512 CEs, depth 21
|
|
67
|
+
64: """
|
|
68
|
+
[(0,2),(1,3),(4,6),(5,7),(8,10),(9,11),(12,14),(13,15),(16,18),(17,19),(20,22),(21,23),(24,26),(25,27),(28,30),(29,31),(32,34),(33,35),(36,38),(37,39),(40,42),(41,43),(44,46),(45,47),(48,50),(49,51),(52,54),(53,55),(56,58),(57,59),(60,62),(61,63)]
|
|
69
|
+
[(0,1),(2,3),(4,5),(6,7),(8,9),(10,11),(12,13),(14,15),(16,17),(18,19),(20,21),(22,23),(24,25),(26,27),(28,29),(30,31),(32,33),(34,35),(36,37),(38,39),(40,41),(42,43),(44,45),(46,47),(48,49),(50,51),(52,53),(54,55),(56,57),(58,59),(60,61),(62,63)]
|
|
70
|
+
[(0,52),(1,2),(3,55),(4,48),(5,6),(7,51),(8,60),(9,10),(11,63),(12,56),(13,14),(15,59),(16,32),(17,18),(19,35),(20,24),(21,22),(23,27),(25,26),(28,44),(29,30),(31,47),(33,34),(36,40),(37,38),(39,43),(41,42),(45,46),(49,50),(53,54),(57,58),(61,62)]
|
|
71
|
+
[(0,20),(1,53),(2,54),(3,23),(4,28),(5,49),(6,50),(7,31),(8,36),(9,61),(10,62),(11,39),(12,16),(13,57),(14,58),(15,19),(17,33),(18,34),(21,25),(22,26),(24,52),(27,55),(29,45),(30,46),(32,56),(35,59),(37,41),(38,42),(40,60),(43,63),(44,48),(47,51)]
|
|
72
|
+
[(0,4),(1,21),(2,22),(3,7),(5,29),(6,30),(8,12),(9,37),(10,38),(11,15),(13,17),(14,18),(16,20),(19,23),(24,32),(25,53),(26,54),(27,35),(28,36),(31,39),(33,57),(34,58),(40,44),(41,61),(42,62),(43,47),(45,49),(46,50),(48,52),(51,55),(56,60),(59,63)]
|
|
73
|
+
[(0,8),(1,5),(2,6),(3,11),(4,12),(7,15),(9,13),(10,14),(16,40),(17,21),(18,22),(19,43),(20,44),(23,47),(24,28),(25,33),(26,34),(27,31),(29,37),(30,38),(32,36),(35,39),(41,45),(42,46),(48,56),(49,53),(50,54),(51,59),(52,60),(55,63),(57,61),(58,62)]
|
|
74
|
+
[(1,9),(2,10),(4,8),(5,13),(6,14),(7,11),(12,48),(15,51),(16,24),(17,41),(18,42),(19,27),(20,28),(21,45),(22,46),(23,31),(25,29),(26,30),(32,40),(33,37),(34,38),(35,43),(36,44),(39,47),(49,57),(50,58),(52,56),(53,61),(54,62),(55,59)]
|
|
75
|
+
[(4,16),(5,9),(6,10),(7,19),(8,24),(11,27),(13,49),(14,50),(17,25),(18,26),(20,32),(21,29),(22,30),(23,35),(28,40),(31,43),(33,41),(34,42),(36,52),(37,45),(38,46),(39,55),(44,56),(47,59),(53,57),(54,58)]
|
|
76
|
+
[(1,4),(5,17),(6,18),(8,16),(9,25),(10,26),(11,19),(12,24),(15,27),(21,33),(22,34),(29,41),(30,42),(36,48),(37,53),(38,54),(39,51),(44,52),(45,57),(46,58),(47,55),(59,62)]
|
|
77
|
+
[(2,8),(9,17),(10,18),(12,20),(13,25),(14,26),(15,23),(24,32),(27,35),(28,36),(31,39),(37,49),(38,50),(40,48),(43,51),(45,53),(46,54),(55,61)]
|
|
78
|
+
[(2,4),(12,16),(13,21),(14,22),(15,19),(20,24),(23,27),(25,33),(26,34),(28,32),(29,37),(30,38),(31,35),(36,40),(39,43),(41,49),(42,50),(44,48),(47,51),(59,61)]
|
|
79
|
+
[(4,16),(5,20),(10,40),(13,17),(14,18),(21,25),(22,26),(23,53),(24,28),(27,31),(29,33),(30,34),(32,36),(35,39),(37,41),(38,42),(43,58),(45,49),(46,50),(47,59)]
|
|
80
|
+
[(3,17),(6,36),(7,21),(8,32),(9,24),(11,41),(13,28),(14,44),(15,45),(18,48),(19,49),(22,52),(25,29),(26,30),(27,57),(31,55),(33,37),(34,38),(35,50),(39,54),(42,56),(46,60)]
|
|
81
|
+
[(6,20),(8,16),(10,24),(11,25),(14,28),(15,29),(17,33),(18,32),(21,37),(22,36),(26,42),(27,41),(30,46),(31,45),(34,48),(35,49),(38,52),(39,53),(43,57),(47,55)]
|
|
82
|
+
[(3,18),(5,8),(6,12),(7,22),(15,21),(17,32),(19,33),(23,37),(26,40),(30,44),(31,46),(41,56),(42,48),(45,60),(51,57),(55,58)]
|
|
83
|
+
[(3,16),(7,20),(11,26),(18,24),(19,25),(22,28),(23,29),(27,33),(30,36),(34,40),(35,41),(37,52),(38,44),(39,45),(43,56),(47,60)]
|
|
84
|
+
[(3,9),(7,13),(10,16),(11,17),(14,20),(15,30),(19,34),(21,36),(23,38),(25,40),(26,32),(27,42),(29,44),(31,37),(33,48),(43,49),(46,52),(47,53),(50,56),(54,60)]
|
|
85
|
+
[(3,8),(7,10),(9,12),(11,18),(13,14),(15,24),(17,22),(19,28),(21,26),(23,25),(27,34),(29,36),(30,32),(31,33),(35,44),(37,42),(38,40),(39,48),(41,46),(45,52),(49,50),(51,54),(53,56),(55,60)]
|
|
86
|
+
[(3,6),(7,12),(11,16),(15,17),(18,20),(19,24),(21,22),(23,30),(25,32),(26,28),(27,29),(31,38),(33,40),(34,36),(35,37),(39,44),(41,42),(43,45),(46,48),(47,52),(51,56),(57,60)]
|
|
87
|
+
[(3,5),(6,8),(7,9),(10,12),(11,13),(14,16),(15,18),(17,20),(19,21),(22,24),(23,26),(25,28),(27,30),(29,32),(31,34),(33,36),(35,38),(37,40),(39,41),(42,44),(43,46),(45,48),(47,49),(50,52),(51,53),(54,56),(55,57),(58,60)]
|
|
88
|
+
[(3,4),(7,8),(11,12),(13,14),(15,16),(17,18),(19,20),(21,22),(23,24),(25,26),(27,28),(29,30),(31,32),(33,34),(35,36),(37,38),(39,40),(41,42),(43,44),(45,46),(47,48),(49,50),(51,52),(55,56),(59,60)]
|
|
89
|
+
""",
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
# This will be populated by initialize_networks()
|
|
93
|
+
OPTIMAL_NETWORKS: Dict[int, Tuple[int, int, List[List[Tuple[int, int]]]]] = {}
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def parse_network_string(network_str: str) -> List[List[Tuple[int, int]]]:
|
|
97
|
+
"""
|
|
98
|
+
Parse a sorting network string from bertdobbelaere.github.io format.
|
|
99
|
+
|
|
100
|
+
Examples:
|
|
101
|
+
Input: "[(0,2),(1,3)], [(0,1),(2,3)], [(1,2)]"
|
|
102
|
+
Output: [[(0, 2), (1, 3)], [(0, 1), (2, 3)], [(1, 2)]]
|
|
103
|
+
|
|
104
|
+
Input: "[(0,1)], [(1,2)], [(0,1)]"
|
|
105
|
+
Output: [[(0, 1)], [(1, 2)], [(0, 1)]]
|
|
106
|
+
"""
|
|
107
|
+
# Remove whitespace and split by '], ['
|
|
108
|
+
network_str = network_str.strip()
|
|
109
|
+
if not network_str:
|
|
110
|
+
return []
|
|
111
|
+
|
|
112
|
+
# Split into layer strings
|
|
113
|
+
layer_pattern = r"\[((?:\(\d+,\d+\)(?:,\(\d+,\d+\))*)?)\]"
|
|
114
|
+
layers = []
|
|
115
|
+
|
|
116
|
+
for match in re.finditer(layer_pattern, network_str):
|
|
117
|
+
layer_str = match.group(1)
|
|
118
|
+
if not layer_str.strip():
|
|
119
|
+
layers.append([])
|
|
120
|
+
continue
|
|
121
|
+
|
|
122
|
+
# Parse comparisons in this layer: (i,j), (k,l), ...
|
|
123
|
+
comparisons = []
|
|
124
|
+
comp_pattern = r"\((\d+),(\d+)\)"
|
|
125
|
+
|
|
126
|
+
for comp_match in re.finditer(comp_pattern, layer_str):
|
|
127
|
+
i, j = int(comp_match.group(1)), int(comp_match.group(2))
|
|
128
|
+
comparisons.append((i, j))
|
|
129
|
+
|
|
130
|
+
layers.append(comparisons)
|
|
131
|
+
|
|
132
|
+
return layers
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def calculate_network_stats(layers: List[List[Tuple[int, int]]]) -> Tuple[int, int, int]:
|
|
136
|
+
"""Calculate depth, total comparisons, and max index from network layers."""
|
|
137
|
+
depth = len(layers)
|
|
138
|
+
total_comparisons = sum(len(layer) for layer in layers)
|
|
139
|
+
|
|
140
|
+
# Find maximum index to determine network size
|
|
141
|
+
max_index = 0
|
|
142
|
+
for layer in layers:
|
|
143
|
+
for i, j in layer:
|
|
144
|
+
max_index = max(max_index, i, j)
|
|
145
|
+
|
|
146
|
+
network_size = max_index + 1 # Since indices are 0-based
|
|
147
|
+
return depth, total_comparisons, network_size
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def add_network_from_string(size: int, network_str: str, description: str = ""):
|
|
151
|
+
"""
|
|
152
|
+
Add a network from a string representation to the OPTIMAL_NETWORKS dictionary.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
size: Size of the network (number of elements)
|
|
156
|
+
network_str: Network string in bertdobbelaere.github.io format
|
|
157
|
+
description: Optional description for debugging
|
|
158
|
+
"""
|
|
159
|
+
try:
|
|
160
|
+
layers = parse_network_string(network_str)
|
|
161
|
+
depth, comparisons, detected_size = calculate_network_stats(layers)
|
|
162
|
+
|
|
163
|
+
if detected_size != size:
|
|
164
|
+
print(f"Warning: Network size mismatch! Expected {size}, detected {detected_size}")
|
|
165
|
+
print(f"Network string: {network_str[:100]}...")
|
|
166
|
+
return False
|
|
167
|
+
|
|
168
|
+
OPTIMAL_NETWORKS[size] = (depth, comparisons, layers)
|
|
169
|
+
|
|
170
|
+
if description:
|
|
171
|
+
print(f"Added network for size {size}: {description}")
|
|
172
|
+
print(f" Depth: {depth}, Comparisons: {comparisons}")
|
|
173
|
+
return True
|
|
174
|
+
|
|
175
|
+
except Exception as e:
|
|
176
|
+
print(f"Error parsing network for size {size}: {e}")
|
|
177
|
+
print(f"Network string: {network_str[:100]}...")
|
|
178
|
+
return False
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def generate_networks_dict(
|
|
182
|
+
networks_data: Dict[int, Tuple[int, int, List[List[Tuple[int, int]]]]]
|
|
183
|
+
) -> str:
|
|
184
|
+
"""Generate the global networks dictionary."""
|
|
185
|
+
lines = ["networks = {"]
|
|
186
|
+
|
|
187
|
+
for size, (depth, num_comparisons, layers) in sorted(networks_data.items()):
|
|
188
|
+
# Format the network with proper indentation and newlines
|
|
189
|
+
network_lines = []
|
|
190
|
+
for i, layer in enumerate(layers):
|
|
191
|
+
if i == 0:
|
|
192
|
+
network_lines.append(f" {layer}")
|
|
193
|
+
else:
|
|
194
|
+
network_lines.append(f",\n {layer}")
|
|
195
|
+
|
|
196
|
+
if len(layers) == 1:
|
|
197
|
+
network_str = f"[{network_lines[0].strip()}]"
|
|
198
|
+
else:
|
|
199
|
+
network_str = "[\n" + "".join(network_lines) + "\n ]"
|
|
200
|
+
|
|
201
|
+
lines.append(f" # Size {size}: {num_comparisons} CEs, depth {depth}")
|
|
202
|
+
lines.append(f" {size}: {network_str},")
|
|
203
|
+
lines.append("")
|
|
204
|
+
|
|
205
|
+
lines.append("}")
|
|
206
|
+
return "\n".join(lines)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def generate_optimal_sort_function() -> str:
|
|
210
|
+
"""Generate the single optimal_sort function that looks up networks by size."""
|
|
211
|
+
return """@cute.jit
|
|
212
|
+
def optimal_sort(
|
|
213
|
+
arr: cute.Tensor,
|
|
214
|
+
n: cutlass.Constexpr[int],
|
|
215
|
+
start: cutlass.Constexpr[int] = 0,
|
|
216
|
+
ascending: cutlass.Constexpr[bool] = True
|
|
217
|
+
) -> None:
|
|
218
|
+
\"\"\"
|
|
219
|
+
Optimal sorting network dispatcher.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
arr: Array to sort
|
|
223
|
+
n: Size of array (must be power of 2 and available in networks)
|
|
224
|
+
start: Starting index (default 0)
|
|
225
|
+
ascending: Sort in ascending order (default True)
|
|
226
|
+
|
|
227
|
+
Source: https://bertdobbelaere.github.io/sorting_networks.html
|
|
228
|
+
\"\"\"
|
|
229
|
+
assert n in networks
|
|
230
|
+
for level in networks[n]:
|
|
231
|
+
for i, j in level:
|
|
232
|
+
compare_and_swap(arr, start + i, start + j, ascending)
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def generate_sorting_networks_file(max_size: int = 64):
|
|
237
|
+
"""Generate a complete sorting networks file with optimal networks up to max_size."""
|
|
238
|
+
|
|
239
|
+
output_file = os.path.join(os.path.dirname(__file__), "sorting_networks.py")
|
|
240
|
+
|
|
241
|
+
# Header
|
|
242
|
+
header = '''# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
|
|
243
|
+
"""
|
|
244
|
+
Optimal sorting networks generated from: https://bertdobbelaere.github.io/sorting_networks.html
|
|
245
|
+
|
|
246
|
+
This file was auto-generated by quack/sort/generate_sorting_networks.py. Do not edit it directly.
|
|
247
|
+
"""
|
|
248
|
+
|
|
249
|
+
# fmt: off
|
|
250
|
+
# ruff: noqa
|
|
251
|
+
# isort: skip_file
|
|
252
|
+
|
|
253
|
+
import cutlass
|
|
254
|
+
import cutlass.cute as cute
|
|
255
|
+
|
|
256
|
+
from quack.sort.utils import compare_and_swap
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
'''
|
|
260
|
+
|
|
261
|
+
# Generate networks dictionary and optimal_sort function
|
|
262
|
+
sizes = [n for n in range(2, max_size + 1) if n in OPTIMAL_NETWORKS]
|
|
263
|
+
networks_dict = generate_networks_dict(OPTIMAL_NETWORKS)
|
|
264
|
+
optimal_sort_func = generate_optimal_sort_function()
|
|
265
|
+
|
|
266
|
+
# Combine everything
|
|
267
|
+
content = header + networks_dict + "\n\n\n" + optimal_sort_func
|
|
268
|
+
|
|
269
|
+
with open(output_file, "w") as f:
|
|
270
|
+
f.write(content)
|
|
271
|
+
|
|
272
|
+
print(f"Generated optimal sorting networks for sizes {sizes}")
|
|
273
|
+
print(f"Output written to: {output_file}")
|
|
274
|
+
return sizes
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def initialize_networks():
|
|
278
|
+
"""Initialize the OPTIMAL_NETWORKS dictionary by parsing NETWORK_STRINGS."""
|
|
279
|
+
global OPTIMAL_NETWORKS
|
|
280
|
+
OPTIMAL_NETWORKS.clear()
|
|
281
|
+
|
|
282
|
+
for size, network_str in NETWORK_STRINGS.items():
|
|
283
|
+
success = add_network_from_string(size, network_str, f"Size {size} optimal network")
|
|
284
|
+
if not success:
|
|
285
|
+
print(f"Warning: Failed to parse network for size {size}")
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def main():
|
|
289
|
+
parser = argparse.ArgumentParser(
|
|
290
|
+
description="Generate optimal sorting network code from bertdobbelaere.github.io data"
|
|
291
|
+
)
|
|
292
|
+
parser.add_argument(
|
|
293
|
+
"--max-size",
|
|
294
|
+
"-m",
|
|
295
|
+
type=int,
|
|
296
|
+
default=64,
|
|
297
|
+
help="Maximum sorting network size to generate (default: 32)",
|
|
298
|
+
)
|
|
299
|
+
parser.add_argument(
|
|
300
|
+
"--stats", "-s", action="store_true", help="Print statistics about the optimal networks"
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
args = parser.parse_args()
|
|
304
|
+
|
|
305
|
+
# Initialize networks from strings
|
|
306
|
+
initialize_networks()
|
|
307
|
+
|
|
308
|
+
if args.stats:
|
|
309
|
+
print("Optimal Sorting Network Statistics:")
|
|
310
|
+
print("Size\tDepth\tComparisons\tLayers")
|
|
311
|
+
print("-" * 35)
|
|
312
|
+
for n in sorted(OPTIMAL_NETWORKS.keys()):
|
|
313
|
+
if n <= args.max_size:
|
|
314
|
+
depth, comparisons, layers = OPTIMAL_NETWORKS[n]
|
|
315
|
+
print(f"{n}\t{depth}\t{comparisons}\t\t{len(layers)}")
|
|
316
|
+
|
|
317
|
+
# Generate the sorting networks file
|
|
318
|
+
sizes = generate_sorting_networks_file(args.max_size)
|
|
319
|
+
|
|
320
|
+
print(f"\nGenerated optimal sorting networks for {len(sizes)} sizes")
|
|
321
|
+
print(f"Total networks: {len(sizes)}")
|
|
322
|
+
print(f"Max network size: {max(sizes)}")
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
if __name__ == "__main__":
|
|
326
|
+
main()
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
|
|
2
|
+
"""
|
|
3
|
+
Optimal sorting networks generated from: https://bertdobbelaere.github.io/sorting_networks.html
|
|
4
|
+
|
|
5
|
+
This file was auto-generated by quack/sort/generate_sorting_networks.py. Do not edit it directly.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
# fmt: off
|
|
9
|
+
# ruff: noqa
|
|
10
|
+
# isort: skip_file
|
|
11
|
+
|
|
12
|
+
import cutlass
|
|
13
|
+
import cutlass.cute as cute
|
|
14
|
+
|
|
15
|
+
from quack.sort.utils import compare_and_swap
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
networks = {
|
|
19
|
+
# Size 2: 1 CEs, depth 1
|
|
20
|
+
2: [[(0, 1)]],
|
|
21
|
+
|
|
22
|
+
# Size 4: 5 CEs, depth 3
|
|
23
|
+
4: [
|
|
24
|
+
[(0, 2), (1, 3)],
|
|
25
|
+
[(0, 1), (2, 3)],
|
|
26
|
+
[(1, 2)]
|
|
27
|
+
],
|
|
28
|
+
|
|
29
|
+
# Size 8: 19 CEs, depth 6
|
|
30
|
+
8: [
|
|
31
|
+
[(0, 2), (1, 3), (4, 6), (5, 7)],
|
|
32
|
+
[(0, 4), (1, 5), (2, 6), (3, 7)],
|
|
33
|
+
[(0, 1), (2, 3), (4, 5), (6, 7)],
|
|
34
|
+
[(2, 4), (3, 5)],
|
|
35
|
+
[(1, 4), (3, 6)],
|
|
36
|
+
[(1, 2), (3, 4), (5, 6)]
|
|
37
|
+
],
|
|
38
|
+
|
|
39
|
+
# Size 16: 60 CEs, depth 10
|
|
40
|
+
16: [
|
|
41
|
+
[(0, 13), (1, 12), (2, 15), (3, 14), (4, 8), (5, 6), (7, 11), (9, 10)],
|
|
42
|
+
[(0, 5), (1, 7), (2, 9), (3, 4), (6, 13), (8, 14), (10, 15), (11, 12)],
|
|
43
|
+
[(0, 1), (2, 3), (4, 5), (6, 8), (7, 9), (10, 11), (12, 13), (14, 15)],
|
|
44
|
+
[(0, 2), (1, 3), (4, 10), (5, 11), (6, 7), (8, 9), (12, 14), (13, 15)],
|
|
45
|
+
[(1, 2), (3, 12), (4, 6), (5, 7), (8, 10), (9, 11), (13, 14)],
|
|
46
|
+
[(1, 4), (2, 6), (5, 8), (7, 10), (9, 13), (11, 14)],
|
|
47
|
+
[(2, 4), (3, 6), (9, 12), (11, 13)],
|
|
48
|
+
[(3, 5), (6, 8), (7, 9), (10, 12)],
|
|
49
|
+
[(3, 4), (5, 6), (7, 8), (9, 10), (11, 12)],
|
|
50
|
+
[(6, 7), (8, 9)]
|
|
51
|
+
],
|
|
52
|
+
|
|
53
|
+
# Size 32: 185 CEs, depth 14
|
|
54
|
+
32: [
|
|
55
|
+
[(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19), (20, 21), (22, 23), (24, 25), (26, 27), (28, 29), (30, 31)],
|
|
56
|
+
[(0, 2), (1, 3), (4, 6), (5, 7), (8, 10), (9, 11), (12, 14), (13, 15), (16, 18), (17, 19), (20, 22), (21, 23), (24, 26), (25, 27), (28, 30), (29, 31)],
|
|
57
|
+
[(0, 4), (1, 5), (2, 6), (3, 7), (8, 12), (9, 13), (10, 14), (11, 15), (16, 20), (17, 21), (18, 22), (19, 23), (24, 28), (25, 29), (26, 30), (27, 31)],
|
|
58
|
+
[(0, 8), (1, 9), (2, 10), (3, 11), (4, 12), (5, 13), (6, 14), (7, 15), (16, 24), (17, 25), (18, 26), (19, 27), (20, 28), (21, 29), (22, 30), (23, 31)],
|
|
59
|
+
[(0, 16), (1, 8), (2, 4), (3, 12), (5, 10), (6, 9), (7, 14), (11, 13), (15, 31), (17, 24), (18, 20), (19, 28), (21, 26), (22, 25), (23, 30), (27, 29)],
|
|
60
|
+
[(1, 2), (3, 5), (4, 8), (6, 22), (7, 11), (9, 25), (10, 12), (13, 14), (17, 18), (19, 21), (20, 24), (23, 27), (26, 28), (29, 30)],
|
|
61
|
+
[(1, 17), (2, 18), (3, 19), (4, 20), (5, 10), (7, 23), (8, 24), (11, 27), (12, 28), (13, 29), (14, 30), (21, 26)],
|
|
62
|
+
[(3, 17), (4, 16), (5, 21), (6, 18), (7, 9), (8, 20), (10, 26), (11, 23), (13, 25), (14, 28), (15, 27), (22, 24)],
|
|
63
|
+
[(1, 4), (3, 8), (5, 16), (7, 17), (9, 21), (10, 22), (11, 19), (12, 20), (14, 24), (15, 26), (23, 28), (27, 30)],
|
|
64
|
+
[(2, 5), (7, 8), (9, 18), (11, 17), (12, 16), (13, 22), (14, 20), (15, 19), (23, 24), (26, 29)],
|
|
65
|
+
[(2, 4), (6, 12), (9, 16), (10, 11), (13, 17), (14, 18), (15, 22), (19, 25), (20, 21), (27, 29)],
|
|
66
|
+
[(5, 6), (8, 12), (9, 10), (11, 13), (14, 16), (15, 17), (18, 20), (19, 23), (21, 22), (25, 26)],
|
|
67
|
+
[(3, 5), (6, 7), (8, 9), (10, 12), (11, 14), (13, 16), (15, 18), (17, 20), (19, 21), (22, 23), (24, 25), (26, 28)],
|
|
68
|
+
[(3, 4), (5, 6), (7, 8), (9, 10), (11, 12), (13, 14), (15, 16), (17, 18), (19, 20), (21, 22), (23, 24), (25, 26), (27, 28)]
|
|
69
|
+
],
|
|
70
|
+
|
|
71
|
+
# Size 64: 521 CEs, depth 21
|
|
72
|
+
64: [
|
|
73
|
+
[(0, 2), (1, 3), (4, 6), (5, 7), (8, 10), (9, 11), (12, 14), (13, 15), (16, 18), (17, 19), (20, 22), (21, 23), (24, 26), (25, 27), (28, 30), (29, 31), (32, 34), (33, 35), (36, 38), (37, 39), (40, 42), (41, 43), (44, 46), (45, 47), (48, 50), (49, 51), (52, 54), (53, 55), (56, 58), (57, 59), (60, 62), (61, 63)],
|
|
74
|
+
[(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19), (20, 21), (22, 23), (24, 25), (26, 27), (28, 29), (30, 31), (32, 33), (34, 35), (36, 37), (38, 39), (40, 41), (42, 43), (44, 45), (46, 47), (48, 49), (50, 51), (52, 53), (54, 55), (56, 57), (58, 59), (60, 61), (62, 63)],
|
|
75
|
+
[(0, 52), (1, 2), (3, 55), (4, 48), (5, 6), (7, 51), (8, 60), (9, 10), (11, 63), (12, 56), (13, 14), (15, 59), (16, 32), (17, 18), (19, 35), (20, 24), (21, 22), (23, 27), (25, 26), (28, 44), (29, 30), (31, 47), (33, 34), (36, 40), (37, 38), (39, 43), (41, 42), (45, 46), (49, 50), (53, 54), (57, 58), (61, 62)],
|
|
76
|
+
[(0, 20), (1, 53), (2, 54), (3, 23), (4, 28), (5, 49), (6, 50), (7, 31), (8, 36), (9, 61), (10, 62), (11, 39), (12, 16), (13, 57), (14, 58), (15, 19), (17, 33), (18, 34), (21, 25), (22, 26), (24, 52), (27, 55), (29, 45), (30, 46), (32, 56), (35, 59), (37, 41), (38, 42), (40, 60), (43, 63), (44, 48), (47, 51)],
|
|
77
|
+
[(0, 4), (1, 21), (2, 22), (3, 7), (5, 29), (6, 30), (8, 12), (9, 37), (10, 38), (11, 15), (13, 17), (14, 18), (16, 20), (19, 23), (24, 32), (25, 53), (26, 54), (27, 35), (28, 36), (31, 39), (33, 57), (34, 58), (40, 44), (41, 61), (42, 62), (43, 47), (45, 49), (46, 50), (48, 52), (51, 55), (56, 60), (59, 63)],
|
|
78
|
+
[(0, 8), (1, 5), (2, 6), (3, 11), (4, 12), (7, 15), (9, 13), (10, 14), (16, 40), (17, 21), (18, 22), (19, 43), (20, 44), (23, 47), (24, 28), (25, 33), (26, 34), (27, 31), (29, 37), (30, 38), (32, 36), (35, 39), (41, 45), (42, 46), (48, 56), (49, 53), (50, 54), (51, 59), (52, 60), (55, 63), (57, 61), (58, 62)],
|
|
79
|
+
[(1, 9), (2, 10), (4, 8), (5, 13), (6, 14), (7, 11), (12, 48), (15, 51), (16, 24), (17, 41), (18, 42), (19, 27), (20, 28), (21, 45), (22, 46), (23, 31), (25, 29), (26, 30), (32, 40), (33, 37), (34, 38), (35, 43), (36, 44), (39, 47), (49, 57), (50, 58), (52, 56), (53, 61), (54, 62), (55, 59)],
|
|
80
|
+
[(4, 16), (5, 9), (6, 10), (7, 19), (8, 24), (11, 27), (13, 49), (14, 50), (17, 25), (18, 26), (20, 32), (21, 29), (22, 30), (23, 35), (28, 40), (31, 43), (33, 41), (34, 42), (36, 52), (37, 45), (38, 46), (39, 55), (44, 56), (47, 59), (53, 57), (54, 58)],
|
|
81
|
+
[(1, 4), (5, 17), (6, 18), (8, 16), (9, 25), (10, 26), (11, 19), (12, 24), (15, 27), (21, 33), (22, 34), (29, 41), (30, 42), (36, 48), (37, 53), (38, 54), (39, 51), (44, 52), (45, 57), (46, 58), (47, 55), (59, 62)],
|
|
82
|
+
[(2, 8), (9, 17), (10, 18), (12, 20), (13, 25), (14, 26), (15, 23), (24, 32), (27, 35), (28, 36), (31, 39), (37, 49), (38, 50), (40, 48), (43, 51), (45, 53), (46, 54), (55, 61)],
|
|
83
|
+
[(2, 4), (12, 16), (13, 21), (14, 22), (15, 19), (20, 24), (23, 27), (25, 33), (26, 34), (28, 32), (29, 37), (30, 38), (31, 35), (36, 40), (39, 43), (41, 49), (42, 50), (44, 48), (47, 51), (59, 61)],
|
|
84
|
+
[(4, 16), (5, 20), (10, 40), (13, 17), (14, 18), (21, 25), (22, 26), (23, 53), (24, 28), (27, 31), (29, 33), (30, 34), (32, 36), (35, 39), (37, 41), (38, 42), (43, 58), (45, 49), (46, 50), (47, 59)],
|
|
85
|
+
[(3, 17), (6, 36), (7, 21), (8, 32), (9, 24), (11, 41), (13, 28), (14, 44), (15, 45), (18, 48), (19, 49), (22, 52), (25, 29), (26, 30), (27, 57), (31, 55), (33, 37), (34, 38), (35, 50), (39, 54), (42, 56), (46, 60)],
|
|
86
|
+
[(6, 20), (8, 16), (10, 24), (11, 25), (14, 28), (15, 29), (17, 33), (18, 32), (21, 37), (22, 36), (26, 42), (27, 41), (30, 46), (31, 45), (34, 48), (35, 49), (38, 52), (39, 53), (43, 57), (47, 55)],
|
|
87
|
+
[(3, 18), (5, 8), (6, 12), (7, 22), (15, 21), (17, 32), (19, 33), (23, 37), (26, 40), (30, 44), (31, 46), (41, 56), (42, 48), (45, 60), (51, 57), (55, 58)],
|
|
88
|
+
[(3, 16), (7, 20), (11, 26), (18, 24), (19, 25), (22, 28), (23, 29), (27, 33), (30, 36), (34, 40), (35, 41), (37, 52), (38, 44), (39, 45), (43, 56), (47, 60)],
|
|
89
|
+
[(3, 9), (7, 13), (10, 16), (11, 17), (14, 20), (15, 30), (19, 34), (21, 36), (23, 38), (25, 40), (26, 32), (27, 42), (29, 44), (31, 37), (33, 48), (43, 49), (46, 52), (47, 53), (50, 56), (54, 60)],
|
|
90
|
+
[(3, 8), (7, 10), (9, 12), (11, 18), (13, 14), (15, 24), (17, 22), (19, 28), (21, 26), (23, 25), (27, 34), (29, 36), (30, 32), (31, 33), (35, 44), (37, 42), (38, 40), (39, 48), (41, 46), (45, 52), (49, 50), (51, 54), (53, 56), (55, 60)],
|
|
91
|
+
[(3, 6), (7, 12), (11, 16), (15, 17), (18, 20), (19, 24), (21, 22), (23, 30), (25, 32), (26, 28), (27, 29), (31, 38), (33, 40), (34, 36), (35, 37), (39, 44), (41, 42), (43, 45), (46, 48), (47, 52), (51, 56), (57, 60)],
|
|
92
|
+
[(3, 5), (6, 8), (7, 9), (10, 12), (11, 13), (14, 16), (15, 18), (17, 20), (19, 21), (22, 24), (23, 26), (25, 28), (27, 30), (29, 32), (31, 34), (33, 36), (35, 38), (37, 40), (39, 41), (42, 44), (43, 46), (45, 48), (47, 49), (50, 52), (51, 53), (54, 56), (55, 57), (58, 60)],
|
|
93
|
+
[(3, 4), (7, 8), (11, 12), (13, 14), (15, 16), (17, 18), (19, 20), (21, 22), (23, 24), (25, 26), (27, 28), (29, 30), (31, 32), (33, 34), (35, 36), (37, 38), (39, 40), (41, 42), (43, 44), (45, 46), (47, 48), (49, 50), (51, 52), (55, 56), (59, 60)]
|
|
94
|
+
],
|
|
95
|
+
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@cute.jit
|
|
100
|
+
def optimal_sort(
|
|
101
|
+
arr: cute.Tensor,
|
|
102
|
+
n: cutlass.Constexpr[int],
|
|
103
|
+
start: cutlass.Constexpr[int] = 0,
|
|
104
|
+
ascending: cutlass.Constexpr[bool] = True
|
|
105
|
+
) -> None:
|
|
106
|
+
"""
|
|
107
|
+
Optimal sorting network dispatcher.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
arr: Array to sort
|
|
111
|
+
n: Size of array (must be power of 2 and available in networks)
|
|
112
|
+
start: Starting index (default 0)
|
|
113
|
+
ascending: Sort in ascending order (default True)
|
|
114
|
+
|
|
115
|
+
Source: https://bertdobbelaere.github.io/sorting_networks.html
|
|
116
|
+
"""
|
|
117
|
+
assert n in networks
|
|
118
|
+
for level in networks[n]:
|
|
119
|
+
for i, j in level:
|
|
120
|
+
compare_and_swap(arr, start + i, start + j, ascending)
|
quack/sort/utils.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import cutlass
|
|
2
|
+
import cutlass.cute as cute
|
|
3
|
+
|
|
4
|
+
import quack.utils as utils
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@cute.jit
|
|
8
|
+
def compare_and_swap(
|
|
9
|
+
arr: cute.Tensor, i: int, j: int, ascending: bool = True, use_selection: bool = False
|
|
10
|
+
) -> None:
|
|
11
|
+
"""Compare and swap elements at indices i and j in ascending or descending order."""
|
|
12
|
+
if cutlass.const_expr(use_selection):
|
|
13
|
+
a, b = arr[i], arr[j]
|
|
14
|
+
if (a > b) ^ (not ascending):
|
|
15
|
+
arr[i] = b
|
|
16
|
+
arr[j] = a
|
|
17
|
+
# if cutlass.const_expr(ascending):
|
|
18
|
+
# if a > b:
|
|
19
|
+
# arr[i] = b
|
|
20
|
+
# arr[j] = a
|
|
21
|
+
# else:
|
|
22
|
+
# if a < b:
|
|
23
|
+
# arr[i] = b
|
|
24
|
+
# arr[j] = a
|
|
25
|
+
else:
|
|
26
|
+
min_fn = min if cutlass.const_expr(arr.element_type != cutlass.Float32) else utils.fmin
|
|
27
|
+
max_fn = max if cutlass.const_expr(arr.element_type != cutlass.Float32) else cute.arch.fmax
|
|
28
|
+
if cutlass.const_expr(ascending):
|
|
29
|
+
arr[i], arr[j] = min_fn(arr[i], arr[j]), max_fn(arr[i], arr[j])
|
|
30
|
+
else:
|
|
31
|
+
arr[i], arr[j] = max_fn(arr[i], arr[j]), min_fn(arr[i], arr[j])
|