quack-kernels 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/broadcast_utils.py +29 -0
  4. quack/compile_utils.py +19 -0
  5. quack/copy_utils.py +487 -0
  6. quack/cross_entropy.py +157 -233
  7. quack/cute_dsl_utils.py +20 -34
  8. quack/gemm.py +194 -0
  9. quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
  10. quack/gemm_config.py +72 -46
  11. quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
  12. quack/gemm_default_epi.py +259 -0
  13. quack/gemm_interface.py +177 -31
  14. quack/gemm_sm100.py +729 -506
  15. quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
  16. quack/gemm_symmetric.py +330 -0
  17. quack/gemm_wrapper_utils.py +3 -1
  18. quack/layout_utils.py +287 -0
  19. quack/linear.py +24 -16
  20. quack/pipeline.py +158 -3
  21. quack/reduce.py +88 -49
  22. quack/reduction_base.py +25 -36
  23. quack/rmsnorm.py +476 -526
  24. quack/sm100_utils.py +62 -0
  25. quack/sm90_utils.py +127 -0
  26. quack/softmax.py +135 -203
  27. quack/sort/bitonic_sort.py +13 -10
  28. quack/sort/utils.py +6 -6
  29. quack/tile_scheduler.py +23 -16
  30. quack/topk.py +409 -85
  31. quack/utils.py +32 -220
  32. quack/varlen_utils.py +370 -1
  33. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
  34. quack_kernels-0.2.3.dist-info/RECORD +44 -0
  35. quack/layernorm.py +0 -353
  36. quack/symmetric_dense_gemm_sm90.py +0 -2091
  37. quack_kernels-0.2.2.dist-info/RECORD +0 -37
  38. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
  39. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
  40. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,330 @@
