quack-kernels 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.4"
1
+ __version__ = "0.1.5"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
quack/cross_entropy.py CHANGED
@@ -1,3 +1,5 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
1
3
  import math
2
4
  import torch
3
5
  from typing import Optional, Type
@@ -200,7 +202,7 @@ class CrossEntropy(ReductionBase):
200
202
  mLSE[row] = lse
201
203
 
202
204
 
203
- def cross_entropy(
205
+ def _cross_entropy(
204
206
  x: torch.Tensor,
205
207
  target: torch.Tensor,
206
208
  return_lse: bool = False,
@@ -241,15 +243,300 @@ def cross_entropy(
241
243
  stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
242
244
 
243
245
  compile_key = (dtype, N, lse is not None)
244
- if compile_key not in cross_entropy.compile_cache:
246
+ if compile_key not in _cross_entropy.compile_cache:
245
247
  cross_entropy_op = CrossEntropy(dtype, N)
246
- cross_entropy.compile_cache[compile_key] = cute.compile(
248
+ _cross_entropy.compile_cache[compile_key] = cute.compile(
247
249
  cross_entropy_op, x_tensor, target_tensor, loss_tensor, lse_tensor, stream
248
250
  )
249
- cross_entropy.compile_cache[compile_key](
251
+ _cross_entropy.compile_cache[compile_key](
250
252
  x_tensor, target_tensor, loss_tensor, lse_tensor, stream
251
253
  )
252
254
  return loss if not return_lse else (loss, lse)
253
255
 
254
256
 
255
- cross_entropy.compile_cache = {}
257
+ _cross_entropy.compile_cache = {}
258
+
259
+
260
+ class CrossEntropyBackward:
261
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int):
262
+ self.dtype = dtype
263
+ self.N = N
264
+ self.vecsize = 128 // dtype.width
265
+
266
+ def _calculate_threads_per_row(self):
267
+ N = self.N
268
+ return (
269
+ 8
270
+ if N <= 64
271
+ else (
272
+ 16
273
+ if N <= 128
274
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
275
+ )
276
+ )
277
+
278
+ def _get_tv_layout(self):
279
+ N = self.N
280
+ vecsize = self.vecsize
281
+ num_threads = 128 if N <= 16384 else 256
282
+ threads_per_row = self._calculate_threads_per_row()
283
+ cols_per_block = num_threads // threads_per_row
284
+ num_blocks_N = cute.ceil_div(min(N, 16384) // vecsize, threads_per_row)
285
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
286
+ tv_layout = cute.make_layout(
287
+ ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
288
+ stride=(
289
+ (vecsize * cols_per_block, 1),
290
+ (cols_per_block, cols_per_block * vecsize * threads_per_row),
291
+ ),
292
+ )
293
+ return tiler_mn, tv_layout
294
+
295
+ @cute.jit
296
+ def __call__(
297
+ self,
298
+ mX: cute.Tensor,
299
+ mTarget: cute.Tensor,
300
+ mDLoss: cute.Tensor,
301
+ mdX: cute.Tensor,
302
+ mLSE: cute.Tensor,
303
+ stream: cuda.CUstream,
304
+ ):
305
+ assert mX.element_type == self.dtype
306
+ assert mdX.element_type == self.dtype
307
+
308
+ tiler_mn, tv_layout = self._get_tv_layout()
309
+ num_threads = cute.size(tv_layout, mode=[0])
310
+
311
+ mDLoss = cute.make_tensor(
312
+ mDLoss.iterator, cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,)))
313
+ )
314
+ mTarget = cute.make_tensor(
315
+ mTarget.iterator, cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,)))
316
+ )
317
+ mLSE = cute.make_tensor(
318
+ mLSE.iterator, cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,)))
319
+ )
320
+
321
+ smem_size = cute.size_in_bytes(
322
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0))
323
+ )
324
+
325
+ self.kernel(
326
+ mX,
327
+ mTarget,
328
+ mDLoss,
329
+ mdX,
330
+ mLSE,
331
+ mX.shape,
332
+ tv_layout,
333
+ tiler_mn,
334
+ ).launch(
335
+ grid=[
336
+ cute.ceil_div(mX.shape[0], tiler_mn[0]),
337
+ cute.ceil_div(mX.shape[1], tiler_mn[1]),
338
+ 1,
339
+ ],
340
+ block=[num_threads, 1, 1],
341
+ smem=smem_size,
342
+ stream=stream,
343
+ )
344
+
345
+ @cute.kernel
346
+ def kernel(
347
+ self,
348
+ mX: cute.Tensor, # (M, N)
349
+ mTarget: cute.Tensor, # (M,)
350
+ mDLoss: cute.Tensor, # (M,)
351
+ mdX: cute.Tensor, # (M, N)
352
+ mLSE: cute.Tensor, # (M,)
353
+ shape: cute.Shape,
354
+ tv_layout: cute.Layout,
355
+ tiler_mn: cute.Shape,
356
+ ):
357
+ tidx, _, _ = cute.arch.thread_idx()
358
+ bidx, bidy, _ = cute.arch.block_idx()
359
+
360
+ smem = cutlass.utils.SmemAllocator()
361
+ sX = smem.allocate_tensor(
362
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
363
+ )
364
+
365
+ idX = cute.make_identity_tensor(shape)
366
+
367
+ gX, gdX, cX, gTarget, gDLoss, gLse = [
368
+ cute.local_tile(mT, tiler_mn, (bidx, bidy))
369
+ for mT in (mX, mdX, idX, mTarget, mDLoss, mLSE)
370
+ ]
371
+
372
+ copy_atom_load_X = cute.make_copy_atom(
373
+ cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128
374
+ )
375
+ copy_atom_load_X_async = cute.make_copy_atom(
376
+ cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
377
+ )
378
+ copy_atom_store_O = cute.make_copy_atom(
379
+ cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=128
380
+ )
381
+
382
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
383
+ thr_copy_X_async = cute.make_tiled_copy(
384
+ copy_atom_load_X_async, tv_layout, tiler_mn
385
+ ).get_slice(tidx)
386
+ thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
387
+
388
+ #### Thread View
389
+ tXgX = thr_copy_X_async.partition_S(gX)
390
+ tXsX = thr_copy_X_async.partition_S(sX)
391
+
392
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
393
+ tXcFull = thr_copy_X.partition_S(cX) # improve
394
+
395
+ tXgO = thr_copy_O.partition_D(gdX)
396
+
397
+ # allocate fragments for gmem->rmem
398
+ tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
399
+
400
+ is_even_N = cutlass.const_expr(shape[1] % tiler_mn[1] == 0)
401
+ row = tXcX[0][0]
402
+
403
+ tXpX = (
404
+ utils.predicate_k(thr_copy_X_async.partition_S(cX), limit=shape[1])
405
+ if not is_even_N
406
+ else None
407
+ )
408
+
409
+ if row < shape[0]:
410
+ cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
411
+ cute.arch.cp_async_commit_group()
412
+ cute.arch.cp_async_wait_group(0)
413
+ if cutlass.const_expr(not is_even_N):
414
+ utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
415
+
416
+ cute.autovec_copy(tXsX, tXrX)
417
+ x = tXrX.load().to(cute.Float32)
418
+
419
+ label = cute.Int32.zero
420
+ dloss = cute.Float32.zero
421
+ lse = cute.Float32.zero
422
+ if row < shape[0]:
423
+ label = cute.Int32(mTarget[row])
424
+ dloss = cute.Float32(mDLoss[row])
425
+ lse = cute.Float32(mLSE[row])
426
+
427
+ log2_e = math.log2(math.e)
428
+ probs = utils.exp2f((x - lse) * log2_e)
429
+ prob_shifted = probs - 1.0
430
+
431
+ mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
432
+ for i in cutlass.range_constexpr(cute.size(tXcFull)):
433
+ mask[i] = tXcFull[i][1] == label
434
+
435
+ mask = mask.load()
436
+ grad = cute.where(mask, prob_shifted, probs)
437
+ grad = grad * dloss
438
+
439
+ tXrO.store(grad.to(tXrO.element_type))
440
+ tOpO = (
441
+ utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1]) if not is_even_N else None
442
+ )
443
+ if row < shape[0]:
444
+ cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
445
+
446
+
447
+ def _cross_entropy_backward(
448
+ x: torch.Tensor,
449
+ target: torch.Tensor,
450
+ dloss: torch.Tensor,
451
+ lse: torch.Tensor,
452
+ inplace_backward: bool = False,
453
+ ) -> torch.Tensor:
454
+ """Cross entropy backward pass.
455
+ Args:
456
+ x: Input logits tensor of shape (M, N)
457
+ target: Target class indices tensor of shape (M,)
458
+ dloss: Upstream gradients tensor of shape (M,)
459
+ lse: Log-sum-exp values tensor of shape (M,)
460
+ Returns:
461
+ Input gradients tensor of shape (M, N)
462
+ """
463
+ assert x.dim() == 2, "Input must be 2D"
464
+ assert target.dim() == 1, "Target must be 1D"
465
+ assert dloss.dim() == 1, "dloss must be 1D"
466
+ assert lse.dim() == 1, "lse must be 1D"
467
+ assert x.shape[0] == target.shape[0], "Batch dimensions must match"
468
+ assert x.shape[0] == dloss.shape[0], "Batch dimensions must match"
469
+ assert x.shape[0] == lse.shape[0], "Batch dimensions must match"
470
+ assert (
471
+ x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
472
+ ), "Tensors must be on CUDA device"
473
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
474
+ assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
475
+
476
+ M, N = x.shape
477
+ dx = torch.empty_like(x) if not inplace_backward else x
478
+ dtype = torch2cute_dtype_map[x.dtype]
479
+
480
+ convert_from_dlpack = lambda tensor: (
481
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
482
+ mode=0, stride_order=(0, 1)
483
+ )
484
+ )
485
+ x_tensor = convert_from_dlpack(x)
486
+ dx_tensor = convert_from_dlpack(dx)
487
+ dloss_tensor = from_dlpack(dloss.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
488
+ lse_tensor = from_dlpack(lse.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
489
+ target_tensor = from_dlpack(target.detach(), assumed_align=32).mark_compact_shape_dynamic(
490
+ mode=0
491
+ )
492
+ stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
493
+
494
+ compile_key = (dtype, N)
495
+ if compile_key not in _cross_entropy_backward.compile_cache:
496
+ cross_entropy_backward_op = CrossEntropyBackward(dtype, N)
497
+ _cross_entropy_backward.compile_cache[compile_key] = cute.compile(
498
+ cross_entropy_backward_op,
499
+ x_tensor,
500
+ target_tensor,
501
+ dloss_tensor,
502
+ dx_tensor,
503
+ lse_tensor,
504
+ stream,
505
+ )
506
+ _cross_entropy_backward.compile_cache[compile_key](
507
+ x_tensor, target_tensor, dloss_tensor, dx_tensor, lse_tensor, stream
508
+ )
509
+ return dx
510
+
511
+
512
+ _cross_entropy_backward.compile_cache = {}
513
+
514
+
515
+ class CrossEntropyFunction(torch.autograd.Function):
516
+ @staticmethod
517
+ def forward(ctx, x, target, inplace_backward=False):
518
+ loss, lse = _cross_entropy(x, target, return_lse=True)
519
+ ctx.save_for_backward(x, target, lse)
520
+ ctx.inplace_backward = inplace_backward
521
+ return loss
522
+
523
+ @staticmethod
524
+ def backward(ctx, dloss):
525
+ x, target, lse = ctx.saved_tensors
526
+ dx = _cross_entropy_backward(x, target, dloss, lse, inplace_backward=ctx.inplace_backward)
527
+ return dx, None, None
528
+
529
+
530
+ def cross_entropy(
531
+ x: torch.Tensor, target: torch.Tensor, inplace_backward: bool = False
532
+ ) -> torch.Tensor:
533
+ """Cross entropy loss with automatic differentiation support.
534
+
535
+ Args:
536
+ x: Input logits tensor of shape (M, N)
537
+ target: Target class indices tensor of shape (M,)
538
+
539
+ Returns:
540
+ Cross entropy loss tensor of shape (M,)
541
+ """
542
+ return CrossEntropyFunction.apply(x, target, inplace_backward)
quack/reduction_base.py CHANGED
@@ -6,8 +6,6 @@ from typing import Type, Tuple, Optional
6
6
  import cutlass
7
7
  import cutlass.cute as cute
8
8
 
9
- import quack.utils as utils
10
-
11
9
 
12
10
  torch2cute_dtype_map = {
13
11
  torch.float16: cutlass.Float16,
@@ -39,7 +37,6 @@ class ReductionBase:
39
37
  vecsize = copy_bits // self.dtype.width
40
38
  assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
41
39
  num_threads = self._get_num_threads()
42
- num_warps = num_threads // cute.arch.WARP_SIZE
43
40
  assert num_threads % cute.arch.WARP_SIZE == 0
44
41
 
45
42
  threads_per_row = self._calculate_threads_per_row()
@@ -64,7 +61,7 @@ class ReductionBase:
64
61
 
65
62
  def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int):
66
63
  num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
67
- warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
64
+ warps_per_row = max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
68
65
  return cute.make_ordered_layout(
69
66
  (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
70
67
  order=(1, 0, 2),
quack/rmsnorm.py CHANGED
@@ -9,7 +9,6 @@ import cuda.bindings.driver as cuda
9
9
  import cutlass
10
10
  import cutlass.cute as cute
11
11
  from cutlass.cute.runtime import from_dlpack
12
-
13
12
  import quack.utils as utils
14
13
  from quack.reduction_base import ReductionBase, torch2cute_dtype_map
15
14
 
@@ -210,20 +209,18 @@ class RMSNorm(ReductionBase):
210
209
  cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
211
210
 
212
211
 
213
- def rmsnorm(
212
+ def _rmsnorm_fwd(
214
213
  x: torch.Tensor,
215
214
  weight: torch.Tensor,
216
215
  eps: float = 1e-6,
217
216
  return_rstd: bool = False,
218
217
  ) -> torch.Tensor:
219
218
  """RMSNorm forward pass.
220
-
221
219
  Args:
222
220
  x: Input tensor of shape (M, N)
223
221
  weight: Weight tensor of shape (N,)
224
222
  eps: Small value for numerical stability
225
223
  return_rstd: Whether to return the reciprocal standard deviation
226
-
227
224
  Returns:
228
225
  Normalized output tensor of same shape as x
229
226
  If return_rstd is True, also returns rstd tensor of shape (M,)
@@ -259,18 +256,18 @@ def rmsnorm(
259
256
  )
260
257
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
261
258
  compile_key = (dtype, N, rstd is not None)
262
- if compile_key not in rmsnorm.compile_cache:
259
+ if compile_key not in _rmsnorm_fwd.compile_cache:
263
260
  rmsnorm_op = RMSNorm(dtype, N)
264
- rmsnorm.compile_cache[compile_key] = cute.compile(
261
+ _rmsnorm_fwd.compile_cache[compile_key] = cute.compile(
265
262
  rmsnorm_op, x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream
266
263
  )
267
- rmsnorm.compile_cache[compile_key](
264
+ _rmsnorm_fwd.compile_cache[compile_key](
268
265
  x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream, eps
269
266
  )
270
267
  return (out, rstd) if return_rstd else out
271
268
 
272
269
 
273
- rmsnorm.compile_cache = {}
270
+ _rmsnorm_fwd.compile_cache = {}
274
271
 
275
272
 
276
273
  def rmsnorm_ref(x, w, eps=1e-6):
@@ -283,3 +280,383 @@ def rmsnorm_ref(x, w, eps=1e-6):
283
280
  def rstd_ref(x, eps=1e-6):
284
281
  x_f32 = x.float()
285
282
  return 1.0 / torch.sqrt(torch.mean(x_f32 * x_f32, dim=-1) + eps)
283
+
284
+
285
+ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
286
+ """Reference implementation for RMSNorm backward pass."""
287
+ x_f32 = x.float()
288
+ x_hat = x_f32 * rstd.unsqueeze(1)
289
+ wdy = dout * w
290
+ c1 = (x_hat * wdy).mean(dim=-1, keepdim=True)
291
+ dx = (wdy - x_hat * c1) * rstd.unsqueeze(1)
292
+
293
+ # dL/dW
294
+ dw = (dout * x_hat).sum(dim=0)
295
+ return dx.to(x.dtype), dw.to(w.dtype)
296
+
297
+
298
+ class RMSNormBackward(ReductionBase):
299
+ def __init__(self, dtype: cutlass.Numeric, N: int):
300
+ # 1 stage for computing mean of x_hat * wdy
301
+ super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Float32)
302
+
303
+ def _calculate_threads_per_row(self):
304
+ N = self.N
305
+ return (
306
+ 8
307
+ if N <= 64
308
+ else (
309
+ 16
310
+ if N <= 128
311
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
312
+ )
313
+ )
314
+
315
+ def _set_cluster_n(self):
316
+ N = self.N
317
+ if cutlass.const_expr(self.dtype.width == 16):
318
+ cluster_n = (
319
+ 1
320
+ if N <= 16 * 1024
321
+ else (
322
+ 2
323
+ if N <= 32 * 1024
324
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
325
+ )
326
+ )
327
+ else: # fp32
328
+ cluster_n = (
329
+ 1
330
+ if N <= 32 * 1024
331
+ else (
332
+ 2
333
+ if N <= 64 * 1024
334
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
335
+ )
336
+ )
337
+ self.cluster_n = cluster_n
338
+
339
+ @cute.jit
340
+ def __call__(
341
+ self,
342
+ mX: cute.Tensor,
343
+ mW: cute.Tensor,
344
+ mDout: cute.Tensor,
345
+ mRstd: cute.Tensor,
346
+ mDx: cute.Tensor,
347
+ mDw: cute.Tensor,
348
+ sm_count: cutlass.Constexpr,
349
+ stream: cuda.CUstream,
350
+ ):
351
+ self._set_cluster_n()
352
+ tiler_mn, tv_layout = self._get_tv_layout()
353
+ num_threads = cute.size(tv_layout, mode=[0])
354
+ num_warps = num_threads // cute.arch.WARP_SIZE
355
+
356
+ mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
357
+ mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
358
+
359
+ mRstd_expanded_layout = cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,)))
360
+ mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
361
+
362
+ num_blocks = (
363
+ sm_count if tiler_mn[0] == 1 else min(sm_count, cute.ceil_div(1024, tiler_mn[0]))
364
+ )
365
+
366
+ self.kernel(mX, mW, mDout, mRstd, mDx, mDw, sm_count, tv_layout, tiler_mn).launch(
367
+ grid=[num_blocks, self.cluster_n, 1],
368
+ block=[num_threads, 1, 1],
369
+ cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
370
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
371
+ stream=stream,
372
+ )
373
+
374
+ @cute.kernel
375
+ def kernel(
376
+ self,
377
+ mX: cute.Tensor,
378
+ mW: cute.Tensor,
379
+ mDout: cute.Tensor,
380
+ mRstd: cute.Tensor,
381
+ mDx: cute.Tensor,
382
+ mDw: cute.Tensor,
383
+ sm_count: cutlass.Constexpr,
384
+ tv_layout: cute.Layout,
385
+ tiler_mn: cute.Shape,
386
+ ):
387
+ tidx, _, _ = cute.arch.thread_idx()
388
+ bidx, cluster_y, _ = cute.arch.block_idx()
389
+ gdim, _, _ = cute.arch.grid_dim()
390
+
391
+ shape = mX.shape
392
+ M, N = shape[0], shape[1]
393
+
394
+ idX = cute.make_identity_tensor(shape)
395
+
396
+ smem = cutlass.utils.SmemAllocator()
397
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
398
+
399
+ copy_atom_load_X = cute.make_copy_atom(
400
+ cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
401
+ )
402
+
403
+ copy_atom_load_W = cute.make_copy_atom(
404
+ cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
405
+ )
406
+
407
+ copy_atom_store_dX = cute.make_copy_atom(
408
+ cute.nvgpu.CopyUniversalOp(), mDx.element_type, num_bits_per_copy=128
409
+ )
410
+
411
+ copy_atom_dw = cute.make_copy_atom(
412
+ cute.nvgpu.CopyUniversalOp(), mDw.element_type, num_bits_per_copy=128
413
+ )
414
+
415
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
416
+ thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
417
+ thr_copy_dw = cute.make_tiled_copy(copy_atom_dw, tv_layout, tiler_mn).get_slice(tidx)
418
+ thr_store_dx = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
419
+
420
+ gW = cute.local_tile(mW, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
421
+ tWgW = thr_copy_W.partition_S(gW)
422
+ tWrW = cute.make_fragment_like(tWgW)
423
+ tXrW = thr_copy_X.retile(tWrW)
424
+
425
+ gW_coord = cute.local_tile(idX, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
426
+
427
+ tWpW = utils.predicate_k(thr_copy_W.partition_S(gW_coord), limit=shape[1])
428
+ cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
429
+ weight = tXrW.load().to(cute.Float32)
430
+
431
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
432
+
433
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
434
+
435
+ dw_coord = cute.local_tile(idX, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
436
+ tDwpDw = utils.predicate_k(thr_copy_dw.partition_S(dw_coord), limit=shape[1])
437
+
438
+ gDw = cute.local_tile(mDw, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
439
+ tDwgDw = thr_copy_dw.partition_D(gDw)
440
+ tDwrDw = cute.make_fragment_like(tDwgDw)
441
+ dw_accumulator = thr_copy_X.retile(tDwrDw)
442
+ dw_accumulator.fill(0.0)
443
+
444
+ M_pad = ((M + sm_count - 1) // sm_count) * sm_count
445
+
446
+ jump = sm_count if tiler_mn[0] == 1 else min(sm_count, cute.ceil_div(1024, tiler_mn[0]))
447
+
448
+ if cutlass.const_expr(self.cluster_n > 1):
449
+ cute.arch.cluster_arrive()
450
+ cute.arch.cluster_wait()
451
+
452
+ ## need to update range_dynamic since it will be deprecated soon
453
+ for row_offset in cutlass.range_dynamic(bidx, M_pad, jump):
454
+ gX = cute.local_tile(
455
+ mX, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
456
+ )
457
+ gDout = cute.local_tile(
458
+ mDout, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
459
+ )
460
+ gRstd = cute.local_tile(
461
+ mRstd, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
462
+ )
463
+ gDx = cute.local_tile(
464
+ mDx, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
465
+ )
466
+ cX = cute.local_tile(
467
+ idX, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
468
+ )
469
+
470
+ tXgX = thr_copy_X.partition_S(gX)
471
+ thrDout = thr_copy_X.partition_S(gDout)
472
+ tXrRstd = thr_copy_W.partition_S(gRstd)
473
+ thrDx = thr_store_dx.partition_D(gDx)
474
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
475
+
476
+ tXrX, frgDout, frgDx = [cute.make_fragment_like(thr) for thr in (tXgX, thrDout, thrDx)]
477
+
478
+ tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
479
+
480
+ if tXcX[0][0] < shape[0]:
481
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
482
+ cute.copy(copy_atom_load_X, thrDout, frgDout, pred=tXpX)
483
+
484
+ x = tXrX.load().to(cute.Float32)
485
+ dout = frgDout.load().to(cute.Float32)
486
+
487
+ rstd = tXrRstd[0]
488
+ x_hat = x * rstd
489
+ wdy = dout * weight
490
+
491
+ threads_per_row = tv_layout.shape[0][0]
492
+
493
+ row = tXcX[0][0]
494
+ if cutlass.const_expr(self.cluster_n > 1):
495
+ cute.arch.cluster_arrive()
496
+ cute.arch.cluster_wait()
497
+ else:
498
+ cute.arch.barrier()
499
+
500
+ mean_xhat_wdy = (
501
+ utils.row_reduce(
502
+ x_hat * wdy,
503
+ cute.ReductionOp.ADD,
504
+ threads_per_row,
505
+ reduction_buffer[None, None, 0],
506
+ mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
507
+ init_val=0.0,
508
+ hook_fn=cute.arch.cluster_wait
509
+ if cutlass.const_expr(self.cluster_n > 1)
510
+ else None,
511
+ )
512
+ / shape[1]
513
+ )
514
+
515
+ dx = (wdy - x_hat * mean_xhat_wdy) * rstd
516
+ frgDx.store(dx.to(frgDout.element_type))
517
+
518
+ if row < M:
519
+ cute.copy(copy_atom_store_dX, frgDx, thrDx, pred=tXpX)
520
+
521
+ if cutlass.const_expr(self.cluster_n > 1):
522
+ cute.arch.cluster_arrive()
523
+ cute.arch.cluster_wait()
524
+ else:
525
+ cute.arch.barrier()
526
+
527
+ if row < M:
528
+ dw_row = dout * x_hat
529
+ current_dw = dw_accumulator.load().to(cute.Float32)
530
+ updated_dw = current_dw + dw_row
531
+ dw_accumulator.store(updated_dw.to(dw_accumulator.element_type))
532
+
533
+ """
534
+ if cutlass.const_expr(self.cluster_n > 1):
535
+ cute.arch.cluster_arrive()
536
+ cute.arch.cluster_wait()
537
+ else:
538
+ cute.arch.barrier()
539
+ """
540
+ """
541
+ if cutlass.const_expr(self.cluster_n > 1):
542
+ cute.arch.cluster_arrive()
543
+ cute.arch.cluster_wait()
544
+ else:
545
+ cute.arch.barrier()
546
+ """
547
+
548
+ cute.autovec_copy(dw_accumulator, tDwrDw)
549
+ cute.copy(copy_atom_dw, tDwrDw, tDwgDw, pred=tDwpDw)
550
+
551
+
552
+ def _rmsnorm_backward(
553
+ x: torch.Tensor,
554
+ weight: torch.Tensor,
555
+ dout: torch.Tensor,
556
+ rstd: torch.Tensor,
557
+ ) -> (torch.Tensor, torch.Tensor):
558
+ """RMSNorm backward pass.
559
+ Args:
560
+ x: Input tensor of shape (M, N)
561
+ weight: Weight tensor of shape (N,)
562
+ dout: Upstream gradients tensor of shape (M, N)
563
+ rstd: Reciprocal standard deviation tensor of shape (M,)
564
+ Returns:
565
+ Tuple of (dx, dw) where:
566
+ - dx: Input gradients tensor of same shape as x
567
+ - dw: Weight gradients tensor of same shape as weight
568
+ """
569
+ assert x.dim() == 2, "Input must be 2D"
570
+ assert weight.dim() == 1, "Weight must be 1D"
571
+ assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
572
+ assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
573
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
574
+ assert weight.dtype == torch.float32, "Weight must be float32"
575
+
576
+ M, N = x.shape
577
+ dx = torch.empty_like(x)
578
+
579
+ device = x.device
580
+
581
+ sm_count = torch.cuda.get_device_properties(device).multi_processor_count * 8
582
+ dw_partial = torch.zeros((sm_count, N), device=device, dtype=weight.dtype)
583
+
584
+ dtype = torch2cute_dtype_map[x.dtype]
585
+
586
+ convert_from_dlpack = lambda tensor: (
587
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
588
+ mode=0, stride_order=(0, 1)
589
+ )
590
+ )
591
+
592
+ x_tensor, dout_tensor, dx_tensor = [convert_from_dlpack(tensor) for tensor in (x, dout, dx)]
593
+
594
+ weight_tensor = utils.convert_from_dlpack(
595
+ weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
596
+ )
597
+
598
+ dw_partial_tensor = convert_from_dlpack(dw_partial)
599
+ rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
600
+
601
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
602
+
603
+ compile_key = (dtype, N)
604
+ if compile_key not in _rmsnorm_backward.compile_cache:
605
+ rmsnorm_backward_op = RMSNormBackward(dtype, N)
606
+ _rmsnorm_backward.compile_cache[compile_key] = cute.compile(
607
+ rmsnorm_backward_op,
608
+ x_tensor,
609
+ weight_tensor,
610
+ dout_tensor,
611
+ rstd_tensor,
612
+ dx_tensor,
613
+ dw_partial_tensor,
614
+ sm_count,
615
+ current_stream,
616
+ )
617
+
618
+ _rmsnorm_backward.compile_cache[compile_key](
619
+ x_tensor,
620
+ weight_tensor,
621
+ dout_tensor,
622
+ rstd_tensor,
623
+ dx_tensor,
624
+ dw_partial_tensor,
625
+ current_stream,
626
+ )
627
+
628
+ dw = dw_partial.sum(dim=0).to(weight.dtype)
629
+ return dx, dw
630
+
631
+
632
+ _rmsnorm_backward.compile_cache = {}
633
+
634
+
635
+ class RMSNormFunction(torch.autograd.Function):
636
+ @staticmethod
637
+ def forward(ctx, x, weight, eps):
638
+ out, rstd = _rmsnorm_fwd(x, weight, eps, return_rstd=True)
639
+ ctx.save_for_backward(x, weight, rstd)
640
+ ctx.eps = eps
641
+ return out
642
+
643
+ @staticmethod
644
+ def backward(ctx, dout):
645
+ x, weight, rstd = ctx.saved_tensors
646
+ dx, dw = _rmsnorm_backward(x, weight, dout, rstd)
647
+ # dw is returned for weight gradient, None for eps gradient
648
+ return dx, dw, None
649
+
650
+
651
+ def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
652
+ """RMSNorm forward pass with automatic differentiation support.
653
+
654
+ Args:
655
+ x: Input tensor of shape (M, N)
656
+ weight: Weight tensor of shape (N,)
657
+ eps: Small value for numerical stability
658
+
659
+ Returns:
660
+ Normalized output tensor of same shape as x
661
+ """
662
+ return RMSNormFunction.apply(x, weight, eps)
quack/utils.py CHANGED
@@ -23,20 +23,6 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te
23
23
  )
24
24
 
25
25
 
26
- @cute.jit
27
- def max_constexpr(
28
- a: cutlass.Constexpr[cute.Numeric], b: cutlass.Constexpr[cute.Numeric]
29
- ) -> cutlass.Constexpr[cute.Numeric]:
30
- return a if a > b else b
31
-
32
-
33
- @cute.jit
34
- def min_constexpr(
35
- a: cutlass.Constexpr[cute.Numeric], b: cutlass.Constexpr[cute.Numeric]
36
- ) -> cutlass.Constexpr[cute.Numeric]:
37
- return a if a < b else b
38
-
39
-
40
26
  @cute.jit
41
27
  def warp_reduce(
42
28
  val: cute.TensorSSA | cute.Numeric,
@@ -196,7 +182,7 @@ def row_reduce(
196
182
  val = warp_reduce(
197
183
  val,
198
184
  warp_op,
199
- width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
185
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
200
186
  )
201
187
  if cutlass.const_expr(hook_fn is not None):
202
188
  hook_fn()
@@ -226,7 +212,7 @@ def online_softmax_reduce(
226
212
  max_x = warp_reduce(
227
213
  x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
228
214
  cute.arch.fmax,
229
- width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
215
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
230
216
  )
231
217
  log2_e = math.log2(math.e)
232
218
  exp_x = exp2f(x * log2_e - (max_x * log2_e))
@@ -234,7 +220,7 @@ def online_softmax_reduce(
234
220
  sum_exp_x = warp_reduce(
235
221
  exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
236
222
  operator.add,
237
- width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
223
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
238
224
  )
239
225
  if cutlass.const_expr(hook_fn is not None):
240
226
  hook_fn()
@@ -303,7 +289,6 @@ def online_softmax_reduce(
303
289
  @cute.jit
304
290
  def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
305
291
  """exp2f calculation for both vector and scalar.
306
-
307
292
  :param x: input value
308
293
  :type x: cute.TensorSSA or Float32
309
294
  :return: exp2 value
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.4
3
+ Version: 0.1.5
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
6
  Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
@@ -0,0 +1,11 @@
1
+ quack/__init__.py,sha256=GPoImcynY5-OkMep5RhQhXrnZyxgqZG3RoHhsYQFSL4,203
2
+ quack/cross_entropy.py,sha256=WkngPY8uk4RCjCFtHtB7h9GF_8xt4NnyvDzvw73gIL4,19320
3
+ quack/reduction_base.py,sha256=fFuGXPR3lDq2yw_m86ujmkni6R51jzNAzy_r9R6C8tA,3563
4
+ quack/rmsnorm.py,sha256=N9NavrR85ws4cZgkfpeRLjYkVSq2yfyzJQWvfKf98pY,23935
5
+ quack/softmax.py,sha256=VfhlC2huRuv7olFSVFgS8LF1yF8TFV64yjjjQxYX9yk,16364
6
+ quack/utils.py,sha256=6EyWgf0z3wcbhGUivHmWB8hVBnEzMyOhmAuZ2Te82k0,15226
7
+ quack_kernels-0.1.5.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
8
+ quack_kernels-0.1.5.dist-info/METADATA,sha256=WI-2CP1mRH05V9Fjdx7HsErNOkrc6fUhheoH4ynlo-U,289
9
+ quack_kernels-0.1.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
+ quack_kernels-0.1.5.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
11
+ quack_kernels-0.1.5.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- quack/__init__.py,sha256=cFLxO6nA_faFqHf4N-Fy7G0j8ykuYPB1uOt9uoJ2dkQ,203
2
- quack/cross_entropy.py,sha256=HnF2OErEzb10SWxY6HoYE42lnvlw2DsWCks7mylPwnI,9511
3
- quack/reduction_base.py,sha256=Rsj9ZeSHcKAXGn1p7mY1vrrBqxevi4feLjY0JJhKnmY,3663
4
- quack/rmsnorm.py,sha256=TkOZsXJwcsoZMLnmEWQ-pEF0r-iiZhGrCNLSFCXfv6s,10676
5
- quack/softmax.py,sha256=VfhlC2huRuv7olFSVFgS8LF1yF8TFV64yjjjQxYX9yk,16364
6
- quack/utils.py,sha256=zVc9U-5No19trE585KqDdXx9chAruXPRIPMZdO7mkRg,15603
7
- quack_kernels-0.1.4.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
8
- quack_kernels-0.1.4.dist-info/METADATA,sha256=xl62C5WFgiUbnOICAzjldsljJ9j1Fb_JxZVksHLCI8I,289
9
- quack_kernels-0.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
- quack_kernels-0.1.4.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
11
- quack_kernels-0.1.4.dist-info/RECORD,,