quack-kernels 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/__init__.py CHANGED
@@ -1,5 +1,11 @@
1
- __version__ = "0.1.2"
1
+ __version__ = "0.1.3"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
5
5
  from quack.cross_entropy import cross_entropy
6
+
7
+ __all__ = [
8
+ "rmsnorm",
9
+ "softmax",
10
+ "cross_entropy",
11
+ ]
quack/cross_entropy.py CHANGED
@@ -1,7 +1,6 @@
1
1
  import math
2
2
  import torch
3
- import operator
4
- from typing import Callable, Union, Optional
3
+ from typing import Optional, Type
5
4
 
6
5
  import cuda.bindings.driver as cuda
7
6
 
@@ -10,169 +9,192 @@ import cutlass.cute as cute
10
9
  from cutlass.cute.runtime import from_dlpack
11
10
 
12
11
  import quack.utils as utils
12
+ from quack.reduction_base import ReductionBase, torch2cute_dtype_map
13
13
 
14
14
 
15
- @cute.kernel
16
- def cross_entropy_kernel(
17
- mX: cute.Tensor, # (M, N)
18
- mTarget: cute.Tensor, # (M,)
19
- mLoss: cute.Tensor, # (M,)
20
- mLSE: Optional[cute.Tensor], # (M,)
21
- tv_layout: cute.Layout,
22
- tiler_mn: cute.Shape,
23
- cluster_n: cutlass.Constexpr = 1,
24
- ):
25
- tidx, _, _ = cute.arch.thread_idx()
26
- bidx, cluster_y, _ = cute.arch.block_idx()
27
-
28
- shape: cute.Shape = mX.shape
29
- idX = cute.make_identity_tensor(shape)
30
- # slice for CTAs
31
- gX, cX = [
32
- cute.local_tile(mT, tiler_mn, (bidx, 0 if cluster_n == 1 else cluster_y))
33
- for mT in (mX, idX)
34
- ]
35
-
36
- smem = cutlass.utils.SmemAllocator()
37
- sX = smem.allocate_tensor(mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16)
38
- num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
39
- warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
40
- reduction_buffer_layout = cute.make_ordered_layout(
15
+ class CrossEntropy(ReductionBase):
16
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int, online_softmax: bool = True):
41
17
  # 2 stages: 1 for max, 1 for sum
