quack-kernels 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.6"
1
+ __version__ = "0.1.7"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
quack/reduction_base.py CHANGED
@@ -68,7 +68,7 @@ class ReductionBase:
68
68
  )
69
69
 
70
70
  def _allocate_reduction_buffer_and_mbar(
71
- self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout
71
+ self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout, is_persistent: bool = False
72
72
  ) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
73
73
  reduction_buffer = smem.allocate_tensor(
74
74
  self.reduction_dtype,
@@ -76,20 +76,28 @@ class ReductionBase:
76
76
  byte_alignment=4,
77
77
  )
78
78
  if cutlass.const_expr(self.cluster_n > 1):
79
- mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=self.stage)
79
+ mbar_ptr = smem.allocate_array(
80
+ cutlass.Int64, num_elems=self.stage if not is_persistent else self.stage * 2
81
+ )
80
82
  else:
81
83
  mbar_ptr = None
82
84
  return reduction_buffer, mbar_ptr
83
85
 
84
86
  @cute.jit
85
- def _initialize_cluster(self, tidx: cutlass.Int32, mbar_ptr: cute.Pointer, num_warps: int):
87
+ def _initialize_cluster(
88
+ self,
89
+ tidx: cutlass.Int32,
90
+ mbar_ptr: cute.Pointer,
91
+ num_warps: int,
92
+ is_persistent: bool = False,
93
+ ):
86
94
  if cutlass.const_expr(self.cluster_n > 1):
87
- if tidx < self.stage:
95
+ if tidx < self.stage: # Initialize full barrier
88
96
  cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
97
+ if cutlass.const_expr(is_persistent): # Initialize empty barrier
98
+ cute.arch.mbarrier_init(
99
+ mbar_ptr + self.stage + tidx, num_warps * self.cluster_n
100
+ )
89
101
  cute.arch.mbarrier_init_fence()
90
- if tidx < self.stage:
91
- cute.arch.mbarrier_arrive_and_expect_tx(
92
- mbar_ptr + tidx, num_warps * self.cluster_n * self.reduction_dtype.width // 8
93
- )
94
102
  # Cluster arrive after barrier init
95
103
  cute.arch.cluster_arrive_relaxed()
quack/rmsnorm.py CHANGED
@@ -1,6 +1,5 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
-
4
3
  import torch
5
4
  from typing import Optional
6
5
 
@@ -157,6 +156,7 @@ class RMSNorm(ReductionBase):
157
156
 
158
157
  # allocate fragments for gmem->rmem
159
158
  tWrW = cute.make_fragment_like(tWgW)
159
+ tWrW.fill(0.0)
160
160
  tXrW = thr_copy_X.retile(tWrW)
161
161
  tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
162
162
 
@@ -300,8 +300,14 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
300
300
 
301
301
  class RMSNormBackward(ReductionBase):
302
302
  def __init__(self, dtype: cutlass.Numeric, N: int):
303
- # 1 stage for computing mean of x_hat * wdy
304
- super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Float32)
303
+ # 2 stages for double buffering when computing mean of x_hat * wdy
304
+ super().__init__(dtype, N, stage=2, reduction_dtype=cutlass.Float32)
305
+ if self.N > 128 * 1024 and self.dtype.width >= 32:
306
+ # Not enough smem
307
+ raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits")
308
+
309
+ def _get_num_threads(self):
310
+ return 128 if self.N <= 4096 else 256
305
311
 
306
312
  def _calculate_threads_per_row(self):
307
313
  N = self.N
@@ -311,44 +317,38 @@ class RMSNormBackward(ReductionBase):
311
317
  else (
312
318
  16
313
319
  if N <= 128
314
- else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
320
+ else (32 if N <= 256 else (64 if N <= 512 else (128 if N <= 4096 else 256)))
315
321
  )
316
322
  )
317
323
 
318
324
  def _set_cluster_n(self):
319
325
  N = self.N
320
- if cutlass.const_expr(self.dtype.width == 16):
321
- cluster_n = (
322
- 1
323
- if N <= 16 * 1024
324
- else (
325
- 2
326
- if N <= 32 * 1024
327
- else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
328
- )
329
- )
330
- else: # fp32
331
- cluster_n = (
332
- 1
333
- if N <= 32 * 1024
334
- else (
335
- 2
336
- if N <= 64 * 1024
337
- else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
338
- )
339
- )
326
+ cluster_n = (
327
+ 1
328
+ if N <= 8 * 1024
329
+ else (2 if N <= 16 * 1024 else (4 if N <= 32 * 1024 else (8 if N <= 64 * 1024 else 16)))
330
+ )
340
331
  self.cluster_n = cluster_n
