quack-kernels 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/__init__.py CHANGED
@@ -1,5 +1,11 @@
1
- __version__ = "0.1.2"
1
+ __version__ = "0.1.4"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
5
5
  from quack.cross_entropy import cross_entropy
6
+
7
+ __all__ = [
8
+ "rmsnorm",
9
+ "softmax",
10
+ "cross_entropy",
11
+ ]
quack/cross_entropy.py CHANGED
@@ -1,7 +1,6 @@
1
1
  import math
2
2
  import torch
3
- import operator
4
- from typing import Callable, Union, Optional
3
+ from typing import Optional, Type
5
4
 
6
5
  import cuda.bindings.driver as cuda
7
6
 
@@ -10,169 +9,195 @@ import cutlass.cute as cute
10
9
  from cutlass.cute.runtime import from_dlpack
11
10
 
12
11
  import quack.utils as utils
12
+ from quack.reduction_base import ReductionBase, torch2cute_dtype_map
13
13
 
14
14
 
15
- @cute.kernel
16
- def cross_entropy_kernel(
17
- mX: cute.Tensor, # (M, N)
18
- mTarget: cute.Tensor, # (M,)
19
- mLoss: cute.Tensor, # (M,)
20
- mLSE: Optional[cute.Tensor], # (M,)
21
- tv_layout: cute.Layout,
22
- tiler_mn: cute.Shape,
23
- cluster_n: cutlass.Constexpr = 1,
24
- ):
25
- tidx, _, _ = cute.arch.thread_idx()
26
- bidx, cluster_y, _ = cute.arch.block_idx()
27
-
28
- shape: cute.Shape = mX.shape
29
- idX = cute.make_identity_tensor(shape)
30
- # slice for CTAs
31
- gX, cX = [
32
- cute.local_tile(mT, tiler_mn, (bidx, 0 if cluster_n == 1 else cluster_y))
33
- for mT in (mX, idX)
34
- ]
35
-
36
- smem = cutlass.utils.SmemAllocator()
37
- sX = smem.allocate_tensor(mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16)
38
- num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
39
- warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
40
- reduction_buffer_layout = cute.make_ordered_layout(
15
+ class CrossEntropy(ReductionBase):
16
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int, online_softmax: bool = True):
41
17
  # 2 stages: 1 for max, 1 for sum
