quack-kernels 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/__init__.py CHANGED
@@ -1,5 +1,11 @@
1
- __version__ = "0.1.1"
1
+ __version__ = "0.1.3"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
5
5
  from quack.cross_entropy import cross_entropy
6
+
7
+ __all__ = [
8
+ "rmsnorm",
9
+ "softmax",
10
+ "cross_entropy",
11
+ ]
quack/cross_entropy.py CHANGED
@@ -1,7 +1,6 @@
1
1
  import math
2
2
  import torch
3
- import operator
4
- from typing import Callable, Union
3
+ from typing import Optional, Type
5
4
 
6
5
  import cuda.bindings.driver as cuda
7
6
 
@@ -10,177 +9,198 @@ import cutlass.cute as cute
10
9
  from cutlass.cute.runtime import from_dlpack
11
10
 
12
11
  import quack.utils as utils
12
+ from quack.reduction_base import ReductionBase, torch2cute_dtype_map
13
13
 
14
14
 
15
- @cute.kernel
16
- def cross_entropy_kernel(
17
- mX: cute.Tensor, # (M, N)
18
- mTarget: cute.Tensor, # (M,)
19
- mLoss: cute.Tensor, # (M,)
20
- tv_layout: cute.Layout,
21
- tiler_mn: cute.Shape,
22
- cluster_n: cutlass.Constexpr = 1,
23
- ):
24
- tidx, _, _ = cute.arch.thread_idx()
25
- bidx, cluster_y, _ = cute.arch.block_idx()
26
- gdim, _, _ = cute.arch.grid_dim()
27
-
28
- shape: cute.Shape = mX.shape
29
- idX = cute.make_identity_tensor(mX.shape)
30
- gX, cX = [cute.zipped_divide(mT, tiler_mn) for mT in (mX, idX)]
31
- blkX, blkCrd = [gT[(None, None), bidx if cluster_n == 1 else (bidx, cluster_y)] for gT in (gX, cX)]
32
-
33
- # declare the atoms which will be used later for memory copy
34
- copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128)
35
- copy_atom_load_X_async = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128)
36
- copy_atom_scalar = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=gX.element_type.width)
37
-
38
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
39
- thr_copy_X_async = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(tidx)
40
-
41
- smem = cutlass.utils.SmemAllocator()
42
-
43
- # Don't use blkX.layout here, because the stride is N, not N_rounded
44
- sX = smem.allocate_tensor(gX.element_type, cute.make_ordered_layout(blkX.shape, order=(1, 0)), byte_alignment=16)
45
- num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
46
- warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
47
-
48
- reduction_buffer_layout = cute.make_ordered_layout(
15
+ class CrossEntropy(ReductionBase):
16
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int, online_softmax: bool = True):
49
17
  # 2 stages: 1 for max, 1 for sum