341
332
 
333
+ def _smem_size_in_bytes(self, tiler_mn, num_warps):
334
+ return (
335
+ # Multiply by 2 since we need space for X and dOut,
336
+ # and multiply by another 2 due to double buffering
337
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2 * 2
338
+ + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
339
+ + self.stage * (cutlass.Int64.width // 8) * 2 # mult 2 as we need 2 mbar per stage
340
+ )
341
+
342
342
  @cute.jit
343
343
  def __call__(
344
344
  self,
345
345
  mX: cute.Tensor,
346
346
  mW: cute.Tensor,
347
- mDout: cute.Tensor,
347
+ mdOut: cute.Tensor,
348
348
  mRstd: cute.Tensor,
349
- mDx: cute.Tensor,
350
- mDw: cute.Tensor,
351
- sm_count: cutlass.Constexpr,
349
+ mdX: cute.Tensor,
350
+ mdW: cute.Tensor,
351
+ sm_count: cutlass.Int32,
352
352
  stream: cuda.CUstream,
353
353
  ):
354
354
  self._set_cluster_n()
@@ -359,14 +359,8 @@ class RMSNormBackward(ReductionBase):
359
359
  mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
360
360
  mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
361
361
 
362
- mRstd_expanded_layout = cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,)))
363
- mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
364
-
365
- num_blocks = (
366
- sm_count if tiler_mn[0] == 1 else min(sm_count, cute.ceil_div(1024, tiler_mn[0]))
367
- )
368
-
369
- self.kernel(mX, mW, mDout, mRstd, mDx, mDw, sm_count, tv_layout, tiler_mn).launch(
362
+ num_blocks = sm_count
363
+ self.kernel(mX, mW, mdOut, mRstd, mdX, mdW, tv_layout, tiler_mn).launch(
370
364
  grid=[num_blocks, self.cluster_n, 1],
371
365
  block=[num_threads, 1, 1],
372
366
  cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
@@ -379,177 +373,244 @@ class RMSNormBackward(ReductionBase):
379
373
  self,
380
374
  mX: cute.Tensor,
381
375
  mW: cute.Tensor,
382
- mDout: cute.Tensor,
376
+ mdOut: cute.Tensor,
383
377
  mRstd: cute.Tensor,
384
- mDx: cute.Tensor,
385
- mDw: cute.Tensor,
386
- sm_count: cutlass.Constexpr,
378
+ mdX: cute.Tensor,
379
+ mdW: cute.Tensor,
387
380
  tv_layout: cute.Layout,
388
381
  tiler_mn: cute.Shape,
389
382
  ):
390
383
  tidx, _, _ = cute.arch.thread_idx()
391
- bidx, cluster_y, _ = cute.arch.block_idx()
384
+ bidx_start, _, _ = cute.arch.block_idx()
392
385
  gdim, _, _ = cute.arch.grid_dim()
386
+ if cutlass.const_expr(self.cluster_n > 1):
387
+ cluster_y = cute.arch.block_idx()[1]
388
+ else:
389
+ cluster_y = cutlass.const_expr(0)
393
390
 
394
391
  shape = mX.shape
395
392
  M, N = shape[0], shape[1]
393
+ is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
396
394
 
397
395
  idX = cute.make_identity_tensor(shape)
398
396
 
399
397
  smem = cutlass.utils.SmemAllocator()
400
- reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
398
+ smem_layout = cute.make_ordered_layout((tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2))
399
+ sX = smem.allocate_tensor(mX.element_type, smem_layout, byte_alignment=16)
400
+ sdOut = smem.allocate_tensor(mdOut.element_type, smem_layout, byte_alignment=16)
401
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
402
+ smem, tv_layout, is_persistent=True
403
+ )
404
+ if cutlass.const_expr(mbar_ptr is not None):
405
+ mbar_full_ptr, mbar_empty_ptr = mbar_ptr, mbar_ptr + 2
406
+ else:
407
+ mbar_full_ptr, mbar_empty_ptr = None, None
401
408
 
402
409
  copy_atom_load_X = cute.make_copy_atom(
403
410
  cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
404
411
  )
405
-
412
+ copy_atom_load_X_async = cute.make_copy_atom(
413
+ cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
414
+ )
406
415
  copy_atom_load_W = cute.make_copy_atom(
407
416
  cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
408
417
  )