42
- (num_warps // warps_per_row, (warps_per_row, cluster_n), 2),
43
- order=(1, 0, 2)
44
- )
45
- reduction_buffer = smem.allocate_tensor(cutlass.Float32, reduction_buffer_layout, byte_alignment=4)
46
- if cutlass.const_expr(cluster_n > 1):
47
- # 1 mbar for max reduction, 1 mbar for sum reduction
48
- mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=2)
49
- else:
50
- mbar_ptr = None
51
-
52
- # declare the atoms which will be used later for memory copy
53
- copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128)
54
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
55
-
56
- #### Thread View
57
- tXgX = thr_copy_X.partition_S(gX)
58
- tXsX = thr_copy_X.partition_D(sX)
59
- tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
60
- tXrX = cute.make_fragment_like(tXgX)
61
-
62
- if cluster_n > 1:
63
- if tidx < 2:
64
- cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + tidx, 1)
65
- cute.arch.mbarrier_init_fence()
66
- if tidx < 2:
67
- cute.arch.mbarrier_init_tx_bytes(mbar_ptr + tidx, num_warps * cluster_n * cutlass.Float32.width // 8)
68
- # Cluster arrive after barrier init
69
- cute.arch.cluster_arrive_relaxed()
70
-
71
- row = tXcX[0][0]
72
- target = cute.Int32.zero
73
- if row < shape[0] and tXcX[0][1] == 0:
74
- target = cute.Int32(mTarget[row])
75
-
76
- is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * cluster_n)
77
- tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
78
- if row < shape[0]:
79
- cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
80
- cute.arch.cp_async_commit_group()
81
- cute.arch.cp_async_wait_group(0)
82
- cute.autovec_copy(tXsX, tXrX)
83
- x = tXrX.load().to(cute.Float32)
84
- # Fill OOB values with -inf
85
- if cutlass.const_expr(not is_even_N):
86
- tXrX_fp32 = cute.make_fragment_like(tXrX, cutlass.Float32)
87
- tXrX_fp32.store(x)
88
- for rest_v in range(tXpX.shape[0]):
89
- for rest_k in range(tXpX.shape[2]):
90
- if not tXpX[rest_v, 0, rest_k]:
91
- tXrX_fp32[(None, rest_v), None, rest_k].fill(-cutlass.Float32.inf)
92
- x = tXrX_fp32.load()
93
-
94
- target_logit = cute.Float32.zero
95
- if row < shape[0] and tXcX[0][1] == 0:
96
- target_logit = cute.Float32(mX[row, target])
97
-
98
- threads_per_row = tv_layout.shape[0][0]
99
- max_x = utils.row_reduce(
100
- x,
101
- cute.ReductionOp.MAX,
102
- threads_per_row,
103
- reduction_buffer[None, None, 0],
104
- mbar_ptr + 0 if cluster_n > 1 else None,
105
- init_val=-cutlass.Float32.inf,
106
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(cluster_n > 1) else None
107
- )
108
- log2_e = math.log2(math.e)
109
- # exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
110
- exp_x = utils.exp2f((x - max_x) * log2_e) # a bit faster, idk why
111
- denom = utils.row_reduce(
112
- exp_x,
113
- cute.ReductionOp.ADD,
114
- threads_per_row,
115
- reduction_buffer[None, None, 1],
116
- mbar_ptr + 1 if cluster_n > 1 else None,
117
- init_val=0.0,
118
- )
119
-
120
- if tXcX[0][1] == 0 and row < shape[0] and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0):
121
- ln_2 = math.log(2.0)
122
- lse = max_x + utils.log2f(denom) * ln_2
123
- loss_val = lse - target_logit
124
- mLoss[row] = loss_val.to(mLoss.element_type)
125
- if cutlass.const_expr(mLSE is not None):
126
- mLSE[row] = lse
127
-
128
-
129
- @cute.jit
130
- def cross_entropy_interface(
131
- mX: cute.Tensor,
132
- mTarget: cute.Tensor,
133
- mLoss: cute.Tensor,
134
- mLSE: Optional[cute.Tensor],
135
- stream: cuda.CUstream,
136
- N: cutlass.Constexpr,
137
- copy_bits: cutlass.Constexpr = 128
138
- ):
139
- vecsize = copy_bits // mX.element_type.width
140
- assert N % vecsize == 0, f"Input N {N} is not divisible by vector size {vecsize}"
141
- num_threads = 128 if N <= 16384 else 256
142
-
143
- num_warps = num_threads // cute.arch.WARP_SIZE
144
- assert num_threads % cute.arch.WARP_SIZE == 0
145
- threads_per_row = 8 if N <= 64 else (16 if N <= 128 else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256))))
146
-
147
- if cutlass.const_expr(mX.element_type.width == 16):
148
- cluster_n = 1 if N <= 16 * 1024 else (2 if N <= 32 * 1024 else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)))
149
- else: # fp32
150
- cluster_n = 1 if N <= 16 * 1024 else (2 if N <= 64 * 1024 else (4 if N <= 128 * 1024 else 8))
151
-
152
- num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row * cluster_n)
153
- cols_per_block = num_threads // threads_per_row
154
- tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) # This rounds up N
155
- tv_layout = cute.make_layout(
156
- ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
157
- stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * threads_per_row))
158
- )
18
+ super().__init__(
19
+ dtype,
20
+ N,
21
+ stage=2 if not online_softmax else 1,
22
+ reduction_dtype=cutlass.Float32 if not online_softmax else cutlass.Int64,
23
+ )
24
+ self.online_softmax = online_softmax
25
+ self.reload_from = None if N <= 16384 or online_softmax else "smem"
26
+
27
+ def _calculate_threads_per_row(self):
28
+ N = self.N
29
+ return (
30
+ 8
31
+ if N <= 64
32
+ else (
33
+ 16
34
+ if N <= 128
35
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
36
+ )
37
+ )
159
38
 
