quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +8 -1
- quack/activation.py +288 -0
- quack/autotuner.py +310 -0
- quack/cross_entropy.py +325 -175
- quack/cute_dsl_utils.py +119 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +1657 -842
- quack/fast_math.py +80 -0
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +69 -0
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +569 -0
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +5 -3
- quack/linear.py +240 -0
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +74 -0
- quack/pipeline.py +151 -0
- quack/reduce.py +241 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +583 -231
- quack/softmax.py +27 -15
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2091 -0
- quack/tensormap_manager.py +115 -0
- quack/tile_scheduler.py +937 -0
- quack/topk.py +227 -0
- quack/utils.py +203 -230
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/METADATA +2 -2
- quack_kernels-0.2.0.dist-info/RECORD +37 -0
- quack_kernels-0.1.10.dist-info/RECORD +0 -13
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/top_level.txt +0 -0
quack/softmax.py
CHANGED
|
@@ -9,7 +9,9 @@ import cutlass.cute as cute
|
|
|
9
9
|
from cutlass.cute.runtime import from_dlpack
|
|
10
10
|
|
|
11
11
|
import quack.utils as utils
|
|
12
|
-
from quack.
|
|
12
|
+
from quack.reduce import row_reduce, online_softmax_reduce
|
|
13
|
+
from quack.reduction_base import ReductionBase
|
|
14
|
+
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
class Softmax(ReductionBase):
|
|
@@ -147,7 +149,7 @@ class Softmax(ReductionBase):
|
|
|
147
149
|
x = tXrX.load().to(cute.Float32)
|
|
148
150
|
threads_per_row = tv_layout.shape[0][0]
|
|
149
151
|
if cutlass.const_expr(not self.online_softmax):
|
|
150
|
-
max_x =
|
|
152
|
+
max_x = row_reduce(
|
|
151
153
|
x,
|
|
152
154
|
cute.ReductionOp.MAX,
|
|
153
155
|
threads_per_row,
|
|
@@ -158,7 +160,7 @@ class Softmax(ReductionBase):
|
|
|
158
160
|
)
|
|
159
161
|
log2_e = math.log2(math.e)
|
|
160
162
|
exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
|
|
161
|
-
denom =
|
|
163
|
+
denom = row_reduce(
|
|
162
164
|
exp_x,
|
|
163
165
|
cute.ReductionOp.ADD,
|
|
164
166
|
threads_per_row,
|
|
@@ -167,7 +169,7 @@ class Softmax(ReductionBase):
|
|
|
167
169
|
init_val=0.0,
|
|
168
170
|
)
|
|
169
171
|
else:
|
|
170
|
-
max_x, denom, exp_x =
|
|
172
|
+
max_x, denom, exp_x = online_softmax_reduce(
|
|
171
173
|
x,
|
|
172
174
|
threads_per_row,
|
|
173
175
|
reduction_buffer[None, None, 0],
|
|
@@ -186,7 +188,8 @@ class Softmax(ReductionBase):
|
|
|
186
188
|
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
187
189
|
|
|
188
190
|
|
|
189
|
-
|
|
191
|
+
@torch.library.custom_op("quack::_softmax_fwd", mutates_args={"out"})
|
|
192
|
+
def _softmax_fwd(x: torch.Tensor, out: torch.Tensor) -> None:
|
|
190
193
|
"""Softmax forward pass.
|
|
191
194
|
Args:
|
|
192
195
|
x: Input tensor of shape (M, N)
|
|
@@ -196,8 +199,7 @@ def _softmax_fwd(x: torch.Tensor) -> torch.Tensor:
|
|
|
196
199
|
assert x.dim() == 2, "Input must be 2D"
|
|
197
200
|
assert x.is_cuda, "Tensor must be on CUDA device"
|
|
198
201
|
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
199
|
-
|
|
200
|
-
out = torch.empty_like(x)
|
|
202
|
+
N = x.size(1)
|
|
201
203
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
202
204
|
convert_from_dlpack = lambda tensor: (
|
|
203
205
|
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
@@ -213,12 +215,17 @@ def _softmax_fwd(x: torch.Tensor) -> torch.Tensor:
|
|
|
213
215
|
softmax_op, x_tensor, out_tensor, current_stream
|
|
214
216
|
)
|
|
215
217
|
_softmax_fwd.compile_cache[compile_key](x_tensor, out_tensor, current_stream)
|
|
216
|
-
return out
|
|
217
218
|
|
|
218
219
|
|
|
219
220
|
_softmax_fwd.compile_cache = {}
|
|
220
221
|
|
|
221
222
|
|
|
223
|
+
def softmax_fwd(x: torch.Tensor) -> torch.Tensor:
|
|
224
|
+
out = torch.empty_like(x)
|
|
225
|
+
_softmax_fwd(x, out)
|
|
226
|
+
return out
|
|
227
|
+
|
|
228
|
+
|
|
222
229
|
class SoftmaxBackward(ReductionBase):
|
|
223
230
|
def __init__(self, dtype: Type[cutlass.Numeric], N: int):
|
|
224
231
|
# 1 stage for computing dot product
|
|
@@ -372,7 +379,7 @@ class SoftmaxBackward(ReductionBase):
|
|
|
372
379
|
|
|
373
380
|
# Compute dot product: dot = Σⱼ dy_j × y_j
|
|
374
381
|
threads_per_row = tv_layout.shape[0][0]
|
|
375
|
-
dot =
|
|
382
|
+
dot = row_reduce(
|
|
376
383
|
dy * y,
|
|
377
384
|
cute.ReductionOp.ADD,
|
|
378
385
|
threads_per_row,
|
|
@@ -394,7 +401,8 @@ class SoftmaxBackward(ReductionBase):
|
|
|
394
401
|
cute.copy(copy_atom_store, tdXrdX, tdXgdX, pred=tdXpdX)
|
|
395
402
|
|
|
396
403
|
|
|
397
|
-
|
|
404
|
+
@torch.library.custom_op("quack::_softmax_backward", mutates_args={"dx"})
|
|
405
|
+
def _softmax_backward(dy: torch.Tensor, y: torch.Tensor, dx: torch.Tensor) -> None:
|
|
398
406
|
"""Softmax backward pass.
|
|
399
407
|
Args:
|
|
400
408
|
dy: Upstream gradients tensor of shape (M, N)
|
|
@@ -409,8 +417,7 @@ def _softmax_backward(dy: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
|
409
417
|
assert dy.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
410
418
|
assert y.dtype == dy.dtype, "dy and y must have same dtype"
|
|
411
419
|
|
|
412
|
-
|
|
413
|
-
dx = torch.empty_like(dy)
|
|
420
|
+
N = dy.size(1)
|
|
414
421
|
dtype = torch2cute_dtype_map[dy.dtype]
|
|
415
422
|
convert_from_dlpack = lambda tensor: (
|
|
416
423
|
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
@@ -427,23 +434,28 @@ def _softmax_backward(dy: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
|
427
434
|
softmax_backward_op, dy_tensor, y_tensor, dx_tensor, current_stream
|
|
428
435
|
)
|
|
429
436
|
_softmax_backward.compile_cache[compile_key](dy_tensor, y_tensor, dx_tensor, current_stream)
|
|
430
|
-
return dx
|
|
431
437
|
|
|
432
438
|
|
|
433
439
|
_softmax_backward.compile_cache = {}
|
|
434
440
|
|
|
435
441
|
|
|
442
|
+
def softmax_bwd(dy: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
443
|
+
dx = torch.empty_like(dy)
|
|
444
|
+
_softmax_backward(dy, y, dx)
|
|
445
|
+
return dx
|
|
446
|
+
|
|
447
|
+
|
|
436
448
|
class SoftmaxFunction(torch.autograd.Function):
|
|
437
449
|
@staticmethod
|
|
438
450
|
def forward(ctx, x):
|
|
439
|
-
y =
|
|
451
|
+
y = softmax_fwd(x)
|
|
440
452
|
ctx.save_for_backward(y)
|
|
441
453
|
return y
|
|
442
454
|
|
|
443
455
|
@staticmethod
|
|
444
456
|
def backward(ctx, dy):
|
|
445
457
|
(y,) = ctx.saved_tensors
|
|
446
|
-
dx =
|
|
458
|
+
dx = softmax_bwd(dy, y)
|
|
447
459
|
return dx
|
|
448
460
|
|
|
449
461
|
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import cutlass
|
|
7
|
+
import cutlass.cute as cute
|
|
8
|
+
|
|
9
|
+
import quack.utils as utils
|
|
10
|
+
from quack.sort.utils import compare_and_swap
|
|
11
|
+
from quack.sort.sorting_networks import optimal_sort
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@cute.jit
|
|
15
|
+
def bitonic_merge(
|
|
16
|
+
arr: cute.Tensor,
|
|
17
|
+
n: cutlass.Constexpr[int],
|
|
18
|
+
start: cutlass.Constexpr[int],
|
|
19
|
+
ascending: cutlass.Constexpr[bool] = True,
|
|
20
|
+
) -> None:
|
|
21
|
+
"""Merge a bitonic sequence into a sorted sequence using iterative approach."""
|
|
22
|
+
if cutlass.const_expr(n > 1):
|
|
23
|
+
num_levels = int(math.log2(n))
|
|
24
|
+
assert n == 2**num_levels, "n must be a power of 2"
|
|
25
|
+
# This one must be range_constexpr otherwise it's very slow for n = 128
|
|
26
|
+
for level in cutlass.range_constexpr(num_levels):
|
|
27
|
+
length = n >> level # n // (2^level)
|
|
28
|
+
step = length // 2
|
|
29
|
+
for i in cutlass.range(n // length, unroll_full=True):
|
|
30
|
+
start_i = start + i * length
|
|
31
|
+
for j in cutlass.range(step, unroll_full=True):
|
|
32
|
+
compare_and_swap(arr, start_i + j, start_i + j + step, ascending)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@cute.jit
|
|
36
|
+
def bitonic_sort(
|
|
37
|
+
arr: cute.Tensor,
|
|
38
|
+
n: Optional[cutlass.Constexpr[int]] = None,
|
|
39
|
+
start: cutlass.Constexpr[int] = 0,
|
|
40
|
+
ascending: cutlass.Constexpr[bool] = True,
|
|
41
|
+
) -> None:
|
|
42
|
+
"""
|
|
43
|
+
Bitonic sort for small arrays of size N (power of 2, N <= 128).
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
arr: Array to sort
|
|
47
|
+
n: Size of array (must be power of 2 and <= 128)
|
|
48
|
+
start: Starting index (default 0)
|
|
49
|
+
ascending: Sort in ascending order (default True)
|
|
50
|
+
"""
|
|
51
|
+
if cutlass.const_expr(n is None):
|
|
52
|
+
n = cute.size(arr.shape)
|
|
53
|
+
assert n <= 128
|
|
54
|
+
if cutlass.const_expr(n > 1):
|
|
55
|
+
if cutlass.const_expr(n in [2, 4, 8, 16, 32, 64]):
|
|
56
|
+
optimal_sort(arr, n, start, ascending)
|
|
57
|
+
else: # Fall back to bitonic sort
|
|
58
|
+
assert n % 2 == 0
|
|
59
|
+
# Sort first half in ascending order
|
|
60
|
+
bitonic_sort(arr, n // 2, start, True)
|
|
61
|
+
# Sort second half in descending order
|
|
62
|
+
bitonic_sort(arr, n // 2, start + n // 2, False)
|
|
63
|
+
# Merge the whole sequence
|
|
64
|
+
bitonic_merge(arr, n, start, ascending)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@cute.jit
|
|
68
|
+
def bitonic_topk_merge(
|
|
69
|
+
arr0: cute.Tensor,
|
|
70
|
+
arr1: cute.Tensor,
|
|
71
|
+
k: Optional[cutlass.Constexpr[int]] = None,
|
|
72
|
+
start0: cutlass.Constexpr[int] = 0,
|
|
73
|
+
start1: cutlass.Constexpr[int] = 0,
|
|
74
|
+
ascending: cutlass.Constexpr[bool] = False,
|
|
75
|
+
) -> None:
|
|
76
|
+
if cutlass.const_expr(k is None):
|
|
77
|
+
k = cute.size(arr0.shape)
|
|
78
|
+
if cutlass.const_expr(arr0.element_type == cutlass.Float32):
|
|
79
|
+
minmax_fn = utils.fmin if ascending else cute.arch.fmax
|
|
80
|
+
else:
|
|
81
|
+
minmax_fn = min if ascending else max
|
|
82
|
+
# Write the top k elements to the first half of the array
|
|
83
|
+
for i in cutlass.range(k, unfoll_full=True):
|
|
84
|
+
arr0[start0 + i] = minmax_fn(arr0[start0 + i], arr1[start1 + k - 1 - i])
|
|
85
|
+
# Now the 1st half is bitonic, we just need to merge it
|
|
86
|
+
bitonic_merge(arr0, k, start0, ascending)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@cute.jit
|
|
90
|
+
def bitonic_topk(
|
|
91
|
+
arr: cute.Tensor,
|
|
92
|
+
k: cutlass.Constexpr[int],
|
|
93
|
+
ascending: cutlass.Constexpr[bool] = False,
|
|
94
|
+
warp_width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
|
|
95
|
+
) -> cute.Tensor:
|
|
96
|
+
"""
|
|
97
|
+
Bitonic top-k for small arrays of size N (power of 2, N <= 128).
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
arr: Array to sort
|
|
101
|
+
k: must be power of 2 and <= 128
|
|
102
|
+
ascending: Sort in ascending order (default False)
|
|
103
|
+
"""
|
|
104
|
+
assert arr.element_type in [cutlass.Float32, cutlass.Int32]
|
|
105
|
+
n = cute.size(arr.shape)
|
|
106
|
+
assert k == 1 << int(math.log2(k)), "k must be a power of 2"
|
|
107
|
+
assert n % k == 0, "n must be divisible by k"
|
|
108
|
+
topk_vals = cute.make_fragment(k, arr.element_type)
|
|
109
|
+
for v in cutlass.range(k, unroll_full=True):
|
|
110
|
+
topk_vals[v] = arr[v]
|
|
111
|
+
bitonic_sort(topk_vals, ascending=ascending)
|
|
112
|
+
other_vals = cute.make_fragment(k, arr.element_type)
|
|
113
|
+
for i in cutlass.range(1, n // k, unroll_full=True):
|
|
114
|
+
for v in cutlass.range(k, unroll_full=True):
|
|
115
|
+
other_vals[v] = arr[i * k + v]
|
|
116
|
+
bitonic_sort(other_vals, ascending=ascending)
|
|
117
|
+
# Merge 2 sorted top-k sequences to get a new top-k sequence
|
|
118
|
+
bitonic_topk_merge(topk_vals, other_vals, ascending=ascending)
|
|
119
|
+
# TODO: this is not efficient for large k (e.g. >= 16) since threads in the same warps
|
|
120
|
+
# do duplicate work.
|
|
121
|
+
for i in cutlass.range(int(math.log2(warp_width)), unroll_full=True):
|
|
122
|
+
other_vals = cute.make_fragment(k, arr.element_type)
|
|
123
|
+
for v in cutlass.range(k, unroll_full=True):
|
|
124
|
+
other_vals[v] = cute.arch.shuffle_sync_bfly(topk_vals[v], offset=1 << i)
|
|
125
|
+
bitonic_topk_merge(topk_vals, other_vals, ascending=ascending)
|
|
126
|
+
return topk_vals
|
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Generate optimized sorting network code from the optimal sorting network data.
|
|
4
|
+
Based on data from: https://bertdobbelaere.github.io/sorting_networks.html
|
|
5
|
+
|
|
6
|
+
This script generates CUTE DSL functions for optimal sorting networks of various sizes.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import argparse
|
|
10
|
+
import os
|
|
11
|
+
import re
|
|
12
|
+
from typing import List, Tuple, Dict
|
|
13
|
+
|
|
14
|
+
# Network strings from bertdobbelaere.github.io/sorting_networks.html
|
|
15
|
+
# Copy-paste network strings here, then run initialize_networks() to parse them
|
|
16
|
+
NETWORK_STRINGS = {
|
|
17
|
+
# Size 2: 1 CE, depth 1
|
|
18
|
+
2: """
|
|
19
|
+
[(0,1)]
|
|
20
|
+
""",
|
|
21
|
+
# Size 4: 5 CEs, depth 3
|
|
22
|
+
4: """
|
|
23
|
+
[(0,2),(1,3)]
|
|
24
|
+
[(0,1),(2,3)]
|
|
25
|
+
[(1,2)]
|
|
26
|
+
""",
|
|
27
|
+
# Size 8: 19 CEs, depth 6
|
|
28
|
+
8: """
|
|
29
|
+
[(0,2),(1,3),(4,6),(5,7)]
|
|
30
|
+
[(0,4),(1,5),(2,6),(3,7)]
|
|
31
|
+
[(0,1),(2,3),(4,5),(6,7)]
|
|
32
|
+
[(2,4),(3,5)]
|
|
33
|
+
[(1,4),(3,6)]
|
|
34
|
+
[(1,2),(3,4),(5,6)]
|
|
35
|
+
""",
|
|
36
|
+
# Size 16: 60 CEs, depth 10
|
|
37
|
+
16: """
|
|
38
|
+
[(0,13),(1,12),(2,15),(3,14),(4,8),(5,6),(7,11),(9,10)]
|
|
39
|
+
[(0,5),(1,7),(2,9),(3,4),(6,13),(8,14),(10,15),(11,12)]
|
|
40
|
+
[(0,1),(2,3),(4,5),(6,8),(7,9),(10,11),(12,13),(14,15)]
|
|
41
|
+
[(0,2),(1,3),(4,10),(5,11),(6,7),(8,9),(12,14),(13,15)]
|
|
42
|
+
[(1,2),(3,12),(4,6),(5,7),(8,10),(9,11),(13,14)]
|
|
43
|
+
[(1,4),(2,6),(5,8),(7,10),(9,13),(11,14)]
|
|
44
|
+
[(2,4),(3,6),(9,12),(11,13)]
|
|
45
|
+
[(3,5),(6,8),(7,9),(10,12)]
|
|
46
|
+
[(3,4),(5,6),(7,8),(9,10),(11,12)]
|
|
47
|
+
[(6,7),(8,9)]
|
|
48
|
+
""",
|
|
49
|
+
# Size 32: 185 CEs, depth 14
|
|
50
|
+
32: """
|
|
51
|
+
[(0,1),(2,3),(4,5),(6,7),(8,9),(10,11),(12,13),(14,15),(16,17),(18,19),(20,21),(22,23),(24,25),(26,27),(28,29),(30,31)]
|
|
52
|
+
[(0,2),(1,3),(4,6),(5,7),(8,10),(9,11),(12,14),(13,15),(16,18),(17,19),(20,22),(21,23),(24,26),(25,27),(28,30),(29,31)]
|
|
53
|
+
[(0,4),(1,5),(2,6),(3,7),(8,12),(9,13),(10,14),(11,15),(16,20),(17,21),(18,22),(19,23),(24,28),(25,29),(26,30),(27,31)]
|
|
54
|
+
[(0,8),(1,9),(2,10),(3,11),(4,12),(5,13),(6,14),(7,15),(16,24),(17,25),(18,26),(19,27),(20,28),(21,29),(22,30),(23,31)]
|
|
55
|
+
[(0,16),(1,8),(2,4),(3,12),(5,10),(6,9),(7,14),(11,13),(15,31),(17,24),(18,20),(19,28),(21,26),(22,25),(23,30),(27,29)]
|
|
56
|
+
[(1,2),(3,5),(4,8),(6,22),(7,11),(9,25),(10,12),(13,14),(17,18),(19,21),(20,24),(23,27),(26,28),(29,30)]
|
|
57
|
+
[(1,17),(2,18),(3,19),(4,20),(5,10),(7,23),(8,24),(11,27),(12,28),(13,29),(14,30),(21,26)]
|
|
58
|
+
[(3,17),(4,16),(5,21),(6,18),(7,9),(8,20),(10,26),(11,23),(13,25),(14,28),(15,27),(22,24)]
|
|
59
|
+
[(1,4),(3,8),(5,16),(7,17),(9,21),(10,22),(11,19),(12,20),(14,24),(15,26),(23,28),(27,30)]
|
|
60
|
+
[(2,5),(7,8),(9,18),(11,17),(12,16),(13,22),(14,20),(15,19),(23,24),(26,29)]
|
|
61
|
+
[(2,4),(6,12),(9,16),(10,11),(13,17),(14,18),(15,22),(19,25),(20,21),(27,29)]
|
|
62
|
+
[(5,6),(8,12),(9,10),(11,13),(14,16),(15,17),(18,20),(19,23),(21,22),(25,26)]
|
|
63
|
+
[(3,5),(6,7),(8,9),(10,12),(11,14),(13,16),(15,18),(17,20),(19,21),(22,23),(24,25),(26,28)]
|
|
64
|
+
[(3,4),(5,6),(7,8),(9,10),(11,12),(13,14),(15,16),(17,18),(19,20),(21,22),(23,24),(25,26),(27,28)]
|
|
65
|
+
""",
|
|
66
|
+
# Size 64: 512 CEs, depth 21
|
|
67
|
+
64: """
|
|
68
|
+
[(0,2),(1,3),(4,6),(5,7),(8,10),(9,11),(12,14),(13,15),(16,18),(17,19),(20,22),(21,23),(24,26),(25,27),(28,30),(29,31),(32,34),(33,35),(36,38),(37,39),(40,42),(41,43),(44,46),(45,47),(48,50),(49,51),(52,54),(53,55),(56,58),(57,59),(60,62),(61,63)]
|
|
69
|
+
[(0,1),(2,3),(4,5),(6,7),(8,9),(10,11),(12,13),(14,15),(16,17),(18,19),(20,21),(22,23),(24,25),(26,27),(28,29),(30,31),(32,33),(34,35),(36,37),(38,39),(40,41),(42,43),(44,45),(46,47),(48,49),(50,51),(52,53),(54,55),(56,57),(58,59),(60,61),(62,63)]
|
|
70
|
+
[(0,52),(1,2),(3,55),(4,48),(5,6),(7,51),(8,60),(9,10),(11,63),(12,56),(13,14),(15,59),(16,32),(17,18),(19,35),(20,24),(21,22),(23,27),(25,26),(28,44),(29,30),(31,47),(33,34),(36,40),(37,38),(39,43),(41,42),(45,46),(49,50),(53,54),(57,58),(61,62)]
|
|
71
|
+
[(0,20),(1,53),(2,54),(3,23),(4,28),(5,49),(6,50),(7,31),(8,36),(9,61),(10,62),(11,39),(12,16),(13,57),(14,58),(15,19),(17,33),(18,34),(21,25),(22,26),(24,52),(27,55),(29,45),(30,46),(32,56),(35,59),(37,41),(38,42),(40,60),(43,63),(44,48),(47,51)]
|
|
72
|
+
[(0,4),(1,21),(2,22),(3,7),(5,29),(6,30),(8,12),(9,37),(10,38),(11,15),(13,17),(14,18),(16,20),(19,23),(24,32),(25,53),(26,54),(27,35),(28,36),(31,39),(33,57),(34,58),(40,44),(41,61),(42,62),(43,47),(45,49),(46,50),(48,52),(51,55),(56,60),(59,63)]
|
|
73
|
+
[(0,8),(1,5),(2,6),(3,11),(4,12),(7,15),(9,13),(10,14),(16,40),(17,21),(18,22),(19,43),(20,44),(23,47),(24,28),(25,33),(26,34),(27,31),(29,37),(30,38),(32,36),(35,39),(41,45),(42,46),(48,56),(49,53),(50,54),(51,59),(52,60),(55,63),(57,61),(58,62)]
|
|
74
|
+
[(1,9),(2,10),(4,8),(5,13),(6,14),(7,11),(12,48),(15,51),(16,24),(17,41),(18,42),(19,27),(20,28),(21,45),(22,46),(23,31),(25,29),(26,30),(32,40),(33,37),(34,38),(35,43),(36,44),(39,47),(49,57),(50,58),(52,56),(53,61),(54,62),(55,59)]
|
|
75
|
+
[(4,16),(5,9),(6,10),(7,19),(8,24),(11,27),(13,49),(14,50),(17,25),(18,26),(20,32),(21,29),(22,30),(23,35),(28,40),(31,43),(33,41),(34,42),(36,52),(37,45),(38,46),(39,55),(44,56),(47,59),(53,57),(54,58)]
|
|
76
|
+
[(1,4),(5,17),(6,18),(8,16),(9,25),(10,26),(11,19),(12,24),(15,27),(21,33),(22,34),(29,41),(30,42),(36,48),(37,53),(38,54),(39,51),(44,52),(45,57),(46,58),(47,55),(59,62)]
|
|
77
|
+
[(2,8),(9,17),(10,18),(12,20),(13,25),(14,26),(15,23),(24,32),(27,35),(28,36),(31,39),(37,49),(38,50),(40,48),(43,51),(45,53),(46,54),(55,61)]
|
|
78
|
+
[(2,4),(12,16),(13,21),(14,22),(15,19),(20,24),(23,27),(25,33),(26,34),(28,32),(29,37),(30,38),(31,35),(36,40),(39,43),(41,49),(42,50),(44,48),(47,51),(59,61)]
|
|
79
|
+
[(4,16),(5,20),(10,40),(13,17),(14,18),(21,25),(22,26),(23,53),(24,28),(27,31),(29,33),(30,34),(32,36),(35,39),(37,41),(38,42),(43,58),(45,49),(46,50),(47,59)]
|
|
80
|
+
[(3,17),(6,36),(7,21),(8,32),(9,24),(11,41),(13,28),(14,44),(15,45),(18,48),(19,49),(22,52),(25,29),(26,30),(27,57),(31,55),(33,37),(34,38),(35,50),(39,54),(42,56),(46,60)]
|
|
81
|
+
[(6,20),(8,16),(10,24),(11,25),(14,28),(15,29),(17,33),(18,32),(21,37),(22,36),(26,42),(27,41),(30,46),(31,45),(34,48),(35,49),(38,52),(39,53),(43,57),(47,55)]
|
|
82
|
+
[(3,18),(5,8),(6,12),(7,22),(15,21),(17,32),(19,33),(23,37),(26,40),(30,44),(31,46),(41,56),(42,48),(45,60),(51,57),(55,58)]
|
|
83
|
+
[(3,16),(7,20),(11,26),(18,24),(19,25),(22,28),(23,29),(27,33),(30,36),(34,40),(35,41),(37,52),(38,44),(39,45),(43,56),(47,60)]
|
|
84
|
+
[(3,9),(7,13),(10,16),(11,17),(14,20),(15,30),(19,34),(21,36),(23,38),(25,40),(26,32),(27,42),(29,44),(31,37),(33,48),(43,49),(46,52),(47,53),(50,56),(54,60)]
|
|
85
|
+
[(3,8),(7,10),(9,12),(11,18),(13,14),(15,24),(17,22),(19,28),(21,26),(23,25),(27,34),(29,36),(30,32),(31,33),(35,44),(37,42),(38,40),(39,48),(41,46),(45,52),(49,50),(51,54),(53,56),(55,60)]
|
|
86
|
+
[(3,6),(7,12),(11,16),(15,17),(18,20),(19,24),(21,22),(23,30),(25,32),(26,28),(27,29),(31,38),(33,40),(34,36),(35,37),(39,44),(41,42),(43,45),(46,48),(47,52),(51,56),(57,60)]
|
|
87
|
+
[(3,5),(6,8),(7,9),(10,12),(11,13),(14,16),(15,18),(17,20),(19,21),(22,24),(23,26),(25,28),(27,30),(29,32),(31,34),(33,36),(35,38),(37,40),(39,41),(42,44),(43,46),(45,48),(47,49),(50,52),(51,53),(54,56),(55,57),(58,60)]
|
|
88
|
+
[(3,4),(7,8),(11,12),(13,14),(15,16),(17,18),(19,20),(21,22),(23,24),(25,26),(27,28),(29,30),(31,32),(33,34),(35,36),(37,38),(39,40),(41,42),(43,44),(45,46),(47,48),(49,50),(51,52),(55,56),(59,60)]
|
|
89
|
+
""",
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
# This will be populated by initialize_networks()
|
|
93
|
+
OPTIMAL_NETWORKS: Dict[int, Tuple[int, int, List[List[Tuple[int, int]]]]] = {}
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def parse_network_string(network_str: str) -> List[List[Tuple[int, int]]]:
|
|
97
|
+
"""
|
|
98
|
+
Parse a sorting network string from bertdobbelaere.github.io format.
|
|
99
|
+
|
|
100
|
+
Examples:
|
|
101
|
+
Input: "[(0,2),(1,3)], [(0,1),(2,3)], [(1,2)]"
|
|
102
|
+
Output: [[(0, 2), (1, 3)], [(0, 1), (2, 3)], [(1, 2)]]
|
|
103
|
+
|
|
104
|
+
Input: "[(0,1)], [(1,2)], [(0,1)]"
|
|
105
|
+
Output: [[(0, 1)], [(1, 2)], [(0, 1)]]
|
|
106
|
+
"""
|
|
107
|
+
# Remove whitespace and split by '], ['
|
|
108
|
+
network_str = network_str.strip()
|
|
109
|
+
if not network_str:
|
|
110
|
+
return []
|
|
111
|
+
|
|
112
|
+
# Split into layer strings
|
|
113
|
+
layer_pattern = r"\[((?:\(\d+,\d+\)(?:,\(\d+,\d+\))*)?)\]"
|
|
114
|
+
layers = []
|
|
115
|
+
|
|
116
|
+
for match in re.finditer(layer_pattern, network_str):
|
|
117
|
+
layer_str = match.group(1)
|
|
118
|
+
if not layer_str.strip():
|
|
119
|
+
layers.append([])
|
|
120
|
+
continue
|
|
121
|
+
|
|
122
|
+
# Parse comparisons in this layer: (i,j), (k,l), ...
|
|
123
|
+
comparisons = []
|
|
124
|
+
comp_pattern = r"\((\d+),(\d+)\)"
|
|
125
|
+
|
|
126
|
+
for comp_match in re.finditer(comp_pattern, layer_str):
|
|
127
|
+
i, j = int(comp_match.group(1)), int(comp_match.group(2))
|
|
128
|
+
comparisons.append((i, j))
|
|
129
|
+
|
|
130
|
+
layers.append(comparisons)
|
|
131
|
+
|
|
132
|
+
return layers
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def calculate_network_stats(layers: List[List[Tuple[int, int]]]) -> Tuple[int, int, int]:
|
|
136
|
+
"""Calculate depth, total comparisons, and max index from network layers."""
|
|
137
|
+
depth = len(layers)
|
|
138
|
+
total_comparisons = sum(len(layer) for layer in layers)
|
|
139
|
+
|
|
140
|
+
# Find maximum index to determine network size
|
|
141
|
+
max_index = 0
|
|
142
|
+
for layer in layers:
|
|
143
|
+
for i, j in layer:
|
|
144
|
+
max_index = max(max_index, i, j)
|
|
145
|
+
|
|
146
|
+
network_size = max_index + 1 # Since indices are 0-based
|
|
147
|
+
return depth, total_comparisons, network_size
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def add_network_from_string(size: int, network_str: str, description: str = ""):
|
|
151
|
+
"""
|
|
152
|
+
Add a network from a string representation to the OPTIMAL_NETWORKS dictionary.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
size: Size of the network (number of elements)
|
|
156
|
+
network_str: Network string in bertdobbelaere.github.io format
|
|
157
|
+
description: Optional description for debugging
|
|
158
|
+
"""
|
|
159
|
+
try:
|
|
160
|
+
layers = parse_network_string(network_str)
|
|
161
|
+
depth, comparisons, detected_size = calculate_network_stats(layers)
|
|
162
|
+
|
|
163
|
+
if detected_size != size:
|
|
164
|
+
print(f"Warning: Network size mismatch! Expected {size}, detected {detected_size}")
|
|
165
|
+
print(f"Network string: {network_str[:100]}...")
|
|
166
|
+
return False
|
|
167
|
+
|
|
168
|
+
OPTIMAL_NETWORKS[size] = (depth, comparisons, layers)
|
|
169
|
+
|
|
170
|
+
if description:
|
|
171
|
+
print(f"Added network for size {size}: {description}")
|
|
172
|
+
print(f" Depth: {depth}, Comparisons: {comparisons}")
|
|
173
|
+
return True
|
|
174
|
+
|
|
175
|
+
except Exception as e:
|
|
176
|
+
print(f"Error parsing network for size {size}: {e}")
|
|
177
|
+
print(f"Network string: {network_str[:100]}...")
|
|
178
|
+
return False
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def generate_networks_dict(
|
|
182
|
+
networks_data: Dict[int, Tuple[int, int, List[List[Tuple[int, int]]]]]
|
|
183
|
+
) -> str:
|
|
184
|
+
"""Generate the global networks dictionary."""
|
|
185
|
+
lines = ["networks = {"]
|
|
186
|
+
|
|
187
|
+
for size, (depth, num_comparisons, layers) in sorted(networks_data.items()):
|
|
188
|
+
# Format the network with proper indentation and newlines
|
|
189
|
+
network_lines = []
|
|
190
|
+
for i, layer in enumerate(layers):
|
|
191
|
+
if i == 0:
|
|
192
|
+
network_lines.append(f" {layer}")
|
|
193
|
+
else:
|
|
194
|
+
network_lines.append(f",\n {layer}")
|
|
195
|
+
|
|
196
|
+
if len(layers) == 1:
|
|
197
|
+
network_str = f"[{network_lines[0].strip()}]"
|
|
198
|
+
else:
|
|
199
|
+
network_str = "[\n" + "".join(network_lines) + "\n ]"
|
|
200
|
+
|
|
201
|
+
lines.append(f" # Size {size}: {num_comparisons} CEs, depth {depth}")
|
|
202
|
+
lines.append(f" {size}: {network_str},")
|
|
203
|
+
lines.append("")
|
|
204
|
+
|
|
205
|
+
lines.append("}")
|
|
206
|
+
return "\n".join(lines)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def generate_optimal_sort_function() -> str:
|
|
210
|
+
"""Generate the single optimal_sort function that looks up networks by size."""
|
|
211
|
+
return """@cute.jit
|
|
212
|
+
def optimal_sort(
|
|
213
|
+
arr: cute.Tensor,
|
|
214
|
+
n: cutlass.Constexpr[int],
|
|
215
|
+
start: cutlass.Constexpr[int] = 0,
|
|
216
|
+
ascending: cutlass.Constexpr[bool] = True
|
|
217
|
+
) -> None:
|
|
218
|
+
\"\"\"
|
|
219
|
+
Optimal sorting network dispatcher.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
arr: Array to sort
|
|
223
|
+
n: Size of array (must be power of 2 and available in networks)
|
|
224
|
+
start: Starting index (default 0)
|
|
225
|
+
ascending: Sort in ascending order (default True)
|
|
226
|
+
|
|
227
|
+
Source: https://bertdobbelaere.github.io/sorting_networks.html
|
|
228
|
+
\"\"\"
|
|
229
|
+
assert n in networks
|
|
230
|
+
for level in networks[n]:
|
|
231
|
+
for i, j in level:
|
|
232
|
+
compare_and_swap(arr, start + i, start + j, ascending)
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def generate_sorting_networks_file(max_size: int = 64):
|
|
237
|
+
"""Generate a complete sorting networks file with optimal networks up to max_size."""
|
|
238
|
+
|
|
239
|
+
output_file = os.path.join(os.path.dirname(__file__), "sorting_networks.py")
|
|
240
|
+
|
|
241
|
+
# Header
|
|
242
|
+
header = '''# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
|
|
243
|
+
"""
|
|
244
|
+
Optimal sorting networks generated from: https://bertdobbelaere.github.io/sorting_networks.html
|
|
245
|
+
|
|
246
|
+
This file was auto-generated by quack/sort/generate_sorting_networks.py. Do not edit it directly.
|
|
247
|
+
"""
|
|
248
|
+
|
|
249
|
+
# fmt: off
|
|
250
|
+
# ruff: noqa
|
|
251
|
+
# isort: skip_file
|
|
252
|
+
|
|
253
|
+
import cutlass
|
|
254
|
+
import cutlass.cute as cute
|
|
255
|
+
|
|
256
|
+
from quack.sort.utils import compare_and_swap
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
'''
|
|
260
|
+
|
|
261
|
+
# Generate networks dictionary and optimal_sort function
|
|
262
|
+
sizes = [n for n in range(2, max_size + 1) if n in OPTIMAL_NETWORKS]
|
|
263
|
+
networks_dict = generate_networks_dict(OPTIMAL_NETWORKS)
|
|
264
|
+
optimal_sort_func = generate_optimal_sort_function()
|
|
265
|
+
|
|
266
|
+
# Combine everything
|
|
267
|
+
content = header + networks_dict + "\n\n\n" + optimal_sort_func
|
|
268
|
+
|
|
269
|
+
with open(output_file, "w") as f:
|
|
270
|
+
f.write(content)
|
|
271
|
+
|
|
272
|
+
print(f"Generated optimal sorting networks for sizes {sizes}")
|
|
273
|
+
print(f"Output written to: {output_file}")
|
|
274
|
+
return sizes
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def initialize_networks():
|
|
278
|
+
"""Initialize the OPTIMAL_NETWORKS dictionary by parsing NETWORK_STRINGS."""
|
|
279
|
+
global OPTIMAL_NETWORKS
|
|
280
|
+
OPTIMAL_NETWORKS.clear()
|
|
281
|
+
|
|
282
|
+
for size, network_str in NETWORK_STRINGS.items():
|
|
283
|
+
success = add_network_from_string(size, network_str, f"Size {size} optimal network")
|
|
284
|
+
if not success:
|
|
285
|
+
print(f"Warning: Failed to parse network for size {size}")
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def main():
|
|
289
|
+
parser = argparse.ArgumentParser(
|
|
290
|
+
description="Generate optimal sorting network code from bertdobbelaere.github.io data"
|
|
291
|
+
)
|
|
292
|
+
parser.add_argument(
|
|
293
|
+
"--max-size",
|
|
294
|
+
"-m",
|
|
295
|
+
type=int,
|
|
296
|
+
default=64,
|
|
297
|
+
help="Maximum sorting network size to generate (default: 32)",
|
|
298
|
+
)
|
|
299
|
+
parser.add_argument(
|
|
300
|
+
"--stats", "-s", action="store_true", help="Print statistics about the optimal networks"
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
args = parser.parse_args()
|
|
304
|
+
|
|
305
|
+
# Initialize networks from strings
|
|
306
|
+
initialize_networks()
|
|
307
|
+
|
|
308
|
+
if args.stats:
|
|
309
|
+
print("Optimal Sorting Network Statistics:")
|
|
310
|
+
print("Size\tDepth\tComparisons\tLayers")
|
|
311
|
+
print("-" * 35)
|
|
312
|
+
for n in sorted(OPTIMAL_NETWORKS.keys()):
|
|
313
|
+
if n <= args.max_size:
|
|
314
|
+
depth, comparisons, layers = OPTIMAL_NETWORKS[n]
|
|
315
|
+
print(f"{n}\t{depth}\t{comparisons}\t\t{len(layers)}")
|
|
316
|
+
|
|
317
|
+
# Generate the sorting networks file
|
|
318
|
+
sizes = generate_sorting_networks_file(args.max_size)
|
|
319
|
+
|
|
320
|
+
print(f"\nGenerated optimal sorting networks for {len(sizes)} sizes")
|
|
321
|
+
print(f"Total networks: {len(sizes)}")
|
|
322
|
+
print(f"Max network size: {max(sizes)}")
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
if __name__ == "__main__":
|
|
326
|
+
main()
|