42
- (num_warps // warps_per_row, (warps_per_row, cluster_n), 2),
43
- order=(1, 0, 2)
44
- )
45
- reduction_buffer = smem.allocate_tensor(cutlass.Float32, reduction_buffer_layout, byte_alignment=4)
46
- if cutlass.const_expr(cluster_n > 1):
47
- # 1 mbar for max reduction, 1 mbar for sum reduction
48
- mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=2)
49
- else:
50
- mbar_ptr = None
51
-
52
- # declare the atoms which will be used later for memory copy
53
- copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128)
54
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
55
-
56
- #### Thread View
57
- tXgX = thr_copy_X.partition_S(gX)
58
- tXsX = thr_copy_X.partition_D(sX)
59
- tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
60
- tXrX = cute.make_fragment_like(tXgX)
61
-
62
- if cluster_n > 1:
63
- if tidx < 2:
64
- cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + tidx, 1)
65
- cute.arch.mbarrier_init_fence()
66
- if tidx < 2:
67
- cute.arch.mbarrier_init_tx_bytes(mbar_ptr + tidx, num_warps * cluster_n * cutlass.Float32.width // 8)
68
- # Cluster arrive after barrier init
69
- cute.arch.cluster_arrive_relaxed()
70
-
71
- row = tXcX[0][0]
72
- target = cute.Int32.zero
73
- if row < shape[0] and tXcX[0][1] == 0:
74
- target = cute.Int32(mTarget[row])
75
-
76
- is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * cluster_n)
77
- tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
78
- if row < shape[0]:
79
- cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
80
- cute.arch.cp_async_commit_group()
81
- cute.arch.cp_async_wait_group(0)
82
- cute.autovec_copy(tXsX, tXrX)
83
- x = tXrX.load().to(cute.Float32)
84
- # Fill OOB values with -inf
85
- if cutlass.const_expr(not is_even_N):
86
- tXrX_fp32 = cute.make_fragment_like(tXrX, cutlass.Float32)
87
- tXrX_fp32.store(x)
88
- for rest_v in range(tXpX.shape[0]):
89
- for rest_k in range(tXpX.shape[2]):
90
- if not tXpX[rest_v, 0, rest_k]:
91
- tXrX_fp32[(None, rest_v), None, rest_k].fill(-cutlass.Float32.inf)
92
- x = tXrX_fp32.load()
93
-
94
- target_logit = cute.Float32.zero
95
- if row < shape[0] and tXcX[0][1] == 0:
96
- target_logit = cute.Float32(mX[row, target])
97
-
98
- threads_per_row = tv_layout.shape[0][0]
99
- max_x = utils.row_reduce(
100
- x,
101
- cute.ReductionOp.MAX,
102
- threads_per_row,
103
- reduction_buffer[None, None, 0],
104
- mbar_ptr + 0 if cluster_n > 1 else None,
105
- init_val=-cutlass.Float32.inf,
106
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(cluster_n > 1) else None
107
- )
108
- log2_e = math.log2(math.e)
109
- # exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
110
- exp_x = utils.exp2f((x - max_x) * log2_e) # a bit faster, idk why
111
- denom = utils.row_reduce(
112
- exp_x,
113
- cute.ReductionOp.ADD,
114
- threads_per_row,
115
- reduction_buffer[None, None, 1],
116
- mbar_ptr + 1 if cluster_n > 1 else None,
117
- init_val=0.0,
118
- )
18
+ super().__init__(
19
+ dtype,
20
+ N,
21
+ stage=2 if not online_softmax else 1,
22
+ reduction_dtype=cutlass.Float32 if not online_softmax else cutlass.Int64,
23
+ )
24
+ self.online_softmax = online_softmax
25
+ self.reload_from = None if N <= 16384 or online_softmax else "smem"
26
+
27
+ def _calculate_threads_per_row(self):
28
+ N = self.N
29
+ return (
30
+ 8
31
+ if N <= 64
32
+ else (
33
+ 16
34
+ if N <= 128
35
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
36
+ )
37
+ )
119
38
 
120
- if tXcX[0][1] == 0 and row < shape[0] and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0):
121
- ln_2 = math.log(2.0)
122
- lse = max_x + utils.log2f(denom) * ln_2
123
- loss_val = lse - target_logit
124
- mLoss[row] = loss_val.to(mLoss.element_type)
125
- if cutlass.const_expr(mLSE is not None):
126
- mLSE[row] = lse
127
-
128
-
129
- @cute.jit
130
- def cross_entropy_interface(
131
- mX: cute.Tensor,
132
- mTarget: cute.Tensor,
133
- mLoss: cute.Tensor,
134
- mLSE: Optional[cute.Tensor],
135
- stream: cuda.CUstream,
136
- N: cutlass.Constexpr,
137
- copy_bits: cutlass.Constexpr = 128
138
- ):
139
- vecsize = copy_bits // mX.element_type.width
140
- assert N % vecsize == 0, f"Input N {N} is not divisible by vector size {vecsize}"
141
- num_threads = 128 if N <= 16384 else 256
142
-
143
- num_warps = num_threads // cute.arch.WARP_SIZE
144
- assert num_threads % cute.arch.WARP_SIZE == 0
145
- threads_per_row = 8 if N <= 64 else (16 if N <= 128 else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256))))
146
-
147
- if cutlass.const_expr(mX.element_type.width == 16):
148
- cluster_n = 1 if N <= 16 * 1024 else (2 if N <= 32 * 1024 else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)))
149
- else: # fp32
150
- cluster_n = 1 if N <= 16 * 1024 else (2 if N <= 64 * 1024 else (4 if N <= 128 * 1024 else 8))
151
-
152
- num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row * cluster_n)
153
- cols_per_block = num_threads // threads_per_row
154
- tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) # This rounds up N
155
- tv_layout = cute.make_layout(
156
- ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
157
- stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * threads_per_row))
158
- )
39
+ def _set_cluster_n(self):
40
+ N = self.N
41
+ if cutlass.const_expr(self.dtype.width == 16):
42
+ cluster_n = (
43
+ 1
44
+ if N <= 16 * 1024
45
+ else (
46
+ 2
47
+ if N <= 32 * 1024
48
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
49
+ )
50
+ )
51
+ else: # fp32
52
+ cluster_n = (
53
+ 1
54
+ if N <= 16 * 1024
55
+ else (
56
+ 2
57
+ if N <= 64 * 1024
58
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
59
+ )
60
+ )
61
+ self.cluster_n = cluster_n
62
+
63
+ @cute.jit
64
+ def __call__(
65
+ self,
66
+ mX: cute.Tensor,
67
+ mTarget: cute.Tensor,
68
+ mLoss: cute.Tensor,
69
+ mLSE: Optional[cute.Tensor],
70
+ stream: cuda.CUstream,
71
+ ):
72
+ assert mX.element_type == self.dtype
73
+ self._set_cluster_n()
74
+ tiler_mn, tv_layout = self._get_tv_layout()
75
+ num_threads = cute.size(tv_layout, mode=[0])
76
+ num_warps = num_threads // cute.arch.WARP_SIZE
77
+ self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
78
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
79
+ block=[num_threads, 1, 1],
80
+ cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
81
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
82
+ stream=stream,
83
+ )
159
84
 