160
- smem_allocated = cute.size_in_bytes(mX.element_type, cute.make_layout(tiler_mn)) + 2 * num_warps * cluster_n * (cutlass.Float32.width // 8) + 2 * (cutlass.Int64.width // 8)
161
- cross_entropy_kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn, cluster_n).launch(
162
- grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1],
163
- block=[cute.size(tv_layout, mode=[0]), 1, 1],
164
- # Launching with cluster=[1, 1, 1] instead of None slows down the kernel by ~8us
165
- cluster=[1, cluster_n, 1] if cluster_n > 1 else None,
166
- smem=smem_allocated,
167
- stream=stream,
168
- )
39
+ def _set_cluster_n(self):
40
+ N = self.N
41
+ if cutlass.const_expr(self.dtype.width == 16):
42
+ cluster_n = (
43
+ 1
44
+ if N <= 16 * 1024
45
+ else (
46
+ 2
47
+ if N <= 32 * 1024
48
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
49
+ )
50
+ )
51
+ else: # fp32
52
+ cluster_n = (
53
+ 1
54
+ if N <= 16 * 1024
55
+ else (
56
+ 2
57
+ if N <= 64 * 1024
58
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
59
+ )
60
+ )
61
+ self.cluster_n = cluster_n
62
+
63
+ @cute.jit
64
+ def __call__(
65
+ self,
66
+ mX: cute.Tensor,
67
+ mTarget: cute.Tensor,
68
+ mLoss: cute.Tensor,
69
+ mLSE: Optional[cute.Tensor],
70
+ stream: cuda.CUstream,
71
+ ):
72
+ assert mX.element_type == self.dtype
73
+ self._set_cluster_n()
74
+ tiler_mn, tv_layout = self._get_tv_layout()
75
+ num_threads = cute.size(tv_layout, mode=[0])
76
+ num_warps = num_threads // cute.arch.WARP_SIZE
77
+ self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
78
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
79
+ block=[num_threads, 1, 1],
80
+ cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
81
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
82
+ stream=stream,
83
+ )
169
84
 
85
+ @cute.kernel
86
+ def kernel(
87
+ self,
88
+ mX: cute.Tensor, # (M, N)
89
+ mTarget: cute.Tensor, # (M,)
90
+ mLoss: cute.Tensor, # (M,)
91
+ mLSE: Optional[cute.Tensor], # (M,)
92
+ tv_layout: cute.Layout,
93
+ tiler_mn: cute.Shape,
94
+ ):
95
+ tidx, _, _ = cute.arch.thread_idx()
96
+ bidx, _, _ = cute.arch.block_idx()
97
+ if cutlass.const_expr(self.cluster_n > 1):
98
+ cluster_y = cute.arch.block_idx()[1]
99
+ else:
100
+ cluster_y = cutlass.const_expr(0)
101
+
102
+ shape: cute.Shape = mX.shape
103
+ idX = cute.make_identity_tensor(shape)
104
+ # slice for CTAs
105
+ gX, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, idX)]
106
+
107
+ smem = cutlass.utils.SmemAllocator()
108
+ sX = smem.allocate_tensor(
109
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
110
+ )
111
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
170
112
 
