quack-kernels 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.1"
1
+ __version__ = "0.1.2"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
quack/cross_entropy.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import math
2
2
  import torch
3
3
  import operator
4
- from typing import Callable, Union
4
+ from typing import Callable, Union, Optional
5
5
 
6
6
  import cuda.bindings.driver as cuda
7
7
 
@@ -17,37 +17,29 @@ def cross_entropy_kernel(
17
17
  mX: cute.Tensor, # (M, N)
18
18
  mTarget: cute.Tensor, # (M,)
19
19
  mLoss: cute.Tensor, # (M,)
20
+ mLSE: Optional[cute.Tensor], # (M,)
20
21
  tv_layout: cute.Layout,
21
22
  tiler_mn: cute.Shape,
22
23
  cluster_n: cutlass.Constexpr = 1,
23
24
  ):
24
25
  tidx, _, _ = cute.arch.thread_idx()
25
26
  bidx, cluster_y, _ = cute.arch.block_idx()
26
- gdim, _, _ = cute.arch.grid_dim()
27
27
 
28
28
  shape: cute.Shape = mX.shape
29
- idX = cute.make_identity_tensor(mX.shape)
30
- gX, cX = [cute.zipped_divide(mT, tiler_mn) for mT in (mX, idX)]
31
- blkX, blkCrd = [gT[(None, None), bidx if cluster_n == 1 else (bidx, cluster_y)] for gT in (gX, cX)]
32
-
33
- # declare the atoms which will be used later for memory copy
34
- copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128)
35
- copy_atom_load_X_async = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128)
36
- copy_atom_scalar = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=gX.element_type.width)
37
-
38
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
39
- thr_copy_X_async = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(tidx)
29
+ idX = cute.make_identity_tensor(shape)
30
+ # slice for CTAs
31
+ gX, cX = [
32
+ cute.local_tile(mT, tiler_mn, (bidx, 0 if cluster_n == 1 else cluster_y))
33
+ for mT in (mX, idX)
34
+ ]
40
35
 
41
36
  smem = cutlass.utils.SmemAllocator()
42
-
43
- # Don't use blkX.layout here, because the stride is N, not N_rounded
44
- sX = smem.allocate_tensor(gX.element_type, cute.make_ordered_layout(blkX.shape, order=(1, 0)), byte_alignment=16)
37
+ sX = smem.allocate_tensor(mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16)
45
38
  num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
46
39
  warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
