quack-kernels 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/lse.py ADDED
@@ -0,0 +1,62 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+ # TODO: we probably dont' need this kernel, just use torch.logsumexp
3
+ import torch
4
+
5
+ import triton
6
+ import triton.language as tl
7
+
8
+
9
+ @triton.jit
10
+ def _lse_kernel(
11
+ lse_ptr,
12
+ logits_ptr,
13
+ n_rows,
14
+ n_cols,
15
+ logits_row_stride,
16
+ logits_col_stride,
17
+ BLOCK_SIZE_M: tl.constexpr,
18
+ BLOCK_SIZE_N: tl.constexpr,
19
+ ):
20
+ row_start = tl.program_id(0) * BLOCK_SIZE_M
21
+ rows = row_start + tl.arange(0, BLOCK_SIZE_M)
22
+ cols = tl.arange(0, BLOCK_SIZE_N)
23
+ logits = tl.load(
24
+ logits_ptr + rows[:, None] * logits_row_stride + cols[None, :] * logits_col_stride,
25
+ mask=(rows[:, None] < n_rows) & (cols[None, :] < n_cols),
26
+ other=-float("inf"),
27
+ ).to(tl.float32)
28
+ m = tl.max(logits, 1)
29
+ lse = tl.log(tl.sum(tl.exp(logits - m[:, None]), 1)) + m
30
+ tl.store(lse_ptr + rows, lse, mask=rows < n_rows)
31
+
32
+
33
+ def logsumexp(logits):
34
+ n_rows, n_cols = logits.shape
35
+ BLOCK_SIZE_M = 32 if logits.stride(1) != 1 else 1
36
+ MAX_BLOCK_SIZE = 64 * 1024
37
+ # BLOCK_SIZE_N = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE // BLOCK_SIZE_M)
38
+ BLOCK_SIZE_N = triton.next_power_of_2(n_cols)
39
+ assert (
40
+ BLOCK_SIZE_M * BLOCK_SIZE_N <= MAX_BLOCK_SIZE
41
+ ), f"Only support max dimension {MAX_BLOCK_SIZE // BLOCK_SIZE_M}"
42
+ num_warps = (
43
+ 4
44
+ if BLOCK_SIZE_N < 2048
45
+ else (8 if BLOCK_SIZE_N < 8192 else (16 if BLOCK_SIZE_N < 128 * 1024 else 32))
46
+ )
47
+ lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)
48
+ # Need this, otherwise Triton tries to launch from cuda:0 and we get
49
+ # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
50
+ with torch.cuda.device(logits.device.index):
51
+ _lse_kernel[(triton.cdiv(n_rows, BLOCK_SIZE_M),)](
52
+ lse,
53
+ logits,
54
+ n_rows,
55
+ n_cols, # shapes
56
+ logits.stride(0), # strides
57
+ logits.stride(1),
58
+ BLOCK_SIZE_M=BLOCK_SIZE_M, # constants
59
+ BLOCK_SIZE_N=BLOCK_SIZE_N, # constants
60
+ num_warps=num_warps,
61
+ )
62
+ return lse
quack/mlp.py ADDED
@@ -0,0 +1,204 @@
1
+ # Copyright (c) 2025, Tri Dao
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+ from torch.amp import custom_fwd, custom_bwd
7
+
8
+ from einops import rearrange
9
+
10
+ from gemm_cublas import gemm as gemm_cb, gemm_add_ as gemm_add_cb_
11
+ # from gemm_cublas.interface import gemm_tuned as gemm_cb, gemm_add_tuned_ as gemm_add_cb_
12
+
13
+ from quack import gemm, gemm_swiglu, gemm_dswiglu # TODO: implement these
14
+
15
+
16
+ class MLPSwiGLUFunc(torch.autograd.Function):
17
+ @staticmethod
18
+ @custom_fwd(device_type="cuda")
19
+ def forward(ctx, x, weight1, weight2, fuse_grad_accum=False):
20
+ """
21
+ x: (..., in_features)
22
+ weight1: (2 * intermediate_features, in_features)
23
+ weight2: (out_features, intermediate_features)
24
+ out: (..., out_features)
25
+ Note that we do swiglu on the even and odd indices of the intermediate output,
26
+ i.e. silu(y[..., ::2]) * y[..., 1::2].
27
+ This is different from the usual swiglu implementation that does: y1, y2 = y.chunk(2, dim=-1); silu(y1) * y2
28
+ """
29
+ needs_weight1_grad = weight1.requires_grad
30
+ needs_weight2_grad = weight2.requires_grad
31
+ needs_input_grad = x.requires_grad
32
+ ctx.weight1_dtype = weight1.dtype
33
+ ctx.weight2_dtype = weight2.dtype
34
+ autocast_dtype = torch.get_autocast_dtype("cuda")
35
+ if torch.is_autocast_enabled():
36
+ x = x.to(dtype=autocast_dtype)
37
+ weight1_og = weight1
38
+ weight2_og = weight2
39
+ if torch.is_autocast_enabled():
40
+ weight1 = weight1.to(dtype=autocast_dtype)
41
+ weight2 = weight2.to(dtype=autocast_dtype)
42
+ batch_shape = x.shape[:-1]
43
+ x = x.reshape(-1, x.shape[-1])
44
+ # don't need preact if not computing gradient
45
+ store_preact = needs_input_grad or needs_weight1_grad or needs_weight2_grad
46
+ # (batch, inter_dim) & (batch, 2 * inter_dim)
47
+ y, preact = gemm_swiglu(x, weight1.T, store_preact=store_preact)
48
+ # out = F.linear(y, weight2)
49
+ out = gemm(y, weight2.T)
50
+ if not needs_input_grad:
51
+ weight1, weight1_og = None, None
52
+ if not needs_weight1_grad:
53
+ x = None
54
+ if not needs_input_grad and not needs_weight1_grad and not needs_weight2_grad:
55
+ weight2, weight2_og = None, None
56
+ preact = None
57
+ ctx.save_for_backward(
58
+ x,
59
+ preact,
60
+ weight1,
61
+ weight2,
62
+ *((weight1_og, weight2_og) if fuse_grad_accum else (None, None)),
63
+ )
64
+ ctx.fuse_grad_accum = fuse_grad_accum
65
+ return out.reshape(*batch_shape, out.shape[-1])
66
+
67
+ @staticmethod
68
+ @custom_bwd(device_type="cuda")
69
+ def backward(ctx, dout):
70
+ """
71
+ dout: (..., out_features)
72
+ """
73
+ if not torch.compiler.is_dynamo_compiling():
74
+ assert dout.stride(-1) == 1
75
+ # weight1_og and weight2_og are None if not ctx.fused_grad_accum
76
+ x, preact, weight1, weight2, weight1_og, weight2_og = ctx.saved_tensors
77
+ batch_shape = dout.shape[:-1]
78
+ dout = dout.reshape(-1, dout.shape[-1])
79
+ if (
80
+ not ctx.needs_input_grad[0]
81
+ and not ctx.needs_weight1_grad[0]
82
+ and not ctx.needs_weight2_grad[0]
83
+ ):
84
+ return (None,) * 4
85
+ assert preact is not None
86
+ # (batch, 2 * inter_dim) and (batch, inter_dim)
87
+ # dpreact, y = gemm_dswiglu(dout, weight2, preact)
88
+ dpreact, y = gemm_dswiglu(dout, weight2, preact, sm_carveout=16)
89
+ if ctx.needs_input_grad[2]:
90
+ # fuse_grad_accum is not compatible with torch.compile
91
+ if not ctx.fuse_grad_accum or weight2_og.grad is None or torch.compiler.is_compiling():
92
+ dweight2 = gemm_cb(dout.T, y, out_dtype=ctx.weight2_dtype)
93
+ # dweight2 = gemm_cb(dout.T, y, out_dtype=ctx.weight2_dtype, sm_carveout=16)
94
+ else:
95
+ # print("Using fuse grad accum in MLP 2", dout.shape, y.shape, weight2_og.grad.shape)
96
+ gemm_add_cb_(dout.T, y, weight2_og.grad)
97
+ # gemm_add_cb_(dout.T, y, weight2_og.grad, sm_carveout=16)
98
+ dweight2 = weight2_og.grad
99
+ weight2_og.grad = (
100
+ None # So that pytorch doesn't add dweight to weight2_og.grad again
101
+ )
102
+ else:
103
+ dweight2 = None
104
+ if ctx.needs_input_grad[0]:
105
+ dx = dpreact @ weight1 # (batch, in_features)
106
+ # dx = gemm(dpreact, weight1) # (batch, in_features)
107
+ dx = dx.reshape(*batch_shape, dx.shape[-1])
108
+ else:
109
+ dx = None
110
+ if ctx.needs_input_grad[1]:
111
+ # fuse_grad_accum is not compatible with torch.compile
112
+ if not ctx.fuse_grad_accum or weight1_og.grad is None or torch.compiler.is_compiling():
113
+ dweight1 = gemm_cb(dpreact.T, x, out_dtype=ctx.weight1_dtype)
114
+ else:
115
+ # print("Using fuse grad accum in MLP 1", dpreact.shape, x.shape, weight1_og.grad.shape)
116
+ gemm_add_cb_(dpreact.T, x, weight1_og.grad)
117
+ dweight1 = weight1_og.grad
118
+ weight1_og.grad = (
119
+ None # So that pytorch doesn't add dweight to weight1_og.grad again
120
+ )
121
+ else:
122
+ dweight1 = None
123
+ return dx, dweight1, dweight2, None
124
+
125
+
126
+ def mlp_swiglu_func(x, weight1, weight2, fuse_grad_accum=False):
127
+ return MLPSwiGLUFunc.apply(x, weight1, weight2, fuse_grad_accum)
128
+
129
+
130
+ class MLPSwiGLU(nn.Module):
131
+ def __init__(
132
+ self,
133
+ in_features,
134
+ hidden_features=None,
135
+ out_features=None,
136
+ bias1=False,
137
+ bias2=False,
138
+ multiple_of=128,
139
+ device=None,
140
+ dtype=None,
141
+ fuse_grad_accum: bool = False,
142
+ ):
143
+ factory_kwargs = {"device": device, "dtype": dtype}
144
+ super().__init__()
145
+ out_features = out_features if out_features is not None else in_features
146
+ hidden_features = (
147
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
148
+ )
149
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
150
+ self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
151
+ self.fc1.weight._muon_reshape_functions = (
152
+ lambda w: rearrange(w, "(d two) e -> two d e", two=2),
153
+ lambda w: rearrange(w, "two d e -> (d two) e"),
154
+ )
155
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
156
+ self.fuse_grad_accum = fuse_grad_accum
157
+
158
+ def forward(self, input: Tensor) -> Tensor:
159
+ if (
160
+ self.fc1.bias is None
161
+ and self.fc2.bias is None
162
+ and input.is_cuda
163
+ and input.stride(-1) == 1
164
+ and self.fc1.in_features % 8 == 0
165
+ and self.fc1.out_features % 16 == 0
166
+ and self.fc2.out_features % 8 == 0
167
+ ):
168
+ return mlp_swiglu_func(
169
+ input,
170
+ self.fc1.weight,
171
+ self.fc2.weight,
172
+ fuse_grad_accum=self.fuse_grad_accum,
173
+ )
174
+ else:
175
+ y = self.fc1(input)
176
+ return self.fc2(F.silu(y[..., ::2]) * y[..., 1::2])
177
+
178
+
179
+ class MLPSwiGLURef(nn.Module):
180
+ def __init__(
181
+ self,
182
+ in_features,
183
+ hidden_features=None,
184
+ out_features=None,
185
+ bias1=False,
186
+ bias2=False,
187
+ multiple_of=128,
188
+ device=None,
189
+ dtype=None,
190
+ ):
191
+ factory_kwargs = {"device": device, "dtype": dtype}
192
+ super().__init__()
193
+ out_features = out_features if out_features is not None else in_features
194
+ hidden_features = (
195
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
196
+ )
197
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
198
+ self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
199
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
200
+
201
+ def forward(self, input: Tensor) -> Tensor:
202
+ y = self.fc1(input)
203
+ y1, y2 = y.chunk(2, dim=-1)
204
+ return self.fc2(F.silu(y1) * y2)
quack/pipeline.py ADDED
@@ -0,0 +1,166 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Optional
4
+ from dataclasses import dataclass
5
+
6
+ import cutlass.cute as cute
7
+ from cutlass.cutlass_dsl import Boolean, Int32, if_generate
8
+ from cutlass.pipeline import CooperativeGroup, PipelineOp, pipeline_init_wait
9
+ from cutlass.pipeline import PipelineAsync, PipelineTmaAsync, PipelineState, PipelineUserType
10
+
11
+ from cutlass.cutlass_dsl import dsl_user_op
12
+ from cutlass._mlir.dialects import nvvm
13
+
14
+
15
+ @dsl_user_op
16
+ def cp_async_mbarrier_arrive_shared(
17
+ mbar_ptr: cute.Pointer, noinc: bool = False, *, loc=None, ip=None
18
+ ) -> None:
19
+ nvvm.cp_async_mbarrier_arrive_shared(
20
+ mbar_ptr.llvm_ptr,
21
+ noinc=noinc,
22
+ loc=loc,
23
+ ip=ip,
24
+ )
25
+
26
+
27
+ class PipelineStateWAdvance(PipelineState):
28
+ def advance_iters(self, num_iterations: Int32):
29
+ self._count += Int32(num_iterations)
30
+ new_index = self._index + Int32(num_iterations)
31
+ # How many times did we cross the stages boundary
32
+ num_crossings = new_index // self.stages
33
+ self._phase ^= num_crossings
34
+ self._index = new_index % self.stages
35
+
36
+ # This can be overridden by derived classes
37
+ def __new_from_mlir_values__(self, values):
38
+ return PipelineStateWAdvance(
39
+ self.stages, Int32(values[0]), Int32(values[1]), Int32(values[2])
40
+ )
41
+
42
+
43
+ def make_pipeline_state(type: PipelineUserType, stages: int):
44
+ """
45
+ Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
46
+ """
47
+ if type is PipelineUserType.Producer:
48
+ return PipelineStateWAdvance(
49
+ stages,
50
+ Int32(0),
51
+ Int32(0),
52
+ Int32(1),
53
+ )
54
+ elif type is PipelineUserType.Consumer:
55
+ return PipelineStateWAdvance(
56
+ stages,
57
+ Int32(0),
58
+ Int32(0),
59
+ Int32(0),
60
+ )
61
+ else:
62
+ assert False, "Error: invalid PipelineUserType specified for make_pipeline_state."
63
+
64
+
65
+ @dataclass(frozen=True)
66
+ class PipelineTmaCpAsync(PipelineTmaAsync):
67
+ """
68
+ PipelineTmaCpAsync is used for CpAync + TMA producers and AsyncThread consumers
69
+ """
70
+
71
+ @staticmethod
72
+ def create(
73
+ *,
74
+ num_stages: int,
75
+ producer_group: CooperativeGroup,
76
+ consumer_group: CooperativeGroup,
77
+ tx_count: int,
78
+ barrier_storage: cute.Pointer = None,
79
+ cta_layout_vmnk: Optional[cute.Layout] = None,
80
+ tidx: Optional[Int32] = None,
81
+ ):
82
+ """
83
+ This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync.
84
+ :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
85
+ :type barrier_storage: cute.Pointer
86
+ :param num_stages: Number of buffer stages for this pipeline
87
+ :type num_stages: Int32
88
+ :param producer_group: CooperativeGroup for the producer agent
89
+ :type producer_group: CooperativeGroup
90
+ :param consumer_group: CooperativeGroup for the consumer agent
91
+ :type consumer_group: CooperativeGroup
92
+ :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
93
+ :type tx_count: int
94
+ :param cta_layout_vmnk: Layout of the cluster shape
95
+ :type cta_layout_vmnk: cute.Layout | None
96
+ :param tidx: thread index to consumer async threads
97
+ :type tidx: Int32 | None
98
+ """
99
+ if not isinstance(barrier_storage, cute.Pointer):
100
+ raise ValueError(
101
+ f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
102
+ )
103
+
104
+ producer_type = PipelineOp.TmaLoad
105
+ consumer_type = PipelineOp.AsyncThread
106
+
107
+ producer = (producer_type, producer_group)
108
+ consumer = (consumer_type, consumer_group)
109
+
110
+ sync_object_full = PipelineAsync._make_sync_object(
111
+ barrier_storage.align(min_align=8), num_stages, producer, tx_count
112
+ )
113
+ sync_object_empty = PipelineAsync._make_sync_object(
114
+ barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
115
+ )
116
+ if tidx is None:
117
+ tidx, _, _ = cute.arch.thread_idx()
118
+ if cta_layout_vmnk is None:
119
+ cta_layout_vmnk = cute.make_layout((1, 1, 1, 1))
120
+ (
121
+ dst_rank,
122
+ is_signalling_thread,
123
+ ) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx)
124
+ if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
125
+ dst_rank = None
126
+ else:
127
+ dst_rank = dst_rank
128
+
129
+ producer_mask = None
130
+
131
+ pipeline_init_wait(cta_layout_vmnk)
132
+
133
+ return PipelineTmaCpAsync(
134
+ sync_object_full,
135
+ sync_object_empty,
136
+ num_stages,
137
+ producer_mask,
138
+ dst_rank,
139
+ is_signalling_thread,
140
+ )
141
+
142
+ def producer_acquire(
143
+ self,
144
+ state: PipelineState,
145
+ try_acquire_token: Optional[Boolean] = None,
146
+ is_tma_warp: Optional[Boolean] = True,
147
+ ):
148
+ """
149
+ TMA producer commit conditionally waits on buffer empty and sets the transaction barrier.
150
+ """
151
+ if_generate(
152
+ try_acquire_token is None or try_acquire_token == 0,
153
+ lambda: self.sync_object_empty.wait(state.index, state.phase),
154
+ )
155
+ # This is the difference between this and PipelineTmaAsync: we could have multiple
156
+ # warps calling this, but only 1 warp should do the arrive on the full barrier
157
+ if_generate(
158
+ is_tma_warp,
159
+ lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
160
+ )
161
+
162
+ def producer_commit(self, state: PipelineState):
163
+ """
164
+ We need the mbarrier to track the completion of cp.async
165
+ """
166
+ cp_async_mbarrier_arrive_shared(self.producer_get_barrier(state), noinc=True)
@@ -0,0 +1,126 @@
1
+ # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
2
+
3
+ import math
4
+ from typing import Optional
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+
9
+ import quack.utils as utils
10
+ from quack.sort.utils import compare_and_swap
11
+ from quack.sort.sorting_networks import optimal_sort
12
+
13
+
14
+ @cute.jit
15
+ def bitonic_merge(
16
+ arr: cute.Tensor,
17
+ n: cutlass.Constexpr[int],
18
+ start: cutlass.Constexpr[int],
19
+ ascending: cutlass.Constexpr[bool] = True,
20
+ ) -> None:
21
+ """Merge a bitonic sequence into a sorted sequence using iterative approach."""
22
+ if cutlass.const_expr(n > 1):
23
+ num_levels = int(math.log2(n))
24
+ assert n == 2**num_levels, "n must be a power of 2"
25
+ # This one must be range_constexpr otherwise it's very slow for n = 128
26
+ for level in cutlass.range_constexpr(num_levels):
27
+ length = n >> level # n // (2^level)
28
+ step = length // 2
29
+ for i in cutlass.range(n // length, unroll_full=True):
30
+ start_i = start + i * length
31
+ for j in cutlass.range(step, unroll_full=True):
32
+ compare_and_swap(arr, start_i + j, start_i + j + step, ascending)
33
+
34
+
35
+ @cute.jit
36
+ def bitonic_sort(
37
+ arr: cute.Tensor,
38
+ n: Optional[cutlass.Constexpr[int]] = None,
39
+ start: cutlass.Constexpr[int] = 0,
40
+ ascending: cutlass.Constexpr[bool] = True,
41
+ ) -> None:
42
+ """
43
+ Bitonic sort for small arrays of size N (power of 2, N <= 128).
44
+
45
+ Args:
46
+ arr: Array to sort
47
+ n: Size of array (must be power of 2 and <= 128)
48
+ start: Starting index (default 0)
49
+ ascending: Sort in ascending order (default True)
50
+ """
51
+ if cutlass.const_expr(n is None):
52
+ n = cute.size(arr.shape)
53
+ assert n <= 128
54
+ if cutlass.const_expr(n > 1):
55
+ if cutlass.const_expr(n in [2, 4, 8, 16, 32, 64]):
56
+ optimal_sort(arr, n, start, ascending)
57
+ else: # Fall back to bitonic sort
58
+ assert n % 2 == 0
59
+ # Sort first half in ascending order
60
+ bitonic_sort(arr, n // 2, start, True)
61
+ # Sort second half in descending order
62
+ bitonic_sort(arr, n // 2, start + n // 2, False)
63
+ # Merge the whole sequence
64
+ bitonic_merge(arr, n, start, ascending)
65
+
66
+
67
+ @cute.jit
68
+ def bitonic_topk_merge(
69
+ arr0: cute.Tensor,
70
+ arr1: cute.Tensor,
71
+ k: Optional[cutlass.Constexpr[int]] = None,
72
+ start0: cutlass.Constexpr[int] = 0,
73
+ start1: cutlass.Constexpr[int] = 0,
74
+ ascending: cutlass.Constexpr[bool] = False,
75
+ ) -> None:
76
+ if cutlass.const_expr(k is None):
77
+ k = cute.size(arr0.shape)
78
+ if cutlass.const_expr(arr0.element_type == cutlass.Float32):
79
+ minmax_fn = utils.fmin if ascending else cute.arch.fmax
80
+ else:
81
+ minmax_fn = min if ascending else max
82
+ # Write the top k elements to the first half of the array
83
+ for i in cutlass.range(k, unfoll_full=True):
84
+ arr0[start0 + i] = minmax_fn(arr0[start0 + i], arr1[start1 + k - 1 - i])
85
+ # Now the 1st half is bitonic, we just need to merge it
86
+ bitonic_merge(arr0, k, start0, ascending)
87
+
88
+
89
+ @cute.jit
90
+ def bitonic_topk(
91
+ arr: cute.Tensor,
92
+ k: cutlass.Constexpr[int],
93
+ ascending: cutlass.Constexpr[bool] = False,
94
+ warp_width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
95
+ ) -> cute.Tensor:
96
+ """
97
+ Bitonic top-k for small arrays of size N (power of 2, N <= 128).
98
+
99
+ Args:
100
+ arr: Array to sort
101
+ k: must be power of 2 and <= 128
102
+ ascending: Sort in ascending order (default False)
103
+ """
104
+ assert arr.element_type in [cutlass.Float32, cutlass.Int32]
105
+ n = cute.size(arr.shape)
106
+ assert k == 1 << int(math.log2(k)), "k must be a power of 2"
107
+ assert n % k == 0, "n must be divisible by k"
108
+ topk_vals = cute.make_fragment(k, arr.element_type)
109
+ for v in cutlass.range(k, unroll_full=True):
110
+ topk_vals[v] = arr[v]
111
+ bitonic_sort(topk_vals, ascending=ascending)
112
+ other_vals = cute.make_fragment(k, arr.element_type)
113
+ for i in cutlass.range(1, n // k, unroll_full=True):
114
+ for v in cutlass.range(k, unroll_full=True):
115
+ other_vals[v] = arr[i * k + v]
116
+ bitonic_sort(other_vals, ascending=ascending)
117
+ # Merge 2 sorted top-k sequences to get a new top-k sequence
118
+ bitonic_topk_merge(topk_vals, other_vals, ascending=ascending)
119
+ # TODO: this is not efficient for large k (e.g. >= 16) since threads in the same warps
120
+ # do duplicate work.
121
+ for i in cutlass.range(int(math.log2(warp_width)), unroll_full=True):
122
+ other_vals = cute.make_fragment(k, arr.element_type)
123
+ for v in cutlass.range(k, unroll_full=True):
124
+ other_vals[v] = cute.arch.shuffle_sync_bfly(topk_vals[v], offset=1 << i)
125
+ bitonic_topk_merge(topk_vals, other_vals, ascending=ascending)
126
+ return topk_vals