quack-kernels 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -34
- quack/gemm.py +194 -0
- quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
- quack/gemm_config.py +72 -46
- quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +177 -31
- quack/gemm_sm100.py +729 -506
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +3 -1
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +476 -526
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +23 -16
- quack/topk.py +409 -85
- quack/utils.py +32 -220
- quack/varlen_utils.py +370 -1
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.2.dist-info/RECORD +0 -37
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/softmax.py
CHANGED
|
@@ -1,14 +1,20 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
|
+
|
|
1
3
|
import math
|
|
2
|
-
import torch
|
|
3
4
|
from typing import Type
|
|
5
|
+
from functools import partial
|
|
6
|
+
|
|
7
|
+
import torch
|
|
4
8
|
|
|
5
9
|
import cuda.bindings.driver as cuda
|
|
6
10
|
|
|
7
11
|
import cutlass
|
|
8
12
|
import cutlass.cute as cute
|
|
9
|
-
from cutlass
|
|
13
|
+
from cutlass import Int64, Float32, const_expr
|
|
10
14
|
|
|
11
15
|
import quack.utils as utils
|
|
16
|
+
import quack.copy_utils as copy_utils
|
|
17
|
+
from quack.compile_utils import make_fake_tensor as fake_tensor
|
|
12
18
|
from quack.reduce import row_reduce, online_softmax_reduce
|
|
13
19
|
from quack.reduction_base import ReductionBase
|
|
14
20
|
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
@@ -21,45 +27,28 @@ class Softmax(ReductionBase):
|
|
|
21
27
|
dtype,
|
|
22
28
|
N,
|
|
23
29
|
stage=2 if not online_softmax else 1,
|
|
24
|
-
reduction_dtype=
|
|
30
|
+
reduction_dtype=Float32 if not online_softmax else Int64,
|
|
25
31
|
)
|
|
26
32
|
self.online_softmax = online_softmax
|
|
27
33
|
|
|
28
|
-
def
|
|
34
|
+
def _threads_per_row(self):
|
|
29
35
|
N = self.N
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
16
|
|
35
|
-
if N <= 128
|
|
36
|
-
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
37
|
-
)
|
|
38
|
-
)
|
|
36
|
+
for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]:
|
|
37
|
+
if N <= limit:
|
|
38
|
+
return threads
|
|
39
|
+
return 256
|
|
39
40
|
|
|
40
41
|
def _set_cluster_n(self):
|
|
41
42
|
N = self.N
|
|
42
|
-
if
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
)
|
|
52
|
-
else: # fp32
|
|
53
|
-
cluster_n = (
|
|
54
|
-
1
|
|
55
|
-
if N <= 32 * 1024
|
|
56
|
-
else (
|
|
57
|
-
2
|
|
58
|
-
if N <= 64 * 1024
|
|
59
|
-
else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
|
|
60
|
-
)
|
|
61
|
-
)
|
|
62
|
-
self.cluster_n = cluster_n
|
|
43
|
+
if const_expr(self.dtype.width == 16):
|
|
44
|
+
thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)]
|
|
45
|
+
else:
|
|
46
|
+
thresholds = [(32 * 1024, 1), (64 * 1024, 2), (128 * 1024, 4), (256 * 1024, 8)]
|
|
47
|
+
for limit, cluster in thresholds:
|
|
48
|
+
if N <= limit:
|
|
49
|
+
self.cluster_n = cluster
|
|
50
|
+
return
|
|
51
|
+
self.cluster_n = 16
|
|
63
52
|
|
|
64
53
|
@cute.jit
|
|
65
54
|
def __call__(
|
|
@@ -69,16 +58,16 @@ class Softmax(ReductionBase):
|
|
|
69
58
|
stream: cuda.CUstream,
|
|
70
59
|
):
|
|
71
60
|
assert mX.element_type == self.dtype
|
|
72
|
-
assert mO.element_type == self.dtype
|
|
73
61
|
self._set_cluster_n()
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
62
|
+
largest_dtype_width = const_expr(max(t.element_type.width for t in [mX, mO]))
|
|
63
|
+
tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(
|
|
64
|
+
vecsize=128 // largest_dtype_width
|
|
65
|
+
)
|
|
66
|
+
num_threads = tiled_copy.size
|
|
67
|
+
self.kernel(mX, mO, tiler_mn, tiled_copy, threads_per_row).launch(
|
|
78
68
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
79
69
|
block=[num_threads, 1, 1],
|
|
80
|
-
cluster=[1, self.cluster_n, 1] if
|
|
81
|
-
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
70
|
+
cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None,
|
|
82
71
|
stream=stream,
|
|
83
72
|
)
|
|
84
73
|
|
|
@@ -87,23 +76,20 @@ class Softmax(ReductionBase):
|
|
|
87
76
|
self,
|
|
88
77
|
mX: cute.Tensor,
|
|
89
78
|
mO: cute.Tensor,
|
|
90
|
-
tv_layout: cute.Layout,
|
|
91
79
|
tiler_mn: cute.Shape,
|
|
80
|
+
tiled_copy: cute.TiledCopy,
|
|
81
|
+
threads_per_row: cutlass.Constexpr[int],
|
|
92
82
|
):
|
|
83
|
+
tv_layout = tiled_copy.layout_tv_tiled
|
|
84
|
+
|
|
93
85
|
tidx, _, _ = cute.arch.thread_idx()
|
|
94
86
|
bidx, _, _ = cute.arch.block_idx()
|
|
95
|
-
if
|
|
96
|
-
cluster_y = cute.arch.block_idx()[1]
|
|
97
|
-
else:
|
|
98
|
-
cluster_y = cutlass.const_expr(0)
|
|
87
|
+
cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1]
|
|
99
88
|
|
|
100
89
|
shape = mX.shape
|
|
101
90
|
idX = cute.make_identity_tensor(shape)
|
|
102
91
|
# slice for CTAs
|
|
103
|
-
|
|
104
|
-
mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
|
|
105
|
-
gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
|
|
106
|
-
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
92
|
+
gX, gO, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)]
|
|
107
93
|
|
|
108
94
|
smem = cutlass.utils.SmemAllocator()
|
|
109
95
|
sX = smem.allocate_tensor(
|
|
@@ -111,52 +97,45 @@ class Softmax(ReductionBase):
|
|
|
111
97
|
)
|
|
112
98
|
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
113
99
|
|
|
114
|
-
|
|
115
|
-
copy_atom_load_X = cute.make_copy_atom(
|
|
116
|
-
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
|
|
117
|
-
)
|
|
118
|
-
copy_atom_store_O = cute.make_copy_atom(
|
|
119
|
-
cute.nvgpu.CopyUniversalOp(), gO.element_type, num_bits_per_copy=128
|
|
120
|
-
)
|
|
121
|
-
|
|
122
|
-
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
123
|
-
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
100
|
+
thr_copy_X = tiled_copy.get_slice(tidx)
|
|
124
101
|
|
|
125
102
|
tXgX = thr_copy_X.partition_S(gX)
|
|
126
103
|
tXsX = thr_copy_X.partition_D(sX)
|
|
127
|
-
tXgO =
|
|
104
|
+
tXgO = thr_copy_X.partition_D(gO)
|
|
128
105
|
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
129
|
-
|
|
130
|
-
# allocate fragments for gmem->rmem
|
|
131
106
|
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
132
107
|
|
|
133
|
-
|
|
134
|
-
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
135
|
-
|
|
136
|
-
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
108
|
+
is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
137
109
|
tXpX = (
|
|
138
|
-
|
|
110
|
+
None
|
|
111
|
+
if is_even_N
|
|
112
|
+
else copy_utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
139
113
|
)
|
|
114
|
+
# Each copy will use the same predicate
|
|
115
|
+
copy = partial(copy_utils.copy, pred=tXpX)
|
|
116
|
+
|
|
117
|
+
num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE
|
|
118
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
119
|
+
|
|
140
120
|
if tXcX[0][0] < shape[0]:
|
|
141
|
-
|
|
121
|
+
copy(tXgX, tXsX, is_async=True)
|
|
142
122
|
cute.arch.cp_async_commit_group()
|
|
143
123
|
cute.arch.cp_async_wait_group(0)
|
|
144
124
|
# Fill OOB values with -inf
|
|
145
|
-
if
|
|
125
|
+
if const_expr(not is_even_N):
|
|
146
126
|
utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
|
|
147
127
|
|
|
148
128
|
cute.autovec_copy(tXsX, tXrX)
|
|
149
129
|
x = tXrX.load().to(cute.Float32)
|
|
150
|
-
|
|
151
|
-
if cutlass.const_expr(not self.online_softmax):
|
|
130
|
+
if const_expr(not self.online_softmax):
|
|
152
131
|
max_x = row_reduce(
|
|
153
132
|
x,
|
|
154
133
|
cute.ReductionOp.MAX,
|
|
155
134
|
threads_per_row,
|
|
156
135
|
reduction_buffer[None, None, 0],
|
|
157
|
-
mbar_ptr + 0 if
|
|
158
|
-
init_val=-
|
|
159
|
-
hook_fn=cute.arch.cluster_wait if
|
|
136
|
+
mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None,
|
|
137
|
+
init_val=-Float32.inf,
|
|
138
|
+
hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
|
|
160
139
|
)
|
|
161
140
|
log2_e = math.log2(math.e)
|
|
162
141
|
exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True)
|
|
@@ -165,7 +144,7 @@ class Softmax(ReductionBase):
|
|
|
165
144
|
cute.ReductionOp.ADD,
|
|
166
145
|
threads_per_row,
|
|
167
146
|
reduction_buffer[None, None, 1],
|
|
168
|
-
mbar_ptr + 1 if
|
|
147
|
+
mbar_ptr + 1 if const_expr(self.cluster_n > 1) else None,
|
|
169
148
|
init_val=0.0,
|
|
170
149
|
)
|
|
171
150
|
else:
|
|
@@ -174,18 +153,14 @@ class Softmax(ReductionBase):
|
|
|
174
153
|
threads_per_row,
|
|
175
154
|
reduction_buffer[None, None, 0],
|
|
176
155
|
mbar_ptr,
|
|
177
|
-
hook_fn=cute.arch.cluster_wait if
|
|
156
|
+
hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
|
|
178
157
|
return_exp_x=True,
|
|
179
158
|
)
|
|
180
|
-
y = exp_x * (1.0 / denom)
|
|
159
|
+
# y = exp_x * (1.0 / denom)
|
|
160
|
+
y = exp_x * cute.arch.rcp_approx(denom)
|
|
181
161
|
tXrO.store(y.to(tXrO.element_type))
|
|
182
|
-
tOpO = (
|
|
183
|
-
utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
|
|
184
|
-
if cutlass.const_expr(not is_even_N)
|
|
185
|
-
else None
|
|
186
|
-
)
|
|
187
162
|
if tXcX[0][0] < shape[0]:
|
|
188
|
-
|
|
163
|
+
copy(tXrO, tXgO)
|
|
189
164
|
|
|
190
165
|
|
|
191
166
|
@torch.library.custom_op("quack::_softmax_fwd", mutates_args={"out"})
|
|
@@ -200,21 +175,21 @@ def _softmax_fwd(x: torch.Tensor, out: torch.Tensor) -> None:
|
|
|
200
175
|
assert x.is_cuda, "Tensor must be on CUDA device"
|
|
201
176
|
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
202
177
|
N = x.size(1)
|
|
203
|
-
dtype = torch2cute_dtype_map[
|
|
204
|
-
|
|
205
|
-
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
206
|
-
mode=0, stride_order=(0, 1)
|
|
207
|
-
)
|
|
208
|
-
)
|
|
209
|
-
x_tensor, out_tensor = [convert_from_dlpack(tensor) for tensor in (x, out)]
|
|
210
|
-
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
211
|
-
compile_key = (dtype, N)
|
|
178
|
+
dtype, out_dtype = [torch2cute_dtype_map[t.dtype] for t in [x, out]]
|
|
179
|
+
compile_key = (dtype, out_dtype, N)
|
|
212
180
|
if compile_key not in _softmax_fwd.compile_cache:
|
|
181
|
+
batch_sym = cute.sym_int()
|
|
182
|
+
div = math.gcd(128 // dtype.width, N)
|
|
183
|
+
x_cute, out_cute = [fake_tensor(dt, (batch_sym, N), div) for dt in [dtype, out_dtype]]
|
|
213
184
|
softmax_op = Softmax(dtype, N)
|
|
214
185
|
_softmax_fwd.compile_cache[compile_key] = cute.compile(
|
|
215
|
-
softmax_op,
|
|
186
|
+
softmax_op,
|
|
187
|
+
x_cute,
|
|
188
|
+
out_cute,
|
|
189
|
+
cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
|
|
190
|
+
options="--enable-tvm-ffi",
|
|
216
191
|
)
|
|
217
|
-
_softmax_fwd.compile_cache[compile_key](
|
|
192
|
+
_softmax_fwd.compile_cache[compile_key](x, out)
|
|
218
193
|
|
|
219
194
|
|
|
220
195
|
_softmax_fwd.compile_cache = {}
|
|
@@ -229,55 +204,30 @@ def softmax_fwd(x: torch.Tensor) -> torch.Tensor:
|
|
|
229
204
|
class SoftmaxBackward(ReductionBase):
|
|
230
205
|
def __init__(self, dtype: Type[cutlass.Numeric], N: int):
|
|
231
206
|
# 1 stage for computing dot product
|
|
232
|
-
super().__init__(dtype, N, stage=1, reduction_dtype=
|
|
207
|
+
super().__init__(dtype, N, stage=1, reduction_dtype=Float32)
|
|
233
208
|
|
|
234
|
-
def
|
|
209
|
+
def _threads_per_row(self):
|
|
235
210
|
N = self.N
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
16
|
|
241
|
-
if N <= 128
|
|
242
|
-
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 8192 else 256)))
|
|
243
|
-
)
|
|
244
|
-
)
|
|
211
|
+
for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (8192, 128)]:
|
|
212
|
+
if N <= limit:
|
|
213
|
+
return threads
|
|
214
|
+
return 256
|
|
245
215
|
|
|
246
216
|
def _set_cluster_n(self):
|
|
247
217
|
N = self.N
|
|
248
|
-
if
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
cluster_n = (
|
|
260
|
-
1
|
|
261
|
-
if N <= 16 * 1024
|
|
262
|
-
else (
|
|
263
|
-
2
|
|
264
|
-
if N <= 32 * 1024
|
|
265
|
-
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
266
|
-
)
|
|
267
|
-
)
|
|
268
|
-
self.cluster_n = cluster_n
|
|
269
|
-
|
|
270
|
-
def _get_num_threads(self):
|
|
218
|
+
if const_expr(self.dtype.width == 16):
|
|
219
|
+
thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)]
|
|
220
|
+
else:
|
|
221
|
+
thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)]
|
|
222
|
+
for limit, cluster in thresholds:
|
|
223
|
+
if N <= limit:
|
|
224
|
+
self.cluster_n = cluster
|
|
225
|
+
return
|
|
226
|
+
self.cluster_n = 16
|
|
227
|
+
|
|
228
|
+
def _num_threads(self):
|
|
271
229
|
return 128 if self.N <= 8192 else 256
|
|
272
230
|
|
|
273
|
-
def _smem_size_in_bytes(self, tiler_mn, num_warps):
|
|
274
|
-
return (
|
|
275
|
-
# Multiply by 2 since we need space for Y and dY
|
|
276
|
-
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2
|
|
277
|
-
+ self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
|
|
278
|
-
+ self.stage * (cutlass.Int64.width // 8)
|
|
279
|
-
)
|
|
280
|
-
|
|
281
231
|
@cute.jit
|
|
282
232
|
def __call__(
|
|
283
233
|
self,
|
|
@@ -287,17 +237,16 @@ class SoftmaxBackward(ReductionBase):
|
|
|
287
237
|
stream: cuda.CUstream,
|
|
288
238
|
):
|
|
289
239
|
assert mdY.element_type == self.dtype
|
|
290
|
-
assert mY.element_type == self.dtype
|
|
291
|
-
assert mdX.element_type == self.dtype
|
|
292
240
|
self._set_cluster_n()
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
241
|
+
largest_dtype_width = const_expr(max(t.element_type.width for t in [mdY, mY, mdX]))
|
|
242
|
+
tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(
|
|
243
|
+
vecsize=128 // largest_dtype_width
|
|
244
|
+
)
|
|
245
|
+
num_threads = tiled_copy.size
|
|
246
|
+
self.kernel(mdY, mY, mdX, tiler_mn, tiled_copy, threads_per_row).launch(
|
|
297
247
|
grid=[cute.ceil_div(mdY.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
298
248
|
block=[num_threads, 1, 1],
|
|
299
|
-
cluster=[1, self.cluster_n, 1] if
|
|
300
|
-
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
249
|
+
cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None,
|
|
301
250
|
stream=stream,
|
|
302
251
|
)
|
|
303
252
|
|
|
@@ -307,24 +256,21 @@ class SoftmaxBackward(ReductionBase):
|
|
|
307
256
|
mdY: cute.Tensor,
|
|
308
257
|
mY: cute.Tensor,
|
|
309
258
|
mdX: cute.Tensor,
|
|
310
|
-
tv_layout: cute.Layout,
|
|
311
259
|
tiler_mn: cute.Shape,
|
|
260
|
+
tiled_copy: cute.TiledCopy,
|
|
261
|
+
threads_per_row: cutlass.Constexpr[int],
|
|
312
262
|
):
|
|
313
263
|
tidx, _, _ = cute.arch.thread_idx()
|
|
314
264
|
bidx, _, _ = cute.arch.block_idx()
|
|
315
|
-
if
|
|
316
|
-
|
|
317
|
-
else:
|
|
318
|
-
cluster_y = cutlass.const_expr(0)
|
|
265
|
+
cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1]
|
|
266
|
+
tv_layout = tiled_copy.layout_tv_tiled
|
|
319
267
|
|
|
320
268
|
shape = mdY.shape
|
|
321
269
|
idX = cute.make_identity_tensor(shape)
|
|
322
270
|
# slice for CTAs
|
|
323
|
-
|
|
324
|
-
|
|
271
|
+
gdY, gY, gdX, cX = [
|
|
272
|
+
cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mdY, mY, mdX, idX)
|
|
325
273
|
]
|
|
326
|
-
gdY, gY, gdX = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mdY, mY, mdX)]
|
|
327
|
-
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
328
274
|
|
|
329
275
|
smem = cutlass.utils.SmemAllocator()
|
|
330
276
|
sdY = smem.allocate_tensor(
|
|
@@ -335,42 +281,32 @@ class SoftmaxBackward(ReductionBase):
|
|
|
335
281
|
)
|
|
336
282
|
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
337
283
|
|
|
338
|
-
|
|
339
|
-
copy_atom_load = cute.make_copy_atom(
|
|
340
|
-
cute.nvgpu.cpasync.CopyG2SOp(), mdY.element_type, num_bits_per_copy=128
|
|
341
|
-
)
|
|
342
|
-
copy_atom_store = cute.make_copy_atom(
|
|
343
|
-
cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=128
|
|
344
|
-
)
|
|
284
|
+
thr_copy = tiled_copy.get_slice(tidx)
|
|
345
285
|
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
tYsY = thr_copy_load.partition_D(sY)
|
|
353
|
-
tdXgdX = thr_copy_store.partition_D(gdX)
|
|
354
|
-
tXcX = thr_copy_load.partition_S(cX)[(0, None), None, None]
|
|
355
|
-
|
|
356
|
-
# allocate fragments for gmem->rmem
|
|
286
|
+
tdYgdY = thr_copy.partition_S(gdY)
|
|
287
|
+
tdYsdY = thr_copy.partition_D(sdY)
|
|
288
|
+
tYgY = thr_copy.partition_S(gY)
|
|
289
|
+
tYsY = thr_copy.partition_D(sY)
|
|
290
|
+
tdXgdX = thr_copy.partition_D(gdX)
|
|
291
|
+
tXcX = thr_copy.partition_S(cX)[(0, None), None, None]
|
|
357
292
|
tdYrdY, tYrY, tdXrdX = [cute.make_fragment_like(thr) for thr in (tdYgdY, tYgY, tdXgdX)]
|
|
358
293
|
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
363
|
-
tdYpdY = (
|
|
364
|
-
utils.predicate_k(thr_copy_load.partition_S(cX), limit=shape[1])
|
|
365
|
-
if cutlass.const_expr(not is_even_N)
|
|
366
|
-
else None
|
|
294
|
+
is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
295
|
+
tXpX = (
|
|
296
|
+
None if is_even_N else copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1])
|
|
367
297
|
)
|
|
298
|
+
# Each copy will use the same predicate
|
|
299
|
+
copy = partial(copy_utils.copy, pred=tXpX)
|
|
300
|
+
|
|
301
|
+
num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE
|
|
302
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
368
303
|
|
|
369
304
|
if tXcX[0][0] < shape[0]:
|
|
370
|
-
|
|
371
|
-
|
|
305
|
+
copy(tdYgdY, tdYsdY, is_async=True)
|
|
306
|
+
copy(tYgY, tYsY, is_async=True)
|
|
372
307
|
cute.arch.cp_async_commit_group()
|
|
373
308
|
cute.arch.cp_async_wait_group(0)
|
|
309
|
+
# Don't need fill_oob since cp.async will automatically fills OOB elements with zeros
|
|
374
310
|
|
|
375
311
|
cute.autovec_copy(tdYsdY, tdYrdY)
|
|
376
312
|
cute.autovec_copy(tYsY, tYrY)
|
|
@@ -378,27 +314,21 @@ class SoftmaxBackward(ReductionBase):
|
|
|
378
314
|
y = tYrY.load().to(cute.Float32)
|
|
379
315
|
|
|
380
316
|
# Compute dot product: dot = Σⱼ dy_j × y_j
|
|
381
|
-
threads_per_row = tv_layout.shape[0][0]
|
|
382
317
|
dot = row_reduce(
|
|
383
318
|
dy * y,
|
|
384
319
|
cute.ReductionOp.ADD,
|
|
385
320
|
threads_per_row,
|
|
386
321
|
reduction_buffer[None, None, 0],
|
|
387
|
-
mbar_ptr if
|
|
322
|
+
mbar_ptr if const_expr(self.cluster_n > 1) else None,
|
|
388
323
|
init_val=0.0,
|
|
389
|
-
hook_fn=cute.arch.cluster_wait if
|
|
324
|
+
hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
|
|
390
325
|
)
|
|
391
326
|
|
|
392
327
|
# Compute gradient: dx_i = y_i × (dy_i - dot)
|
|
393
328
|
dx = y * (dy - dot)
|
|
394
329
|
tdXrdX.store(dx.to(tdXrdX.element_type))
|
|
395
|
-
tdXpdX = (
|
|
396
|
-
utils.predicate_k(thr_copy_store.partition_S(cX), limit=shape[1])
|
|
397
|
-
if cutlass.const_expr(not is_even_N)
|
|
398
|
-
else None
|
|
399
|
-
)
|
|
400
330
|
if tXcX[0][0] < shape[0]:
|
|
401
|
-
|
|
331
|
+
copy(tdXrdX, tdXgdX)
|
|
402
332
|
|
|
403
333
|
|
|
404
334
|
@torch.library.custom_op("quack::_softmax_backward", mutates_args={"dx"})
|
|
@@ -418,22 +348,24 @@ def _softmax_backward(dy: torch.Tensor, y: torch.Tensor, dx: torch.Tensor) -> No
|
|
|
418
348
|
assert y.dtype == dy.dtype, "dy and y must have same dtype"
|
|
419
349
|
|
|
420
350
|
N = dy.size(1)
|
|
421
|
-
dtype = torch2cute_dtype_map[
|
|
422
|
-
|
|
423
|
-
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
424
|
-
mode=0, stride_order=(0, 1)
|
|
425
|
-
)
|
|
426
|
-
)
|
|
427
|
-
dy_tensor, y_tensor, dx_tensor = [convert_from_dlpack(tensor) for tensor in (dy, y, dx)]
|
|
428
|
-
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
429
|
-
|
|
430
|
-
compile_key = (dtype, N)
|
|
351
|
+
dtype, y_dtype, dx_dtype = [torch2cute_dtype_map[t.dtype] for t in [dy, y, dx]]
|
|
352
|
+
compile_key = (dtype, y_dtype, dx_dtype, N)
|
|
431
353
|
if compile_key not in _softmax_backward.compile_cache:
|
|
354
|
+
batch_sym = cute.sym_int()
|
|
355
|
+
div = math.gcd(128 // dtype.width, N)
|
|
356
|
+
dy_cute, y_cute, dx_cute = [
|
|
357
|
+
fake_tensor(dt, (batch_sym, N), div) for dt in [dtype, y_dtype, dx_dtype]
|
|
358
|
+
]
|
|
432
359
|
softmax_backward_op = SoftmaxBackward(dtype, N)
|
|
433
360
|
_softmax_backward.compile_cache[compile_key] = cute.compile(
|
|
434
|
-
softmax_backward_op,
|
|
361
|
+
softmax_backward_op,
|
|
362
|
+
dy_cute,
|
|
363
|
+
y_cute,
|
|
364
|
+
dx_cute,
|
|
365
|
+
cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
|
|
366
|
+
options="--enable-tvm-ffi",
|
|
435
367
|
)
|
|
436
|
-
_softmax_backward.compile_cache[compile_key](
|
|
368
|
+
_softmax_backward.compile_cache[compile_key](dy, y, dx)
|
|
437
369
|
|
|
438
370
|
|
|
439
371
|
_softmax_backward.compile_cache = {}
|
quack/sort/bitonic_sort.py
CHANGED
|
@@ -5,6 +5,7 @@ from typing import Optional
|
|
|
5
5
|
|
|
6
6
|
import cutlass
|
|
7
7
|
import cutlass.cute as cute
|
|
8
|
+
from cutlass import Int32, Float32, const_expr
|
|
8
9
|
|
|
9
10
|
import quack.utils as utils
|
|
10
11
|
from quack.sort.utils import compare_and_swap
|
|
@@ -14,12 +15,14 @@ from quack.sort.sorting_networks import optimal_sort
|
|
|
14
15
|
@cute.jit
|
|
15
16
|
def bitonic_merge(
|
|
16
17
|
arr: cute.Tensor,
|
|
17
|
-
n: cutlass.Constexpr[int],
|
|
18
|
-
start: cutlass.Constexpr[int],
|
|
18
|
+
n: Optional[cutlass.Constexpr[int]] = None,
|
|
19
|
+
start: cutlass.Constexpr[int] = 0,
|
|
19
20
|
ascending: cutlass.Constexpr[bool] = True,
|
|
20
21
|
) -> None:
|
|
21
22
|
"""Merge a bitonic sequence into a sorted sequence using iterative approach."""
|
|
22
|
-
if
|
|
23
|
+
if const_expr(n is None):
|
|
24
|
+
n = cute.size(arr.shape)
|
|
25
|
+
if const_expr(n > 1):
|
|
23
26
|
num_levels = int(math.log2(n))
|
|
24
27
|
assert n == 2**num_levels, "n must be a power of 2"
|
|
25
28
|
# This one must be range_constexpr otherwise it's very slow for n = 128
|
|
@@ -48,11 +51,11 @@ def bitonic_sort(
|
|
|
48
51
|
start: Starting index (default 0)
|
|
49
52
|
ascending: Sort in ascending order (default True)
|
|
50
53
|
"""
|
|
51
|
-
if
|
|
54
|
+
if const_expr(n is None):
|
|
52
55
|
n = cute.size(arr.shape)
|
|
53
56
|
assert n <= 128
|
|
54
|
-
if
|
|
55
|
-
if
|
|
57
|
+
if const_expr(n > 1):
|
|
58
|
+
if const_expr(n in [2, 4, 8, 16, 32, 64]):
|
|
56
59
|
optimal_sort(arr, n, start, ascending)
|
|
57
60
|
else: # Fall back to bitonic sort
|
|
58
61
|
assert n % 2 == 0
|
|
@@ -73,9 +76,9 @@ def bitonic_topk_merge(
|
|
|
73
76
|
start1: cutlass.Constexpr[int] = 0,
|
|
74
77
|
ascending: cutlass.Constexpr[bool] = False,
|
|
75
78
|
) -> None:
|
|
76
|
-
if
|
|
79
|
+
if const_expr(k is None):
|
|
77
80
|
k = cute.size(arr0.shape)
|
|
78
|
-
if
|
|
81
|
+
if const_expr(arr0.element_type == Float32):
|
|
79
82
|
minmax_fn = utils.fmin if ascending else cute.arch.fmax
|
|
80
83
|
else:
|
|
81
84
|
minmax_fn = min if ascending else max
|
|
@@ -101,7 +104,7 @@ def bitonic_topk(
|
|
|
101
104
|
k: must be power of 2 and <= 128
|
|
102
105
|
ascending: Sort in ascending order (default False)
|
|
103
106
|
"""
|
|
104
|
-
assert arr.element_type in [
|
|
107
|
+
assert arr.element_type in [Float32, Int32]
|
|
105
108
|
n = cute.size(arr.shape)
|
|
106
109
|
assert k == 1 << int(math.log2(k)), "k must be a power of 2"
|
|
107
110
|
assert n % k == 0, "n must be divisible by k"
|
|
@@ -109,8 +112,8 @@ def bitonic_topk(
|
|
|
109
112
|
for v in cutlass.range(k, unroll_full=True):
|
|
110
113
|
topk_vals[v] = arr[v]
|
|
111
114
|
bitonic_sort(topk_vals, ascending=ascending)
|
|
112
|
-
other_vals = cute.make_fragment(k, arr.element_type)
|
|
113
115
|
for i in cutlass.range(1, n // k, unroll_full=True):
|
|
116
|
+
other_vals = cute.make_fragment(k, arr.element_type)
|
|
114
117
|
for v in cutlass.range(k, unroll_full=True):
|
|
115
118
|
other_vals[v] = arr[i * k + v]
|
|
116
119
|
bitonic_sort(other_vals, ascending=ascending)
|
quack/sort/utils.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import cutlass
|
|
2
1
|
import cutlass.cute as cute
|
|
2
|
+
from cutlass import Float32, const_expr
|
|
3
3
|
|
|
4
4
|
import quack.utils as utils
|
|
5
5
|
|
|
@@ -9,12 +9,12 @@ def compare_and_swap(
|
|
|
9
9
|
arr: cute.Tensor, i: int, j: int, ascending: bool = True, use_selection: bool = False
|
|
10
10
|
) -> None:
|
|
11
11
|
"""Compare and swap elements at indices i and j in ascending or descending order."""
|
|
12
|
-
if
|
|
12
|
+
if const_expr(use_selection):
|
|
13
13
|
a, b = arr[i], arr[j]
|
|
14
14
|
if (a > b) ^ (not ascending):
|
|
15
15
|
arr[i] = b
|
|
16
16
|
arr[j] = a
|
|
17
|
-
# if
|
|
17
|
+
# if const_expr(ascending):
|
|
18
18
|
# if a > b:
|
|
19
19
|
# arr[i] = b
|
|
20
20
|
# arr[j] = a
|
|
@@ -23,9 +23,9 @@ def compare_and_swap(
|
|
|
23
23
|
# arr[i] = b
|
|
24
24
|
# arr[j] = a
|
|
25
25
|
else:
|
|
26
|
-
min_fn = min if
|
|
27
|
-
max_fn = max if
|
|
28
|
-
if
|
|
26
|
+
min_fn = min if const_expr(arr.element_type != Float32) else utils.fmin
|
|
27
|
+
max_fn = max if const_expr(arr.element_type != Float32) else cute.arch.fmax
|
|
28
|
+
if const_expr(ascending):
|
|
29
29
|
arr[i], arr[j] = min_fn(arr[i], arr[j]), max_fn(arr[i], arr[j])
|
|
30
30
|
else:
|
|
31
31
|
arr[i], arr[j] = max_fn(arr[i], arr[j]), min_fn(arr[i], arr[j])
|