quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/softmax.py CHANGED
@@ -9,7 +9,9 @@ import cutlass.cute as cute
9
9
  from cutlass.cute.runtime import from_dlpack
10
10
 
11
11
  import quack.utils as utils
12
- from quack.reduction_base import ReductionBase, torch2cute_dtype_map
12
+ from quack.reduce import row_reduce, online_softmax_reduce
13
+ from quack.reduction_base import ReductionBase
14
+ from quack.cute_dsl_utils import torch2cute_dtype_map
13
15
 
14
16
 
15
17
  class Softmax(ReductionBase):
@@ -147,7 +149,7 @@ class Softmax(ReductionBase):
147
149
  x = tXrX.load().to(cute.Float32)
148
150
  threads_per_row = tv_layout.shape[0][0]
149
151
  if cutlass.const_expr(not self.online_softmax):
150
- max_x = utils.row_reduce(
152
+ max_x = row_reduce(
151
153
  x,
152
154
  cute.ReductionOp.MAX,
153
155
  threads_per_row,
@@ -158,7 +160,7 @@ class Softmax(ReductionBase):
158
160
  )
159
161
  log2_e = math.log2(math.e)
160
162
  exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
161
- denom = utils.row_reduce(
163
+ denom = row_reduce(
162
164
  exp_x,
163
165
  cute.ReductionOp.ADD,
164
166
  threads_per_row,
@@ -167,7 +169,7 @@ class Softmax(ReductionBase):
167
169
  init_val=0.0,
168
170
  )
169
171
  else:
170
- max_x, denom, exp_x = utils.online_softmax_reduce(
172
+ max_x, denom, exp_x = online_softmax_reduce(
171
173
  x,
172
174
  threads_per_row,
173
175
  reduction_buffer[None, None, 0],
@@ -186,7 +188,8 @@ class Softmax(ReductionBase):
186
188
  cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
187
189
 
188
190
 
189
- def _softmax_fwd(x: torch.Tensor) -> torch.Tensor:
191
+ @torch.library.custom_op("quack::_softmax_fwd", mutates_args={"out"})
192
+ def _softmax_fwd(x: torch.Tensor, out: torch.Tensor) -> None:
190
193
  """Softmax forward pass.
191
194
  Args:
192
195
  x: Input tensor of shape (M, N)
@@ -196,8 +199,7 @@ def _softmax_fwd(x: torch.Tensor) -> torch.Tensor:
196
199
  assert x.dim() == 2, "Input must be 2D"
197
200
  assert x.is_cuda, "Tensor must be on CUDA device"
198
201
  assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
199
- M, N = x.shape
200
- out = torch.empty_like(x)
202
+ N = x.size(1)
201
203
  dtype = torch2cute_dtype_map[x.dtype]
202
204
  convert_from_dlpack = lambda tensor: (
203
205
  from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
@@ -213,12 +215,17 @@ def _softmax_fwd(x: torch.Tensor) -> torch.Tensor:
213
215
  softmax_op, x_tensor, out_tensor, current_stream
214
216
  )
215
217
  _softmax_fwd.compile_cache[compile_key](x_tensor, out_tensor, current_stream)
216
- return out
217
218
 
218
219
 
219
220
  _softmax_fwd.compile_cache = {}
220
221
 
221
222
 
223
+ def softmax_fwd(x: torch.Tensor) -> torch.Tensor:
224
+ out = torch.empty_like(x)
225
+ _softmax_fwd(x, out)
226
+ return out
227
+
228
+
222
229
  class SoftmaxBackward(ReductionBase):
223
230
  def __init__(self, dtype: Type[cutlass.Numeric], N: int):
224
231
  # 1 stage for computing dot product
@@ -372,7 +379,7 @@ class SoftmaxBackward(ReductionBase):
372
379
 
373
380
  # Compute dot product: dot = Σⱼ dy_j × y_j
374
381
  threads_per_row = tv_layout.shape[0][0]
375
- dot = utils.row_reduce(
382
+ dot = row_reduce(
376
383
  dy * y,
377
384
  cute.ReductionOp.ADD,
378
385
  threads_per_row,
@@ -394,7 +401,8 @@ class SoftmaxBackward(ReductionBase):
394
401
  cute.copy(copy_atom_store, tdXrdX, tdXgdX, pred=tdXpdX)
395
402
 
396
403
 
397
- def _softmax_backward(dy: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
404
+ @torch.library.custom_op("quack::_softmax_backward", mutates_args={"dx"})
405
+ def _softmax_backward(dy: torch.Tensor, y: torch.Tensor, dx: torch.Tensor) -> None:
398
406
  """Softmax backward pass.
399
407
  Args:
400
408
  dy: Upstream gradients tensor of shape (M, N)
@@ -409,8 +417,7 @@ def _softmax_backward(dy: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
409
417
  assert dy.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
410
418
  assert y.dtype == dy.dtype, "dy and y must have same dtype"
411
419
 
412
- M, N = dy.shape
413
- dx = torch.empty_like(dy)
420
+ N = dy.size(1)
414
421
  dtype = torch2cute_dtype_map[dy.dtype]
415
422
  convert_from_dlpack = lambda tensor: (
416
423
  from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
@@ -427,23 +434,28 @@ def _softmax_backward(dy: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
427
434
  softmax_backward_op, dy_tensor, y_tensor, dx_tensor, current_stream
428
435
  )
429
436
  _softmax_backward.compile_cache[compile_key](dy_tensor, y_tensor, dx_tensor, current_stream)
430
- return dx
431
437
 
432
438
 
433
439
  _softmax_backward.compile_cache = {}
434
440
 
435
441
 
442
+ def softmax_bwd(dy: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
443
+ dx = torch.empty_like(dy)
444
+ _softmax_backward(dy, y, dx)
445
+ return dx
446
+
447
+
436
448
  class SoftmaxFunction(torch.autograd.Function):
437
449
  @staticmethod
438
450
  def forward(ctx, x):
439
- y = _softmax_fwd(x)
451
+ y = softmax_fwd(x)
440
452
  ctx.save_for_backward(y)
441
453
  return y
442
454
 
443
455
  @staticmethod
444
456
  def backward(ctx, dy):
445
457
  (y,) = ctx.saved_tensors
446
- dx = _softmax_backward(dy, y)
458
+ dx = softmax_bwd(dy, y)
447
459
  return dx
448
460
 
449
461
 
@@ -0,0 +1,126 @@
1
+ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
2
+
3
+ import math
4
+ from typing import Optional
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+
9
+ import quack.utils as utils
10
+ from quack.sort.utils import compare_and_swap
11
+ from quack.sort.sorting_networks import optimal_sort
12
+
13
+
14
+ @cute.jit
15
+ def bitonic_merge(
16
+ arr: cute.Tensor,
17
+ n: cutlass.Constexpr[int],
18
+ start: cutlass.Constexpr[int],
19
+ ascending: cutlass.Constexpr[bool] = True,
20
+ ) -> None:
21
+ """Merge a bitonic sequence into a sorted sequence using iterative approach."""
22
+ if cutlass.const_expr(n > 1):
23
+ num_levels = int(math.log2(n))
24
+ assert n == 2**num_levels, "n must be a power of 2"
25
+ # This one must be range_constexpr otherwise it's very slow for n = 128
26
+ for level in cutlass.range_constexpr(num_levels):
27
+ length = n >> level # n // (2^level)
28
+ step = length // 2
29
+ for i in cutlass.range(n // length, unroll_full=True):
30
+ start_i = start + i * length
31
+ for j in cutlass.range(step, unroll_full=True):
32
+ compare_and_swap(arr, start_i + j, start_i + j + step, ascending)
33
+
34
+
35
+ @cute.jit
36
+ def bitonic_sort(
37
+ arr: cute.Tensor,
38
+ n: Optional[cutlass.Constexpr[int]] = None,
39
+ start: cutlass.Constexpr[int] = 0,
40
+ ascending: cutlass.Constexpr[bool] = True,
41
+ ) -> None:
42
+ """
43
+ Bitonic sort for small arrays of size N (power of 2, N <= 128).
44
+
45
+ Args:
46
+ arr: Array to sort
47
+ n: Size of array (must be power of 2 and <= 128)
48
+ start: Starting index (default 0)
49
+ ascending: Sort in ascending order (default True)
50
+ """
51
+ if cutlass.const_expr(n is None):
52
+ n = cute.size(arr.shape)
53
+ assert n <= 128
54
+ if cutlass.const_expr(n > 1):
55
+ if cutlass.const_expr(n in [2, 4, 8, 16, 32, 64]):
56
+ optimal_sort(arr, n, start, ascending)
57
+ else: # Fall back to bitonic sort
58
+ assert n % 2 == 0
59
+ # Sort first half in ascending order
60
+ bitonic_sort(arr, n // 2, start, True)
61
+ # Sort second half in descending order
62
+ bitonic_sort(arr, n // 2, start + n // 2, False)
63
+ # Merge the whole sequence
64
+ bitonic_merge(arr, n, start, ascending)
65
+
66
+
67
+ @cute.jit
68
+ def bitonic_topk_merge(
69
+ arr0: cute.Tensor,
70
+ arr1: cute.Tensor,
71
+ k: Optional[cutlass.Constexpr[int]] = None,
72
+ start0: cutlass.Constexpr[int] = 0,
73
+ start1: cutlass.Constexpr[int] = 0,
74
+ ascending: cutlass.Constexpr[bool] = False,
75
+ ) -> None:
76
+ if cutlass.const_expr(k is None):
77
+ k = cute.size(arr0.shape)
78
+ if cutlass.const_expr(arr0.element_type == cutlass.Float32):
79
+ minmax_fn = utils.fmin if ascending else cute.arch.fmax
80
+ else:
81
+ minmax_fn = min if ascending else max
82
+ # Write the top k elements to the first half of the array
83
+ for i in cutlass.range(k, unfoll_full=True):
84
+ arr0[start0 + i] = minmax_fn(arr0[start0 + i], arr1[start1 + k - 1 - i])
85
+ # Now the 1st half is bitonic, we just need to merge it
86
+ bitonic_merge(arr0, k, start0, ascending)
87
+
88
+
89
+ @cute.jit
90
+ def bitonic_topk(
91
+ arr: cute.Tensor,
92
+ k: cutlass.Constexpr[int],
93
+ ascending: cutlass.Constexpr[bool] = False,
94
+ warp_width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
95
+ ) -> cute.Tensor:
96
+ """
97
+ Bitonic top-k for small arrays of size N (power of 2, N <= 128).
98
+
99
+ Args:
100
+ arr: Array to sort
101
+ k: must be power of 2 and <= 128
102
+ ascending: Sort in ascending order (default False)
103
+ """
104
+ assert arr.element_type in [cutlass.Float32, cutlass.Int32]
105
+ n = cute.size(arr.shape)
106
+ assert k == 1 << int(math.log2(k)), "k must be a power of 2"
107
+ assert n % k == 0, "n must be divisible by k"
108
+ topk_vals = cute.make_fragment(k, arr.element_type)
109
+ for v in cutlass.range(k, unroll_full=True):
110
+ topk_vals[v] = arr[v]
111
+ bitonic_sort(topk_vals, ascending=ascending)
112
+ other_vals = cute.make_fragment(k, arr.element_type)
113
+ for i in cutlass.range(1, n // k, unroll_full=True):
114
+ for v in cutlass.range(k, unroll_full=True):
115
+ other_vals[v] = arr[i * k + v]
116
+ bitonic_sort(other_vals, ascending=ascending)
117
+ # Merge 2 sorted top-k sequences to get a new top-k sequence
118
+ bitonic_topk_merge(topk_vals, other_vals, ascending=ascending)
119
+ # TODO: this is not efficient for large k (e.g. >= 16) since threads in the same warps
120
+ # do duplicate work.
121
+ for i in cutlass.range(int(math.log2(warp_width)), unroll_full=True):
122
+ other_vals = cute.make_fragment(k, arr.element_type)
123
+ for v in cutlass.range(k, unroll_full=True):
124
+ other_vals[v] = cute.arch.shuffle_sync_bfly(topk_vals[v], offset=1 << i)
125
+ bitonic_topk_merge(topk_vals, other_vals, ascending=ascending)
126
+ return topk_vals
@@ -0,0 +1,326 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate optimized sorting network code from the optimal sorting network data.
4
+ Based on data from: https://bertdobbelaere.github.io/sorting_networks.html
5
+
6
+ This script generates CUTE DSL functions for optimal sorting networks of various sizes.
7
+ """
8
+
9
+ import argparse
10
+ import os
11
+ import re
12
+ from typing import List, Tuple, Dict
13
+
14
+ # Network strings from bertdobbelaere.github.io/sorting_networks.html
15
+ # Copy-paste network strings here, then run initialize_networks() to parse them
16
+ NETWORK_STRINGS = {
17
+ # Size 2: 1 CE, depth 1
18
+ 2: """
19
+ [(0,1)]
20
+ """,
21
+ # Size 4: 5 CEs, depth 3
22
+ 4: """
23
+ [(0,2),(1,3)]
24
+ [(0,1),(2,3)]
25
+ [(1,2)]
26
+ """,
27
+ # Size 8: 19 CEs, depth 6
28
+ 8: """
29
+ [(0,2),(1,3),(4,6),(5,7)]
30
+ [(0,4),(1,5),(2,6),(3,7)]
31
+ [(0,1),(2,3),(4,5),(6,7)]
32
+ [(2,4),(3,5)]
33
+ [(1,4),(3,6)]
34
+ [(1,2),(3,4),(5,6)]
35
+ """,
36
+ # Size 16: 60 CEs, depth 10
37
+ 16: """
38
+ [(0,13),(1,12),(2,15),(3,14),(4,8),(5,6),(7,11),(9,10)]
39
+ [(0,5),(1,7),(2,9),(3,4),(6,13),(8,14),(10,15),(11,12)]
40
+ [(0,1),(2,3),(4,5),(6,8),(7,9),(10,11),(12,13),(14,15)]
41
+ [(0,2),(1,3),(4,10),(5,11),(6,7),(8,9),(12,14),(13,15)]
42
+ [(1,2),(3,12),(4,6),(5,7),(8,10),(9,11),(13,14)]
43
+ [(1,4),(2,6),(5,8),(7,10),(9,13),(11,14)]
44
+ [(2,4),(3,6),(9,12),(11,13)]
45
+ [(3,5),(6,8),(7,9),(10,12)]
46
+ [(3,4),(5,6),(7,8),(9,10),(11,12)]
47
+ [(6,7),(8,9)]
48
+ """,
49
+ # Size 32: 185 CEs, depth 14
50
+ 32: """
51
+ [(0,1),(2,3),(4,5),(6,7),(8,9),(10,11),(12,13),(14,15),(16,17),(18,19),(20,21),(22,23),(24,25),(26,27),(28,29),(30,31)]
52
+ [(0,2),(1,3),(4,6),(5,7),(8,10),(9,11),(12,14),(13,15),(16,18),(17,19),(20,22),(21,23),(24,26),(25,27),(28,30),(29,31)]
53
+ [(0,4),(1,5),(2,6),(3,7),(8,12),(9,13),(10,14),(11,15),(16,20),(17,21),(18,22),(19,23),(24,28),(25,29),(26,30),(27,31)]
54
+ [(0,8),(1,9),(2,10),(3,11),(4,12),(5,13),(6,14),(7,15),(16,24),(17,25),(18,26),(19,27),(20,28),(21,29),(22,30),(23,31)]
55
+ [(0,16),(1,8),(2,4),(3,12),(5,10),(6,9),(7,14),(11,13),(15,31),(17,24),(18,20),(19,28),(21,26),(22,25),(23,30),(27,29)]
56
+ [(1,2),(3,5),(4,8),(6,22),(7,11),(9,25),(10,12),(13,14),(17,18),(19,21),(20,24),(23,27),(26,28),(29,30)]
57
+ [(1,17),(2,18),(3,19),(4,20),(5,10),(7,23),(8,24),(11,27),(12,28),(13,29),(14,30),(21,26)]
58
+ [(3,17),(4,16),(5,21),(6,18),(7,9),(8,20),(10,26),(11,23),(13,25),(14,28),(15,27),(22,24)]
59
+ [(1,4),(3,8),(5,16),(7,17),(9,21),(10,22),(11,19),(12,20),(14,24),(15,26),(23,28),(27,30)]
60
+ [(2,5),(7,8),(9,18),(11,17),(12,16),(13,22),(14,20),(15,19),(23,24),(26,29)]
61
+ [(2,4),(6,12),(9,16),(10,11),(13,17),(14,18),(15,22),(19,25),(20,21),(27,29)]
62
+ [(5,6),(8,12),(9,10),(11,13),(14,16),(15,17),(18,20),(19,23),(21,22),(25,26)]
63
+ [(3,5),(6,7),(8,9),(10,12),(11,14),(13,16),(15,18),(17,20),(19,21),(22,23),(24,25),(26,28)]
64
+ [(3,4),(5,6),(7,8),(9,10),(11,12),(13,14),(15,16),(17,18),(19,20),(21,22),(23,24),(25,26),(27,28)]
65
+ """,
66
+ # Size 64: 512 CEs, depth 21
67
+ 64: """
68
+ [(0,2),(1,3),(4,6),(5,7),(8,10),(9,11),(12,14),(13,15),(16,18),(17,19),(20,22),(21,23),(24,26),(25,27),(28,30),(29,31),(32,34),(33,35),(36,38),(37,39),(40,42),(41,43),(44,46),(45,47),(48,50),(49,51),(52,54),(53,55),(56,58),(57,59),(60,62),(61,63)]
69
+ [(0,1),(2,3),(4,5),(6,7),(8,9),(10,11),(12,13),(14,15),(16,17),(18,19),(20,21),(22,23),(24,25),(26,27),(28,29),(30,31),(32,33),(34,35),(36,37),(38,39),(40,41),(42,43),(44,45),(46,47),(48,49),(50,51),(52,53),(54,55),(56,57),(58,59),(60,61),(62,63)]
70
+ [(0,52),(1,2),(3,55),(4,48),(5,6),(7,51),(8,60),(9,10),(11,63),(12,56),(13,14),(15,59),(16,32),(17,18),(19,35),(20,24),(21,22),(23,27),(25,26),(28,44),(29,30),(31,47),(33,34),(36,40),(37,38),(39,43),(41,42),(45,46),(49,50),(53,54),(57,58),(61,62)]
71
+ [(0,20),(1,53),(2,54),(3,23),(4,28),(5,49),(6,50),(7,31),(8,36),(9,61),(10,62),(11,39),(12,16),(13,57),(14,58),(15,19),(17,33),(18,34),(21,25),(22,26),(24,52),(27,55),(29,45),(30,46),(32,56),(35,59),(37,41),(38,42),(40,60),(43,63),(44,48),(47,51)]
72
+ [(0,4),(1,21),(2,22),(3,7),(5,29),(6,30),(8,12),(9,37),(10,38),(11,15),(13,17),(14,18),(16,20),(19,23),(24,32),(25,53),(26,54),(27,35),(28,36),(31,39),(33,57),(34,58),(40,44),(41,61),(42,62),(43,47),(45,49),(46,50),(48,52),(51,55),(56,60),(59,63)]
73
+ [(0,8),(1,5),(2,6),(3,11),(4,12),(7,15),(9,13),(10,14),(16,40),(17,21),(18,22),(19,43),(20,44),(23,47),(24,28),(25,33),(26,34),(27,31),(29,37),(30,38),(32,36),(35,39),(41,45),(42,46),(48,56),(49,53),(50,54),(51,59),(52,60),(55,63),(57,61),(58,62)]
74
+ [(1,9),(2,10),(4,8),(5,13),(6,14),(7,11),(12,48),(15,51),(16,24),(17,41),(18,42),(19,27),(20,28),(21,45),(22,46),(23,31),(25,29),(26,30),(32,40),(33,37),(34,38),(35,43),(36,44),(39,47),(49,57),(50,58),(52,56),(53,61),(54,62),(55,59)]
75
+ [(4,16),(5,9),(6,10),(7,19),(8,24),(11,27),(13,49),(14,50),(17,25),(18,26),(20,32),(21,29),(22,30),(23,35),(28,40),(31,43),(33,41),(34,42),(36,52),(37,45),(38,46),(39,55),(44,56),(47,59),(53,57),(54,58)]
76
+ [(1,4),(5,17),(6,18),(8,16),(9,25),(10,26),(11,19),(12,24),(15,27),(21,33),(22,34),(29,41),(30,42),(36,48),(37,53),(38,54),(39,51),(44,52),(45,57),(46,58),(47,55),(59,62)]
77
+ [(2,8),(9,17),(10,18),(12,20),(13,25),(14,26),(15,23),(24,32),(27,35),(28,36),(31,39),(37,49),(38,50),(40,48),(43,51),(45,53),(46,54),(55,61)]
78
+ [(2,4),(12,16),(13,21),(14,22),(15,19),(20,24),(23,27),(25,33),(26,34),(28,32),(29,37),(30,38),(31,35),(36,40),(39,43),(41,49),(42,50),(44,48),(47,51),(59,61)]
79
+ [(4,16),(5,20),(10,40),(13,17),(14,18),(21,25),(22,26),(23,53),(24,28),(27,31),(29,33),(30,34),(32,36),(35,39),(37,41),(38,42),(43,58),(45,49),(46,50),(47,59)]
80
+ [(3,17),(6,36),(7,21),(8,32),(9,24),(11,41),(13,28),(14,44),(15,45),(18,48),(19,49),(22,52),(25,29),(26,30),(27,57),(31,55),(33,37),(34,38),(35,50),(39,54),(42,56),(46,60)]
81
+ [(6,20),(8,16),(10,24),(11,25),(14,28),(15,29),(17,33),(18,32),(21,37),(22,36),(26,42),(27,41),(30,46),(31,45),(34,48),(35,49),(38,52),(39,53),(43,57),(47,55)]
82
+ [(3,18),(5,8),(6,12),(7,22),(15,21),(17,32),(19,33),(23,37),(26,40),(30,44),(31,46),(41,56),(42,48),(45,60),(51,57),(55,58)]
83
+ [(3,16),(7,20),(11,26),(18,24),(19,25),(22,28),(23,29),(27,33),(30,36),(34,40),(35,41),(37,52),(38,44),(39,45),(43,56),(47,60)]
84
+ [(3,9),(7,13),(10,16),(11,17),(14,20),(15,30),(19,34),(21,36),(23,38),(25,40),(26,32),(27,42),(29,44),(31,37),(33,48),(43,49),(46,52),(47,53),(50,56),(54,60)]
85
+ [(3,8),(7,10),(9,12),(11,18),(13,14),(15,24),(17,22),(19,28),(21,26),(23,25),(27,34),(29,36),(30,32),(31,33),(35,44),(37,42),(38,40),(39,48),(41,46),(45,52),(49,50),(51,54),(53,56),(55,60)]
86
+ [(3,6),(7,12),(11,16),(15,17),(18,20),(19,24),(21,22),(23,30),(25,32),(26,28),(27,29),(31,38),(33,40),(34,36),(35,37),(39,44),(41,42),(43,45),(46,48),(47,52),(51,56),(57,60)]
87
+ [(3,5),(6,8),(7,9),(10,12),(11,13),(14,16),(15,18),(17,20),(19,21),(22,24),(23,26),(25,28),(27,30),(29,32),(31,34),(33,36),(35,38),(37,40),(39,41),(42,44),(43,46),(45,48),(47,49),(50,52),(51,53),(54,56),(55,57),(58,60)]
88
+ [(3,4),(7,8),(11,12),(13,14),(15,16),(17,18),(19,20),(21,22),(23,24),(25,26),(27,28),(29,30),(31,32),(33,34),(35,36),(37,38),(39,40),(41,42),(43,44),(45,46),(47,48),(49,50),(51,52),(55,56),(59,60)]
89
+ """,
90
+ }
91
+
92
+ # This will be populated by initialize_networks()
93
+ OPTIMAL_NETWORKS: Dict[int, Tuple[int, int, List[List[Tuple[int, int]]]]] = {}
94
+
95
+
96
+ def parse_network_string(network_str: str) -> List[List[Tuple[int, int]]]:
97
+ """
98
+ Parse a sorting network string from bertdobbelaere.github.io format.
99
+
100
+ Examples:
101
+ Input: "[(0,2),(1,3)], [(0,1),(2,3)], [(1,2)]"
102
+ Output: [[(0, 2), (1, 3)], [(0, 1), (2, 3)], [(1, 2)]]
103
+
104
+ Input: "[(0,1)], [(1,2)], [(0,1)]"
105
+ Output: [[(0, 1)], [(1, 2)], [(0, 1)]]
106
+ """
107
+ # Remove whitespace and split by '], ['
108
+ network_str = network_str.strip()
109
+ if not network_str:
110
+ return []
111
+
112
+ # Split into layer strings
113
+ layer_pattern = r"\[((?:\(\d+,\d+\)(?:,\(\d+,\d+\))*)?)\]"
114
+ layers = []
115
+
116
+ for match in re.finditer(layer_pattern, network_str):
117
+ layer_str = match.group(1)
118
+ if not layer_str.strip():
119
+ layers.append([])
120
+ continue
121
+
122
+ # Parse comparisons in this layer: (i,j), (k,l), ...
123
+ comparisons = []
124
+ comp_pattern = r"\((\d+),(\d+)\)"
125
+
126
+ for comp_match in re.finditer(comp_pattern, layer_str):
127
+ i, j = int(comp_match.group(1)), int(comp_match.group(2))
128
+ comparisons.append((i, j))
129
+
130
+ layers.append(comparisons)
131
+
132
+ return layers
133
+
134
+
135
+ def calculate_network_stats(layers: List[List[Tuple[int, int]]]) -> Tuple[int, int, int]:
136
+ """Calculate depth, total comparisons, and max index from network layers."""
137
+ depth = len(layers)
138
+ total_comparisons = sum(len(layer) for layer in layers)
139
+
140
+ # Find maximum index to determine network size
141
+ max_index = 0
142
+ for layer in layers:
143
+ for i, j in layer:
144
+ max_index = max(max_index, i, j)
145
+
146
+ network_size = max_index + 1 # Since indices are 0-based
147
+ return depth, total_comparisons, network_size
148
+
149
+
150
+ def add_network_from_string(size: int, network_str: str, description: str = ""):
151
+ """
152
+ Add a network from a string representation to the OPTIMAL_NETWORKS dictionary.
153
+
154
+ Args:
155
+ size: Size of the network (number of elements)
156
+ network_str: Network string in bertdobbelaere.github.io format
157
+ description: Optional description for debugging
158
+ """
159
+ try:
160
+ layers = parse_network_string(network_str)
161
+ depth, comparisons, detected_size = calculate_network_stats(layers)
162
+
163
+ if detected_size != size:
164
+ print(f"Warning: Network size mismatch! Expected {size}, detected {detected_size}")
165
+ print(f"Network string: {network_str[:100]}...")
166
+ return False
167
+
168
+ OPTIMAL_NETWORKS[size] = (depth, comparisons, layers)
169
+
170
+ if description:
171
+ print(f"Added network for size {size}: {description}")
172
+ print(f" Depth: {depth}, Comparisons: {comparisons}")
173
+ return True
174
+
175
+ except Exception as e:
176
+ print(f"Error parsing network for size {size}: {e}")
177
+ print(f"Network string: {network_str[:100]}...")
178
+ return False
179
+
180
+
181
+ def generate_networks_dict(
182
+ networks_data: Dict[int, Tuple[int, int, List[List[Tuple[int, int]]]]]
183
+ ) -> str:
184
+ """Generate the global networks dictionary."""
185
+ lines = ["networks = {"]
186
+
187
+ for size, (depth, num_comparisons, layers) in sorted(networks_data.items()):
188
+ # Format the network with proper indentation and newlines
189
+ network_lines = []
190
+ for i, layer in enumerate(layers):
191
+ if i == 0:
192
+ network_lines.append(f" {layer}")
193
+ else:
194
+ network_lines.append(f",\n {layer}")
195
+
196
+ if len(layers) == 1:
197
+ network_str = f"[{network_lines[0].strip()}]"
198
+ else:
199
+ network_str = "[\n" + "".join(network_lines) + "\n ]"
200
+
201
+ lines.append(f" # Size {size}: {num_comparisons} CEs, depth {depth}")
202
+ lines.append(f" {size}: {network_str},")
203
+ lines.append("")
204
+
205
+ lines.append("}")
206
+ return "\n".join(lines)
207
+
208
+
209
+ def generate_optimal_sort_function() -> str:
210
+ """Generate the single optimal_sort function that looks up networks by size."""
211
+ return """@cute.jit
212
+ def optimal_sort(
213
+ arr: cute.Tensor,
214
+ n: cutlass.Constexpr[int],
215
+ start: cutlass.Constexpr[int] = 0,
216
+ ascending: cutlass.Constexpr[bool] = True
217
+ ) -> None:
218
+ \"\"\"
219
+ Optimal sorting network dispatcher.
220
+
221
+ Args:
222
+ arr: Array to sort
223
+ n: Size of array (must be power of 2 and available in networks)
224
+ start: Starting index (default 0)
225
+ ascending: Sort in ascending order (default True)
226
+
227
+ Source: https://bertdobbelaere.github.io/sorting_networks.html
228
+ \"\"\"
229
+ assert n in networks
230
+ for level in networks[n]:
231
+ for i, j in level:
232
+ compare_and_swap(arr, start + i, start + j, ascending)
233
+ """
234
+
235
+
236
+ def generate_sorting_networks_file(max_size: int = 64):
237
+ """Generate a complete sorting networks file with optimal networks up to max_size."""
238
+
239
+ output_file = os.path.join(os.path.dirname(__file__), "sorting_networks.py")
240
+
241
+ # Header
242
+ header = '''# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
243
+ """
244
+ Optimal sorting networks generated from: https://bertdobbelaere.github.io/sorting_networks.html
245
+
246
+ This file was auto-generated by quack/sort/generate_sorting_networks.py. Do not edit it directly.
247
+ """
248
+
249
+ # fmt: off
250
+ # ruff: noqa
251
+ # isort: skip_file
252
+
253
+ import cutlass
254
+ import cutlass.cute as cute
255
+
256
+ from quack.sort.utils import compare_and_swap
257
+
258
+
259
+ '''
260
+
261
+ # Generate networks dictionary and optimal_sort function
262
+ sizes = [n for n in range(2, max_size + 1) if n in OPTIMAL_NETWORKS]
263
+ networks_dict = generate_networks_dict(OPTIMAL_NETWORKS)
264
+ optimal_sort_func = generate_optimal_sort_function()
265
+
266
+ # Combine everything
267
+ content = header + networks_dict + "\n\n\n" + optimal_sort_func
268
+
269
+ with open(output_file, "w") as f:
270
+ f.write(content)
271
+
272
+ print(f"Generated optimal sorting networks for sizes {sizes}")
273
+ print(f"Output written to: {output_file}")
274
+ return sizes
275
+
276
+
277
+ def initialize_networks():
278
+ """Initialize the OPTIMAL_NETWORKS dictionary by parsing NETWORK_STRINGS."""
279
+ global OPTIMAL_NETWORKS
280
+ OPTIMAL_NETWORKS.clear()
281
+
282
+ for size, network_str in NETWORK_STRINGS.items():
283
+ success = add_network_from_string(size, network_str, f"Size {size} optimal network")
284
+ if not success:
285
+ print(f"Warning: Failed to parse network for size {size}")
286
+
287
+
288
+ def main():
289
+ parser = argparse.ArgumentParser(
290
+ description="Generate optimal sorting network code from bertdobbelaere.github.io data"
291
+ )
292
+ parser.add_argument(
293
+ "--max-size",
294
+ "-m",
295
+ type=int,
296
+ default=64,
297
+ help="Maximum sorting network size to generate (default: 32)",
298
+ )
299
+ parser.add_argument(
300
+ "--stats", "-s", action="store_true", help="Print statistics about the optimal networks"
301
+ )
302
+
303
+ args = parser.parse_args()
304
+
305
+ # Initialize networks from strings
306
+ initialize_networks()
307
+
308
+ if args.stats:
309
+ print("Optimal Sorting Network Statistics:")
310
+ print("Size\tDepth\tComparisons\tLayers")
311
+ print("-" * 35)
312
+ for n in sorted(OPTIMAL_NETWORKS.keys()):
313
+ if n <= args.max_size:
314
+ depth, comparisons, layers = OPTIMAL_NETWORKS[n]
315
+ print(f"{n}\t{depth}\t{comparisons}\t\t{len(layers)}")
316
+
317
+ # Generate the sorting networks file
318
+ sizes = generate_sorting_networks_file(args.max_size)
319
+
320
+ print(f"\nGenerated optimal sorting networks for {len(sizes)} sizes")
321
+ print(f"Total networks: {len(sizes)}")
322
+ print(f"Max network size: {max(sizes)}")
323
+
324
+
325
+ if __name__ == "__main__":
326
+ main()