160
- smem_allocated = cute.size_in_bytes(mX.element_type, cute.make_layout(tiler_mn)) + 2 * num_warps * cluster_n * (cutlass.Float32.width // 8) + 2 * (cutlass.Int64.width // 8)
161
- cross_entropy_kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn, cluster_n).launch(
162
- grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1],
163
- block=[cute.size(tv_layout, mode=[0]), 1, 1],
164
- # Launching with cluster=[1, 1, 1] instead of None slows down the kernel by ~8us
165
- cluster=[1, cluster_n, 1] if cluster_n > 1 else None,
166
- smem=smem_allocated,
167
- stream=stream,
168
- )
85
+ @cute.kernel
86
+ def kernel(
87
+ self,
88
+ mX: cute.Tensor, # (M, N)
89
+ mTarget: cute.Tensor, # (M,)
90
+ mLoss: cute.Tensor, # (M,)
91
+ mLSE: Optional[cute.Tensor], # (M,)
92
+ tv_layout: cute.Layout,
93
+ tiler_mn: cute.Shape,
94
+ ):
95
+ tidx, _, _ = cute.arch.thread_idx()
96
+ bidx, cluster_y, _ = cute.arch.block_idx()
97
+
98
+ shape: cute.Shape = mX.shape
99
+ idX = cute.make_identity_tensor(shape)
100
+ # slice for CTAs
101
+ gX, cX = [
102
+ cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
103
+ for mT in (mX, idX)
104
+ ]
105
+
106
+ smem = cutlass.utils.SmemAllocator()
107
+ sX = smem.allocate_tensor(
108
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
109
+ )
110
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
169
111
 
112
+ # declare the atoms which will be used later for memory copy
113
+ copy_atom_load_X = cute.make_copy_atom(
114
+ cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
115
+ )
116
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
117
+
118
+ #### Thread View
119
+ tXgX = thr_copy_X.partition_S(gX)
120
+ tXsX = thr_copy_X.partition_D(sX)
121
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
122
+ tXrX = cute.make_fragment_like(tXgX)
123
+
124
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
125
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
170
126
 
