quack-kernels 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +4 -1
- quack/autotuner.py +309 -0
- quack/cross_entropy.py +2 -5
- quack/cute_dsl_utils.py +40 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +2474 -0
- quack/fast_math.py +97 -0
- quack/gemm_config.py +61 -0
- quack/gemm_interface.py +321 -0
- quack/linear.py +176 -0
- quack/lse.py +62 -0
- quack/mlp.py +204 -0
- quack/pipeline.py +166 -0
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2088 -0
- quack/tensormap_manager.py +114 -0
- quack/tile_scheduler.py +935 -0
- quack/topk.py +221 -0
- quack/utils.py +237 -19
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/METADATA +3 -3
- quack_kernels-0.1.11.dist-info/RECORD +31 -0
- quack_kernels-0.1.9.dist-info/RECORD +0 -12
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/top_level.txt +0 -0
quack/lse.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
# TODO: we probably dont' need this kernel, just use torch.logsumexp
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
import triton
|
|
6
|
+
import triton.language as tl
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@triton.jit
|
|
10
|
+
def _lse_kernel(
|
|
11
|
+
lse_ptr,
|
|
12
|
+
logits_ptr,
|
|
13
|
+
n_rows,
|
|
14
|
+
n_cols,
|
|
15
|
+
logits_row_stride,
|
|
16
|
+
logits_col_stride,
|
|
17
|
+
BLOCK_SIZE_M: tl.constexpr,
|
|
18
|
+
BLOCK_SIZE_N: tl.constexpr,
|
|
19
|
+
):
|
|
20
|
+
row_start = tl.program_id(0) * BLOCK_SIZE_M
|
|
21
|
+
rows = row_start + tl.arange(0, BLOCK_SIZE_M)
|
|
22
|
+
cols = tl.arange(0, BLOCK_SIZE_N)
|
|
23
|
+
logits = tl.load(
|
|
24
|
+
logits_ptr + rows[:, None] * logits_row_stride + cols[None, :] * logits_col_stride,
|
|
25
|
+
mask=(rows[:, None] < n_rows) & (cols[None, :] < n_cols),
|
|
26
|
+
other=-float("inf"),
|
|
27
|
+
).to(tl.float32)
|
|
28
|
+
m = tl.max(logits, 1)
|
|
29
|
+
lse = tl.log(tl.sum(tl.exp(logits - m[:, None]), 1)) + m
|
|
30
|
+
tl.store(lse_ptr + rows, lse, mask=rows < n_rows)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def logsumexp(logits):
|
|
34
|
+
n_rows, n_cols = logits.shape
|
|
35
|
+
BLOCK_SIZE_M = 32 if logits.stride(1) != 1 else 1
|
|
36
|
+
MAX_BLOCK_SIZE = 64 * 1024
|
|
37
|
+
# BLOCK_SIZE_N = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE // BLOCK_SIZE_M)
|
|
38
|
+
BLOCK_SIZE_N = triton.next_power_of_2(n_cols)
|
|
39
|
+
assert (
|
|
40
|
+
BLOCK_SIZE_M * BLOCK_SIZE_N <= MAX_BLOCK_SIZE
|
|
41
|
+
), f"Only support max dimension {MAX_BLOCK_SIZE // BLOCK_SIZE_M}"
|
|
42
|
+
num_warps = (
|
|
43
|
+
4
|
|
44
|
+
if BLOCK_SIZE_N < 2048
|
|
45
|
+
else (8 if BLOCK_SIZE_N < 8192 else (16 if BLOCK_SIZE_N < 128 * 1024 else 32))
|
|
46
|
+
)
|
|
47
|
+
lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)
|
|
48
|
+
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
|
49
|
+
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
|
50
|
+
with torch.cuda.device(logits.device.index):
|
|
51
|
+
_lse_kernel[(triton.cdiv(n_rows, BLOCK_SIZE_M),)](
|
|
52
|
+
lse,
|
|
53
|
+
logits,
|
|
54
|
+
n_rows,
|
|
55
|
+
n_cols, # shapes
|
|
56
|
+
logits.stride(0), # strides
|
|
57
|
+
logits.stride(1),
|
|
58
|
+
BLOCK_SIZE_M=BLOCK_SIZE_M, # constants
|
|
59
|
+
BLOCK_SIZE_N=BLOCK_SIZE_N, # constants
|
|
60
|
+
num_warps=num_warps,
|
|
61
|
+
)
|
|
62
|
+
return lse
|
quack/mlp.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
from torch.amp import custom_fwd, custom_bwd
|
|
7
|
+
|
|
8
|
+
from einops import rearrange
|
|
9
|
+
|
|
10
|
+
from gemm_cublas import gemm as gemm_cb, gemm_add_ as gemm_add_cb_
|
|
11
|
+
# from gemm_cublas.interface import gemm_tuned as gemm_cb, gemm_add_tuned_ as gemm_add_cb_
|
|
12
|
+
|
|
13
|
+
from quack import gemm, gemm_swiglu, gemm_dswiglu # TODO: implement these
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MLPSwiGLUFunc(torch.autograd.Function):
|
|
17
|
+
@staticmethod
|
|
18
|
+
@custom_fwd(device_type="cuda")
|
|
19
|
+
def forward(ctx, x, weight1, weight2, fuse_grad_accum=False):
|
|
20
|
+
"""
|
|
21
|
+
x: (..., in_features)
|
|
22
|
+
weight1: (2 * intermediate_features, in_features)
|
|
23
|
+
weight2: (out_features, intermediate_features)
|
|
24
|
+
out: (..., out_features)
|
|
25
|
+
Note that we do swiglu on the even and odd indices of the intermediate output,
|
|
26
|
+
i.e. silu(y[..., ::2]) * y[..., 1::2].
|
|
27
|
+
This is different from the usual swiglu implementation that does: y1, y2 = y.chunk(2, dim=-1); silu(y1) * y2
|
|
28
|
+
"""
|
|
29
|
+
needs_weight1_grad = weight1.requires_grad
|
|
30
|
+
needs_weight2_grad = weight2.requires_grad
|
|
31
|
+
needs_input_grad = x.requires_grad
|
|
32
|
+
ctx.weight1_dtype = weight1.dtype
|
|
33
|
+
ctx.weight2_dtype = weight2.dtype
|
|
34
|
+
autocast_dtype = torch.get_autocast_dtype("cuda")
|
|
35
|
+
if torch.is_autocast_enabled():
|
|
36
|
+
x = x.to(dtype=autocast_dtype)
|
|
37
|
+
weight1_og = weight1
|
|
38
|
+
weight2_og = weight2
|
|
39
|
+
if torch.is_autocast_enabled():
|
|
40
|
+
weight1 = weight1.to(dtype=autocast_dtype)
|
|
41
|
+
weight2 = weight2.to(dtype=autocast_dtype)
|
|
42
|
+
batch_shape = x.shape[:-1]
|
|
43
|
+
x = x.reshape(-1, x.shape[-1])
|
|
44
|
+
# don't need preact if not computing gradient
|
|
45
|
+
store_preact = needs_input_grad or needs_weight1_grad or needs_weight2_grad
|
|
46
|
+
# (batch, inter_dim) & (batch, 2 * inter_dim)
|
|
47
|
+
y, preact = gemm_swiglu(x, weight1.T, store_preact=store_preact)
|
|
48
|
+
# out = F.linear(y, weight2)
|
|
49
|
+
out = gemm(y, weight2.T)
|
|
50
|
+
if not needs_input_grad:
|
|
51
|
+
weight1, weight1_og = None, None
|
|
52
|
+
if not needs_weight1_grad:
|
|
53
|
+
x = None
|
|
54
|
+
if not needs_input_grad and not needs_weight1_grad and not needs_weight2_grad:
|
|
55
|
+
weight2, weight2_og = None, None
|
|
56
|
+
preact = None
|
|
57
|
+
ctx.save_for_backward(
|
|
58
|
+
x,
|
|
59
|
+
preact,
|
|
60
|
+
weight1,
|
|
61
|
+
weight2,
|
|
62
|
+
*((weight1_og, weight2_og) if fuse_grad_accum else (None, None)),
|
|
63
|
+
)
|
|
64
|
+
ctx.fuse_grad_accum = fuse_grad_accum
|
|
65
|
+
return out.reshape(*batch_shape, out.shape[-1])
|
|
66
|
+
|
|
67
|
+
@staticmethod
|
|
68
|
+
@custom_bwd(device_type="cuda")
|
|
69
|
+
def backward(ctx, dout):
|
|
70
|
+
"""
|
|
71
|
+
dout: (..., out_features)
|
|
72
|
+
"""
|
|
73
|
+
if not torch.compiler.is_dynamo_compiling():
|
|
74
|
+
assert dout.stride(-1) == 1
|
|
75
|
+
# weight1_og and weight2_og are None if not ctx.fused_grad_accum
|
|
76
|
+
x, preact, weight1, weight2, weight1_og, weight2_og = ctx.saved_tensors
|
|
77
|
+
batch_shape = dout.shape[:-1]
|
|
78
|
+
dout = dout.reshape(-1, dout.shape[-1])
|
|
79
|
+
if (
|
|
80
|
+
not ctx.needs_input_grad[0]
|
|
81
|
+
and not ctx.needs_weight1_grad[0]
|
|
82
|
+
and not ctx.needs_weight2_grad[0]
|
|
83
|
+
):
|
|
84
|
+
return (None,) * 4
|
|
85
|
+
assert preact is not None
|
|
86
|
+
# (batch, 2 * inter_dim) and (batch, inter_dim)
|
|
87
|
+
# dpreact, y = gemm_dswiglu(dout, weight2, preact)
|
|
88
|
+
dpreact, y = gemm_dswiglu(dout, weight2, preact, sm_carveout=16)
|
|
89
|
+
if ctx.needs_input_grad[2]:
|
|
90
|
+
# fuse_grad_accum is not compatible with torch.compile
|
|
91
|
+
if not ctx.fuse_grad_accum or weight2_og.grad is None or torch.compiler.is_compiling():
|
|
92
|
+
dweight2 = gemm_cb(dout.T, y, out_dtype=ctx.weight2_dtype)
|
|
93
|
+
# dweight2 = gemm_cb(dout.T, y, out_dtype=ctx.weight2_dtype, sm_carveout=16)
|
|
94
|
+
else:
|
|
95
|
+
# print("Using fuse grad accum in MLP 2", dout.shape, y.shape, weight2_og.grad.shape)
|
|
96
|
+
gemm_add_cb_(dout.T, y, weight2_og.grad)
|
|
97
|
+
# gemm_add_cb_(dout.T, y, weight2_og.grad, sm_carveout=16)
|
|
98
|
+
dweight2 = weight2_og.grad
|
|
99
|
+
weight2_og.grad = (
|
|
100
|
+
None # So that pytorch doesn't add dweight to weight2_og.grad again
|
|
101
|
+
)
|
|
102
|
+
else:
|
|
103
|
+
dweight2 = None
|
|
104
|
+
if ctx.needs_input_grad[0]:
|
|
105
|
+
dx = dpreact @ weight1 # (batch, in_features)
|
|
106
|
+
# dx = gemm(dpreact, weight1) # (batch, in_features)
|
|
107
|
+
dx = dx.reshape(*batch_shape, dx.shape[-1])
|
|
108
|
+
else:
|
|
109
|
+
dx = None
|
|
110
|
+
if ctx.needs_input_grad[1]:
|
|
111
|
+
# fuse_grad_accum is not compatible with torch.compile
|
|
112
|
+
if not ctx.fuse_grad_accum or weight1_og.grad is None or torch.compiler.is_compiling():
|
|
113
|
+
dweight1 = gemm_cb(dpreact.T, x, out_dtype=ctx.weight1_dtype)
|
|
114
|
+
else:
|
|
115
|
+
# print("Using fuse grad accum in MLP 1", dpreact.shape, x.shape, weight1_og.grad.shape)
|
|
116
|
+
gemm_add_cb_(dpreact.T, x, weight1_og.grad)
|
|
117
|
+
dweight1 = weight1_og.grad
|
|
118
|
+
weight1_og.grad = (
|
|
119
|
+
None # So that pytorch doesn't add dweight to weight1_og.grad again
|
|
120
|
+
)
|
|
121
|
+
else:
|
|
122
|
+
dweight1 = None
|
|
123
|
+
return dx, dweight1, dweight2, None
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def mlp_swiglu_func(x, weight1, weight2, fuse_grad_accum=False):
|
|
127
|
+
return MLPSwiGLUFunc.apply(x, weight1, weight2, fuse_grad_accum)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class MLPSwiGLU(nn.Module):
|
|
131
|
+
def __init__(
|
|
132
|
+
self,
|
|
133
|
+
in_features,
|
|
134
|
+
hidden_features=None,
|
|
135
|
+
out_features=None,
|
|
136
|
+
bias1=False,
|
|
137
|
+
bias2=False,
|
|
138
|
+
multiple_of=128,
|
|
139
|
+
device=None,
|
|
140
|
+
dtype=None,
|
|
141
|
+
fuse_grad_accum: bool = False,
|
|
142
|
+
):
|
|
143
|
+
factory_kwargs = {"device": device, "dtype": dtype}
|
|
144
|
+
super().__init__()
|
|
145
|
+
out_features = out_features if out_features is not None else in_features
|
|
146
|
+
hidden_features = (
|
|
147
|
+
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
|
148
|
+
)
|
|
149
|
+
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
|
150
|
+
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
|
|
151
|
+
self.fc1.weight._muon_reshape_functions = (
|
|
152
|
+
lambda w: rearrange(w, "(d two) e -> two d e", two=2),
|
|
153
|
+
lambda w: rearrange(w, "two d e -> (d two) e"),
|
|
154
|
+
)
|
|
155
|
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
|
156
|
+
self.fuse_grad_accum = fuse_grad_accum
|
|
157
|
+
|
|
158
|
+
def forward(self, input: Tensor) -> Tensor:
|
|
159
|
+
if (
|
|
160
|
+
self.fc1.bias is None
|
|
161
|
+
and self.fc2.bias is None
|
|
162
|
+
and input.is_cuda
|
|
163
|
+
and input.stride(-1) == 1
|
|
164
|
+
and self.fc1.in_features % 8 == 0
|
|
165
|
+
and self.fc1.out_features % 16 == 0
|
|
166
|
+
and self.fc2.out_features % 8 == 0
|
|
167
|
+
):
|
|
168
|
+
return mlp_swiglu_func(
|
|
169
|
+
input,
|
|
170
|
+
self.fc1.weight,
|
|
171
|
+
self.fc2.weight,
|
|
172
|
+
fuse_grad_accum=self.fuse_grad_accum,
|
|
173
|
+
)
|
|
174
|
+
else:
|
|
175
|
+
y = self.fc1(input)
|
|
176
|
+
return self.fc2(F.silu(y[..., ::2]) * y[..., 1::2])
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class MLPSwiGLURef(nn.Module):
|
|
180
|
+
def __init__(
|
|
181
|
+
self,
|
|
182
|
+
in_features,
|
|
183
|
+
hidden_features=None,
|
|
184
|
+
out_features=None,
|
|
185
|
+
bias1=False,
|
|
186
|
+
bias2=False,
|
|
187
|
+
multiple_of=128,
|
|
188
|
+
device=None,
|
|
189
|
+
dtype=None,
|
|
190
|
+
):
|
|
191
|
+
factory_kwargs = {"device": device, "dtype": dtype}
|
|
192
|
+
super().__init__()
|
|
193
|
+
out_features = out_features if out_features is not None else in_features
|
|
194
|
+
hidden_features = (
|
|
195
|
+
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
|
196
|
+
)
|
|
197
|
+
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
|
198
|
+
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
|
|
199
|
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
|
200
|
+
|
|
201
|
+
def forward(self, input: Tensor) -> Tensor:
|
|
202
|
+
y = self.fc1(input)
|
|
203
|
+
y1, y2 = y.chunk(2, dim=-1)
|
|
204
|
+
return self.fc2(F.silu(y1) * y2)
|
quack/pipeline.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
import cutlass.cute as cute
|
|
7
|
+
from cutlass.cutlass_dsl import Boolean, Int32, if_generate
|
|
8
|
+
from cutlass.pipeline import CooperativeGroup, PipelineOp, pipeline_init_wait
|
|
9
|
+
from cutlass.pipeline import PipelineAsync, PipelineTmaAsync, PipelineState, PipelineUserType
|
|
10
|
+
|
|
11
|
+
from cutlass.cutlass_dsl import dsl_user_op
|
|
12
|
+
from cutlass._mlir.dialects import nvvm
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dsl_user_op
|
|
16
|
+
def cp_async_mbarrier_arrive_shared(
|
|
17
|
+
mbar_ptr: cute.Pointer, noinc: bool = False, *, loc=None, ip=None
|
|
18
|
+
) -> None:
|
|
19
|
+
nvvm.cp_async_mbarrier_arrive_shared(
|
|
20
|
+
mbar_ptr.llvm_ptr,
|
|
21
|
+
noinc=noinc,
|
|
22
|
+
loc=loc,
|
|
23
|
+
ip=ip,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class PipelineStateWAdvance(PipelineState):
|
|
28
|
+
def advance_iters(self, num_iterations: Int32):
|
|
29
|
+
self._count += Int32(num_iterations)
|
|
30
|
+
new_index = self._index + Int32(num_iterations)
|
|
31
|
+
# How many times did we cross the stages boundary
|
|
32
|
+
num_crossings = new_index // self.stages
|
|
33
|
+
self._phase ^= num_crossings
|
|
34
|
+
self._index = new_index % self.stages
|
|
35
|
+
|
|
36
|
+
# This can be overridden by derived classes
|
|
37
|
+
def __new_from_mlir_values__(self, values):
|
|
38
|
+
return PipelineStateWAdvance(
|
|
39
|
+
self.stages, Int32(values[0]), Int32(values[1]), Int32(values[2])
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def make_pipeline_state(type: PipelineUserType, stages: int):
|
|
44
|
+
"""
|
|
45
|
+
Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.
|
|
46
|
+
"""
|
|
47
|
+
if type is PipelineUserType.Producer:
|
|
48
|
+
return PipelineStateWAdvance(
|
|
49
|
+
stages,
|
|
50
|
+
Int32(0),
|
|
51
|
+
Int32(0),
|
|
52
|
+
Int32(1),
|
|
53
|
+
)
|
|
54
|
+
elif type is PipelineUserType.Consumer:
|
|
55
|
+
return PipelineStateWAdvance(
|
|
56
|
+
stages,
|
|
57
|
+
Int32(0),
|
|
58
|
+
Int32(0),
|
|
59
|
+
Int32(0),
|
|
60
|
+
)
|
|
61
|
+
else:
|
|
62
|
+
assert False, "Error: invalid PipelineUserType specified for make_pipeline_state."
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass(frozen=True)
|
|
66
|
+
class PipelineTmaCpAsync(PipelineTmaAsync):
|
|
67
|
+
"""
|
|
68
|
+
PipelineTmaCpAsync is used for CpAync + TMA producers and AsyncThread consumers
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def create(
|
|
73
|
+
*,
|
|
74
|
+
num_stages: int,
|
|
75
|
+
producer_group: CooperativeGroup,
|
|
76
|
+
consumer_group: CooperativeGroup,
|
|
77
|
+
tx_count: int,
|
|
78
|
+
barrier_storage: cute.Pointer = None,
|
|
79
|
+
cta_layout_vmnk: Optional[cute.Layout] = None,
|
|
80
|
+
tidx: Optional[Int32] = None,
|
|
81
|
+
):
|
|
82
|
+
"""
|
|
83
|
+
This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync.
|
|
84
|
+
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
|
85
|
+
:type barrier_storage: cute.Pointer
|
|
86
|
+
:param num_stages: Number of buffer stages for this pipeline
|
|
87
|
+
:type num_stages: Int32
|
|
88
|
+
:param producer_group: CooperativeGroup for the producer agent
|
|
89
|
+
:type producer_group: CooperativeGroup
|
|
90
|
+
:param consumer_group: CooperativeGroup for the consumer agent
|
|
91
|
+
:type consumer_group: CooperativeGroup
|
|
92
|
+
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
|
93
|
+
:type tx_count: int
|
|
94
|
+
:param cta_layout_vmnk: Layout of the cluster shape
|
|
95
|
+
:type cta_layout_vmnk: cute.Layout | None
|
|
96
|
+
:param tidx: thread index to consumer async threads
|
|
97
|
+
:type tidx: Int32 | None
|
|
98
|
+
"""
|
|
99
|
+
if not isinstance(barrier_storage, cute.Pointer):
|
|
100
|
+
raise ValueError(
|
|
101
|
+
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
producer_type = PipelineOp.TmaLoad
|
|
105
|
+
consumer_type = PipelineOp.AsyncThread
|
|
106
|
+
|
|
107
|
+
producer = (producer_type, producer_group)
|
|
108
|
+
consumer = (consumer_type, consumer_group)
|
|
109
|
+
|
|
110
|
+
sync_object_full = PipelineAsync._make_sync_object(
|
|
111
|
+
barrier_storage.align(min_align=8), num_stages, producer, tx_count
|
|
112
|
+
)
|
|
113
|
+
sync_object_empty = PipelineAsync._make_sync_object(
|
|
114
|
+
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
|
115
|
+
)
|
|
116
|
+
if tidx is None:
|
|
117
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
118
|
+
if cta_layout_vmnk is None:
|
|
119
|
+
cta_layout_vmnk = cute.make_layout((1, 1, 1, 1))
|
|
120
|
+
(
|
|
121
|
+
dst_rank,
|
|
122
|
+
is_signalling_thread,
|
|
123
|
+
) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx)
|
|
124
|
+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
|
|
125
|
+
dst_rank = None
|
|
126
|
+
else:
|
|
127
|
+
dst_rank = dst_rank
|
|
128
|
+
|
|
129
|
+
producer_mask = None
|
|
130
|
+
|
|
131
|
+
pipeline_init_wait(cta_layout_vmnk)
|
|
132
|
+
|
|
133
|
+
return PipelineTmaCpAsync(
|
|
134
|
+
sync_object_full,
|
|
135
|
+
sync_object_empty,
|
|
136
|
+
num_stages,
|
|
137
|
+
producer_mask,
|
|
138
|
+
dst_rank,
|
|
139
|
+
is_signalling_thread,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
def producer_acquire(
|
|
143
|
+
self,
|
|
144
|
+
state: PipelineState,
|
|
145
|
+
try_acquire_token: Optional[Boolean] = None,
|
|
146
|
+
is_tma_warp: Optional[Boolean] = True,
|
|
147
|
+
):
|
|
148
|
+
"""
|
|
149
|
+
TMA producer commit conditionally waits on buffer empty and sets the transaction barrier.
|
|
150
|
+
"""
|
|
151
|
+
if_generate(
|
|
152
|
+
try_acquire_token is None or try_acquire_token == 0,
|
|
153
|
+
lambda: self.sync_object_empty.wait(state.index, state.phase),
|
|
154
|
+
)
|
|
155
|
+
# This is the difference between this and PipelineTmaAsync: we could have multiple
|
|
156
|
+
# warps calling this, but only 1 warp should do the arrive on the full barrier
|
|
157
|
+
if_generate(
|
|
158
|
+
is_tma_warp,
|
|
159
|
+
lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def producer_commit(self, state: PipelineState):
|
|
163
|
+
"""
|
|
164
|
+
We need the mbarrier to track the completion of cp.async
|
|
165
|
+
"""
|
|
166
|
+
cp_async_mbarrier_arrive_shared(self.producer_get_barrier(state), noinc=True)
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import cutlass
|
|
7
|
+
import cutlass.cute as cute
|
|
8
|
+
|
|
9
|
+
import quack.utils as utils
|
|
10
|
+
from quack.sort.utils import compare_and_swap
|
|
11
|
+
from quack.sort.sorting_networks import optimal_sort
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@cute.jit
|
|
15
|
+
def bitonic_merge(
|
|
16
|
+
arr: cute.Tensor,
|
|
17
|
+
n: cutlass.Constexpr[int],
|
|
18
|
+
start: cutlass.Constexpr[int],
|
|
19
|
+
ascending: cutlass.Constexpr[bool] = True,
|
|
20
|
+
) -> None:
|
|
21
|
+
"""Merge a bitonic sequence into a sorted sequence using iterative approach."""
|
|
22
|
+
if cutlass.const_expr(n > 1):
|
|
23
|
+
num_levels = int(math.log2(n))
|
|
24
|
+
assert n == 2**num_levels, "n must be a power of 2"
|
|
25
|
+
# This one must be range_constexpr otherwise it's very slow for n = 128
|
|
26
|
+
for level in cutlass.range_constexpr(num_levels):
|
|
27
|
+
length = n >> level # n // (2^level)
|
|
28
|
+
step = length // 2
|
|
29
|
+
for i in cutlass.range(n // length, unroll_full=True):
|
|
30
|
+
start_i = start + i * length
|
|
31
|
+
for j in cutlass.range(step, unroll_full=True):
|
|
32
|
+
compare_and_swap(arr, start_i + j, start_i + j + step, ascending)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@cute.jit
|
|
36
|
+
def bitonic_sort(
|
|
37
|
+
arr: cute.Tensor,
|
|
38
|
+
n: Optional[cutlass.Constexpr[int]] = None,
|
|
39
|
+
start: cutlass.Constexpr[int] = 0,
|
|
40
|
+
ascending: cutlass.Constexpr[bool] = True,
|
|
41
|
+
) -> None:
|
|
42
|
+
"""
|
|
43
|
+
Bitonic sort for small arrays of size N (power of 2, N <= 128).
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
arr: Array to sort
|
|
47
|
+
n: Size of array (must be power of 2 and <= 128)
|
|
48
|
+
start: Starting index (default 0)
|
|
49
|
+
ascending: Sort in ascending order (default True)
|
|
50
|
+
"""
|
|
51
|
+
if cutlass.const_expr(n is None):
|
|
52
|
+
n = cute.size(arr.shape)
|
|
53
|
+
assert n <= 128
|
|
54
|
+
if cutlass.const_expr(n > 1):
|
|
55
|
+
if cutlass.const_expr(n in [2, 4, 8, 16, 32, 64]):
|
|
56
|
+
optimal_sort(arr, n, start, ascending)
|
|
57
|
+
else: # Fall back to bitonic sort
|
|
58
|
+
assert n % 2 == 0
|
|
59
|
+
# Sort first half in ascending order
|
|
60
|
+
bitonic_sort(arr, n // 2, start, True)
|
|
61
|
+
# Sort second half in descending order
|
|
62
|
+
bitonic_sort(arr, n // 2, start + n // 2, False)
|
|
63
|
+
# Merge the whole sequence
|
|
64
|
+
bitonic_merge(arr, n, start, ascending)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@cute.jit
|
|
68
|
+
def bitonic_topk_merge(
|
|
69
|
+
arr0: cute.Tensor,
|
|
70
|
+
arr1: cute.Tensor,
|
|
71
|
+
k: Optional[cutlass.Constexpr[int]] = None,
|
|
72
|
+
start0: cutlass.Constexpr[int] = 0,
|
|
73
|
+
start1: cutlass.Constexpr[int] = 0,
|
|
74
|
+
ascending: cutlass.Constexpr[bool] = False,
|
|
75
|
+
) -> None:
|
|
76
|
+
if cutlass.const_expr(k is None):
|
|
77
|
+
k = cute.size(arr0.shape)
|
|
78
|
+
if cutlass.const_expr(arr0.element_type == cutlass.Float32):
|
|
79
|
+
minmax_fn = utils.fmin if ascending else cute.arch.fmax
|
|
80
|
+
else:
|
|
81
|
+
minmax_fn = min if ascending else max
|
|
82
|
+
# Write the top k elements to the first half of the array
|
|
83
|
+
for i in cutlass.range(k, unfoll_full=True):
|
|
84
|
+
arr0[start0 + i] = minmax_fn(arr0[start0 + i], arr1[start1 + k - 1 - i])
|
|
85
|
+
# Now the 1st half is bitonic, we just need to merge it
|
|
86
|
+
bitonic_merge(arr0, k, start0, ascending)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@cute.jit
|
|
90
|
+
def bitonic_topk(
|
|
91
|
+
arr: cute.Tensor,
|
|
92
|
+
k: cutlass.Constexpr[int],
|
|
93
|
+
ascending: cutlass.Constexpr[bool] = False,
|
|
94
|
+
warp_width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
|
|
95
|
+
) -> cute.Tensor:
|
|
96
|
+
"""
|
|
97
|
+
Bitonic top-k for small arrays of size N (power of 2, N <= 128).
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
arr: Array to sort
|
|
101
|
+
k: must be power of 2 and <= 128
|
|
102
|
+
ascending: Sort in ascending order (default False)
|
|
103
|
+
"""
|
|
104
|
+
assert arr.element_type in [cutlass.Float32, cutlass.Int32]
|
|
105
|
+
n = cute.size(arr.shape)
|
|
106
|
+
assert k == 1 << int(math.log2(k)), "k must be a power of 2"
|
|
107
|
+
assert n % k == 0, "n must be divisible by k"
|
|
108
|
+
topk_vals = cute.make_fragment(k, arr.element_type)
|
|
109
|
+
for v in cutlass.range(k, unroll_full=True):
|
|
110
|
+
topk_vals[v] = arr[v]
|
|
111
|
+
bitonic_sort(topk_vals, ascending=ascending)
|
|
112
|
+
other_vals = cute.make_fragment(k, arr.element_type)
|
|
113
|
+
for i in cutlass.range(1, n // k, unroll_full=True):
|
|
114
|
+
for v in cutlass.range(k, unroll_full=True):
|
|
115
|
+
other_vals[v] = arr[i * k + v]
|
|
116
|
+
bitonic_sort(other_vals, ascending=ascending)
|
|
117
|
+
# Merge 2 sorted top-k sequences to get a new top-k sequence
|
|
118
|
+
bitonic_topk_merge(topk_vals, other_vals, ascending=ascending)
|
|
119
|
+
# TODO: this is not efficient for large k (e.g. >= 16) since threads in the same warps
|
|
120
|
+
# do duplicate work.
|
|
121
|
+
for i in cutlass.range(int(math.log2(warp_width)), unroll_full=True):
|
|
122
|
+
other_vals = cute.make_fragment(k, arr.element_type)
|
|
123
|
+
for v in cutlass.range(k, unroll_full=True):
|
|
124
|
+
other_vals[v] = cute.arch.shuffle_sync_bfly(topk_vals[v], offset=1 << i)
|
|
125
|
+
bitonic_topk_merge(topk_vals, other_vals, ascending=ascending)
|
|
126
|
+
return topk_vals
|