50
- (num_warps // warps_per_row, warps_per_row if cluster_n == 1 else (warps_per_row, cluster_n), 2),
51
- order=(1, 0, 2)
52
- )
53
- reduction_buffer = smem.allocate_tensor(cutlass.Float32, reduction_buffer_layout, byte_alignment=4)
54
- if cutlass.const_expr(cluster_n > 1):
55
- # 1 mbar for max reduction, 1 mbar for sum reduction
56
- mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=2)
57
- else:
58
- mbar_ptr = None
59
-
60
- #### Thread View
61
- tXgX = thr_copy_X_async.partition_S(blkX)
62
- tXsX = thr_copy_X_async.partition_S(sX)
63
-
64
- tXcX = thr_copy_X.partition_S(blkCrd)[(0, None), None, None]
65
-
66
- # allocate fragments for gmem->rmem
67
- tXrX = cute.make_fragment_like(tXgX) # only logits fragment needed
68
-
69
- if cluster_n > 1:
70
- if tidx < 2:
71
- cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + tidx, 1)
72
- cute.arch.mbarrier_init_fence()
73
- if tidx < 2:
74
- cute.arch.mbarrier_init_tx_bytes(mbar_ptr + tidx, num_warps * cluster_n * cutlass.Float32.width // 8)
75
- # Cluster arrive after barrier init
76
- cute.arch.cluster_arrive_relaxed()
77
-
78
- row = tXcX[0][0]
79
- target = cute.Int32.zero
80
- if row < shape[0] and tXcX[0][1] == 0:
81
- target = cute.Int32(mTarget[row])
82
-
83
- tXpX = cute.make_fragment_like(tXgX[(0, None), None, None], cutlass.Boolean)
84
- for i in range(cute.size(tXpX)):
85
- tXpX[i] = cute.elem_less(tXcX[i][1], shape[1])
86
- if row < shape[0]:
87
- cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
88
- cute.arch.cp_async_commit_group()
89
- cute.arch.cp_async_wait_group(0)
90
- cute.autovec_copy(tXsX, tXrX)
91
- x = tXrX.load().to(cute.Float32)
92
-
93
- target_logit = cute.Float32.zero
94
- if row < shape[0] and tXcX[0][1] == 0:
95
- target_logit = cute.Float32(mX[row, target])
96
-
97
- max_x = utils.warp_reduce(
98
- x.reduce(cute.ReductionOp.MAX, init_val=float('-inf'), reduction_profile=0),
99
- cute.arch.fmax,
100
- width=utils.min_constexpr(tv_layout.shape[0][0], cute.arch.WARP_SIZE),
101
- )
102
- if cutlass.const_expr(cluster_n > 1):
103
- cute.arch.cluster_wait()
104
- if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
105
- max_mbar_ptr = mbar_ptr + 0 if cluster_n > 1 else None
106
- max_x = utils.block_or_cluster_reduce(
107
- max_x, cute.arch.fmax, reduction_buffer[None, None, 0], max_mbar_ptr, init_val=-cutlass.Float32.inf
18
+ super().__init__(
19
+ dtype,
20
+ N,
21
+ stage=2 if not online_softmax else 1,
22
+ reduction_dtype=cutlass.Float32 if not online_softmax else cutlass.Int64,
108
23
  )
109
- log2_e = math.log2(math.e)
110
- # exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
111
- exp_x = utils.exp2f((x - max_x) * log2_e) # a bit faster, idk why
112
- denom = utils.warp_reduce(
113
- exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
114
- operator.add,
115
- width=utils.min_constexpr(tv_layout.shape[0][0], cute.arch.WARP_SIZE),
116
- )
117
- if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
118
- sum_mbar_ptr = mbar_ptr + 1 if cluster_n > 1 else None
119
- denom = utils.block_or_cluster_reduce(
120
- denom, operator.add, reduction_buffer[None, None, 1], sum_mbar_ptr, init_val=0.0
24
+ self.online_softmax = online_softmax
25
+ self.reload_from = None if N <= 16384 or online_softmax else "smem"
26
+
27
+ def _calculate_threads_per_row(self):
28
+ N = self.N
29
+ return (
30
+ 8
31
+ if N <= 64
32
+ else (
33
+ 16
34
+ if N <= 128
35
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
36
+ )
121
37
  )
122
38
 
123
- if tXcX[0][1] == 0 and row < shape[0]:
124
- ln_2 = math.log(2.0)
125
- loss_val = -target_logit + max_x + utils.log2f(denom) * ln_2
126
- if cutlass.const_expr(cluster_n == 1):
127
- mLoss[row] = loss_val.to(mLoss.element_type)
128
- else:
129
- if cute.arch.block_idx_in_cluster() == 0:
130
- mLoss[row] = loss_val.to(mLoss.element_type)
131
-
132
-
133
- @cute.jit
134
- def cross_entropy_interface(
135
- mX: cute.Tensor,
136
- mTarget: cute.Tensor,
137
- mLoss: cute.Tensor,
138
- stream: cuda.CUstream,
139
- N: cutlass.Constexpr,
140
- copy_bits: cutlass.Constexpr = 128
141
- ):
142
- vecsize = copy_bits // mX.element_type.width
143
- assert N % vecsize == 0, f"Input N {N} is not divisible by vector size {vecsize}"
144
- num_threads = 128 if N <= 16384 else 256
145
-
146
- num_warps = num_threads // cute.arch.WARP_SIZE
147
- assert num_threads % cute.arch.WARP_SIZE == 0
148
- threads_per_row = 8 if N <= 64 else (16 if N <= 128 else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256))))
149
-
150
- if cutlass.const_expr(mX.element_type.width == 16):
151
- cluster_n = 1 if N <= 16 * 1024 else (2 if N <= 32 * 1024 else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)))
152
- else: # fp32
153
- cluster_n = 1 if N <= 16 * 1024 else (2 if N <= 64 * 1024 else (4 if N <= 128 * 1024 else 8))
154
-
155
- num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row * cluster_n)
156
- cols_per_block = num_threads // threads_per_row
157
- tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) # This rounds up N
158
- tv_layout = cute.make_layout(
159
- ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
160
- stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * threads_per_row))
161
- )
39
+ def _set_cluster_n(self):
40
+ N = self.N
41
+ if cutlass.const_expr(self.dtype.width == 16):
42
+ cluster_n = (
43
+ 1
44
+ if N <= 16 * 1024
45
+ else (
46
+ 2
47
+ if N <= 32 * 1024
48
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
49
+ )
50
+ )
51
+ else: # fp32
52
+ cluster_n = (
53
+ 1
54
+ if N <= 16 * 1024
55
+ else (
56
+ 2
57
+ if N <= 64 * 1024
58
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
59
+ )
60
+ )
61
+ self.cluster_n = cluster_n
62
+
63
+ @cute.jit
64
+ def __call__(
65
+ self,
66
+ mX: cute.Tensor,
67
+ mTarget: cute.Tensor,
68
+ mLoss: cute.Tensor,
69
+ mLSE: Optional[cute.Tensor],
70
+ stream: cuda.CUstream,
71
+ ):
72
+ assert mX.element_type == self.dtype
73
+ self._set_cluster_n()
74
+ tiler_mn, tv_layout = self._get_tv_layout()
75
+ num_threads = cute.size(tv_layout, mode=[0])
76
+ num_warps = num_threads // cute.arch.WARP_SIZE
77
+ self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
78
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
79
+ block=[num_threads, 1, 1],
80
+ cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
81
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
82
+ stream=stream,
83
+ )
162
84
 