47
-
48
40
  reduction_buffer_layout = cute.make_ordered_layout(
49
41
  # 2 stages: 1 for max, 1 for sum
50
- (num_warps // warps_per_row, warps_per_row if cluster_n == 1 else (warps_per_row, cluster_n), 2),
42
+ (num_warps // warps_per_row, (warps_per_row, cluster_n), 2),
51
43
  order=(1, 0, 2)
52
44
  )
53
45
  reduction_buffer = smem.allocate_tensor(cutlass.Float32, reduction_buffer_layout, byte_alignment=4)
@@ -57,14 +49,15 @@ def cross_entropy_kernel(
57
49
  else:
58
50
  mbar_ptr = None
59
51
 
60
- #### Thread View
61
- tXgX = thr_copy_X_async.partition_S(blkX)
62
- tXsX = thr_copy_X_async.partition_S(sX)
63
-
64
- tXcX = thr_copy_X.partition_S(blkCrd)[(0, None), None, None]
52
+ # declare the atoms which will be used later for memory copy
53
+ copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128)
54
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
65
55
 
66
- # allocate fragments for gmem->rmem
67
- tXrX = cute.make_fragment_like(tXgX) # only logits fragment needed
56
+ #### Thread View
57
+ tXgX = thr_copy_X.partition_S(gX)
58
+ tXsX = thr_copy_X.partition_D(sX)
59
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
60
+ tXrX = cute.make_fragment_like(tXgX)
68
61
 
69
62
  if cluster_n > 1:
70
63
  if tidx < 2:
@@ -80,54 +73,57 @@ def cross_entropy_kernel(
80
73
  if row < shape[0] and tXcX[0][1] == 0:
81
74
  target = cute.Int32(mTarget[row])
82
75
 
83
- tXpX = cute.make_fragment_like(tXgX[(0, None), None, None], cutlass.Boolean)
84
- for i in range(cute.size(tXpX)):
85
- tXpX[i] = cute.elem_less(tXcX[i][1], shape[1])
76
+ is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * cluster_n)
77
+ tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
86
78
  if row < shape[0]:
87
- cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
79
+ cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
88
80
  cute.arch.cp_async_commit_group()
89
81
  cute.arch.cp_async_wait_group(0)
90
82
  cute.autovec_copy(tXsX, tXrX)
91
83
  x = tXrX.load().to(cute.Float32)
84
+ # Fill OOB values with -inf
85
+ if cutlass.const_expr(not is_even_N):
86
+ tXrX_fp32 = cute.make_fragment_like(tXrX, cutlass.Float32)
87
+ tXrX_fp32.store(x)
88
+ for rest_v in range(tXpX.shape[0]):
89
+ for rest_k in range(tXpX.shape[2]):
90
+ if not tXpX[rest_v, 0, rest_k]:
91
+ tXrX_fp32[(None, rest_v), None, rest_k].fill(-cutlass.Float32.inf)
92
+ x = tXrX_fp32.load()
92
93
 
93
94
  target_logit = cute.Float32.zero
94
95
  if row < shape[0] and tXcX[0][1] == 0:
95
96
  target_logit = cute.Float32(mX[row, target])
96
97
 
97
- max_x = utils.warp_reduce(
98
- x.reduce(cute.ReductionOp.MAX, init_val=float('-inf'), reduction_profile=0),
99
- cute.arch.fmax,
100
- width=utils.min_constexpr(tv_layout.shape[0][0], cute.arch.WARP_SIZE),
98
+ threads_per_row = tv_layout.shape[0][0]
99
+ max_x = utils.row_reduce(
100
+ x,
101
+ cute.ReductionOp.MAX,
102
+ threads_per_row,
103
+ reduction_buffer[None, None, 0],
104
+ mbar_ptr + 0 if cluster_n > 1 else None,
105
+ init_val=-cutlass.Float32.inf,
106
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(cluster_n > 1) else None
101
107
  )
102
- if cutlass.const_expr(cluster_n > 1):
103
- cute.arch.cluster_wait()
104
- if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
105
- max_mbar_ptr = mbar_ptr + 0 if cluster_n > 1 else None
106
- max_x = utils.block_or_cluster_reduce(
107
- max_x, cute.arch.fmax, reduction_buffer[None, None, 0], max_mbar_ptr, init_val=-cutlass.Float32.inf
108
- )
109
108
  log2_e = math.log2(math.e)
110
109
  # exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
111
110
  exp_x = utils.exp2f((x - max_x) * log2_e) # a bit faster, idk why
112
- denom = utils.warp_reduce(
113
- exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
114
- operator.add,
115
- width=utils.min_constexpr(tv_layout.shape[0][0], cute.arch.WARP_SIZE),
111
+ denom = utils.row_reduce(
112
+ exp_x,
113
+ cute.ReductionOp.ADD,
114
+ threads_per_row,
115
+ reduction_buffer[None, None, 1],
116
+ mbar_ptr + 1 if cluster_n > 1 else None,
117
+ init_val=0.0,
116
118
  )
117
- if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
118
- sum_mbar_ptr = mbar_ptr + 1 if cluster_n > 1 else None
119
- denom = utils.block_or_cluster_reduce(
120
- denom, operator.add, reduction_buffer[None, None, 1], sum_mbar_ptr, init_val=0.0
121
- )
122
119
 
123
- if tXcX[0][1] == 0 and row < shape[0]:
120
+ if tXcX[0][1] == 0 and row < shape[0] and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0):
124
121
  ln_2 = math.log(2.0)
125
- loss_val = -target_logit + max_x + utils.log2f(denom) * ln_2
126
- if cutlass.const_expr(cluster_n == 1):
127
- mLoss[row] = loss_val.to(mLoss.element_type)
128
- else:
129
- if cute.arch.block_idx_in_cluster() == 0:
130
- mLoss[row] = loss_val.to(mLoss.element_type)
122
+ lse = max_x + utils.log2f(denom) * ln_2
123
+ loss_val = lse - target_logit
124
+ mLoss[row] = loss_val.to(mLoss.element_type)
125
+ if cutlass.const_expr(mLSE is not None):
126
+ mLSE[row] = lse
131
127
 
132
128
 
133
129
  @cute.jit
@@ -135,6 +131,7 @@ def cross_entropy_interface(
135
131
  mX: cute.Tensor,
136
132
  mTarget: cute.Tensor,
137
133
  mLoss: cute.Tensor,
134
+ mLSE: Optional[cute.Tensor],
138
135
  stream: cuda.CUstream,
139
136
  N: cutlass.Constexpr,
140
137
  copy_bits: cutlass.Constexpr = 128
@@ -161,7 +158,7 @@ def cross_entropy_interface(
161
158
  )
162
159
 
163
160
  smem_allocated = cute.size_in_bytes(mX.element_type, cute.make_layout(tiler_mn)) + 2 * num_warps * cluster_n * (cutlass.Float32.width // 8) + 2 * (cutlass.Int64.width // 8)
164
- cross_entropy_kernel(mX, mTarget, mLoss, tv_layout, tiler_mn, cluster_n).launch(
161
+ cross_entropy_kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn, cluster_n).launch(
165
162
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1],
166
163
  block=[cute.size(tv_layout, mode=[0]), 1, 1],
167
164
  # Launching with cluster=[1, 1, 1] instead of None slows down the kernel by ~8us
@@ -181,6 +178,7 @@ torch2cute_dtype_map = {
181
178
  def cross_entropy(
182
179
  x: torch.Tensor,
183
180
  target: torch.Tensor,
181
+ return_lse: bool = False,
184
182
  ) -> torch.Tensor:
185
183
  """Cross entropy forward pass.
186
184
 
@@ -199,7 +197,8 @@ def cross_entropy(
199
197
  assert target.dtype == torch.int64, "Target must be int64"
200
198
  M, N = x.shape
201
199
  device = x.device
202
- loss = torch.empty(M, device=device, dtype=x.dtype)
200
+ loss = torch.empty(M, device=device, dtype=torch.float32)
201
+ lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
203
202
  dtype = torch2cute_dtype_map[x.dtype]
204
203
  convert_from_dlpack = lambda tensor: (
205
204
  from_dlpack(tensor.detach(), assumed_align=16)
@@ -207,15 +206,16 @@ def cross_entropy(
207
206
  )
208
207
  x_tensor, = [convert_from_dlpack(tensor) for tensor in (x,)]
209
208
  loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
209
+ lse_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0) if lse is not None else None
210
210
  target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_compact_shape_dynamic(mode=0)
211
211
  stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
212
- compile_key = (dtype, N)
212
+ compile_key = (dtype, N, lse_tensor is not None)
213
213
  if compile_key not in cross_entropy.compile_cache:
214
214
  cross_entropy.compile_cache[compile_key] = cute.compile(
215
- cross_entropy_interface, x_tensor, target_tensor, loss_tensor, stream, N
215
+ cross_entropy_interface, x_tensor, target_tensor, loss_tensor, lse_tensor, stream, N
216
216
  )
217
- cross_entropy.compile_cache[compile_key](x_tensor, target_tensor, loss_tensor, stream)
218
- return loss
217
+ cross_entropy.compile_cache[compile_key](x_tensor, target_tensor, loss_tensor, lse_tensor, stream)
218
+ return loss if not return_lse else (loss, lse)
219
219
 
220
220
 
221
221
  cross_entropy.compile_cache = {}
quack/rmsnorm.py CHANGED
@@ -16,13 +16,11 @@ import quack.utils as utils
16
16
 
17
17
  @cute.kernel
18
18
  def rmsnorm_kernel(
19
- gX: cute.Tensor,
20
- gW: cute.Tensor,
21
- gO: cute.Tensor,
22
- gRstd: cute.Tensor,
23
- cX: cute.Tensor, # coordinate tensor
19
+ mX: cute.Tensor,
20
+ mW: cute.Tensor,
21
+ mO: cute.Tensor,
22
+ mRstd: cute.Tensor,
24
23
  eps: cute.Float32,
25
- shape: cute.Shape,
26
24
  tv_layout: cute.Layout,
27
25
  tiler_mn: cute.Shape,
28
26
  cluster_n: cutlass.Constexpr = 1,
@@ -31,42 +29,45 @@ def rmsnorm_kernel(
31
29
  ):
32
30
  tidx, _, _ = cute.arch.thread_idx()
33
31
  bidx, cluster_y, _ = cute.arch.block_idx()
34
- gdim, _, _ = cute.arch.grid_dim()
35
-
36
- # slice for CTAs
37
- # logical id -> address
38
- blkX, blkOut, blkRstd, blkCrd = [gT[(None, None), bidx if cluster_n == 1 else (bidx, cluster_y)] for gT in (gX, gO, gRstd, cX)]
39
- blkW = gW[(None, None), 0 if cluster_n == 1 else (0, cluster_y)]
40
-
41
- # declare the atoms which will be used later for memory copy
42
- copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128)
43
- copy_atom_load_X_async = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128)
44
- copy_atom_load_W = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gW.element_type, num_bits_per_copy=128)
45
- copy_atom_store_O = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gO.element_type, num_bits_per_copy=128)
46
-
47
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
48
- thr_copy_X_async = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(tidx)
49
- thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
50
- thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
51
32
 
52
33
  smem = cutlass.utils.SmemAllocator()
53
- # Don't use blkX.layout here, because the stride is N, not N_rounded
54
- sX = smem.allocate_tensor(gX.element_type, cute.make_ordered_layout(blkX.shape, order=(1, 0)), byte_alignment=16)
34
+ sX = smem.allocate_tensor(mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16)
55
35
  num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
56
36
  warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
57
- # reduction_buffer_layout = cute.make_ordered_layout((num_warps // warps_per_row, warps_per_row), order=(1, 0))
58
- reduction_buffer_layout = cute.make_ordered_layout((num_warps // warps_per_row, warps_per_row if cluster_n == 1 else (warps_per_row, cluster_n)), order=(1, 0))
37
+ reduction_buffer_layout = cute.make_ordered_layout(
38
+ (num_warps // warps_per_row, (warps_per_row, cluster_n)),
39
+ order=(1, 0)
40
+ )
59
41
  reduction_buffer = smem.allocate_tensor(cutlass.Float32, reduction_buffer_layout, byte_alignment=4)
60
42
  if cutlass.const_expr(cluster_n > 1):
61
43
  mbar_ptr = smem.allocate(cutlass.Int64.width // 8, byte_alignment=8)
62
44
  else:
63
45
  mbar_ptr = None
64
46
 
65
- tWgW = thr_copy_W.partition_S(blkW)
66
- tXgX = thr_copy_X_async.partition_S(blkX)
67
- tXsX = thr_copy_X_async.partition_S(sX)
68
- tXgO, tXrRstd = [thr_copy_O.partition_D(blk) for blk in (blkOut, blkRstd)]
69
- tXcX = thr_copy_X.partition_S(blkCrd)[(0, None), None, None]
47
+ shape = mX.shape
48
+ idX = cute.make_identity_tensor(shape)
49
+ # slice for CTAs
50
+ gX, gO, gRstd, cX = [
51
+ cute.local_tile(mT, tiler_mn, (bidx, 0 if cluster_n == 1 else cluster_y))
52
+ for mT in (mX, mO, mRstd, idX)
53
+ ]
54
+ gW = cute.local_tile(mW, tiler_mn, (0, 0 if cluster_n == 1 else cluster_y))
55
+
56
+ # declare the atoms which will be used later for memory copy
57
+ copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128)
58
+ copy_atom_load_X_async = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128)
59
+ copy_atom_load_W = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128)
60
+ copy_atom_store_O = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=128)
61
+
62
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(tidx)
63
+ thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
64
+ thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
65
+
66
+ tWgW = thr_copy_W.partition_S(gW)
67
+ tXgX = thr_copy_X.partition_S(gX)
68
+ tXsX = thr_copy_X.partition_D(sX)
69
+ tXgO, tXrRstd = [thr_copy_O.partition_D(gT) for gT in (gO, gRstd)]
70
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
70
71
 
71
72
  # allocate fragments for gmem->rmem
72
73
  tWrW = cute.make_fragment_like(tWgW)
@@ -82,44 +83,33 @@ def rmsnorm_kernel(
82
83
  # Cluster arrive after barrier init
83
84
  cute.arch.cluster_arrive_relaxed()
84
85
 
85
- tXpX = cute.make_fragment_like(tXgX[(0, None), None, None], cutlass.Boolean)
86
- for i in range(cute.size(tXpX)):
87
- tXpX[i] = cute.elem_less(tXcX[i][1], shape[1])
88
- # tXrX.fill(0.0)
89
- if tXcX[0][0] < shape[0]:
90
- # cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
86
+ tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
87
+ row = tXcX[0][0]
88
+ if row < shape[0]:
91
89
  cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
92
90
  cute.arch.cp_async_commit_group()
93
91
 
94
- tWpW = cute.make_fragment_like(tWgW[(0, None), None, None], cutlass.Boolean)
95
- tWcX = thr_copy_W.partition_S(blkCrd)[(0, None), None, None]
96
- for i in range(cute.size(tWpW)):
97
- tWpW[i] = cute.elem_less(tWcX[i][1], shape[1])
92
+ tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
98
93
  if not delay_w_load:
99
94
  cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
100
95
 
101
96
  cute.arch.cp_async_wait_group(0)
102
97
  cute.autovec_copy(tXsX, tXrX)
103
98
  x = tXrX.load().to(cute.Float32)
104
- sum_sq_x = utils.warp_reduce(
105
- (x * x).reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
106
- operator.add,
107
- width=utils.min_constexpr(tv_layout.shape[0][0], cute.arch.WARP_SIZE),
99
+ threads_per_row = tv_layout.shape[0][0]
100
+ sum_sq_x = utils.row_reduce(
101
+ x * x,
102
+ cute.ReductionOp.ADD,
103
+ threads_per_row,
104
+ reduction_buffer,
105
+ mbar_ptr,
106
+ init_val=0.0,
107
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(cluster_n > 1) else None
108
108
  )
109
- if cutlass.const_expr(cluster_n > 1):
110
- cute.arch.cluster_wait()
111
- if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
112
- sum_sq_x = utils.block_or_cluster_reduce(
113
- sum_sq_x, operator.add, reduction_buffer, mbar_ptr, init_val=0.0
114
- )
115
109
  rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
116
110
  # Only the thread corresponding to column 0 writes out the rstd to gmem
117
- if tXcX[0][1] == 0 and tXcX[0][0] < shape[0]:
118
- if cutlass.const_expr(cluster_n == 1):
119
- tXrRstd[0] = rstd
120
- else:
121
- if cute.arch.block_idx_in_cluster() == 0:
122
- tXrRstd[0] = rstd
111
+ if tXcX[0][1] == 0 and row < shape[0] and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0):
112
+ tXrRstd[0] = rstd
123
113
  if delay_w_load:
124
114
  cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
125
115
  if reload_from == "smem":
@@ -132,20 +122,16 @@ def rmsnorm_kernel(
132
122
  w = tXrW.load().to(cute.Float32)
133
123
  y = x_hat * w
134
124
  tXrO.store(y.to(tXrO.element_type))
135
- tOcX = thr_copy_O.partition_S(blkCrd)[(0, None), None, None]
136
- tOpO = cute.make_fragment_like(tXgO[(0, None), None, None], cutlass.Boolean)
137
- for i in range(cute.size(tOpO)):
138
- tOpO[i] = cute.elem_less(tOcX[i][1], shape[1])
139
- if tXcX[0][0] < shape[0]:
125
+ tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
126
+ if row < shape[0]:
140
127
  cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
141
128
 
142
129
 
143
130
  @cute.jit
144
131
  def rmsnorm_interface(
145
- # mX_: cute.Tensor,
146
132
  mX: cute.Tensor,
147
133
  mW: cute.Tensor,
148
- mOut: cute.Tensor,
134
+ mO: cute.Tensor,
149
135
  mRstd: cute.Tensor,
150
136
  stream: cuda.CUstream,
151
137
  N: cutlass.Constexpr,
@@ -180,21 +166,18 @@ def rmsnorm_interface(
180
166
  mW_expanded = cute.make_tensor(mW.iterator, mW_expanded_layout)
181
167
  mRstd_expanded_layout = cute.append(mRstd.layout, cute.make_layout((N,), stride=(0,)))
182
168
  mRstd_expanded = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
183
- idX = cute.make_identity_tensor(mX.shape)
184
- gX, gW, gO, gRstd, cX = [cute.zipped_divide(mT, tiler_mn) for mT in (mX, mW_expanded, mOut, mRstd_expanded, idX)] # ((TileM,TileN),(RestM,RestN))
185
169
 
186
170
  # reload_from = None if N <= 16384 else ("smem" if N <= 32768 else "gmem")
187
171
  reload_from = None if N <= 16384 else "smem"
188
172
  # delay_w_load = N > 64 * 1024
189
173
  delay_w_load = False
190
174
  N_rounded = tiler_mn[1]
191
- rmsnorm_kernel(gX, gW, gO, gRstd, cX, eps, mX.shape, tv_layout, tiler_mn, cluster_n, reload_from).launch(
192
- grid=[cute.size(gX, mode=[1, 0]), cluster_n, 1],
175
+ rmsnorm_kernel(mX, mW_expanded, mO, mRstd_expanded, eps, tv_layout, tiler_mn, cluster_n, reload_from).launch(
176
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1],
193
177
  block=[cute.size(tv_layout, mode=[0]), 1, 1],
194
178
  # Launching with cluster=[1, 1, 1] instead of None slows down the kernel by ~8us
195
179
  cluster=[1, cluster_n, 1] if cluster_n > 1 else None,
196
- # We don't want to use gX.layout[0] here since that has stride in N, not N_rounded, leading IMA on smem
197
- smem=cute.size_in_bytes(mX.element_type, cute.make_layout(gX.shape[0])) + num_warps * cluster_n * (cutlass.Float32.width // 8) + (cutlass.Int64.width // 8),
180
+ smem=cute.size_in_bytes(mX.element_type, cute.make_layout(tiler_mn)) + num_warps * cluster_n * (cutlass.Float32.width // 8) + (cutlass.Int64.width // 8),
198
181
  stream=stream,
199
182
  )
200
183
 
quack/softmax.py CHANGED
@@ -15,40 +15,30 @@ import quack.utils as utils
15
15
 
16
16
  @cute.kernel
17
17
  def softmax_kernel(
18
- gX: cute.Tensor,
19
- gO: cute.Tensor,
20
- cX: cute.Tensor, # coordinate tensor
21
- shape: cute.Shape,
18
+ mX: cute.Tensor,
19
+ mO: cute.Tensor,
22
20
  tv_layout: cute.Layout,
23
21
  tiler_mn: cute.Shape,
24
22
  cluster_n: cutlass.Constexpr = 1,
25
23
  ):
26
24
  tidx, _, _ = cute.arch.thread_idx()
27
25
  bidx, cluster_y, _ = cute.arch.block_idx()
28
- gdim, _, _ = cute.arch.grid_dim()
29
26
 
27
+ shape = mX.shape
28
+ idX = cute.make_identity_tensor(shape)
30
29
  # slice for CTAs
31
- # logical id -> address
32
- blkX, blkOut, blkCrd = [gT[(None, None), bidx if cluster_n == 1 else (bidx, cluster_y)] for gT in (gX, gO, cX)]
33
-
34
- # declare the atoms which will be used later for memory copy
35
- copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128)
36
- copy_atom_load_X_async = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128)
37
- copy_atom_store_O = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gO.element_type, num_bits_per_copy=128)
38
-
39
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
40
- thr_copy_X_async = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(tidx)
41
- thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
30
+ gX, gO, cX = [
31
+ cute.local_tile(mT, tiler_mn, (bidx, 0 if cluster_n == 1 else cluster_y))
32
+ for mT in (mX, mO, idX)
33
+ ]
42
34
 
43
35
  smem = cutlass.utils.SmemAllocator()
44
- # Don't use blkX.layout here, because the stride is N, not N_rounded
45
- sX = smem.allocate_tensor(gX.element_type, cute.make_ordered_layout(blkX.shape, order=(1, 0)), byte_alignment=16)
36
+ sX = smem.allocate_tensor(mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16)
46
37
  num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
47
38
  warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
48
-
49
39
  reduction_buffer_layout = cute.make_ordered_layout(
50
40
  # 2 stages: 1 for max, 1 for sum
51
- (num_warps // warps_per_row, warps_per_row if cluster_n == 1 else (warps_per_row, cluster_n), 2),
41
+ (num_warps // warps_per_row, (warps_per_row, cluster_n), 2),
52
42
  order=(1, 0, 2)
53
43
  )
54
44
  reduction_buffer = smem.allocate_tensor(cutlass.Float32, reduction_buffer_layout, byte_alignment=4)
@@ -58,10 +48,17 @@ def softmax_kernel(
58
48
  else:
59
49
  mbar_ptr = None
60
50
 
61
- tXgX = thr_copy_X_async.partition_S(blkX)
62
- tXsX = thr_copy_X_async.partition_S(sX)
63
- tXgO = thr_copy_O.partition_D(blkOut)
64
- tXcX = thr_copy_X.partition_S(blkCrd)[(0, None), None, None]
51
+ # declare the atoms which will be used later for memory copy
52
+ copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128)
53
+ copy_atom_store_O = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gO.element_type, num_bits_per_copy=128)
54
+
55
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
56
+ thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
57
+
58
+ tXgX = thr_copy_X.partition_S(gX)
59
+ tXsX = thr_copy_X.partition_D(sX)
60
+ tXgO = thr_copy_O.partition_D(gO)
61
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
65
62
 
66
63
  # allocate fragments for gmem->rmem
67
64
  tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
@@ -75,49 +72,48 @@ def softmax_kernel(
75
72
  # Cluster arrive after barrier init
76
73
  cute.arch.cluster_arrive_relaxed()
77
74
 
78
- tXpX = cute.make_fragment_like(tXgX[(0, None), None, None], cutlass.Boolean)
79
- for i in range(cute.size(tXpX)):
80
- tXpX[i] = cute.elem_less(tXcX[i][1], shape[1])
81
-
75
+ is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * cluster_n)
76
+ tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
82
77
  if tXcX[0][0] < shape[0]:
83
- cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
78
+ cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
84
79
  cute.arch.cp_async_commit_group()
85
80
  cute.arch.cp_async_wait_group(0)
86
81
 
87
82
  cute.autovec_copy(tXsX, tXrX)
88
83
  x = tXrX.load().to(cute.Float32)
89
- max_x = utils.warp_reduce(
90
- x.reduce(cute.ReductionOp.MAX, init_val=float('-inf'), reduction_profile=0),
91
- cute.arch.fmax,
92
- width=utils.min_constexpr(tv_layout.shape[0][0], cute.arch.WARP_SIZE),
84
+ # Fill OOB values with -inf
85
+ if cutlass.const_expr(not is_even_N):
86
+ tXrX_fp32 = cute.make_fragment_like(tXrX, cutlass.Float32)
87
+ tXrX_fp32.store(x)
88
+ for rest_v in range(tXpX.shape[0]):
89
+ for rest_k in range(tXpX.shape[2]):
90
+ if not tXpX[rest_v, 0, rest_k]:
91
+ tXrX_fp32[(None, rest_v), None, rest_k].fill(-cutlass.Float32.inf)
92
+ x = tXrX_fp32.load()
93
+ threads_per_row = tv_layout.shape[0][0]
94
+ max_x = utils.row_reduce(
95
+ x,
96
+ cute.ReductionOp.MAX,
97
+ threads_per_row,
98
+ reduction_buffer[None, None, 0],
99
+ mbar_ptr + 0 if cluster_n > 1 else None,
100
+ init_val=-cutlass.Float32.inf,
101
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(cluster_n > 1) else None
93
102
  )
94
- if cutlass.const_expr(cluster_n > 1):
95
- cute.arch.cluster_wait()
96
- if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
97
- max_mbar_ptr = mbar_ptr + 0 if cluster_n > 1 else None
98
- max_x = utils.block_or_cluster_reduce(
99
- max_x, cute.arch.fmax, reduction_buffer[None, None, 0], max_mbar_ptr, init_val=-cutlass.Float32.inf
100
- )
101
103
  log2_e = math.log2(math.e)
102
104
  exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
103
- denom = utils.warp_reduce(
104
- exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
105
- operator.add,
106
- width=utils.min_constexpr(tv_layout.shape[0][0], cute.arch.WARP_SIZE),
105
+ denom = utils.row_reduce(
106
+ exp_x,
107
+ cute.ReductionOp.ADD,
108
+ threads_per_row,
109
+ reduction_buffer[None, None, 1],
110
+ mbar_ptr + 1 if cluster_n > 1 else None,
111
+ init_val=0.0,
107
112
  )
108
- if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
109
- sum_mbar_ptr = mbar_ptr + 1 if cluster_n > 1 else None
110
- denom = utils.block_or_cluster_reduce(
111
- denom, operator.add, reduction_buffer[None, None, 1], sum_mbar_ptr, init_val=0.0
112
- )
113
113
  inv = 1.0 / denom
114
114
  y = exp_x * inv
115
-
116
115
  tXrO.store(y.to(tXrO.element_type))
117
- tOcX = thr_copy_O.partition_S(blkCrd)[(0, None), None, None]
118
- tOpO = cute.make_fragment_like(tXgO[(0, None), None, None], cutlass.Boolean)
119
- for i in range(cute.size(tOpO)):
120
- tOpO[i] = cute.elem_less(tOcX[i][1], shape[1])
116
+ tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1]) if not is_even_N else None
121
117
  if tXcX[0][0] < shape[0]:
122
118
  cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
123
119
 
@@ -125,7 +121,7 @@ def softmax_kernel(
125
121
  @cute.jit
126
122
  def softmax_interface(
127
123
  mX: cute.Tensor,
128
- mOut: cute.Tensor,
124
+ mO: cute.Tensor,
129
125
  stream: cuda.CUstream,
130
126
  N: cutlass.Constexpr,
131
127
  copy_bits: cutlass.Constexpr = 128
@@ -149,12 +145,9 @@ def softmax_interface(
149
145
  stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * threads_per_row))
150
146
  )
151
147
 
152
- idX = cute.make_identity_tensor(mX.shape)
153
- gX, gO, cX = [cute.zipped_divide(mT, tiler_mn) for mT in (mX, mOut, idX)] # ((TileM,TileN),(RestM,RestN))
154
-
155
- smem_allocated = cute.size_in_bytes(mX.element_type, cute.make_layout(gX.shape[0])) + 2 * num_warps * cluster_n * (cutlass.Float32.width // 8) + 2 * (cutlass.Int64.width // 8)
156
- softmax_kernel(gX, gO, cX, mX.shape, tv_layout, tiler_mn, cluster_n).launch(
157
- grid=[cute.size(gX, mode=[1, 0]), cluster_n, 1],
148
+ smem_allocated = cute.size_in_bytes(mX.element_type, cute.make_layout(tiler_mn)) + 2 * num_warps * cluster_n * (cutlass.Float32.width // 8) + 2 * (cutlass.Int64.width // 8)
149
+ softmax_kernel(mX, mO, tv_layout, tiler_mn, cluster_n).launch(
150
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1],
158
151
  block=[cute.size(tv_layout, mode=[0]), 1, 1],
159
152
  # Launching with cluster=[1, 1, 1] instead of None slows down the kernel by ~8us
160
153
  cluster=[1, cluster_n, 1] if cluster_n > 1 else None,
quack/utils.py CHANGED
@@ -1,5 +1,6 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
+ import operator
3
4
  import math
4
5
  from typing import Type, Callable, Optional
5
6
 
@@ -57,7 +58,7 @@ def block_reduce(val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor,
57
58
  """reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)
58
59
  """
59
60
  lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
60
- warps_per_row = reduction_buffer.shape[1]
61
+ warps_per_row = cute.size(reduction_buffer.shape[1])
61
62
  row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
62
63
  if lane_idx == 0:
63
64
  reduction_buffer[row_idx, col_idx] = val
@@ -142,6 +143,46 @@ def block_or_cluster_reduce(val: cute.Numeric, op: Callable, reduction_buffer: c
142
143
  return cluster_reduce(val, op, reduction_buffer, mbar_ptr, init_val=init_val)
143
144
 
144
145
 
146
+ @cute.jit
147
+ def row_reduce(
148
+ x: cute.TensorSSA | cute.Numeric,
149
+ op: cute.ReductionOp,
150
+ threads_per_row: cutlass.Constexpr[int],
151
+ reduction_buffer: Optional[cute.Tensor] = None,
152
+ mbar_ptr: Optional[cute.Pointer] = None,
153
+ init_val: cute.Numeric = 0.0,
154
+ hook_fn: Optional[Callable] = None,
155
+ ) -> cute.Numeric:
156
+ """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))
157
+ """
158
+ if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
159
+ val = x.reduce(op, init_val=init_val, reduction_profile=0)
160
+ else:
161
+ val = x
162
+ warp_op = {
163
+ cute.ReductionOp.ADD: operator.add,
164
+ cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == cute.Float32) else max,
165
+ cute.ReductionOp.MIN: min,
166
+ cute.ReductionOp.MUL: operator.mul,
167
+ }[op]
168
+ val = warp_reduce(
169
+ val,
170
+ warp_op,
171
+ width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
172
+ )
173
+ if cutlass.const_expr(hook_fn is not None):
174
+ hook_fn()
175
+ if cutlass.const_expr(reduction_buffer is not None):
176
+ warps_per_row, cluster_n = reduction_buffer.shape[1]
177
+ assert cluster_n == 1 or mbar_ptr is not None, "mbar_ptr must be provided for cluster reduction"
178
+ if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
179
+ val = block_or_cluster_reduce(
180
+ val, warp_op, reduction_buffer, mbar_ptr, init_val=init_val
181
+ )
182
+ return val
183
+
184
+
185
+
145
186
  def exp2f(x: cute.TensorSSA | cutlass.Float32) -> cute.TensorSSA | cutlass.Float32:
146
187
  """exp2f calculation for both vector and scalar.
147
188
 
@@ -188,3 +229,18 @@ def rsqrt(a: float | cute.Float32, *, loc=None, ip=None) -> cute.Float32:
188
229
  asm_dialect=llvm.AsmDialect.AD_ATT,
189
230
  )
190
231
  )
232
+
233
+
234
+ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
235
+ # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
236
+ tApA = cute.make_fragment(
237
+ cute.make_layout(
238
+ (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
239
+ stride=(cute.size(tAcA, mode=[2]), 0, 1),
240
+ ),
241
+ cutlass.Boolean,
242
+ )
243
+ for rest_v in range(tApA.shape[0]):
244
+ for rest_k in range(tApA.shape[2]):
245
+ tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
246
+ return tApA
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
6
  Requires-Dist: nvidia-cutlass-dsl==4.0.0
@@ -0,0 +1,10 @@
1
+ quack/__init__.py,sha256=Nf01m1CGrOjSkqGJom6P65hSLkckljRMhlkSoqqlO9k,137
2
+ quack/cross_entropy.py,sha256=gdo8sR9KT5TsrShbgAmy-bwRZLu0gTs_ykXBF2RMbFI,8900
3
+ quack/rmsnorm.py,sha256=JhwJSAPDDpB_hV90xU9ymiLU-zu4WScrSHc5JX2JarY,10470
4
+ quack/softmax.py,sha256=C8e8ZNaF5ePJ1NlrWZN1goCcvsx1C60FWlRyuFCcYoM,7737
5
+ quack/utils.py,sha256=PRdu-P7azA_PeHUNdtoy1zyxZwg_QyVrSiVwE1iXaWo,8961
6
+ quack_kernels-0.1.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
7
+ quack_kernels-0.1.2.dist-info/METADATA,sha256=3WjugLu1IhLlgsg2qUcLBZq1HI4-BIyyJIuQc5Hk-rU,186
8
+ quack_kernels-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
+ quack_kernels-0.1.2.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
10
+ quack_kernels-0.1.2.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- quack/__init__.py,sha256=y3Oa4OVPqaGU_P1miI435DzfpMgIwKVmU8-Eogv58jQ,137
2
- quack/cross_entropy.py,sha256=V0kG8DCNh2735sPIDwe68NB50rAqDF3XQApnGyo-sKg,9220
3
- quack/rmsnorm.py,sha256=RNqcT-q4uvMbF6ejpzuqQH8l8VVuTRlnueXf28V47sc,11954
4
- quack/softmax.py,sha256=QABgOESH5JjDm3yuUkyZZKXXpzn7CTuMSs0NEBnFD80,8536
5
- quack/utils.py,sha256=ofV7QLDuq80h3nEA3TwZW-ti8CnYwMgnz1dpxpvhHpk,6859
6
- quack_kernels-0.1.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
7
- quack_kernels-0.1.1.dist-info/METADATA,sha256=XG3zS0_q48TzkoR7CemzaJGVYHS731yVOrzH49_uRK8,186
8
- quack_kernels-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
- quack_kernels-0.1.1.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
10
- quack_kernels-0.1.1.dist-info/RECORD,,