quack-kernels 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -34
- quack/gemm.py +194 -0
- quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
- quack/gemm_config.py +72 -46
- quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +177 -31
- quack/gemm_sm100.py +729 -506
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +3 -1
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +476 -526
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +23 -16
- quack/topk.py +409 -85
- quack/utils.py +32 -220
- quack/varlen_utils.py +370 -1
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/METADATA +4 -2
- quack_kernels-0.2.4.dist-info/RECORD +44 -0
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.2.dist-info/RECORD +0 -37
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/top_level.txt +0 -0
quack/cross_entropy.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
3
|
import math
|
|
4
|
+
from functools import partial
|
|
4
5
|
from typing import Optional, Type, Literal
|
|
5
6
|
|
|
6
7
|
import torch
|
|
@@ -10,10 +11,12 @@ import cuda.bindings.driver as cuda
|
|
|
10
11
|
|
|
11
12
|
import cutlass
|
|
12
13
|
import cutlass.cute as cute
|
|
13
|
-
from cutlass import Int32, Float32, Boolean, const_expr
|
|
14
|
-
from cutlass.cute.runtime import from_dlpack
|
|
14
|
+
from cutlass import Int32, Int64, Float32, Boolean, const_expr
|
|
15
15
|
|
|
16
16
|
import quack.utils as utils
|
|
17
|
+
import quack.copy_utils as copy_utils
|
|
18
|
+
import quack.layout_utils as layout_utils
|
|
19
|
+
from quack.compile_utils import make_fake_tensor as fake_tensor
|
|
17
20
|
from quack.reduce import row_reduce, online_softmax_reduce
|
|
18
21
|
from quack.reduction_base import ReductionBase
|
|
19
22
|
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
@@ -26,46 +29,29 @@ class CrossEntropy(ReductionBase):
|
|
|
26
29
|
dtype,
|
|
27
30
|
N,
|
|
28
31
|
stage=2 if not online_softmax else 1,
|
|
29
|
-
reduction_dtype=Float32 if not online_softmax else
|
|
32
|
+
reduction_dtype=Float32 if not online_softmax else Int64,
|
|
30
33
|
)
|
|
31
34
|
self.online_softmax = online_softmax
|
|
32
35
|
self.reload_from = None if N <= 16384 or online_softmax else "smem"
|
|
33
36
|
|
|
34
|
-
def
|
|
37
|
+
def _threads_per_row(self):
|
|
35
38
|
N = self.N
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
16
|
|
41
|
-
if N <= 128
|
|
42
|
-
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
43
|
-
)
|
|
44
|
-
)
|
|
39
|
+
for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]:
|
|
40
|
+
if N <= limit:
|
|
41
|
+
return threads
|
|
42
|
+
return 256
|
|
45
43
|
|
|
46
44
|
def _set_cluster_n(self):
|
|
47
45
|
N = self.N
|
|
48
46
|
if const_expr(self.dtype.width == 16):
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
)
|
|
58
|
-
else: # fp32
|
|
59
|
-
cluster_n = (
|
|
60
|
-
1
|
|
61
|
-
if N <= 16 * 1024
|
|
62
|
-
else (
|
|
63
|
-
2
|
|
64
|
-
if N <= 64 * 1024
|
|
65
|
-
else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
|
|
66
|
-
)
|
|
67
|
-
)
|
|
68
|
-
self.cluster_n = cluster_n
|
|
47
|
+
thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)]
|
|
48
|
+
else:
|
|
49
|
+
thresholds = [(16 * 1024, 1), (64 * 1024, 2), (128 * 1024, 4), (256 * 1024, 8)]
|
|
50
|
+
for limit, cluster in thresholds:
|
|
51
|
+
if N <= limit:
|
|
52
|
+
self.cluster_n = cluster
|
|
53
|
+
return
|
|
54
|
+
self.cluster_n = 16
|
|
69
55
|
|
|
70
56
|
@cute.jit
|
|
71
57
|
def __call__(
|
|
@@ -82,19 +68,30 @@ class CrossEntropy(ReductionBase):
|
|
|
82
68
|
assert mX.element_type == self.dtype
|
|
83
69
|
if const_expr(mTargetLogit is None):
|
|
84
70
|
mTargetLogit = mX
|
|
71
|
+
if const_expr(mdX is not None):
|
|
72
|
+
assert mdX.element_type == self.dtype
|
|
85
73
|
self._set_cluster_n()
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
74
|
+
largest_dtype_width = const_expr(mX.element_type.width)
|
|
75
|
+
if const_expr(mdX is not None):
|
|
76
|
+
largest_dtype_width = const_expr(max(largest_dtype_width, mdX.element_type.width))
|
|
77
|
+
vecsize = math.gcd(self.N, 128 // largest_dtype_width)
|
|
78
|
+
tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize)
|
|
79
|
+
num_threads = tiled_copy.size
|
|
91
80
|
self.kernel(
|
|
92
|
-
mX,
|
|
81
|
+
mX,
|
|
82
|
+
mTarget,
|
|
83
|
+
mTargetLogit,
|
|
84
|
+
mLoss,
|
|
85
|
+
mLSE,
|
|
86
|
+
mdX,
|
|
87
|
+
ignore_index,
|
|
88
|
+
tiler_mn,
|
|
89
|
+
tiled_copy,
|
|
90
|
+
threads_per_row,
|
|
93
91
|
).launch(
|
|
94
92
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
95
93
|
block=[num_threads, 1, 1],
|
|
96
|
-
cluster=
|
|
97
|
-
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
94
|
+
cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None,
|
|
98
95
|
stream=stream,
|
|
99
96
|
)
|
|
100
97
|
|
|
@@ -108,47 +105,40 @@ class CrossEntropy(ReductionBase):
|
|
|
108
105
|
mLSE: Optional[cute.Tensor], # (M,)
|
|
109
106
|
mdX: Optional[cute.Tensor], # (M, N) - if provided, compute gradient
|
|
110
107
|
ignore_index: Int32, # Index to ignore in loss computation
|
|
111
|
-
tv_layout: cute.Layout,
|
|
112
108
|
tiler_mn: cute.Shape,
|
|
109
|
+
tiled_copy: cute.TiledCopy,
|
|
110
|
+
threads_per_row: cutlass.Constexpr[int],
|
|
113
111
|
):
|
|
114
112
|
tidx, _, _ = cute.arch.thread_idx()
|
|
115
113
|
bidx, _, _ = cute.arch.block_idx()
|
|
116
|
-
if const_expr(self.cluster_n
|
|
117
|
-
|
|
118
|
-
else:
|
|
119
|
-
cluster_y = const_expr(0)
|
|
114
|
+
cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1]
|
|
115
|
+
tv_layout = tiled_copy.layout_tv_tiled
|
|
120
116
|
|
|
121
|
-
shape
|
|
117
|
+
shape = mX.shape
|
|
122
118
|
idX = cute.make_identity_tensor(shape)
|
|
123
119
|
# slice for CTAs
|
|
124
|
-
|
|
125
|
-
mX_off = utils.domain_offset_i64((bidx * tiler_mn[0], 0), mX)
|
|
126
|
-
gX = cute.local_tile(mX_off, tiler_mn, (0, cluster_y))
|
|
127
|
-
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
120
|
+
gX, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, idX)]
|
|
128
121
|
|
|
129
122
|
smem = cutlass.utils.SmemAllocator()
|
|
130
123
|
sX = smem.allocate_tensor(
|
|
131
|
-
mX.element_type,
|
|
132
|
-
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
133
|
-
byte_alignment=16,
|
|
124
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
134
125
|
)
|
|
135
126
|
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
136
127
|
|
|
137
|
-
|
|
138
|
-
num_copy_elems_X = tv_layout.shape[1][0]
|
|
139
|
-
num_copy_bits_X = mX.element_type.width * num_copy_elems_X
|
|
140
|
-
copy_atom_load_X = cute.make_copy_atom(
|
|
141
|
-
cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
142
|
-
)
|
|
143
|
-
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
128
|
+
thr_copy = tiled_copy.get_slice(tidx)
|
|
144
129
|
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
130
|
+
tXgX = thr_copy.partition_S(gX)
|
|
131
|
+
tXsX = thr_copy.partition_D(sX)
|
|
132
|
+
tXcX = thr_copy.partition_S(cX)[(0, None), None, None]
|
|
149
133
|
tXrX = cute.make_fragment_like(tXgX)
|
|
150
134
|
|
|
151
|
-
|
|
135
|
+
is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
136
|
+
tXpX = (
|
|
137
|
+
None if is_even_N else copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1])
|
|
138
|
+
)
|
|
139
|
+
copy = partial(copy_utils.copy, pred=tXpX)
|
|
140
|
+
|
|
141
|
+
num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE
|
|
152
142
|
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
153
143
|
|
|
154
144
|
row = tXcX[0][0]
|
|
@@ -156,14 +146,8 @@ class CrossEntropy(ReductionBase):
|
|
|
156
146
|
if row < shape[0]:
|
|
157
147
|
target = Int32(mTarget[row])
|
|
158
148
|
|
|
159
|
-
is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
160
|
-
tXpX = (
|
|
161
|
-
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
162
|
-
if const_expr(not is_even_N)
|
|
163
|
-
else None
|
|
164
|
-
)
|
|
165
149
|
if row < shape[0]:
|
|
166
|
-
|
|
150
|
+
copy(tXgX, tXsX, is_async=True)
|
|
167
151
|
cute.arch.cp_async_commit_group()
|
|
168
152
|
cute.arch.cp_async_wait_group(0)
|
|
169
153
|
# Fill OOB values with -inf
|
|
@@ -177,14 +161,11 @@ class CrossEntropy(ReductionBase):
|
|
|
177
161
|
if row < shape[0] and tXcX[0][1] == 0 and not should_ignore:
|
|
178
162
|
# Only load target logit if not ignoring this index
|
|
179
163
|
if const_expr(cute.rank(mTargetLogit.shape) == 2):
|
|
180
|
-
|
|
181
|
-
mTargetLogit_off = utils.domain_offset_i64((row, 0), mTargetLogit)
|
|
182
|
-
target_logit = Float32(mTargetLogit_off[0, target])
|
|
164
|
+
target_logit = Float32(mTargetLogit[row, target])
|
|
183
165
|
else:
|
|
184
166
|
assert cute.rank(mTargetLogit.shape) == 1
|
|
185
167
|
target_logit = Float32(mTargetLogit[row])
|
|
186
168
|
|
|
187
|
-
threads_per_row = tv_layout.shape[0][0]
|
|
188
169
|
if const_expr(not self.online_softmax):
|
|
189
170
|
max_x = row_reduce(
|
|
190
171
|
x,
|
|
@@ -237,21 +218,16 @@ class CrossEntropy(ReductionBase):
|
|
|
237
218
|
# Compute probabilities: exp(x) / sum(exp(x))
|
|
238
219
|
# If ignored, gradient should be zero
|
|
239
220
|
denom_inv = (
|
|
240
|
-
1.0 / denom
|
|
221
|
+
# 1.0 / denom
|
|
222
|
+
cute.arch.rcp_approx(denom)
|
|
241
223
|
if not (denom == 0.0 or denom != denom or should_ignore)
|
|
242
224
|
else Float32.zero
|
|
243
225
|
)
|
|
244
226
|
probs = exp_x * denom_inv
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
# Setup copy atom for storing gradient
|
|
248
|
-
copy_atom_store = cute.make_copy_atom(
|
|
249
|
-
cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
250
|
-
)
|
|
251
|
-
thr_copy_dX = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn).get_slice(tidx)
|
|
252
|
-
tXgdX = thr_copy_dX.partition_D(gdX)
|
|
227
|
+
gdX = cute.local_tile(mdX, tiler_mn, (bidx, cluster_y))
|
|
228
|
+
tXgdX = thr_copy.partition_D(gdX)
|
|
253
229
|
tXrdX = cute.make_fragment_like(tXgdX)
|
|
254
|
-
tXcFull =
|
|
230
|
+
tXcFull = thr_copy.partition_S(cX)
|
|
255
231
|
# Compute gradient: probs for all classes, (probs - 1) for target class
|
|
256
232
|
# If ignored, gradient is already zero
|
|
257
233
|
tXrdX_f32 = cute.make_fragment_like(tXrX, Float32)
|
|
@@ -260,13 +236,8 @@ class CrossEntropy(ReductionBase):
|
|
|
260
236
|
for i in cutlass.range(cute.size(tXrX), unroll_full=True):
|
|
261
237
|
tXrdX_f32[i] = tXrdX_f32[i] if tXcFull[i][1] != target else tXrdX_f32[i] - 1.0
|
|
262
238
|
tXrdX.store(tXrdX_f32.load().to(tXrdX.element_type))
|
|
263
|
-
tXpdX = (
|
|
264
|
-
utils.predicate_k(thr_copy_dX.partition_S(cX), limit=shape[1])
|
|
265
|
-
if not is_even_N
|
|
266
|
-
else None
|
|
267
|
-
)
|
|
268
239
|
if row < shape[0]:
|
|
269
|
-
|
|
240
|
+
copy(tXrdX, tXgdX)
|
|
270
241
|
|
|
271
242
|
|
|
272
243
|
@torch.library.custom_op("quack::cross_entropy_fwd_out", mutates_args={"loss", "lse", "dx"})
|
|
@@ -296,77 +267,61 @@ def cross_entropy_fwd_out(
|
|
|
296
267
|
"""
|
|
297
268
|
assert x.dim() == 2, "Input must be 2D"
|
|
298
269
|
assert target.dim() == 1, "Target must be 1D"
|
|
299
|
-
assert x.shape[0] == target.shape[0], "Batch dimensions must match"
|
|
300
270
|
assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
|
|
301
271
|
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
|
|
302
272
|
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
303
273
|
if target_logit is not None:
|
|
304
|
-
assert target_logit.shape[0] == x.shape[0]
|
|
305
274
|
assert target_logit.is_cuda, "Target logits must be on CUDA device"
|
|
306
275
|
assert target_logit.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
|
307
276
|
if dx is not None:
|
|
308
|
-
assert dx.shape == x.shape, "dx must have same shape as x"
|
|
309
277
|
assert dx.is_cuda, "dx must be on CUDA device"
|
|
310
|
-
assert dx.dtype == x.dtype, "dx must have same dtype as x"
|
|
311
278
|
N = x.size(1)
|
|
312
279
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
)
|
|
280
|
+
target_dtype = torch2cute_dtype_map[target.dtype]
|
|
281
|
+
target_logit_dtype = (
|
|
282
|
+
torch2cute_dtype_map[target_logit.dtype] if target_logit is not None else None
|
|
317
283
|
)
|
|
318
|
-
x_tensor = convert_from_dlpack(x)
|
|
319
|
-
loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_layout_dynamic()
|
|
320
|
-
lse_tensor = (
|
|
321
|
-
from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic()
|
|
322
|
-
if lse is not None
|
|
323
|
-
else None
|
|
324
|
-
)
|
|
325
|
-
target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_layout_dynamic()
|
|
326
|
-
target_logit_tensor = (
|
|
327
|
-
from_dlpack(target_logit.detach(), assumed_align=4).mark_layout_dynamic(
|
|
328
|
-
leading_dim=target_logit.ndim - 1
|
|
329
|
-
)
|
|
330
|
-
if target_logit is not None
|
|
331
|
-
else None
|
|
332
|
-
)
|
|
333
|
-
dx_tensor = convert_from_dlpack(dx) if dx is not None else None
|
|
334
|
-
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
335
|
-
|
|
336
284
|
compile_key = (
|
|
337
285
|
dtype,
|
|
286
|
+
target_dtype,
|
|
287
|
+
target_logit_dtype,
|
|
338
288
|
N,
|
|
339
|
-
|
|
340
|
-
lse.dtype if lse is not None else None,
|
|
289
|
+
lse is not None,
|
|
341
290
|
dx is not None,
|
|
342
|
-
loss.stride(),
|
|
343
|
-
lse.stride() if lse is not None else None,
|
|
344
|
-
target.stride(),
|
|
345
|
-
target_logit.stride(-1) if target_logit is not None else None,
|
|
346
291
|
)
|
|
347
292
|
if compile_key not in cross_entropy_fwd_out.compile_cache:
|
|
293
|
+
batch_sym = cute.sym_int()
|
|
294
|
+
div = math.gcd(128 // dtype.width, N)
|
|
295
|
+
x_cute = fake_tensor(dtype, (batch_sym, N), div)
|
|
296
|
+
dx_cute = fake_tensor(dtype, (batch_sym, N), div) if dx is not None else None
|
|
297
|
+
target_cute = fake_tensor(target_dtype, (batch_sym,))
|
|
298
|
+
if target_logit is not None:
|
|
299
|
+
if target_logit.ndim == 2:
|
|
300
|
+
target_logit_cute = fake_tensor(
|
|
301
|
+
target_logit_dtype, (batch_sym, cute.sym_int()), div
|
|
302
|
+
)
|
|
303
|
+
else:
|
|
304
|
+
target_logit_cute = fake_tensor(target_logit_dtype, (batch_sym,))
|
|
305
|
+
else:
|
|
306
|
+
target_logit_cute = None
|
|
307
|
+
loss_cute = fake_tensor(Float32, (batch_sym,))
|
|
308
|
+
lse_cute = fake_tensor(Float32, (batch_sym,)) if lse is not None else None
|
|
348
309
|
# If there's dx, it's faster to not use online softmax since we want the exp(x - max)
|
|
349
310
|
cross_entropy_op = CrossEntropy(dtype, N, online_softmax=dx is None)
|
|
350
311
|
cross_entropy_fwd_out.compile_cache[compile_key] = cute.compile(
|
|
351
312
|
cross_entropy_op,
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
Int32(
|
|
359
|
-
|
|
313
|
+
x_cute,
|
|
314
|
+
target_cute,
|
|
315
|
+
target_logit_cute,
|
|
316
|
+
loss_cute,
|
|
317
|
+
lse_cute,
|
|
318
|
+
dx_cute,
|
|
319
|
+
Int32(0), # ignore_index, just for compilation
|
|
320
|
+
cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
|
|
321
|
+
options="--enable-tvm-ffi",
|
|
360
322
|
)
|
|
361
323
|
cross_entropy_fwd_out.compile_cache[compile_key](
|
|
362
|
-
|
|
363
|
-
target_tensor,
|
|
364
|
-
target_logit_tensor,
|
|
365
|
-
loss_tensor,
|
|
366
|
-
lse_tensor,
|
|
367
|
-
dx_tensor,
|
|
368
|
-
Int32(ignore_index),
|
|
369
|
-
stream,
|
|
324
|
+
x, target, target_logit, loss, lse, dx, Int32(ignore_index)
|
|
370
325
|
)
|
|
371
326
|
|
|
372
327
|
|
|
@@ -404,35 +359,25 @@ class CrossEntropyBackward:
|
|
|
404
359
|
self.N = N
|
|
405
360
|
self.vecsize = 128 // dtype.width
|
|
406
361
|
|
|
407
|
-
def
|
|
362
|
+
def _threads_per_row(self):
|
|
408
363
|
N = min(self.N, 16384) # We split by blocks of 16k
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
16
|
|
414
|
-
if N <= 128
|
|
415
|
-
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
416
|
-
)
|
|
417
|
-
)
|
|
364
|
+
for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]:
|
|
365
|
+
if N <= limit:
|
|
366
|
+
return threads
|
|
367
|
+
return 256
|
|
418
368
|
|
|
419
|
-
def
|
|
420
|
-
vecsize = num_copy_bits // self.dtype.width
|
|
369
|
+
def _get_tiled_copy(self, vecsize: int):
|
|
421
370
|
assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
|
|
422
371
|
N = min(self.N, 16384)
|
|
423
372
|
num_threads = 128 if N <= 16384 else 256
|
|
424
|
-
threads_per_row = self.
|
|
373
|
+
threads_per_row = self._threads_per_row()
|
|
425
374
|
cols_per_block = num_threads // threads_per_row
|
|
426
375
|
num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row)
|
|
427
376
|
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
stride=(
|
|
431
|
-
(vecsize * cols_per_block, 1),
|
|
432
|
-
(cols_per_block, cols_per_block * vecsize * threads_per_row),
|
|
433
|
-
),
|
|
377
|
+
tiled_copy = copy_utils.tiled_copy_2d(
|
|
378
|
+
self.dtype, threads_per_row, num_threads, num_copy_elems=vecsize
|
|
434
379
|
)
|
|
435
|
-
return tiler_mn,
|
|
380
|
+
return tiled_copy, tiler_mn, threads_per_row
|
|
436
381
|
|
|
437
382
|
@cute.jit
|
|
438
383
|
def __call__(
|
|
@@ -448,21 +393,24 @@ class CrossEntropyBackward:
|
|
|
448
393
|
assert mX.element_type == self.dtype
|
|
449
394
|
assert mdX.element_type == self.dtype
|
|
450
395
|
# e.g. if self.N isn't divisible by 8 for bf16, we might use 64 bits (4 elements) copy
|
|
451
|
-
|
|
452
|
-
tiler_mn,
|
|
453
|
-
num_threads =
|
|
396
|
+
vecsize = math.gcd(self.N, 128 // self.dtype.width)
|
|
397
|
+
tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize)
|
|
398
|
+
num_threads = tiled_copy.size
|
|
454
399
|
# (M,) -> (M, N) with stride 0 in the N dimension
|
|
455
400
|
mDLoss, mTarget, mLSE = [
|
|
456
|
-
|
|
457
|
-
X.iterator, cute.append(X.layout, cute.make_layout((self.N,), stride=(0,)))
|
|
458
|
-
)
|
|
459
|
-
for X in (mDLoss, mTarget, mLSE)
|
|
401
|
+
layout_utils.expand(X, dim=1, size=self.N) for X in (mDLoss, mTarget, mLSE)
|
|
460
402
|
]
|
|
461
|
-
smem_size = cute.size_in_bytes(
|
|
462
|
-
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0))
|
|
463
|
-
)
|
|
464
403
|
self.kernel(
|
|
465
|
-
mX,
|
|
404
|
+
mX,
|
|
405
|
+
mTarget,
|
|
406
|
+
mDLoss,
|
|
407
|
+
mdX,
|
|
408
|
+
mLSE,
|
|
409
|
+
ignore_index,
|
|
410
|
+
mX.shape,
|
|
411
|
+
tiler_mn,
|
|
412
|
+
tiled_copy,
|
|
413
|
+
threads_per_row,
|
|
466
414
|
).launch(
|
|
467
415
|
grid=[
|
|
468
416
|
cute.ceil_div(mX.shape[0], tiler_mn[0]),
|
|
@@ -470,7 +418,6 @@ class CrossEntropyBackward:
|
|
|
470
418
|
1,
|
|
471
419
|
],
|
|
472
420
|
block=[num_threads, 1, 1],
|
|
473
|
-
smem=smem_size,
|
|
474
421
|
stream=stream,
|
|
475
422
|
)
|
|
476
423
|
|
|
@@ -484,52 +431,39 @@ class CrossEntropyBackward:
|
|
|
484
431
|
mLSE: cute.Tensor, # (M,)
|
|
485
432
|
ignore_index: Int32, # Index to ignore in gradient computation
|
|
486
433
|
shape: cute.Shape,
|
|
487
|
-
tv_layout: cute.Layout,
|
|
488
434
|
tiler_mn: cute.Shape,
|
|
435
|
+
tiled_copy: cute.TiledCopy,
|
|
436
|
+
threads_per_row: cutlass.Constexpr[int],
|
|
489
437
|
):
|
|
490
438
|
tidx, _, _ = cute.arch.thread_idx()
|
|
491
439
|
bidx, bidy, _ = cute.arch.block_idx()
|
|
492
440
|
|
|
493
441
|
smem = cutlass.utils.SmemAllocator()
|
|
494
442
|
sX = smem.allocate_tensor(
|
|
495
|
-
mX.element_type,
|
|
496
|
-
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
497
|
-
byte_alignment=16,
|
|
443
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
498
444
|
)
|
|
499
445
|
|
|
500
446
|
idX = cute.make_identity_tensor(shape)
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
)
|
|
511
|
-
copy_atom_store_dX = cute.make_copy_atom(
|
|
512
|
-
cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
513
|
-
)
|
|
514
|
-
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
515
|
-
thr_copy_dX = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
|
|
516
|
-
|
|
517
|
-
#### Partition to get thread view
|
|
518
|
-
tXgX = thr_copy_X.partition_S(gX)
|
|
519
|
-
tXsX = thr_copy_X.partition_S(sX)
|
|
520
|
-
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
521
|
-
tXcFull = thr_copy_X.partition_S(cX)
|
|
522
|
-
tXgdX = thr_copy_dX.partition_D(gdX)
|
|
523
|
-
# allocate fragments for gmem->rmem
|
|
447
|
+
gX, gdX, cX = [cute.local_tile(mT, tiler_mn, (bidx, bidy)) for mT in (mX, mdX, idX)]
|
|
448
|
+
|
|
449
|
+
thr_copy = tiled_copy.get_slice(tidx)
|
|
450
|
+
|
|
451
|
+
tXgX = thr_copy.partition_S(gX)
|
|
452
|
+
tXsX = thr_copy.partition_D(sX)
|
|
453
|
+
tXcX = thr_copy.partition_S(cX)[(0, None), None, None]
|
|
454
|
+
tXcFull = thr_copy.partition_S(cX)
|
|
455
|
+
tXgdX = thr_copy.partition_D(gdX)
|
|
524
456
|
tXrX, tXrdX = [cute.make_fragment_like(thr) for thr in (tXgX, tXgdX)]
|
|
525
457
|
|
|
526
458
|
is_even_N = const_expr(shape[1] % tiler_mn[1] == 0)
|
|
527
|
-
row = tXcX[0][0]
|
|
528
459
|
tXpX = (
|
|
529
|
-
|
|
460
|
+
None if is_even_N else copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1])
|
|
530
461
|
)
|
|
462
|
+
copy = partial(copy_utils.copy, pred=tXpX)
|
|
463
|
+
|
|
464
|
+
row = tXcX[0][0]
|
|
531
465
|
if row < shape[0]:
|
|
532
|
-
|
|
466
|
+
copy(tXgX, tXsX, is_async=True)
|
|
533
467
|
cute.arch.cp_async_commit_group()
|
|
534
468
|
cute.arch.cp_async_wait_group(0)
|
|
535
469
|
if const_expr(not is_even_N):
|
|
@@ -544,26 +478,22 @@ class CrossEntropyBackward:
|
|
|
544
478
|
target = Int32(mTarget[row])
|
|
545
479
|
should_ignore = Boolean(target == ignore_index)
|
|
546
480
|
# Set dloss to 0 if this index should be ignored
|
|
547
|
-
|
|
481
|
+
if not should_ignore:
|
|
482
|
+
dloss = Float32(mDLoss[row])
|
|
548
483
|
lse = Float32(mLSE[row])
|
|
549
484
|
|
|
550
485
|
log2_e = math.log2(math.e)
|
|
551
486
|
probs = cute.math.exp2(x * log2_e - (lse * log2_e), fastmath=True)
|
|
552
487
|
prob_shifted = probs - 1.0
|
|
553
|
-
mask = cute.make_fragment_like(tXrX,
|
|
488
|
+
mask = cute.make_fragment_like(tXrX, Boolean)
|
|
554
489
|
for i in cutlass.range(cute.size(tXcFull), unroll_full=True):
|
|
555
490
|
mask[i] = tXcFull[i][1] == target
|
|
556
491
|
grad = cute.where(mask.load(), prob_shifted, probs)
|
|
557
492
|
grad = grad * dloss
|
|
558
493
|
|
|
559
494
|
tXrdX.store(grad.to(tXrdX.element_type))
|
|
560
|
-
tXpdX = (
|
|
561
|
-
utils.predicate_k(thr_copy_dX.partition_S(cX), limit=shape[1])
|
|
562
|
-
if not is_even_N
|
|
563
|
-
else None
|
|
564
|
-
)
|
|
565
495
|
if row < shape[0]:
|
|
566
|
-
|
|
496
|
+
copy(tXrdX, tXgdX)
|
|
567
497
|
|
|
568
498
|
|
|
569
499
|
def _cross_entropy_backward(
|
|
@@ -598,34 +528,28 @@ def _cross_entropy_backward(
|
|
|
598
528
|
|
|
599
529
|
N = x.size(1)
|
|
600
530
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
604
|
-
mode=0, stride_order=(0, 1)
|
|
605
|
-
)
|
|
606
|
-
)
|
|
607
|
-
x_tensor = convert_from_dlpack(x)
|
|
608
|
-
dx_tensor = convert_from_dlpack(dx)
|
|
609
|
-
dloss_tensor = from_dlpack(dloss.detach(), assumed_align=4).mark_layout_dynamic()
|
|
610
|
-
lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic()
|
|
611
|
-
target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_layout_dynamic()
|
|
612
|
-
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
613
|
-
|
|
614
|
-
compile_key = (dtype, N, target.dtype, dloss.stride(), lse.stride(), target.stride())
|
|
531
|
+
target_dtype = torch2cute_dtype_map[target.dtype]
|
|
532
|
+
compile_key = (dtype, target_dtype, N)
|
|
615
533
|
if compile_key not in _cross_entropy_backward.compile_cache:
|
|
534
|
+
batch_sym = cute.sym_int()
|
|
535
|
+
div = math.gcd(128 // dtype.width, N)
|
|
536
|
+
x_cute, dx_cute = [fake_tensor(dtype, (batch_sym, N), div)] * 2
|
|
537
|
+
target_cute = fake_tensor(target_dtype, (batch_sym,))
|
|
538
|
+
dloss_cute, lse_cute = [fake_tensor(Float32, (batch_sym,))] * 2
|
|
616
539
|
cross_entropy_backward_op = CrossEntropyBackward(dtype, N)
|
|
617
540
|
_cross_entropy_backward.compile_cache[compile_key] = cute.compile(
|
|
618
541
|
cross_entropy_backward_op,
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
Int32(
|
|
625
|
-
|
|
542
|
+
x_cute,
|
|
543
|
+
target_cute,
|
|
544
|
+
dloss_cute,
|
|
545
|
+
dx_cute,
|
|
546
|
+
lse_cute,
|
|
547
|
+
Int32(0), # ignore_index, just for compilation
|
|
548
|
+
cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
|
|
549
|
+
options="--enable-tvm-ffi",
|
|
626
550
|
)
|
|
627
551
|
_cross_entropy_backward.compile_cache[compile_key](
|
|
628
|
-
|
|
552
|
+
x, target, dloss, dx, lse, Int32(ignore_index)
|
|
629
553
|
)
|
|
630
554
|
|
|
631
555
|
|