1
+ from typing import Tuple, Optional, Callable
2
+ from functools import partial
3
+ from torch import Tensor
4
+ from quack.gemm_act import GemmActMixin, act_fn_map, gemm_act
5
+ from quack.gemm_sm90 import GemmSm90
6
+ from quack.gemm_sm100 import GemmSm100
7
+ from quack.tile_scheduler import TriangularTileScheduler
8
+ from quack.gemm_wrapper_utils import GemmWrapperBase
9
+ from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters
10
+ from quack.varlen_utils import VarlenManager
11
+ import quack.copy_utils as copy_utils
12
+ import cutlass
13
+ import cutlass.cute as cute
14
+ import cutlass.torch as cutlass_torch
15
+ from cutlass.cute.runtime import make_ptr
16
+ from cutlass import Int32, Float32, Boolean, const_expr
17
+ import cutlass.utils.hopper_helpers as sm90_utils_og
18
+ import cutlass.utils.blackwell_helpers as sm100_utils
19
+ from cutlass.cutlass_dsl import if_generate
20
+
21
+
22
+ class GemmSymmetricMixin(GemmActMixin, GemmSm90):
23
+ def get_scheduler_class(self, varlen_m: bool = False):
24
+ return TriangularTileScheduler
25
+
26
+ @cute.jit
27
+ def epilogue(
28
+ self,
29
+ params: GemmActMixin.EpilogueParams,
30
+ epi_smem_tensors: Tuple[cute.Tensor, ...],
31
+ tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
32
+ epi_pipeline: cutlass.pipeline.PipelineAsync,
33
+ epi_store_pipeline: cutlass.pipeline.PipelineAsync,
34
+ epi_read_state: cutlass.pipeline.PipelineState,
35
+ epi_producer_state: cutlass.pipeline.PipelineState,
36
+ epi_tile: cute.Tile,
37
+ load_acc_subtile: Callable,
38
+ tRS_rD: cute.Tensor,
39
+ tRS_rC: Optional[cute.Tensor],
40
+ tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100
41
+ tiled_copy_r2s: cute.TiledCopy,
42
+ tRS_sD: cute.Tensor,
43
+ tiled_copy_s2r: Optional[cute.TiledCopy],
44
+ tSR_rC: Optional[cute.Tensor],
45
+ tSR_sC: Optional[cute.Tensor],
46
+ copy_D: Optional[Callable],
47
+ copy_C: Optional[Callable],
48
+ tile_coord_mnkl: cute.Coord,
49
+ varlen_manager: VarlenManager,
50
+ epilogue_barrier: cutlass.pipeline.NamedBarrier,
51
+ tile_scheduler,
52
+ tidx: Int32,
53
+ is_tma_warp: Boolean,
54
+ ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
55
+ has_C = const_expr(tRS_rC is not None)
56
+ has_D = const_expr(copy_D is not None)
57
+
58
+ tma_atom_postact = params.tma_atom_postact
59
+ mPostAct_mnl = params.mPostAct_mnl
60
+ sRowVec, sColVec, sPostAct = epi_smem_tensors
61
+ get_smem_store_op = (
62
+ partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r)
63
+ if self.arch == 100
64
+ else sm90_utils_og.sm90_get_smem_store_op
65
+ )
66
+ copy_atom_postact_r2s = get_smem_store_op(
67
+ self.postact_layout, self.postact_dtype, self.acc_dtype
68
+ )
69
+ # tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
70
+ # tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom)
71
+ tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s)
72
+ tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct)
73
+ (tma_desc_postact_ptr,) = tma_desc_epi_ptrs
74
+ batch_idx = tile_coord_mnkl[3]
75
+ copy_postact, _, _ = self.epilog_gmem_copy_and_partition(
76
+ tma_atom_postact,
77
+ varlen_manager.offset_batch_epi(mPostAct_mnl, batch_idx),
78
+ self.cta_tile_shape_postact_mn,
79
+ params.epi_tile_postact,
80
+ sPostAct,
81
+ tile_coord_mnkl,
82
+ tma_desc_ptr=tma_desc_postact_ptr,
83
+ )
84
+
85
+ # We iterate over epi tiles in the N dimension first before the M dimension
86
+ epi_tile_shape = cute.zipped_divide(
87
+ cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
88
+ ).shape[1]
89
+ epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
90
+ epi_tile_num = cute.size(epi_tile_shape)
91
+ num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
92
+
93
+ epi_tensors = self.epi_begin(
94
+ params,
95
+ epi_smem_tensors,
96
+ epi_tile,
97
+ tiled_copy_t2r,
98
+ tiled_copy_r2s,
99
+ tile_coord_mnkl,
100
+ varlen_manager,
101
+ epilogue_barrier,
102
+ tidx,
103
+ )
104
+
105
+ if const_expr(copy_C is not None):
106
+ for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
107
+ gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
108
+ if is_tma_warp:
109
+ epi_pipeline.producer_acquire(epi_producer_state)
110
+ copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
111
+ epi_pipeline.producer_commit(epi_producer_state)
112
+ epi_producer_state.advance()
113
+
114
+ def tma_store_fn(src_idx, dst_idx, tile_coord_mnkl):
115
+ pid_m = tile_coord_mnkl[0]
116
+ pid_n = tile_coord_mnkl[1]
117
+ # Fence and barrier to make sure shared memory store is visible to TMA store
118
+ cute.arch.fence_proxy(
119
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
120
+ )
121
+ epilogue_barrier.arrive_and_wait()
122
+ # Copy from shared memory to global memory
123
+ if is_tma_warp:
124
+ square_tile_m = pid_m // self.cluster_shape_mnk[0]
125
+ square_tile_n = pid_n // self.cluster_shape_mnk[1]
126
+ if const_expr(has_D):
127
+ copy_D(src_idx=src_idx, dst_idx=dst_idx)
128
+ if square_tile_m != square_tile_n: # don't write twice to the same tile
129
+ copy_postact(src_idx=src_idx, dst_idx=dst_idx)
130
+ # Can't use if statement here, epi_store_pipeline object isn't captured somehow
131
+ if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
132
+ if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
133
+ epilogue_barrier.arrive_and_wait()
134
+
135
+ delay_tma_store = True
136
+
137
+ src_idx_prev, dst_idx_prev = None, None
138
+ for epi_idx in cutlass.range_constexpr(epi_tile_num):
139
+ # The global memory coordinate for the current epi tile
140
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
141
+ # Copy from acc to D registers
142
+ load_acc_subtile(tRS_rD, epi_idx)
143
+ epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
144
+ if const_expr(has_C):
145
+ epi_pipeline.consumer_wait(epi_read_state)
146
+ cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
147
+ # Fence to make sure shared memory read is visible to TMA load
148
+ cute.arch.fence_proxy(
149
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
150
+ )
151
+ cute.arch.sync_warp()
152
+ with cute.arch.elect_one():
153
+ epi_pipeline.consumer_release(epi_read_state)
154
+ epi_read_state.advance()
155
+ if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
156
+ gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
157
+ if is_tma_warp:
158
+ epi_pipeline.producer_acquire(epi_producer_state)
159
+ copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
160
+ epi_pipeline.producer_commit(epi_producer_state)
161
+ epi_producer_state.advance()
162
+ tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
163
+ epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
164
+ if const_expr(delay_tma_store):
165
+ if const_expr(epi_idx > 0):
166
+ tma_store_fn(
167
+ src_idx=src_idx_prev, dst_idx=dst_idx_prev, tile_coord_mnkl=tile_coord_mnkl
168
+ )
169
+ src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
170
+ # Copy from D registers to shared memory
171
+ if const_expr(has_D):
172
+ copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
173
+ cute.copy(
174
+ tiled_copy_postact_r2s,
175
+ tiled_copy_postact_r2s.retile(tRS_rPostAct),
176
+ tRS_sPostAct[None, None, None, epi_buffer],
177
+ )
178
+ if const_expr(not delay_tma_store):
179
+ tma_store_fn(
180
+ src_idx=epi_buffer, dst_idx=gmem_coord, tile_coord_mnkl=tile_coord_mnkl
181
+ )
182
+
183
+ if const_expr(delay_tma_store):
184
+ tma_store_fn(
185
+ src_idx=src_idx_prev, dst_idx=dst_idx_prev, tile_coord_mnkl=tile_coord_mnkl
186
+ )
187
+
188
+ self.epi_end(
189
+ params,
190
+ epi_tensors,
191
+ epi_tile,
192
+ tiled_copy_t2r,
193
+ tiled_copy_r2s,
194
+ tile_coord_mnkl,
195
+ varlen_manager,
196
+ tidx,
197
+ )
198
+
199
+ return epi_read_state, epi_producer_state
200
+
201
+
202
+ class GemmSymmetricSm90(GemmSymmetricMixin, GemmSm90):
203
+ pass
204
+
205
+
206
+ class GemmSymmetricSm100(GemmSymmetricMixin, GemmSm100):
207
+ pass
208
+
209
+
210
+ def gemm_symmetric(
211
+ A: Tensor, # (l, m, k)
212
+ B: Tensor, # (l, m, k)
213
+ D: Optional[Tensor], # (l, m, m)
214
+ C: Optional[Tensor], # (l, m, m)
215
+ tile_count_semaphore: Optional[Tensor], # (1,)
216
+ tile_M: int,
217
+ tile_N: int,
218
+ cluster_M: int,
219
+ cluster_N: int,
220
+ pingpong: bool = False,
221
+ persistent: bool = True,
222
+ max_swizzle_size: int = 8,
223
+ alpha: float | Tensor = 1.0,
224
+ beta: float | Tensor = 1.0,
225
+ ) -> None:
226
+ # Tranpose D so the "activation" is a write to the mirrored tile
227
+ PostAct = D.mT
228
+
229
+ L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
230
+ A, B, D, C, additional_tensors={"PostAct": PostAct}
231
+ )
232
+ assert M == N, "M and N must be the same; symmetric gemm only supports square matrices"
233
+ GemmWrapperBase.permute_tensors(tensor_infos)
234
+ GemmWrapperBase.extract_dtypes(tensor_infos)
235
+ major_configs = {
236
+ "A": ("m", "k", "l"),
237
+ "B": ("n", "k", "l"),
238
+ "D": ("m", "n", "l"),
239
+ "C": ("m", "n", "l"),
240
+ "PostAct": ("m", "n", "l"),
241
+ }
242
+ GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
243
+
244
+ device_capacity = get_device_capacity(A.device)
245
+ assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
246
+ GemmCls = GemmSymmetricSm90 if device_capacity[0] == 9 else GemmSymmetricSm100
247
+
248
+ acc_dtype = Float32
249
+ tile_shape_mn = (tile_M, tile_N)
250
+ cluster_shape_mnk = (cluster_M, cluster_N, 1)
251
+ if not GemmCls.is_valid_dtypes(
252
+ tensor_infos["A"].dtype,
253
+ tensor_infos["B"].dtype,
254
+ acc_dtype,
255
+ tensor_infos["D"].dtype,
256
+ tensor_infos["A"].major,
257
+ tensor_infos["B"].major,
258
+ ):
259
+ raise TypeError("Skipping due to unsupported combination of types and majors")
260
+
261
+ max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
262
+ GemmWrapperBase.create_cute_tensors({k: v for k, v in tensor_infos.items()}, major_configs)
263
+
264
+ def scalar_arg(scalar: float | Tensor):
265
+ if isinstance(scalar, float):
266
+ return Float32(scalar) if scalar != 1.0 else None
267
+ else:
268
+ assert isinstance(scalar, Tensor)
269
+ return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
270
+
271
+ activation = None # Equivalent to identity
272
+ act_fn = act_fn_map[activation]
273
+ epi_args = GemmCls.EpilogueArguments(
274
+ tensor_infos["PostAct"].cute_tensor, act_fn, scalar_arg(alpha), scalar_arg(beta)
275
+ )
276
+ scheduler_args = GemmWrapperBase.create_scheduler_args(
277
+ max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
278
+ )
279
+ varlen_args = None
280
+
281
+ current_stream = cutlass_torch.current_stream()
282
+ compile_key = GemmWrapperBase.get_compile_key(
283
+ tensor_infos,
284
+ activation,
285
+ tile_shape_mn,
286
+ cluster_shape_mnk,
287
+ pingpong,
288
+ persistent,
289
+ tile_count_semaphore is not None,
290
+ device_capacity,
291
+ max_swizzle_size,
292
+ 2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
293
+ 2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
294
+ key_tensor_names=("A", "B", "D", "PostAct", "C"),
295
+ )
296
+ cache = gemm_act.compile_cache
297
+ if compile_key not in cache:
298
+ if device_capacity[0] == 9:
299
+ GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
300
+ gemm_obj = GemmCls(
301
+ acc_dtype,
302
+ tensor_infos["A"].dtype,
303
+ tile_shape_mn,
304
+ cluster_shape_mnk,
305
+ gather_A=False,
306
+ )
307
+ cache[compile_key] = cute.compile(
308
+ gemm_obj,
309
+ tensor_infos["A"].cute_tensor,
310
+ tensor_infos["B"].cute_tensor,
311
+ tensor_infos["D"].cute_tensor,
312
+ tensor_infos["C"].cute_tensor,
313
+ epi_args,
314
+ scheduler_args,
315
+ varlen_args,
316
+ current_stream,
317
+ )
318
+ cache[compile_key](
319
+ tensor_infos["A"].cute_tensor,
320
+ tensor_infos["B"].cute_tensor,
321
+ tensor_infos["D"].cute_tensor,
322
+ tensor_infos["C"].cute_tensor,
323
+ epi_args,
324
+ scheduler_args,
325
+ varlen_args,
326
+ current_stream,
327
+ )
328
+
329
+
330
+ gemm_act.compile_cache = {}
@@ -11,7 +11,7 @@ from cutlass.cute.runtime import from_dlpack, make_ptr
11
11
 
