quack-kernels 0.1.3__tar.gz → 0.1.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {quack_kernels-0.1.3/quack_kernels.egg-info → quack_kernels-0.1.5}/PKG-INFO +2 -2
- quack_kernels-0.1.5/README.md +48 -0
- {quack_kernels-0.1.3 → quack_kernels-0.1.5}/pyproject.toml +1 -1
- {quack_kernels-0.1.3 → quack_kernels-0.1.5}/quack/__init__.py +1 -1
- quack_kernels-0.1.5/quack/cross_entropy.py +542 -0
- {quack_kernels-0.1.3 → quack_kernels-0.1.5}/quack/reduction_base.py +3 -6
- quack_kernels-0.1.5/quack/rmsnorm.py +662 -0
- {quack_kernels-0.1.3 → quack_kernels-0.1.5}/quack/softmax.py +25 -17
- {quack_kernels-0.1.3 → quack_kernels-0.1.5}/quack/utils.py +17 -29
- {quack_kernels-0.1.3 → quack_kernels-0.1.5/quack_kernels.egg-info}/PKG-INFO +2 -2
- quack_kernels-0.1.5/quack_kernels.egg-info/requires.txt +6 -0
- {quack_kernels-0.1.3 → quack_kernels-0.1.5}/tests/test_cross_entropy.py +54 -10
- quack_kernels-0.1.3/README.md +0 -37
- quack_kernels-0.1.3/quack/cross_entropy.py +0 -252
- quack_kernels-0.1.3/quack/rmsnorm.py +0 -284
- quack_kernels-0.1.3/quack_kernels.egg-info/requires.txt +0 -6
- {quack_kernels-0.1.3 → quack_kernels-0.1.5}/LICENSE +0 -0
- {quack_kernels-0.1.3 → quack_kernels-0.1.5}/quack_kernels.egg-info/SOURCES.txt +0 -0
- {quack_kernels-0.1.3 → quack_kernels-0.1.5}/quack_kernels.egg-info/dependency_links.txt +0 -0
- {quack_kernels-0.1.3 → quack_kernels-0.1.5}/quack_kernels.egg-info/top_level.txt +0 -0
- {quack_kernels-0.1.3 → quack_kernels-0.1.5}/setup.cfg +0 -0
- {quack_kernels-0.1.3 → quack_kernels-0.1.5}/setup.py +0 -0
- {quack_kernels-0.1.3 → quack_kernels-0.1.5}/tests/test_rmsnorm.py +0 -0
- {quack_kernels-0.1.3 → quack_kernels-0.1.5}/tests/test_softmax.py +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: quack-kernels
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.5
|
|
4
4
|
Requires-Python: >=3.9
|
|
5
5
|
License-File: LICENSE
|
|
6
|
-
Requires-Dist: nvidia-cutlass-dsl==4.0.
|
|
6
|
+
Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
|
|
7
7
|
Requires-Dist: torch
|
|
8
8
|
Provides-Extra: dev
|
|
9
9
|
Requires-Dist: pre-commit; extra == "dev"
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# 🦆 QuACK: A Quirky Assortment of CuTe Kernels 🦆
|
|
2
|
+
|
|
3
|
+
Kernels are written in the [CuTe-DSL](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.html).
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
|
|
7
|
+
``` bash
|
|
8
|
+
pip install quack-kernels
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
## Requirements
|
|
12
|
+
|
|
13
|
+
- H100 or B200 GPU
|
|
14
|
+
- CUDA toolkit 12.9+
|
|
15
|
+
- Python 3.12
|
|
16
|
+
|
|
17
|
+
## Kernels 🐥
|
|
18
|
+
|
|
19
|
+
- 🦆 RMSNorm forward
|
|
20
|
+
- 🦆 Softmax forward + backward
|
|
21
|
+
- 🦆 Cross entropy forward + backward
|
|
22
|
+
|
|
23
|
+
Upcoming:
|
|
24
|
+
- 🦆 RMSNorm backward
|
|
25
|
+
- 🦆 Rotary forward + backward
|
|
26
|
+
|
|
27
|
+
## Usage
|
|
28
|
+
|
|
29
|
+
```
|
|
30
|
+
from quack import rmsnorm, softmax, cross_entropy
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
## Caveats 🦆⚠️
|
|
34
|
+
|
|
35
|
+
**Tensor Size Limitation**: We currently only support tensors ≤ 4GB due to CuTe-DSL using int32 for indexing.
|
|
36
|
+
|
|
37
|
+
🦆 **Workaround**: For larger tensors, split your input tensors into chunks of
|
|
38
|
+
size ≤ 4GB each. We will implement this automatic chunking in the pytorch part
|
|
39
|
+
of the code in the near future, but if you need it in the meantime, we welcome contributions!
|
|
40
|
+
|
|
41
|
+
## Development
|
|
42
|
+
|
|
43
|
+
To set up the development environment:
|
|
44
|
+
|
|
45
|
+
```bash
|
|
46
|
+
pip install -e '.[dev]'
|
|
47
|
+
pre-commit install
|
|
48
|
+
```
|
|
@@ -0,0 +1,542 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import torch
|
|
5
|
+
from typing import Optional, Type
|
|
6
|
+
|
|
7
|
+
import cuda.bindings.driver as cuda
|
|
8
|
+
|
|
9
|
+
import cutlass
|
|
10
|
+
import cutlass.cute as cute
|
|
11
|
+
from cutlass.cute.runtime import from_dlpack
|
|
12
|
+
|
|
13
|
+
import quack.utils as utils
|
|
14
|
+
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CrossEntropy(ReductionBase):
|
|
18
|
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int, online_softmax: bool = True):
|
|
19
|
+
# 2 stages: 1 for max, 1 for sum
|
|
20
|
+
super().__init__(
|
|
21
|
+
dtype,
|
|
22
|
+
N,
|
|
23
|
+
stage=2 if not online_softmax else 1,
|
|
24
|
+
reduction_dtype=cutlass.Float32 if not online_softmax else cutlass.Int64,
|
|
25
|
+
)
|
|
26
|
+
self.online_softmax = online_softmax
|
|
27
|
+
self.reload_from = None if N <= 16384 or online_softmax else "smem"
|
|
28
|
+
|
|
29
|
+
def _calculate_threads_per_row(self):
|
|
30
|
+
N = self.N
|
|
31
|
+
return (
|
|
32
|
+
8
|
|
33
|
+
if N <= 64
|
|
34
|
+
else (
|
|
35
|
+
16
|
|
36
|
+
if N <= 128
|
|
37
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
38
|
+
)
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def _set_cluster_n(self):
|
|
42
|
+
N = self.N
|
|
43
|
+
if cutlass.const_expr(self.dtype.width == 16):
|
|
44
|
+
cluster_n = (
|
|
45
|
+
1
|
|
46
|
+
if N <= 16 * 1024
|
|
47
|
+
else (
|
|
48
|
+
2
|
|
49
|
+
if N <= 32 * 1024
|
|
50
|
+
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
51
|
+
)
|
|
52
|
+
)
|
|
53
|
+
else: # fp32
|
|
54
|
+
cluster_n = (
|
|
55
|
+
1
|
|
56
|
+
if N <= 16 * 1024
|
|
57
|
+
else (
|
|
58
|
+
2
|
|
59
|
+
if N <= 64 * 1024
|
|
60
|
+
else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
|
|
61
|
+
)
|
|
62
|
+
)
|
|
63
|
+
self.cluster_n = cluster_n
|
|
64
|
+
|
|
65
|
+
@cute.jit
|
|
66
|
+
def __call__(
|
|
67
|
+
self,
|
|
68
|
+
mX: cute.Tensor,
|
|
69
|
+
mTarget: cute.Tensor,
|
|
70
|
+
mLoss: cute.Tensor,
|
|
71
|
+
mLSE: Optional[cute.Tensor],
|
|
72
|
+
stream: cuda.CUstream,
|
|
73
|
+
):
|
|
74
|
+
assert mX.element_type == self.dtype
|
|
75
|
+
self._set_cluster_n()
|
|
76
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
77
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
78
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
79
|
+
self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
|
|
80
|
+
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
81
|
+
block=[num_threads, 1, 1],
|
|
82
|
+
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
83
|
+
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
84
|
+
stream=stream,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
@cute.kernel
|
|
88
|
+
def kernel(
|
|
89
|
+
self,
|
|
90
|
+
mX: cute.Tensor, # (M, N)
|
|
91
|
+
mTarget: cute.Tensor, # (M,)
|
|
92
|
+
mLoss: cute.Tensor, # (M,)
|
|
93
|
+
mLSE: Optional[cute.Tensor], # (M,)
|
|
94
|
+
tv_layout: cute.Layout,
|
|
95
|
+
tiler_mn: cute.Shape,
|
|
96
|
+
):
|
|
97
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
98
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
99
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
100
|
+
cluster_y = cute.arch.block_idx()[1]
|
|
101
|
+
else:
|
|
102
|
+
cluster_y = cutlass.const_expr(0)
|
|
103
|
+
|
|
104
|
+
shape: cute.Shape = mX.shape
|
|
105
|
+
idX = cute.make_identity_tensor(shape)
|
|
106
|
+
# slice for CTAs
|
|
107
|
+
gX, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, idX)]
|
|
108
|
+
|
|
109
|
+
smem = cutlass.utils.SmemAllocator()
|
|
110
|
+
sX = smem.allocate_tensor(
|
|
111
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
112
|
+
)
|
|
113
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
114
|
+
|
|
115
|
+
# declare the atoms which will be used later for memory copy
|
|
116
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
117
|
+
cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
|
|
118
|
+
)
|
|
119
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
120
|
+
|
|
121
|
+
#### Thread View
|
|
122
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
123
|
+
tXsX = thr_copy_X.partition_D(sX)
|
|
124
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
125
|
+
tXrX = cute.make_fragment_like(tXgX)
|
|
126
|
+
|
|
127
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
128
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
129
|
+
|
|
130
|
+
row = tXcX[0][0]
|
|
131
|
+
target = cute.Int32.zero
|
|
132
|
+
if row < shape[0] and tXcX[0][1] == 0:
|
|
133
|
+
target = cute.Int32(mTarget[row])
|
|
134
|
+
|
|
135
|
+
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
136
|
+
tXpX = (
|
|
137
|
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
138
|
+
if cutlass.const_expr(not is_even_N)
|
|
139
|
+
else None
|
|
140
|
+
)
|
|
141
|
+
if row < shape[0]:
|
|
142
|
+
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
|
143
|
+
cute.arch.cp_async_commit_group()
|
|
144
|
+
cute.arch.cp_async_wait_group(0)
|
|
145
|
+
# Fill OOB values with -inf
|
|
146
|
+
if cutlass.const_expr(not is_even_N):
|
|
147
|
+
utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
|
|
148
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
149
|
+
x = tXrX.load().to(cute.Float32)
|
|
150
|
+
|
|
151
|
+
target_logit = cute.Float32.zero
|
|
152
|
+
if row < shape[0] and tXcX[0][1] == 0:
|
|
153
|
+
target_logit = cute.Float32(mX[row, target])
|
|
154
|
+
|
|
155
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
156
|
+
if cutlass.const_expr(not self.online_softmax):
|
|
157
|
+
max_x = utils.row_reduce(
|
|
158
|
+
x,
|
|
159
|
+
cute.ReductionOp.MAX,
|
|
160
|
+
threads_per_row,
|
|
161
|
+
reduction_buffer[None, None, 0],
|
|
162
|
+
mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
163
|
+
init_val=-cutlass.Float32.inf,
|
|
164
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
165
|
+
)
|
|
166
|
+
if cutlass.const_expr(self.reload_from == "smem"):
|
|
167
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
168
|
+
x = tXrX.load().to(cute.Float32)
|
|
169
|
+
log2_e = math.log2(math.e)
|
|
170
|
+
# exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
|
|
171
|
+
# a bit faster, probably because it's calling ex2.approx.ftz instead of ex2.approx?
|
|
172
|
+
# exp_x = utils.exp2f((x - max_x) * log2_e)
|
|
173
|
+
# This would use ffma instead of fadd then fmul
|
|
174
|
+
exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
|
|
175
|
+
denom = utils.row_reduce(
|
|
176
|
+
exp_x,
|
|
177
|
+
cute.ReductionOp.ADD,
|
|
178
|
+
threads_per_row,
|
|
179
|
+
reduction_buffer[None, None, 1],
|
|
180
|
+
mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
181
|
+
init_val=0.0,
|
|
182
|
+
)
|
|
183
|
+
else:
|
|
184
|
+
max_x, denom, _ = utils.online_softmax_reduce(
|
|
185
|
+
x,
|
|
186
|
+
threads_per_row,
|
|
187
|
+
reduction_buffer[None, None, 0],
|
|
188
|
+
mbar_ptr,
|
|
189
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
if (
|
|
193
|
+
tXcX[0][1] == 0
|
|
194
|
+
and row < shape[0]
|
|
195
|
+
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
196
|
+
):
|
|
197
|
+
ln_2 = math.log(2.0)
|
|
198
|
+
lse = max_x + utils.log2f(denom) * ln_2
|
|
199
|
+
loss_val = lse - target_logit
|
|
200
|
+
mLoss[row] = loss_val.to(mLoss.element_type)
|
|
201
|
+
if cutlass.const_expr(mLSE is not None):
|
|
202
|
+
mLSE[row] = lse
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _cross_entropy(
|
|
206
|
+
x: torch.Tensor,
|
|
207
|
+
target: torch.Tensor,
|
|
208
|
+
return_lse: bool = False,
|
|
209
|
+
) -> torch.Tensor:
|
|
210
|
+
"""Cross entropy forward pass.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
x: Input logits tensor of shape (M, N)
|
|
214
|
+
target: Target class indices tensor of shape (M,)
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
Cross entropy loss tensor of shape (M,)
|
|
218
|
+
"""
|
|
219
|
+
assert x.dim() == 2, "Input must be 2D"
|
|
220
|
+
assert target.dim() == 1, "Target must be 1D"
|
|
221
|
+
assert x.shape[0] == target.shape[0], "Batch dimensions must match"
|
|
222
|
+
assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
|
|
223
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
|
|
224
|
+
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
225
|
+
M, N = x.shape
|
|
226
|
+
device = x.device
|
|
227
|
+
loss = torch.empty(M, device=device, dtype=torch.float32)
|
|
228
|
+
lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
|
|
229
|
+
dtype = torch2cute_dtype_map[x.dtype]
|
|
230
|
+
convert_from_dlpack = lambda tensor: (
|
|
231
|
+
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
232
|
+
mode=0, stride_order=(0, 1)
|
|
233
|
+
)
|
|
234
|
+
)
|
|
235
|
+
x_tensor = convert_from_dlpack(x)
|
|
236
|
+
loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
|
|
237
|
+
lse_tensor = (
|
|
238
|
+
from_dlpack(lse.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
|
|
239
|
+
if lse is not None
|
|
240
|
+
else None
|
|
241
|
+
)
|
|
242
|
+
target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_compact_shape_dynamic(mode=0)
|
|
243
|
+
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
244
|
+
|
|
245
|
+
compile_key = (dtype, N, lse is not None)
|
|
246
|
+
if compile_key not in _cross_entropy.compile_cache:
|
|
247
|
+
cross_entropy_op = CrossEntropy(dtype, N)
|
|
248
|
+
_cross_entropy.compile_cache[compile_key] = cute.compile(
|
|
249
|
+
cross_entropy_op, x_tensor, target_tensor, loss_tensor, lse_tensor, stream
|
|
250
|
+
)
|
|
251
|
+
_cross_entropy.compile_cache[compile_key](
|
|
252
|
+
x_tensor, target_tensor, loss_tensor, lse_tensor, stream
|
|
253
|
+
)
|
|
254
|
+
return loss if not return_lse else (loss, lse)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
_cross_entropy.compile_cache = {}
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
class CrossEntropyBackward:
|
|
261
|
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int):
|
|
262
|
+
self.dtype = dtype
|
|
263
|
+
self.N = N
|
|
264
|
+
self.vecsize = 128 // dtype.width
|
|
265
|
+
|
|
266
|
+
def _calculate_threads_per_row(self):
|
|
267
|
+
N = self.N
|
|
268
|
+
return (
|
|
269
|
+
8
|
|
270
|
+
if N <= 64
|
|
271
|
+
else (
|
|
272
|
+
16
|
|
273
|
+
if N <= 128
|
|
274
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
275
|
+
)
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
def _get_tv_layout(self):
|
|
279
|
+
N = self.N
|
|
280
|
+
vecsize = self.vecsize
|
|
281
|
+
num_threads = 128 if N <= 16384 else 256
|
|
282
|
+
threads_per_row = self._calculate_threads_per_row()
|
|
283
|
+
cols_per_block = num_threads // threads_per_row
|
|
284
|
+
num_blocks_N = cute.ceil_div(min(N, 16384) // vecsize, threads_per_row)
|
|
285
|
+
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
|
|
286
|
+
tv_layout = cute.make_layout(
|
|
287
|
+
((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
|
|
288
|
+
stride=(
|
|
289
|
+
(vecsize * cols_per_block, 1),
|
|
290
|
+
(cols_per_block, cols_per_block * vecsize * threads_per_row),
|
|
291
|
+
),
|
|
292
|
+
)
|
|
293
|
+
return tiler_mn, tv_layout
|
|
294
|
+
|
|
295
|
+
@cute.jit
|
|
296
|
+
def __call__(
|
|
297
|
+
self,
|
|
298
|
+
mX: cute.Tensor,
|
|
299
|
+
mTarget: cute.Tensor,
|
|
300
|
+
mDLoss: cute.Tensor,
|
|
301
|
+
mdX: cute.Tensor,
|
|
302
|
+
mLSE: cute.Tensor,
|
|
303
|
+
stream: cuda.CUstream,
|
|
304
|
+
):
|
|
305
|
+
assert mX.element_type == self.dtype
|
|
306
|
+
assert mdX.element_type == self.dtype
|
|
307
|
+
|
|
308
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
309
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
310
|
+
|
|
311
|
+
mDLoss = cute.make_tensor(
|
|
312
|
+
mDLoss.iterator, cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,)))
|
|
313
|
+
)
|
|
314
|
+
mTarget = cute.make_tensor(
|
|
315
|
+
mTarget.iterator, cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,)))
|
|
316
|
+
)
|
|
317
|
+
mLSE = cute.make_tensor(
|
|
318
|
+
mLSE.iterator, cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,)))
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
smem_size = cute.size_in_bytes(
|
|
322
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0))
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
self.kernel(
|
|
326
|
+
mX,
|
|
327
|
+
mTarget,
|
|
328
|
+
mDLoss,
|
|
329
|
+
mdX,
|
|
330
|
+
mLSE,
|
|
331
|
+
mX.shape,
|
|
332
|
+
tv_layout,
|
|
333
|
+
tiler_mn,
|
|
334
|
+
).launch(
|
|
335
|
+
grid=[
|
|
336
|
+
cute.ceil_div(mX.shape[0], tiler_mn[0]),
|
|
337
|
+
cute.ceil_div(mX.shape[1], tiler_mn[1]),
|
|
338
|
+
1,
|
|
339
|
+
],
|
|
340
|
+
block=[num_threads, 1, 1],
|
|
341
|
+
smem=smem_size,
|
|
342
|
+
stream=stream,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
@cute.kernel
|
|
346
|
+
def kernel(
|
|
347
|
+
self,
|
|
348
|
+
mX: cute.Tensor, # (M, N)
|
|
349
|
+
mTarget: cute.Tensor, # (M,)
|
|
350
|
+
mDLoss: cute.Tensor, # (M,)
|
|
351
|
+
mdX: cute.Tensor, # (M, N)
|
|
352
|
+
mLSE: cute.Tensor, # (M,)
|
|
353
|
+
shape: cute.Shape,
|
|
354
|
+
tv_layout: cute.Layout,
|
|
355
|
+
tiler_mn: cute.Shape,
|
|
356
|
+
):
|
|
357
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
358
|
+
bidx, bidy, _ = cute.arch.block_idx()
|
|
359
|
+
|
|
360
|
+
smem = cutlass.utils.SmemAllocator()
|
|
361
|
+
sX = smem.allocate_tensor(
|
|
362
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
idX = cute.make_identity_tensor(shape)
|
|
366
|
+
|
|
367
|
+
gX, gdX, cX, gTarget, gDLoss, gLse = [
|
|
368
|
+
cute.local_tile(mT, tiler_mn, (bidx, bidy))
|
|
369
|
+
for mT in (mX, mdX, idX, mTarget, mDLoss, mLSE)
|
|
370
|
+
]
|
|
371
|
+
|
|
372
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
373
|
+
cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128
|
|
374
|
+
)
|
|
375
|
+
copy_atom_load_X_async = cute.make_copy_atom(
|
|
376
|
+
cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
|
|
377
|
+
)
|
|
378
|
+
copy_atom_store_O = cute.make_copy_atom(
|
|
379
|
+
cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=128
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
383
|
+
thr_copy_X_async = cute.make_tiled_copy(
|
|
384
|
+
copy_atom_load_X_async, tv_layout, tiler_mn
|
|
385
|
+
).get_slice(tidx)
|
|
386
|
+
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
387
|
+
|
|
388
|
+
#### Thread View
|
|
389
|
+
tXgX = thr_copy_X_async.partition_S(gX)
|
|
390
|
+
tXsX = thr_copy_X_async.partition_S(sX)
|
|
391
|
+
|
|
392
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
393
|
+
tXcFull = thr_copy_X.partition_S(cX) # improve
|
|
394
|
+
|
|
395
|
+
tXgO = thr_copy_O.partition_D(gdX)
|
|
396
|
+
|
|
397
|
+
# allocate fragments for gmem->rmem
|
|
398
|
+
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
399
|
+
|
|
400
|
+
is_even_N = cutlass.const_expr(shape[1] % tiler_mn[1] == 0)
|
|
401
|
+
row = tXcX[0][0]
|
|
402
|
+
|
|
403
|
+
tXpX = (
|
|
404
|
+
utils.predicate_k(thr_copy_X_async.partition_S(cX), limit=shape[1])
|
|
405
|
+
if not is_even_N
|
|
406
|
+
else None
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
if row < shape[0]:
|
|
410
|
+
cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
|
|
411
|
+
cute.arch.cp_async_commit_group()
|
|
412
|
+
cute.arch.cp_async_wait_group(0)
|
|
413
|
+
if cutlass.const_expr(not is_even_N):
|
|
414
|
+
utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
|
|
415
|
+
|
|
416
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
417
|
+
x = tXrX.load().to(cute.Float32)
|
|
418
|
+
|
|
419
|
+
label = cute.Int32.zero
|
|
420
|
+
dloss = cute.Float32.zero
|
|
421
|
+
lse = cute.Float32.zero
|
|
422
|
+
if row < shape[0]:
|
|
423
|
+
label = cute.Int32(mTarget[row])
|
|
424
|
+
dloss = cute.Float32(mDLoss[row])
|
|
425
|
+
lse = cute.Float32(mLSE[row])
|
|
426
|
+
|
|
427
|
+
log2_e = math.log2(math.e)
|
|
428
|
+
probs = utils.exp2f((x - lse) * log2_e)
|
|
429
|
+
prob_shifted = probs - 1.0
|
|
430
|
+
|
|
431
|
+
mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
|
|
432
|
+
for i in cutlass.range_constexpr(cute.size(tXcFull)):
|
|
433
|
+
mask[i] = tXcFull[i][1] == label
|
|
434
|
+
|
|
435
|
+
mask = mask.load()
|
|
436
|
+
grad = cute.where(mask, prob_shifted, probs)
|
|
437
|
+
grad = grad * dloss
|
|
438
|
+
|
|
439
|
+
tXrO.store(grad.to(tXrO.element_type))
|
|
440
|
+
tOpO = (
|
|
441
|
+
utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
442
|
+
)
|
|
443
|
+
if row < shape[0]:
|
|
444
|
+
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def _cross_entropy_backward(
|
|
448
|
+
x: torch.Tensor,
|
|
449
|
+
target: torch.Tensor,
|
|
450
|
+
dloss: torch.Tensor,
|
|
451
|
+
lse: torch.Tensor,
|
|
452
|
+
inplace_backward: bool = False,
|
|
453
|
+
) -> torch.Tensor:
|
|
454
|
+
"""Cross entropy backward pass.
|
|
455
|
+
Args:
|
|
456
|
+
x: Input logits tensor of shape (M, N)
|
|
457
|
+
target: Target class indices tensor of shape (M,)
|
|
458
|
+
dloss: Upstream gradients tensor of shape (M,)
|
|
459
|
+
lse: Log-sum-exp values tensor of shape (M,)
|
|
460
|
+
Returns:
|
|
461
|
+
Input gradients tensor of shape (M, N)
|
|
462
|
+
"""
|
|
463
|
+
assert x.dim() == 2, "Input must be 2D"
|
|
464
|
+
assert target.dim() == 1, "Target must be 1D"
|
|
465
|
+
assert dloss.dim() == 1, "dloss must be 1D"
|
|
466
|
+
assert lse.dim() == 1, "lse must be 1D"
|
|
467
|
+
assert x.shape[0] == target.shape[0], "Batch dimensions must match"
|
|
468
|
+
assert x.shape[0] == dloss.shape[0], "Batch dimensions must match"
|
|
469
|
+
assert x.shape[0] == lse.shape[0], "Batch dimensions must match"
|
|
470
|
+
assert (
|
|
471
|
+
x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
|
|
472
|
+
), "Tensors must be on CUDA device"
|
|
473
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
|
|
474
|
+
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
475
|
+
|
|
476
|
+
M, N = x.shape
|
|
477
|
+
dx = torch.empty_like(x) if not inplace_backward else x
|
|
478
|
+
dtype = torch2cute_dtype_map[x.dtype]
|
|
479
|
+
|
|
480
|
+
convert_from_dlpack = lambda tensor: (
|
|
481
|
+
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
482
|
+
mode=0, stride_order=(0, 1)
|
|
483
|
+
)
|
|
484
|
+
)
|
|
485
|
+
x_tensor = convert_from_dlpack(x)
|
|
486
|
+
dx_tensor = convert_from_dlpack(dx)
|
|
487
|
+
dloss_tensor = from_dlpack(dloss.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
488
|
+
lse_tensor = from_dlpack(lse.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
489
|
+
target_tensor = from_dlpack(target.detach(), assumed_align=32).mark_compact_shape_dynamic(
|
|
490
|
+
mode=0
|
|
491
|
+
)
|
|
492
|
+
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
493
|
+
|
|
494
|
+
compile_key = (dtype, N)
|
|
495
|
+
if compile_key not in _cross_entropy_backward.compile_cache:
|
|
496
|
+
cross_entropy_backward_op = CrossEntropyBackward(dtype, N)
|
|
497
|
+
_cross_entropy_backward.compile_cache[compile_key] = cute.compile(
|
|
498
|
+
cross_entropy_backward_op,
|
|
499
|
+
x_tensor,
|
|
500
|
+
target_tensor,
|
|
501
|
+
dloss_tensor,
|
|
502
|
+
dx_tensor,
|
|
503
|
+
lse_tensor,
|
|
504
|
+
stream,
|
|
505
|
+
)
|
|
506
|
+
_cross_entropy_backward.compile_cache[compile_key](
|
|
507
|
+
x_tensor, target_tensor, dloss_tensor, dx_tensor, lse_tensor, stream
|
|
508
|
+
)
|
|
509
|
+
return dx
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
_cross_entropy_backward.compile_cache = {}
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
class CrossEntropyFunction(torch.autograd.Function):
|
|
516
|
+
@staticmethod
|
|
517
|
+
def forward(ctx, x, target, inplace_backward=False):
|
|
518
|
+
loss, lse = _cross_entropy(x, target, return_lse=True)
|
|
519
|
+
ctx.save_for_backward(x, target, lse)
|
|
520
|
+
ctx.inplace_backward = inplace_backward
|
|
521
|
+
return loss
|
|
522
|
+
|
|
523
|
+
@staticmethod
|
|
524
|
+
def backward(ctx, dloss):
|
|
525
|
+
x, target, lse = ctx.saved_tensors
|
|
526
|
+
dx = _cross_entropy_backward(x, target, dloss, lse, inplace_backward=ctx.inplace_backward)
|
|
527
|
+
return dx, None, None
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def cross_entropy(
|
|
531
|
+
x: torch.Tensor, target: torch.Tensor, inplace_backward: bool = False
|
|
532
|
+
) -> torch.Tensor:
|
|
533
|
+
"""Cross entropy loss with automatic differentiation support.
|
|
534
|
+
|
|
535
|
+
Args:
|
|
536
|
+
x: Input logits tensor of shape (M, N)
|
|
537
|
+
target: Target class indices tensor of shape (M,)
|
|
538
|
+
|
|
539
|
+
Returns:
|
|
540
|
+
Cross entropy loss tensor of shape (M,)
|
|
541
|
+
"""
|
|
542
|
+
return CrossEntropyFunction.apply(x, target, inplace_backward)
|
|
@@ -6,8 +6,6 @@ from typing import Type, Tuple, Optional
|
|
|
6
6
|
import cutlass
|
|
7
7
|
import cutlass.cute as cute
|
|
8
8
|
|
|
9
|
-
import quack.utils as utils
|
|
10
|
-
|
|
11
9
|
|
|
12
10
|
torch2cute_dtype_map = {
|
|
13
11
|
torch.float16: cutlass.Float16,
|
|
@@ -39,7 +37,6 @@ class ReductionBase:
|
|
|
39
37
|
vecsize = copy_bits // self.dtype.width
|
|
40
38
|
assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
|
|
41
39
|
num_threads = self._get_num_threads()
|
|
42
|
-
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
43
40
|
assert num_threads % cute.arch.WARP_SIZE == 0
|
|
44
41
|
|
|
45
42
|
threads_per_row = self._calculate_threads_per_row()
|
|
@@ -64,7 +61,7 @@ class ReductionBase:
|
|
|
64
61
|
|
|
65
62
|
def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int):
|
|
66
63
|
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
67
|
-
warps_per_row =
|
|
64
|
+
warps_per_row = max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
|
|
68
65
|
return cute.make_ordered_layout(
|
|
69
66
|
(num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
|
|
70
67
|
order=(1, 0, 2),
|
|
@@ -88,10 +85,10 @@ class ReductionBase:
|
|
|
88
85
|
def _initialize_cluster(self, tidx: cutlass.Int32, mbar_ptr: cute.Pointer, num_warps: int):
|
|
89
86
|
if cutlass.const_expr(self.cluster_n > 1):
|
|
90
87
|
if tidx < self.stage:
|
|
91
|
-
cute.arch.
|
|
88
|
+
cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
|
|
92
89
|
cute.arch.mbarrier_init_fence()
|
|
93
90
|
if tidx < self.stage:
|
|
94
|
-
cute.arch.
|
|
91
|
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
95
92
|
mbar_ptr + tidx, num_warps * self.cluster_n * self.reduction_dtype.width // 8
|
|
96
93
|
)
|
|
97
94
|
# Cluster arrive after barrier init
|