quack-kernels 0.1.2__tar.gz → 0.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack_kernels-0.1.4/PKG-INFO +11 -0
- {quack_kernels-0.1.2 → quack_kernels-0.1.4}/README.md +14 -1
- quack_kernels-0.1.4/pyproject.toml +33 -0
- {quack_kernels-0.1.2 → quack_kernels-0.1.4}/quack/__init__.py +7 -1
- quack_kernels-0.1.4/quack/cross_entropy.py +255 -0
- quack_kernels-0.1.4/quack/reduction_base.py +98 -0
- quack_kernels-0.1.4/quack/rmsnorm.py +285 -0
- quack_kernels-0.1.4/quack/softmax.py +456 -0
- quack_kernels-0.1.4/quack/utils.py +407 -0
- quack_kernels-0.1.4/quack_kernels.egg-info/PKG-INFO +11 -0
- {quack_kernels-0.1.2 → quack_kernels-0.1.4}/quack_kernels.egg-info/SOURCES.txt +1 -0
- quack_kernels-0.1.4/quack_kernels.egg-info/requires.txt +6 -0
- {quack_kernels-0.1.2 → quack_kernels-0.1.4}/tests/test_cross_entropy.py +3 -3
- {quack_kernels-0.1.2 → quack_kernels-0.1.4}/tests/test_softmax.py +75 -16
- quack_kernels-0.1.2/PKG-INFO +0 -8
- quack_kernels-0.1.2/pyproject.toml +0 -18
- quack_kernels-0.1.2/quack/cross_entropy.py +0 -221
- quack_kernels-0.1.2/quack/rmsnorm.py +0 -254
- quack_kernels-0.1.2/quack/softmax.py +0 -195
- quack_kernels-0.1.2/quack/utils.py +0 -246
- quack_kernels-0.1.2/quack_kernels.egg-info/PKG-INFO +0 -8
- quack_kernels-0.1.2/quack_kernels.egg-info/requires.txt +0 -2
- {quack_kernels-0.1.2 → quack_kernels-0.1.4}/LICENSE +0 -0
- {quack_kernels-0.1.2 → quack_kernels-0.1.4}/quack_kernels.egg-info/dependency_links.txt +0 -0
- {quack_kernels-0.1.2 → quack_kernels-0.1.4}/quack_kernels.egg-info/top_level.txt +0 -0
- {quack_kernels-0.1.2 → quack_kernels-0.1.4}/setup.cfg +0 -0
- {quack_kernels-0.1.2 → quack_kernels-0.1.4}/setup.py +0 -0
- {quack_kernels-0.1.2 → quack_kernels-0.1.4}/tests/test_rmsnorm.py +0 -0
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: quack-kernels
|
|
3
|
+
Version: 0.1.4
|
|
4
|
+
Requires-Python: >=3.9
|
|
5
|
+
License-File: LICENSE
|
|
6
|
+
Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
|
|
7
|
+
Requires-Dist: torch
|
|
8
|
+
Provides-Extra: dev
|
|
9
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
10
|
+
Requires-Dist: ruff; extra == "dev"
|
|
11
|
+
Dynamic: license-file
|
|
@@ -17,12 +17,25 @@ pip install quack-kernels
|
|
|
17
17
|
## Kernels 🐥
|
|
18
18
|
|
|
19
19
|
- 🦆 RMSNorm forward
|
|
20
|
-
- 🦆 Softmax forward
|
|
20
|
+
- 🦆 Softmax forward and backward
|
|
21
21
|
- 🦆 Cross entropy forward
|
|
22
22
|
|
|
23
|
+
Upcoming:
|
|
24
|
+
- 🦆 Cross entropy backward
|
|
25
|
+
- 🦆 RMSNorm backward
|
|
26
|
+
- 🦆 Rotary forward + backward
|
|
23
27
|
|
|
24
28
|
## Usage
|
|
25
29
|
|
|
26
30
|
```
|
|
27
31
|
from quack import rmsnorm, softmax, cross_entropy
|
|
28
32
|
```
|
|
33
|
+
|
|
34
|
+
## Development
|
|
35
|
+
|
|
36
|
+
To set up the development environment:
|
|
37
|
+
|
|
38
|
+
```bash
|
|
39
|
+
pip install -e '.[dev]'
|
|
40
|
+
pre-commit install
|
|
41
|
+
```
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "quack-kernels"
|
|
7
|
+
dynamic = ["version"]
|
|
8
|
+
requires-python = ">=3.9"
|
|
9
|
+
dependencies = [
|
|
10
|
+
"nvidia-cutlass-dsl==4.1.0.dev0",
|
|
11
|
+
"torch",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
[project.optional-dependencies]
|
|
15
|
+
dev = [
|
|
16
|
+
"pre-commit",
|
|
17
|
+
"ruff",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
[tool.setuptools.packages.find]
|
|
21
|
+
exclude = ["tests", "benchmarks"]
|
|
22
|
+
|
|
23
|
+
[tool.setuptools.dynamic]
|
|
24
|
+
version = {attr = "quack.__version__"}
|
|
25
|
+
|
|
26
|
+
[tool.ruff]
|
|
27
|
+
line-length = 100
|
|
28
|
+
|
|
29
|
+
[tool.ruff.lint]
|
|
30
|
+
ignore = [
|
|
31
|
+
"E731", # do not assign a lambda expression, use a def
|
|
32
|
+
"F841", # local variable is assigned to but never used
|
|
33
|
+
]
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import torch
|
|
3
|
+
from typing import Optional, Type
|
|
4
|
+
|
|
5
|
+
import cuda.bindings.driver as cuda
|
|
6
|
+
|
|
7
|
+
import cutlass
|
|
8
|
+
import cutlass.cute as cute
|
|
9
|
+
from cutlass.cute.runtime import from_dlpack
|
|
10
|
+
|
|
11
|
+
import quack.utils as utils
|
|
12
|
+
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CrossEntropy(ReductionBase):
|
|
16
|
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int, online_softmax: bool = True):
|
|
17
|
+
# 2 stages: 1 for max, 1 for sum
|
|
18
|
+
super().__init__(
|
|
19
|
+
dtype,
|
|
20
|
+
N,
|
|
21
|
+
stage=2 if not online_softmax else 1,
|
|
22
|
+
reduction_dtype=cutlass.Float32 if not online_softmax else cutlass.Int64,
|
|
23
|
+
)
|
|
24
|
+
self.online_softmax = online_softmax
|
|
25
|
+
self.reload_from = None if N <= 16384 or online_softmax else "smem"
|
|
26
|
+
|
|
27
|
+
def _calculate_threads_per_row(self):
|
|
28
|
+
N = self.N
|
|
29
|
+
return (
|
|
30
|
+
8
|
|
31
|
+
if N <= 64
|
|
32
|
+
else (
|
|
33
|
+
16
|
|
34
|
+
if N <= 128
|
|
35
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
36
|
+
)
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def _set_cluster_n(self):
|
|
40
|
+
N = self.N
|
|
41
|
+
if cutlass.const_expr(self.dtype.width == 16):
|
|
42
|
+
cluster_n = (
|
|
43
|
+
1
|
|
44
|
+
if N <= 16 * 1024
|
|
45
|
+
else (
|
|
46
|
+
2
|
|
47
|
+
if N <= 32 * 1024
|
|
48
|
+
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
49
|
+
)
|
|
50
|
+
)
|
|
51
|
+
else: # fp32
|
|
52
|
+
cluster_n = (
|
|
53
|
+
1
|
|
54
|
+
if N <= 16 * 1024
|
|
55
|
+
else (
|
|
56
|
+
2
|
|
57
|
+
if N <= 64 * 1024
|
|
58
|
+
else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
|
|
59
|
+
)
|
|
60
|
+
)
|
|
61
|
+
self.cluster_n = cluster_n
|
|
62
|
+
|
|
63
|
+
@cute.jit
|
|
64
|
+
def __call__(
|
|
65
|
+
self,
|
|
66
|
+
mX: cute.Tensor,
|
|
67
|
+
mTarget: cute.Tensor,
|
|
68
|
+
mLoss: cute.Tensor,
|
|
69
|
+
mLSE: Optional[cute.Tensor],
|
|
70
|
+
stream: cuda.CUstream,
|
|
71
|
+
):
|
|
72
|
+
assert mX.element_type == self.dtype
|
|
73
|
+
self._set_cluster_n()
|
|
74
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
75
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
76
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
77
|
+
self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
|
|
78
|
+
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
79
|
+
block=[num_threads, 1, 1],
|
|
80
|
+
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
81
|
+
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
82
|
+
stream=stream,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
@cute.kernel
|
|
86
|
+
def kernel(
|
|
87
|
+
self,
|
|
88
|
+
mX: cute.Tensor, # (M, N)
|
|
89
|
+
mTarget: cute.Tensor, # (M,)
|
|
90
|
+
mLoss: cute.Tensor, # (M,)
|
|
91
|
+
mLSE: Optional[cute.Tensor], # (M,)
|
|
92
|
+
tv_layout: cute.Layout,
|
|
93
|
+
tiler_mn: cute.Shape,
|
|
94
|
+
):
|
|
95
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
96
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
97
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
98
|
+
cluster_y = cute.arch.block_idx()[1]
|
|
99
|
+
else:
|
|
100
|
+
cluster_y = cutlass.const_expr(0)
|
|
101
|
+
|
|
102
|
+
shape: cute.Shape = mX.shape
|
|
103
|
+
idX = cute.make_identity_tensor(shape)
|
|
104
|
+
# slice for CTAs
|
|
105
|
+
gX, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, idX)]
|
|
106
|
+
|
|
107
|
+
smem = cutlass.utils.SmemAllocator()
|
|
108
|
+
sX = smem.allocate_tensor(
|
|
109
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
110
|
+
)
|
|
111
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
112
|
+
|
|
113
|
+
# declare the atoms which will be used later for memory copy
|
|
114
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
115
|
+
cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
|
|
116
|
+
)
|
|
117
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
118
|
+
|
|
119
|
+
#### Thread View
|
|
120
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
121
|
+
tXsX = thr_copy_X.partition_D(sX)
|
|
122
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
123
|
+
tXrX = cute.make_fragment_like(tXgX)
|
|
124
|
+
|
|
125
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
126
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
127
|
+
|
|
128
|
+
row = tXcX[0][0]
|
|
129
|
+
target = cute.Int32.zero
|
|
130
|
+
if row < shape[0] and tXcX[0][1] == 0:
|
|
131
|
+
target = cute.Int32(mTarget[row])
|
|
132
|
+
|
|
133
|
+
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
134
|
+
tXpX = (
|
|
135
|
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
136
|
+
if cutlass.const_expr(not is_even_N)
|
|
137
|
+
else None
|
|
138
|
+
)
|
|
139
|
+
if row < shape[0]:
|
|
140
|
+
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
|
141
|
+
cute.arch.cp_async_commit_group()
|
|
142
|
+
cute.arch.cp_async_wait_group(0)
|
|
143
|
+
# Fill OOB values with -inf
|
|
144
|
+
if cutlass.const_expr(not is_even_N):
|
|
145
|
+
utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
|
|
146
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
147
|
+
x = tXrX.load().to(cute.Float32)
|
|
148
|
+
|
|
149
|
+
target_logit = cute.Float32.zero
|
|
150
|
+
if row < shape[0] and tXcX[0][1] == 0:
|
|
151
|
+
target_logit = cute.Float32(mX[row, target])
|
|
152
|
+
|
|
153
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
154
|
+
if cutlass.const_expr(not self.online_softmax):
|
|
155
|
+
max_x = utils.row_reduce(
|
|
156
|
+
x,
|
|
157
|
+
cute.ReductionOp.MAX,
|
|
158
|
+
threads_per_row,
|
|
159
|
+
reduction_buffer[None, None, 0],
|
|
160
|
+
mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
161
|
+
init_val=-cutlass.Float32.inf,
|
|
162
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
163
|
+
)
|
|
164
|
+
if cutlass.const_expr(self.reload_from == "smem"):
|
|
165
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
166
|
+
x = tXrX.load().to(cute.Float32)
|
|
167
|
+
log2_e = math.log2(math.e)
|
|
168
|
+
# exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
|
|
169
|
+
# a bit faster, probably because it's calling ex2.approx.ftz instead of ex2.approx?
|
|
170
|
+
# exp_x = utils.exp2f((x - max_x) * log2_e)
|
|
171
|
+
# This would use ffma instead of fadd then fmul
|
|
172
|
+
exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
|
|
173
|
+
denom = utils.row_reduce(
|
|
174
|
+
exp_x,
|
|
175
|
+
cute.ReductionOp.ADD,
|
|
176
|
+
threads_per_row,
|
|
177
|
+
reduction_buffer[None, None, 1],
|
|
178
|
+
mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
179
|
+
init_val=0.0,
|
|
180
|
+
)
|
|
181
|
+
else:
|
|
182
|
+
max_x, denom, _ = utils.online_softmax_reduce(
|
|
183
|
+
x,
|
|
184
|
+
threads_per_row,
|
|
185
|
+
reduction_buffer[None, None, 0],
|
|
186
|
+
mbar_ptr,
|
|
187
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
if (
|
|
191
|
+
tXcX[0][1] == 0
|
|
192
|
+
and row < shape[0]
|
|
193
|
+
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
194
|
+
):
|
|
195
|
+
ln_2 = math.log(2.0)
|
|
196
|
+
lse = max_x + utils.log2f(denom) * ln_2
|
|
197
|
+
loss_val = lse - target_logit
|
|
198
|
+
mLoss[row] = loss_val.to(mLoss.element_type)
|
|
199
|
+
if cutlass.const_expr(mLSE is not None):
|
|
200
|
+
mLSE[row] = lse
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def cross_entropy(
|
|
204
|
+
x: torch.Tensor,
|
|
205
|
+
target: torch.Tensor,
|
|
206
|
+
return_lse: bool = False,
|
|
207
|
+
) -> torch.Tensor:
|
|
208
|
+
"""Cross entropy forward pass.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
x: Input logits tensor of shape (M, N)
|
|
212
|
+
target: Target class indices tensor of shape (M,)
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
Cross entropy loss tensor of shape (M,)
|
|
216
|
+
"""
|
|
217
|
+
assert x.dim() == 2, "Input must be 2D"
|
|
218
|
+
assert target.dim() == 1, "Target must be 1D"
|
|
219
|
+
assert x.shape[0] == target.shape[0], "Batch dimensions must match"
|
|
220
|
+
assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
|
|
221
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
|
|
222
|
+
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
223
|
+
M, N = x.shape
|
|
224
|
+
device = x.device
|
|
225
|
+
loss = torch.empty(M, device=device, dtype=torch.float32)
|
|
226
|
+
lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
|
|
227
|
+
dtype = torch2cute_dtype_map[x.dtype]
|
|
228
|
+
convert_from_dlpack = lambda tensor: (
|
|
229
|
+
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
230
|
+
mode=0, stride_order=(0, 1)
|
|
231
|
+
)
|
|
232
|
+
)
|
|
233
|
+
x_tensor = convert_from_dlpack(x)
|
|
234
|
+
loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
|
|
235
|
+
lse_tensor = (
|
|
236
|
+
from_dlpack(lse.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
|
|
237
|
+
if lse is not None
|
|
238
|
+
else None
|
|
239
|
+
)
|
|
240
|
+
target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_compact_shape_dynamic(mode=0)
|
|
241
|
+
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
242
|
+
|
|
243
|
+
compile_key = (dtype, N, lse is not None)
|
|
244
|
+
if compile_key not in cross_entropy.compile_cache:
|
|
245
|
+
cross_entropy_op = CrossEntropy(dtype, N)
|
|
246
|
+
cross_entropy.compile_cache[compile_key] = cute.compile(
|
|
247
|
+
cross_entropy_op, x_tensor, target_tensor, loss_tensor, lse_tensor, stream
|
|
248
|
+
)
|
|
249
|
+
cross_entropy.compile_cache[compile_key](
|
|
250
|
+
x_tensor, target_tensor, loss_tensor, lse_tensor, stream
|
|
251
|
+
)
|
|
252
|
+
return loss if not return_lse else (loss, lse)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
cross_entropy.compile_cache = {}
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from typing import Type, Tuple, Optional
|
|
5
|
+
|
|
6
|
+
import cutlass
|
|
7
|
+
import cutlass.cute as cute
|
|
8
|
+
|
|
9
|
+
import quack.utils as utils
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
torch2cute_dtype_map = {
|
|
13
|
+
torch.float16: cutlass.Float16,
|
|
14
|
+
torch.bfloat16: cutlass.BFloat16,
|
|
15
|
+
torch.float32: cutlass.Float32,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ReductionBase:
|
|
20
|
+
def __init__(
|
|
21
|
+
self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=cutlass.Float32
|
|
22
|
+
):
|
|
23
|
+
self.dtype = dtype
|
|
24
|
+
self.N = N
|
|
25
|
+
self.stage = stage
|
|
26
|
+
self.reduction_dtype = reduction_dtype
|
|
27
|
+
|
|
28
|
+
def _calculate_threads_per_row(self):
|
|
29
|
+
raise NotImplementedError()
|
|
30
|
+
|
|
31
|
+
def _set_cluster_n(self):
|
|
32
|
+
self.cluster_n = 1
|
|
33
|
+
|
|
34
|
+
def _get_num_threads(self):
|
|
35
|
+
return 128 if self.N <= 16384 else 256
|
|
36
|
+
|
|
37
|
+
def _get_tv_layout(self):
|
|
38
|
+
copy_bits = 128
|
|
39
|
+
vecsize = copy_bits // self.dtype.width
|
|
40
|
+
assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
|
|
41
|
+
num_threads = self._get_num_threads()
|
|
42
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
43
|
+
assert num_threads % cute.arch.WARP_SIZE == 0
|
|
44
|
+
|
|
45
|
+
threads_per_row = self._calculate_threads_per_row()
|
|
46
|
+
num_blocks_N = cute.ceil_div(self.N // vecsize, threads_per_row * self.cluster_n)
|
|
47
|
+
cols_per_block = num_threads // threads_per_row
|
|
48
|
+
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
|
|
49
|
+
tv_layout = cute.make_layout(
|
|
50
|
+
((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
|
|
51
|
+
stride=(
|
|
52
|
+
(vecsize * cols_per_block, 1),
|
|
53
|
+
(cols_per_block, cols_per_block * vecsize * threads_per_row),
|
|
54
|
+
),
|
|
55
|
+
)
|
|
56
|
+
return tiler_mn, tv_layout
|
|
57
|
+
|
|
58
|
+
def _smem_size_in_bytes(self, tiler_mn, num_warps):
|
|
59
|
+
return (
|
|
60
|
+
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
|
|
61
|
+
+ self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
|
|
62
|
+
+ self.stage * (cutlass.Int64.width // 8)
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int):
|
|
66
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
67
|
+
warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
|
|
68
|
+
return cute.make_ordered_layout(
|
|
69
|
+
(num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
|
|
70
|
+
order=(1, 0, 2),
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def _allocate_reduction_buffer_and_mbar(
|
|
74
|
+
self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout
|
|
75
|
+
) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
|
|
76
|
+
reduction_buffer = smem.allocate_tensor(
|
|
77
|
+
self.reduction_dtype,
|
|
78
|
+
self._get_reduction_buffer_layout(tv_layout, self.cluster_n),
|
|
79
|
+
byte_alignment=4,
|
|
80
|
+
)
|
|
81
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
82
|
+
mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=self.stage)
|
|
83
|
+
else:
|
|
84
|
+
mbar_ptr = None
|
|
85
|
+
return reduction_buffer, mbar_ptr
|
|
86
|
+
|
|
87
|
+
@cute.jit
|
|
88
|
+
def _initialize_cluster(self, tidx: cutlass.Int32, mbar_ptr: cute.Pointer, num_warps: int):
|
|
89
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
90
|
+
if tidx < self.stage:
|
|
91
|
+
cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
|
|
92
|
+
cute.arch.mbarrier_init_fence()
|
|
93
|
+
if tidx < self.stage:
|
|
94
|
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
95
|
+
mbar_ptr + tidx, num_warps * self.cluster_n * self.reduction_dtype.width // 8
|
|
96
|
+
)
|
|
97
|
+
# Cluster arrive after barrier init
|
|
98
|
+
cute.arch.cluster_arrive_relaxed()
|