quack-kernels 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -1
- quack/cross_entropy.py +201 -167
- quack/reduction_base.py +98 -0
- quack/rmsnorm.py +212 -181
- quack/softmax.py +417 -156
- quack/utils.py +206 -45
- quack_kernels-0.1.4.dist-info/METADATA +11 -0
- quack_kernels-0.1.4.dist-info/RECORD +11 -0
- quack_kernels-0.1.2.dist-info/METADATA +0 -8
- quack_kernels-0.1.2.dist-info/RECORD +0 -10
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.4.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.4.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.4.dist-info}/top_level.txt +0 -0
quack/__init__.py
CHANGED
quack/cross_entropy.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import torch
|
|
3
|
-
import
|
|
4
|
-
from typing import Callable, Union, Optional
|
|
3
|
+
from typing import Optional, Type
|
|
5
4
|
|
|
6
5
|
import cuda.bindings.driver as cuda
|
|
7
6
|
|
|
@@ -10,169 +9,195 @@ import cutlass.cute as cute
|
|
|
10
9
|
from cutlass.cute.runtime import from_dlpack
|
|
11
10
|
|
|
12
11
|
import quack.utils as utils
|
|
12
|
+
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
|
|
16
|
-
def
|
|
17
|
-
mX: cute.Tensor, # (M, N)
|
|
18
|
-
mTarget: cute.Tensor, # (M,)
|
|
19
|
-
mLoss: cute.Tensor, # (M,)
|
|
20
|
-
mLSE: Optional[cute.Tensor], # (M,)
|
|
21
|
-
tv_layout: cute.Layout,
|
|
22
|
-
tiler_mn: cute.Shape,
|
|
23
|
-
cluster_n: cutlass.Constexpr = 1,
|
|
24
|
-
):
|
|
25
|
-
tidx, _, _ = cute.arch.thread_idx()
|
|
26
|
-
bidx, cluster_y, _ = cute.arch.block_idx()
|
|
27
|
-
|
|
28
|
-
shape: cute.Shape = mX.shape
|
|
29
|
-
idX = cute.make_identity_tensor(shape)
|
|
30
|
-
# slice for CTAs
|
|
31
|
-
gX, cX = [
|
|
32
|
-
cute.local_tile(mT, tiler_mn, (bidx, 0 if cluster_n == 1 else cluster_y))
|
|
33
|
-
for mT in (mX, idX)
|
|
34
|
-
]
|
|
35
|
-
|
|
36
|
-
smem = cutlass.utils.SmemAllocator()
|
|
37
|
-
sX = smem.allocate_tensor(mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16)
|
|
38
|
-
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
39
|
-
warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
|
|
40
|
-
reduction_buffer_layout = cute.make_ordered_layout(
|
|
15
|
+
class CrossEntropy(ReductionBase):
|
|
16
|
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int, online_softmax: bool = True):
|
|
41
17
|
# 2 stages: 1 for max, 1 for sum
|
|
42
|
-
(
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
if cluster_n > 1:
|
|
63
|
-
if tidx < 2:
|
|
64
|
-
cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + tidx, 1)
|
|
65
|
-
cute.arch.mbarrier_init_fence()
|
|
66
|
-
if tidx < 2:
|
|
67
|
-
cute.arch.mbarrier_init_tx_bytes(mbar_ptr + tidx, num_warps * cluster_n * cutlass.Float32.width // 8)
|
|
68
|
-
# Cluster arrive after barrier init
|
|
69
|
-
cute.arch.cluster_arrive_relaxed()
|
|
70
|
-
|
|
71
|
-
row = tXcX[0][0]
|
|
72
|
-
target = cute.Int32.zero
|
|
73
|
-
if row < shape[0] and tXcX[0][1] == 0:
|
|
74
|
-
target = cute.Int32(mTarget[row])
|
|
75
|
-
|
|
76
|
-
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * cluster_n)
|
|
77
|
-
tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
78
|
-
if row < shape[0]:
|
|
79
|
-
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
|
80
|
-
cute.arch.cp_async_commit_group()
|
|
81
|
-
cute.arch.cp_async_wait_group(0)
|
|
82
|
-
cute.autovec_copy(tXsX, tXrX)
|
|
83
|
-
x = tXrX.load().to(cute.Float32)
|
|
84
|
-
# Fill OOB values with -inf
|
|
85
|
-
if cutlass.const_expr(not is_even_N):
|
|
86
|
-
tXrX_fp32 = cute.make_fragment_like(tXrX, cutlass.Float32)
|
|
87
|
-
tXrX_fp32.store(x)
|
|
88
|
-
for rest_v in range(tXpX.shape[0]):
|
|
89
|
-
for rest_k in range(tXpX.shape[2]):
|
|
90
|
-
if not tXpX[rest_v, 0, rest_k]:
|
|
91
|
-
tXrX_fp32[(None, rest_v), None, rest_k].fill(-cutlass.Float32.inf)
|
|
92
|
-
x = tXrX_fp32.load()
|
|
93
|
-
|
|
94
|
-
target_logit = cute.Float32.zero
|
|
95
|
-
if row < shape[0] and tXcX[0][1] == 0:
|
|
96
|
-
target_logit = cute.Float32(mX[row, target])
|
|
97
|
-
|
|
98
|
-
threads_per_row = tv_layout.shape[0][0]
|
|
99
|
-
max_x = utils.row_reduce(
|
|
100
|
-
x,
|
|
101
|
-
cute.ReductionOp.MAX,
|
|
102
|
-
threads_per_row,
|
|
103
|
-
reduction_buffer[None, None, 0],
|
|
104
|
-
mbar_ptr + 0 if cluster_n > 1 else None,
|
|
105
|
-
init_val=-cutlass.Float32.inf,
|
|
106
|
-
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(cluster_n > 1) else None
|
|
107
|
-
)
|
|
108
|
-
log2_e = math.log2(math.e)
|
|
109
|
-
# exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
|
|
110
|
-
exp_x = utils.exp2f((x - max_x) * log2_e) # a bit faster, idk why
|
|
111
|
-
denom = utils.row_reduce(
|
|
112
|
-
exp_x,
|
|
113
|
-
cute.ReductionOp.ADD,
|
|
114
|
-
threads_per_row,
|
|
115
|
-
reduction_buffer[None, None, 1],
|
|
116
|
-
mbar_ptr + 1 if cluster_n > 1 else None,
|
|
117
|
-
init_val=0.0,
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
if tXcX[0][1] == 0 and row < shape[0] and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0):
|
|
121
|
-
ln_2 = math.log(2.0)
|
|
122
|
-
lse = max_x + utils.log2f(denom) * ln_2
|
|
123
|
-
loss_val = lse - target_logit
|
|
124
|
-
mLoss[row] = loss_val.to(mLoss.element_type)
|
|
125
|
-
if cutlass.const_expr(mLSE is not None):
|
|
126
|
-
mLSE[row] = lse
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
@cute.jit
|
|
130
|
-
def cross_entropy_interface(
|
|
131
|
-
mX: cute.Tensor,
|
|
132
|
-
mTarget: cute.Tensor,
|
|
133
|
-
mLoss: cute.Tensor,
|
|
134
|
-
mLSE: Optional[cute.Tensor],
|
|
135
|
-
stream: cuda.CUstream,
|
|
136
|
-
N: cutlass.Constexpr,
|
|
137
|
-
copy_bits: cutlass.Constexpr = 128
|
|
138
|
-
):
|
|
139
|
-
vecsize = copy_bits // mX.element_type.width
|
|
140
|
-
assert N % vecsize == 0, f"Input N {N} is not divisible by vector size {vecsize}"
|
|
141
|
-
num_threads = 128 if N <= 16384 else 256
|
|
142
|
-
|
|
143
|
-
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
144
|
-
assert num_threads % cute.arch.WARP_SIZE == 0
|
|
145
|
-
threads_per_row = 8 if N <= 64 else (16 if N <= 128 else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256))))
|
|
146
|
-
|
|
147
|
-
if cutlass.const_expr(mX.element_type.width == 16):
|
|
148
|
-
cluster_n = 1 if N <= 16 * 1024 else (2 if N <= 32 * 1024 else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)))
|
|
149
|
-
else: # fp32
|
|
150
|
-
cluster_n = 1 if N <= 16 * 1024 else (2 if N <= 64 * 1024 else (4 if N <= 128 * 1024 else 8))
|
|
151
|
-
|
|
152
|
-
num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row * cluster_n)
|
|
153
|
-
cols_per_block = num_threads // threads_per_row
|
|
154
|
-
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) # This rounds up N
|
|
155
|
-
tv_layout = cute.make_layout(
|
|
156
|
-
((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
|
|
157
|
-
stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * threads_per_row))
|
|
158
|
-
)
|
|
18
|
+
super().__init__(
|
|
19
|
+
dtype,
|
|
20
|
+
N,
|
|
21
|
+
stage=2 if not online_softmax else 1,
|
|
22
|
+
reduction_dtype=cutlass.Float32 if not online_softmax else cutlass.Int64,
|
|
23
|
+
)
|
|
24
|
+
self.online_softmax = online_softmax
|
|
25
|
+
self.reload_from = None if N <= 16384 or online_softmax else "smem"
|
|
26
|
+
|
|
27
|
+
def _calculate_threads_per_row(self):
|
|
28
|
+
N = self.N
|
|
29
|
+
return (
|
|
30
|
+
8
|
|
31
|
+
if N <= 64
|
|
32
|
+
else (
|
|
33
|
+
16
|
|
34
|
+
if N <= 128
|
|
35
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
36
|
+
)
|
|
37
|
+
)
|
|
159
38
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
39
|
+
def _set_cluster_n(self):
|
|
40
|
+
N = self.N
|
|
41
|
+
if cutlass.const_expr(self.dtype.width == 16):
|
|
42
|
+
cluster_n = (
|
|
43
|
+
1
|
|
44
|
+
if N <= 16 * 1024
|
|
45
|
+
else (
|
|
46
|
+
2
|
|
47
|
+
if N <= 32 * 1024
|
|
48
|
+
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
49
|
+
)
|
|
50
|
+
)
|
|
51
|
+
else: # fp32
|
|
52
|
+
cluster_n = (
|
|
53
|
+
1
|
|
54
|
+
if N <= 16 * 1024
|
|
55
|
+
else (
|
|
56
|
+
2
|
|
57
|
+
if N <= 64 * 1024
|
|
58
|
+
else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
|
|
59
|
+
)
|
|
60
|
+
)
|
|
61
|
+
self.cluster_n = cluster_n
|
|
62
|
+
|
|
63
|
+
@cute.jit
|
|
64
|
+
def __call__(
|
|
65
|
+
self,
|
|
66
|
+
mX: cute.Tensor,
|
|
67
|
+
mTarget: cute.Tensor,
|
|
68
|
+
mLoss: cute.Tensor,
|
|
69
|
+
mLSE: Optional[cute.Tensor],
|
|
70
|
+
stream: cuda.CUstream,
|
|
71
|
+
):
|
|
72
|
+
assert mX.element_type == self.dtype
|
|
73
|
+
self._set_cluster_n()
|
|
74
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
75
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
76
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
77
|
+
self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
|
|
78
|
+
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
79
|
+
block=[num_threads, 1, 1],
|
|
80
|
+
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
81
|
+
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
82
|
+
stream=stream,
|
|
83
|
+
)
|
|
169
84
|
|
|
85
|
+
@cute.kernel
|
|
86
|
+
def kernel(
|
|
87
|
+
self,
|
|
88
|
+
mX: cute.Tensor, # (M, N)
|
|
89
|
+
mTarget: cute.Tensor, # (M,)
|
|
90
|
+
mLoss: cute.Tensor, # (M,)
|
|
91
|
+
mLSE: Optional[cute.Tensor], # (M,)
|
|
92
|
+
tv_layout: cute.Layout,
|
|
93
|
+
tiler_mn: cute.Shape,
|
|
94
|
+
):
|
|
95
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
96
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
97
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
98
|
+
cluster_y = cute.arch.block_idx()[1]
|
|
99
|
+
else:
|
|
100
|
+
cluster_y = cutlass.const_expr(0)
|
|
101
|
+
|
|
102
|
+
shape: cute.Shape = mX.shape
|
|
103
|
+
idX = cute.make_identity_tensor(shape)
|
|
104
|
+
# slice for CTAs
|
|
105
|
+
gX, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, idX)]
|
|
106
|
+
|
|
107
|
+
smem = cutlass.utils.SmemAllocator()
|
|
108
|
+
sX = smem.allocate_tensor(
|
|
109
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
110
|
+
)
|
|
111
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
170
112
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
113
|
+
# declare the atoms which will be used later for memory copy
|
|
114
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
115
|
+
cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
|
|
116
|
+
)
|
|
117
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
118
|
+
|
|
119
|
+
#### Thread View
|
|
120
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
121
|
+
tXsX = thr_copy_X.partition_D(sX)
|
|
122
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
123
|
+
tXrX = cute.make_fragment_like(tXgX)
|
|
124
|
+
|
|
125
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
126
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
127
|
+
|
|
128
|
+
row = tXcX[0][0]
|
|
129
|
+
target = cute.Int32.zero
|
|
130
|
+
if row < shape[0] and tXcX[0][1] == 0:
|
|
131
|
+
target = cute.Int32(mTarget[row])
|
|
132
|
+
|
|
133
|
+
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
134
|
+
tXpX = (
|
|
135
|
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
136
|
+
if cutlass.const_expr(not is_even_N)
|
|
137
|
+
else None
|
|
138
|
+
)
|
|
139
|
+
if row < shape[0]:
|
|
140
|
+
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
|
141
|
+
cute.arch.cp_async_commit_group()
|
|
142
|
+
cute.arch.cp_async_wait_group(0)
|
|
143
|
+
# Fill OOB values with -inf
|
|
144
|
+
if cutlass.const_expr(not is_even_N):
|
|
145
|
+
utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
|
|
146
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
147
|
+
x = tXrX.load().to(cute.Float32)
|
|
148
|
+
|
|
149
|
+
target_logit = cute.Float32.zero
|
|
150
|
+
if row < shape[0] and tXcX[0][1] == 0:
|
|
151
|
+
target_logit = cute.Float32(mX[row, target])
|
|
152
|
+
|
|
153
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
154
|
+
if cutlass.const_expr(not self.online_softmax):
|
|
155
|
+
max_x = utils.row_reduce(
|
|
156
|
+
x,
|
|
157
|
+
cute.ReductionOp.MAX,
|
|
158
|
+
threads_per_row,
|
|
159
|
+
reduction_buffer[None, None, 0],
|
|
160
|
+
mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
161
|
+
init_val=-cutlass.Float32.inf,
|
|
162
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
163
|
+
)
|
|
164
|
+
if cutlass.const_expr(self.reload_from == "smem"):
|
|
165
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
166
|
+
x = tXrX.load().to(cute.Float32)
|
|
167
|
+
log2_e = math.log2(math.e)
|
|
168
|
+
# exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
|
|
169
|
+
# a bit faster, probably because it's calling ex2.approx.ftz instead of ex2.approx?
|
|
170
|
+
# exp_x = utils.exp2f((x - max_x) * log2_e)
|
|
171
|
+
# This would use ffma instead of fadd then fmul
|
|
172
|
+
exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
|
|
173
|
+
denom = utils.row_reduce(
|
|
174
|
+
exp_x,
|
|
175
|
+
cute.ReductionOp.ADD,
|
|
176
|
+
threads_per_row,
|
|
177
|
+
reduction_buffer[None, None, 1],
|
|
178
|
+
mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
179
|
+
init_val=0.0,
|
|
180
|
+
)
|
|
181
|
+
else:
|
|
182
|
+
max_x, denom, _ = utils.online_softmax_reduce(
|
|
183
|
+
x,
|
|
184
|
+
threads_per_row,
|
|
185
|
+
reduction_buffer[None, None, 0],
|
|
186
|
+
mbar_ptr,
|
|
187
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
if (
|
|
191
|
+
tXcX[0][1] == 0
|
|
192
|
+
and row < shape[0]
|
|
193
|
+
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
194
|
+
):
|
|
195
|
+
ln_2 = math.log(2.0)
|
|
196
|
+
lse = max_x + utils.log2f(denom) * ln_2
|
|
197
|
+
loss_val = lse - target_logit
|
|
198
|
+
mLoss[row] = loss_val.to(mLoss.element_type)
|
|
199
|
+
if cutlass.const_expr(mLSE is not None):
|
|
200
|
+
mLSE[row] = lse
|
|
176
201
|
|
|
177
202
|
|
|
178
203
|
def cross_entropy(
|
|
@@ -194,27 +219,36 @@ def cross_entropy(
|
|
|
194
219
|
assert x.shape[0] == target.shape[0], "Batch dimensions must match"
|
|
195
220
|
assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
|
|
196
221
|
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
|
|
197
|
-
assert target.dtype
|
|
222
|
+
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
198
223
|
M, N = x.shape
|
|
199
224
|
device = x.device
|
|
200
225
|
loss = torch.empty(M, device=device, dtype=torch.float32)
|
|
201
226
|
lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
|
|
202
227
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
203
228
|
convert_from_dlpack = lambda tensor: (
|
|
204
|
-
from_dlpack(tensor.detach(), assumed_align=16)
|
|
205
|
-
|
|
229
|
+
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
230
|
+
mode=0, stride_order=(0, 1)
|
|
231
|
+
)
|
|
206
232
|
)
|
|
207
|
-
x_tensor
|
|
233
|
+
x_tensor = convert_from_dlpack(x)
|
|
208
234
|
loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
|
|
209
|
-
lse_tensor =
|
|
235
|
+
lse_tensor = (
|
|
236
|
+
from_dlpack(lse.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
|
|
237
|
+
if lse is not None
|
|
238
|
+
else None
|
|
239
|
+
)
|
|
210
240
|
target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_compact_shape_dynamic(mode=0)
|
|
211
241
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
212
|
-
|
|
242
|
+
|
|
243
|
+
compile_key = (dtype, N, lse is not None)
|
|
213
244
|
if compile_key not in cross_entropy.compile_cache:
|
|
245
|
+
cross_entropy_op = CrossEntropy(dtype, N)
|
|
214
246
|
cross_entropy.compile_cache[compile_key] = cute.compile(
|
|
215
|
-
|
|
247
|
+
cross_entropy_op, x_tensor, target_tensor, loss_tensor, lse_tensor, stream
|
|
216
248
|
)
|
|
217
|
-
cross_entropy.compile_cache[compile_key](
|
|
249
|
+
cross_entropy.compile_cache[compile_key](
|
|
250
|
+
x_tensor, target_tensor, loss_tensor, lse_tensor, stream
|
|
251
|
+
)
|
|
218
252
|
return loss if not return_lse else (loss, lse)
|
|
219
253
|
|
|
220
254
|
|
quack/reduction_base.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from typing import Type, Tuple, Optional
|
|
5
|
+
|
|
6
|
+
import cutlass
|
|
7
|
+
import cutlass.cute as cute
|
|
8
|
+
|
|
9
|
+
import quack.utils as utils
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
torch2cute_dtype_map = {
|
|
13
|
+
torch.float16: cutlass.Float16,
|
|
14
|
+
torch.bfloat16: cutlass.BFloat16,
|
|
15
|
+
torch.float32: cutlass.Float32,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ReductionBase:
|
|
20
|
+
def __init__(
|
|
21
|
+
self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=cutlass.Float32
|
|
22
|
+
):
|
|
23
|
+
self.dtype = dtype
|
|
24
|
+
self.N = N
|
|
25
|
+
self.stage = stage
|
|
26
|
+
self.reduction_dtype = reduction_dtype
|
|
27
|
+
|
|
28
|
+
def _calculate_threads_per_row(self):
|
|
29
|
+
raise NotImplementedError()
|
|
30
|
+
|
|
31
|
+
def _set_cluster_n(self):
|
|
32
|
+
self.cluster_n = 1
|
|
33
|
+
|
|
34
|
+
def _get_num_threads(self):
|
|
35
|
+
return 128 if self.N <= 16384 else 256
|
|
36
|
+
|
|
37
|
+
def _get_tv_layout(self):
|
|
38
|
+
copy_bits = 128
|
|
39
|
+
vecsize = copy_bits // self.dtype.width
|
|
40
|
+
assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
|
|
41
|
+
num_threads = self._get_num_threads()
|
|
42
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
43
|
+
assert num_threads % cute.arch.WARP_SIZE == 0
|
|
44
|
+
|
|
45
|
+
threads_per_row = self._calculate_threads_per_row()
|
|
46
|
+
num_blocks_N = cute.ceil_div(self.N // vecsize, threads_per_row * self.cluster_n)
|
|
47
|
+
cols_per_block = num_threads // threads_per_row
|
|
48
|
+
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
|
|
49
|
+
tv_layout = cute.make_layout(
|
|
50
|
+
((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
|
|
51
|
+
stride=(
|
|
52
|
+
(vecsize * cols_per_block, 1),
|
|
53
|
+
(cols_per_block, cols_per_block * vecsize * threads_per_row),
|
|
54
|
+
),
|
|
55
|
+
)
|
|
56
|
+
return tiler_mn, tv_layout
|
|
57
|
+
|
|
58
|
+
def _smem_size_in_bytes(self, tiler_mn, num_warps):
|
|
59
|
+
return (
|
|
60
|
+
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
|
|
61
|
+
+ self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
|
|
62
|
+
+ self.stage * (cutlass.Int64.width // 8)
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int):
|
|
66
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
67
|
+
warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
|
|
68
|
+
return cute.make_ordered_layout(
|
|
69
|
+
(num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
|
|
70
|
+
order=(1, 0, 2),
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def _allocate_reduction_buffer_and_mbar(
|
|
74
|
+
self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout
|
|
75
|
+
) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
|
|
76
|
+
reduction_buffer = smem.allocate_tensor(
|
|
77
|
+
self.reduction_dtype,
|
|
78
|
+
self._get_reduction_buffer_layout(tv_layout, self.cluster_n),
|
|
79
|
+
byte_alignment=4,
|
|
80
|
+
)
|
|
81
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
82
|
+
mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=self.stage)
|
|
83
|
+
else:
|
|
84
|
+
mbar_ptr = None
|
|
85
|
+
return reduction_buffer, mbar_ptr
|
|
86
|
+
|
|
87
|
+
@cute.jit
|
|
88
|
+
def _initialize_cluster(self, tidx: cutlass.Int32, mbar_ptr: cute.Pointer, num_warps: int):
|
|
89
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
90
|
+
if tidx < self.stage:
|
|
91
|
+
cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
|
|
92
|
+
cute.arch.mbarrier_init_fence()
|
|
93
|
+
if tidx < self.stage:
|
|
94
|
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
95
|
+
mbar_ptr + tidx, num_warps * self.cluster_n * self.reduction_dtype.width // 8
|
|
96
|
+
)
|
|
97
|
+
# Cluster arrive after barrier init
|
|
98
|
+
cute.arch.cluster_arrive_relaxed()
|