409
-
410
418
  copy_atom_store_dX = cute.make_copy_atom(
411
- cute.nvgpu.CopyUniversalOp(), mDx.element_type, num_bits_per_copy=128
419
+ cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=128
412
420
  )
413
-
414
- copy_atom_dw = cute.make_copy_atom(
415
- cute.nvgpu.CopyUniversalOp(), mDw.element_type, num_bits_per_copy=128
421
+ copy_atom_store_dW = cute.make_copy_atom(
422
+ cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=128
416
423
  )
417
424
 
418
425
  thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
426
+ thr_copy_X_async = cute.make_tiled_copy(
427
+ copy_atom_load_X_async, tv_layout, tiler_mn
428
+ ).get_slice(tidx)
419
429
  thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
420
- thr_copy_dw = cute.make_tiled_copy(copy_atom_dw, tv_layout, tiler_mn).get_slice(tidx)
421
- thr_store_dx = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
430
+ thr_copy_dW = cute.make_tiled_copy(copy_atom_store_dW, tv_layout, tiler_mn).get_slice(tidx)
431
+ thr_store_dX = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
422
432
 
423
- gW = cute.local_tile(mW, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
433
+ gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
424
434
  tWgW = thr_copy_W.partition_S(gW)
425
435
  tWrW = cute.make_fragment_like(tWgW)
436
+ # Need this, otherwise rW can have arbitrary values that changes the reduction
437
+ if not is_even_N:
438
+ tWrW.fill(0.0)
426
439
  tXrW = thr_copy_X.retile(tWrW)
427
440
 
428
- gW_coord = cute.local_tile(idX, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
429
-
430
- tWpW = utils.predicate_k(thr_copy_W.partition_S(gW_coord), limit=shape[1])
441
+ gW_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
442
+ tWpW = (
443
+ utils.predicate_k(thr_copy_W.partition_S(gW_coord), limit=shape[1])
444
+ if not is_even_N
445
+ else None
446
+ )
431
447
  cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
432
448
  weight = tXrW.load().to(cute.Float32)
433
449
 
434
450
  num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
435
451
 
436
- self._initialize_cluster(tidx, mbar_ptr, num_warps)
437
-
438
- dw_coord = cute.local_tile(idX, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
439
- tDwpDw = utils.predicate_k(thr_copy_dw.partition_S(dw_coord), limit=shape[1])
452
+ self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True)
440
453
 
441
- gDw = cute.local_tile(mDw, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
442
- tDwgDw = thr_copy_dw.partition_D(gDw)
443
- tDwrDw = cute.make_fragment_like(tDwgDw)
444
- dw_accumulator = thr_copy_X.retile(tDwrDw)
445
- dw_accumulator.fill(0.0)
446
-
447
- M_pad = ((M + sm_count - 1) // sm_count) * sm_count
454
+ dw_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
455
+ tdWpdW = (
456
+ utils.predicate_k(thr_copy_dW.partition_S(dw_coord), limit=shape[1])
457
+ if not is_even_N
458
+ else None
459
+ )
448
460
 
449
- jump = sm_count if tiler_mn[0] == 1 else min(sm_count, cute.ceil_div(1024, tiler_mn[0]))
461
+ gdW = cute.local_tile(mdW, (1, tiler_mn[1]), (bidx_start, cluster_y))
462
+ tdWgdW = thr_copy_dW.partition_D(gdW)
463
+ tdWrdW = cute.make_fragment_like(tdWgdW, cutlass.Float32)
464
+ tXrdW = thr_copy_X.retile(tdWrdW)
450
465
 
451
- if cutlass.const_expr(self.cluster_n > 1):
452
- cute.arch.cluster_arrive()
453
- cute.arch.cluster_wait()
466
+ gX = cute.local_tile(mX, tiler_mn, (None, cluster_y))
467
+ gdOut = cute.local_tile(mdOut, tiler_mn, (None, cluster_y))
468
+ gdX = cute.local_tile(mdX, tiler_mn, (None, cluster_y))
469
+ cX = cute.local_tile(idX, tiler_mn, (None, cluster_y))
470
+ tXgX = thr_copy_X.partition_S(gX)
471
+ tXsX = thr_copy_X.partition_D(sX)
472
+ tXgdOut = thr_copy_X.partition_S(gdOut)
473
+ tXsdOut = thr_copy_X.partition_D(sdOut)
474
+ tXgdX = thr_store_dX.partition_D(gdX)
475
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None]
476
+ # This doesn't change across iterations
477
+ tXpX = (
478
+ utils.predicate_k(thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1])
479
+ if not is_even_N
480
+ else None
481
+ )
454
482
 
455
- ## need to update range_dynamic since it will be deprecated soon
456
- for row_offset in cutlass.range_dynamic(bidx, M_pad, jump):
457
- gX = cute.local_tile(
458
- mX, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
483
+ tXrX, tXrdOut, tXrdX = [
484
+ cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdOut, tXgdX)
485
+ ]
486
+
487
+ # Prefetch the first batch
488
+ row = tXcX[None, None, None, bidx_start][0][0]
489
+ if row < M:
490
+ tXgX_cur = utils.coord_offset_i64(bidx_start, tXgX, dim=3)[None, None, None, 0]
491
+ tXgdOut_cur = utils.coord_offset_i64(bidx_start, tXgdOut, dim=3)[None, None, None, 0]
492
+ cute.copy(
493
+ copy_atom_load_X_async,
494
+ tXgX_cur,
495
+ tXsX[None, None, None, 0],
496
+ pred=tXpX,
459
497
  )
460
- gDout = cute.local_tile(
461
- mDout, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
498
+ cute.copy(
499
+ copy_atom_load_X_async,
500
+ tXgdOut_cur,
501
+ tXsdOut[None, None, None, 0],
502
+ pred=tXpX,
462
503
  )
463
- gRstd = cute.local_tile(
464
- mRstd, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
465
- )
466
- gDx = cute.local_tile(
467
- mDx, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
468
- )
469
- cX = cute.local_tile(
470
- idX, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
471
- )
472
-
473
- tXgX = thr_copy_X.partition_S(gX)
474
- thrDout = thr_copy_X.partition_S(gDout)
475
- tXrRstd = thr_copy_W.partition_S(gRstd)
476
- thrDx = thr_store_dx.partition_D(gDx)
477
- tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
478
-
479
- tXrX, frgDout, frgDx = [cute.make_fragment_like(thr) for thr in (tXgX, thrDout, thrDx)]
480
-
481
- tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
504
+ elif tiler_mn[0] > 1:
505
+ # Fill with zero, otherwise smem will be uninitialized, and we could read this back
506
+ # later into registers, causing wrong dW.
507
+ utils.fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero)
508
+ utils.fill_oob(tXsdOut[None, None, None, 0], None, fill_value=mdOut.element_type.zero)
509
+ cute.arch.cp_async_commit_group()
482
510
 
483
- if tXcX[0][0] < shape[0]:
484
- cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
485
- cute.copy(copy_atom_load_X, thrDout, frgDout, pred=tXpX)
511
+ if cutlass.const_expr(self.cluster_n > 1):
512
+ cute.arch.cluster_wait()
486
513
 
514
+ threads_per_row = tv_layout.shape[0][0]
515
+ tXrdW.fill(0.0)
516
+ stage = cutlass.Int32(0)
517
+ producer_phase = cutlass.Int32(1)
518
+ consumer_phase = cutlass.Int32(0)
519
+ for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
520
+ row = tXcX[None, None, None, bidx][0][0]
521
+ rstd = cutlass.Float.zero
522
+ if row + gdim * tiler_mn[0] < M: # Prefetch the next batch
523
+ tXgX_cur = utils.coord_offset_i64(bidx + gdim, tXgX, dim=3)[None, None, None, 0]
524
+ tXgdOut_cur = utils.coord_offset_i64(bidx + gdim, tXgdOut, dim=3)[
525
+ None, None, None, 0
526
+ ]
527
+ cute.copy(
528
+ copy_atom_load_X_async,
529
+ tXgX_cur,
530
+ tXsX[None, None, None, stage ^ 1],
531
+ pred=tXpX,
532
+ )
533
+ cute.copy(
534
+ copy_atom_load_X_async,
535
+ tXgdOut_cur,
536
+ tXsdOut[None, None, None, stage ^ 1],
537
+ pred=tXpX,
538
+ )
539
+ elif tiler_mn[0] > 1:
540
+ utils.fill_oob(
541
+ tXsX[None, None, None, stage ^ 1], None, fill_value=mX.element_type.zero
542
+ )
543
+ utils.fill_oob(
544
+ tXsdOut[None, None, None, stage ^ 1], None, fill_value=mdOut.element_type.zero
545
+ )
546
+ cute.arch.cp_async_commit_group()
547
+ if row < M or tiler_mn[0] == 1:
548
+ rstd = mRstd[row]
549
+ cute.arch.cp_async_wait_group(1)
550
+ cute.autovec_copy(tXsX[None, None, None, stage], tXrX)
487
551
  x = tXrX.load().to(cute.Float32)
488
- dout = frgDout.load().to(cute.Float32)
489
-
490
- rstd = tXrRstd[0]
552
+ cute.autovec_copy(tXsdOut[None, None, None, stage], tXrdOut)
553
+ dout = tXrdOut.load().to(cute.Float32)
491
554
  x_hat = x * rstd
492
555
  wdy = dout * weight
493
-
494
- threads_per_row = tv_layout.shape[0][0]
495
-
496
- row = tXcX[0][0]
497
556
  if cutlass.const_expr(self.cluster_n > 1):
498
- cute.arch.cluster_arrive()
499
- cute.arch.cluster_wait()
500
- else:
501
- cute.arch.barrier()
502
-
557
+ cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
503
558
  mean_xhat_wdy = (
504
559
  utils.row_reduce(
505
560
  x_hat * wdy,
506
561
  cute.ReductionOp.ADD,
507
562
  threads_per_row,
508
- reduction_buffer[None, None, 0],
509
- mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
563
+ reduction_buffer[None, None, stage],
564
+ mbar_full_ptr + stage if cutlass.const_expr(self.cluster_n > 1) else None,
565
+ phase=consumer_phase,
510
566
  init_val=0.0,
511
- hook_fn=cute.arch.cluster_wait
512
- if cutlass.const_expr(self.cluster_n > 1)
513
- else None,
514
567
  )
515
568
  / shape[1]
516
569
  )
517
-
518
- dx = (wdy - x_hat * mean_xhat_wdy) * rstd
519
- frgDx.store(dx.to(frgDout.element_type))
520
-
521
- if row < M:
522
- cute.copy(copy_atom_store_dX, frgDx, thrDx, pred=tXpX)
523
-
524
570
  if cutlass.const_expr(self.cluster_n > 1):
525
- cute.arch.cluster_arrive()
526
- cute.arch.cluster_wait()
527
- else:
528
- cute.arch.barrier()
529
-
530
- if row < M:
531
- dw_row = dout * x_hat
532
- current_dw = dw_accumulator.load().to(cute.Float32)
533
- updated_dw = current_dw + dw_row
534
- dw_accumulator.store(updated_dw.to(dw_accumulator.element_type))
535
-
536
- """
537
- if cutlass.const_expr(self.cluster_n > 1):
538
- cute.arch.cluster_arrive()
539
- cute.arch.cluster_wait()
540
- else:
541
- cute.arch.barrier()
542
- """
543
- """
544
- if cutlass.const_expr(self.cluster_n > 1):
545
- cute.arch.cluster_arrive()
546
- cute.arch.cluster_wait()
547
- else:
571
+ # It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
572
+ # Requires adjusting the thread_count when initializing the mbar
573
+ cute.arch.sync_warp()
574
+ lane_idx = cute.arch.lane_idx()
575
+ if lane_idx < self.cluster_n:
576
+ cute.arch.mbarrier_arrive(
577
+ mbar_empty_ptr + stage, peer_cta_rank_in_cluster=lane_idx
578
+ )
579
+ dx = (wdy - x_hat * mean_xhat_wdy) * rstd
580
+ tXrdX.store(dx.to(tXrdOut.element_type))
581
+ if row < M or tiler_mn[0] == 1:
582
+ tXgdX_cur = utils.coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
583
+ cute.copy(copy_atom_store_dX, tXrdX, tXgdX_cur, pred=tXpX)
584
+ tXrdW.store(tXrdW.load() + dout * x_hat)
585
+ stage ^= 1
586
+ if stage == 0:
587
+ consumer_phase ^= 1
588
+ producer_phase ^= 1
589
+
590
+ if cutlass.const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
591
+ cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
592
+
593
+ if cutlass.const_expr(tiler_mn[0] > 1):
594
+ # reduction of dw_partial within the same threadblock
595
+ sdW = cute.make_tensor(
596
+ cute.recast_ptr(sX.iterator, dtype=cute.Float32),
597
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
598
+ )
599
+ tXsdW = thr_copy_X.partition_D(sdW)
548
600
  cute.arch.barrier()
549
- """
550
-
551
- cute.autovec_copy(dw_accumulator, tDwrDw)
552
- cute.copy(copy_atom_dw, tDwrDw, tDwgDw, pred=tDwpDw)
601
+ row = tXcX[None, None, None, 0][0][0]
602
+ if row > 0:
603
+ cute.autovec_copy(tXrdW, tXsdW)
604
+ cute.arch.barrier()
605
+ if row == 0:
606
+ for i in cutlass.range_constexpr(1, cutlass.const_expr(tiler_mn[0])):
607
+ tXrdW_other = cute.make_fragment_like(tXrdW)
608
+ tXsdW_other = cute.make_tensor(tXsdW.iterator + i * sdW.stride[0], tXsdW.layout)
609
+ cute.autovec_copy(tXsdW_other, tXrdW_other)
610
+ tXrdW.store(tXrdW.load() + tXrdW_other.load())
611
+ cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
612
+ else:
613
+ cute.copy(copy_atom_store_dW, tdWrdW, tdWgdW, pred=tdWpdW)
553
614
 
554
615
 
555
616
  def _rmsnorm_backward(
@@ -581,8 +642,19 @@ def _rmsnorm_backward(
581
642
 
582
643
  device = x.device
583
644
 
584
- sm_count = torch.cuda.get_device_properties(device).multi_processor_count * 8
585
- dw_partial = torch.zeros((sm_count, N), device=device, dtype=weight.dtype)
645
+ # This should be tuned on how many CTAs can be launched on each SM
646
+ sm_count_multiple = (
647
+ 16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
648
+ )
649
+ sm_count = torch.cuda.get_device_properties(device).multi_processor_count
650
+ # By right, if we're using cluster, this should be cluster_count not sm_count.
651
+ # But for cluster >= 4, due to quantization we would need to query active max cluster.
652
+ # Instead we just do sm_count * 2, which is reasonably larger than active_cluster_count to
653
+ # avoid wave quantization.
654
+ sm_count = (
655
+ sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
656
+ )
657
+ dw_partial = torch.empty(sm_count, N, device=device, dtype=weight.dtype)
586
658
 
587
659
  dtype = torch2cute_dtype_map[x.dtype]
588
660
 
@@ -625,6 +697,7 @@ def _rmsnorm_backward(
625
697
  rstd_tensor,
626
698
  dx_tensor,
627
699
  dw_partial_tensor,
700
+ sm_count,
628
701
  current_stream,
629
702
  )
630
703
 
quack/softmax.py CHANGED
@@ -133,9 +133,7 @@ class Softmax(ReductionBase):
133
133
 
134
134
  is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
135
135
  tXpX = (
136
- utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
137
- if cutlass.const_expr(not is_even_N)
138
- else None
136
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
139
137
  )
140
138
  if tXcX[0][0] < shape[0]:
141
139
  cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
quack/utils.py CHANGED
@@ -120,12 +120,20 @@ def cluster_reduce(
120
120
  reduction_buffer: cute.Tensor,
121
121
  mbar_ptr: cute.Pointer,
122
122
  init_val: cute.Numeric = 0.0,
123
+ phase: Optional[cutlass.Int32] = None,
123
124
  ) -> cute.Numeric:
124
125
  """reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
125
126
  cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
126
127
  lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
127
- warps_per_row, cluster_n = reduction_buffer.shape[1]
128
+ rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
128
129
  row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
130
+ if warp_idx == 0:
131
+ with cute.arch.elect_one():
132
+ num_warps = rows_per_block * warps_per_row
133
+ cute.arch.mbarrier_arrive_and_expect_tx(
134
+ mbar_ptr,
135
+ num_warps * cluster_n * reduction_buffer.element_type.width // 8,
136
+ )
129
137
  if lane_idx < cluster_n:
130
138
  store_shared_remote(
131
139
  val,
@@ -133,7 +141,7 @@ def cluster_reduce(
133
141
  mbar_ptr,
134
142
  peer_cta_rank_in_cluster=lane_idx,
135
143
  )
136
- cute.arch.mbarrier_wait(mbar_ptr, phase=0)
144
+ cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
137
145
  block_reduce_val = init_val
138
146
  num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
139
147
  for i in cutlass.range_constexpr(num_iter):
@@ -149,13 +157,14 @@ def block_or_cluster_reduce(
149
157
  op: Callable,
150
158
  reduction_buffer: cute.Tensor,
151
159
  mbar_ptr: Optional[cute.Pointer],
160
+ phase: Optional[cutlass.Int32] = None,
152
161
  init_val: cute.Numeric = 0.0,
153
162
  ) -> cute.Numeric:
154
163
  """Perform either block or cluster reduction based on whether mbar_ptr is provided."""
155
164
  if cutlass.const_expr(mbar_ptr is None):
156
165
  return block_reduce(val, op, reduction_buffer, init_val=init_val)
157
166
  else:
158
- return cluster_reduce(val, op, reduction_buffer, mbar_ptr, init_val=init_val)
167
+ return cluster_reduce(val, op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val)
159
168
 
160
169
 
161
170
  @cute.jit
@@ -165,6 +174,7 @@ def row_reduce(
165
174
  threads_per_row: cutlass.Constexpr[int],
166
175
  reduction_buffer: Optional[cute.Tensor] = None,
167
176
  mbar_ptr: Optional[cute.Pointer] = None,
177
+ phase: Optional[cutlass.Int32] = None,
168
178
  init_val: cute.Numeric = 0.0,
169
179
  hook_fn: Optional[Callable] = None,
170
180
  ) -> cute.Numeric:
@@ -193,7 +203,7 @@ def row_reduce(
193
203
  ), "mbar_ptr must be provided for cluster reduction"
194
204
  if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
195
205
  val = block_or_cluster_reduce(
196
- val, warp_op, reduction_buffer, mbar_ptr, init_val=init_val
206
+ val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val
197
207
  )
198
208
  return val
199
209
 
@@ -205,6 +215,7 @@ def online_softmax_reduce(
205
215
  reduction_buffer: Optional[cute.Tensor] = None,
206
216
  mbar_ptr: Optional[cute.Pointer] = None,
207
217
  hook_fn: Optional[Callable] = None,
218
+ phase: Optional[cutlass.Int32] = None,
208
219
  return_exp_x: bool = False,
209
220
  ) -> [Float32, Float32, Optional[cute.TensorSSA]]:
210
221
  assert x.dtype == Float32, "x must be of type Float32"
@@ -225,7 +236,7 @@ def online_softmax_reduce(
225
236
  if cutlass.const_expr(hook_fn is not None):
226
237
  hook_fn()
227
238
  if cutlass.const_expr(reduction_buffer is not None):
228
- warps_per_row, cluster_n = reduction_buffer.shape[1]
239
+ rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
229
240
  assert (
230
241
  cluster_n == 1 or mbar_ptr is not None
231
242
  ), "mbar_ptr must be provided for cluster reduction"
@@ -251,6 +262,13 @@ def online_softmax_reduce(
251
262
  max_x = max_x_final
252
263
  else:
253
264
  cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
265
+ if warp_idx == 0:
266
+ with cute.arch.elect_one():
267
+ num_warps = rows_per_block * warps_per_row
268
+ cute.arch.mbarrier_arrive_and_expect_tx(
269
+ mbar_ptr,
270
+ num_warps * cluster_n * reduction_buffer.element_type.width // 8,
271
+ )
254
272
  if lane_idx < cluster_n:
255
273
  store_shared_remote(
256
274
  f32x2_to_i64(max_x, sum_exp_x),
@@ -258,7 +276,7 @@ def online_softmax_reduce(
258
276
  mbar_ptr,
259
277
  peer_cta_rank_in_cluster=lane_idx,
260
278
  )
261
- cute.arch.mbarrier_wait(mbar_ptr, phase=0)
279
+ cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
262
280
  num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
263
281
  max_x_single_warp = cute.make_fragment(num_iter, Float32)
264
282
  max_x_single_warp.fill(-Float32.inf)
@@ -351,7 +369,7 @@ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
351
369
 
352
370
 
353
371
  @cute.jit
354
- def fill_oob(tXsX: cute.Tensor, tXpX: cute.Tensor, fill_value: cute.Numeric) -> None:
372
+ def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Numeric) -> None:
355
373
  """Fill out-of-bounds values in shared memory tensor.
356
374
 
357
375
  Args:
@@ -361,9 +379,12 @@ def fill_oob(tXsX: cute.Tensor, tXpX: cute.Tensor, fill_value: cute.Numeric) ->
361
379
  """
362
380
  tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), 0, 0])
363
381
  tXrX_fill.fill(fill_value)
364
- for rest_v in cutlass.range_constexpr(tXpX.shape[0]):
365
- for rest_k in cutlass.range_constexpr(tXpX.shape[2]):
366
- if not tXpX[rest_v, 0, rest_k]:
382
+ for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]):
383
+ for rest_k in cutlass.range_constexpr(tXsX.shape[2]):
384
+ if cutlass.const_expr(tXpX is not None):
385
+ if not tXpX[rest_v, 0, rest_k]:
386
+ cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
387
+ else:
367
388
  cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
368
389
 
369
390
 
@@ -396,6 +417,9 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
396
417
  def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
397
418
  flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
398
419
  flat_stride = cute.flatten_to_tuple(tensor.stride)
420
+ assert len(flat_coord_i64) == len(
421
+ flat_stride
422
+ ), "Coordinate and stride must have the same length"
399
423
  offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
400
424
  assert isinstance(tensor.iterator, cute.Pointer)
401
425
  # HACK: we assume that applying the offset does not change the pointer alignment
@@ -406,3 +430,19 @@ def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=No
406
430
  assumed_align=tensor.iterator.max_alignment,
407
431
  )
408
432
  return cute.make_tensor(new_ptr, tensor.layout)
433
+
434
+
435
+ @dsl_user_op
436
+ def coord_offset_i64(
437
+ idx: cute.typing.Int, tensor: cute.Tensor, dim: int, *, loc=None, ip=None
438
+ ) -> cute.Tensor:
439
+ offset = cutlass.Int64(idx) * cute.size(tensor.stride[dim])
440
+ assert isinstance(tensor.iterator, cute.Pointer)
441
+ # HACK: we assume that applying the offset does not change the pointer alignment
442
+ new_ptr = cute.make_ptr(
443
+ tensor.element_type,
444
+ tensor.iterator.toint() + offset * tensor.element_type.width // 8,
445
+ tensor.memspace,
446
+ assumed_align=tensor.iterator.max_alignment,
447
+ )
448
+ return cute.make_tensor(new_ptr, tensor.layout)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.6
3
+ Version: 0.1.7
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
6
  Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
@@ -0,0 +1,12 @@
1
+ quack/__init__.py,sha256=R9cZd_vslI5oZjjS-ojfWAd9tCZAqsLUiFVqEbUaGnw,203
2
+ quack/cross_entropy.py,sha256=bg66wECki5I71SMPIRUa-6-oFJ93aIKpK1jqT__SCBM,19775
3
+ quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
4
+ quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
5
+ quack/rmsnorm.py,sha256=3jiwWhVmaG0n5vuUnGGrpg3StAB4lnzziNF97QVMLGQ,28870
6
+ quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
7
+ quack/utils.py,sha256=aiyzBc9BEwq8s965elfiR331hAaLLBKL9kDHjuls86Q,17791
8
+ quack_kernels-0.1.7.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
9
+ quack_kernels-0.1.7.dist-info/METADATA,sha256=9RlqUmX3-7BI2aZk88r84B8o2FzZkQgkfV1UxwN8GlE,289
10
+ quack_kernels-0.1.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
11
+ quack_kernels-0.1.7.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
12
+ quack_kernels-0.1.7.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- quack/__init__.py,sha256=EV_43VfcxQUsFu_Nfq0944ImloZ2T9X94-8T9IQM0I8,203
2
- quack/cross_entropy.py,sha256=bg66wECki5I71SMPIRUa-6-oFJ93aIKpK1jqT__SCBM,19775
3
- quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
4
- quack/reduction_base.py,sha256=fFuGXPR3lDq2yw_m86ujmkni6R51jzNAzy_r9R6C8tA,3563
5
- quack/rmsnorm.py,sha256=rbRP_O-EJYvrvxYPFwfy0yCbG7Qs_DgbzgacxTLvih4,24159
6
- quack/softmax.py,sha256=b-QEQiGjOEPidolE2--K21kwEaLJ-3wQoGIz_BsEcSI,16742
7
- quack/utils.py,sha256=laz_lqeggiIOY_xCs3c3VvCPP8bJmSKBfjFpFR1Al80,15946
8
- quack_kernels-0.1.6.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
9
- quack_kernels-0.1.6.dist-info/METADATA,sha256=NKhSoudW9lNdYryNiMkyjKXYl5n37cuFMbyLv3zr5js,289
10
- quack_kernels-0.1.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
11
- quack_kernels-0.1.6.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
12
- quack_kernels-0.1.6.dist-info/RECORD,,