163
- smem_allocated = cute.size_in_bytes(mX.element_type, cute.make_layout(tiler_mn)) + 2 * num_warps * cluster_n * (cutlass.Float32.width // 8) + 2 * (cutlass.Int64.width // 8)
164
- cross_entropy_kernel(mX, mTarget, mLoss, tv_layout, tiler_mn, cluster_n).launch(
165
- grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1],
166
- block=[cute.size(tv_layout, mode=[0]), 1, 1],
167
- # Launching with cluster=[1, 1, 1] instead of None slows down the kernel by ~8us
168
- cluster=[1, cluster_n, 1] if cluster_n > 1 else None,
169
- smem=smem_allocated,
170
- stream=stream,
171
- )
85
+ @cute.kernel
86
+ def kernel(
87
+ self,
88
+ mX: cute.Tensor, # (M, N)
89
+ mTarget: cute.Tensor, # (M,)
90
+ mLoss: cute.Tensor, # (M,)
91
+ mLSE: Optional[cute.Tensor], # (M,)
92
+ tv_layout: cute.Layout,
93
+ tiler_mn: cute.Shape,
94
+ ):
95
+ tidx, _, _ = cute.arch.thread_idx()
96
+ bidx, cluster_y, _ = cute.arch.block_idx()
97
+
98
+ shape: cute.Shape = mX.shape
99
+ idX = cute.make_identity_tensor(shape)
100
+ # slice for CTAs
101
+ gX, cX = [
102
+ cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
103
+ for mT in (mX, idX)
104
+ ]
105
+
106
+ smem = cutlass.utils.SmemAllocator()
107
+ sX = smem.allocate_tensor(
108
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
109
+ )
110
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
111
+
112
+ # declare the atoms which will be used later for memory copy
113
+ copy_atom_load_X = cute.make_copy_atom(
114
+ cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
115
+ )
116
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
117
+
118
+ #### Thread View
119
+ tXgX = thr_copy_X.partition_S(gX)
120
+ tXsX = thr_copy_X.partition_D(sX)
121
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
122
+ tXrX = cute.make_fragment_like(tXgX)
123
+
124
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
125
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
172
126
 
127
+ row = tXcX[0][0]
128
+ target = cute.Int32.zero
129
+ if row < shape[0] and tXcX[0][1] == 0:
130
+ target = cute.Int32(mTarget[row])
173
131
 