171
- torch2cute_dtype_map = {
172
- torch.float16: cutlass.Float16,
173
- torch.bfloat16: cutlass.BFloat16,
174
- torch.float32: cutlass.Float32,
175
- }
127
+ row = tXcX[0][0]
128
+ target = cute.Int32.zero
129
+ if row < shape[0] and tXcX[0][1] == 0:
130
+ target = cute.Int32(mTarget[row])
131
+
132
+ is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
133
+ tXpX = (
134
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
135
+ )
136
+ if row < shape[0]:
137
+ cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
138
+ cute.arch.cp_async_commit_group()
139
+ cute.arch.cp_async_wait_group(0)
140
+ # Fill OOB values with -inf
141
+ if cutlass.const_expr(not is_even_N):
142
+ utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
143
+ cute.autovec_copy(tXsX, tXrX)
144
+ x = tXrX.load().to(cute.Float32)
145
+
146
+ target_logit = cute.Float32.zero
147
+ if row < shape[0] and tXcX[0][1] == 0:
148
+ target_logit = cute.Float32(mX[row, target])
149
+
150
+ threads_per_row = tv_layout.shape[0][0]
151
+ if cutlass.const_expr(not self.online_softmax):
152
+ max_x = utils.row_reduce(
153
+ x,
154
+ cute.ReductionOp.MAX,
155
+ threads_per_row,
156
+ reduction_buffer[None, None, 0],
157
+ mbar_ptr + 0 if self.cluster_n > 1 else None,
158
+ init_val=-cutlass.Float32.inf,
159
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
160
+ )
161
+ if cutlass.const_expr(self.reload_from == "smem"):
162
+ cute.autovec_copy(tXsX, tXrX)
163
+ x = tXrX.load().to(cute.Float32)
164
+ log2_e = math.log2(math.e)
165
+ # exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
166
+ # a bit faster, probably because it's calling ex2.approx.ftz instead of ex2.approx?
167
+ # exp_x = utils.exp2f((x - max_x) * log2_e)
168
+ # This would use ffma instead of fadd then fmul
169
+ exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
170
+ denom = utils.row_reduce(
171
+ exp_x,
172
+ cute.ReductionOp.ADD,
173
+ threads_per_row,
174
+ reduction_buffer[None, None, 1],
175
+ mbar_ptr + 1 if self.cluster_n > 1 else None,
176
+ init_val=0.0,
177
+ )
178
+ else:
179
+ max_x, denom, _ = utils.online_softmax_reduce(
180
+ x,
181
+ threads_per_row,
182
+ reduction_buffer[None, None, 0],
183
+ mbar_ptr,
184
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
185
+ )
186
+
187
+ if (
188
+ tXcX[0][1] == 0
189
+ and row < shape[0]
190
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
191
+ ):
192
+ ln_2 = math.log(2.0)
193
+ lse = max_x + utils.log2f(denom) * ln_2
194
+ loss_val = lse - target_logit
195
+ mLoss[row] = loss_val.to(mLoss.element_type)
196
+ if cutlass.const_expr(mLSE is not None):
197
+ mLSE[row] = lse
176
198
 
177
199
 
