quack-kernels 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -1
- quack/cross_entropy.py +199 -168
- quack/reduction_base.py +98 -0
- quack/rmsnorm.py +208 -195
- quack/softmax.py +409 -163
- quack/utils.py +249 -35
- {quack_kernels-0.1.1.dist-info → quack_kernels-0.1.3.dist-info}/METADATA +4 -1
- quack_kernels-0.1.3.dist-info/RECORD +11 -0
- quack_kernels-0.1.1.dist-info/RECORD +0 -10
- {quack_kernels-0.1.1.dist-info → quack_kernels-0.1.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.1.dist-info → quack_kernels-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.1.dist-info → quack_kernels-0.1.3.dist-info}/top_level.txt +0 -0
quack/__init__.py
CHANGED
quack/cross_entropy.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import torch
|
|
3
|
-
import
|
|
4
|
-
from typing import Callable, Union
|
|
3
|
+
from typing import Optional, Type
|
|
5
4
|
|
|
6
5
|
import cuda.bindings.driver as cuda
|
|
7
6
|
|
|
@@ -10,177 +9,198 @@ import cutlass.cute as cute
|
|
|
10
9
|
from cutlass.cute.runtime import from_dlpack
|
|
11
10
|
|
|
12
11
|
import quack.utils as utils
|
|
12
|
+
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
|
|
16
|
-
def
|
|
17
|
-
mX: cute.Tensor, # (M, N)
|
|
18
|
-
mTarget: cute.Tensor, # (M,)
|
|
19
|
-
mLoss: cute.Tensor, # (M,)
|
|
20
|
-
tv_layout: cute.Layout,
|
|
21
|
-
tiler_mn: cute.Shape,
|
|
22
|
-
cluster_n: cutlass.Constexpr = 1,
|
|
23
|
-
):
|
|
24
|
-
tidx, _, _ = cute.arch.thread_idx()
|
|
25
|
-
bidx, cluster_y, _ = cute.arch.block_idx()
|
|
26
|
-
gdim, _, _ = cute.arch.grid_dim()
|
|
27
|
-
|
|
28
|
-
shape: cute.Shape = mX.shape
|
|
29
|
-
idX = cute.make_identity_tensor(mX.shape)
|
|
30
|
-
gX, cX = [cute.zipped_divide(mT, tiler_mn) for mT in (mX, idX)]
|
|
31
|
-
blkX, blkCrd = [gT[(None, None), bidx if cluster_n == 1 else (bidx, cluster_y)] for gT in (gX, cX)]
|
|
32
|
-
|
|
33
|
-
# declare the atoms which will be used later for memory copy
|
|
34
|
-
copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128)
|
|
35
|
-
copy_atom_load_X_async = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128)
|
|
36
|
-
copy_atom_scalar = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=gX.element_type.width)
|
|
37
|
-
|
|
38
|
-
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
39
|
-
thr_copy_X_async = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(tidx)
|
|
40
|
-
|
|
41
|
-
smem = cutlass.utils.SmemAllocator()
|
|
42
|
-
|
|
43
|
-
# Don't use blkX.layout here, because the stride is N, not N_rounded
|
|
44
|
-
sX = smem.allocate_tensor(gX.element_type, cute.make_ordered_layout(blkX.shape, order=(1, 0)), byte_alignment=16)
|
|
45
|
-
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
46
|
-
warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
|
|
47
|
-
|
|
48
|
-
reduction_buffer_layout = cute.make_ordered_layout(
|
|
15
|
+
class CrossEntropy(ReductionBase):
|
|
16
|
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int, online_softmax: bool = True):
|
|
49
17
|
# 2 stages: 1 for max, 1 for sum
|
|
50
|
-
(
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
# 1 mbar for max reduction, 1 mbar for sum reduction
|
|
56
|
-
mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=2)
|
|
57
|
-
else:
|
|
58
|
-
mbar_ptr = None
|
|
59
|
-
|
|
60
|
-
#### Thread View
|
|
61
|
-
tXgX = thr_copy_X_async.partition_S(blkX)
|
|
62
|
-
tXsX = thr_copy_X_async.partition_S(sX)
|
|
63
|
-
|
|
64
|
-
tXcX = thr_copy_X.partition_S(blkCrd)[(0, None), None, None]
|
|
65
|
-
|
|
66
|
-
# allocate fragments for gmem->rmem
|
|
67
|
-
tXrX = cute.make_fragment_like(tXgX) # only logits fragment needed
|
|
68
|
-
|
|
69
|
-
if cluster_n > 1:
|
|
70
|
-
if tidx < 2:
|
|
71
|
-
cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + tidx, 1)
|
|
72
|
-
cute.arch.mbarrier_init_fence()
|
|
73
|
-
if tidx < 2:
|
|
74
|
-
cute.arch.mbarrier_init_tx_bytes(mbar_ptr + tidx, num_warps * cluster_n * cutlass.Float32.width // 8)
|
|
75
|
-
# Cluster arrive after barrier init
|
|
76
|
-
cute.arch.cluster_arrive_relaxed()
|
|
77
|
-
|
|
78
|
-
row = tXcX[0][0]
|
|
79
|
-
target = cute.Int32.zero
|
|
80
|
-
if row < shape[0] and tXcX[0][1] == 0:
|
|
81
|
-
target = cute.Int32(mTarget[row])
|
|
82
|
-
|
|
83
|
-
tXpX = cute.make_fragment_like(tXgX[(0, None), None, None], cutlass.Boolean)
|
|
84
|
-
for i in range(cute.size(tXpX)):
|
|
85
|
-
tXpX[i] = cute.elem_less(tXcX[i][1], shape[1])
|
|
86
|
-
if row < shape[0]:
|
|
87
|
-
cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
|
|
88
|
-
cute.arch.cp_async_commit_group()
|
|
89
|
-
cute.arch.cp_async_wait_group(0)
|
|
90
|
-
cute.autovec_copy(tXsX, tXrX)
|
|
91
|
-
x = tXrX.load().to(cute.Float32)
|
|
92
|
-
|
|
93
|
-
target_logit = cute.Float32.zero
|
|
94
|
-
if row < shape[0] and tXcX[0][1] == 0:
|
|
95
|
-
target_logit = cute.Float32(mX[row, target])
|
|
96
|
-
|
|
97
|
-
max_x = utils.warp_reduce(
|
|
98
|
-
x.reduce(cute.ReductionOp.MAX, init_val=float('-inf'), reduction_profile=0),
|
|
99
|
-
cute.arch.fmax,
|
|
100
|
-
width=utils.min_constexpr(tv_layout.shape[0][0], cute.arch.WARP_SIZE),
|
|
101
|
-
)
|
|
102
|
-
if cutlass.const_expr(cluster_n > 1):
|
|
103
|
-
cute.arch.cluster_wait()
|
|
104
|
-
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
105
|
-
max_mbar_ptr = mbar_ptr + 0 if cluster_n > 1 else None
|
|
106
|
-
max_x = utils.block_or_cluster_reduce(
|
|
107
|
-
max_x, cute.arch.fmax, reduction_buffer[None, None, 0], max_mbar_ptr, init_val=-cutlass.Float32.inf
|
|
18
|
+
super().__init__(
|
|
19
|
+
dtype,
|
|
20
|
+
N,
|
|
21
|
+
stage=2 if not online_softmax else 1,
|
|
22
|
+
reduction_dtype=cutlass.Float32 if not online_softmax else cutlass.Int64,
|
|
108
23
|
)
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
24
|
+
self.online_softmax = online_softmax
|
|
25
|
+
self.reload_from = None if N <= 16384 or online_softmax else "smem"
|
|
26
|
+
|
|
27
|
+
def _calculate_threads_per_row(self):
|
|
28
|
+
N = self.N
|
|
29
|
+
return (
|
|
30
|
+
8
|
|
31
|
+
if N <= 64
|
|
32
|
+
else (
|
|
33
|
+
16
|
|
34
|
+
if N <= 128
|
|
35
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
36
|
+
)
|
|
121
37
|
)
|
|
122
38
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
39
|
+
def _set_cluster_n(self):
|
|
40
|
+
N = self.N
|
|
41
|
+
if cutlass.const_expr(self.dtype.width == 16):
|
|
42
|
+
cluster_n = (
|
|
43
|
+
1
|
|
44
|
+
if N <= 16 * 1024
|
|
45
|
+
else (
|
|
46
|
+
2
|
|
47
|
+
if N <= 32 * 1024
|
|
48
|
+
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
49
|
+
)
|
|
50
|
+
)
|
|
51
|
+
else: # fp32
|
|
52
|
+
cluster_n = (
|
|
53
|
+
1
|
|
54
|
+
if N <= 16 * 1024
|
|
55
|
+
else (
|
|
56
|
+
2
|
|
57
|
+
if N <= 64 * 1024
|
|
58
|
+
else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
|
|
59
|
+
)
|
|
60
|
+
)
|
|
61
|
+
self.cluster_n = cluster_n
|
|
62
|
+
|
|
63
|
+
@cute.jit
|
|
64
|
+
def __call__(
|
|
65
|
+
self,
|
|
66
|
+
mX: cute.Tensor,
|
|
67
|
+
mTarget: cute.Tensor,
|
|
68
|
+
mLoss: cute.Tensor,
|
|
69
|
+
mLSE: Optional[cute.Tensor],
|
|
70
|
+
stream: cuda.CUstream,
|
|
71
|
+
):
|
|
72
|
+
assert mX.element_type == self.dtype
|
|
73
|
+
self._set_cluster_n()
|
|
74
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
75
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
76
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
77
|
+
self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
|
|
78
|
+
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
79
|
+
block=[num_threads, 1, 1],
|
|
80
|
+
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
81
|
+
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
82
|
+
stream=stream,
|
|
83
|
+
)
|
|
162
84
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
85
|
+
@cute.kernel
|
|
86
|
+
def kernel(
|
|
87
|
+
self,
|
|
88
|
+
mX: cute.Tensor, # (M, N)
|
|
89
|
+
mTarget: cute.Tensor, # (M,)
|
|
90
|
+
mLoss: cute.Tensor, # (M,)
|
|
91
|
+
mLSE: Optional[cute.Tensor], # (M,)
|
|
92
|
+
tv_layout: cute.Layout,
|
|
93
|
+
tiler_mn: cute.Shape,
|
|
94
|
+
):
|
|
95
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
96
|
+
bidx, cluster_y, _ = cute.arch.block_idx()
|
|
97
|
+
|
|
98
|
+
shape: cute.Shape = mX.shape
|
|
99
|
+
idX = cute.make_identity_tensor(shape)
|
|
100
|
+
# slice for CTAs
|
|
101
|
+
gX, cX = [
|
|
102
|
+
cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
|
|
103
|
+
for mT in (mX, idX)
|
|
104
|
+
]
|
|
105
|
+
|
|
106
|
+
smem = cutlass.utils.SmemAllocator()
|
|
107
|
+
sX = smem.allocate_tensor(
|
|
108
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
109
|
+
)
|
|
110
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
111
|
+
|
|
112
|
+
# declare the atoms which will be used later for memory copy
|
|
113
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
114
|
+
cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
|
|
115
|
+
)
|
|
116
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
117
|
+
|
|
118
|
+
#### Thread View
|
|
119
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
120
|
+
tXsX = thr_copy_X.partition_D(sX)
|
|
121
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
122
|
+
tXrX = cute.make_fragment_like(tXgX)
|
|
123
|
+
|
|
124
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
125
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
172
126
|
|
|
127
|
+
row = tXcX[0][0]
|
|
128
|
+
target = cute.Int32.zero
|
|
129
|
+
if row < shape[0] and tXcX[0][1] == 0:
|
|
130
|
+
target = cute.Int32(mTarget[row])
|
|
173
131
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
132
|
+
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
133
|
+
tXpX = (
|
|
134
|
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
135
|
+
)
|
|
136
|
+
if row < shape[0]:
|
|
137
|
+
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
|
138
|
+
cute.arch.cp_async_commit_group()
|
|
139
|
+
cute.arch.cp_async_wait_group(0)
|
|
140
|
+
# Fill OOB values with -inf
|
|
141
|
+
if cutlass.const_expr(not is_even_N):
|
|
142
|
+
utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
|
|
143
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
144
|
+
x = tXrX.load().to(cute.Float32)
|
|
145
|
+
|
|
146
|
+
target_logit = cute.Float32.zero
|
|
147
|
+
if row < shape[0] and tXcX[0][1] == 0:
|
|
148
|
+
target_logit = cute.Float32(mX[row, target])
|
|
149
|
+
|
|
150
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
151
|
+
if cutlass.const_expr(not self.online_softmax):
|
|
152
|
+
max_x = utils.row_reduce(
|
|
153
|
+
x,
|
|
154
|
+
cute.ReductionOp.MAX,
|
|
155
|
+
threads_per_row,
|
|
156
|
+
reduction_buffer[None, None, 0],
|
|
157
|
+
mbar_ptr + 0 if self.cluster_n > 1 else None,
|
|
158
|
+
init_val=-cutlass.Float32.inf,
|
|
159
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
160
|
+
)
|
|
161
|
+
if cutlass.const_expr(self.reload_from == "smem"):
|
|
162
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
163
|
+
x = tXrX.load().to(cute.Float32)
|
|
164
|
+
log2_e = math.log2(math.e)
|
|
165
|
+
# exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
|
|
166
|
+
# a bit faster, probably because it's calling ex2.approx.ftz instead of ex2.approx?
|
|
167
|
+
# exp_x = utils.exp2f((x - max_x) * log2_e)
|
|
168
|
+
# This would use ffma instead of fadd then fmul
|
|
169
|
+
exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
|
|
170
|
+
denom = utils.row_reduce(
|
|
171
|
+
exp_x,
|
|
172
|
+
cute.ReductionOp.ADD,
|
|
173
|
+
threads_per_row,
|
|
174
|
+
reduction_buffer[None, None, 1],
|
|
175
|
+
mbar_ptr + 1 if self.cluster_n > 1 else None,
|
|
176
|
+
init_val=0.0,
|
|
177
|
+
)
|
|
178
|
+
else:
|
|
179
|
+
max_x, denom, _ = utils.online_softmax_reduce(
|
|
180
|
+
x,
|
|
181
|
+
threads_per_row,
|
|
182
|
+
reduction_buffer[None, None, 0],
|
|
183
|
+
mbar_ptr,
|
|
184
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
if (
|
|
188
|
+
tXcX[0][1] == 0
|
|
189
|
+
and row < shape[0]
|
|
190
|
+
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
191
|
+
):
|
|
192
|
+
ln_2 = math.log(2.0)
|
|
193
|
+
lse = max_x + utils.log2f(denom) * ln_2
|
|
194
|
+
loss_val = lse - target_logit
|
|
195
|
+
mLoss[row] = loss_val.to(mLoss.element_type)
|
|
196
|
+
if cutlass.const_expr(mLSE is not None):
|
|
197
|
+
mLSE[row] = lse
|
|
179
198
|
|
|
180
199
|
|
|
181
200
|
def cross_entropy(
|
|
182
201
|
x: torch.Tensor,
|
|
183
202
|
target: torch.Tensor,
|
|
203
|
+
return_lse: bool = False,
|
|
184
204
|
) -> torch.Tensor:
|
|
185
205
|
"""Cross entropy forward pass.
|
|
186
206
|
|
|
@@ -196,26 +216,37 @@ def cross_entropy(
|
|
|
196
216
|
assert x.shape[0] == target.shape[0], "Batch dimensions must match"
|
|
197
217
|
assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
|
|
198
218
|
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
|
|
199
|
-
assert target.dtype
|
|
219
|
+
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
200
220
|
M, N = x.shape
|
|
201
221
|
device = x.device
|
|
202
|
-
loss = torch.empty(M, device=device, dtype=
|
|
222
|
+
loss = torch.empty(M, device=device, dtype=torch.float32)
|
|
223
|
+
lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
|
|
203
224
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
204
225
|
convert_from_dlpack = lambda tensor: (
|
|
205
|
-
from_dlpack(tensor.detach(), assumed_align=16)
|
|
206
|
-
|
|
226
|
+
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
227
|
+
mode=0, stride_order=(0, 1)
|
|
228
|
+
)
|
|
207
229
|
)
|
|
208
|
-
x_tensor
|
|
230
|
+
x_tensor = convert_from_dlpack(x)
|
|
209
231
|
loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
|
|
232
|
+
lse_tensor = (
|
|
233
|
+
from_dlpack(lse.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
|
|
234
|
+
if lse is not None
|
|
235
|
+
else None
|
|
236
|
+
)
|
|
210
237
|
target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_compact_shape_dynamic(mode=0)
|
|
211
238
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
212
|
-
|
|
239
|
+
|
|
240
|
+
compile_key = (dtype, N, lse is not None)
|
|
213
241
|
if compile_key not in cross_entropy.compile_cache:
|
|
242
|
+
cross_entropy_op = CrossEntropy(dtype, N)
|
|
214
243
|
cross_entropy.compile_cache[compile_key] = cute.compile(
|
|
215
|
-
|
|
244
|
+
cross_entropy_op, x_tensor, target_tensor, loss_tensor, lse_tensor, stream
|
|
216
245
|
)
|
|
217
|
-
cross_entropy.compile_cache[compile_key](
|
|
218
|
-
|
|
246
|
+
cross_entropy.compile_cache[compile_key](
|
|
247
|
+
x_tensor, target_tensor, loss_tensor, lse_tensor, stream
|
|
248
|
+
)
|
|
249
|
+
return loss if not return_lse else (loss, lse)
|
|
219
250
|
|
|
220
251
|
|
|
221
252
|
cross_entropy.compile_cache = {}
|
quack/reduction_base.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from typing import Type, Tuple, Optional
|
|
5
|
+
|
|
6
|
+
import cutlass
|
|
7
|
+
import cutlass.cute as cute
|
|
8
|
+
|
|
9
|
+
import quack.utils as utils
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
torch2cute_dtype_map = {
|
|
13
|
+
torch.float16: cutlass.Float16,
|
|
14
|
+
torch.bfloat16: cutlass.BFloat16,
|
|
15
|
+
torch.float32: cutlass.Float32,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ReductionBase:
|
|
20
|
+
def __init__(
|
|
21
|
+
self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=cutlass.Float32
|
|
22
|
+
):
|
|
23
|
+
self.dtype = dtype
|
|
24
|
+
self.N = N
|
|
25
|
+
self.stage = stage
|
|
26
|
+
self.reduction_dtype = reduction_dtype
|
|
27
|
+
|
|
28
|
+
def _calculate_threads_per_row(self):
|
|
29
|
+
raise NotImplementedError()
|
|
30
|
+
|
|
31
|
+
def _set_cluster_n(self):
|
|
32
|
+
self.cluster_n = 1
|
|
33
|
+
|
|
34
|
+
def _get_num_threads(self):
|
|
35
|
+
return 128 if self.N <= 16384 else 256
|
|
36
|
+
|
|
37
|
+
def _get_tv_layout(self):
|
|
38
|
+
copy_bits = 128
|
|
39
|
+
vecsize = copy_bits // self.dtype.width
|
|
40
|
+
assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
|
|
41
|
+
num_threads = self._get_num_threads()
|
|
42
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
43
|
+
assert num_threads % cute.arch.WARP_SIZE == 0
|
|
44
|
+
|
|
45
|
+
threads_per_row = self._calculate_threads_per_row()
|
|
46
|
+
num_blocks_N = cute.ceil_div(self.N // vecsize, threads_per_row * self.cluster_n)
|
|
47
|
+
cols_per_block = num_threads // threads_per_row
|
|
48
|
+
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
|
|
49
|
+
tv_layout = cute.make_layout(
|
|
50
|
+
((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
|
|
51
|
+
stride=(
|
|
52
|
+
(vecsize * cols_per_block, 1),
|
|
53
|
+
(cols_per_block, cols_per_block * vecsize * threads_per_row),
|
|
54
|
+
),
|
|
55
|
+
)
|
|
56
|
+
return tiler_mn, tv_layout
|
|
57
|
+
|
|
58
|
+
def _smem_size_in_bytes(self, tiler_mn, num_warps):
|
|
59
|
+
return (
|
|
60
|
+
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
|
|
61
|
+
+ self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
|
|
62
|
+
+ self.stage * (cutlass.Int64.width // 8)
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int):
|
|
66
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
67
|
+
warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
|
|
68
|
+
return cute.make_ordered_layout(
|
|
69
|
+
(num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
|
|
70
|
+
order=(1, 0, 2),
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def _allocate_reduction_buffer_and_mbar(
|
|
74
|
+
self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout
|
|
75
|
+
) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
|
|
76
|
+
reduction_buffer = smem.allocate_tensor(
|
|
77
|
+
self.reduction_dtype,
|
|
78
|
+
self._get_reduction_buffer_layout(tv_layout, self.cluster_n),
|
|
79
|
+
byte_alignment=4,
|
|
80
|
+
)
|
|
81
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
82
|
+
mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=self.stage)
|
|
83
|
+
else:
|
|
84
|
+
mbar_ptr = None
|
|
85
|
+
return reduction_buffer, mbar_ptr
|
|
86
|
+
|
|
87
|
+
@cute.jit
|
|
88
|
+
def _initialize_cluster(self, tidx: cutlass.Int32, mbar_ptr: cute.Pointer, num_warps: int):
|
|
89
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
90
|
+
if tidx < self.stage:
|
|
91
|
+
cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + tidx, 1)
|
|
92
|
+
cute.arch.mbarrier_init_fence()
|
|
93
|
+
if tidx < self.stage:
|
|
94
|
+
cute.arch.mbarrier_init_tx_bytes(
|
|
95
|
+
mbar_ptr + tidx, num_warps * self.cluster_n * self.reduction_dtype.width // 8
|
|
96
|
+
)
|
|
97
|
+
# Cluster arrive after barrier init
|
|
98
|
+
cute.arch.cluster_arrive_relaxed()
|