quack-kernels 0.1.4__tar.gz → 0.1.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {quack_kernels-0.1.4/quack_kernels.egg-info → quack_kernels-0.1.6}/PKG-INFO +1 -1
- {quack_kernels-0.1.4 → quack_kernels-0.1.6}/README.md +21 -3
- {quack_kernels-0.1.4 → quack_kernels-0.1.6}/quack/__init__.py +1 -1
- quack_kernels-0.1.6/quack/cross_entropy.py +546 -0
- quack_kernels-0.1.4/quack/rmsnorm.py → quack_kernels-0.1.6/quack/layernorm.py +93 -27
- {quack_kernels-0.1.4 → quack_kernels-0.1.6}/quack/reduction_base.py +1 -4
- quack_kernels-0.1.6/quack/rmsnorm.py +665 -0
- {quack_kernels-0.1.4 → quack_kernels-0.1.6}/quack/softmax.py +8 -3
- {quack_kernels-0.1.4 → quack_kernels-0.1.6}/quack/utils.py +19 -18
- {quack_kernels-0.1.4 → quack_kernels-0.1.6/quack_kernels.egg-info}/PKG-INFO +1 -1
- {quack_kernels-0.1.4 → quack_kernels-0.1.6}/quack_kernels.egg-info/SOURCES.txt +2 -0
- {quack_kernels-0.1.4 → quack_kernels-0.1.6}/quack_kernels.egg-info/top_level.txt +1 -0
- {quack_kernels-0.1.4 → quack_kernels-0.1.6}/tests/test_cross_entropy.py +21 -13
- quack_kernels-0.1.6/tests/test_layernorm.py +162 -0
- {quack_kernels-0.1.4 → quack_kernels-0.1.6}/tests/test_rmsnorm.py +36 -5
- {quack_kernels-0.1.4 → quack_kernels-0.1.6}/tests/test_softmax.py +2 -3
- quack_kernels-0.1.4/quack/cross_entropy.py +0 -255
- {quack_kernels-0.1.4 → quack_kernels-0.1.6}/LICENSE +0 -0
- {quack_kernels-0.1.4 → quack_kernels-0.1.6}/pyproject.toml +0 -0
- {quack_kernels-0.1.4 → quack_kernels-0.1.6}/quack_kernels.egg-info/dependency_links.txt +0 -0
- {quack_kernels-0.1.4 → quack_kernels-0.1.6}/quack_kernels.egg-info/requires.txt +0 -0
- {quack_kernels-0.1.4 → quack_kernels-0.1.6}/setup.cfg +0 -0
- {quack_kernels-0.1.4 → quack_kernels-0.1.6}/setup.py +0 -0
|
@@ -17,11 +17,11 @@ pip install quack-kernels
|
|
|
17
17
|
## Kernels 🐥
|
|
18
18
|
|
|
19
19
|
- 🦆 RMSNorm forward
|
|
20
|
-
- 🦆 Softmax forward
|
|
21
|
-
- 🦆 Cross entropy forward
|
|
20
|
+
- 🦆 Softmax forward + backward
|
|
21
|
+
- 🦆 Cross entropy forward + backward
|
|
22
|
+
- 🦆 Layernorm forward
|
|
22
23
|
|
|
23
24
|
Upcoming:
|
|
24
|
-
- 🦆 Cross entropy backward
|
|
25
25
|
- 🦆 RMSNorm backward
|
|
26
26
|
- 🦆 Rotary forward + backward
|
|
27
27
|
|
|
@@ -31,6 +31,24 @@ Upcoming:
|
|
|
31
31
|
from quack import rmsnorm, softmax, cross_entropy
|
|
32
32
|
```
|
|
33
33
|
|
|
34
|
+
## Documentations
|
|
35
|
+
|
|
36
|
+
[2025-07-10] We have a comprehensive
|
|
37
|
+
[blogpost](media/2025-07-10-membound-sol.md) on how to get memory-bound kernels
|
|
38
|
+
to speed-of-light, right in the comfort of Python thanks to the [CuTe-DSL](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.html).
|
|
39
|
+
|
|
40
|
+
## Performance
|
|
41
|
+
|
|
42
|
+
<div align="center">
|
|
43
|
+
<figure>
|
|
44
|
+
<img
|
|
45
|
+
src="media/bf16_kernel_benchmarks_single_row.svg"
|
|
46
|
+
>
|
|
47
|
+
</figure>
|
|
48
|
+
</div>
|
|
49
|
+
|
|
50
|
+
See our [blogpost](media/2025-07-10-membound-sol.md) for the details.
|
|
51
|
+
|
|
34
52
|
## Development
|
|
35
53
|
|
|
36
54
|
To set up the development environment:
|
|
@@ -0,0 +1,546 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import torch
|
|
5
|
+
from typing import Optional, Type
|
|
6
|
+
|
|
7
|
+
import cuda.bindings.driver as cuda
|
|
8
|
+
|
|
9
|
+
import cutlass
|
|
10
|
+
import cutlass.cute as cute
|
|
11
|
+
from cutlass.cute.runtime import from_dlpack
|
|
12
|
+
|
|
13
|
+
import quack.utils as utils
|
|
14
|
+
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CrossEntropy(ReductionBase):
|
|
18
|
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int, online_softmax: bool = True):
|
|
19
|
+
# 2 stages: 1 for max, 1 for sum
|
|
20
|
+
super().__init__(
|
|
21
|
+
dtype,
|
|
22
|
+
N,
|
|
23
|
+
stage=2 if not online_softmax else 1,
|
|
24
|
+
reduction_dtype=cutlass.Float32 if not online_softmax else cutlass.Int64,
|
|
25
|
+
)
|
|
26
|
+
self.online_softmax = online_softmax
|
|
27
|
+
self.reload_from = None if N <= 16384 or online_softmax else "smem"
|
|
28
|
+
|
|
29
|
+
def _calculate_threads_per_row(self):
|
|
30
|
+
N = self.N
|
|
31
|
+
return (
|
|
32
|
+
8
|
|
33
|
+
if N <= 64
|
|
34
|
+
else (
|
|
35
|
+
16
|
|
36
|
+
if N <= 128
|
|
37
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
38
|
+
)
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def _set_cluster_n(self):
|
|
42
|
+
N = self.N
|
|
43
|
+
if cutlass.const_expr(self.dtype.width == 16):
|
|
44
|
+
cluster_n = (
|
|
45
|
+
1
|
|
46
|
+
if N <= 16 * 1024
|
|
47
|
+
else (
|
|
48
|
+
2
|
|
49
|
+
if N <= 32 * 1024
|
|
50
|
+
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
51
|
+
)
|
|
52
|
+
)
|
|
53
|
+
else: # fp32
|
|
54
|
+
cluster_n = (
|
|
55
|
+
1
|
|
56
|
+
if N <= 16 * 1024
|
|
57
|
+
else (
|
|
58
|
+
2
|
|
59
|
+
if N <= 64 * 1024
|
|
60
|
+
else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
|
|
61
|
+
)
|
|
62
|
+
)
|
|
63
|
+
self.cluster_n = cluster_n
|
|
64
|
+
|
|
65
|
+
@cute.jit
|
|
66
|
+
def __call__(
|
|
67
|
+
self,
|
|
68
|
+
mX: cute.Tensor,
|
|
69
|
+
mTarget: cute.Tensor,
|
|
70
|
+
mLoss: cute.Tensor,
|
|
71
|
+
mLSE: Optional[cute.Tensor],
|
|
72
|
+
stream: cuda.CUstream,
|
|
73
|
+
):
|
|
74
|
+
assert mX.element_type == self.dtype
|
|
75
|
+
self._set_cluster_n()
|
|
76
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
77
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
78
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
79
|
+
self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
|
|
80
|
+
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
81
|
+
block=[num_threads, 1, 1],
|
|
82
|
+
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
83
|
+
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
84
|
+
stream=stream,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
@cute.kernel
|
|
88
|
+
def kernel(
|
|
89
|
+
self,
|
|
90
|
+
mX: cute.Tensor, # (M, N)
|
|
91
|
+
mTarget: cute.Tensor, # (M,)
|
|
92
|
+
mLoss: cute.Tensor, # (M,)
|
|
93
|
+
mLSE: Optional[cute.Tensor], # (M,)
|
|
94
|
+
tv_layout: cute.Layout,
|
|
95
|
+
tiler_mn: cute.Shape,
|
|
96
|
+
):
|
|
97
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
98
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
99
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
100
|
+
cluster_y = cute.arch.block_idx()[1]
|
|
101
|
+
else:
|
|
102
|
+
cluster_y = cutlass.const_expr(0)
|
|
103
|
+
|
|
104
|
+
shape: cute.Shape = mX.shape
|
|
105
|
+
idX = cute.make_identity_tensor(shape)
|
|
106
|
+
# slice for CTAs
|
|
107
|
+
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
|
108
|
+
mX_off = utils.domain_offset_i64((bidx * tiler_mn[0], 0), mX)
|
|
109
|
+
gX = cute.local_tile(mX_off, tiler_mn, (0, cluster_y))
|
|
110
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
111
|
+
|
|
112
|
+
smem = cutlass.utils.SmemAllocator()
|
|
113
|
+
sX = smem.allocate_tensor(
|
|
114
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
115
|
+
)
|
|
116
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
117
|
+
|
|
118
|
+
# declare the atoms which will be used later for memory copy
|
|
119
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
120
|
+
cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
|
|
121
|
+
)
|
|
122
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
123
|
+
|
|
124
|
+
#### Thread View
|
|
125
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
126
|
+
tXsX = thr_copy_X.partition_D(sX)
|
|
127
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
128
|
+
tXrX = cute.make_fragment_like(tXgX)
|
|
129
|
+
|
|
130
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
131
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
132
|
+
|
|
133
|
+
row = tXcX[0][0]
|
|
134
|
+
target = cute.Int32.zero
|
|
135
|
+
if row < shape[0] and tXcX[0][1] == 0:
|
|
136
|
+
target = cute.Int32(mTarget[row])
|
|
137
|
+
|
|
138
|
+
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
139
|
+
tXpX = (
|
|
140
|
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
141
|
+
if cutlass.const_expr(not is_even_N)
|
|
142
|
+
else None
|
|
143
|
+
)
|
|
144
|
+
if row < shape[0]:
|
|
145
|
+
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
|
146
|
+
cute.arch.cp_async_commit_group()
|
|
147
|
+
cute.arch.cp_async_wait_group(0)
|
|
148
|
+
# Fill OOB values with -inf
|
|
149
|
+
if cutlass.const_expr(not is_even_N):
|
|
150
|
+
utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
|
|
151
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
152
|
+
x = tXrX.load().to(cute.Float32)
|
|
153
|
+
|
|
154
|
+
target_logit = cute.Float32.zero
|
|
155
|
+
if row < shape[0] and tXcX[0][1] == 0:
|
|
156
|
+
# Use Int64 for indexing to deal with large tensors
|
|
157
|
+
mX_off = utils.domain_offset_i64((row, 0), mX)
|
|
158
|
+
target_logit = cute.Float32(mX_off[0, target])
|
|
159
|
+
|
|
160
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
161
|
+
if cutlass.const_expr(not self.online_softmax):
|
|
162
|
+
max_x = utils.row_reduce(
|
|
163
|
+
x,
|
|
164
|
+
cute.ReductionOp.MAX,
|
|
165
|
+
threads_per_row,
|
|
166
|
+
reduction_buffer[None, None, 0],
|
|
167
|
+
mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
168
|
+
init_val=-cutlass.Float32.inf,
|
|
169
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
170
|
+
)
|
|
171
|
+
if cutlass.const_expr(self.reload_from == "smem"):
|
|
172
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
173
|
+
x = tXrX.load().to(cute.Float32)
|
|
174
|
+
log2_e = math.log2(math.e)
|
|
175
|
+
# exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
|
|
176
|
+
# a bit faster, probably because it's calling ex2.approx.ftz instead of ex2.approx?
|
|
177
|
+
# exp_x = utils.exp2f((x - max_x) * log2_e)
|
|
178
|
+
# This would use ffma instead of fadd then fmul
|
|
179
|
+
exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
|
|
180
|
+
denom = utils.row_reduce(
|
|
181
|
+
exp_x,
|
|
182
|
+
cute.ReductionOp.ADD,
|
|
183
|
+
threads_per_row,
|
|
184
|
+
reduction_buffer[None, None, 1],
|
|
185
|
+
mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
186
|
+
init_val=0.0,
|
|
187
|
+
)
|
|
188
|
+
else:
|
|
189
|
+
max_x, denom, _ = utils.online_softmax_reduce(
|
|
190
|
+
x,
|
|
191
|
+
threads_per_row,
|
|
192
|
+
reduction_buffer[None, None, 0],
|
|
193
|
+
mbar_ptr,
|
|
194
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
if (
|
|
198
|
+
tXcX[0][1] == 0
|
|
199
|
+
and row < shape[0]
|
|
200
|
+
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
201
|
+
):
|
|
202
|
+
ln_2 = math.log(2.0)
|
|
203
|
+
lse = max_x + utils.log2f(denom) * ln_2
|
|
204
|
+
loss_val = lse - target_logit
|
|
205
|
+
mLoss[row] = loss_val.to(mLoss.element_type)
|
|
206
|
+
if cutlass.const_expr(mLSE is not None):
|
|
207
|
+
mLSE[row] = lse
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _cross_entropy(
|
|
211
|
+
x: torch.Tensor,
|
|
212
|
+
target: torch.Tensor,
|
|
213
|
+
return_lse: bool = False,
|
|
214
|
+
) -> torch.Tensor:
|
|
215
|
+
"""Cross entropy forward pass.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
x: Input logits tensor of shape (M, N)
|
|
219
|
+
target: Target class indices tensor of shape (M,)
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
Cross entropy loss tensor of shape (M,)
|
|
223
|
+
"""
|
|
224
|
+
assert x.dim() == 2, "Input must be 2D"
|
|
225
|
+
assert target.dim() == 1, "Target must be 1D"
|
|
226
|
+
assert x.shape[0] == target.shape[0], "Batch dimensions must match"
|
|
227
|
+
assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
|
|
228
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
|
|
229
|
+
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
230
|
+
M, N = x.shape
|
|
231
|
+
device = x.device
|
|
232
|
+
loss = torch.empty(M, device=device, dtype=torch.float32)
|
|
233
|
+
lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
|
|
234
|
+
dtype = torch2cute_dtype_map[x.dtype]
|
|
235
|
+
convert_from_dlpack = lambda tensor: (
|
|
236
|
+
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
237
|
+
mode=0, stride_order=(0, 1)
|
|
238
|
+
)
|
|
239
|
+
)
|
|
240
|
+
x_tensor = convert_from_dlpack(x)
|
|
241
|
+
loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
|
|
242
|
+
lse_tensor = (
|
|
243
|
+
from_dlpack(lse.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
|
|
244
|
+
if lse is not None
|
|
245
|
+
else None
|
|
246
|
+
)
|
|
247
|
+
target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_compact_shape_dynamic(mode=0)
|
|
248
|
+
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
249
|
+
|
|
250
|
+
compile_key = (dtype, N, lse is not None)
|
|
251
|
+
if compile_key not in _cross_entropy.compile_cache:
|
|
252
|
+
cross_entropy_op = CrossEntropy(dtype, N)
|
|
253
|
+
_cross_entropy.compile_cache[compile_key] = cute.compile(
|
|
254
|
+
cross_entropy_op, x_tensor, target_tensor, loss_tensor, lse_tensor, stream
|
|
255
|
+
)
|
|
256
|
+
_cross_entropy.compile_cache[compile_key](
|
|
257
|
+
x_tensor, target_tensor, loss_tensor, lse_tensor, stream
|
|
258
|
+
)
|
|
259
|
+
return loss if not return_lse else (loss, lse)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
_cross_entropy.compile_cache = {}
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
class CrossEntropyBackward:
|
|
266
|
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int):
|
|
267
|
+
self.dtype = dtype
|
|
268
|
+
self.N = N
|
|
269
|
+
self.vecsize = 128 // dtype.width
|
|
270
|
+
|
|
271
|
+
def _calculate_threads_per_row(self):
|
|
272
|
+
N = self.N
|
|
273
|
+
return (
|
|
274
|
+
8
|
|
275
|
+
if N <= 64
|
|
276
|
+
else (
|
|
277
|
+
16
|
|
278
|
+
if N <= 128
|
|
279
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
280
|
+
)
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
def _get_tv_layout(self):
|
|
284
|
+
N = self.N
|
|
285
|
+
vecsize = self.vecsize
|
|
286
|
+
num_threads = 128 if N <= 16384 else 256
|
|
287
|
+
threads_per_row = self._calculate_threads_per_row()
|
|
288
|
+
cols_per_block = num_threads // threads_per_row
|
|
289
|
+
num_blocks_N = cute.ceil_div(min(N, 16384) // vecsize, threads_per_row)
|
|
290
|
+
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
|
|
291
|
+
tv_layout = cute.make_layout(
|
|
292
|
+
((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
|
|
293
|
+
stride=(
|
|
294
|
+
(vecsize * cols_per_block, 1),
|
|
295
|
+
(cols_per_block, cols_per_block * vecsize * threads_per_row),
|
|
296
|
+
),
|
|
297
|
+
)
|
|
298
|
+
return tiler_mn, tv_layout
|
|
299
|
+
|
|
300
|
+
@cute.jit
|
|
301
|
+
def __call__(
|
|
302
|
+
self,
|
|
303
|
+
mX: cute.Tensor,
|
|
304
|
+
mTarget: cute.Tensor,
|
|
305
|
+
mDLoss: cute.Tensor,
|
|
306
|
+
mdX: cute.Tensor,
|
|
307
|
+
mLSE: cute.Tensor,
|
|
308
|
+
stream: cuda.CUstream,
|
|
309
|
+
):
|
|
310
|
+
assert mX.element_type == self.dtype
|
|
311
|
+
assert mdX.element_type == self.dtype
|
|
312
|
+
|
|
313
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
314
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
315
|
+
|
|
316
|
+
mDLoss = cute.make_tensor(
|
|
317
|
+
mDLoss.iterator, cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,)))
|
|
318
|
+
)
|
|
319
|
+
mTarget = cute.make_tensor(
|
|
320
|
+
mTarget.iterator, cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,)))
|
|
321
|
+
)
|
|
322
|
+
mLSE = cute.make_tensor(
|
|
323
|
+
mLSE.iterator, cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,)))
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
smem_size = cute.size_in_bytes(
|
|
327
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0))
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
self.kernel(
|
|
331
|
+
mX,
|
|
332
|
+
mTarget,
|
|
333
|
+
mDLoss,
|
|
334
|
+
mdX,
|
|
335
|
+
mLSE,
|
|
336
|
+
mX.shape,
|
|
337
|
+
tv_layout,
|
|
338
|
+
tiler_mn,
|
|
339
|
+
).launch(
|
|
340
|
+
grid=[
|
|
341
|
+
cute.ceil_div(mX.shape[0], tiler_mn[0]),
|
|
342
|
+
cute.ceil_div(mX.shape[1], tiler_mn[1]),
|
|
343
|
+
1,
|
|
344
|
+
],
|
|
345
|
+
block=[num_threads, 1, 1],
|
|
346
|
+
smem=smem_size,
|
|
347
|
+
stream=stream,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
@cute.kernel
|
|
351
|
+
def kernel(
|
|
352
|
+
self,
|
|
353
|
+
mX: cute.Tensor, # (M, N)
|
|
354
|
+
mTarget: cute.Tensor, # (M,)
|
|
355
|
+
mDLoss: cute.Tensor, # (M,)
|
|
356
|
+
mdX: cute.Tensor, # (M, N)
|
|
357
|
+
mLSE: cute.Tensor, # (M,)
|
|
358
|
+
shape: cute.Shape,
|
|
359
|
+
tv_layout: cute.Layout,
|
|
360
|
+
tiler_mn: cute.Shape,
|
|
361
|
+
):
|
|
362
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
363
|
+
bidx, bidy, _ = cute.arch.block_idx()
|
|
364
|
+
|
|
365
|
+
smem = cutlass.utils.SmemAllocator()
|
|
366
|
+
sX = smem.allocate_tensor(
|
|
367
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
idX = cute.make_identity_tensor(shape)
|
|
371
|
+
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
|
372
|
+
mX, mdX = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mdX)]
|
|
373
|
+
gX, gdX = [cute.local_tile(mT, tiler_mn, (0, bidy)) for mT in (mX, mdX)]
|
|
374
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, bidy))
|
|
375
|
+
|
|
376
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
377
|
+
cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128
|
|
378
|
+
)
|
|
379
|
+
copy_atom_load_X_async = cute.make_copy_atom(
|
|
380
|
+
cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
|
|
381
|
+
)
|
|
382
|
+
copy_atom_store_O = cute.make_copy_atom(
|
|
383
|
+
cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=128
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
387
|
+
thr_copy_X_async = cute.make_tiled_copy(
|
|
388
|
+
copy_atom_load_X_async, tv_layout, tiler_mn
|
|
389
|
+
).get_slice(tidx)
|
|
390
|
+
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
391
|
+
|
|
392
|
+
#### Thread View
|
|
393
|
+
tXgX = thr_copy_X_async.partition_S(gX)
|
|
394
|
+
tXsX = thr_copy_X_async.partition_S(sX)
|
|
395
|
+
|
|
396
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
397
|
+
tXcFull = thr_copy_X.partition_S(cX) # improve
|
|
398
|
+
|
|
399
|
+
tXgO = thr_copy_O.partition_D(gdX)
|
|
400
|
+
|
|
401
|
+
# allocate fragments for gmem->rmem
|
|
402
|
+
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
403
|
+
|
|
404
|
+
is_even_N = cutlass.const_expr(shape[1] % tiler_mn[1] == 0)
|
|
405
|
+
row = tXcX[0][0]
|
|
406
|
+
|
|
407
|
+
tXpX = (
|
|
408
|
+
utils.predicate_k(thr_copy_X_async.partition_S(cX), limit=shape[1])
|
|
409
|
+
if not is_even_N
|
|
410
|
+
else None
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
if row < shape[0]:
|
|
414
|
+
cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
|
|
415
|
+
cute.arch.cp_async_commit_group()
|
|
416
|
+
cute.arch.cp_async_wait_group(0)
|
|
417
|
+
if cutlass.const_expr(not is_even_N):
|
|
418
|
+
utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
|
|
419
|
+
|
|
420
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
421
|
+
x = tXrX.load().to(cute.Float32)
|
|
422
|
+
|
|
423
|
+
label = cute.Int32.zero
|
|
424
|
+
dloss = cute.Float32.zero
|
|
425
|
+
lse = cute.Float32.zero
|
|
426
|
+
if row < shape[0]:
|
|
427
|
+
label = cute.Int32(mTarget[row])
|
|
428
|
+
dloss = cute.Float32(mDLoss[row])
|
|
429
|
+
lse = cute.Float32(mLSE[row])
|
|
430
|
+
|
|
431
|
+
log2_e = math.log2(math.e)
|
|
432
|
+
probs = utils.exp2f((x - lse) * log2_e)
|
|
433
|
+
prob_shifted = probs - 1.0
|
|
434
|
+
|
|
435
|
+
mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
|
|
436
|
+
for i in cutlass.range_constexpr(cute.size(tXcFull)):
|
|
437
|
+
mask[i] = tXcFull[i][1] == label
|
|
438
|
+
|
|
439
|
+
mask = mask.load()
|
|
440
|
+
grad = cute.where(mask, prob_shifted, probs)
|
|
441
|
+
grad = grad * dloss
|
|
442
|
+
|
|
443
|
+
tXrO.store(grad.to(tXrO.element_type))
|
|
444
|
+
tOpO = (
|
|
445
|
+
utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
446
|
+
)
|
|
447
|
+
if row < shape[0]:
|
|
448
|
+
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def _cross_entropy_backward(
|
|
452
|
+
x: torch.Tensor,
|
|
453
|
+
target: torch.Tensor,
|
|
454
|
+
dloss: torch.Tensor,
|
|
455
|
+
lse: torch.Tensor,
|
|
456
|
+
inplace_backward: bool = False,
|
|
457
|
+
) -> torch.Tensor:
|
|
458
|
+
"""Cross entropy backward pass.
|
|
459
|
+
Args:
|
|
460
|
+
x: Input logits tensor of shape (M, N)
|
|
461
|
+
target: Target class indices tensor of shape (M,)
|
|
462
|
+
dloss: Upstream gradients tensor of shape (M,)
|
|
463
|
+
lse: Log-sum-exp values tensor of shape (M,)
|
|
464
|
+
Returns:
|
|
465
|
+
Input gradients tensor of shape (M, N)
|
|
466
|
+
"""
|
|
467
|
+
assert x.dim() == 2, "Input must be 2D"
|
|
468
|
+
assert target.dim() == 1, "Target must be 1D"
|
|
469
|
+
assert dloss.dim() == 1, "dloss must be 1D"
|
|
470
|
+
assert lse.dim() == 1, "lse must be 1D"
|
|
471
|
+
assert x.shape[0] == target.shape[0], "Batch dimensions must match"
|
|
472
|
+
assert x.shape[0] == dloss.shape[0], "Batch dimensions must match"
|
|
473
|
+
assert x.shape[0] == lse.shape[0], "Batch dimensions must match"
|
|
474
|
+
assert (
|
|
475
|
+
x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
|
|
476
|
+
), "Tensors must be on CUDA device"
|
|
477
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
|
|
478
|
+
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
479
|
+
|
|
480
|
+
M, N = x.shape
|
|
481
|
+
dx = torch.empty_like(x) if not inplace_backward else x
|
|
482
|
+
dtype = torch2cute_dtype_map[x.dtype]
|
|
483
|
+
|
|
484
|
+
convert_from_dlpack = lambda tensor: (
|
|
485
|
+
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
486
|
+
mode=0, stride_order=(0, 1)
|
|
487
|
+
)
|
|
488
|
+
)
|
|
489
|
+
x_tensor = convert_from_dlpack(x)
|
|
490
|
+
dx_tensor = convert_from_dlpack(dx)
|
|
491
|
+
dloss_tensor = from_dlpack(dloss.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
492
|
+
lse_tensor = from_dlpack(lse.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
493
|
+
target_tensor = from_dlpack(target.detach(), assumed_align=32).mark_compact_shape_dynamic(
|
|
494
|
+
mode=0
|
|
495
|
+
)
|
|
496
|
+
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
497
|
+
|
|
498
|
+
compile_key = (dtype, N)
|
|
499
|
+
if compile_key not in _cross_entropy_backward.compile_cache:
|
|
500
|
+
cross_entropy_backward_op = CrossEntropyBackward(dtype, N)
|
|
501
|
+
_cross_entropy_backward.compile_cache[compile_key] = cute.compile(
|
|
502
|
+
cross_entropy_backward_op,
|
|
503
|
+
x_tensor,
|
|
504
|
+
target_tensor,
|
|
505
|
+
dloss_tensor,
|
|
506
|
+
dx_tensor,
|
|
507
|
+
lse_tensor,
|
|
508
|
+
stream,
|
|
509
|
+
)
|
|
510
|
+
_cross_entropy_backward.compile_cache[compile_key](
|
|
511
|
+
x_tensor, target_tensor, dloss_tensor, dx_tensor, lse_tensor, stream
|
|
512
|
+
)
|
|
513
|
+
return dx
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
_cross_entropy_backward.compile_cache = {}
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
class CrossEntropyFunction(torch.autograd.Function):
|
|
520
|
+
@staticmethod
|
|
521
|
+
def forward(ctx, x, target, inplace_backward=False):
|
|
522
|
+
loss, lse = _cross_entropy(x, target, return_lse=True)
|
|
523
|
+
ctx.save_for_backward(x, target, lse)
|
|
524
|
+
ctx.inplace_backward = inplace_backward
|
|
525
|
+
return loss
|
|
526
|
+
|
|
527
|
+
@staticmethod
|
|
528
|
+
def backward(ctx, dloss):
|
|
529
|
+
x, target, lse = ctx.saved_tensors
|
|
530
|
+
dx = _cross_entropy_backward(x, target, dloss, lse, inplace_backward=ctx.inplace_backward)
|
|
531
|
+
return dx, None, None
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
def cross_entropy(
|
|
535
|
+
x: torch.Tensor, target: torch.Tensor, inplace_backward: bool = False
|
|
536
|
+
) -> torch.Tensor:
|
|
537
|
+
"""Cross entropy loss with automatic differentiation support.
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
x: Input logits tensor of shape (M, N)
|
|
541
|
+
target: Target class indices tensor of shape (M,)
|
|
542
|
+
|
|
543
|
+
Returns:
|
|
544
|
+
Cross entropy loss tensor of shape (M,)
|
|
545
|
+
"""
|
|
546
|
+
return CrossEntropyFunction.apply(x, target, inplace_backward)
|