171
- torch2cute_dtype_map = {
172
- torch.float16: cutlass.Float16,
173
- torch.bfloat16: cutlass.BFloat16,
174
- torch.float32: cutlass.Float32,
175
- }
113
+ # declare the atoms which will be used later for memory copy
114
+ copy_atom_load_X = cute.make_copy_atom(
115
+ cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
116
+ )
117
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
118
+
119
+ #### Thread View
120
+ tXgX = thr_copy_X.partition_S(gX)
121
+ tXsX = thr_copy_X.partition_D(sX)
122
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
123
+ tXrX = cute.make_fragment_like(tXgX)
124
+
125
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
126
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
127
+
128
+ row = tXcX[0][0]
129
+ target = cute.Int32.zero
130
+ if row < shape[0] and tXcX[0][1] == 0:
131
+ target = cute.Int32(mTarget[row])
132
+
133
+ is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
134
+ tXpX = (
135
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
136
+ if cutlass.const_expr(not is_even_N)
137
+ else None
138
+ )
139
+ if row < shape[0]:
140
+ cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
141
+ cute.arch.cp_async_commit_group()
142
+ cute.arch.cp_async_wait_group(0)
143
+ # Fill OOB values with -inf
144
+ if cutlass.const_expr(not is_even_N):
145
+ utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
146
+ cute.autovec_copy(tXsX, tXrX)
147
+ x = tXrX.load().to(cute.Float32)
148
+
149
+ target_logit = cute.Float32.zero
150
+ if row < shape[0] and tXcX[0][1] == 0:
151
+ target_logit = cute.Float32(mX[row, target])
152
+
153
+ threads_per_row = tv_layout.shape[0][0]
154
+ if cutlass.const_expr(not self.online_softmax):
155
+ max_x = utils.row_reduce(
156
+ x,
157
+ cute.ReductionOp.MAX,
158
+ threads_per_row,
159
+ reduction_buffer[None, None, 0],
160
+ mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
161
+ init_val=-cutlass.Float32.inf,
162
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
163
+ )
164
+ if cutlass.const_expr(self.reload_from == "smem"):
165
+ cute.autovec_copy(tXsX, tXrX)
166
+ x = tXrX.load().to(cute.Float32)
167
+ log2_e = math.log2(math.e)
168
+ # exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
169
+ # a bit faster, probably because it's calling ex2.approx.ftz instead of ex2.approx?
170
+ # exp_x = utils.exp2f((x - max_x) * log2_e)
171
+ # This would use ffma instead of fadd then fmul
172
+ exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
173
+ denom = utils.row_reduce(
174
+ exp_x,
175
+ cute.ReductionOp.ADD,
176
+ threads_per_row,
177
+ reduction_buffer[None, None, 1],
178
+ mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
179
+ init_val=0.0,
180
+ )
181
+ else:
182
+ max_x, denom, _ = utils.online_softmax_reduce(
183
+ x,
184
+ threads_per_row,
185
+ reduction_buffer[None, None, 0],
186
+ mbar_ptr,
187
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
188
+ )
189
+
190
+ if (
191
+ tXcX[0][1] == 0
192
+ and row < shape[0]
193
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
194
+ ):
195
+ ln_2 = math.log(2.0)
196
+ lse = max_x + utils.log2f(denom) * ln_2
197
+ loss_val = lse - target_logit
198
+ mLoss[row] = loss_val.to(mLoss.element_type)
199
+ if cutlass.const_expr(mLSE is not None):
200
+ mLSE[row] = lse
176
201
 
177
202
 
