quack-kernels 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -1
- quack/cross_entropy.py +199 -168
- quack/reduction_base.py +98 -0
- quack/rmsnorm.py +208 -195
- quack/softmax.py +409 -163
- quack/utils.py +249 -35
- {quack_kernels-0.1.1.dist-info → quack_kernels-0.1.3.dist-info}/METADATA +4 -1
- quack_kernels-0.1.3.dist-info/RECORD +11 -0
- quack_kernels-0.1.1.dist-info/RECORD +0 -10
- {quack_kernels-0.1.1.dist-info → quack_kernels-0.1.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.1.dist-info → quack_kernels-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.1.dist-info → quack_kernels-0.1.3.dist-info}/top_level.txt +0 -0
quack/softmax.py
CHANGED
|
@@ -1,176 +1,186 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import torch
|
|
3
|
-
import
|
|
4
|
-
from typing import Callable
|
|
3
|
+
from typing import Type
|
|
5
4
|
|
|
6
5
|
import cuda.bindings.driver as cuda
|
|
7
6
|
|
|
8
7
|
import cutlass
|
|
9
8
|
import cutlass.cute as cute
|
|
10
9
|
from cutlass.cute.runtime import from_dlpack
|
|
11
|
-
import cutlass.torch as cutlass_torch
|
|
12
10
|
|
|
13
11
|
import quack.utils as utils
|
|
12
|
+
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
14
13
|
|
|
15
14
|
|
|
16
|
-
|
|
17
|
-
def
|
|
18
|
-
gX: cute.Tensor,
|
|
19
|
-
gO: cute.Tensor,
|
|
20
|
-
cX: cute.Tensor, # coordinate tensor
|
|
21
|
-
shape: cute.Shape,
|
|
22
|
-
tv_layout: cute.Layout,
|
|
23
|
-
tiler_mn: cute.Shape,
|
|
24
|
-
cluster_n: cutlass.Constexpr = 1,
|
|
25
|
-
):
|
|
26
|
-
tidx, _, _ = cute.arch.thread_idx()
|
|
27
|
-
bidx, cluster_y, _ = cute.arch.block_idx()
|
|
28
|
-
gdim, _, _ = cute.arch.grid_dim()
|
|
29
|
-
|
|
30
|
-
# slice for CTAs
|
|
31
|
-
# logical id -> address
|
|
32
|
-
blkX, blkOut, blkCrd = [gT[(None, None), bidx if cluster_n == 1 else (bidx, cluster_y)] for gT in (gX, gO, cX)]
|
|
33
|
-
|
|
34
|
-
# declare the atoms which will be used later for memory copy
|
|
35
|
-
copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128)
|
|
36
|
-
copy_atom_load_X_async = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128)
|
|
37
|
-
copy_atom_store_O = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gO.element_type, num_bits_per_copy=128)
|
|
38
|
-
|
|
39
|
-
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
40
|
-
thr_copy_X_async = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(tidx)
|
|
41
|
-
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
42
|
-
|
|
43
|
-
smem = cutlass.utils.SmemAllocator()
|
|
44
|
-
# Don't use blkX.layout here, because the stride is N, not N_rounded
|
|
45
|
-
sX = smem.allocate_tensor(gX.element_type, cute.make_ordered_layout(blkX.shape, order=(1, 0)), byte_alignment=16)
|
|
46
|
-
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
47
|
-
warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
|
|
48
|
-
|
|
49
|
-
reduction_buffer_layout = cute.make_ordered_layout(
|
|
15
|
+
class Softmax(ReductionBase):
|
|
16
|
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int, online_softmax: bool = True):
|
|
50
17
|
# 2 stages: 1 for max, 1 for sum
|
|
51
|
-
(
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
else:
|
|
59
|
-
mbar_ptr = None
|
|
60
|
-
|
|
61
|
-
tXgX = thr_copy_X_async.partition_S(blkX)
|
|
62
|
-
tXsX = thr_copy_X_async.partition_S(sX)
|
|
63
|
-
tXgO = thr_copy_O.partition_D(blkOut)
|
|
64
|
-
tXcX = thr_copy_X.partition_S(blkCrd)[(0, None), None, None]
|
|
65
|
-
|
|
66
|
-
# allocate fragments for gmem->rmem
|
|
67
|
-
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
68
|
-
|
|
69
|
-
if cluster_n > 1:
|
|
70
|
-
if tidx < 2:
|
|
71
|
-
cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + tidx, 1)
|
|
72
|
-
cute.arch.mbarrier_init_fence()
|
|
73
|
-
if tidx < 2:
|
|
74
|
-
cute.arch.mbarrier_init_tx_bytes(mbar_ptr + tidx, num_warps * cluster_n * cutlass.Float32.width // 8)
|
|
75
|
-
# Cluster arrive after barrier init
|
|
76
|
-
cute.arch.cluster_arrive_relaxed()
|
|
77
|
-
|
|
78
|
-
tXpX = cute.make_fragment_like(tXgX[(0, None), None, None], cutlass.Boolean)
|
|
79
|
-
for i in range(cute.size(tXpX)):
|
|
80
|
-
tXpX[i] = cute.elem_less(tXcX[i][1], shape[1])
|
|
81
|
-
|
|
82
|
-
if tXcX[0][0] < shape[0]:
|
|
83
|
-
cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
|
|
84
|
-
cute.arch.cp_async_commit_group()
|
|
85
|
-
cute.arch.cp_async_wait_group(0)
|
|
86
|
-
|
|
87
|
-
cute.autovec_copy(tXsX, tXrX)
|
|
88
|
-
x = tXrX.load().to(cute.Float32)
|
|
89
|
-
max_x = utils.warp_reduce(
|
|
90
|
-
x.reduce(cute.ReductionOp.MAX, init_val=float('-inf'), reduction_profile=0),
|
|
91
|
-
cute.arch.fmax,
|
|
92
|
-
width=utils.min_constexpr(tv_layout.shape[0][0], cute.arch.WARP_SIZE),
|
|
93
|
-
)
|
|
94
|
-
if cutlass.const_expr(cluster_n > 1):
|
|
95
|
-
cute.arch.cluster_wait()
|
|
96
|
-
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
97
|
-
max_mbar_ptr = mbar_ptr + 0 if cluster_n > 1 else None
|
|
98
|
-
max_x = utils.block_or_cluster_reduce(
|
|
99
|
-
max_x, cute.arch.fmax, reduction_buffer[None, None, 0], max_mbar_ptr, init_val=-cutlass.Float32.inf
|
|
100
|
-
)
|
|
101
|
-
log2_e = math.log2(math.e)
|
|
102
|
-
exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
|
|
103
|
-
denom = utils.warp_reduce(
|
|
104
|
-
exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
|
|
105
|
-
operator.add,
|
|
106
|
-
width=utils.min_constexpr(tv_layout.shape[0][0], cute.arch.WARP_SIZE),
|
|
107
|
-
)
|
|
108
|
-
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
109
|
-
sum_mbar_ptr = mbar_ptr + 1 if cluster_n > 1 else None
|
|
110
|
-
denom = utils.block_or_cluster_reduce(
|
|
111
|
-
denom, operator.add, reduction_buffer[None, None, 1], sum_mbar_ptr, init_val=0.0
|
|
112
|
-
)
|
|
113
|
-
inv = 1.0 / denom
|
|
114
|
-
y = exp_x * inv
|
|
115
|
-
|
|
116
|
-
tXrO.store(y.to(tXrO.element_type))
|
|
117
|
-
tOcX = thr_copy_O.partition_S(blkCrd)[(0, None), None, None]
|
|
118
|
-
tOpO = cute.make_fragment_like(tXgO[(0, None), None, None], cutlass.Boolean)
|
|
119
|
-
for i in range(cute.size(tOpO)):
|
|
120
|
-
tOpO[i] = cute.elem_less(tOcX[i][1], shape[1])
|
|
121
|
-
if tXcX[0][0] < shape[0]:
|
|
122
|
-
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
@cute.jit
|
|
126
|
-
def softmax_interface(
|
|
127
|
-
mX: cute.Tensor,
|
|
128
|
-
mOut: cute.Tensor,
|
|
129
|
-
stream: cuda.CUstream,
|
|
130
|
-
N: cutlass.Constexpr,
|
|
131
|
-
copy_bits: cutlass.Constexpr = 128
|
|
132
|
-
):
|
|
133
|
-
vecsize = copy_bits // mX.element_type.width
|
|
134
|
-
assert N % vecsize == 0, f"Input N {N} is not divisible by vector size {vecsize}"
|
|
135
|
-
num_threads = 128 if N <= 16384 else 256
|
|
136
|
-
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
137
|
-
assert num_threads % cute.arch.WARP_SIZE == 0
|
|
138
|
-
threads_per_row = 8 if N <= 64 else (16 if N <= 128 else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256))))
|
|
139
|
-
if cutlass.const_expr(mX.element_type.width == 16):
|
|
140
|
-
cluster_n = 1 if N <= 16 * 1024 else (2 if N <= 32 * 1024 else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)))
|
|
141
|
-
else: # fp32
|
|
142
|
-
cluster_n = 1 if N <= 32 * 1024 else (2 if N <= 64 * 1024 else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)))
|
|
143
|
-
|
|
144
|
-
num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row * cluster_n)
|
|
145
|
-
cols_per_block = num_threads // threads_per_row
|
|
146
|
-
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) # This rounds up N
|
|
147
|
-
tv_layout = cute.make_layout(
|
|
148
|
-
((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
|
|
149
|
-
stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * threads_per_row))
|
|
150
|
-
)
|
|
18
|
+
super().__init__(
|
|
19
|
+
dtype,
|
|
20
|
+
N,
|
|
21
|
+
stage=2 if not online_softmax else 1,
|
|
22
|
+
reduction_dtype=cutlass.Float32 if not online_softmax else cutlass.Int64,
|
|
23
|
+
)
|
|
24
|
+
self.online_softmax = online_softmax
|
|
151
25
|
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
)
|
|
26
|
+
def _calculate_threads_per_row(self):
|
|
27
|
+
N = self.N
|
|
28
|
+
return (
|
|
29
|
+
8
|
|
30
|
+
if N <= 64
|
|
31
|
+
else (
|
|
32
|
+
16
|
|
33
|
+
if N <= 128
|
|
34
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
35
|
+
)
|
|
36
|
+
)
|
|
164
37
|
|
|
38
|
+
def _set_cluster_n(self):
|
|
39
|
+
N = self.N
|
|
40
|
+
if cutlass.const_expr(self.dtype.width == 16):
|
|
41
|
+
cluster_n = (
|
|
42
|
+
1
|
|
43
|
+
if N <= 16 * 1024
|
|
44
|
+
else (
|
|
45
|
+
2
|
|
46
|
+
if N <= 32 * 1024
|
|
47
|
+
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
48
|
+
)
|
|
49
|
+
)
|
|
50
|
+
else: # fp32
|
|
51
|
+
cluster_n = (
|
|
52
|
+
1
|
|
53
|
+
if N <= 32 * 1024
|
|
54
|
+
else (
|
|
55
|
+
2
|
|
56
|
+
if N <= 64 * 1024
|
|
57
|
+
else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
|
|
58
|
+
)
|
|
59
|
+
)
|
|
60
|
+
self.cluster_n = cluster_n
|
|
165
61
|
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
62
|
+
@cute.jit
|
|
63
|
+
def __call__(
|
|
64
|
+
self,
|
|
65
|
+
mX: cute.Tensor,
|
|
66
|
+
mO: cute.Tensor,
|
|
67
|
+
stream: cuda.CUstream,
|
|
68
|
+
):
|
|
69
|
+
assert mX.element_type == self.dtype
|
|
70
|
+
assert mO.element_type == self.dtype
|
|
71
|
+
self._set_cluster_n()
|
|
72
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
73
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
74
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
75
|
+
self.kernel(mX, mO, tv_layout, tiler_mn).launch(
|
|
76
|
+
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
77
|
+
block=[num_threads, 1, 1],
|
|
78
|
+
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
79
|
+
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
80
|
+
stream=stream,
|
|
81
|
+
)
|
|
171
82
|
|
|
83
|
+
@cute.kernel
|
|
84
|
+
def kernel(
|
|
85
|
+
self,
|
|
86
|
+
mX: cute.Tensor,
|
|
87
|
+
mO: cute.Tensor,
|
|
88
|
+
tv_layout: cute.Layout,
|
|
89
|
+
tiler_mn: cute.Shape,
|
|
90
|
+
):
|
|
91
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
92
|
+
bidx, cluster_y, _ = cute.arch.block_idx()
|
|
172
93
|
|
|
173
|
-
|
|
94
|
+
shape = mX.shape
|
|
95
|
+
idX = cute.make_identity_tensor(shape)
|
|
96
|
+
# slice for CTAs
|
|
97
|
+
gX, gO, cX = [
|
|
98
|
+
cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
|
|
99
|
+
for mT in (mX, mO, idX)
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
smem = cutlass.utils.SmemAllocator()
|
|
103
|
+
sX = smem.allocate_tensor(
|
|
104
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
105
|
+
)
|
|
106
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
107
|
+
|
|
108
|
+
# declare the atoms which will be used later for memory copy
|
|
109
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
110
|
+
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
|
|
111
|
+
)
|
|
112
|
+
copy_atom_store_O = cute.make_copy_atom(
|
|
113
|
+
cute.nvgpu.CopyUniversalOp(), gO.element_type, num_bits_per_copy=128
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
117
|
+
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
118
|
+
|
|
119
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
120
|
+
tXsX = thr_copy_X.partition_D(sX)
|
|
121
|
+
tXgO = thr_copy_O.partition_D(gO)
|
|
122
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
123
|
+
|
|
124
|
+
# allocate fragments for gmem->rmem
|
|
125
|
+
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
126
|
+
|
|
127
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
128
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
129
|
+
|
|
130
|
+
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
131
|
+
tXpX = (
|
|
132
|
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
133
|
+
)
|
|
134
|
+
if tXcX[0][0] < shape[0]:
|
|
135
|
+
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
|
136
|
+
cute.arch.cp_async_commit_group()
|
|
137
|
+
cute.arch.cp_async_wait_group(0)
|
|
138
|
+
# Fill OOB values with -inf
|
|
139
|
+
if cutlass.const_expr(not is_even_N):
|
|
140
|
+
utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
|
|
141
|
+
|
|
142
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
143
|
+
x = tXrX.load().to(cute.Float32)
|
|
144
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
145
|
+
if cutlass.const_expr(not self.online_softmax):
|
|
146
|
+
max_x = utils.row_reduce(
|
|
147
|
+
x,
|
|
148
|
+
cute.ReductionOp.MAX,
|
|
149
|
+
threads_per_row,
|
|
150
|
+
reduction_buffer[None, None, 0],
|
|
151
|
+
mbar_ptr + 0 if self.cluster_n > 1 else None,
|
|
152
|
+
init_val=-cutlass.Float32.inf,
|
|
153
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
154
|
+
)
|
|
155
|
+
log2_e = math.log2(math.e)
|
|
156
|
+
exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
|
|
157
|
+
denom = utils.row_reduce(
|
|
158
|
+
exp_x,
|
|
159
|
+
cute.ReductionOp.ADD,
|
|
160
|
+
threads_per_row,
|
|
161
|
+
reduction_buffer[None, None, 1],
|
|
162
|
+
mbar_ptr + 1 if self.cluster_n > 1 else None,
|
|
163
|
+
init_val=0.0,
|
|
164
|
+
)
|
|
165
|
+
else:
|
|
166
|
+
max_x, denom, exp_x = utils.online_softmax_reduce(
|
|
167
|
+
x,
|
|
168
|
+
threads_per_row,
|
|
169
|
+
reduction_buffer[None, None, 0],
|
|
170
|
+
mbar_ptr,
|
|
171
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
172
|
+
return_exp_x=True,
|
|
173
|
+
)
|
|
174
|
+
y = exp_x * (1.0 / denom)
|
|
175
|
+
tXrO.store(y.to(tXrO.element_type))
|
|
176
|
+
tOpO = (
|
|
177
|
+
utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
178
|
+
)
|
|
179
|
+
if tXcX[0][0] < shape[0]:
|
|
180
|
+
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _softmax_fwd(x: torch.Tensor) -> torch.Tensor:
|
|
174
184
|
"""Softmax forward pass.
|
|
175
185
|
Args:
|
|
176
186
|
x: Input tensor of shape (M, N)
|
|
@@ -181,22 +191,258 @@ def softmax(x: torch.Tensor) -> torch.Tensor:
|
|
|
181
191
|
assert x.is_cuda, "Tensor must be on CUDA device"
|
|
182
192
|
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
183
193
|
M, N = x.shape
|
|
184
|
-
device = x.device
|
|
185
194
|
out = torch.empty_like(x)
|
|
186
195
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
187
196
|
convert_from_dlpack = lambda tensor: (
|
|
188
|
-
from_dlpack(tensor.detach(), assumed_align=16)
|
|
189
|
-
|
|
197
|
+
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
198
|
+
mode=0, stride_order=(0, 1)
|
|
199
|
+
)
|
|
190
200
|
)
|
|
191
201
|
x_tensor, out_tensor = [convert_from_dlpack(tensor) for tensor in (x, out)]
|
|
192
202
|
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
193
203
|
compile_key = (dtype, N)
|
|
194
|
-
if compile_key not in
|
|
195
|
-
|
|
196
|
-
|
|
204
|
+
if compile_key not in _softmax_fwd.compile_cache:
|
|
205
|
+
softmax_op = Softmax(dtype, N)
|
|
206
|
+
_softmax_fwd.compile_cache[compile_key] = cute.compile(
|
|
207
|
+
softmax_op, x_tensor, out_tensor, current_stream
|
|
197
208
|
)
|
|
198
|
-
|
|
209
|
+
_softmax_fwd.compile_cache[compile_key](x_tensor, out_tensor, current_stream)
|
|
199
210
|
return out
|
|
200
211
|
|
|
201
212
|
|
|
202
|
-
|
|
213
|
+
_softmax_fwd.compile_cache = {}
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class SoftmaxBackward(ReductionBase):
|
|
217
|
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int):
|
|
218
|
+
# 1 stage for computing dot product
|
|
219
|
+
super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Float32)
|
|
220
|
+
|
|
221
|
+
def _calculate_threads_per_row(self):
|
|
222
|
+
N = self.N
|
|
223
|
+
return (
|
|
224
|
+
8
|
|
225
|
+
if N <= 64
|
|
226
|
+
else (
|
|
227
|
+
16
|
|
228
|
+
if N <= 128
|
|
229
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 8192 else 256)))
|
|
230
|
+
)
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
def _set_cluster_n(self):
|
|
234
|
+
N = self.N
|
|
235
|
+
if cutlass.const_expr(self.dtype.width == 16):
|
|
236
|
+
cluster_n = (
|
|
237
|
+
1
|
|
238
|
+
if N <= 16 * 1024
|
|
239
|
+
else (
|
|
240
|
+
2
|
|
241
|
+
if N <= 32 * 1024
|
|
242
|
+
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
243
|
+
)
|
|
244
|
+
)
|
|
245
|
+
else: # fp32
|
|
246
|
+
cluster_n = (
|
|
247
|
+
1
|
|
248
|
+
if N <= 16 * 1024
|
|
249
|
+
else (
|
|
250
|
+
2
|
|
251
|
+
if N <= 32 * 1024
|
|
252
|
+
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
253
|
+
)
|
|
254
|
+
)
|
|
255
|
+
self.cluster_n = cluster_n
|
|
256
|
+
|
|
257
|
+
def _get_num_threads(self):
|
|
258
|
+
return 128 if self.N <= 8192 else 256
|
|
259
|
+
|
|
260
|
+
def _smem_size_in_bytes(self, tiler_mn, num_warps):
|
|
261
|
+
return (
|
|
262
|
+
# Multiply by 2 since we need space for Y and dY
|
|
263
|
+
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2
|
|
264
|
+
+ self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
|
|
265
|
+
+ self.stage * (cutlass.Int64.width // 8)
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
@cute.jit
|
|
269
|
+
def __call__(
|
|
270
|
+
self,
|
|
271
|
+
mdY: cute.Tensor,
|
|
272
|
+
mY: cute.Tensor,
|
|
273
|
+
mdX: cute.Tensor,
|
|
274
|
+
stream: cuda.CUstream,
|
|
275
|
+
):
|
|
276
|
+
assert mdY.element_type == self.dtype
|
|
277
|
+
assert mY.element_type == self.dtype
|
|
278
|
+
assert mdX.element_type == self.dtype
|
|
279
|
+
self._set_cluster_n()
|
|
280
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
281
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
282
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
283
|
+
self.kernel(mdY, mY, mdX, tv_layout, tiler_mn).launch(
|
|
284
|
+
grid=[cute.ceil_div(mdY.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
285
|
+
block=[num_threads, 1, 1],
|
|
286
|
+
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
287
|
+
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
288
|
+
stream=stream,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
@cute.kernel
|
|
292
|
+
def kernel(
|
|
293
|
+
self,
|
|
294
|
+
mdY: cute.Tensor,
|
|
295
|
+
mY: cute.Tensor,
|
|
296
|
+
mdX: cute.Tensor,
|
|
297
|
+
tv_layout: cute.Layout,
|
|
298
|
+
tiler_mn: cute.Shape,
|
|
299
|
+
):
|
|
300
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
301
|
+
bidx, cluster_y, _ = cute.arch.block_idx()
|
|
302
|
+
|
|
303
|
+
shape = mdY.shape
|
|
304
|
+
idX = cute.make_identity_tensor(shape)
|
|
305
|
+
# slice for CTAs
|
|
306
|
+
gdY, gY, gdX, cX = [
|
|
307
|
+
cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
|
|
308
|
+
for mT in (mdY, mY, mdX, idX)
|
|
309
|
+
]
|
|
310
|
+
|
|
311
|
+
smem = cutlass.utils.SmemAllocator()
|
|
312
|
+
sdY = smem.allocate_tensor(
|
|
313
|
+
mdY.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
314
|
+
)
|
|
315
|
+
sY = smem.allocate_tensor(
|
|
316
|
+
mY.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
317
|
+
)
|
|
318
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
319
|
+
|
|
320
|
+
# declare the atoms which will be used later for memory copy
|
|
321
|
+
copy_atom_load = cute.make_copy_atom(
|
|
322
|
+
cute.nvgpu.cpasync.CopyG2SOp(), mdY.element_type, num_bits_per_copy=128
|
|
323
|
+
)
|
|
324
|
+
copy_atom_store = cute.make_copy_atom(
|
|
325
|
+
cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=128
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
thr_copy_load = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn).get_slice(tidx)
|
|
329
|
+
thr_copy_store = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn).get_slice(tidx)
|
|
330
|
+
|
|
331
|
+
tdYgdY = thr_copy_load.partition_S(gdY)
|
|
332
|
+
tdYsdY = thr_copy_load.partition_D(sdY)
|
|
333
|
+
tYgY = thr_copy_load.partition_S(gY)
|
|
334
|
+
tYsY = thr_copy_load.partition_D(sY)
|
|
335
|
+
tdXgdX = thr_copy_store.partition_D(gdX)
|
|
336
|
+
tXcX = thr_copy_load.partition_S(cX)[(0, None), None, None]
|
|
337
|
+
|
|
338
|
+
# allocate fragments for gmem->rmem
|
|
339
|
+
tdYrdY, tYrY, tdXrdX = [cute.make_fragment_like(thr) for thr in (tdYgdY, tYgY, tdXgdX)]
|
|
340
|
+
|
|
341
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
342
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
343
|
+
|
|
344
|
+
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
345
|
+
tdYpdY = (
|
|
346
|
+
utils.predicate_k(thr_copy_load.partition_S(cX), limit=shape[1])
|
|
347
|
+
if not is_even_N
|
|
348
|
+
else None
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
if tXcX[0][0] < shape[0]:
|
|
352
|
+
cute.copy(copy_atom_load, tdYgdY, tdYsdY, pred=tdYpdY)
|
|
353
|
+
cute.copy(copy_atom_load, tYgY, tYsY, pred=tdYpdY)
|
|
354
|
+
cute.arch.cp_async_commit_group()
|
|
355
|
+
cute.arch.cp_async_wait_group(0)
|
|
356
|
+
|
|
357
|
+
cute.autovec_copy(tdYsdY, tdYrdY)
|
|
358
|
+
cute.autovec_copy(tYsY, tYrY)
|
|
359
|
+
dy = tdYrdY.load().to(cute.Float32)
|
|
360
|
+
y = tYrY.load().to(cute.Float32)
|
|
361
|
+
|
|
362
|
+
# Compute dot product: dot = Σⱼ dy_j × y_j
|
|
363
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
364
|
+
dot = utils.row_reduce(
|
|
365
|
+
dy * y,
|
|
366
|
+
cute.ReductionOp.ADD,
|
|
367
|
+
threads_per_row,
|
|
368
|
+
reduction_buffer[None, None, 0],
|
|
369
|
+
mbar_ptr if self.cluster_n > 1 else None,
|
|
370
|
+
init_val=0.0,
|
|
371
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
# Compute gradient: dx_i = y_i × (dy_i - dot)
|
|
375
|
+
dx = y * (dy - dot)
|
|
376
|
+
tdXrdX.store(dx.to(tdXrdX.element_type))
|
|
377
|
+
tdXpdX = (
|
|
378
|
+
utils.predicate_k(thr_copy_store.partition_S(cX), limit=shape[1])
|
|
379
|
+
if not is_even_N
|
|
380
|
+
else None
|
|
381
|
+
)
|
|
382
|
+
if tXcX[0][0] < shape[0]:
|
|
383
|
+
cute.copy(copy_atom_store, tdXrdX, tdXgdX, pred=tdXpdX)
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def _softmax_backward(dy: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
387
|
+
"""Softmax backward pass.
|
|
388
|
+
Args:
|
|
389
|
+
dy: Upstream gradients tensor of shape (M, N)
|
|
390
|
+
y: Softmax output tensor of shape (M, N)
|
|
391
|
+
Returns:
|
|
392
|
+
Input gradients tensor of same shape as dy and y
|
|
393
|
+
"""
|
|
394
|
+
assert dy.dim() == 2, "dy must be 2D"
|
|
395
|
+
assert y.dim() == 2, "y must be 2D"
|
|
396
|
+
assert dy.shape == y.shape, "dy and y must have same shape"
|
|
397
|
+
assert dy.is_cuda and y.is_cuda, "Tensors must be on CUDA device"
|
|
398
|
+
assert dy.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
399
|
+
assert y.dtype == dy.dtype, "dy and y must have same dtype"
|
|
400
|
+
|
|
401
|
+
M, N = dy.shape
|
|
402
|
+
dx = torch.empty_like(dy)
|
|
403
|
+
dtype = torch2cute_dtype_map[dy.dtype]
|
|
404
|
+
convert_from_dlpack = lambda tensor: (
|
|
405
|
+
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
406
|
+
mode=0, stride_order=(0, 1)
|
|
407
|
+
)
|
|
408
|
+
)
|
|
409
|
+
dy_tensor, y_tensor, dx_tensor = [convert_from_dlpack(tensor) for tensor in (dy, y, dx)]
|
|
410
|
+
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
411
|
+
|
|
412
|
+
compile_key = (dtype, N)
|
|
413
|
+
if compile_key not in _softmax_backward.compile_cache:
|
|
414
|
+
softmax_backward_op = SoftmaxBackward(dtype, N)
|
|
415
|
+
_softmax_backward.compile_cache[compile_key] = cute.compile(
|
|
416
|
+
softmax_backward_op, dy_tensor, y_tensor, dx_tensor, current_stream
|
|
417
|
+
)
|
|
418
|
+
_softmax_backward.compile_cache[compile_key](dy_tensor, y_tensor, dx_tensor, current_stream)
|
|
419
|
+
return dx
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
_softmax_backward.compile_cache = {}
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
class SoftmaxFunction(torch.autograd.Function):
|
|
426
|
+
@staticmethod
|
|
427
|
+
def forward(ctx, x):
|
|
428
|
+
y = _softmax_fwd(x)
|
|
429
|
+
ctx.save_for_backward(y)
|
|
430
|
+
return y
|
|
431
|
+
|
|
432
|
+
@staticmethod
|
|
433
|
+
def backward(ctx, dy):
|
|
434
|
+
(y,) = ctx.saved_tensors
|
|
435
|
+
dx = _softmax_backward(dy, y)
|
|
436
|
+
return dx
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def softmax(x: torch.Tensor) -> torch.Tensor:
|
|
440
|
+
"""Softmax forward pass with automatic differentiation support.
|
|
441
|
+
|
|
442
|
+
Args:
|
|
443
|
+
x: Input tensor of shape (M, N)
|
|
444
|
+
|
|
445
|
+
Returns:
|
|
446
|
+
Softmax output tensor of same shape as x
|
|
447
|
+
"""
|
|
448
|
+
return SoftmaxFunction.apply(x)
|