quack-kernels 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/cross_entropy.py +304 -14
- quack/reduction_base.py +3 -6
- quack/rmsnorm.py +398 -20
- quack/softmax.py +25 -17
- quack/utils.py +17 -29
- {quack_kernels-0.1.3.dist-info → quack_kernels-0.1.5.dist-info}/METADATA +2 -2
- quack_kernels-0.1.5.dist-info/RECORD +11 -0
- quack_kernels-0.1.3.dist-info/RECORD +0 -11
- {quack_kernels-0.1.3.dist-info → quack_kernels-0.1.5.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.3.dist-info → quack_kernels-0.1.5.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.3.dist-info → quack_kernels-0.1.5.dist-info}/top_level.txt +0 -0
quack/__init__.py
CHANGED
quack/cross_entropy.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
|
+
|
|
1
3
|
import math
|
|
2
4
|
import torch
|
|
3
5
|
from typing import Optional, Type
|
|
@@ -77,7 +79,7 @@ class CrossEntropy(ReductionBase):
|
|
|
77
79
|
self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
|
|
78
80
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
79
81
|
block=[num_threads, 1, 1],
|
|
80
|
-
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
82
|
+
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
81
83
|
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
82
84
|
stream=stream,
|
|
83
85
|
)
|
|
@@ -93,15 +95,16 @@ class CrossEntropy(ReductionBase):
|
|
|
93
95
|
tiler_mn: cute.Shape,
|
|
94
96
|
):
|
|
95
97
|
tidx, _, _ = cute.arch.thread_idx()
|
|
96
|
-
bidx,
|
|
98
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
99
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
100
|
+
cluster_y = cute.arch.block_idx()[1]
|
|
101
|
+
else:
|
|
102
|
+
cluster_y = cutlass.const_expr(0)
|
|
97
103
|
|
|
98
104
|
shape: cute.Shape = mX.shape
|
|
99
105
|
idX = cute.make_identity_tensor(shape)
|
|
100
106
|
# slice for CTAs
|
|
101
|
-
gX, cX = [
|
|
102
|
-
cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
|
|
103
|
-
for mT in (mX, idX)
|
|
104
|
-
]
|
|
107
|
+
gX, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, idX)]
|
|
105
108
|
|
|
106
109
|
smem = cutlass.utils.SmemAllocator()
|
|
107
110
|
sX = smem.allocate_tensor(
|
|
@@ -131,7 +134,9 @@ class CrossEntropy(ReductionBase):
|
|
|
131
134
|
|
|
132
135
|
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
133
136
|
tXpX = (
|
|
134
|
-
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
137
|
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
138
|
+
if cutlass.const_expr(not is_even_N)
|
|
139
|
+
else None
|
|
135
140
|
)
|
|
136
141
|
if row < shape[0]:
|
|
137
142
|
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
|
@@ -154,7 +159,7 @@ class CrossEntropy(ReductionBase):
|
|
|
154
159
|
cute.ReductionOp.MAX,
|
|
155
160
|
threads_per_row,
|
|
156
161
|
reduction_buffer[None, None, 0],
|
|
157
|
-
mbar_ptr + 0 if self.cluster_n > 1 else None,
|
|
162
|
+
mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
158
163
|
init_val=-cutlass.Float32.inf,
|
|
159
164
|
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
160
165
|
)
|
|
@@ -172,7 +177,7 @@ class CrossEntropy(ReductionBase):
|
|
|
172
177
|
cute.ReductionOp.ADD,
|
|
173
178
|
threads_per_row,
|
|
174
179
|
reduction_buffer[None, None, 1],
|
|
175
|
-
mbar_ptr + 1 if self.cluster_n > 1 else None,
|
|
180
|
+
mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
176
181
|
init_val=0.0,
|
|
177
182
|
)
|
|
178
183
|
else:
|
|
@@ -197,7 +202,7 @@ class CrossEntropy(ReductionBase):
|
|
|
197
202
|
mLSE[row] = lse
|
|
198
203
|
|
|
199
204
|
|
|
200
|
-
def
|
|
205
|
+
def _cross_entropy(
|
|
201
206
|
x: torch.Tensor,
|
|
202
207
|
target: torch.Tensor,
|
|
203
208
|
return_lse: bool = False,
|
|
@@ -238,15 +243,300 @@ def cross_entropy(
|
|
|
238
243
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
239
244
|
|
|
240
245
|
compile_key = (dtype, N, lse is not None)
|
|
241
|
-
if compile_key not in
|
|
246
|
+
if compile_key not in _cross_entropy.compile_cache:
|
|
242
247
|
cross_entropy_op = CrossEntropy(dtype, N)
|
|
243
|
-
|
|
248
|
+
_cross_entropy.compile_cache[compile_key] = cute.compile(
|
|
244
249
|
cross_entropy_op, x_tensor, target_tensor, loss_tensor, lse_tensor, stream
|
|
245
250
|
)
|
|
246
|
-
|
|
251
|
+
_cross_entropy.compile_cache[compile_key](
|
|
247
252
|
x_tensor, target_tensor, loss_tensor, lse_tensor, stream
|
|
248
253
|
)
|
|
249
254
|
return loss if not return_lse else (loss, lse)
|
|
250
255
|
|
|
251
256
|
|
|
252
|
-
|
|
257
|
+
_cross_entropy.compile_cache = {}
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
class CrossEntropyBackward:
|
|
261
|
+
def __init__(self, dtype: Type[cutlass.Numeric], N: int):
|
|
262
|
+
self.dtype = dtype
|
|
263
|
+
self.N = N
|
|
264
|
+
self.vecsize = 128 // dtype.width
|
|
265
|
+
|
|
266
|
+
def _calculate_threads_per_row(self):
|
|
267
|
+
N = self.N
|
|
268
|
+
return (
|
|
269
|
+
8
|
|
270
|
+
if N <= 64
|
|
271
|
+
else (
|
|
272
|
+
16
|
|
273
|
+
if N <= 128
|
|
274
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
275
|
+
)
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
def _get_tv_layout(self):
|
|
279
|
+
N = self.N
|
|
280
|
+
vecsize = self.vecsize
|
|
281
|
+
num_threads = 128 if N <= 16384 else 256
|
|
282
|
+
threads_per_row = self._calculate_threads_per_row()
|
|
283
|
+
cols_per_block = num_threads // threads_per_row
|
|
284
|
+
num_blocks_N = cute.ceil_div(min(N, 16384) // vecsize, threads_per_row)
|
|
285
|
+
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
|
|
286
|
+
tv_layout = cute.make_layout(
|
|
287
|
+
((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
|
|
288
|
+
stride=(
|
|
289
|
+
(vecsize * cols_per_block, 1),
|
|
290
|
+
(cols_per_block, cols_per_block * vecsize * threads_per_row),
|
|
291
|
+
),
|
|
292
|
+
)
|
|
293
|
+
return tiler_mn, tv_layout
|
|
294
|
+
|
|
295
|
+
@cute.jit
|
|
296
|
+
def __call__(
|
|
297
|
+
self,
|
|
298
|
+
mX: cute.Tensor,
|
|
299
|
+
mTarget: cute.Tensor,
|
|
300
|
+
mDLoss: cute.Tensor,
|
|
301
|
+
mdX: cute.Tensor,
|
|
302
|
+
mLSE: cute.Tensor,
|
|
303
|
+
stream: cuda.CUstream,
|
|
304
|
+
):
|
|
305
|
+
assert mX.element_type == self.dtype
|
|
306
|
+
assert mdX.element_type == self.dtype
|
|
307
|
+
|
|
308
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
309
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
310
|
+
|
|
311
|
+
mDLoss = cute.make_tensor(
|
|
312
|
+
mDLoss.iterator, cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,)))
|
|
313
|
+
)
|
|
314
|
+
mTarget = cute.make_tensor(
|
|
315
|
+
mTarget.iterator, cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,)))
|
|
316
|
+
)
|
|
317
|
+
mLSE = cute.make_tensor(
|
|
318
|
+
mLSE.iterator, cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,)))
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
smem_size = cute.size_in_bytes(
|
|
322
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0))
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
self.kernel(
|
|
326
|
+
mX,
|
|
327
|
+
mTarget,
|
|
328
|
+
mDLoss,
|
|
329
|
+
mdX,
|
|
330
|
+
mLSE,
|
|
331
|
+
mX.shape,
|
|
332
|
+
tv_layout,
|
|
333
|
+
tiler_mn,
|
|
334
|
+
).launch(
|
|
335
|
+
grid=[
|
|
336
|
+
cute.ceil_div(mX.shape[0], tiler_mn[0]),
|
|
337
|
+
cute.ceil_div(mX.shape[1], tiler_mn[1]),
|
|
338
|
+
1,
|
|
339
|
+
],
|
|
340
|
+
block=[num_threads, 1, 1],
|
|
341
|
+
smem=smem_size,
|
|
342
|
+
stream=stream,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
@cute.kernel
|
|
346
|
+
def kernel(
|
|
347
|
+
self,
|
|
348
|
+
mX: cute.Tensor, # (M, N)
|
|
349
|
+
mTarget: cute.Tensor, # (M,)
|
|
350
|
+
mDLoss: cute.Tensor, # (M,)
|
|
351
|
+
mdX: cute.Tensor, # (M, N)
|
|
352
|
+
mLSE: cute.Tensor, # (M,)
|
|
353
|
+
shape: cute.Shape,
|
|
354
|
+
tv_layout: cute.Layout,
|
|
355
|
+
tiler_mn: cute.Shape,
|
|
356
|
+
):
|
|
357
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
358
|
+
bidx, bidy, _ = cute.arch.block_idx()
|
|
359
|
+
|
|
360
|
+
smem = cutlass.utils.SmemAllocator()
|
|
361
|
+
sX = smem.allocate_tensor(
|
|
362
|
+
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
idX = cute.make_identity_tensor(shape)
|
|
366
|
+
|
|
367
|
+
gX, gdX, cX, gTarget, gDLoss, gLse = [
|
|
368
|
+
cute.local_tile(mT, tiler_mn, (bidx, bidy))
|
|
369
|
+
for mT in (mX, mdX, idX, mTarget, mDLoss, mLSE)
|
|
370
|
+
]
|
|
371
|
+
|
|
372
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
373
|
+
cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128
|
|
374
|
+
)
|
|
375
|
+
copy_atom_load_X_async = cute.make_copy_atom(
|
|
376
|
+
cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
|
|
377
|
+
)
|
|
378
|
+
copy_atom_store_O = cute.make_copy_atom(
|
|
379
|
+
cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=128
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
383
|
+
thr_copy_X_async = cute.make_tiled_copy(
|
|
384
|
+
copy_atom_load_X_async, tv_layout, tiler_mn
|
|
385
|
+
).get_slice(tidx)
|
|
386
|
+
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
387
|
+
|
|
388
|
+
#### Thread View
|
|
389
|
+
tXgX = thr_copy_X_async.partition_S(gX)
|
|
390
|
+
tXsX = thr_copy_X_async.partition_S(sX)
|
|
391
|
+
|
|
392
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
393
|
+
tXcFull = thr_copy_X.partition_S(cX) # improve
|
|
394
|
+
|
|
395
|
+
tXgO = thr_copy_O.partition_D(gdX)
|
|
396
|
+
|
|
397
|
+
# allocate fragments for gmem->rmem
|
|
398
|
+
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
|
|
399
|
+
|
|
400
|
+
is_even_N = cutlass.const_expr(shape[1] % tiler_mn[1] == 0)
|
|
401
|
+
row = tXcX[0][0]
|
|
402
|
+
|
|
403
|
+
tXpX = (
|
|
404
|
+
utils.predicate_k(thr_copy_X_async.partition_S(cX), limit=shape[1])
|
|
405
|
+
if not is_even_N
|
|
406
|
+
else None
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
if row < shape[0]:
|
|
410
|
+
cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
|
|
411
|
+
cute.arch.cp_async_commit_group()
|
|
412
|
+
cute.arch.cp_async_wait_group(0)
|
|
413
|
+
if cutlass.const_expr(not is_even_N):
|
|
414
|
+
utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
|
|
415
|
+
|
|
416
|
+
cute.autovec_copy(tXsX, tXrX)
|
|
417
|
+
x = tXrX.load().to(cute.Float32)
|
|
418
|
+
|
|
419
|
+
label = cute.Int32.zero
|
|
420
|
+
dloss = cute.Float32.zero
|
|
421
|
+
lse = cute.Float32.zero
|
|
422
|
+
if row < shape[0]:
|
|
423
|
+
label = cute.Int32(mTarget[row])
|
|
424
|
+
dloss = cute.Float32(mDLoss[row])
|
|
425
|
+
lse = cute.Float32(mLSE[row])
|
|
426
|
+
|
|
427
|
+
log2_e = math.log2(math.e)
|
|
428
|
+
probs = utils.exp2f((x - lse) * log2_e)
|
|
429
|
+
prob_shifted = probs - 1.0
|
|
430
|
+
|
|
431
|
+
mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
|
|
432
|
+
for i in cutlass.range_constexpr(cute.size(tXcFull)):
|
|
433
|
+
mask[i] = tXcFull[i][1] == label
|
|
434
|
+
|
|
435
|
+
mask = mask.load()
|
|
436
|
+
grad = cute.where(mask, prob_shifted, probs)
|
|
437
|
+
grad = grad * dloss
|
|
438
|
+
|
|
439
|
+
tXrO.store(grad.to(tXrO.element_type))
|
|
440
|
+
tOpO = (
|
|
441
|
+
utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
442
|
+
)
|
|
443
|
+
if row < shape[0]:
|
|
444
|
+
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def _cross_entropy_backward(
|
|
448
|
+
x: torch.Tensor,
|
|
449
|
+
target: torch.Tensor,
|
|
450
|
+
dloss: torch.Tensor,
|
|
451
|
+
lse: torch.Tensor,
|
|
452
|
+
inplace_backward: bool = False,
|
|
453
|
+
) -> torch.Tensor:
|
|
454
|
+
"""Cross entropy backward pass.
|
|
455
|
+
Args:
|
|
456
|
+
x: Input logits tensor of shape (M, N)
|
|
457
|
+
target: Target class indices tensor of shape (M,)
|
|
458
|
+
dloss: Upstream gradients tensor of shape (M,)
|
|
459
|
+
lse: Log-sum-exp values tensor of shape (M,)
|
|
460
|
+
Returns:
|
|
461
|
+
Input gradients tensor of shape (M, N)
|
|
462
|
+
"""
|
|
463
|
+
assert x.dim() == 2, "Input must be 2D"
|
|
464
|
+
assert target.dim() == 1, "Target must be 1D"
|
|
465
|
+
assert dloss.dim() == 1, "dloss must be 1D"
|
|
466
|
+
assert lse.dim() == 1, "lse must be 1D"
|
|
467
|
+
assert x.shape[0] == target.shape[0], "Batch dimensions must match"
|
|
468
|
+
assert x.shape[0] == dloss.shape[0], "Batch dimensions must match"
|
|
469
|
+
assert x.shape[0] == lse.shape[0], "Batch dimensions must match"
|
|
470
|
+
assert (
|
|
471
|
+
x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
|
|
472
|
+
), "Tensors must be on CUDA device"
|
|
473
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
|
|
474
|
+
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
475
|
+
|
|
476
|
+
M, N = x.shape
|
|
477
|
+
dx = torch.empty_like(x) if not inplace_backward else x
|
|
478
|
+
dtype = torch2cute_dtype_map[x.dtype]
|
|
479
|
+
|
|
480
|
+
convert_from_dlpack = lambda tensor: (
|
|
481
|
+
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
482
|
+
mode=0, stride_order=(0, 1)
|
|
483
|
+
)
|
|
484
|
+
)
|
|
485
|
+
x_tensor = convert_from_dlpack(x)
|
|
486
|
+
dx_tensor = convert_from_dlpack(dx)
|
|
487
|
+
dloss_tensor = from_dlpack(dloss.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
488
|
+
lse_tensor = from_dlpack(lse.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
|
|
489
|
+
target_tensor = from_dlpack(target.detach(), assumed_align=32).mark_compact_shape_dynamic(
|
|
490
|
+
mode=0
|
|
491
|
+
)
|
|
492
|
+
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
493
|
+
|
|
494
|
+
compile_key = (dtype, N)
|
|
495
|
+
if compile_key not in _cross_entropy_backward.compile_cache:
|
|
496
|
+
cross_entropy_backward_op = CrossEntropyBackward(dtype, N)
|
|
497
|
+
_cross_entropy_backward.compile_cache[compile_key] = cute.compile(
|
|
498
|
+
cross_entropy_backward_op,
|
|
499
|
+
x_tensor,
|
|
500
|
+
target_tensor,
|
|
501
|
+
dloss_tensor,
|
|
502
|
+
dx_tensor,
|
|
503
|
+
lse_tensor,
|
|
504
|
+
stream,
|
|
505
|
+
)
|
|
506
|
+
_cross_entropy_backward.compile_cache[compile_key](
|
|
507
|
+
x_tensor, target_tensor, dloss_tensor, dx_tensor, lse_tensor, stream
|
|
508
|
+
)
|
|
509
|
+
return dx
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
_cross_entropy_backward.compile_cache = {}
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
class CrossEntropyFunction(torch.autograd.Function):
|
|
516
|
+
@staticmethod
|
|
517
|
+
def forward(ctx, x, target, inplace_backward=False):
|
|
518
|
+
loss, lse = _cross_entropy(x, target, return_lse=True)
|
|
519
|
+
ctx.save_for_backward(x, target, lse)
|
|
520
|
+
ctx.inplace_backward = inplace_backward
|
|
521
|
+
return loss
|
|
522
|
+
|
|
523
|
+
@staticmethod
|
|
524
|
+
def backward(ctx, dloss):
|
|
525
|
+
x, target, lse = ctx.saved_tensors
|
|
526
|
+
dx = _cross_entropy_backward(x, target, dloss, lse, inplace_backward=ctx.inplace_backward)
|
|
527
|
+
return dx, None, None
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def cross_entropy(
|
|
531
|
+
x: torch.Tensor, target: torch.Tensor, inplace_backward: bool = False
|
|
532
|
+
) -> torch.Tensor:
|
|
533
|
+
"""Cross entropy loss with automatic differentiation support.
|
|
534
|
+
|
|
535
|
+
Args:
|
|
536
|
+
x: Input logits tensor of shape (M, N)
|
|
537
|
+
target: Target class indices tensor of shape (M,)
|
|
538
|
+
|
|
539
|
+
Returns:
|
|
540
|
+
Cross entropy loss tensor of shape (M,)
|
|
541
|
+
"""
|
|
542
|
+
return CrossEntropyFunction.apply(x, target, inplace_backward)
|
quack/reduction_base.py
CHANGED
|
@@ -6,8 +6,6 @@ from typing import Type, Tuple, Optional
|
|
|
6
6
|
import cutlass
|
|
7
7
|
import cutlass.cute as cute
|
|
8
8
|
|
|
9
|
-
import quack.utils as utils
|
|
10
|
-
|
|
11
9
|
|
|
12
10
|
torch2cute_dtype_map = {
|
|
13
11
|
torch.float16: cutlass.Float16,
|
|
@@ -39,7 +37,6 @@ class ReductionBase:
|
|
|
39
37
|
vecsize = copy_bits // self.dtype.width
|
|
40
38
|
assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
|
|
41
39
|
num_threads = self._get_num_threads()
|
|
42
|
-
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
43
40
|
assert num_threads % cute.arch.WARP_SIZE == 0
|
|
44
41
|
|
|
45
42
|
threads_per_row = self._calculate_threads_per_row()
|
|
@@ -64,7 +61,7 @@ class ReductionBase:
|
|
|
64
61
|
|
|
65
62
|
def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int):
|
|
66
63
|
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
67
|
-
warps_per_row =
|
|
64
|
+
warps_per_row = max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
|
|
68
65
|
return cute.make_ordered_layout(
|
|
69
66
|
(num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
|
|
70
67
|
order=(1, 0, 2),
|
|
@@ -88,10 +85,10 @@ class ReductionBase:
|
|
|
88
85
|
def _initialize_cluster(self, tidx: cutlass.Int32, mbar_ptr: cute.Pointer, num_warps: int):
|
|
89
86
|
if cutlass.const_expr(self.cluster_n > 1):
|
|
90
87
|
if tidx < self.stage:
|
|
91
|
-
cute.arch.
|
|
88
|
+
cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
|
|
92
89
|
cute.arch.mbarrier_init_fence()
|
|
93
90
|
if tidx < self.stage:
|
|
94
|
-
cute.arch.
|
|
91
|
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
95
92
|
mbar_ptr + tidx, num_warps * self.cluster_n * self.reduction_dtype.width // 8
|
|
96
93
|
)
|
|
97
94
|
# Cluster arrive after barrier init
|
quack/rmsnorm.py
CHANGED
|
@@ -9,7 +9,6 @@ import cuda.bindings.driver as cuda
|
|
|
9
9
|
import cutlass
|
|
10
10
|
import cutlass.cute as cute
|
|
11
11
|
from cutlass.cute.runtime import from_dlpack
|
|
12
|
-
|
|
13
12
|
import quack.utils as utils
|
|
14
13
|
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
15
14
|
|
|
@@ -84,7 +83,7 @@ class RMSNorm(ReductionBase):
|
|
|
84
83
|
self.kernel(mX, mW, mO, mRstd, eps, tv_layout, tiler_mn, self.reload_from).launch(
|
|
85
84
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
86
85
|
block=[num_threads, 1, 1],
|
|
87
|
-
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
86
|
+
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
88
87
|
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
89
88
|
stream=stream,
|
|
90
89
|
)
|
|
@@ -103,7 +102,11 @@ class RMSNorm(ReductionBase):
|
|
|
103
102
|
delay_w_load: cutlass.Constexpr = False,
|
|
104
103
|
):
|
|
105
104
|
tidx, _, _ = cute.arch.thread_idx()
|
|
106
|
-
bidx,
|
|
105
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
106
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
107
|
+
cluster_y = cute.arch.block_idx()[1]
|
|
108
|
+
else:
|
|
109
|
+
cluster_y = cutlass.const_expr(0)
|
|
107
110
|
|
|
108
111
|
smem = cutlass.utils.SmemAllocator()
|
|
109
112
|
sX = smem.allocate_tensor(
|
|
@@ -114,13 +117,10 @@ class RMSNorm(ReductionBase):
|
|
|
114
117
|
shape = mX.shape
|
|
115
118
|
idX = cute.make_identity_tensor(shape)
|
|
116
119
|
# slice for CTAs
|
|
117
|
-
gX, gO, cX = [
|
|
118
|
-
|
|
119
|
-
for mT in (mX, mO, idX)
|
|
120
|
-
]
|
|
121
|
-
gW = cute.local_tile(mW, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
|
|
120
|
+
gX, gO, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)]
|
|
121
|
+
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
|
|
122
122
|
gRstd = (
|
|
123
|
-
cute.local_tile(mRstd, tiler_mn, (bidx,
|
|
123
|
+
cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
|
|
124
124
|
if cutlass.const_expr(mRstd is not None)
|
|
125
125
|
else None
|
|
126
126
|
)
|
|
@@ -167,7 +167,7 @@ class RMSNorm(ReductionBase):
|
|
|
167
167
|
cute.arch.cp_async_commit_group()
|
|
168
168
|
|
|
169
169
|
tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
|
|
170
|
-
if not delay_w_load:
|
|
170
|
+
if cutlass.const_expr(not delay_w_load):
|
|
171
171
|
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
172
172
|
|
|
173
173
|
cute.arch.cp_async_wait_group(0)
|
|
@@ -192,12 +192,12 @@ class RMSNorm(ReductionBase):
|
|
|
192
192
|
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
193
193
|
):
|
|
194
194
|
tXrRstd[0] = rstd
|
|
195
|
-
if delay_w_load:
|
|
195
|
+
if cutlass.const_expr(delay_w_load):
|
|
196
196
|
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
197
|
-
if reload_from == "smem":
|
|
197
|
+
if cutlass.const_expr(reload_from == "smem"):
|
|
198
198
|
cute.autovec_copy(tXsX, tXrX)
|
|
199
199
|
x = tXrX.load().to(cute.Float32)
|
|
200
|
-
elif reload_from == "gmem":
|
|
200
|
+
elif cutlass.const_expr(reload_from == "gmem"):
|
|
201
201
|
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
202
202
|
x = tXrX.load().to(cute.Float32)
|
|
203
203
|
x_hat = x * rstd
|
|
@@ -209,20 +209,18 @@ class RMSNorm(ReductionBase):
|
|
|
209
209
|
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
210
210
|
|
|
211
211
|
|
|
212
|
-
def
|
|
212
|
+
def _rmsnorm_fwd(
|
|
213
213
|
x: torch.Tensor,
|
|
214
214
|
weight: torch.Tensor,
|
|
215
215
|
eps: float = 1e-6,
|
|
216
216
|
return_rstd: bool = False,
|
|
217
217
|
) -> torch.Tensor:
|
|
218
218
|
"""RMSNorm forward pass.
|
|
219
|
-
|
|
220
219
|
Args:
|
|
221
220
|
x: Input tensor of shape (M, N)
|
|
222
221
|
weight: Weight tensor of shape (N,)
|
|
223
222
|
eps: Small value for numerical stability
|
|
224
223
|
return_rstd: Whether to return the reciprocal standard deviation
|
|
225
|
-
|
|
226
224
|
Returns:
|
|
227
225
|
Normalized output tensor of same shape as x
|
|
228
226
|
If return_rstd is True, also returns rstd tensor of shape (M,)
|
|
@@ -258,18 +256,18 @@ def rmsnorm(
|
|
|
258
256
|
)
|
|
259
257
|
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
260
258
|
compile_key = (dtype, N, rstd is not None)
|
|
261
|
-
if compile_key not in
|
|
259
|
+
if compile_key not in _rmsnorm_fwd.compile_cache:
|
|
262
260
|
rmsnorm_op = RMSNorm(dtype, N)
|
|
263
|
-
|
|
261
|
+
_rmsnorm_fwd.compile_cache[compile_key] = cute.compile(
|
|
264
262
|
rmsnorm_op, x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream
|
|
265
263
|
)
|
|
266
|
-
|
|
264
|
+
_rmsnorm_fwd.compile_cache[compile_key](
|
|
267
265
|
x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream, eps
|
|
268
266
|
)
|
|
269
267
|
return (out, rstd) if return_rstd else out
|
|
270
268
|
|
|
271
269
|
|
|
272
|
-
|
|
270
|
+
_rmsnorm_fwd.compile_cache = {}
|
|
273
271
|
|
|
274
272
|
|
|
275
273
|
def rmsnorm_ref(x, w, eps=1e-6):
|
|
@@ -282,3 +280,383 @@ def rmsnorm_ref(x, w, eps=1e-6):
|
|
|
282
280
|
def rstd_ref(x, eps=1e-6):
|
|
283
281
|
x_f32 = x.float()
|
|
284
282
|
return 1.0 / torch.sqrt(torch.mean(x_f32 * x_f32, dim=-1) + eps)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
|
|
286
|
+
"""Reference implementation for RMSNorm backward pass."""
|
|
287
|
+
x_f32 = x.float()
|
|
288
|
+
x_hat = x_f32 * rstd.unsqueeze(1)
|
|
289
|
+
wdy = dout * w
|
|
290
|
+
c1 = (x_hat * wdy).mean(dim=-1, keepdim=True)
|
|
291
|
+
dx = (wdy - x_hat * c1) * rstd.unsqueeze(1)
|
|
292
|
+
|
|
293
|
+
# dL/dW
|
|
294
|
+
dw = (dout * x_hat).sum(dim=0)
|
|
295
|
+
return dx.to(x.dtype), dw.to(w.dtype)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class RMSNormBackward(ReductionBase):
|
|
299
|
+
def __init__(self, dtype: cutlass.Numeric, N: int):
|
|
300
|
+
# 1 stage for computing mean of x_hat * wdy
|
|
301
|
+
super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Float32)
|
|
302
|
+
|
|
303
|
+
def _calculate_threads_per_row(self):
|
|
304
|
+
N = self.N
|
|
305
|
+
return (
|
|
306
|
+
8
|
|
307
|
+
if N <= 64
|
|
308
|
+
else (
|
|
309
|
+
16
|
|
310
|
+
if N <= 128
|
|
311
|
+
else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
|
|
312
|
+
)
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
def _set_cluster_n(self):
|
|
316
|
+
N = self.N
|
|
317
|
+
if cutlass.const_expr(self.dtype.width == 16):
|
|
318
|
+
cluster_n = (
|
|
319
|
+
1
|
|
320
|
+
if N <= 16 * 1024
|
|
321
|
+
else (
|
|
322
|
+
2
|
|
323
|
+
if N <= 32 * 1024
|
|
324
|
+
else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
|
|
325
|
+
)
|
|
326
|
+
)
|
|
327
|
+
else: # fp32
|
|
328
|
+
cluster_n = (
|
|
329
|
+
1
|
|
330
|
+
if N <= 32 * 1024
|
|
331
|
+
else (
|
|
332
|
+
2
|
|
333
|
+
if N <= 64 * 1024
|
|
334
|
+
else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
|
|
335
|
+
)
|
|
336
|
+
)
|
|
337
|
+
self.cluster_n = cluster_n
|
|
338
|
+
|
|
339
|
+
@cute.jit
|
|
340
|
+
def __call__(
|
|
341
|
+
self,
|
|
342
|
+
mX: cute.Tensor,
|
|
343
|
+
mW: cute.Tensor,
|
|
344
|
+
mDout: cute.Tensor,
|
|
345
|
+
mRstd: cute.Tensor,
|
|
346
|
+
mDx: cute.Tensor,
|
|
347
|
+
mDw: cute.Tensor,
|
|
348
|
+
sm_count: cutlass.Constexpr,
|
|
349
|
+
stream: cuda.CUstream,
|
|
350
|
+
):
|
|
351
|
+
self._set_cluster_n()
|
|
352
|
+
tiler_mn, tv_layout = self._get_tv_layout()
|
|
353
|
+
num_threads = cute.size(tv_layout, mode=[0])
|
|
354
|
+
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
355
|
+
|
|
356
|
+
mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
|
|
357
|
+
mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
|
|
358
|
+
|
|
359
|
+
mRstd_expanded_layout = cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,)))
|
|
360
|
+
mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
|
|
361
|
+
|
|
362
|
+
num_blocks = (
|
|
363
|
+
sm_count if tiler_mn[0] == 1 else min(sm_count, cute.ceil_div(1024, tiler_mn[0]))
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
self.kernel(mX, mW, mDout, mRstd, mDx, mDw, sm_count, tv_layout, tiler_mn).launch(
|
|
367
|
+
grid=[num_blocks, self.cluster_n, 1],
|
|
368
|
+
block=[num_threads, 1, 1],
|
|
369
|
+
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
370
|
+
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
371
|
+
stream=stream,
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
@cute.kernel
|
|
375
|
+
def kernel(
|
|
376
|
+
self,
|
|
377
|
+
mX: cute.Tensor,
|
|
378
|
+
mW: cute.Tensor,
|
|
379
|
+
mDout: cute.Tensor,
|
|
380
|
+
mRstd: cute.Tensor,
|
|
381
|
+
mDx: cute.Tensor,
|
|
382
|
+
mDw: cute.Tensor,
|
|
383
|
+
sm_count: cutlass.Constexpr,
|
|
384
|
+
tv_layout: cute.Layout,
|
|
385
|
+
tiler_mn: cute.Shape,
|
|
386
|
+
):
|
|
387
|
+
tidx, _, _ = cute.arch.thread_idx()
|
|
388
|
+
bidx, cluster_y, _ = cute.arch.block_idx()
|
|
389
|
+
gdim, _, _ = cute.arch.grid_dim()
|
|
390
|
+
|
|
391
|
+
shape = mX.shape
|
|
392
|
+
M, N = shape[0], shape[1]
|
|
393
|
+
|
|
394
|
+
idX = cute.make_identity_tensor(shape)
|
|
395
|
+
|
|
396
|
+
smem = cutlass.utils.SmemAllocator()
|
|
397
|
+
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
398
|
+
|
|
399
|
+
copy_atom_load_X = cute.make_copy_atom(
|
|
400
|
+
cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
copy_atom_load_W = cute.make_copy_atom(
|
|
404
|
+
cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
copy_atom_store_dX = cute.make_copy_atom(
|
|
408
|
+
cute.nvgpu.CopyUniversalOp(), mDx.element_type, num_bits_per_copy=128
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
copy_atom_dw = cute.make_copy_atom(
|
|
412
|
+
cute.nvgpu.CopyUniversalOp(), mDw.element_type, num_bits_per_copy=128
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
416
|
+
thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
|
|
417
|
+
thr_copy_dw = cute.make_tiled_copy(copy_atom_dw, tv_layout, tiler_mn).get_slice(tidx)
|
|
418
|
+
thr_store_dx = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
|
|
419
|
+
|
|
420
|
+
gW = cute.local_tile(mW, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
|
|
421
|
+
tWgW = thr_copy_W.partition_S(gW)
|
|
422
|
+
tWrW = cute.make_fragment_like(tWgW)
|
|
423
|
+
tXrW = thr_copy_X.retile(tWrW)
|
|
424
|
+
|
|
425
|
+
gW_coord = cute.local_tile(idX, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
|
|
426
|
+
|
|
427
|
+
tWpW = utils.predicate_k(thr_copy_W.partition_S(gW_coord), limit=shape[1])
|
|
428
|
+
cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
|
|
429
|
+
weight = tXrW.load().to(cute.Float32)
|
|
430
|
+
|
|
431
|
+
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
|
|
432
|
+
|
|
433
|
+
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
434
|
+
|
|
435
|
+
dw_coord = cute.local_tile(idX, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
|
|
436
|
+
tDwpDw = utils.predicate_k(thr_copy_dw.partition_S(dw_coord), limit=shape[1])
|
|
437
|
+
|
|
438
|
+
gDw = cute.local_tile(mDw, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
|
|
439
|
+
tDwgDw = thr_copy_dw.partition_D(gDw)
|
|
440
|
+
tDwrDw = cute.make_fragment_like(tDwgDw)
|
|
441
|
+
dw_accumulator = thr_copy_X.retile(tDwrDw)
|
|
442
|
+
dw_accumulator.fill(0.0)
|
|
443
|
+
|
|
444
|
+
M_pad = ((M + sm_count - 1) // sm_count) * sm_count
|
|
445
|
+
|
|
446
|
+
jump = sm_count if tiler_mn[0] == 1 else min(sm_count, cute.ceil_div(1024, tiler_mn[0]))
|
|
447
|
+
|
|
448
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
449
|
+
cute.arch.cluster_arrive()
|
|
450
|
+
cute.arch.cluster_wait()
|
|
451
|
+
|
|
452
|
+
## need to update range_dynamic since it will be deprecated soon
|
|
453
|
+
for row_offset in cutlass.range_dynamic(bidx, M_pad, jump):
|
|
454
|
+
gX = cute.local_tile(
|
|
455
|
+
mX, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
|
|
456
|
+
)
|
|
457
|
+
gDout = cute.local_tile(
|
|
458
|
+
mDout, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
|
|
459
|
+
)
|
|
460
|
+
gRstd = cute.local_tile(
|
|
461
|
+
mRstd, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
|
|
462
|
+
)
|
|
463
|
+
gDx = cute.local_tile(
|
|
464
|
+
mDx, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
|
|
465
|
+
)
|
|
466
|
+
cX = cute.local_tile(
|
|
467
|
+
idX, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
471
|
+
thrDout = thr_copy_X.partition_S(gDout)
|
|
472
|
+
tXrRstd = thr_copy_W.partition_S(gRstd)
|
|
473
|
+
thrDx = thr_store_dx.partition_D(gDx)
|
|
474
|
+
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
475
|
+
|
|
476
|
+
tXrX, frgDout, frgDx = [cute.make_fragment_like(thr) for thr in (tXgX, thrDout, thrDx)]
|
|
477
|
+
|
|
478
|
+
tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
479
|
+
|
|
480
|
+
if tXcX[0][0] < shape[0]:
|
|
481
|
+
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
482
|
+
cute.copy(copy_atom_load_X, thrDout, frgDout, pred=tXpX)
|
|
483
|
+
|
|
484
|
+
x = tXrX.load().to(cute.Float32)
|
|
485
|
+
dout = frgDout.load().to(cute.Float32)
|
|
486
|
+
|
|
487
|
+
rstd = tXrRstd[0]
|
|
488
|
+
x_hat = x * rstd
|
|
489
|
+
wdy = dout * weight
|
|
490
|
+
|
|
491
|
+
threads_per_row = tv_layout.shape[0][0]
|
|
492
|
+
|
|
493
|
+
row = tXcX[0][0]
|
|
494
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
495
|
+
cute.arch.cluster_arrive()
|
|
496
|
+
cute.arch.cluster_wait()
|
|
497
|
+
else:
|
|
498
|
+
cute.arch.barrier()
|
|
499
|
+
|
|
500
|
+
mean_xhat_wdy = (
|
|
501
|
+
utils.row_reduce(
|
|
502
|
+
x_hat * wdy,
|
|
503
|
+
cute.ReductionOp.ADD,
|
|
504
|
+
threads_per_row,
|
|
505
|
+
reduction_buffer[None, None, 0],
|
|
506
|
+
mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
507
|
+
init_val=0.0,
|
|
508
|
+
hook_fn=cute.arch.cluster_wait
|
|
509
|
+
if cutlass.const_expr(self.cluster_n > 1)
|
|
510
|
+
else None,
|
|
511
|
+
)
|
|
512
|
+
/ shape[1]
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
dx = (wdy - x_hat * mean_xhat_wdy) * rstd
|
|
516
|
+
frgDx.store(dx.to(frgDout.element_type))
|
|
517
|
+
|
|
518
|
+
if row < M:
|
|
519
|
+
cute.copy(copy_atom_store_dX, frgDx, thrDx, pred=tXpX)
|
|
520
|
+
|
|
521
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
522
|
+
cute.arch.cluster_arrive()
|
|
523
|
+
cute.arch.cluster_wait()
|
|
524
|
+
else:
|
|
525
|
+
cute.arch.barrier()
|
|
526
|
+
|
|
527
|
+
if row < M:
|
|
528
|
+
dw_row = dout * x_hat
|
|
529
|
+
current_dw = dw_accumulator.load().to(cute.Float32)
|
|
530
|
+
updated_dw = current_dw + dw_row
|
|
531
|
+
dw_accumulator.store(updated_dw.to(dw_accumulator.element_type))
|
|
532
|
+
|
|
533
|
+
"""
|
|
534
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
535
|
+
cute.arch.cluster_arrive()
|
|
536
|
+
cute.arch.cluster_wait()
|
|
537
|
+
else:
|
|
538
|
+
cute.arch.barrier()
|
|
539
|
+
"""
|
|
540
|
+
"""
|
|
541
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
542
|
+
cute.arch.cluster_arrive()
|
|
543
|
+
cute.arch.cluster_wait()
|
|
544
|
+
else:
|
|
545
|
+
cute.arch.barrier()
|
|
546
|
+
"""
|
|
547
|
+
|
|
548
|
+
cute.autovec_copy(dw_accumulator, tDwrDw)
|
|
549
|
+
cute.copy(copy_atom_dw, tDwrDw, tDwgDw, pred=tDwpDw)
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
def _rmsnorm_backward(
|
|
553
|
+
x: torch.Tensor,
|
|
554
|
+
weight: torch.Tensor,
|
|
555
|
+
dout: torch.Tensor,
|
|
556
|
+
rstd: torch.Tensor,
|
|
557
|
+
) -> (torch.Tensor, torch.Tensor):
|
|
558
|
+
"""RMSNorm backward pass.
|
|
559
|
+
Args:
|
|
560
|
+
x: Input tensor of shape (M, N)
|
|
561
|
+
weight: Weight tensor of shape (N,)
|
|
562
|
+
dout: Upstream gradients tensor of shape (M, N)
|
|
563
|
+
rstd: Reciprocal standard deviation tensor of shape (M,)
|
|
564
|
+
Returns:
|
|
565
|
+
Tuple of (dx, dw) where:
|
|
566
|
+
- dx: Input gradients tensor of same shape as x
|
|
567
|
+
- dw: Weight gradients tensor of same shape as weight
|
|
568
|
+
"""
|
|
569
|
+
assert x.dim() == 2, "Input must be 2D"
|
|
570
|
+
assert weight.dim() == 1, "Weight must be 1D"
|
|
571
|
+
assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
|
|
572
|
+
assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
|
|
573
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
|
|
574
|
+
assert weight.dtype == torch.float32, "Weight must be float32"
|
|
575
|
+
|
|
576
|
+
M, N = x.shape
|
|
577
|
+
dx = torch.empty_like(x)
|
|
578
|
+
|
|
579
|
+
device = x.device
|
|
580
|
+
|
|
581
|
+
sm_count = torch.cuda.get_device_properties(device).multi_processor_count * 8
|
|
582
|
+
dw_partial = torch.zeros((sm_count, N), device=device, dtype=weight.dtype)
|
|
583
|
+
|
|
584
|
+
dtype = torch2cute_dtype_map[x.dtype]
|
|
585
|
+
|
|
586
|
+
convert_from_dlpack = lambda tensor: (
|
|
587
|
+
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
588
|
+
mode=0, stride_order=(0, 1)
|
|
589
|
+
)
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
x_tensor, dout_tensor, dx_tensor = [convert_from_dlpack(tensor) for tensor in (x, dout, dx)]
|
|
593
|
+
|
|
594
|
+
weight_tensor = utils.convert_from_dlpack(
|
|
595
|
+
weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
dw_partial_tensor = convert_from_dlpack(dw_partial)
|
|
599
|
+
rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
600
|
+
|
|
601
|
+
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
602
|
+
|
|
603
|
+
compile_key = (dtype, N)
|
|
604
|
+
if compile_key not in _rmsnorm_backward.compile_cache:
|
|
605
|
+
rmsnorm_backward_op = RMSNormBackward(dtype, N)
|
|
606
|
+
_rmsnorm_backward.compile_cache[compile_key] = cute.compile(
|
|
607
|
+
rmsnorm_backward_op,
|
|
608
|
+
x_tensor,
|
|
609
|
+
weight_tensor,
|
|
610
|
+
dout_tensor,
|
|
611
|
+
rstd_tensor,
|
|
612
|
+
dx_tensor,
|
|
613
|
+
dw_partial_tensor,
|
|
614
|
+
sm_count,
|
|
615
|
+
current_stream,
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
_rmsnorm_backward.compile_cache[compile_key](
|
|
619
|
+
x_tensor,
|
|
620
|
+
weight_tensor,
|
|
621
|
+
dout_tensor,
|
|
622
|
+
rstd_tensor,
|
|
623
|
+
dx_tensor,
|
|
624
|
+
dw_partial_tensor,
|
|
625
|
+
current_stream,
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
dw = dw_partial.sum(dim=0).to(weight.dtype)
|
|
629
|
+
return dx, dw
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
_rmsnorm_backward.compile_cache = {}
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
class RMSNormFunction(torch.autograd.Function):
|
|
636
|
+
@staticmethod
|
|
637
|
+
def forward(ctx, x, weight, eps):
|
|
638
|
+
out, rstd = _rmsnorm_fwd(x, weight, eps, return_rstd=True)
|
|
639
|
+
ctx.save_for_backward(x, weight, rstd)
|
|
640
|
+
ctx.eps = eps
|
|
641
|
+
return out
|
|
642
|
+
|
|
643
|
+
@staticmethod
|
|
644
|
+
def backward(ctx, dout):
|
|
645
|
+
x, weight, rstd = ctx.saved_tensors
|
|
646
|
+
dx, dw = _rmsnorm_backward(x, weight, dout, rstd)
|
|
647
|
+
# dw is returned for weight gradient, None for eps gradient
|
|
648
|
+
return dx, dw, None
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
|
|
652
|
+
"""RMSNorm forward pass with automatic differentiation support.
|
|
653
|
+
|
|
654
|
+
Args:
|
|
655
|
+
x: Input tensor of shape (M, N)
|
|
656
|
+
weight: Weight tensor of shape (N,)
|
|
657
|
+
eps: Small value for numerical stability
|
|
658
|
+
|
|
659
|
+
Returns:
|
|
660
|
+
Normalized output tensor of same shape as x
|
|
661
|
+
"""
|
|
662
|
+
return RMSNormFunction.apply(x, weight, eps)
|
quack/softmax.py
CHANGED
|
@@ -75,7 +75,7 @@ class Softmax(ReductionBase):
|
|
|
75
75
|
self.kernel(mX, mO, tv_layout, tiler_mn).launch(
|
|
76
76
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
77
77
|
block=[num_threads, 1, 1],
|
|
78
|
-
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
78
|
+
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
79
79
|
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
80
80
|
stream=stream,
|
|
81
81
|
)
|
|
@@ -89,15 +89,16 @@ class Softmax(ReductionBase):
|
|
|
89
89
|
tiler_mn: cute.Shape,
|
|
90
90
|
):
|
|
91
91
|
tidx, _, _ = cute.arch.thread_idx()
|
|
92
|
-
bidx,
|
|
92
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
93
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
94
|
+
cluster_y = cute.arch.block_idx()[1]
|
|
95
|
+
else:
|
|
96
|
+
cluster_y = cutlass.const_expr(0)
|
|
93
97
|
|
|
94
98
|
shape = mX.shape
|
|
95
99
|
idX = cute.make_identity_tensor(shape)
|
|
96
100
|
# slice for CTAs
|
|
97
|
-
gX, gO, cX = [
|
|
98
|
-
cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
|
|
99
|
-
for mT in (mX, mO, idX)
|
|
100
|
-
]
|
|
101
|
+
gX, gO, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)]
|
|
101
102
|
|
|
102
103
|
smem = cutlass.utils.SmemAllocator()
|
|
103
104
|
sX = smem.allocate_tensor(
|
|
@@ -129,7 +130,9 @@ class Softmax(ReductionBase):
|
|
|
129
130
|
|
|
130
131
|
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
131
132
|
tXpX = (
|
|
132
|
-
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
133
|
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
134
|
+
if cutlass.const_expr(not is_even_N)
|
|
135
|
+
else None
|
|
133
136
|
)
|
|
134
137
|
if tXcX[0][0] < shape[0]:
|
|
135
138
|
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
|
@@ -148,7 +151,7 @@ class Softmax(ReductionBase):
|
|
|
148
151
|
cute.ReductionOp.MAX,
|
|
149
152
|
threads_per_row,
|
|
150
153
|
reduction_buffer[None, None, 0],
|
|
151
|
-
mbar_ptr + 0 if self.cluster_n > 1 else None,
|
|
154
|
+
mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
152
155
|
init_val=-cutlass.Float32.inf,
|
|
153
156
|
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
154
157
|
)
|
|
@@ -159,7 +162,7 @@ class Softmax(ReductionBase):
|
|
|
159
162
|
cute.ReductionOp.ADD,
|
|
160
163
|
threads_per_row,
|
|
161
164
|
reduction_buffer[None, None, 1],
|
|
162
|
-
mbar_ptr + 1 if self.cluster_n > 1 else None,
|
|
165
|
+
mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
163
166
|
init_val=0.0,
|
|
164
167
|
)
|
|
165
168
|
else:
|
|
@@ -174,7 +177,9 @@ class Softmax(ReductionBase):
|
|
|
174
177
|
y = exp_x * (1.0 / denom)
|
|
175
178
|
tXrO.store(y.to(tXrO.element_type))
|
|
176
179
|
tOpO = (
|
|
177
|
-
utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
|
|
180
|
+
utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
|
|
181
|
+
if cutlass.const_expr(not is_even_N)
|
|
182
|
+
else None
|
|
178
183
|
)
|
|
179
184
|
if tXcX[0][0] < shape[0]:
|
|
180
185
|
cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
|
|
@@ -283,7 +288,7 @@ class SoftmaxBackward(ReductionBase):
|
|
|
283
288
|
self.kernel(mdY, mY, mdX, tv_layout, tiler_mn).launch(
|
|
284
289
|
grid=[cute.ceil_div(mdY.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
285
290
|
block=[num_threads, 1, 1],
|
|
286
|
-
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
|
|
291
|
+
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
287
292
|
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
288
293
|
stream=stream,
|
|
289
294
|
)
|
|
@@ -298,14 +303,17 @@ class SoftmaxBackward(ReductionBase):
|
|
|
298
303
|
tiler_mn: cute.Shape,
|
|
299
304
|
):
|
|
300
305
|
tidx, _, _ = cute.arch.thread_idx()
|
|
301
|
-
bidx,
|
|
306
|
+
bidx, _, _ = cute.arch.block_idx()
|
|
307
|
+
if cutlass.const_expr(self.cluster_n > 1):
|
|
308
|
+
cluster_y = cute.arch.block_idx()[1]
|
|
309
|
+
else:
|
|
310
|
+
cluster_y = cutlass.const_expr(0)
|
|
302
311
|
|
|
303
312
|
shape = mdY.shape
|
|
304
313
|
idX = cute.make_identity_tensor(shape)
|
|
305
314
|
# slice for CTAs
|
|
306
315
|
gdY, gY, gdX, cX = [
|
|
307
|
-
cute.local_tile(mT, tiler_mn, (bidx,
|
|
308
|
-
for mT in (mdY, mY, mdX, idX)
|
|
316
|
+
cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mdY, mY, mdX, idX)
|
|
309
317
|
]
|
|
310
318
|
|
|
311
319
|
smem = cutlass.utils.SmemAllocator()
|
|
@@ -344,7 +352,7 @@ class SoftmaxBackward(ReductionBase):
|
|
|
344
352
|
is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
345
353
|
tdYpdY = (
|
|
346
354
|
utils.predicate_k(thr_copy_load.partition_S(cX), limit=shape[1])
|
|
347
|
-
if not is_even_N
|
|
355
|
+
if cutlass.const_expr(not is_even_N)
|
|
348
356
|
else None
|
|
349
357
|
)
|
|
350
358
|
|
|
@@ -366,7 +374,7 @@ class SoftmaxBackward(ReductionBase):
|
|
|
366
374
|
cute.ReductionOp.ADD,
|
|
367
375
|
threads_per_row,
|
|
368
376
|
reduction_buffer[None, None, 0],
|
|
369
|
-
mbar_ptr if self.cluster_n > 1 else None,
|
|
377
|
+
mbar_ptr if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
370
378
|
init_val=0.0,
|
|
371
379
|
hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
372
380
|
)
|
|
@@ -376,7 +384,7 @@ class SoftmaxBackward(ReductionBase):
|
|
|
376
384
|
tdXrdX.store(dx.to(tdXrdX.element_type))
|
|
377
385
|
tdXpdX = (
|
|
378
386
|
utils.predicate_k(thr_copy_store.partition_S(cX), limit=shape[1])
|
|
379
|
-
if not is_even_N
|
|
387
|
+
if cutlass.const_expr(not is_even_N)
|
|
380
388
|
else None
|
|
381
389
|
)
|
|
382
390
|
if tXcX[0][0] < shape[0]:
|
quack/utils.py
CHANGED
|
@@ -24,32 +24,19 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
@cute.jit
|
|
27
|
-
def max_constexpr(
|
|
28
|
-
a: cutlass.Constexpr[cute.Numeric], b: cutlass.Constexpr[cute.Numeric]
|
|
29
|
-
) -> cutlass.Constexpr[cute.Numeric]:
|
|
30
|
-
return a if a > b else b
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
@cute.jit
|
|
34
|
-
def min_constexpr(
|
|
35
|
-
a: cutlass.Constexpr[cute.Numeric], b: cutlass.Constexpr[cute.Numeric]
|
|
36
|
-
) -> cutlass.Constexpr[cute.Numeric]:
|
|
37
|
-
return a if a < b else b
|
|
38
|
-
|
|
39
|
-
|
|
40
27
|
def warp_reduce(
|
|
41
28
|
val: cute.TensorSSA | cute.Numeric,
|
|
42
29
|
op: Callable,
|
|
43
30
|
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
|
|
44
31
|
) -> cute.TensorSSA | cute.Numeric:
|
|
45
|
-
if isinstance(val, cute.TensorSSA):
|
|
32
|
+
if cutlass.const_expr(isinstance(val, cute.TensorSSA)):
|
|
46
33
|
res = cute.make_fragment(val.shape, val.dtype)
|
|
47
34
|
res.store(val)
|
|
48
|
-
for i in
|
|
35
|
+
for i in cutlass.range_constexpr(cute.size(val.shape)):
|
|
49
36
|
res[i] = warp_reduce(res[i], op, width)
|
|
50
37
|
return res.load()
|
|
51
38
|
else:
|
|
52
|
-
for i in
|
|
39
|
+
for i in cutlass.range_constexpr(int(math.log2(width))):
|
|
53
40
|
val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
|
|
54
41
|
return val
|
|
55
42
|
|
|
@@ -111,15 +98,15 @@ def store_shared_remote(
|
|
|
111
98
|
remote_mbar_ptr_i32 = set_block_rank(
|
|
112
99
|
mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
|
113
100
|
).ir_value()
|
|
114
|
-
if isinstance(val, float):
|
|
101
|
+
if cutlass.const_expr(isinstance(val, float)):
|
|
115
102
|
val = Float32(val)
|
|
116
103
|
assert isinstance(val, (Float32, cutlass.Int64)), "val must be Float32 or Int64"
|
|
117
|
-
suffix = "f32" if isinstance(val, Float32) else "s64"
|
|
104
|
+
suffix = "f32" if cutlass.const_expr(isinstance(val, Float32)) else "s64"
|
|
118
105
|
llvm.inline_asm(
|
|
119
106
|
None,
|
|
120
107
|
[remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
|
|
121
108
|
f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];",
|
|
122
|
-
f"r,{'f' if isinstance(val, Float32) else 'l'},r",
|
|
109
|
+
f"r,{'f' if cutlass.const_expr(isinstance(val, Float32)) else 'l'},r",
|
|
123
110
|
has_side_effects=True,
|
|
124
111
|
is_align_stack=False,
|
|
125
112
|
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
@@ -195,7 +182,7 @@ def row_reduce(
|
|
|
195
182
|
val = warp_reduce(
|
|
196
183
|
val,
|
|
197
184
|
warp_op,
|
|
198
|
-
width=
|
|
185
|
+
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
199
186
|
)
|
|
200
187
|
if cutlass.const_expr(hook_fn is not None):
|
|
201
188
|
hook_fn()
|
|
@@ -225,7 +212,7 @@ def online_softmax_reduce(
|
|
|
225
212
|
max_x = warp_reduce(
|
|
226
213
|
x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
|
|
227
214
|
cute.arch.fmax,
|
|
228
|
-
width=
|
|
215
|
+
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
229
216
|
)
|
|
230
217
|
log2_e = math.log2(math.e)
|
|
231
218
|
exp_x = exp2f(x * log2_e - (max_x * log2_e))
|
|
@@ -233,7 +220,7 @@ def online_softmax_reduce(
|
|
|
233
220
|
sum_exp_x = warp_reduce(
|
|
234
221
|
exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
|
|
235
222
|
operator.add,
|
|
236
|
-
width=
|
|
223
|
+
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
237
224
|
)
|
|
238
225
|
if cutlass.const_expr(hook_fn is not None):
|
|
239
226
|
hook_fn()
|
|
@@ -299,18 +286,18 @@ def online_softmax_reduce(
|
|
|
299
286
|
return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
|
|
300
287
|
|
|
301
288
|
|
|
289
|
+
@cute.jit
|
|
302
290
|
def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
|
|
303
291
|
"""exp2f calculation for both vector and scalar.
|
|
304
|
-
|
|
305
292
|
:param x: input value
|
|
306
293
|
:type x: cute.TensorSSA or Float32
|
|
307
294
|
:return: exp2 value
|
|
308
295
|
:rtype: cute.TensorSSA or Float32
|
|
309
296
|
"""
|
|
310
|
-
if isinstance(x, cute.TensorSSA):
|
|
297
|
+
if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
|
|
311
298
|
res = cute.make_fragment(x.shape, Float32)
|
|
312
299
|
res.store(x)
|
|
313
|
-
for i in
|
|
300
|
+
for i in cutlass.range_constexpr(cute.size(x.shape)):
|
|
314
301
|
res[i] = cute.arch.exp2(res[i])
|
|
315
302
|
return res.load()
|
|
316
303
|
else:
|
|
@@ -347,6 +334,7 @@ def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
|
347
334
|
)
|
|
348
335
|
|
|
349
336
|
|
|
337
|
+
@cute.jit
|
|
350
338
|
def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
|
|
351
339
|
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
|
|
352
340
|
tApA = cute.make_fragment(
|
|
@@ -356,8 +344,8 @@ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
|
|
|
356
344
|
),
|
|
357
345
|
cutlass.Boolean,
|
|
358
346
|
)
|
|
359
|
-
for rest_v in
|
|
360
|
-
for rest_k in
|
|
347
|
+
for rest_v in cutlass.range_constexpr(tApA.shape[0]):
|
|
348
|
+
for rest_k in cutlass.range_constexpr(tApA.shape[2]):
|
|
361
349
|
tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
|
|
362
350
|
return tApA
|
|
363
351
|
|
|
@@ -373,8 +361,8 @@ def fill_oob(tXsX: cute.Tensor, tXpX: cute.Tensor, fill_value: cute.Numeric) ->
|
|
|
373
361
|
"""
|
|
374
362
|
tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), 0, 0])
|
|
375
363
|
tXrX_fill.fill(fill_value)
|
|
376
|
-
for rest_v in
|
|
377
|
-
for rest_k in
|
|
364
|
+
for rest_v in cutlass.range_constexpr(tXpX.shape[0]):
|
|
365
|
+
for rest_k in cutlass.range_constexpr(tXpX.shape[2]):
|
|
378
366
|
if not tXpX[rest_v, 0, rest_k]:
|
|
379
367
|
cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
|
|
380
368
|
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: quack-kernels
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.5
|
|
4
4
|
Requires-Python: >=3.9
|
|
5
5
|
License-File: LICENSE
|
|
6
|
-
Requires-Dist: nvidia-cutlass-dsl==4.0.
|
|
6
|
+
Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
|
|
7
7
|
Requires-Dist: torch
|
|
8
8
|
Provides-Extra: dev
|
|
9
9
|
Requires-Dist: pre-commit; extra == "dev"
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
quack/__init__.py,sha256=GPoImcynY5-OkMep5RhQhXrnZyxgqZG3RoHhsYQFSL4,203
|
|
2
|
+
quack/cross_entropy.py,sha256=WkngPY8uk4RCjCFtHtB7h9GF_8xt4NnyvDzvw73gIL4,19320
|
|
3
|
+
quack/reduction_base.py,sha256=fFuGXPR3lDq2yw_m86ujmkni6R51jzNAzy_r9R6C8tA,3563
|
|
4
|
+
quack/rmsnorm.py,sha256=N9NavrR85ws4cZgkfpeRLjYkVSq2yfyzJQWvfKf98pY,23935
|
|
5
|
+
quack/softmax.py,sha256=VfhlC2huRuv7olFSVFgS8LF1yF8TFV64yjjjQxYX9yk,16364
|
|
6
|
+
quack/utils.py,sha256=6EyWgf0z3wcbhGUivHmWB8hVBnEzMyOhmAuZ2Te82k0,15226
|
|
7
|
+
quack_kernels-0.1.5.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
8
|
+
quack_kernels-0.1.5.dist-info/METADATA,sha256=WI-2CP1mRH05V9Fjdx7HsErNOkrc6fUhheoH4ynlo-U,289
|
|
9
|
+
quack_kernels-0.1.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
10
|
+
quack_kernels-0.1.5.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
11
|
+
quack_kernels-0.1.5.dist-info/RECORD,,
|
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
quack/__init__.py,sha256=aUR7drzgaqmbzw9H_eoFselMUVQVF3BHc9VOzZg5d-Q,203
|
|
2
|
-
quack/cross_entropy.py,sha256=_Xlyifd_YS8LaYxYlZEsuBfsi8zTH4At3i9DDggGCf8,9319
|
|
3
|
-
quack/reduction_base.py,sha256=nrRsXwTpLVQkPp2Gr_FgHRPnifqkMHRodve5ciHzx58,3667
|
|
4
|
-
quack/rmsnorm.py,sha256=YqGTTKHHXYzw3xnnjBRfaN9TDlhG8D_fSI9CHKAU40A,10548
|
|
5
|
-
quack/softmax.py,sha256=mWaUfaY6PBtO1ioYxXxS-yodQmcBNGasWVMUg9G066Y,15938
|
|
6
|
-
quack/utils.py,sha256=1-HMcFTEvGdAtqC3ucQGZ3DLa_PoJQsqwYlKd9bcXO8,15347
|
|
7
|
-
quack_kernels-0.1.3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
8
|
-
quack_kernels-0.1.3.dist-info/METADATA,sha256=DDuEKHLjFx9dFTQV5YtXsnKVFZVoueO7NwhcwOtpw6g,284
|
|
9
|
-
quack_kernels-0.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
10
|
-
quack_kernels-0.1.3.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
11
|
-
quack_kernels-0.1.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|