174
- torch2cute_dtype_map = {
175
- torch.float16: cutlass.Float16,
176
- torch.bfloat16: cutlass.BFloat16,
177
- torch.float32: cutlass.Float32,
178
- }
132
+ is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
133
+ tXpX = (
134
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
135
+ )
136
+ if row < shape[0]:
137
+ cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
138
+ cute.arch.cp_async_commit_group()
139
+ cute.arch.cp_async_wait_group(0)
140
+ # Fill OOB values with -inf
141
+ if cutlass.const_expr(not is_even_N):
142
+ utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
143
+ cute.autovec_copy(tXsX, tXrX)
144
+ x = tXrX.load().to(cute.Float32)
145
+
146
+ target_logit = cute.Float32.zero
147
+ if row < shape[0] and tXcX[0][1] == 0:
148
+ target_logit = cute.Float32(mX[row, target])
149
+
150
+ threads_per_row = tv_layout.shape[0][0]
151
+ if cutlass.const_expr(not self.online_softmax):
152
+ max_x = utils.row_reduce(
153
+ x,
154
+ cute.ReductionOp.MAX,
155
+ threads_per_row,
156
+ reduction_buffer[None, None, 0],
157
+ mbar_ptr + 0 if self.cluster_n > 1 else None,
158
+ init_val=-cutlass.Float32.inf,
159
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
160
+ )
161
+ if cutlass.const_expr(self.reload_from == "smem"):
162
+ cute.autovec_copy(tXsX, tXrX)
163
+ x = tXrX.load().to(cute.Float32)
164
+ log2_e = math.log2(math.e)
165
+ # exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
166
+ # a bit faster, probably because it's calling ex2.approx.ftz instead of ex2.approx?
167
+ # exp_x = utils.exp2f((x - max_x) * log2_e)
168
+ # This would use ffma instead of fadd then fmul
169
+ exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
170
+ denom = utils.row_reduce(
171
+ exp_x,
172
+ cute.ReductionOp.ADD,
173
+ threads_per_row,
174
+ reduction_buffer[None, None, 1],
175
+ mbar_ptr + 1 if self.cluster_n > 1 else None,
176
+ init_val=0.0,
177
+ )
178
+ else:
179
+ max_x, denom, _ = utils.online_softmax_reduce(
180
+ x,
181
+ threads_per_row,
182
+ reduction_buffer[None, None, 0],
183
+ mbar_ptr,
184
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
185
+ )
186
+
187
+ if (
188
+ tXcX[0][1] == 0
189
+ and row < shape[0]
190
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
191
+ ):
192
+ ln_2 = math.log(2.0)
193
+ lse = max_x + utils.log2f(denom) * ln_2
194
+ loss_val = lse - target_logit
195
+ mLoss[row] = loss_val.to(mLoss.element_type)
196
+ if cutlass.const_expr(mLSE is not None):
197
+ mLSE[row] = lse
179
198
 
180
199
 
181
200
  def cross_entropy(
182
201
  x: torch.Tensor,
183
202
  target: torch.Tensor,
203
+ return_lse: bool = False,
184
204
  ) -> torch.Tensor:
185
205
  """Cross entropy forward pass.
186
206
 
@@ -196,26 +216,37 @@ def cross_entropy(
196
216
  assert x.shape[0] == target.shape[0], "Batch dimensions must match"
197
217
  assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
198
218
  assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
199
- assert target.dtype == torch.int64, "Target must be int64"
219
+ assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
200
220
  M, N = x.shape
201
221
  device = x.device
202
- loss = torch.empty(M, device=device, dtype=x.dtype)
222
+ loss = torch.empty(M, device=device, dtype=torch.float32)
223
+ lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
203
224
  dtype = torch2cute_dtype_map[x.dtype]
204
225
  convert_from_dlpack = lambda tensor: (
205
- from_dlpack(tensor.detach(), assumed_align=16)
206
- .mark_compact_shape_dynamic(mode=0, stride_order=(0, 1))
226
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
227
+ mode=0, stride_order=(0, 1)
228
+ )
207
229
  )
208
- x_tensor, = [convert_from_dlpack(tensor) for tensor in (x,)]
230
+ x_tensor = convert_from_dlpack(x)
209
231
  loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
232
+ lse_tensor = (
233
+ from_dlpack(lse.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
234
+ if lse is not None
235
+ else None
236
+ )
210
237
  target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_compact_shape_dynamic(mode=0)
211
238
  stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
212
- compile_key = (dtype, N)
239
+
240
+ compile_key = (dtype, N, lse is not None)
213
241
  if compile_key not in cross_entropy.compile_cache:
242
+ cross_entropy_op = CrossEntropy(dtype, N)
214
243
  cross_entropy.compile_cache[compile_key] = cute.compile(
215
- cross_entropy_interface, x_tensor, target_tensor, loss_tensor, stream, N
244
+ cross_entropy_op, x_tensor, target_tensor, loss_tensor, lse_tensor, stream
216
245
  )
217
- cross_entropy.compile_cache[compile_key](x_tensor, target_tensor, loss_tensor, stream)
218
- return loss
246
+ cross_entropy.compile_cache[compile_key](
247
+ x_tensor, target_tensor, loss_tensor, lse_tensor, stream
248
+ )
249
+ return loss if not return_lse else (loss, lse)
219
250
 
220
251
 
221
252
  cross_entropy.compile_cache = {}
@@ -0,0 +1,98 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ import torch
4
+ from typing import Type, Tuple, Optional
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+
9
+ import quack.utils as utils
10
+
11
+
12
+ torch2cute_dtype_map = {
13
+ torch.float16: cutlass.Float16,
14
+ torch.bfloat16: cutlass.BFloat16,
15
+ torch.float32: cutlass.Float32,
16
+ }
17
+
18
+
19
+ class ReductionBase:
20
+ def __init__(
21
+ self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=cutlass.Float32
22
+ ):
23
+ self.dtype = dtype
24
+ self.N = N
25
+ self.stage = stage
26
+ self.reduction_dtype = reduction_dtype
27
+
28
+ def _calculate_threads_per_row(self):
29
+ raise NotImplementedError()
30
+
31
+ def _set_cluster_n(self):
32
+ self.cluster_n = 1
33
+
34
+ def _get_num_threads(self):
35
+ return 128 if self.N <= 16384 else 256
36
+
37
+ def _get_tv_layout(self):
38
+ copy_bits = 128
39
+ vecsize = copy_bits // self.dtype.width
40
+ assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
41
+ num_threads = self._get_num_threads()
42
+ num_warps = num_threads // cute.arch.WARP_SIZE
43
+ assert num_threads % cute.arch.WARP_SIZE == 0
44
+
45
+ threads_per_row = self._calculate_threads_per_row()
46
+ num_blocks_N = cute.ceil_div(self.N // vecsize, threads_per_row * self.cluster_n)
47
+ cols_per_block = num_threads // threads_per_row
48
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
49
+ tv_layout = cute.make_layout(
50
+ ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
51
+ stride=(
52
+ (vecsize * cols_per_block, 1),
53
+ (cols_per_block, cols_per_block * vecsize * threads_per_row),
54
+ ),
55
+ )
56
+ return tiler_mn, tv_layout
57
+
58
+ def _smem_size_in_bytes(self, tiler_mn, num_warps):
59
+ return (
60
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
61
+ + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
62
+ + self.stage * (cutlass.Int64.width // 8)
63
+ )
64
+
65
+ def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int):
66
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
67
+ warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
68
+ return cute.make_ordered_layout(
69
+ (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
70
+ order=(1, 0, 2),
71
+ )
72
+
73
+ def _allocate_reduction_buffer_and_mbar(
74
+ self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout
75
+ ) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
76
+ reduction_buffer = smem.allocate_tensor(
77
+ self.reduction_dtype,
78
+ self._get_reduction_buffer_layout(tv_layout, self.cluster_n),
79
+ byte_alignment=4,
80
+ )
81
+ if cutlass.const_expr(self.cluster_n > 1):
82
+ mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=self.stage)
83
+ else:
84
+ mbar_ptr = None
85
+ return reduction_buffer, mbar_ptr
86
+
87
+ @cute.jit
88
+ def _initialize_cluster(self, tidx: cutlass.Int32, mbar_ptr: cute.Pointer, num_warps: int):
89
+ if cutlass.const_expr(self.cluster_n > 1):
90
+ if tidx < self.stage:
91
+ cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + tidx, 1)
92
+ cute.arch.mbarrier_init_fence()
93
+ if tidx < self.stage:
94
+ cute.arch.mbarrier_init_tx_bytes(
95
+ mbar_ptr + tidx, num_warps * self.cluster_n * self.reduction_dtype.width // 8
96
+ )
97
+ # Cluster arrive after barrier init
98
+ cute.arch.cluster_arrive_relaxed()