quack-kernels 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -1
- quack/cross_entropy.py +201 -167
- quack/reduction_base.py +98 -0
- quack/rmsnorm.py +212 -181
- quack/softmax.py +417 -156
- quack/utils.py +206 -45
- quack_kernels-0.1.4.dist-info/METADATA +11 -0
- quack_kernels-0.1.4.dist-info/RECORD +11 -0
- quack_kernels-0.1.2.dist-info/METADATA +0 -8
- quack_kernels-0.1.2.dist-info/RECORD +0 -10
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.4.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.4.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.4.dist-info}/top_level.txt +0 -0
quack/softmax.py
CHANGED
|
@@ -1,169 +1,191 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import torch
|
|
3
|
-
import
|
|
4
|
-
from typing import Callable
|
|
3
|
+
from typing import Type
|
|
5
4
|
|
|
6
5
|
import cuda.bindings.driver as cuda
|
|
7
6
|
|
|
8
7
|
import cutlass
|
|
9
8
|
import cutlass.cute as cute
|
|
10
9
|
from cutlass.cute.runtime import from_dlpack
|
|
11
|
-
import cutlass.torch as cutlass_torch
|
|
12
10
|
|
|
13
11
|
import quack.utils as utils
|
|
12
|
+
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
14
13
|
|
|
15
14
|
|
|
16
|
-
|
|
17
|
-
def
|
|
18
|
-
mX: cute.Tensor,
|
|
19
|
-
mO: cute.Tensor,
|
|
20
|
-
tv_layout: cute.Layout,
|
|
21
|
-
tiler_mn: cute.Shape,
|
|
22
|
-
cluster_n: cutlass.Constexpr = 1,
|
|
23
|
-
):
|
|
24
|
-
tidx, _, _ = cute.arch.thread_idx()
|
|
25
|
-
bidx, cluster_y, _ = cute.arch.block_idx()
|
|
26
|
-
|
|
27
|
-
shape = mX.shape
|
|
28
|
-
idX = cute.make_identity_tensor(shape)
|
|
29
|
-
# slice for CTAs
|
|
30
|
-
gX, gO, cX = [
|
|
31
|
-
cute.local_tile(mT, tiler_mn, (bidx, 0 if cluster_n == 1 else cluster_y))
|
|
32
|
-
for mT in (mX, mO, idX)
|
|
33
|
-
]
|
|
34
|
-
|
|
35
|
-
smem = cutlass.utils.SmemAllocator()
|
|
36
|
-
sX = smem.allocate_tensor(mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16)
|
|
37
|
-
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
38
|
-
warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
|
|
39
|
-
reduction_buffer_layout = cute.make_ordered_layout(
|
|
15
|
+
class Softmax(ReductionBase):
|
|
16
|
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int, online_softmax: bool = True):
|
|
40
17
|
# 2 stages: 1 for max, 1 for sum
|
|
41
|
-
(
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
else:
|
|
49
|
-
mbar_ptr = None
|
|
50
|
-
|
|
51
|
-
# declare the atoms which will be used later for memory copy
|
|
52
|
-
copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128)
|
|
53
|
-
copy_atom_store_O = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gO.element_type, num_bits_per_copy=128)
|
|
54
|
-
|
|
55
|
-
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
56
|
-
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
57
|
-
|
|
58
|
-
tXgX = thr_copy_X.partition_S(gX)
|
|
59
|
-
tXsX = thr_copy_X.partition_D(sX)
|
|
60
|
-
tXgO = thr_copy_O.partition_D(gO)
|
|
61
|
-
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
62
|
-
|
|
63
|
-
# allocate fragments for gmem->rmem
|
|
64
|
-
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
65
|
-
|
|
66
|
-
if cluster_n > 1:
|
|
67
|
-
if tidx < 2:
|
|
68
|
-
cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + tidx, 1)
|
|
69
|
-
cute.arch.mbarrier_init_fence()
|
|
70
|
-
if tidx < 2:
|
|
71
|
-
cute.arch.mbarrier_init_tx_bytes(mbar_ptr + tidx, num_warps * cluster_n * cutlass.Float32.width // 8)
|
|
72
|
-
# Cluster arrive after barrier init
|
|
73
|
-
cute.arch.cluster_arrive_relaxed()
|
|
74
|
-
|
|
75
|
-
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * cluster_n)
|
|
76
|
-
tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
77
|
-
if tXcX[0][0] < shape[0]:
|
|
78
|
-
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
|
79
|
-
cute.arch.cp_async_commit_group()
|
|
80
|
-
cute.arch.cp_async_wait_group(0)
|
|
81
|
-
|
|
82
|
-
cute.autovec_copy(tXsX, tXrX)
|
|
83
|
-
x = tXrX.load().to(cute.Float32)
|
|
84
|
-
# Fill OOB values with -inf
|
|
85
|
-
if cutlass.const_expr(not is_even_N):
|
|
86
|
-
tXrX_fp32 = cute.make_fragment_like(tXrX, cutlass.Float32)
|
|
87
|
-
tXrX_fp32.store(x)
|
|
88
|
-
for rest_v in range(tXpX.shape[0]):
|
|
89
|
-
for rest_k in range(tXpX.shape[2]):
|
|
90
|
-
if not tXpX[rest_v, 0, rest_k]:
|
|
91
|
-
tXrX_fp32[(None, rest_v), None, rest_k].fill(-cutlass.Float32.inf)
|
|
92
|
-
x = tXrX_fp32.load()
|
|
93
|
-
threads_per_row = tv_layout.shape[0][0]
|
|
94
|
-
max_x = utils.row_reduce(
|
|
95
|
-
x,
|
|
96
|
-
cute.ReductionOp.MAX,
|
|
97
|
-
threads_per_row,
|
|
98
|
-
reduction_buffer[None, None, 0],
|
|
99
|
-
mbar_ptr + 0 if cluster_n > 1 else None,
|
|
100
|
-
init_val=-cutlass.Float32.inf,
|
|
101
|
-
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(cluster_n > 1) else None
|
|
102
|
-
)
|
|
103
|
-
log2_e = math.log2(math.e)
|
|
104
|
-
exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
|
|
105
|
-
denom = utils.row_reduce(
|
|
106
|
-
exp_x,
|
|
107
|
-
cute.ReductionOp.ADD,
|
|
108
|
-
threads_per_row,
|
|
109
|
-
reduction_buffer[None, None, 1],
|
|
110
|
-
mbar_ptr + 1 if cluster_n > 1 else None,
|
|
111
|
-
init_val=0.0,
|
|
112
|
-
)
|
|
113
|
-
inv = 1.0 / denom
|
|
114
|
-
y = exp_x * inv
|
|
115
|
-
tXrO.store(y.to(tXrO.element_type))
|
|
116
|
-
tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
117
|
-
if tXcX[0][0] < shape[0]:
|
|
118
|
-
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
@cute.jit
|
|
122
|
-
def softmax_interface(
|
|
123
|
-
mX: cute.Tensor,
|
|
124
|
-
mO: cute.Tensor,
|
|
125
|
-
stream: cuda.CUstream,
|
|
126
|
-
N: cutlass.Constexpr,
|
|
127
|
-
copy_bits: cutlass.Constexpr = 128
|
|
128
|
-
):
|
|
129
|
-
vecsize = copy_bits // mX.element_type.width
|
|
130
|
-
assert N % vecsize == 0, f"Input N {N} is not divisible by vector size {vecsize}"
|
|
131
|
-
num_threads = 128 if N <= 16384 else 256
|
|
132
|
-
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
133
|
-
assert num_threads % cute.arch.WARP_SIZE == 0
|
|
134
|
-
threads_per_row = 8 if N <= 64 else (16 if N <= 128 else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256))))
|
|
135
|
-
if cutlass.const_expr(mX.element_type.width == 16):
|
|
136
|
-
cluster_n = 1 if N <= 16 * 1024 else (2 if N <= 32 * 1024 else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)))
|
|
137
|
-
else: # fp32
|
|
138
|
-
cluster_n = 1 if N <= 32 * 1024 else (2 if N <= 64 * 1024 else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)))
|
|
139
|
-
|
|
140
|
-
num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row * cluster_n)
|
|
141
|
-
cols_per_block = num_threads // threads_per_row
|
|
142
|
-
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) # This rounds up N
|
|
143
|
-
tv_layout = cute.make_layout(
|
|
144
|
-
((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
|
|
145
|
-
stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * threads_per_row))
|
|
146
|
-
)
|
|
18
|
+
super().__init__(
|
|
19
|
+
dtype,
|
|
20
|
+
N,
|
|
21
|
+
stage=2 if not online_softmax else 1,
|
|
22
|
+
reduction_dtype=cutlass.Float32 if not online_softmax else cutlass.Int64,
|
|
23
|
+
)
|
|
24
|
+
self.online_softmax = online_softmax
|
|
147
25
|
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
26
|
+
def _calculate_threads_per_row(self):
|
|
27
|
+
N = self.N
|
|
28
|
+
return (
|
|
29
|
+
8
|
|
30
|
+
if N <= 64
|
|
31
|
+
else (
|
|
32
|
+
16
|
|
33
|
+
if N <= 128
|
|
34
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
35
|
+
)
|
|
36
|
+
)
|
|
157
37
|
|
|
38
|
+
def _set_cluster_n(self):
|
|
39
|
+
N = self.N
|
|
40
|
+
if cutlass.const_expr(self.dtype.width == 16):
|
|
41
|
+
cluster_n = (
|
|
42
|
+
1
|
|
43
|
+
if N <= 16 * 1024
|
|
44
|
+
else (
|
|
45
|
+
2
|
|
46
|
+
if N <= 32 * 1024
|
|
47
|
+
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
48
|
+
)
|
|
49
|
+
)
|
|
50
|
+
else: # fp32
|
|
51
|
+
cluster_n = (
|
|
52
|
+
1
|
|
53
|
+
if N <= 32 * 1024
|
|
54
|
+
else (
|
|
55
|
+
2
|
|
56
|
+
if N <= 64 * 1024
|
|
57
|
+
else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
|
|
58
|
+
)
|
|
59
|
+
)
|
|
60
|
+
self.cluster_n = cluster_n
|
|
158
61
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
62
|
+
@cute.jit
|
|
63
|
+
def __call__(
|
|
64
|
+
self,
|
|
65
|
+
mX: cute.Tensor,
|
|
66
|
+
mO: cute.Tensor,
|
|
67
|
+
stream: cuda.CUstream,
|
|
68
|
+
):
|
|
69
|
+
assert mX.element_type == self.dtype
|
|
70
|
+
assert mO.element_type == self.dtype
|
|
71
|
+
self._set_cluster_n()
|
|
72
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
73
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
74
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
75
|
+
self.kernel(mX, mO, tv_layout, tiler_mn).launch(
|
|
76
|
+
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
77
|
+
block=[num_threads, 1, 1],
|
|
78
|
+
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
79
|
+
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
80
|
+
stream=stream,
|
|
81
|
+
)
|
|
164
82
|
|
|
83
|
+
@cute.kernel
|
|
84
|
+
def kernel(
|
|
85
|
+
self,
|
|
86
|
+
mX: cute.Tensor,
|
|
87
|
+
mO: cute.Tensor,
|
|
88
|
+
tv_layout: cute.Layout,
|
|
89
|
+
tiler_mn: cute.Shape,
|
|
90
|
+
):
|
|
91
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
92
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
93
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
94
|
+
cluster_y = cute.arch.block_idx()[1]
|
|
95
|
+
else:
|
|
96
|
+
cluster_y = cutlass.const_expr(0)
|
|
165
97
|
|
|
166
|
-
|
|
98
|
+
shape = mX.shape
|
|
99
|
+
idX = cute.make_identity_tensor(shape)
|
|
100
|
+
# slice for CTAs
|
|
101
|
+
gX, gO, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)]
|
|
102
|
+
|
|
103
|
+
smem = cutlass.utils.SmemAllocator()
|
|
104
|
+
sX = smem.allocate_tensor(
|
|
105
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
106
|
+
)
|
|
107
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
108
|
+
|
|
109
|
+
# declare the atoms which will be used later for memory copy
|
|
110
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
111
|
+
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
|
|
112
|
+
)
|
|
113
|
+
copy_atom_store_O = cute.make_copy_atom(
|
|
114
|
+
cute.nvgpu.CopyUniversalOp(), gO.element_type, num_bits_per_copy=128
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
118
|
+
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
119
|
+
|
|
120
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
121
|
+
tXsX = thr_copy_X.partition_D(sX)
|
|
122
|
+
tXgO = thr_copy_O.partition_D(gO)
|
|
123
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
124
|
+
|
|
125
|
+
# allocate fragments for gmem->rmem
|
|
126
|
+
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
127
|
+
|
|
128
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
129
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
130
|
+
|
|
131
|
+
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
132
|
+
tXpX = (
|
|
133
|
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
134
|
+
if cutlass.const_expr(not is_even_N)
|
|
135
|
+
else None
|
|
136
|
+
)
|
|
137
|
+
if tXcX[0][0] < shape[0]:
|
|
138
|
+
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
|
139
|
+
cute.arch.cp_async_commit_group()
|
|
140
|
+
cute.arch.cp_async_wait_group(0)
|
|
141
|
+
# Fill OOB values with -inf
|
|
142
|
+
if cutlass.const_expr(not is_even_N):
|
|
143
|
+
utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
|
|
144
|
+
|
|
145
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
146
|
+
x = tXrX.load().to(cute.Float32)
|
|
147
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
148
|
+
if cutlass.const_expr(not self.online_softmax):
|
|
149
|
+
max_x = utils.row_reduce(
|
|
150
|
+
x,
|
|
151
|
+
cute.ReductionOp.MAX,
|
|
152
|
+
threads_per_row,
|
|
153
|
+
reduction_buffer[None, None, 0],
|
|
154
|
+
mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
155
|
+
init_val=-cutlass.Float32.inf,
|
|
156
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
157
|
+
)
|
|
158
|
+
log2_e = math.log2(math.e)
|
|
159
|
+
exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
|
|
160
|
+
denom = utils.row_reduce(
|
|
161
|
+
exp_x,
|
|
162
|
+
cute.ReductionOp.ADD,
|
|
163
|
+
threads_per_row,
|
|
164
|
+
reduction_buffer[None, None, 1],
|
|
165
|
+
mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
166
|
+
init_val=0.0,
|
|
167
|
+
)
|
|
168
|
+
else:
|
|
169
|
+
max_x, denom, exp_x = utils.online_softmax_reduce(
|
|
170
|
+
x,
|
|
171
|
+
threads_per_row,
|
|
172
|
+
reduction_buffer[None, None, 0],
|
|
173
|
+
mbar_ptr,
|
|
174
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
175
|
+
return_exp_x=True,
|
|
176
|
+
)
|
|
177
|
+
y = exp_x * (1.0 / denom)
|
|
178
|
+
tXrO.store(y.to(tXrO.element_type))
|
|
179
|
+
tOpO = (
|
|
180
|
+
utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
|
|
181
|
+
if cutlass.const_expr(not is_even_N)
|
|
182
|
+
else None
|
|
183
|
+
)
|
|
184
|
+
if tXcX[0][0] < shape[0]:
|
|
185
|
+
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _softmax_fwd(x: torch.Tensor) -> torch.Tensor:
|
|
167
189
|
"""Softmax forward pass.
|
|
168
190
|
Args:
|
|
169
191
|
x: Input tensor of shape (M, N)
|
|
@@ -174,22 +196,261 @@ def softmax(x: torch.Tensor) -> torch.Tensor:
|
|
|
174
196
|
assert x.is_cuda, "Tensor must be on CUDA device"
|
|
175
197
|
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
176
198
|
M, N = x.shape
|
|
177
|
-
device = x.device
|
|
178
199
|
out = torch.empty_like(x)
|
|
179
200
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
180
201
|
convert_from_dlpack = lambda tensor: (
|
|
181
|
-
from_dlpack(tensor.detach(), assumed_align=16)
|
|
182
|
-
|
|
202
|
+
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
203
|
+
mode=0, stride_order=(0, 1)
|
|
204
|
+
)
|
|
183
205
|
)
|
|
184
206
|
x_tensor, out_tensor = [convert_from_dlpack(tensor) for tensor in (x, out)]
|
|
185
207
|
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
186
208
|
compile_key = (dtype, N)
|
|
187
|
-
if compile_key not in
|
|
188
|
-
|
|
189
|
-
|
|
209
|
+
if compile_key not in _softmax_fwd.compile_cache:
|
|
210
|
+
softmax_op = Softmax(dtype, N)
|
|
211
|
+
_softmax_fwd.compile_cache[compile_key] = cute.compile(
|
|
212
|
+
softmax_op, x_tensor, out_tensor, current_stream
|
|
190
213
|
)
|
|
191
|
-
|
|
214
|
+
_softmax_fwd.compile_cache[compile_key](x_tensor, out_tensor, current_stream)
|
|
192
215
|
return out
|
|
193
216
|
|
|
194
217
|
|
|
195
|
-
|
|
218
|
+
_softmax_fwd.compile_cache = {}
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class SoftmaxBackward(ReductionBase):
|
|
222
|
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int):
|
|
223
|
+
# 1 stage for computing dot product
|
|
224
|
+
super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Float32)
|
|
225
|
+
|
|
226
|
+
def _calculate_threads_per_row(self):
|
|
227
|
+
N = self.N
|
|
228
|
+
return (
|
|
229
|
+
8
|
|
230
|
+
if N <= 64
|
|
231
|
+
else (
|
|
232
|
+
16
|
|
233
|
+
if N <= 128
|
|
234
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 8192 else 256)))
|
|
235
|
+
)
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
def _set_cluster_n(self):
|
|
239
|
+
N = self.N
|
|
240
|
+
if cutlass.const_expr(self.dtype.width == 16):
|
|
241
|
+
cluster_n = (
|
|
242
|
+
1
|
|
243
|
+
if N <= 16 * 1024
|
|
244
|
+
else (
|
|
245
|
+
2
|
|
246
|
+
if N <= 32 * 1024
|
|
247
|
+
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
248
|
+
)
|
|
249
|
+
)
|
|
250
|
+
else: # fp32
|
|
251
|
+
cluster_n = (
|
|
252
|
+
1
|
|
253
|
+
if N <= 16 * 1024
|
|
254
|
+
else (
|
|
255
|
+
2
|
|
256
|
+
if N <= 32 * 1024
|
|
257
|
+
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
258
|
+
)
|
|
259
|
+
)
|
|
260
|
+
self.cluster_n = cluster_n
|
|
261
|
+
|
|
262
|
+
def _get_num_threads(self):
|
|
263
|
+
return 128 if self.N <= 8192 else 256
|
|
264
|
+
|
|
265
|
+
def _smem_size_in_bytes(self, tiler_mn, num_warps):
|
|
266
|
+
return (
|
|
267
|
+
# Multiply by 2 since we need space for Y and dY
|
|
268
|
+
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2
|
|
269
|
+
+ self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
|
|
270
|
+
+ self.stage * (cutlass.Int64.width // 8)
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
@cute.jit
|
|
274
|
+
def __call__(
|
|
275
|
+
self,
|
|
276
|
+
mdY: cute.Tensor,
|
|
277
|
+
mY: cute.Tensor,
|
|
278
|
+
mdX: cute.Tensor,
|
|
279
|
+
stream: cuda.CUstream,
|
|
280
|
+
):
|
|
281
|
+
assert mdY.element_type == self.dtype
|
|
282
|
+
assert mY.element_type == self.dtype
|
|
283
|
+
assert mdX.element_type == self.dtype
|
|
284
|
+
self._set_cluster_n()
|
|
285
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
286
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
287
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
288
|
+
self.kernel(mdY, mY, mdX, tv_layout, tiler_mn).launch(
|
|
289
|
+
grid=[cute.ceil_div(mdY.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
290
|
+
block=[num_threads, 1, 1],
|
|
291
|
+
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
292
|
+
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
293
|
+
stream=stream,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
@cute.kernel
|
|
297
|
+
def kernel(
|
|
298
|
+
self,
|
|
299
|
+
mdY: cute.Tensor,
|
|
300
|
+
mY: cute.Tensor,
|
|
301
|
+
mdX: cute.Tensor,
|
|
302
|
+
tv_layout: cute.Layout,
|
|
303
|
+
tiler_mn: cute.Shape,
|
|
304
|
+
):
|
|
305
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
306
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
307
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
308
|
+
cluster_y = cute.arch.block_idx()[1]
|
|
309
|
+
else:
|
|
310
|
+
cluster_y = cutlass.const_expr(0)
|
|
311
|
+
|
|
312
|
+
shape = mdY.shape
|
|
313
|
+
idX = cute.make_identity_tensor(shape)
|
|
314
|
+
# slice for CTAs
|
|
315
|
+
gdY, gY, gdX, cX = [
|
|
316
|
+
cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mdY, mY, mdX, idX)
|
|
317
|
+
]
|
|
318
|
+
|
|
319
|
+
smem = cutlass.utils.SmemAllocator()
|
|
320
|
+
sdY = smem.allocate_tensor(
|
|
321
|
+
mdY.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
322
|
+
)
|
|
323
|
+
sY = smem.allocate_tensor(
|
|
324
|
+
mY.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
325
|
+
)
|
|
326
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
327
|
+
|
|
328
|
+
# declare the atoms which will be used later for memory copy
|
|
329
|
+
copy_atom_load = cute.make_copy_atom(
|
|
330
|
+
cute.nvgpu.cpasync.CopyG2SOp(), mdY.element_type, num_bits_per_copy=128
|
|
331
|
+
)
|
|
332
|
+
copy_atom_store = cute.make_copy_atom(
|
|
333
|
+
cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=128
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
thr_copy_load = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn).get_slice(tidx)
|
|
337
|
+
thr_copy_store = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn).get_slice(tidx)
|
|
338
|
+
|
|
339
|
+
tdYgdY = thr_copy_load.partition_S(gdY)
|
|
340
|
+
tdYsdY = thr_copy_load.partition_D(sdY)
|
|
341
|
+
tYgY = thr_copy_load.partition_S(gY)
|
|
342
|
+
tYsY = thr_copy_load.partition_D(sY)
|
|
343
|
+
tdXgdX = thr_copy_store.partition_D(gdX)
|
|
344
|
+
tXcX = thr_copy_load.partition_S(cX)[(0, None), None, None]
|
|
345
|
+
|
|
346
|
+
# allocate fragments for gmem->rmem
|
|
347
|
+
tdYrdY, tYrY, tdXrdX = [cute.make_fragment_like(thr) for thr in (tdYgdY, tYgY, tdXgdX)]
|
|
348
|
+
|
|
349
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
350
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
351
|
+
|
|
352
|
+
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
353
|
+
tdYpdY = (
|
|
354
|
+
utils.predicate_k(thr_copy_load.partition_S(cX), limit=shape[1])
|
|
355
|
+
if cutlass.const_expr(not is_even_N)
|
|
356
|
+
else None
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
if tXcX[0][0] < shape[0]:
|
|
360
|
+
cute.copy(copy_atom_load, tdYgdY, tdYsdY, pred=tdYpdY)
|
|
361
|
+
cute.copy(copy_atom_load, tYgY, tYsY, pred=tdYpdY)
|
|
362
|
+
cute.arch.cp_async_commit_group()
|
|
363
|
+
cute.arch.cp_async_wait_group(0)
|
|
364
|
+
|
|
365
|
+
cute.autovec_copy(tdYsdY, tdYrdY)
|
|
366
|
+
cute.autovec_copy(tYsY, tYrY)
|
|
367
|
+
dy = tdYrdY.load().to(cute.Float32)
|
|
368
|
+
y = tYrY.load().to(cute.Float32)
|
|
369
|
+
|
|
370
|
+
# Compute dot product: dot = Σⱼ dy_j × y_j
|
|
371
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
372
|
+
dot = utils.row_reduce(
|
|
373
|
+
dy * y,
|
|
374
|
+
cute.ReductionOp.ADD,
|
|
375
|
+
threads_per_row,
|
|
376
|
+
reduction_buffer[None, None, 0],
|
|
377
|
+
mbar_ptr if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
378
|
+
init_val=0.0,
|
|
379
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
# Compute gradient: dx_i = y_i × (dy_i - dot)
|
|
383
|
+
dx = y * (dy - dot)
|
|
384
|
+
tdXrdX.store(dx.to(tdXrdX.element_type))
|
|
385
|
+
tdXpdX = (
|
|
386
|
+
utils.predicate_k(thr_copy_store.partition_S(cX), limit=shape[1])
|
|
387
|
+
if cutlass.const_expr(not is_even_N)
|
|
388
|
+
else None
|
|
389
|
+
)
|
|
390
|
+
if tXcX[0][0] < shape[0]:
|
|
391
|
+
cute.copy(copy_atom_store, tdXrdX, tdXgdX, pred=tdXpdX)
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def _softmax_backward(dy: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
395
|
+
"""Softmax backward pass.
|
|
396
|
+
Args:
|
|
397
|
+
dy: Upstream gradients tensor of shape (M, N)
|
|
398
|
+
y: Softmax output tensor of shape (M, N)
|
|
399
|
+
Returns:
|
|
400
|
+
Input gradients tensor of same shape as dy and y
|
|
401
|
+
"""
|
|
402
|
+
assert dy.dim() == 2, "dy must be 2D"
|
|
403
|
+
assert y.dim() == 2, "y must be 2D"
|
|
404
|
+
assert dy.shape == y.shape, "dy and y must have same shape"
|
|
405
|
+
assert dy.is_cuda and y.is_cuda, "Tensors must be on CUDA device"
|
|
406
|
+
assert dy.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
407
|
+
assert y.dtype == dy.dtype, "dy and y must have same dtype"
|
|
408
|
+
|
|
409
|
+
M, N = dy.shape
|
|
410
|
+
dx = torch.empty_like(dy)
|
|
411
|
+
dtype = torch2cute_dtype_map[dy.dtype]
|
|
412
|
+
convert_from_dlpack = lambda tensor: (
|
|
413
|
+
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
414
|
+
mode=0, stride_order=(0, 1)
|
|
415
|
+
)
|
|
416
|
+
)
|
|
417
|
+
dy_tensor, y_tensor, dx_tensor = [convert_from_dlpack(tensor) for tensor in (dy, y, dx)]
|
|
418
|
+
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
419
|
+
|
|
420
|
+
compile_key = (dtype, N)
|
|
421
|
+
if compile_key not in _softmax_backward.compile_cache:
|
|
422
|
+
softmax_backward_op = SoftmaxBackward(dtype, N)
|
|
423
|
+
_softmax_backward.compile_cache[compile_key] = cute.compile(
|
|
424
|
+
softmax_backward_op, dy_tensor, y_tensor, dx_tensor, current_stream
|
|
425
|
+
)
|
|
426
|
+
_softmax_backward.compile_cache[compile_key](dy_tensor, y_tensor, dx_tensor, current_stream)
|
|
427
|
+
return dx
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
_softmax_backward.compile_cache = {}
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
class SoftmaxFunction(torch.autograd.Function):
|
|
434
|
+
@staticmethod
|
|
435
|
+
def forward(ctx, x):
|
|
436
|
+
y = _softmax_fwd(x)
|
|
437
|
+
ctx.save_for_backward(y)
|
|
438
|
+
return y
|
|
439
|
+
|
|
440
|
+
@staticmethod
|
|
441
|
+
def backward(ctx, dy):
|
|
442
|
+
(y,) = ctx.saved_tensors
|
|
443
|
+
dx = _softmax_backward(dy, y)
|
|
444
|
+
return dx
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def softmax(x: torch.Tensor) -> torch.Tensor:
|
|
448
|
+
"""Softmax forward pass with automatic differentiation support.
|
|
449
|
+
|
|
450
|
+
Args:
|
|
451
|
+
x: Input tensor of shape (M, N)
|
|
452
|
+
|
|
453
|
+
Returns:
|
|
454
|
+
Softmax output tensor of same shape as x
|
|
455
|
+
"""
|
|
456
|
+
return SoftmaxFunction.apply(x)
|