quack-kernels 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/cross_entropy.py +298 -7
- quack/layernorm.py +351 -0
- quack/reduction_base.py +1 -4
- quack/rmsnorm.py +389 -9
- quack/softmax.py +8 -3
- quack/utils.py +19 -18
- {quack_kernels-0.1.4.dist-info → quack_kernels-0.1.6.dist-info}/METADATA +1 -1
- quack_kernels-0.1.6.dist-info/RECORD +12 -0
- quack_kernels-0.1.4.dist-info/RECORD +0 -11
- {quack_kernels-0.1.4.dist-info → quack_kernels-0.1.6.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.4.dist-info → quack_kernels-0.1.6.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.4.dist-info → quack_kernels-0.1.6.dist-info}/top_level.txt +0 -0
quack/__init__.py
CHANGED
quack/cross_entropy.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
|
+
|
|
1
3
|
import math
|
|
2
4
|
import torch
|
|
3
5
|
from typing import Optional, Type
|
|
@@ -102,7 +104,10 @@ class CrossEntropy(ReductionBase):
|
|
|
102
104
|
shape: cute.Shape = mX.shape
|
|
103
105
|
idX = cute.make_identity_tensor(shape)
|
|
104
106
|
# slice for CTAs
|
|
105
|
-
|
|
107
|
+
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
|
108
|
+
mX_off = utils.domain_offset_i64((bidx * tiler_mn[0], 0), mX)
|
|
109
|
+
gX = cute.local_tile(mX_off, tiler_mn, (0, cluster_y))
|
|
110
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
106
111
|
|
|
107
112
|
smem = cutlass.utils.SmemAllocator()
|
|
108
113
|
sX = smem.allocate_tensor(
|
|
@@ -148,7 +153,9 @@ class CrossEntropy(ReductionBase):
|
|
|
148
153
|
|
|
149
154
|
target_logit = cute.Float32.zero
|
|
150
155
|
if row < shape[0] and tXcX[0][1] == 0:
|
|
151
|
-
|
|
156
|
+
# Use Int64 for indexing to deal with large tensors
|
|
157
|
+
mX_off = utils.domain_offset_i64((row, 0), mX)
|
|
158
|
+
target_logit = cute.Float32(mX_off[0, target])
|
|
152
159
|
|
|
153
160
|
threads_per_row = tv_layout.shape[0][0]
|
|
154
161
|
if cutlass.const_expr(not self.online_softmax):
|
|
@@ -200,7 +207,7 @@ class CrossEntropy(ReductionBase):
|
|
|
200
207
|
mLSE[row] = lse
|
|
201
208
|
|
|
202
209
|
|
|
203
|
-
def
|
|
210
|
+
def _cross_entropy(
|
|
204
211
|
x: torch.Tensor,
|
|
205
212
|
target: torch.Tensor,
|
|
206
213
|
return_lse: bool = False,
|
|
@@ -241,15 +248,299 @@ def cross_entropy(
|
|
|
241
248
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
242
249
|
|
|
243
250
|
compile_key = (dtype, N, lse is not None)
|
|
244
|
-
if compile_key not in
|
|
251
|
+
if compile_key not in _cross_entropy.compile_cache:
|
|
245
252
|
cross_entropy_op = CrossEntropy(dtype, N)
|
|
246
|
-
|
|
253
|
+
_cross_entropy.compile_cache[compile_key] = cute.compile(
|
|
247
254
|
cross_entropy_op, x_tensor, target_tensor, loss_tensor, lse_tensor, stream
|
|
248
255
|
)
|
|
249
|
-
|
|
256
|
+
_cross_entropy.compile_cache[compile_key](
|
|
250
257
|
x_tensor, target_tensor, loss_tensor, lse_tensor, stream
|
|
251
258
|
)
|
|
252
259
|
return loss if not return_lse else (loss, lse)
|
|
253
260
|
|
|
254
261
|
|
|
255
|
-
|
|
262
|
+
_cross_entropy.compile_cache = {}
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
class CrossEntropyBackward:
|
|
266
|
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int):
|
|
267
|
+
self.dtype = dtype
|
|
268
|
+
self.N = N
|
|
269
|
+
self.vecsize = 128 // dtype.width
|
|
270
|
+
|
|
271
|
+
def _calculate_threads_per_row(self):
|
|
272
|
+
N = self.N
|
|
273
|
+
return (
|
|
274
|
+
8
|
|
275
|
+
if N <= 64
|
|
276
|
+
else (
|
|
277
|
+
16
|
|
278
|
+
if N <= 128
|
|
279
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
280
|
+
)
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
def _get_tv_layout(self):
|
|
284
|
+
N = self.N
|
|
285
|
+
vecsize = self.vecsize
|
|
286
|
+
num_threads = 128 if N <= 16384 else 256
|
|
287
|
+
threads_per_row = self._calculate_threads_per_row()
|
|
288
|
+
cols_per_block = num_threads // threads_per_row
|
|
289
|
+
num_blocks_N = cute.ceil_div(min(N, 16384) // vecsize, threads_per_row)
|
|
290
|
+
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
|
|
291
|
+
tv_layout = cute.make_layout(
|
|
292
|
+
((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
|
|
293
|
+
stride=(
|
|
294
|
+
(vecsize * cols_per_block, 1),
|
|
295
|
+
(cols_per_block, cols_per_block * vecsize * threads_per_row),
|
|
296
|
+
),
|
|
297
|
+
)
|
|
298
|
+
return tiler_mn, tv_layout
|
|
299
|
+
|
|
300
|
+
@cute.jit
|
|
301
|
+
def __call__(
|
|
302
|
+
self,
|
|
303
|
+
mX: cute.Tensor,
|
|
304
|
+
mTarget: cute.Tensor,
|
|
305
|
+
mDLoss: cute.Tensor,
|
|
306
|
+
mdX: cute.Tensor,
|
|
307
|
+
mLSE: cute.Tensor,
|
|
308
|
+
stream: cuda.CUstream,
|
|
309
|
+
):
|
|
310
|
+
assert mX.element_type == self.dtype
|
|
311
|
+
assert mdX.element_type == self.dtype
|
|
312
|
+
|
|
313
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
314
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
315
|
+
|
|
316
|
+
mDLoss = cute.make_tensor(
|
|
317
|
+
mDLoss.iterator, cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,)))
|
|
318
|
+
)
|
|
319
|
+
mTarget = cute.make_tensor(
|
|
320
|
+
mTarget.iterator, cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,)))
|
|
321
|
+
)
|
|
322
|
+
mLSE = cute.make_tensor(
|
|
323
|
+
mLSE.iterator, cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,)))
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
smem_size = cute.size_in_bytes(
|
|
327
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0))
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
self.kernel(
|
|
331
|
+
mX,
|
|
332
|
+
mTarget,
|
|
333
|
+
mDLoss,
|
|
334
|
+
mdX,
|
|
335
|
+
mLSE,
|
|
336
|
+
mX.shape,
|
|
337
|
+
tv_layout,
|
|
338
|
+
tiler_mn,
|
|
339
|
+
).launch(
|
|
340
|
+
grid=[
|
|
341
|
+
cute.ceil_div(mX.shape[0], tiler_mn[0]),
|
|
342
|
+
cute.ceil_div(mX.shape[1], tiler_mn[1]),
|
|
343
|
+
1,
|
|
344
|
+
],
|
|
345
|
+
block=[num_threads, 1, 1],
|
|
346
|
+
smem=smem_size,
|
|
347
|
+
stream=stream,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
@cute.kernel
|
|
351
|
+
def kernel(
|
|
352
|
+
self,
|
|
353
|
+
mX: cute.Tensor, # (M, N)
|
|
354
|
+
mTarget: cute.Tensor, # (M,)
|
|
355
|
+
mDLoss: cute.Tensor, # (M,)
|
|
356
|
+
mdX: cute.Tensor, # (M, N)
|
|
357
|
+
mLSE: cute.Tensor, # (M,)
|
|
358
|
+
shape: cute.Shape,
|
|
359
|
+
tv_layout: cute.Layout,
|
|
360
|
+
tiler_mn: cute.Shape,
|
|
361
|
+
):
|
|
362
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
363
|
+
bidx, bidy, _ = cute.arch.block_idx()
|
|
364
|
+
|
|
365
|
+
smem = cutlass.utils.SmemAllocator()
|
|
366
|
+
sX = smem.allocate_tensor(
|
|
367
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
idX = cute.make_identity_tensor(shape)
|
|
371
|
+
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
|
372
|
+
mX, mdX = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mdX)]
|
|
373
|
+
gX, gdX = [cute.local_tile(mT, tiler_mn, (0, bidy)) for mT in (mX, mdX)]
|
|
374
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, bidy))
|
|
375
|
+
|
|
376
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
377
|
+
cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128
|
|
378
|
+
)
|
|
379
|
+
copy_atom_load_X_async = cute.make_copy_atom(
|
|
380
|
+
cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
|
|
381
|
+
)
|
|
382
|
+
copy_atom_store_O = cute.make_copy_atom(
|
|
383
|
+
cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=128
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
387
|
+
thr_copy_X_async = cute.make_tiled_copy(
|
|
388
|
+
copy_atom_load_X_async, tv_layout, tiler_mn
|
|
389
|
+
).get_slice(tidx)
|
|
390
|
+
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
391
|
+
|
|
392
|
+
#### Thread View
|
|
393
|
+
tXgX = thr_copy_X_async.partition_S(gX)
|
|
394
|
+
tXsX = thr_copy_X_async.partition_S(sX)
|
|
395
|
+
|
|
396
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
397
|
+
tXcFull = thr_copy_X.partition_S(cX) # improve
|
|
398
|
+
|
|
399
|
+
tXgO = thr_copy_O.partition_D(gdX)
|
|
400
|
+
|
|
401
|
+
# allocate fragments for gmem->rmem
|
|
402
|
+
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
403
|
+
|
|
404
|
+
is_even_N = cutlass.const_expr(shape[1] % tiler_mn[1] == 0)
|
|
405
|
+
row = tXcX[0][0]
|
|
406
|
+
|
|
407
|
+
tXpX = (
|
|
408
|
+
utils.predicate_k(thr_copy_X_async.partition_S(cX), limit=shape[1])
|
|
409
|
+
if not is_even_N
|
|
410
|
+
else None
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
if row < shape[0]:
|
|
414
|
+
cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
|
|
415
|
+
cute.arch.cp_async_commit_group()
|
|
416
|
+
cute.arch.cp_async_wait_group(0)
|
|
417
|
+
if cutlass.const_expr(not is_even_N):
|
|
418
|
+
utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
|
|
419
|
+
|
|
420
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
421
|
+
x = tXrX.load().to(cute.Float32)
|
|
422
|
+
|
|
423
|
+
label = cute.Int32.zero
|
|
424
|
+
dloss = cute.Float32.zero
|
|
425
|
+
lse = cute.Float32.zero
|
|
426
|
+
if row < shape[0]:
|
|
427
|
+
label = cute.Int32(mTarget[row])
|
|
428
|
+
dloss = cute.Float32(mDLoss[row])
|
|
429
|
+
lse = cute.Float32(mLSE[row])
|
|
430
|
+
|
|
431
|
+
log2_e = math.log2(math.e)
|
|
432
|
+
probs = utils.exp2f((x - lse) * log2_e)
|
|
433
|
+
prob_shifted = probs - 1.0
|
|
434
|
+
|
|
435
|
+
mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
|
|
436
|
+
for i in cutlass.range_constexpr(cute.size(tXcFull)):
|
|
437
|
+
mask[i] = tXcFull[i][1] == label
|
|
438
|
+
|
|
439
|
+
mask = mask.load()
|
|
440
|
+
grad = cute.where(mask, prob_shifted, probs)
|
|
441
|
+
grad = grad * dloss
|
|
442
|
+
|
|
443
|
+
tXrO.store(grad.to(tXrO.element_type))
|
|
444
|
+
tOpO = (
|
|
445
|
+
utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
446
|
+
)
|
|
447
|
+
if row < shape[0]:
|
|
448
|
+
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def _cross_entropy_backward(
|
|
452
|
+
x: torch.Tensor,
|
|
453
|
+
target: torch.Tensor,
|
|
454
|
+
dloss: torch.Tensor,
|
|
455
|
+
lse: torch.Tensor,
|
|
456
|
+
inplace_backward: bool = False,
|
|
457
|
+
) -> torch.Tensor:
|
|
458
|
+
"""Cross entropy backward pass.
|
|
459
|
+
Args:
|
|
460
|
+
x: Input logits tensor of shape (M, N)
|
|
461
|
+
target: Target class indices tensor of shape (M,)
|
|
462
|
+
dloss: Upstream gradients tensor of shape (M,)
|
|
463
|
+
lse: Log-sum-exp values tensor of shape (M,)
|
|
464
|
+
Returns:
|
|
465
|
+
Input gradients tensor of shape (M, N)
|
|
466
|
+
"""
|
|
467
|
+
assert x.dim() == 2, "Input must be 2D"
|
|
468
|
+
assert target.dim() == 1, "Target must be 1D"
|
|
469
|
+
assert dloss.dim() == 1, "dloss must be 1D"
|
|
470
|
+
assert lse.dim() == 1, "lse must be 1D"
|
|
471
|
+
assert x.shape[0] == target.shape[0], "Batch dimensions must match"
|
|
472
|
+
assert x.shape[0] == dloss.shape[0], "Batch dimensions must match"
|
|
473
|
+
assert x.shape[0] == lse.shape[0], "Batch dimensions must match"
|
|
474
|
+
assert (
|
|
475
|
+
x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
|
|
476
|
+
), "Tensors must be on CUDA device"
|
|
477
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
|
|
478
|
+
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
479
|
+
|
|
480
|
+
M, N = x.shape
|
|
481
|
+
dx = torch.empty_like(x) if not inplace_backward else x
|
|
482
|
+
dtype = torch2cute_dtype_map[x.dtype]
|
|
483
|
+
|
|
484
|
+
convert_from_dlpack = lambda tensor: (
|
|
485
|
+
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
486
|
+
mode=0, stride_order=(0, 1)
|
|
487
|
+
)
|
|
488
|
+
)
|
|
489
|
+
x_tensor = convert_from_dlpack(x)
|
|
490
|
+
dx_tensor = convert_from_dlpack(dx)
|
|
491
|
+
dloss_tensor = from_dlpack(dloss.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
492
|
+
lse_tensor = from_dlpack(lse.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
493
|
+
target_tensor = from_dlpack(target.detach(), assumed_align=32).mark_compact_shape_dynamic(
|
|
494
|
+
mode=0
|
|
495
|
+
)
|
|
496
|
+
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
497
|
+
|
|
498
|
+
compile_key = (dtype, N)
|
|
499
|
+
if compile_key not in _cross_entropy_backward.compile_cache:
|
|
500
|
+
cross_entropy_backward_op = CrossEntropyBackward(dtype, N)
|
|
501
|
+
_cross_entropy_backward.compile_cache[compile_key] = cute.compile(
|
|
502
|
+
cross_entropy_backward_op,
|
|
503
|
+
x_tensor,
|
|
504
|
+
target_tensor,
|
|
505
|
+
dloss_tensor,
|
|
506
|
+
dx_tensor,
|
|
507
|
+
lse_tensor,
|
|
508
|
+
stream,
|
|
509
|
+
)
|
|
510
|
+
_cross_entropy_backward.compile_cache[compile_key](
|
|
511
|
+
x_tensor, target_tensor, dloss_tensor, dx_tensor, lse_tensor, stream
|
|
512
|
+
)
|
|
513
|
+
return dx
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
_cross_entropy_backward.compile_cache = {}
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
class CrossEntropyFunction(torch.autograd.Function):
|
|
520
|
+
@staticmethod
|
|
521
|
+
def forward(ctx, x, target, inplace_backward=False):
|
|
522
|
+
loss, lse = _cross_entropy(x, target, return_lse=True)
|
|
523
|
+
ctx.save_for_backward(x, target, lse)
|
|
524
|
+
ctx.inplace_backward = inplace_backward
|
|
525
|
+
return loss
|
|
526
|
+
|
|
527
|
+
@staticmethod
|
|
528
|
+
def backward(ctx, dloss):
|
|
529
|
+
x, target, lse = ctx.saved_tensors
|
|
530
|
+
dx = _cross_entropy_backward(x, target, dloss, lse, inplace_backward=ctx.inplace_backward)
|
|
531
|
+
return dx, None, None
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
def cross_entropy(
|
|
535
|
+
x: torch.Tensor, target: torch.Tensor, inplace_backward: bool = False
|
|
536
|
+
) -> torch.Tensor:
|
|
537
|
+
"""Cross entropy loss with automatic differentiation support.
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
x: Input logits tensor of shape (M, N)
|
|
541
|
+
target: Target class indices tensor of shape (M,)
|
|
542
|
+
|
|
543
|
+
Returns:
|
|
544
|
+
Cross entropy loss tensor of shape (M,)
|
|
545
|
+
"""
|
|
546
|
+
return CrossEntropyFunction.apply(x, target, inplace_backward)
|
quack/layernorm.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import cuda.bindings.driver as cuda
|
|
8
|
+
|
|
9
|
+
import cutlass
|
|
10
|
+
import cutlass.cute as cute
|
|
11
|
+
from cutlass.cute.runtime import from_dlpack
|
|
12
|
+
import quack.utils as utils
|
|
13
|
+
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LayerNorm(ReductionBase):
|
|
17
|
+
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
18
|
+
super().__init__(dtype, N, stage=2) # 2 stages for mean and var
|
|
19
|
+
self.reload_from = None if N <= 16384 else "smem"
|
|
20
|
+
self.delay_w_load = False
|
|
21
|
+
|
|
22
|
+
def _calculate_threads_per_row(self):
|
|
23
|
+
N = self.N
|
|
24
|
+
return (
|
|
25
|
+
8
|
|
26
|
+
if N <= 64
|
|
27
|
+
else (
|
|
28
|
+
16
|
|
29
|
+
if N <= 128
|
|
30
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
31
|
+
)
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def _set_cluster_n(self):
|
|
35
|
+
N = self.N
|
|
36
|
+
# cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
|
|
37
|
+
# Similarly cluster_n = 8 is faster for N=128k
|
|
38
|
+
if cutlass.const_expr(self.dtype.width == 16):
|
|
39
|
+
cluster_n = (
|
|
40
|
+
1
|
|
41
|
+
if N <= 16 * 1024
|
|
42
|
+
else (
|
|
43
|
+
2
|
|
44
|
+
if N <= 32 * 1024
|
|
45
|
+
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
46
|
+
)
|
|
47
|
+
)
|
|
48
|
+
else: # fp32
|
|
49
|
+
cluster_n = (
|
|
50
|
+
1
|
|
51
|
+
if N <= 32 * 1024
|
|
52
|
+
else (
|
|
53
|
+
2
|
|
54
|
+
if N <= 64 * 1024
|
|
55
|
+
else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
|
|
56
|
+
)
|
|
57
|
+
)
|
|
58
|
+
self.cluster_n = cluster_n
|
|
59
|
+
|
|
60
|
+
@cute.jit
|
|
61
|
+
def __call__(
|
|
62
|
+
self,
|
|
63
|
+
mX: cute.Tensor,
|
|
64
|
+
mW: cute.Tensor,
|
|
65
|
+
mO: cute.Tensor,
|
|
66
|
+
mRstd: Optional[cute.Tensor],
|
|
67
|
+
mMean: Optional[cute.Tensor],
|
|
68
|
+
stream: cuda.CUstream,
|
|
69
|
+
eps: cutlass.Float32 = 1e-6,
|
|
70
|
+
):
|
|
71
|
+
assert mX.element_type == self.dtype
|
|
72
|
+
assert mO.element_type == self.dtype
|
|
73
|
+
self._set_cluster_n()
|
|
74
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
75
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
76
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
77
|
+
mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
|
|
78
|
+
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
79
|
+
if cutlass.const_expr(mRstd is not None):
|
|
80
|
+
mRstd_expanded_layout = cute.append(
|
|
81
|
+
mRstd.layout, cute.make_layout((self.N,), stride=(0,))
|
|
82
|
+
)
|
|
83
|
+
mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
|
|
84
|
+
if cutlass.const_expr(mMean is not None):
|
|
85
|
+
mMean_expanded_layout = cute.append(
|
|
86
|
+
mMean.layout, cute.make_layout((self.N,), stride=(0,))
|
|
87
|
+
)
|
|
88
|
+
mMean = cute.make_tensor(mMean.iterator, mMean_expanded_layout)
|
|
89
|
+
self.kernel(mX, mW, mO, mRstd, mMean, eps, tv_layout, tiler_mn, self.reload_from).launch(
|
|
90
|
+
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
91
|
+
block=[num_threads, 1, 1],
|
|
92
|
+
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
93
|
+
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
94
|
+
stream=stream,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
@cute.kernel
|
|
98
|
+
def kernel(
|
|
99
|
+
self,
|
|
100
|
+
mX: cute.Tensor,
|
|
101
|
+
mW: cute.Tensor,
|
|
102
|
+
mO: cute.Tensor,
|
|
103
|
+
mRstd: Optional[cute.Tensor],
|
|
104
|
+
mMean: Optional[cute.Tensor],
|
|
105
|
+
eps: cute.Float32,
|
|
106
|
+
tv_layout: cute.Layout,
|
|
107
|
+
tiler_mn: cute.Shape,
|
|
108
|
+
reload_from: cutlass.Constexpr = None,
|
|
109
|
+
delay_w_load: cutlass.Constexpr = False,
|
|
110
|
+
):
|
|
111
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
112
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
113
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
114
|
+
cluster_y = cute.arch.block_idx()[1]
|
|
115
|
+
else:
|
|
116
|
+
cluster_y = cutlass.const_expr(0)
|
|
117
|
+
|
|
118
|
+
smem = cutlass.utils.SmemAllocator()
|
|
119
|
+
sX = smem.allocate_tensor(
|
|
120
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
121
|
+
)
|
|
122
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
123
|
+
|
|
124
|
+
shape = mX.shape
|
|
125
|
+
idX = cute.make_identity_tensor(shape)
|
|
126
|
+
# slice for CTAs
|
|
127
|
+
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
|
128
|
+
mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
|
|
129
|
+
gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
|
|
130
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
131
|
+
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
132
|
+
gRstd = (
|
|
133
|
+
cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
|
|
134
|
+
if cutlass.const_expr(mRstd is not None)
|
|
135
|
+
else None
|
|
136
|
+
)
|
|
137
|
+
gMean = (
|
|
138
|
+
cute.local_tile(mMean, tiler_mn, (bidx, cluster_y))
|
|
139
|
+
if cutlass.const_expr(mMean is not None)
|
|
140
|
+
else None
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# declare the atoms which will be used later for memory copy
|
|
144
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
145
|
+
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
|
|
146
|
+
)
|
|
147
|
+
copy_atom_load_X_async = cute.make_copy_atom(
|
|
148
|
+
cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
|
|
149
|
+
)
|
|
150
|
+
copy_atom_load_W = cute.make_copy_atom(
|
|
151
|
+
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
|
|
152
|
+
)
|
|
153
|
+
copy_atom_store_O = cute.make_copy_atom(
|
|
154
|
+
cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=128
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
|
|
158
|
+
tidx
|
|
159
|
+
)
|
|
160
|
+
thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
|
|
161
|
+
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
162
|
+
|
|
163
|
+
tWgW = thr_copy_W.partition_S(gW)
|
|
164
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
165
|
+
tXsX = thr_copy_X.partition_D(sX)
|
|
166
|
+
tXgO = thr_copy_O.partition_D(gO)
|
|
167
|
+
tXrRstd = thr_copy_O.partition_D(gRstd) if cutlass.const_expr(mRstd is not None) else None
|
|
168
|
+
tXrMean = thr_copy_O.partition_D(gMean) if cutlass.const_expr(mMean is not None) else None
|
|
169
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
170
|
+
|
|
171
|
+
# allocate fragments for gmem->rmem
|
|
172
|
+
tWrW = cute.make_fragment_like(tWgW)
|
|
173
|
+
tXrW = thr_copy_X.retile(tWrW)
|
|
174
|
+
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
175
|
+
|
|
176
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
177
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
178
|
+
|
|
179
|
+
tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
180
|
+
row = tXcX[0][0]
|
|
181
|
+
if row < shape[0]:
|
|
182
|
+
cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
|
|
183
|
+
cute.arch.cp_async_commit_group()
|
|
184
|
+
|
|
185
|
+
tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
|
|
186
|
+
if cutlass.const_expr(not delay_w_load):
|
|
187
|
+
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
188
|
+
|
|
189
|
+
cute.arch.cp_async_wait_group(0)
|
|
190
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
191
|
+
x = tXrX.load().to(cute.Float32)
|
|
192
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
193
|
+
sum_x = utils.row_reduce(
|
|
194
|
+
x,
|
|
195
|
+
cute.ReductionOp.ADD,
|
|
196
|
+
threads_per_row,
|
|
197
|
+
reduction_buffer[None, None, 0],
|
|
198
|
+
mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
199
|
+
init_val=0.0,
|
|
200
|
+
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
201
|
+
)
|
|
202
|
+
mean = sum_x / shape[1]
|
|
203
|
+
if cutlass.const_expr(reload_from == "smem"):
|
|
204
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
205
|
+
x = tXrX.load().to(cute.Float32)
|
|
206
|
+
elif cutlass.const_expr(reload_from == "gmem"):
|
|
207
|
+
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
208
|
+
x = tXrX.load().to(cute.Float32)
|
|
209
|
+
|
|
210
|
+
sum_sq_x_sub_mean = utils.row_reduce(
|
|
211
|
+
(x - mean) * (x - mean),
|
|
212
|
+
cute.ReductionOp.ADD,
|
|
213
|
+
threads_per_row,
|
|
214
|
+
reduction_buffer[None, None, 1],
|
|
215
|
+
mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
216
|
+
init_val=0.0,
|
|
217
|
+
)
|
|
218
|
+
rstd = utils.rsqrt(sum_sq_x_sub_mean / shape[1] + eps)
|
|
219
|
+
if cutlass.const_expr(mRstd is not None):
|
|
220
|
+
# Only the thread corresponding to column 0 writes out the rstd to gmem
|
|
221
|
+
if (
|
|
222
|
+
tXcX[0][1] == 0
|
|
223
|
+
and row < shape[0]
|
|
224
|
+
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
225
|
+
):
|
|
226
|
+
tXrRstd[0] = rstd
|
|
227
|
+
if cutlass.const_expr(mMean is not None):
|
|
228
|
+
# Only the thread corresponding to column 0 writes out the mean to gmem
|
|
229
|
+
if (
|
|
230
|
+
tXcX[0][1] == 0
|
|
231
|
+
and row < shape[0]
|
|
232
|
+
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
233
|
+
):
|
|
234
|
+
tXrMean[0] = mean
|
|
235
|
+
if cutlass.const_expr(delay_w_load):
|
|
236
|
+
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
237
|
+
if cutlass.const_expr(reload_from == "smem"):
|
|
238
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
239
|
+
x = tXrX.load().to(cute.Float32)
|
|
240
|
+
elif cutlass.const_expr(reload_from == "gmem"):
|
|
241
|
+
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
242
|
+
x = tXrX.load().to(cute.Float32)
|
|
243
|
+
x_hat = (x - mean) * rstd
|
|
244
|
+
w = tXrW.load().to(cute.Float32)
|
|
245
|
+
y = x_hat * w
|
|
246
|
+
tXrO.store(y.to(tXrO.element_type))
|
|
247
|
+
tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
|
|
248
|
+
if row < shape[0]:
|
|
249
|
+
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def layernorm(
|
|
253
|
+
x: torch.Tensor,
|
|
254
|
+
weight: torch.Tensor,
|
|
255
|
+
eps: float = 1e-6,
|
|
256
|
+
return_rstd: bool = False,
|
|
257
|
+
return_mean: bool = False,
|
|
258
|
+
) -> torch.Tensor:
|
|
259
|
+
"""LayerNorm forward pass.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
x: Input tensor of shape (M, N)
|
|
263
|
+
weight: Weight tensor of shape (N,)
|
|
264
|
+
eps: Small value for numerical stability
|
|
265
|
+
return_rstd: Whether to return the reciprocal standard deviation
|
|
266
|
+
return_mean: Whether to return the mean
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
Normalized output tensor of same shape as x
|
|
270
|
+
If return_rstd is True, also returns rstd tensor of shape (M,)
|
|
271
|
+
If return_mean is True, also returns mean tensor of shape (M,)
|
|
272
|
+
"""
|
|
273
|
+
assert x.dim() == 2, "Input must be 2D"
|
|
274
|
+
assert weight.dim() == 1, "Weight must be 1D"
|
|
275
|
+
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
276
|
+
assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
|
|
277
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
278
|
+
assert weight.dtype == torch.float32, "Weight must be float32"
|
|
279
|
+
M, N = x.shape
|
|
280
|
+
device = x.device
|
|
281
|
+
out = torch.empty_like(x)
|
|
282
|
+
rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
|
|
283
|
+
mean = torch.empty(M, device=device, dtype=torch.float32) if return_mean else None
|
|
284
|
+
dtype = torch2cute_dtype_map[x.dtype]
|
|
285
|
+
convert_from_dlpack = lambda x: (
|
|
286
|
+
from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
287
|
+
mode=0, stride_order=(0, 1)
|
|
288
|
+
)
|
|
289
|
+
)
|
|
290
|
+
x_tensor, out_tensor = [
|
|
291
|
+
# utils.convert_from_dlpack(t, leading_dim=t.ndim - 1, divisibility=128 // dtype.width)
|
|
292
|
+
convert_from_dlpack(t)
|
|
293
|
+
for t in (x, out)
|
|
294
|
+
]
|
|
295
|
+
weight_tensor = utils.convert_from_dlpack(
|
|
296
|
+
weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
|
|
297
|
+
)
|
|
298
|
+
rstd_tensor = (
|
|
299
|
+
from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
300
|
+
if rstd is not None
|
|
301
|
+
else None
|
|
302
|
+
)
|
|
303
|
+
mean_tensor = (
|
|
304
|
+
from_dlpack(mean.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
305
|
+
if mean is not None
|
|
306
|
+
else None
|
|
307
|
+
)
|
|
308
|
+
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
309
|
+
compile_key = (dtype, N, rstd is not None, mean is not None)
|
|
310
|
+
if compile_key not in layernorm.compile_cache:
|
|
311
|
+
rmsnorm_op = LayerNorm(dtype, N)
|
|
312
|
+
layernorm.compile_cache[compile_key] = cute.compile(
|
|
313
|
+
rmsnorm_op,
|
|
314
|
+
x_tensor,
|
|
315
|
+
weight_tensor,
|
|
316
|
+
out_tensor,
|
|
317
|
+
rstd_tensor,
|
|
318
|
+
mean_tensor,
|
|
319
|
+
current_stream,
|
|
320
|
+
)
|
|
321
|
+
layernorm.compile_cache[compile_key](
|
|
322
|
+
x_tensor, weight_tensor, out_tensor, rstd_tensor, mean_tensor, current_stream, eps
|
|
323
|
+
)
|
|
324
|
+
return (
|
|
325
|
+
(out, rstd, mean)
|
|
326
|
+
if return_mean and return_rstd
|
|
327
|
+
else (
|
|
328
|
+
(out, rstd)
|
|
329
|
+
if return_rstd and not return_mean
|
|
330
|
+
else ((out, mean) if return_mean and not return_rstd else (out))
|
|
331
|
+
)
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
layernorm.compile_cache = {}
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def layernorm_ref(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6):
|
|
339
|
+
x_f32 = x.float()
|
|
340
|
+
return torch.nn.functional.layer_norm(x_f32, w.shape, w, None, eps).to(x.dtype)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def rstd_ref(x: torch.Tensor, eps: float = 1e-6):
|
|
344
|
+
x_f32 = x.float()
|
|
345
|
+
mean = x_f32.mean(dim=-1, keepdim=True)
|
|
346
|
+
var = ((x_f32 - mean) ** 2).mean(dim=-1)
|
|
347
|
+
return 1.0 / torch.sqrt(var + eps)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def mean_ref(x: torch.Tensor) -> torch.Tensor:
|
|
351
|
+
return x.float().mean(dim=-1)
|
quack/reduction_base.py
CHANGED
|
@@ -6,8 +6,6 @@ from typing import Type, Tuple, Optional
|
|
|
6
6
|
import cutlass
|
|
7
7
|
import cutlass.cute as cute
|
|
8
8
|
|
|
9
|
-
import quack.utils as utils
|
|
10
|
-
|
|
11
9
|
|
|
12
10
|
torch2cute_dtype_map = {
|
|
13
11
|
torch.float16: cutlass.Float16,
|
|
@@ -39,7 +37,6 @@ class ReductionBase:
|
|
|
39
37
|
vecsize = copy_bits // self.dtype.width
|
|
40
38
|
assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
|
|
41
39
|
num_threads = self._get_num_threads()
|
|
42
|
-
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
43
40
|
assert num_threads % cute.arch.WARP_SIZE == 0
|
|
44
41
|
|
|
45
42
|
threads_per_row = self._calculate_threads_per_row()
|
|
@@ -64,7 +61,7 @@ class ReductionBase:
|
|
|
64
61
|
|
|
65
62
|
def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int):
|
|
66
63
|
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
67
|
-
warps_per_row =
|
|
64
|
+
warps_per_row = max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
|
|
68
65
|
return cute.make_ordered_layout(
|
|
69
66
|
(num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
|
|
70
67
|
order=(1, 0, 2),
|
quack/rmsnorm.py
CHANGED
|
@@ -9,7 +9,6 @@ import cuda.bindings.driver as cuda
|
|
|
9
9
|
import cutlass
|
|
10
10
|
import cutlass.cute as cute
|
|
11
11
|
from cutlass.cute.runtime import from_dlpack
|
|
12
|
-
|
|
13
12
|
import quack.utils as utils
|
|
14
13
|
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
15
14
|
|
|
@@ -118,7 +117,10 @@ class RMSNorm(ReductionBase):
|
|
|
118
117
|
shape = mX.shape
|
|
119
118
|
idX = cute.make_identity_tensor(shape)
|
|
120
119
|
# slice for CTAs
|
|
121
|
-
|
|
120
|
+
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
|
121
|
+
mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
|
|
122
|
+
gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
|
|
123
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
122
124
|
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
123
125
|
gRstd = (
|
|
124
126
|
cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
|
|
@@ -210,20 +212,18 @@ class RMSNorm(ReductionBase):
|
|
|
210
212
|
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
211
213
|
|
|
212
214
|
|
|
213
|
-
def
|
|
215
|
+
def _rmsnorm_fwd(
|
|
214
216
|
x: torch.Tensor,
|
|
215
217
|
weight: torch.Tensor,
|
|
216
218
|
eps: float = 1e-6,
|
|
217
219
|
return_rstd: bool = False,
|
|
218
220
|
) -> torch.Tensor:
|
|
219
221
|
"""RMSNorm forward pass.
|
|
220
|
-
|
|
221
222
|
Args:
|
|
222
223
|
x: Input tensor of shape (M, N)
|
|
223
224
|
weight: Weight tensor of shape (N,)
|
|
224
225
|
eps: Small value for numerical stability
|
|
225
226
|
return_rstd: Whether to return the reciprocal standard deviation
|
|
226
|
-
|
|
227
227
|
Returns:
|
|
228
228
|
Normalized output tensor of same shape as x
|
|
229
229
|
If return_rstd is True, also returns rstd tensor of shape (M,)
|
|
@@ -259,18 +259,18 @@ def rmsnorm(
|
|
|
259
259
|
)
|
|
260
260
|
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
261
261
|
compile_key = (dtype, N, rstd is not None)
|
|
262
|
-
if compile_key not in
|
|
262
|
+
if compile_key not in _rmsnorm_fwd.compile_cache:
|
|
263
263
|
rmsnorm_op = RMSNorm(dtype, N)
|
|
264
|
-
|
|
264
|
+
_rmsnorm_fwd.compile_cache[compile_key] = cute.compile(
|
|
265
265
|
rmsnorm_op, x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream
|
|
266
266
|
)
|
|
267
|
-
|
|
267
|
+
_rmsnorm_fwd.compile_cache[compile_key](
|
|
268
268
|
x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream, eps
|
|
269
269
|
)
|
|
270
270
|
return (out, rstd) if return_rstd else out
|
|
271
271
|
|
|
272
272
|
|
|
273
|
-
|
|
273
|
+
_rmsnorm_fwd.compile_cache = {}
|
|
274
274
|
|
|
275
275
|
|
|
276
276
|
def rmsnorm_ref(x, w, eps=1e-6):
|
|
@@ -283,3 +283,383 @@ def rmsnorm_ref(x, w, eps=1e-6):
|
|
|
283
283
|
def rstd_ref(x, eps=1e-6):
|
|
284
284
|
x_f32 = x.float()
|
|
285
285
|
return 1.0 / torch.sqrt(torch.mean(x_f32 * x_f32, dim=-1) + eps)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
|
|
289
|
+
"""Reference implementation for RMSNorm backward pass."""
|
|
290
|
+
x_f32 = x.float()
|
|
291
|
+
x_hat = x_f32 * rstd.unsqueeze(1)
|
|
292
|
+
wdy = dout * w
|
|
293
|
+
c1 = (x_hat * wdy).mean(dim=-1, keepdim=True)
|
|
294
|
+
dx = (wdy - x_hat * c1) * rstd.unsqueeze(1)
|
|
295
|
+
|
|
296
|
+
# dL/dW
|
|
297
|
+
dw = (dout * x_hat).sum(dim=0)
|
|
298
|
+
return dx.to(x.dtype), dw.to(w.dtype)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
class RMSNormBackward(ReductionBase):
|
|
302
|
+
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
303
|
+
# 1 stage for computing mean of x_hat * wdy
|
|
304
|
+
super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Float32)
|
|
305
|
+
|
|
306
|
+
def _calculate_threads_per_row(self):
|
|
307
|
+
N = self.N
|
|
308
|
+
return (
|
|
309
|
+
8
|
|
310
|
+
if N <= 64
|
|
311
|
+
else (
|
|
312
|
+
16
|
|
313
|
+
if N <= 128
|
|
314
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
315
|
+
)
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
def _set_cluster_n(self):
|
|
319
|
+
N = self.N
|
|
320
|
+
if cutlass.const_expr(self.dtype.width == 16):
|
|
321
|
+
cluster_n = (
|
|
322
|
+
1
|
|
323
|
+
if N <= 16 * 1024
|
|
324
|
+
else (
|
|
325
|
+
2
|
|
326
|
+
if N <= 32 * 1024
|
|
327
|
+
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
328
|
+
)
|
|
329
|
+
)
|
|
330
|
+
else: # fp32
|
|
331
|
+
cluster_n = (
|
|
332
|
+
1
|
|
333
|
+
if N <= 32 * 1024
|
|
334
|
+
else (
|
|
335
|
+
2
|
|
336
|
+
if N <= 64 * 1024
|
|
337
|
+
else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
|
|
338
|
+
)
|
|
339
|
+
)
|
|
340
|
+
self.cluster_n = cluster_n
|
|
341
|
+
|
|
342
|
+
@cute.jit
|
|
343
|
+
def __call__(
|
|
344
|
+
self,
|
|
345
|
+
mX: cute.Tensor,
|
|
346
|
+
mW: cute.Tensor,
|
|
347
|
+
mDout: cute.Tensor,
|
|
348
|
+
mRstd: cute.Tensor,
|
|
349
|
+
mDx: cute.Tensor,
|
|
350
|
+
mDw: cute.Tensor,
|
|
351
|
+
sm_count: cutlass.Constexpr,
|
|
352
|
+
stream: cuda.CUstream,
|
|
353
|
+
):
|
|
354
|
+
self._set_cluster_n()
|
|
355
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
356
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
357
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
358
|
+
|
|
359
|
+
mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
|
|
360
|
+
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
361
|
+
|
|
362
|
+
mRstd_expanded_layout = cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,)))
|
|
363
|
+
mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
|
|
364
|
+
|
|
365
|
+
num_blocks = (
|
|
366
|
+
sm_count if tiler_mn[0] == 1 else min(sm_count, cute.ceil_div(1024, tiler_mn[0]))
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
self.kernel(mX, mW, mDout, mRstd, mDx, mDw, sm_count, tv_layout, tiler_mn).launch(
|
|
370
|
+
grid=[num_blocks, self.cluster_n, 1],
|
|
371
|
+
block=[num_threads, 1, 1],
|
|
372
|
+
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
373
|
+
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
374
|
+
stream=stream,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
@cute.kernel
|
|
378
|
+
def kernel(
|
|
379
|
+
self,
|
|
380
|
+
mX: cute.Tensor,
|
|
381
|
+
mW: cute.Tensor,
|
|
382
|
+
mDout: cute.Tensor,
|
|
383
|
+
mRstd: cute.Tensor,
|
|
384
|
+
mDx: cute.Tensor,
|
|
385
|
+
mDw: cute.Tensor,
|
|
386
|
+
sm_count: cutlass.Constexpr,
|
|
387
|
+
tv_layout: cute.Layout,
|
|
388
|
+
tiler_mn: cute.Shape,
|
|
389
|
+
):
|
|
390
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
391
|
+
bidx, cluster_y, _ = cute.arch.block_idx()
|
|
392
|
+
gdim, _, _ = cute.arch.grid_dim()
|
|
393
|
+
|
|
394
|
+
shape = mX.shape
|
|
395
|
+
M, N = shape[0], shape[1]
|
|
396
|
+
|
|
397
|
+
idX = cute.make_identity_tensor(shape)
|
|
398
|
+
|
|
399
|
+
smem = cutlass.utils.SmemAllocator()
|
|
400
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
401
|
+
|
|
402
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
403
|
+
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
copy_atom_load_W = cute.make_copy_atom(
|
|
407
|
+
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
copy_atom_store_dX = cute.make_copy_atom(
|
|
411
|
+
cute.nvgpu.CopyUniversalOp(), mDx.element_type, num_bits_per_copy=128
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
copy_atom_dw = cute.make_copy_atom(
|
|
415
|
+
cute.nvgpu.CopyUniversalOp(), mDw.element_type, num_bits_per_copy=128
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
419
|
+
thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
|
|
420
|
+
thr_copy_dw = cute.make_tiled_copy(copy_atom_dw, tv_layout, tiler_mn).get_slice(tidx)
|
|
421
|
+
thr_store_dx = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
|
|
422
|
+
|
|
423
|
+
gW = cute.local_tile(mW, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
|
|
424
|
+
tWgW = thr_copy_W.partition_S(gW)
|
|
425
|
+
tWrW = cute.make_fragment_like(tWgW)
|
|
426
|
+
tXrW = thr_copy_X.retile(tWrW)
|
|
427
|
+
|
|
428
|
+
gW_coord = cute.local_tile(idX, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
|
|
429
|
+
|
|
430
|
+
tWpW = utils.predicate_k(thr_copy_W.partition_S(gW_coord), limit=shape[1])
|
|
431
|
+
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
432
|
+
weight = tXrW.load().to(cute.Float32)
|
|
433
|
+
|
|
434
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
435
|
+
|
|
436
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
437
|
+
|
|
438
|
+
dw_coord = cute.local_tile(idX, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
|
|
439
|
+
tDwpDw = utils.predicate_k(thr_copy_dw.partition_S(dw_coord), limit=shape[1])
|
|
440
|
+
|
|
441
|
+
gDw = cute.local_tile(mDw, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
|
|
442
|
+
tDwgDw = thr_copy_dw.partition_D(gDw)
|
|
443
|
+
tDwrDw = cute.make_fragment_like(tDwgDw)
|
|
444
|
+
dw_accumulator = thr_copy_X.retile(tDwrDw)
|
|
445
|
+
dw_accumulator.fill(0.0)
|
|
446
|
+
|
|
447
|
+
M_pad = ((M + sm_count - 1) // sm_count) * sm_count
|
|
448
|
+
|
|
449
|
+
jump = sm_count if tiler_mn[0] == 1 else min(sm_count, cute.ceil_div(1024, tiler_mn[0]))
|
|
450
|
+
|
|
451
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
452
|
+
cute.arch.cluster_arrive()
|
|
453
|
+
cute.arch.cluster_wait()
|
|
454
|
+
|
|
455
|
+
## need to update range_dynamic since it will be deprecated soon
|
|
456
|
+
for row_offset in cutlass.range_dynamic(bidx, M_pad, jump):
|
|
457
|
+
gX = cute.local_tile(
|
|
458
|
+
mX, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
|
|
459
|
+
)
|
|
460
|
+
gDout = cute.local_tile(
|
|
461
|
+
mDout, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
|
|
462
|
+
)
|
|
463
|
+
gRstd = cute.local_tile(
|
|
464
|
+
mRstd, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
|
|
465
|
+
)
|
|
466
|
+
gDx = cute.local_tile(
|
|
467
|
+
mDx, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
|
|
468
|
+
)
|
|
469
|
+
cX = cute.local_tile(
|
|
470
|
+
idX, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
474
|
+
thrDout = thr_copy_X.partition_S(gDout)
|
|
475
|
+
tXrRstd = thr_copy_W.partition_S(gRstd)
|
|
476
|
+
thrDx = thr_store_dx.partition_D(gDx)
|
|
477
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
478
|
+
|
|
479
|
+
tXrX, frgDout, frgDx = [cute.make_fragment_like(thr) for thr in (tXgX, thrDout, thrDx)]
|
|
480
|
+
|
|
481
|
+
tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
482
|
+
|
|
483
|
+
if tXcX[0][0] < shape[0]:
|
|
484
|
+
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
485
|
+
cute.copy(copy_atom_load_X, thrDout, frgDout, pred=tXpX)
|
|
486
|
+
|
|
487
|
+
x = tXrX.load().to(cute.Float32)
|
|
488
|
+
dout = frgDout.load().to(cute.Float32)
|
|
489
|
+
|
|
490
|
+
rstd = tXrRstd[0]
|
|
491
|
+
x_hat = x * rstd
|
|
492
|
+
wdy = dout * weight
|
|
493
|
+
|
|
494
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
495
|
+
|
|
496
|
+
row = tXcX[0][0]
|
|
497
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
498
|
+
cute.arch.cluster_arrive()
|
|
499
|
+
cute.arch.cluster_wait()
|
|
500
|
+
else:
|
|
501
|
+
cute.arch.barrier()
|
|
502
|
+
|
|
503
|
+
mean_xhat_wdy = (
|
|
504
|
+
utils.row_reduce(
|
|
505
|
+
x_hat * wdy,
|
|
506
|
+
cute.ReductionOp.ADD,
|
|
507
|
+
threads_per_row,
|
|
508
|
+
reduction_buffer[None, None, 0],
|
|
509
|
+
mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
510
|
+
init_val=0.0,
|
|
511
|
+
hook_fn=cute.arch.cluster_wait
|
|
512
|
+
if cutlass.const_expr(self.cluster_n > 1)
|
|
513
|
+
else None,
|
|
514
|
+
)
|
|
515
|
+
/ shape[1]
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
dx = (wdy - x_hat * mean_xhat_wdy) * rstd
|
|
519
|
+
frgDx.store(dx.to(frgDout.element_type))
|
|
520
|
+
|
|
521
|
+
if row < M:
|
|
522
|
+
cute.copy(copy_atom_store_dX, frgDx, thrDx, pred=tXpX)
|
|
523
|
+
|
|
524
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
525
|
+
cute.arch.cluster_arrive()
|
|
526
|
+
cute.arch.cluster_wait()
|
|
527
|
+
else:
|
|
528
|
+
cute.arch.barrier()
|
|
529
|
+
|
|
530
|
+
if row < M:
|
|
531
|
+
dw_row = dout * x_hat
|
|
532
|
+
current_dw = dw_accumulator.load().to(cute.Float32)
|
|
533
|
+
updated_dw = current_dw + dw_row
|
|
534
|
+
dw_accumulator.store(updated_dw.to(dw_accumulator.element_type))
|
|
535
|
+
|
|
536
|
+
"""
|
|
537
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
538
|
+
cute.arch.cluster_arrive()
|
|
539
|
+
cute.arch.cluster_wait()
|
|
540
|
+
else:
|
|
541
|
+
cute.arch.barrier()
|
|
542
|
+
"""
|
|
543
|
+
"""
|
|
544
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
545
|
+
cute.arch.cluster_arrive()
|
|
546
|
+
cute.arch.cluster_wait()
|
|
547
|
+
else:
|
|
548
|
+
cute.arch.barrier()
|
|
549
|
+
"""
|
|
550
|
+
|
|
551
|
+
cute.autovec_copy(dw_accumulator, tDwrDw)
|
|
552
|
+
cute.copy(copy_atom_dw, tDwrDw, tDwgDw, pred=tDwpDw)
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
def _rmsnorm_backward(
|
|
556
|
+
x: torch.Tensor,
|
|
557
|
+
weight: torch.Tensor,
|
|
558
|
+
dout: torch.Tensor,
|
|
559
|
+
rstd: torch.Tensor,
|
|
560
|
+
) -> (torch.Tensor, torch.Tensor):
|
|
561
|
+
"""RMSNorm backward pass.
|
|
562
|
+
Args:
|
|
563
|
+
x: Input tensor of shape (M, N)
|
|
564
|
+
weight: Weight tensor of shape (N,)
|
|
565
|
+
dout: Upstream gradients tensor of shape (M, N)
|
|
566
|
+
rstd: Reciprocal standard deviation tensor of shape (M,)
|
|
567
|
+
Returns:
|
|
568
|
+
Tuple of (dx, dw) where:
|
|
569
|
+
- dx: Input gradients tensor of same shape as x
|
|
570
|
+
- dw: Weight gradients tensor of same shape as weight
|
|
571
|
+
"""
|
|
572
|
+
assert x.dim() == 2, "Input must be 2D"
|
|
573
|
+
assert weight.dim() == 1, "Weight must be 1D"
|
|
574
|
+
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
575
|
+
assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
|
|
576
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
577
|
+
assert weight.dtype == torch.float32, "Weight must be float32"
|
|
578
|
+
|
|
579
|
+
M, N = x.shape
|
|
580
|
+
dx = torch.empty_like(x)
|
|
581
|
+
|
|
582
|
+
device = x.device
|
|
583
|
+
|
|
584
|
+
sm_count = torch.cuda.get_device_properties(device).multi_processor_count * 8
|
|
585
|
+
dw_partial = torch.zeros((sm_count, N), device=device, dtype=weight.dtype)
|
|
586
|
+
|
|
587
|
+
dtype = torch2cute_dtype_map[x.dtype]
|
|
588
|
+
|
|
589
|
+
convert_from_dlpack = lambda tensor: (
|
|
590
|
+
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
591
|
+
mode=0, stride_order=(0, 1)
|
|
592
|
+
)
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
x_tensor, dout_tensor, dx_tensor = [convert_from_dlpack(tensor) for tensor in (x, dout, dx)]
|
|
596
|
+
|
|
597
|
+
weight_tensor = utils.convert_from_dlpack(
|
|
598
|
+
weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
dw_partial_tensor = convert_from_dlpack(dw_partial)
|
|
602
|
+
rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
603
|
+
|
|
604
|
+
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
605
|
+
|
|
606
|
+
compile_key = (dtype, N)
|
|
607
|
+
if compile_key not in _rmsnorm_backward.compile_cache:
|
|
608
|
+
rmsnorm_backward_op = RMSNormBackward(dtype, N)
|
|
609
|
+
_rmsnorm_backward.compile_cache[compile_key] = cute.compile(
|
|
610
|
+
rmsnorm_backward_op,
|
|
611
|
+
x_tensor,
|
|
612
|
+
weight_tensor,
|
|
613
|
+
dout_tensor,
|
|
614
|
+
rstd_tensor,
|
|
615
|
+
dx_tensor,
|
|
616
|
+
dw_partial_tensor,
|
|
617
|
+
sm_count,
|
|
618
|
+
current_stream,
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
_rmsnorm_backward.compile_cache[compile_key](
|
|
622
|
+
x_tensor,
|
|
623
|
+
weight_tensor,
|
|
624
|
+
dout_tensor,
|
|
625
|
+
rstd_tensor,
|
|
626
|
+
dx_tensor,
|
|
627
|
+
dw_partial_tensor,
|
|
628
|
+
current_stream,
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
dw = dw_partial.sum(dim=0).to(weight.dtype)
|
|
632
|
+
return dx, dw
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
_rmsnorm_backward.compile_cache = {}
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
class RMSNormFunction(torch.autograd.Function):
|
|
639
|
+
@staticmethod
|
|
640
|
+
def forward(ctx, x, weight, eps):
|
|
641
|
+
out, rstd = _rmsnorm_fwd(x, weight, eps, return_rstd=True)
|
|
642
|
+
ctx.save_for_backward(x, weight, rstd)
|
|
643
|
+
ctx.eps = eps
|
|
644
|
+
return out
|
|
645
|
+
|
|
646
|
+
@staticmethod
|
|
647
|
+
def backward(ctx, dout):
|
|
648
|
+
x, weight, rstd = ctx.saved_tensors
|
|
649
|
+
dx, dw = _rmsnorm_backward(x, weight, dout, rstd)
|
|
650
|
+
# dw is returned for weight gradient, None for eps gradient
|
|
651
|
+
return dx, dw, None
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
|
|
655
|
+
"""RMSNorm forward pass with automatic differentiation support.
|
|
656
|
+
|
|
657
|
+
Args:
|
|
658
|
+
x: Input tensor of shape (M, N)
|
|
659
|
+
weight: Weight tensor of shape (N,)
|
|
660
|
+
eps: Small value for numerical stability
|
|
661
|
+
|
|
662
|
+
Returns:
|
|
663
|
+
Normalized output tensor of same shape as x
|
|
664
|
+
"""
|
|
665
|
+
return RMSNormFunction.apply(x, weight, eps)
|
quack/softmax.py
CHANGED
|
@@ -98,7 +98,10 @@ class Softmax(ReductionBase):
|
|
|
98
98
|
shape = mX.shape
|
|
99
99
|
idX = cute.make_identity_tensor(shape)
|
|
100
100
|
# slice for CTAs
|
|
101
|
-
|
|
101
|
+
# We use domain_offset_i64 to deal with tensors larger than 2^31 elements
|
|
102
|
+
mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
|
|
103
|
+
gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
|
|
104
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
102
105
|
|
|
103
106
|
smem = cutlass.utils.SmemAllocator()
|
|
104
107
|
sX = smem.allocate_tensor(
|
|
@@ -312,9 +315,11 @@ class SoftmaxBackward(ReductionBase):
|
|
|
312
315
|
shape = mdY.shape
|
|
313
316
|
idX = cute.make_identity_tensor(shape)
|
|
314
317
|
# slice for CTAs
|
|
315
|
-
|
|
316
|
-
|
|
318
|
+
mdY, mY, mdX = [
|
|
319
|
+
utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mdY, mY, mdX)
|
|
317
320
|
]
|
|
321
|
+
gdY, gY, gdX = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mdY, mY, mdX)]
|
|
322
|
+
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
|
|
318
323
|
|
|
319
324
|
smem = cutlass.utils.SmemAllocator()
|
|
320
325
|
sdY = smem.allocate_tensor(
|
quack/utils.py
CHANGED
|
@@ -23,20 +23,6 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te
|
|
|
23
23
|
)
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
@cute.jit
|
|
27
|
-
def max_constexpr(
|
|
28
|
-
a: cutlass.Constexpr[cute.Numeric], b: cutlass.Constexpr[cute.Numeric]
|
|
29
|
-
) -> cutlass.Constexpr[cute.Numeric]:
|
|
30
|
-
return a if a > b else b
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
@cute.jit
|
|
34
|
-
def min_constexpr(
|
|
35
|
-
a: cutlass.Constexpr[cute.Numeric], b: cutlass.Constexpr[cute.Numeric]
|
|
36
|
-
) -> cutlass.Constexpr[cute.Numeric]:
|
|
37
|
-
return a if a < b else b
|
|
38
|
-
|
|
39
|
-
|
|
40
26
|
@cute.jit
|
|
41
27
|
def warp_reduce(
|
|
42
28
|
val: cute.TensorSSA | cute.Numeric,
|
|
@@ -196,7 +182,7 @@ def row_reduce(
|
|
|
196
182
|
val = warp_reduce(
|
|
197
183
|
val,
|
|
198
184
|
warp_op,
|
|
199
|
-
width=
|
|
185
|
+
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
200
186
|
)
|
|
201
187
|
if cutlass.const_expr(hook_fn is not None):
|
|
202
188
|
hook_fn()
|
|
@@ -226,7 +212,7 @@ def online_softmax_reduce(
|
|
|
226
212
|
max_x = warp_reduce(
|
|
227
213
|
x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
|
|
228
214
|
cute.arch.fmax,
|
|
229
|
-
width=
|
|
215
|
+
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
230
216
|
)
|
|
231
217
|
log2_e = math.log2(math.e)
|
|
232
218
|
exp_x = exp2f(x * log2_e - (max_x * log2_e))
|
|
@@ -234,7 +220,7 @@ def online_softmax_reduce(
|
|
|
234
220
|
sum_exp_x = warp_reduce(
|
|
235
221
|
exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
|
|
236
222
|
operator.add,
|
|
237
|
-
width=
|
|
223
|
+
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
238
224
|
)
|
|
239
225
|
if cutlass.const_expr(hook_fn is not None):
|
|
240
226
|
hook_fn()
|
|
@@ -303,7 +289,6 @@ def online_softmax_reduce(
|
|
|
303
289
|
@cute.jit
|
|
304
290
|
def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
|
|
305
291
|
"""exp2f calculation for both vector and scalar.
|
|
306
|
-
|
|
307
292
|
:param x: input value
|
|
308
293
|
:type x: cute.TensorSSA or Float32
|
|
309
294
|
:return: exp2 value
|
|
@@ -405,3 +390,19 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
|
|
|
405
390
|
vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip)
|
|
406
391
|
)
|
|
407
392
|
return res0, res1
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
@dsl_user_op
|
|
396
|
+
def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
|
397
|
+
flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
|
|
398
|
+
flat_stride = cute.flatten_to_tuple(tensor.stride)
|
|
399
|
+
offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
|
|
400
|
+
assert isinstance(tensor.iterator, cute.Pointer)
|
|
401
|
+
# HACK: we assume that applying the offset does not change the pointer alignment
|
|
402
|
+
new_ptr = cute.make_ptr(
|
|
403
|
+
tensor.element_type,
|
|
404
|
+
tensor.iterator.toint() + offset * tensor.element_type.width // 8,
|
|
405
|
+
tensor.memspace,
|
|
406
|
+
assumed_align=tensor.iterator.max_alignment,
|
|
407
|
+
)
|
|
408
|
+
return cute.make_tensor(new_ptr, tensor.layout)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
quack/__init__.py,sha256=EV_43VfcxQUsFu_Nfq0944ImloZ2T9X94-8T9IQM0I8,203
|
|
2
|
+
quack/cross_entropy.py,sha256=bg66wECki5I71SMPIRUa-6-oFJ93aIKpK1jqT__SCBM,19775
|
|
3
|
+
quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
|
|
4
|
+
quack/reduction_base.py,sha256=fFuGXPR3lDq2yw_m86ujmkni6R51jzNAzy_r9R6C8tA,3563
|
|
5
|
+
quack/rmsnorm.py,sha256=rbRP_O-EJYvrvxYPFwfy0yCbG7Qs_DgbzgacxTLvih4,24159
|
|
6
|
+
quack/softmax.py,sha256=b-QEQiGjOEPidolE2--K21kwEaLJ-3wQoGIz_BsEcSI,16742
|
|
7
|
+
quack/utils.py,sha256=laz_lqeggiIOY_xCs3c3VvCPP8bJmSKBfjFpFR1Al80,15946
|
|
8
|
+
quack_kernels-0.1.6.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
9
|
+
quack_kernels-0.1.6.dist-info/METADATA,sha256=NKhSoudW9lNdYryNiMkyjKXYl5n37cuFMbyLv3zr5js,289
|
|
10
|
+
quack_kernels-0.1.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
11
|
+
quack_kernels-0.1.6.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
12
|
+
quack_kernels-0.1.6.dist-info/RECORD,,
|
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
quack/__init__.py,sha256=cFLxO6nA_faFqHf4N-Fy7G0j8ykuYPB1uOt9uoJ2dkQ,203
|
|
2
|
-
quack/cross_entropy.py,sha256=HnF2OErEzb10SWxY6HoYE42lnvlw2DsWCks7mylPwnI,9511
|
|
3
|
-
quack/reduction_base.py,sha256=Rsj9ZeSHcKAXGn1p7mY1vrrBqxevi4feLjY0JJhKnmY,3663
|
|
4
|
-
quack/rmsnorm.py,sha256=TkOZsXJwcsoZMLnmEWQ-pEF0r-iiZhGrCNLSFCXfv6s,10676
|
|
5
|
-
quack/softmax.py,sha256=VfhlC2huRuv7olFSVFgS8LF1yF8TFV64yjjjQxYX9yk,16364
|
|
6
|
-
quack/utils.py,sha256=zVc9U-5No19trE585KqDdXx9chAruXPRIPMZdO7mkRg,15603
|
|
7
|
-
quack_kernels-0.1.4.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
8
|
-
quack_kernels-0.1.4.dist-info/METADATA,sha256=xl62C5WFgiUbnOICAzjldsljJ9j1Fb_JxZVksHLCI8I,289
|
|
9
|
-
quack_kernels-0.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
10
|
-
quack_kernels-0.1.4.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
11
|
-
quack_kernels-0.1.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|