12
12
  from quack.cute_dsl_utils import torch2cute_dtype_map
13
13
  from quack.varlen_utils import VarlenArguments
14
- from quack.dense_gemm_sm90 import TileSchedulerOptions
14
+ from quack.tile_scheduler import TileSchedulerOptions
15
15
 
16
16
 
17
17
  @dataclass
@@ -214,6 +214,7 @@ class GemmWrapperBase:
214
214
  max_active_clusters: int,
215
215
  tile_count_semaphore: Optional[Tensor] = None,
216
216
  batch_idx_permute: Optional[Tensor] = None,
217
+ max_swizzle_size: int = 8,
217
218
  ) -> TileSchedulerOptions:
218
219
  return TileSchedulerOptions(
219
220
  Int32(max_active_clusters),
@@ -227,6 +228,7 @@ class GemmWrapperBase:
227
228
  )
228
229
  if batch_idx_permute is not None
229
230
  else None,
231
+ max_swizzle_size=Int32(max_swizzle_size),
230
232
  )
231
233
 
232
234
  @staticmethod
quack/layout_utils.py ADDED
@@ -0,0 +1,287 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+
4
+ import cutlass
5
+ import cutlass.cute as cute
6
+
7
+ from cutlass import Int32, const_expr
8
+
9
+ from quack.utils import prmt
10
+
11
+
12
+ def transpose_view(a: cute.Tensor) -> cute.Tensor:
13
+ """Transpose the first two dimensions of a tensor on smem."""
14
+ shape = (a.shape[1], a.shape[0], *a.shape[2:])
15
+ order = (1, 0, *range(2, cute.rank(a)))
16
+ return cute.composition(a, cute.make_ordered_layout(shape, order=order))
17
+
18
+
19
+ def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor:
20
+ return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
21
+
22
+
23
+ def expand(a: cute.Tensor, dim: int, size: Int32 | int) -> cute.Tensor:
24
+ shape = (*a.shape[:dim], size, *a.shape[dim:])
25
+ stride = (*a.layout.stride[:dim], 0, *a.layout.stride[dim:])
26
+ return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride))
27
+
28
+
29
+ @cute.jit
30
+ def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
31
+ assert t.element_type.width == 16
32
+ assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation"
33
+ t_u32 = cute.recast_tensor(t, Int32)
34
+
35
+ quad_idx = cute.arch.lane_idx() % 4
36
+ lane_03 = quad_idx == 0 or quad_idx == 3
37
+ selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054)
38
+ selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276)
39
+ # upper_map = [0, 3, 1, 2]
40
+ # lower_map = [1, 2, 0, 3]
41
+ # upper_idx = upper_map[quad_idx]
42
+ # indexing isn't supported so we have to do arithmetic
43
+ upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2
44
+ lower_idx = upper_idx ^ 1
45
+
46
+ # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
47
+ width = 4
48
+ mask = cute.arch.WARP_SIZE - width
49
+ clamp = cute.arch.WARP_SIZE - 1
50
+ mask_and_clamp = mask << 8 | clamp
51
+
52
+ for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True):
53
+ upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1]
54
+ upper0 = upper if lane_03 else lower
55
+ lower0 = lower if lane_03 else upper
56
+ upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
57
+ lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
58
+ t_u32[i * 2 + 0] = prmt(upper0, lower0, selector_upper)
59
+ t_u32[i * 2 + 1] = prmt(upper0, lower0, selector_lower)
60
+
61
+
62
+ @cute.jit
63
+ def permute_Cregs_b32_for_stsm(t: cute.Tensor) -> None:
64
+ """Permute and shuffle within 4 threads to change the layout from
65
+ T0 | T1 | T2 | T3
66
+ a b | c d | e f | g h
67
+ to
68
+ T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3
69
+ a | b | c | d | e | f | g | h
70
+ This is so that we can use STSM (instead of STS.64) to store C registers without bank conflict.
71
+ """
72
+
73
+ assert t.element_type.width == 32
74
+ assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation"
75
+
76
+ quad_idx = cute.arch.lane_idx() % 4
77
+ # left_map = [0, 2, 1, 3]
78
+ # right_map = [2, 0, 3, 1]
79
+ # indexing isn't supported so we have to do arithmetic
80
+ left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2
81
+ right_idx = left_idx ^ 0b10
82
+
83
+ # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
84
+ width = 4
85
+ mask = cute.arch.WARP_SIZE - width
86
+ clamp = cute.arch.WARP_SIZE - 1
87
+ mask_and_clamp = mask << 8 | clamp
88
+
89
+ for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True):
90
+ for r in cutlass.range(2, unroll_full=True):
91
+ left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1]
92
+ # a b | c d | e f | g h -> a b | c d | f e | h g
93
+ left0 = left if quad_idx < 2 else right
94
+ right0 = right if quad_idx < 2 else left
95
+ # a b | c d | f e | h g -> a b | f d | c e | h g
96
+ left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp)
97
+ # a b | f d | c e | h g -> a e | f b | c g | h d
98
+ right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp)
99
+ # a e | f b | c g | h d -> a e | b f | c g | d h
100
+ t[i * 4 + r * 2 + 0] = left0 if quad_idx % 2 == 0 else right0
101
+ t[i * 4 + r * 2 + 1] = right0 if quad_idx % 2 == 0 else left0
102
+ t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1]
103
+
104
+
105
+ @cute.jit
106
+ def permute_Cregs_b32_for_ldsm(t: cute.Tensor) -> None:
107
+ """Permute and shuffle within 4 threads to change the layout from
108
+ T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3
109
+ a | b | c | d | e | f | g | h
110
+ to
111
+ T0 | T1 | T2 | T3
112
+ a b | c d | e f | g h
113
+ This is so that we can use LDSM (instead of LDS.64) to store C registers without bank conflict.
114
+ """
115
+
116
+ assert t.element_type.width == 32
117
+ assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation"
118
+
119
+ quad_idx = cute.arch.lane_idx() % 4
120
+ # left_map = [0, 2, 1, 3]
121
+ # right_map = [1, 3, 0, 2]
122
+ # indexing isn't supported so we have to do arithmetic
123
+ left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2
124
+ right_idx = left_idx ^ 0b01
125
+
126
+ # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
127
+ width = 4
128
+ mask = cute.arch.WARP_SIZE - width
129
+ clamp = cute.arch.WARP_SIZE - 1
130
+ mask_and_clamp = mask << 8 | clamp
131
+
132
+ # This is just the inverse of permute_Cregs_b32_for_stsm
133
+ for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True):
134
+ t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1]
135
+ for r in cutlass.range(2, unroll_full=True):
136
+ left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1]
137
+ # a e | b f | c g | d h -> a e | f b | c g | h d
138
+ left0 = left if quad_idx % 2 == 0 else right
139
+ right0 = right if quad_idx % 2 == 0 else left
140
+ # a e | f b | c g | h d -> a b | f d | c e | h g
141
+ right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp)
142
+ # a b | f d | c e | h g -> a b | c d | f e | h g
143
+ left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp)
144
+ # a b | c d | f e | h g -> a b | c d | e f | g h
145
+ t[i * 4 + r * 2 + 0] = left0 if quad_idx < 2 else right0
146
+ t[i * 4 + r * 2 + 1] = right0 if quad_idx < 2 else left0
147
+
148
+
149
+ @cute.jit
150
+ def concat_layout(*layouts: cute.Layout) -> cute.Layout:
151
+ return cute.make_layout(
152
+ tuple(l.shape for l in layouts),
153
+ stride=tuple(l.stride for l in layouts),
154
+ )
155
+
156
+
157
+ def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout:
158
+ """
159
+ For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
160
+ For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
161
+ """
162
+ acc_layout_col_major = cute.make_layout(acc_layout.shape)
163
+ acc_layout_mn = cute.make_layout(
164
+ (
165
+ (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
166
+ (
167
+ acc_layout_col_major.shape[0][0],
168
+ *acc_layout_col_major.shape[0][2:],
169
+ acc_layout_col_major.shape[2],
170
+ ), # MMA_N
171
+ *acc_layout_col_major.shape[3:],
172
+ ),
173
+ stride=(
174
+ (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
175
+ (
176
+ acc_layout_col_major.stride[0][0],
177
+ *acc_layout_col_major.stride[0][2:],
178
+ acc_layout_col_major.stride[2],
179
+ ), # MMA_N
180
+ *acc_layout_col_major.stride[3:],
181
+ ),
182
+ )
183
+ return cute.composition(acc_layout, acc_layout_mn)
184
+
185
+
186
+ def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor:
187
+ return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
188
+
189
+
190
+ @cute.jit
191
+ def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
192
+ # For back to back gemm, convert layout of acc0 to gemm 1 accept layout.
193
+ # For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
194
+ # For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
195
+ # TODO: Sm90 FP8
196
+ if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90
197
+ l = cute.logical_divide(
198
+ acc_layout, ((None, None, 2), None, None)
199
+ ) # ((2, 2, (2, N / 16)), MMA_M, MMA_N)
200
+ rA_mma_view = cute.make_layout(
201
+ (
202
+ (l.shape[0][0], l.shape[0][1], l.shape[0][2][0]),
203
+ l.shape[1],
204
+ (l.shape[0][2][1], l.shape[2]),
205
+ ),
206
+ stride=(
207
+ (l.stride[0][0], l.stride[0][1], l.stride[0][2][0]),
208
+ l.stride[1],
209
+ (l.stride[0][2][1], l.stride[2]),
210
+ ),
211
+ )
212
+ else: # Sm80
213
+ # (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2))
214
+ l = cute.logical_divide(acc_layout, (None, None, 2))
215
+ rA_mma_view = cute.make_layout(
216
+ (
217
+ (l.shape[0], l.shape[2][0]),
218
+ l.shape[1],
219
+ l.shape[2][1],
220
+ ),
221
+ stride=(
222
+ (l.stride[0], l.stride[2][0]),
223
+ l.stride[1],
224
+ l.stride[2][1],
225
+ ),
226
+ )
227
+ return rA_mma_view
228
+
229
+
230
+ def convert_layout_zero_stride(
231
+ input: cute.Tensor | cute.Layout, ref_layout: cute.Layout
232
+ ) -> cute.Layout:
233
+ layout = input.layout if const_expr(isinstance(input, cute.Tensor)) else input
234
+ # Group the modes with non-zero stride in the ref_layout together,
235
+ # and the modes with zero stride together
236
+ layout_flat = cute.flatten(layout)
237
+ ref_layout_flat = cute.flatten(ref_layout)
238
+ nonzero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride != 0]
239
+ zero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride == 0]
240
+ # There's an edge case when all modes are zero stride
241
+ new_shape = (
242
+ tuple(layout_flat[i].shape for i in nonzero_modes) if len(nonzero_modes) > 0 else (1,),
243
+ tuple(layout_flat[i].shape for i in zero_modes),
244
+ )
245
+ new_stride = (
246
+ tuple(layout_flat[i].stride for i in nonzero_modes) if len(nonzero_modes) > 0 else (0,),
247
+ tuple(layout_flat[i].stride for i in zero_modes),
248
+ )
249
+ out_layout = cute.make_layout(new_shape, stride=new_stride)
250
+ if const_expr(isinstance(input, cute.Tensor)):
251
+ return cute.make_tensor(input.iterator, out_layout)
252
+ else:
253
+ return out_layout
254
+
255
+
256
+ def mma_partition_C_vec(
257
+ sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool
258
+ ) -> cute.Tensor:
259
+ assert cute.rank(sVec) == 2
260
+ assert sVec.stride[0] == 1
261
+ stage = sVec.shape[1]
262
+ shape = (
263
+ (sVec.shape[0], expand_shape, stage)
264
+ if const_expr(is_colvec)
265
+ else (expand_shape, sVec.shape[0], stage)
266
+ )
267
+ stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
268
+ sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
269
+ tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_C(sVec_mma))
270
+ return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
271
+
272
+
273
+ def mma_partition_A_vec(
274
+ sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool
275
+ ) -> cute.Tensor:
276
+ assert cute.rank(sVec) == 2
277
+ assert sVec.stride[0] == 1
278
+ stage = sVec.shape[1]
279
+ shape = (
280
+ (sVec.shape[0], expand_shape, stage)
281
+ if const_expr(is_colvec)
282
+ else (expand_shape, sVec.shape[0], stage)
283
+ )
284
+ stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
285
+ sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
286
+ tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma))
287
+ return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]