178
200
  def cross_entropy(
@@ -194,27 +216,36 @@ def cross_entropy(
194
216
  assert x.shape[0] == target.shape[0], "Batch dimensions must match"
195
217
  assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
196
218
  assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
197
- assert target.dtype == torch.int64, "Target must be int64"
219
+ assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
198
220
  M, N = x.shape
199
221
  device = x.device
200
222
  loss = torch.empty(M, device=device, dtype=torch.float32)
201
223
  lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
202
224
  dtype = torch2cute_dtype_map[x.dtype]
203
225
  convert_from_dlpack = lambda tensor: (
204
- from_dlpack(tensor.detach(), assumed_align=16)
205
- .mark_compact_shape_dynamic(mode=0, stride_order=(0, 1))
226
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
227
+ mode=0, stride_order=(0, 1)
228
+ )
206
229
  )
207
- x_tensor, = [convert_from_dlpack(tensor) for tensor in (x,)]
230
+ x_tensor = convert_from_dlpack(x)
208
231
  loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
209
- lse_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0) if lse is not None else None
232
+ lse_tensor = (
233
+ from_dlpack(lse.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
234
+ if lse is not None
235
+ else None
236
+ )
210
237
  target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_compact_shape_dynamic(mode=0)
211
238
  stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
212
- compile_key = (dtype, N, lse_tensor is not None)
239
+
240
+ compile_key = (dtype, N, lse is not None)
213
241
  if compile_key not in cross_entropy.compile_cache:
242
+ cross_entropy_op = CrossEntropy(dtype, N)
214
243
  cross_entropy.compile_cache[compile_key] = cute.compile(
215
- cross_entropy_interface, x_tensor, target_tensor, loss_tensor, lse_tensor, stream, N
244
+ cross_entropy_op, x_tensor, target_tensor, loss_tensor, lse_tensor, stream
216
245
  )
217
- cross_entropy.compile_cache[compile_key](x_tensor, target_tensor, loss_tensor, lse_tensor, stream)
246
+ cross_entropy.compile_cache[compile_key](
247
+ x_tensor, target_tensor, loss_tensor, lse_tensor, stream
248
+ )
218
249
  return loss if not return_lse else (loss, lse)
219
250
 
220
251
 
@@ -0,0 +1,98 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ import torch
4
+ from typing import Type, Tuple, Optional
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+
9
+ import quack.utils as utils
10
+
11
+
12
+ torch2cute_dtype_map = {
13
+ torch.float16: cutlass.Float16,
14
+ torch.bfloat16: cutlass.BFloat16,
15
+ torch.float32: cutlass.Float32,
16
+ }
17
+
18
+
19
+ class ReductionBase:
20
+ def __init__(
21
+ self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=cutlass.Float32
22
+ ):
23
+ self.dtype = dtype
24
+ self.N = N
25
+ self.stage = stage
26
+ self.reduction_dtype = reduction_dtype
27
+
28
+ def _calculate_threads_per_row(self):
29
+ raise NotImplementedError()
30
+
31
+ def _set_cluster_n(self):
32
+ self.cluster_n = 1
33
+
34
+ def _get_num_threads(self):
35
+ return 128 if self.N <= 16384 else 256
36
+
37
+ def _get_tv_layout(self):
38
+ copy_bits = 128
39
+ vecsize = copy_bits // self.dtype.width
40
+ assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
41
+ num_threads = self._get_num_threads()
42
+ num_warps = num_threads // cute.arch.WARP_SIZE
43
+ assert num_threads % cute.arch.WARP_SIZE == 0
44
+
45
+ threads_per_row = self._calculate_threads_per_row()
46
+ num_blocks_N = cute.ceil_div(self.N // vecsize, threads_per_row * self.cluster_n)
47
+ cols_per_block = num_threads // threads_per_row
48
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
49
+ tv_layout = cute.make_layout(
50
+ ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
51
+ stride=(
52
+ (vecsize * cols_per_block, 1),
53
+ (cols_per_block, cols_per_block * vecsize * threads_per_row),
54
+ ),
55
+ )
56
+ return tiler_mn, tv_layout
57
+
58
+ def _smem_size_in_bytes(self, tiler_mn, num_warps):
59
+ return (
60
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
61
+ + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
62
+ + self.stage * (cutlass.Int64.width // 8)
63
+ )
64
+
65
+ def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int):
66
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
67
+ warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
68
+ return cute.make_ordered_layout(
69
+ (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
70
+ order=(1, 0, 2),
71
+ )
72
+
73
+ def _allocate_reduction_buffer_and_mbar(
74
+ self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout
75
+ ) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
76
+ reduction_buffer = smem.allocate_tensor(
77
+ self.reduction_dtype,
78
+ self._get_reduction_buffer_layout(tv_layout, self.cluster_n),
79
+ byte_alignment=4,
80
+ )
81
+ if cutlass.const_expr(self.cluster_n > 1):
82
+ mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=self.stage)
83
+ else:
84
+ mbar_ptr = None
85
+ return reduction_buffer, mbar_ptr
86
+
87
+ @cute.jit
88
+ def _initialize_cluster(self, tidx: cutlass.Int32, mbar_ptr: cute.Pointer, num_warps: int):
89
+ if cutlass.const_expr(self.cluster_n > 1):
90
+ if tidx < self.stage:
91
+ cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + tidx, 1)
92
+ cute.arch.mbarrier_init_fence()
93
+ if tidx < self.stage:
94
+ cute.arch.mbarrier_init_tx_bytes(
95
+ mbar_ptr + tidx, num_warps * self.cluster_n * self.reduction_dtype.width // 8
96
+ )
97
+ # Cluster arrive after barrier init
98
+ cute.arch.cluster_arrive_relaxed()