178
203
  def cross_entropy(
@@ -194,27 +219,36 @@ def cross_entropy(
194
219
  assert x.shape[0] == target.shape[0], "Batch dimensions must match"
195
220
  assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
196
221
  assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
197
- assert target.dtype == torch.int64, "Target must be int64"
222
+ assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
198
223
  M, N = x.shape
199
224
  device = x.device
200
225
  loss = torch.empty(M, device=device, dtype=torch.float32)
201
226
  lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
202
227
  dtype = torch2cute_dtype_map[x.dtype]
203
228
  convert_from_dlpack = lambda tensor: (
204
- from_dlpack(tensor.detach(), assumed_align=16)
205
- .mark_compact_shape_dynamic(mode=0, stride_order=(0, 1))
229
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
230
+ mode=0, stride_order=(0, 1)
231
+ )
206
232
  )
207
- x_tensor, = [convert_from_dlpack(tensor) for tensor in (x,)]
233
+ x_tensor = convert_from_dlpack(x)
208
234
  loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
209
- lse_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0) if lse is not None else None
235
+ lse_tensor = (
236
+ from_dlpack(lse.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
237
+ if lse is not None
238
+ else None
239
+ )
210
240
  target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_compact_shape_dynamic(mode=0)
211
241
  stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
212
- compile_key = (dtype, N, lse_tensor is not None)
242
+
243
+ compile_key = (dtype, N, lse is not None)
213
244
  if compile_key not in cross_entropy.compile_cache:
245
+ cross_entropy_op = CrossEntropy(dtype, N)
214
246
  cross_entropy.compile_cache[compile_key] = cute.compile(
215
- cross_entropy_interface, x_tensor, target_tensor, loss_tensor, lse_tensor, stream, N
247
+ cross_entropy_op, x_tensor, target_tensor, loss_tensor, lse_tensor, stream
216
248
  )
217
- cross_entropy.compile_cache[compile_key](x_tensor, target_tensor, loss_tensor, lse_tensor, stream)
249
+ cross_entropy.compile_cache[compile_key](
250
+ x_tensor, target_tensor, loss_tensor, lse_tensor, stream
251
+ )
218
252
  return loss if not return_lse else (loss, lse)
219
253
 
220
254
 
@@ -0,0 +1,98 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ import torch
4
+ from typing import Type, Tuple, Optional
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+
9
+ import quack.utils as utils
10
+
11
+
12
+ torch2cute_dtype_map = {
13
+ torch.float16: cutlass.Float16,
14
+ torch.bfloat16: cutlass.BFloat16,
15
+ torch.float32: cutlass.Float32,
16
+ }
17
+
18
+
19
+ class ReductionBase:
20
+ def __init__(
21
+ self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=cutlass.Float32
22
+ ):
23
+ self.dtype = dtype
24
+ self.N = N
25
+ self.stage = stage
26
+ self.reduction_dtype = reduction_dtype
27
+
28
+ def _calculate_threads_per_row(self):
29
+ raise NotImplementedError()
30
+
31
+ def _set_cluster_n(self):
32
+ self.cluster_n = 1
33
+
34
+ def _get_num_threads(self):
35
+ return 128 if self.N <= 16384 else 256
36
+
37
+ def _get_tv_layout(self):
38
+ copy_bits = 128
39
+ vecsize = copy_bits // self.dtype.width
40
+ assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
41
+ num_threads = self._get_num_threads()
42
+ num_warps = num_threads // cute.arch.WARP_SIZE
43
+ assert num_threads % cute.arch.WARP_SIZE == 0
44
+
45
+ threads_per_row = self._calculate_threads_per_row()
46
+ num_blocks_N = cute.ceil_div(self.N // vecsize, threads_per_row * self.cluster_n)
47
+ cols_per_block = num_threads // threads_per_row
48
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
49
+ tv_layout = cute.make_layout(
50
+ ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
51
+ stride=(
52
+ (vecsize * cols_per_block, 1),
53
+ (cols_per_block, cols_per_block * vecsize * threads_per_row),
54
+ ),
55
+ )
56
+ return tiler_mn, tv_layout
57
+
58
+ def _smem_size_in_bytes(self, tiler_mn, num_warps):
59
+ return (
60
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
61
+ + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
62
+ + self.stage * (cutlass.Int64.width // 8)
63
+ )
64
+
65
+ def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int):
66
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
67
+ warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
68
+ return cute.make_ordered_layout(
69
+ (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
70
+ order=(1, 0, 2),
71
+ )
72
+
73
+ def _allocate_reduction_buffer_and_mbar(
74
+ self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout
75
+ ) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
76
+ reduction_buffer = smem.allocate_tensor(
77
+ self.reduction_dtype,
78
+ self._get_reduction_buffer_layout(tv_layout, self.cluster_n),
79
+ byte_alignment=4,
80
+ )
81
+ if cutlass.const_expr(self.cluster_n > 1):
82
+ mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=self.stage)
83
+ else:
84
+ mbar_ptr = None
85
+ return reduction_buffer, mbar_ptr
86
+
87
+ @cute.jit
88
+ def _initialize_cluster(self, tidx: cutlass.Int32, mbar_ptr: cute.Pointer, num_warps: int):
89
+ if cutlass.const_expr(self.cluster_n > 1):
90
+ if tidx < self.stage:
91
+ cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
92
+ cute.arch.mbarrier_init_fence()
93
+ if tidx < self.stage:
94
+ cute.arch.mbarrier_arrive_and_expect_tx(
95
+ mbar_ptr + tidx, num_warps * self.cluster_n * self.reduction_dtype.width // 8
96
+ )
97
+ # Cluster arrive after barrier init
98
+ cute.arch.cluster_arrive_relaxed()