quack-kernels 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -1
- quack/cross_entropy.py +197 -166
- quack/reduction_base.py +98 -0
- quack/rmsnorm.py +211 -181
- quack/softmax.py +409 -156
- quack/utils.py +197 -39
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.3.dist-info}/METADATA +4 -1
- quack_kernels-0.1.3.dist-info/RECORD +11 -0
- quack_kernels-0.1.2.dist-info/RECORD +0 -10
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.3.dist-info}/top_level.txt +0 -0
quack/__init__.py
CHANGED
quack/cross_entropy.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import torch
|
|
3
|
-
import
|
|
4
|
-
from typing import Callable, Union, Optional
|
|
3
|
+
from typing import Optional, Type
|
|
5
4
|
|
|
6
5
|
import cuda.bindings.driver as cuda
|
|
7
6
|
|
|
@@ -10,169 +9,192 @@ import cutlass.cute as cute
|
|
|
10
9
|
from cutlass.cute.runtime import from_dlpack
|
|
11
10
|
|
|
12
11
|
import quack.utils as utils
|
|
12
|
+
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
|
|
16
|
-
def
|
|
17
|
-
mX: cute.Tensor, # (M, N)
|
|
18
|
-
mTarget: cute.Tensor, # (M,)
|
|
19
|
-
mLoss: cute.Tensor, # (M,)
|
|
20
|
-
mLSE: Optional[cute.Tensor], # (M,)
|
|
21
|
-
tv_layout: cute.Layout,
|
|
22
|
-
tiler_mn: cute.Shape,
|
|
23
|
-
cluster_n: cutlass.Constexpr = 1,
|
|
24
|
-
):
|
|
25
|
-
tidx, _, _ = cute.arch.thread_idx()
|
|
26
|
-
bidx, cluster_y, _ = cute.arch.block_idx()
|
|
27
|
-
|
|
28
|
-
shape: cute.Shape = mX.shape
|
|
29
|
-
idX = cute.make_identity_tensor(shape)
|
|
30
|
-
# slice for CTAs
|
|
31
|
-
gX, cX = [
|
|
32
|
-
cute.local_tile(mT, tiler_mn, (bidx, 0 if cluster_n == 1 else cluster_y))
|
|
33
|
-
for mT in (mX, idX)
|
|
34
|
-
]
|
|
35
|
-
|
|
36
|
-
smem = cutlass.utils.SmemAllocator()
|
|
37
|
-
sX = smem.allocate_tensor(mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16)
|
|
38
|
-
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
39
|
-
warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
|
|
40
|
-
reduction_buffer_layout = cute.make_ordered_layout(
|
|
15
|
+
class CrossEntropy(ReductionBase):
|
|
16
|
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int, online_softmax: bool = True):
|
|
41
17
|
# 2 stages: 1 for max, 1 for sum
|
|
42
|
-
(
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
if cluster_n > 1:
|
|
63
|
-
if tidx < 2:
|
|
64
|
-
cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + tidx, 1)
|
|
65
|
-
cute.arch.mbarrier_init_fence()
|
|
66
|
-
if tidx < 2:
|
|
67
|
-
cute.arch.mbarrier_init_tx_bytes(mbar_ptr + tidx, num_warps * cluster_n * cutlass.Float32.width // 8)
|
|
68
|
-
# Cluster arrive after barrier init
|
|
69
|
-
cute.arch.cluster_arrive_relaxed()
|
|
70
|
-
|
|
71
|
-
row = tXcX[0][0]
|
|
72
|
-
target = cute.Int32.zero
|
|
73
|
-
if row < shape[0] and tXcX[0][1] == 0:
|
|
74
|
-
target = cute.Int32(mTarget[row])
|
|
75
|
-
|
|
76
|
-
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * cluster_n)
|
|
77
|
-
tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
78
|
-
if row < shape[0]:
|
|
79
|
-
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
|
80
|
-
cute.arch.cp_async_commit_group()
|
|
81
|
-
cute.arch.cp_async_wait_group(0)
|
|
82
|
-
cute.autovec_copy(tXsX, tXrX)
|
|
83
|
-
x = tXrX.load().to(cute.Float32)
|
|
84
|
-
# Fill OOB values with -inf
|
|
85
|
-
if cutlass.const_expr(not is_even_N):
|
|
86
|
-
tXrX_fp32 = cute.make_fragment_like(tXrX, cutlass.Float32)
|
|
87
|
-
tXrX_fp32.store(x)
|
|
88
|
-
for rest_v in range(tXpX.shape[0]):
|
|
89
|
-
for rest_k in range(tXpX.shape[2]):
|
|
90
|
-
if not tXpX[rest_v, 0, rest_k]:
|
|
91
|
-
tXrX_fp32[(None, rest_v), None, rest_k].fill(-cutlass.Float32.inf)
|
|
92
|
-
x = tXrX_fp32.load()
|
|
93
|
-
|
|
94
|
-
target_logit = cute.Float32.zero
|
|
95
|
-
if row < shape[0] and tXcX[0][1] == 0:
|
|
96
|
-
target_logit = cute.Float32(mX[row, target])
|
|
97
|
-
|
|
98
|
-
threads_per_row = tv_layout.shape[0][0]
|
|
99
|
-
max_x = utils.row_reduce(
|
|
100
|
-
x,
|
|
101
|
-
cute.ReductionOp.MAX,
|
|
102
|
-
threads_per_row,
|
|
103
|
-
reduction_buffer[None, None, 0],
|
|
104
|
-
mbar_ptr + 0 if cluster_n > 1 else None,
|
|
105
|
-
init_val=-cutlass.Float32.inf,
|
|
106
|
-
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(cluster_n > 1) else None
|
|
107
|
-
)
|
|
108
|
-
log2_e = math.log2(math.e)
|
|
109
|
-
# exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
|
|
110
|
-
exp_x = utils.exp2f((x - max_x) * log2_e) # a bit faster, idk why
|
|
111
|
-
denom = utils.row_reduce(
|
|
112
|
-
exp_x,
|
|
113
|
-
cute.ReductionOp.ADD,
|
|
114
|
-
threads_per_row,
|
|
115
|
-
reduction_buffer[None, None, 1],
|
|
116
|
-
mbar_ptr + 1 if cluster_n > 1 else None,
|
|
117
|
-
init_val=0.0,
|
|
118
|
-
)
|
|
18
|
+
super().__init__(
|
|
19
|
+
dtype,
|
|
20
|
+
N,
|
|
21
|
+
stage=2 if not online_softmax else 1,
|
|
22
|
+
reduction_dtype=cutlass.Float32 if not online_softmax else cutlass.Int64,
|
|
23
|
+
)
|
|
24
|
+
self.online_softmax = online_softmax
|
|
25
|
+
self.reload_from = None if N <= 16384 or online_softmax else "smem"
|
|
26
|
+
|
|
27
|
+
def _calculate_threads_per_row(self):
|
|
28
|
+
N = self.N
|
|
29
|
+
return (
|
|
30
|
+
8
|
|
31
|
+
if N <= 64
|
|
32
|
+
else (
|
|
33
|
+
16
|
|
34
|
+
if N <= 128
|
|
35
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
36
|
+
)
|
|
37
|
+
)
|
|
119
38
|
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
39
|
+
def _set_cluster_n(self):
|
|
40
|
+
N = self.N
|
|
41
|
+
if cutlass.const_expr(self.dtype.width == 16):
|
|
42
|
+
cluster_n = (
|
|
43
|
+
1
|
|
44
|
+
if N <= 16 * 1024
|
|
45
|
+
else (
|
|
46
|
+
2
|
|
47
|
+
if N <= 32 * 1024
|
|
48
|
+
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
49
|
+
)
|
|
50
|
+
)
|
|
51
|
+
else: # fp32
|
|
52
|
+
cluster_n = (
|
|
53
|
+
1
|
|
54
|
+
if N <= 16 * 1024
|
|
55
|
+
else (
|
|
56
|
+
2
|
|
57
|
+
if N <= 64 * 1024
|
|
58
|
+
else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
|
|
59
|
+
)
|
|
60
|
+
)
|
|
61
|
+
self.cluster_n = cluster_n
|
|
62
|
+
|
|
63
|
+
@cute.jit
|
|
64
|
+
def __call__(
|
|
65
|
+
self,
|
|
66
|
+
mX: cute.Tensor,
|
|
67
|
+
mTarget: cute.Tensor,
|
|
68
|
+
mLoss: cute.Tensor,
|
|
69
|
+
mLSE: Optional[cute.Tensor],
|
|
70
|
+
stream: cuda.CUstream,
|
|
71
|
+
):
|
|
72
|
+
assert mX.element_type == self.dtype
|
|
73
|
+
self._set_cluster_n()
|
|
74
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
75
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
76
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
77
|
+
self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
|
|
78
|
+
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
79
|
+
block=[num_threads, 1, 1],
|
|
80
|
+
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
81
|
+
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
82
|
+
stream=stream,
|
|
83
|
+
)
|
|
159
84
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
85
|
+
@cute.kernel
|
|
86
|
+
def kernel(
|
|
87
|
+
self,
|
|
88
|
+
mX: cute.Tensor, # (M, N)
|
|
89
|
+
mTarget: cute.Tensor, # (M,)
|
|
90
|
+
mLoss: cute.Tensor, # (M,)
|
|
91
|
+
mLSE: Optional[cute.Tensor], # (M,)
|
|
92
|
+
tv_layout: cute.Layout,
|
|
93
|
+
tiler_mn: cute.Shape,
|
|
94
|
+
):
|
|
95
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
96
|
+
bidx, cluster_y, _ = cute.arch.block_idx()
|
|
97
|
+
|
|
98
|
+
shape: cute.Shape = mX.shape
|
|
99
|
+
idX = cute.make_identity_tensor(shape)
|
|
100
|
+
# slice for CTAs
|
|
101
|
+
gX, cX = [
|
|
102
|
+
cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
|
|
103
|
+
for mT in (mX, idX)
|
|
104
|
+
]
|
|
105
|
+
|
|
106
|
+
smem = cutlass.utils.SmemAllocator()
|
|
107
|
+
sX = smem.allocate_tensor(
|
|
108
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
109
|
+
)
|
|
110
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
169
111
|
|
|
112
|
+
# declare the atoms which will be used later for memory copy
|
|
113
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
114
|
+
cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
|
|
115
|
+
)
|
|
116
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
117
|
+
|
|
118
|
+
#### Thread View
|
|
119
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
120
|
+
tXsX = thr_copy_X.partition_D(sX)
|
|
121
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
122
|
+
tXrX = cute.make_fragment_like(tXgX)
|
|
123
|
+
|
|
124
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
125
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
170
126
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
127
|
+
row = tXcX[0][0]
|
|
128
|
+
target = cute.Int32.zero
|
|
129
|
+
if row < shape[0] and tXcX[0][1] == 0:
|
|
130
|
+
target = cute.Int32(mTarget[row])
|
|
131
|
+
|
|
132
|
+
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
133
|
+
tXpX = (
|
|
134
|
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
135
|
+
)
|
|
136
|
+
if row < shape[0]:
|
|
137
|
+
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
|
138
|
+
cute.arch.cp_async_commit_group()
|
|
139
|
+
cute.arch.cp_async_wait_group(0)
|
|
140
|
+
# Fill OOB values with -inf
|
|
141
|
+
if cutlass.const_expr(not is_even_N):
|
|
142
|
+
utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
|
|
143
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
144
|
+
x = tXrX.load().to(cute.Float32)
|
|
145
|
+
|
|
146
|
+
target_logit = cute.Float32.zero
|
|
147
|
+
if row < shape[0] and tXcX[0][1] == 0:
|
|
148
|
+
target_logit = cute.Float32(mX[row, target])
|
|
149
|
+
|
|
150
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
151
|
+
if cutlass.const_expr(not self.online_softmax):
|
|
152
|
+
max_x = utils.row_reduce(
|
|
153
|
+
x,
|
|
154
|
+
cute.ReductionOp.MAX,
|
|
155
|
+
threads_per_row,
|
|
156
|
+
reduction_buffer[None, None, 0],
|
|
157
|
+
mbar_ptr + 0 if self.cluster_n > 1 else None,
|
|
158
|
+
init_val=-cutlass.Float32.inf,
|
|
159
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
160
|
+
)
|
|
161
|
+
if cutlass.const_expr(self.reload_from == "smem"):
|
|
162
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
163
|
+
x = tXrX.load().to(cute.Float32)
|
|
164
|
+
log2_e = math.log2(math.e)
|
|
165
|
+
# exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
|
|
166
|
+
# a bit faster, probably because it's calling ex2.approx.ftz instead of ex2.approx?
|
|
167
|
+
# exp_x = utils.exp2f((x - max_x) * log2_e)
|
|
168
|
+
# This would use ffma instead of fadd then fmul
|
|
169
|
+
exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
|
|
170
|
+
denom = utils.row_reduce(
|
|
171
|
+
exp_x,
|
|
172
|
+
cute.ReductionOp.ADD,
|
|
173
|
+
threads_per_row,
|
|
174
|
+
reduction_buffer[None, None, 1],
|
|
175
|
+
mbar_ptr + 1 if self.cluster_n > 1 else None,
|
|
176
|
+
init_val=0.0,
|
|
177
|
+
)
|
|
178
|
+
else:
|
|
179
|
+
max_x, denom, _ = utils.online_softmax_reduce(
|
|
180
|
+
x,
|
|
181
|
+
threads_per_row,
|
|
182
|
+
reduction_buffer[None, None, 0],
|
|
183
|
+
mbar_ptr,
|
|
184
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
if (
|
|
188
|
+
tXcX[0][1] == 0
|
|
189
|
+
and row < shape[0]
|
|
190
|
+
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
191
|
+
):
|
|
192
|
+
ln_2 = math.log(2.0)
|
|
193
|
+
lse = max_x + utils.log2f(denom) * ln_2
|
|
194
|
+
loss_val = lse - target_logit
|
|
195
|
+
mLoss[row] = loss_val.to(mLoss.element_type)
|
|
196
|
+
if cutlass.const_expr(mLSE is not None):
|
|
197
|
+
mLSE[row] = lse
|
|
176
198
|
|
|
177
199
|
|
|
178
200
|
def cross_entropy(
|
|
@@ -194,27 +216,36 @@ def cross_entropy(
|
|
|
194
216
|
assert x.shape[0] == target.shape[0], "Batch dimensions must match"
|
|
195
217
|
assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
|
|
196
218
|
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
|
|
197
|
-
assert target.dtype
|
|
219
|
+
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
198
220
|
M, N = x.shape
|
|
199
221
|
device = x.device
|
|
200
222
|
loss = torch.empty(M, device=device, dtype=torch.float32)
|
|
201
223
|
lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
|
|
202
224
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
203
225
|
convert_from_dlpack = lambda tensor: (
|
|
204
|
-
from_dlpack(tensor.detach(), assumed_align=16)
|
|
205
|
-
|
|
226
|
+
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
227
|
+
mode=0, stride_order=(0, 1)
|
|
228
|
+
)
|
|
206
229
|
)
|
|
207
|
-
x_tensor
|
|
230
|
+
x_tensor = convert_from_dlpack(x)
|
|
208
231
|
loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
|
|
209
|
-
lse_tensor =
|
|
232
|
+
lse_tensor = (
|
|
233
|
+
from_dlpack(lse.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
|
|
234
|
+
if lse is not None
|
|
235
|
+
else None
|
|
236
|
+
)
|
|
210
237
|
target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_compact_shape_dynamic(mode=0)
|
|
211
238
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
212
|
-
|
|
239
|
+
|
|
240
|
+
compile_key = (dtype, N, lse is not None)
|
|
213
241
|
if compile_key not in cross_entropy.compile_cache:
|
|
242
|
+
cross_entropy_op = CrossEntropy(dtype, N)
|
|
214
243
|
cross_entropy.compile_cache[compile_key] = cute.compile(
|
|
215
|
-
|
|
244
|
+
cross_entropy_op, x_tensor, target_tensor, loss_tensor, lse_tensor, stream
|
|
216
245
|
)
|
|
217
|
-
cross_entropy.compile_cache[compile_key](
|
|
246
|
+
cross_entropy.compile_cache[compile_key](
|
|
247
|
+
x_tensor, target_tensor, loss_tensor, lse_tensor, stream
|
|
248
|
+
)
|
|
218
249
|
return loss if not return_lse else (loss, lse)
|
|
219
250
|
|
|
220
251
|
|
quack/reduction_base.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from typing import Type, Tuple, Optional
|
|
5
|
+
|
|
6
|
+
import cutlass
|
|
7
|
+
import cutlass.cute as cute
|
|
8
|
+
|
|
9
|
+
import quack.utils as utils
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
torch2cute_dtype_map = {
|
|
13
|
+
torch.float16: cutlass.Float16,
|
|
14
|
+
torch.bfloat16: cutlass.BFloat16,
|
|
15
|
+
torch.float32: cutlass.Float32,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ReductionBase:
|
|
20
|
+
def __init__(
|
|
21
|
+
self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=cutlass.Float32
|
|
22
|
+
):
|
|
23
|
+
self.dtype = dtype
|
|
24
|
+
self.N = N
|
|
25
|
+
self.stage = stage
|
|
26
|
+
self.reduction_dtype = reduction_dtype
|
|
27
|
+
|
|
28
|
+
def _calculate_threads_per_row(self):
|
|
29
|
+
raise NotImplementedError()
|
|
30
|
+
|
|
31
|
+
def _set_cluster_n(self):
|
|
32
|
+
self.cluster_n = 1
|
|
33
|
+
|
|
34
|
+
def _get_num_threads(self):
|
|
35
|
+
return 128 if self.N <= 16384 else 256
|
|
36
|
+
|
|
37
|
+
def _get_tv_layout(self):
|
|
38
|
+
copy_bits = 128
|
|
39
|
+
vecsize = copy_bits // self.dtype.width
|
|
40
|
+
assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
|
|
41
|
+
num_threads = self._get_num_threads()
|
|
42
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
43
|
+
assert num_threads % cute.arch.WARP_SIZE == 0
|
|
44
|
+
|
|
45
|
+
threads_per_row = self._calculate_threads_per_row()
|
|
46
|
+
num_blocks_N = cute.ceil_div(self.N // vecsize, threads_per_row * self.cluster_n)
|
|
47
|
+
cols_per_block = num_threads // threads_per_row
|
|
48
|
+
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
|
|
49
|
+
tv_layout = cute.make_layout(
|
|
50
|
+
((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
|
|
51
|
+
stride=(
|
|
52
|
+
(vecsize * cols_per_block, 1),
|
|
53
|
+
(cols_per_block, cols_per_block * vecsize * threads_per_row),
|
|
54
|
+
),
|
|
55
|
+
)
|
|
56
|
+
return tiler_mn, tv_layout
|
|
57
|
+
|
|
58
|
+
def _smem_size_in_bytes(self, tiler_mn, num_warps):
|
|
59
|
+
return (
|
|
60
|
+
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
|
|
61
|
+
+ self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
|
|
62
|
+
+ self.stage * (cutlass.Int64.width // 8)
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int):
|
|
66
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
67
|
+
warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
|
|
68
|
+
return cute.make_ordered_layout(
|
|
69
|
+
(num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
|
|
70
|
+
order=(1, 0, 2),
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def _allocate_reduction_buffer_and_mbar(
|
|
74
|
+
self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout
|
|
75
|
+
) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
|
|
76
|
+
reduction_buffer = smem.allocate_tensor(
|
|
77
|
+
self.reduction_dtype,
|
|
78
|
+
self._get_reduction_buffer_layout(tv_layout, self.cluster_n),
|
|
79
|
+
byte_alignment=4,
|
|
80
|
+
)
|
|
81
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
82
|
+
mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=self.stage)
|
|
83
|
+
else:
|
|
84
|
+
mbar_ptr = None
|
|
85
|
+
return reduction_buffer, mbar_ptr
|
|
86
|
+
|
|
87
|
+
@cute.jit
|
|
88
|
+
def _initialize_cluster(self, tidx: cutlass.Int32, mbar_ptr: cute.Pointer, num_warps: int):
|
|
89
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
90
|
+
if tidx < self.stage:
|
|
91
|
+
cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + tidx, 1)
|
|
92
|
+
cute.arch.mbarrier_init_fence()
|
|
93
|
+
if tidx < self.stage:
|
|
94
|
+
cute.arch.mbarrier_init_tx_bytes(
|
|
95
|
+
mbar_ptr + tidx, num_warps * self.cluster_n * self.reduction_dtype.width // 8
|
|
96
|
+
)
|
|
97
|
+
# Cluster arrive after barrier init
|
|
98
|
+
cute.arch.cluster_arrive_relaxed()
|