quack-kernels 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -34
- quack/gemm.py +194 -0
- quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
- quack/gemm_config.py +72 -46
- quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +177 -31
- quack/gemm_sm100.py +729 -506
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +3 -1
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +476 -526
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +23 -16
- quack/topk.py +409 -85
- quack/utils.py +32 -220
- quack/varlen_utils.py +370 -1
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/METADATA +4 -2
- quack_kernels-0.2.4.dist-info/RECORD +44 -0
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.2.dist-info/RECORD +0 -37
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/top_level.txt +0 -0
quack/gemm_symmetric.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
from typing import Tuple, Optional, Callable
|
|
2
|
+
from functools import partial
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from quack.gemm_act import GemmActMixin, act_fn_map, gemm_act
|
|
5
|
+
from quack.gemm_sm90 import GemmSm90
|
|
6
|
+
from quack.gemm_sm100 import GemmSm100
|
|
7
|
+
from quack.tile_scheduler import TriangularTileScheduler
|
|
8
|
+
from quack.gemm_wrapper_utils import GemmWrapperBase
|
|
9
|
+
from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters
|
|
10
|
+
from quack.varlen_utils import VarlenManager
|
|
11
|
+
import quack.copy_utils as copy_utils
|
|
12
|
+
import cutlass
|
|
13
|
+
import cutlass.cute as cute
|
|
14
|
+
import cutlass.torch as cutlass_torch
|
|
15
|
+
from cutlass.cute.runtime import make_ptr
|
|
16
|
+
from cutlass import Int32, Float32, Boolean, const_expr
|
|
17
|
+
import cutlass.utils.hopper_helpers as sm90_utils_og
|
|
18
|
+
import cutlass.utils.blackwell_helpers as sm100_utils
|
|
19
|
+
from cutlass.cutlass_dsl import if_generate
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class GemmSymmetricMixin(GemmActMixin, GemmSm90):
|
|
23
|
+
def get_scheduler_class(self, varlen_m: bool = False):
|
|
24
|
+
return TriangularTileScheduler
|
|
25
|
+
|
|
26
|
+
@cute.jit
|
|
27
|
+
def epilogue(
|
|
28
|
+
self,
|
|
29
|
+
params: GemmActMixin.EpilogueParams,
|
|
30
|
+
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
|
31
|
+
tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
|
|
32
|
+
epi_pipeline: cutlass.pipeline.PipelineAsync,
|
|
33
|
+
epi_store_pipeline: cutlass.pipeline.PipelineAsync,
|
|
34
|
+
epi_read_state: cutlass.pipeline.PipelineState,
|
|
35
|
+
epi_producer_state: cutlass.pipeline.PipelineState,
|
|
36
|
+
epi_tile: cute.Tile,
|
|
37
|
+
load_acc_subtile: Callable,
|
|
38
|
+
tRS_rD: cute.Tensor,
|
|
39
|
+
tRS_rC: Optional[cute.Tensor],
|
|
40
|
+
tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100
|
|
41
|
+
tiled_copy_r2s: cute.TiledCopy,
|
|
42
|
+
tRS_sD: cute.Tensor,
|
|
43
|
+
tiled_copy_s2r: Optional[cute.TiledCopy],
|
|
44
|
+
tSR_rC: Optional[cute.Tensor],
|
|
45
|
+
tSR_sC: Optional[cute.Tensor],
|
|
46
|
+
copy_D: Optional[Callable],
|
|
47
|
+
copy_C: Optional[Callable],
|
|
48
|
+
tile_coord_mnkl: cute.Coord,
|
|
49
|
+
varlen_manager: VarlenManager,
|
|
50
|
+
epilogue_barrier: cutlass.pipeline.NamedBarrier,
|
|
51
|
+
tile_scheduler,
|
|
52
|
+
tidx: Int32,
|
|
53
|
+
is_tma_warp: Boolean,
|
|
54
|
+
) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
|
|
55
|
+
has_C = const_expr(tRS_rC is not None)
|
|
56
|
+
has_D = const_expr(copy_D is not None)
|
|
57
|
+
|
|
58
|
+
tma_atom_postact = params.tma_atom_postact
|
|
59
|
+
mPostAct_mnl = params.mPostAct_mnl
|
|
60
|
+
sRowVec, sColVec, sPostAct = epi_smem_tensors
|
|
61
|
+
get_smem_store_op = (
|
|
62
|
+
partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r)
|
|
63
|
+
if self.arch == 100
|
|
64
|
+
else sm90_utils_og.sm90_get_smem_store_op
|
|
65
|
+
)
|
|
66
|
+
copy_atom_postact_r2s = get_smem_store_op(
|
|
67
|
+
self.postact_layout, self.postact_dtype, self.acc_dtype
|
|
68
|
+
)
|
|
69
|
+
# tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
|
|
70
|
+
# tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom)
|
|
71
|
+
tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s)
|
|
72
|
+
tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct)
|
|
73
|
+
(tma_desc_postact_ptr,) = tma_desc_epi_ptrs
|
|
74
|
+
batch_idx = tile_coord_mnkl[3]
|
|
75
|
+
copy_postact, _, _ = self.epilog_gmem_copy_and_partition(
|
|
76
|
+
tma_atom_postact,
|
|
77
|
+
varlen_manager.offset_batch_epi(mPostAct_mnl, batch_idx),
|
|
78
|
+
self.cta_tile_shape_postact_mn,
|
|
79
|
+
params.epi_tile_postact,
|
|
80
|
+
sPostAct,
|
|
81
|
+
tile_coord_mnkl,
|
|
82
|
+
tma_desc_ptr=tma_desc_postact_ptr,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# We iterate over epi tiles in the N dimension first before the M dimension
|
|
86
|
+
epi_tile_shape = cute.zipped_divide(
|
|
87
|
+
cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
|
|
88
|
+
).shape[1]
|
|
89
|
+
epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
|
|
90
|
+
epi_tile_num = cute.size(epi_tile_shape)
|
|
91
|
+
num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
|
|
92
|
+
|
|
93
|
+
epi_tensors = self.epi_begin(
|
|
94
|
+
params,
|
|
95
|
+
epi_smem_tensors,
|
|
96
|
+
epi_tile,
|
|
97
|
+
tiled_copy_t2r,
|
|
98
|
+
tiled_copy_r2s,
|
|
99
|
+
tile_coord_mnkl,
|
|
100
|
+
varlen_manager,
|
|
101
|
+
epilogue_barrier,
|
|
102
|
+
tidx,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if const_expr(copy_C is not None):
|
|
106
|
+
for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
|
|
107
|
+
gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
|
|
108
|
+
if is_tma_warp:
|
|
109
|
+
epi_pipeline.producer_acquire(epi_producer_state)
|
|
110
|
+
copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
|
|
111
|
+
epi_pipeline.producer_commit(epi_producer_state)
|
|
112
|
+
epi_producer_state.advance()
|
|
113
|
+
|
|
114
|
+
def tma_store_fn(src_idx, dst_idx, tile_coord_mnkl):
|
|
115
|
+
pid_m = tile_coord_mnkl[0]
|
|
116
|
+
pid_n = tile_coord_mnkl[1]
|
|
117
|
+
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
118
|
+
cute.arch.fence_proxy(
|
|
119
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
120
|
+
)
|
|
121
|
+
epilogue_barrier.arrive_and_wait()
|
|
122
|
+
# Copy from shared memory to global memory
|
|
123
|
+
if is_tma_warp:
|
|
124
|
+
square_tile_m = pid_m // self.cluster_shape_mnk[0]
|
|
125
|
+
square_tile_n = pid_n // self.cluster_shape_mnk[1]
|
|
126
|
+
if const_expr(has_D):
|
|
127
|
+
copy_D(src_idx=src_idx, dst_idx=dst_idx)
|
|
128
|
+
if square_tile_m != square_tile_n: # don't write twice to the same tile
|
|
129
|
+
copy_postact(src_idx=src_idx, dst_idx=dst_idx)
|
|
130
|
+
# Can't use if statement here, epi_store_pipeline object isn't captured somehow
|
|
131
|
+
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
|
|
132
|
+
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
|
|
133
|
+
epilogue_barrier.arrive_and_wait()
|
|
134
|
+
|
|
135
|
+
delay_tma_store = True
|
|
136
|
+
|
|
137
|
+
src_idx_prev, dst_idx_prev = None, None
|
|
138
|
+
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
|
139
|
+
# The global memory coordinate for the current epi tile
|
|
140
|
+
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
141
|
+
# Copy from acc to D registers
|
|
142
|
+
load_acc_subtile(tRS_rD, epi_idx)
|
|
143
|
+
epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
|
|
144
|
+
if const_expr(has_C):
|
|
145
|
+
epi_pipeline.consumer_wait(epi_read_state)
|
|
146
|
+
cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
|
|
147
|
+
# Fence to make sure shared memory read is visible to TMA load
|
|
148
|
+
cute.arch.fence_proxy(
|
|
149
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
150
|
+
)
|
|
151
|
+
cute.arch.sync_warp()
|
|
152
|
+
with cute.arch.elect_one():
|
|
153
|
+
epi_pipeline.consumer_release(epi_read_state)
|
|
154
|
+
epi_read_state.advance()
|
|
155
|
+
if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
|
|
156
|
+
gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
|
|
157
|
+
if is_tma_warp:
|
|
158
|
+
epi_pipeline.producer_acquire(epi_producer_state)
|
|
159
|
+
copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
|
|
160
|
+
epi_pipeline.producer_commit(epi_producer_state)
|
|
161
|
+
epi_producer_state.advance()
|
|
162
|
+
tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
|
|
163
|
+
epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
|
|
164
|
+
if const_expr(delay_tma_store):
|
|
165
|
+
if const_expr(epi_idx > 0):
|
|
166
|
+
tma_store_fn(
|
|
167
|
+
src_idx=src_idx_prev, dst_idx=dst_idx_prev, tile_coord_mnkl=tile_coord_mnkl
|
|
168
|
+
)
|
|
169
|
+
src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
|
|
170
|
+
# Copy from D registers to shared memory
|
|
171
|
+
if const_expr(has_D):
|
|
172
|
+
copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
|
|
173
|
+
cute.copy(
|
|
174
|
+
tiled_copy_postact_r2s,
|
|
175
|
+
tiled_copy_postact_r2s.retile(tRS_rPostAct),
|
|
176
|
+
tRS_sPostAct[None, None, None, epi_buffer],
|
|
177
|
+
)
|
|
178
|
+
if const_expr(not delay_tma_store):
|
|
179
|
+
tma_store_fn(
|
|
180
|
+
src_idx=epi_buffer, dst_idx=gmem_coord, tile_coord_mnkl=tile_coord_mnkl
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
if const_expr(delay_tma_store):
|
|
184
|
+
tma_store_fn(
|
|
185
|
+
src_idx=src_idx_prev, dst_idx=dst_idx_prev, tile_coord_mnkl=tile_coord_mnkl
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
self.epi_end(
|
|
189
|
+
params,
|
|
190
|
+
epi_tensors,
|
|
191
|
+
epi_tile,
|
|
192
|
+
tiled_copy_t2r,
|
|
193
|
+
tiled_copy_r2s,
|
|
194
|
+
tile_coord_mnkl,
|
|
195
|
+
varlen_manager,
|
|
196
|
+
tidx,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
return epi_read_state, epi_producer_state
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class GemmSymmetricSm90(GemmSymmetricMixin, GemmSm90):
|
|
203
|
+
pass
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class GemmSymmetricSm100(GemmSymmetricMixin, GemmSm100):
|
|
207
|
+
pass
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def gemm_symmetric(
|
|
211
|
+
A: Tensor, # (l, m, k)
|
|
212
|
+
B: Tensor, # (l, m, k)
|
|
213
|
+
D: Optional[Tensor], # (l, m, m)
|
|
214
|
+
C: Optional[Tensor], # (l, m, m)
|
|
215
|
+
tile_count_semaphore: Optional[Tensor], # (1,)
|
|
216
|
+
tile_M: int,
|
|
217
|
+
tile_N: int,
|
|
218
|
+
cluster_M: int,
|
|
219
|
+
cluster_N: int,
|
|
220
|
+
pingpong: bool = False,
|
|
221
|
+
persistent: bool = True,
|
|
222
|
+
max_swizzle_size: int = 8,
|
|
223
|
+
alpha: float | Tensor = 1.0,
|
|
224
|
+
beta: float | Tensor = 1.0,
|
|
225
|
+
) -> None:
|
|
226
|
+
# Tranpose D so the "activation" is a write to the mirrored tile
|
|
227
|
+
PostAct = D.mT
|
|
228
|
+
|
|
229
|
+
L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
|
|
230
|
+
A, B, D, C, additional_tensors={"PostAct": PostAct}
|
|
231
|
+
)
|
|
232
|
+
assert M == N, "M and N must be the same; symmetric gemm only supports square matrices"
|
|
233
|
+
GemmWrapperBase.permute_tensors(tensor_infos)
|
|
234
|
+
GemmWrapperBase.extract_dtypes(tensor_infos)
|
|
235
|
+
major_configs = {
|
|
236
|
+
"A": ("m", "k", "l"),
|
|
237
|
+
"B": ("n", "k", "l"),
|
|
238
|
+
"D": ("m", "n", "l"),
|
|
239
|
+
"C": ("m", "n", "l"),
|
|
240
|
+
"PostAct": ("m", "n", "l"),
|
|
241
|
+
}
|
|
242
|
+
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
|
243
|
+
|
|
244
|
+
device_capacity = get_device_capacity(A.device)
|
|
245
|
+
assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
|
|
246
|
+
GemmCls = GemmSymmetricSm90 if device_capacity[0] == 9 else GemmSymmetricSm100
|
|
247
|
+
|
|
248
|
+
acc_dtype = Float32
|
|
249
|
+
tile_shape_mn = (tile_M, tile_N)
|
|
250
|
+
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
|
251
|
+
if not GemmCls.is_valid_dtypes(
|
|
252
|
+
tensor_infos["A"].dtype,
|
|
253
|
+
tensor_infos["B"].dtype,
|
|
254
|
+
acc_dtype,
|
|
255
|
+
tensor_infos["D"].dtype,
|
|
256
|
+
tensor_infos["A"].major,
|
|
257
|
+
tensor_infos["B"].major,
|
|
258
|
+
):
|
|
259
|
+
raise TypeError("Skipping due to unsupported combination of types and majors")
|
|
260
|
+
|
|
261
|
+
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
|
262
|
+
GemmWrapperBase.create_cute_tensors({k: v for k, v in tensor_infos.items()}, major_configs)
|
|
263
|
+
|
|
264
|
+
def scalar_arg(scalar: float | Tensor):
|
|
265
|
+
if isinstance(scalar, float):
|
|
266
|
+
return Float32(scalar) if scalar != 1.0 else None
|
|
267
|
+
else:
|
|
268
|
+
assert isinstance(scalar, Tensor)
|
|
269
|
+
return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
|
|
270
|
+
|
|
271
|
+
activation = None # Equivalent to identity
|
|
272
|
+
act_fn = act_fn_map[activation]
|
|
273
|
+
epi_args = GemmCls.EpilogueArguments(
|
|
274
|
+
tensor_infos["PostAct"].cute_tensor, act_fn, scalar_arg(alpha), scalar_arg(beta)
|
|
275
|
+
)
|
|
276
|
+
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
|
277
|
+
max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
|
|
278
|
+
)
|
|
279
|
+
varlen_args = None
|
|
280
|
+
|
|
281
|
+
current_stream = cutlass_torch.current_stream()
|
|
282
|
+
compile_key = GemmWrapperBase.get_compile_key(
|
|
283
|
+
tensor_infos,
|
|
284
|
+
activation,
|
|
285
|
+
tile_shape_mn,
|
|
286
|
+
cluster_shape_mnk,
|
|
287
|
+
pingpong,
|
|
288
|
+
persistent,
|
|
289
|
+
tile_count_semaphore is not None,
|
|
290
|
+
device_capacity,
|
|
291
|
+
max_swizzle_size,
|
|
292
|
+
2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
|
|
293
|
+
2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
|
|
294
|
+
key_tensor_names=("A", "B", "D", "PostAct", "C"),
|
|
295
|
+
)
|
|
296
|
+
cache = gemm_act.compile_cache
|
|
297
|
+
if compile_key not in cache:
|
|
298
|
+
if device_capacity[0] == 9:
|
|
299
|
+
GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
|
|
300
|
+
gemm_obj = GemmCls(
|
|
301
|
+
acc_dtype,
|
|
302
|
+
tensor_infos["A"].dtype,
|
|
303
|
+
tile_shape_mn,
|
|
304
|
+
cluster_shape_mnk,
|
|
305
|
+
gather_A=False,
|
|
306
|
+
)
|
|
307
|
+
cache[compile_key] = cute.compile(
|
|
308
|
+
gemm_obj,
|
|
309
|
+
tensor_infos["A"].cute_tensor,
|
|
310
|
+
tensor_infos["B"].cute_tensor,
|
|
311
|
+
tensor_infos["D"].cute_tensor,
|
|
312
|
+
tensor_infos["C"].cute_tensor,
|
|
313
|
+
epi_args,
|
|
314
|
+
scheduler_args,
|
|
315
|
+
varlen_args,
|
|
316
|
+
current_stream,
|
|
317
|
+
)
|
|
318
|
+
cache[compile_key](
|
|
319
|
+
tensor_infos["A"].cute_tensor,
|
|
320
|
+
tensor_infos["B"].cute_tensor,
|
|
321
|
+
tensor_infos["D"].cute_tensor,
|
|
322
|
+
tensor_infos["C"].cute_tensor,
|
|
323
|
+
epi_args,
|
|
324
|
+
scheduler_args,
|
|
325
|
+
varlen_args,
|
|
326
|
+
current_stream,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
gemm_act.compile_cache = {}
|
quack/gemm_wrapper_utils.py
CHANGED
|
@@ -11,7 +11,7 @@ from cutlass.cute.runtime import from_dlpack, make_ptr
|
|
|
11
11
|
|
|
12
12
|
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
13
13
|
from quack.varlen_utils import VarlenArguments
|
|
14
|
-
from quack.
|
|
14
|
+
from quack.tile_scheduler import TileSchedulerOptions
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
@dataclass
|
|
@@ -214,6 +214,7 @@ class GemmWrapperBase:
|
|
|
214
214
|
max_active_clusters: int,
|
|
215
215
|
tile_count_semaphore: Optional[Tensor] = None,
|
|
216
216
|
batch_idx_permute: Optional[Tensor] = None,
|
|
217
|
+
max_swizzle_size: int = 8,
|
|
217
218
|
) -> TileSchedulerOptions:
|
|
218
219
|
return TileSchedulerOptions(
|
|
219
220
|
Int32(max_active_clusters),
|
|
@@ -227,6 +228,7 @@ class GemmWrapperBase:
|
|
|
227
228
|
)
|
|
228
229
|
if batch_idx_permute is not None
|
|
229
230
|
else None,
|
|
231
|
+
max_swizzle_size=Int32(max_swizzle_size),
|
|
230
232
|
)
|
|
231
233
|
|
|
232
234
|
@staticmethod
|
quack/layout_utils.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
import cutlass
|
|
5
|
+
import cutlass.cute as cute
|
|
6
|
+
|
|
7
|
+
from cutlass import Int32, const_expr
|
|
8
|
+
|
|
9
|
+
from quack.utils import prmt
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def transpose_view(a: cute.Tensor) -> cute.Tensor:
|
|
13
|
+
"""Transpose the first two dimensions of a tensor on smem."""
|
|
14
|
+
shape = (a.shape[1], a.shape[0], *a.shape[2:])
|
|
15
|
+
order = (1, 0, *range(2, cute.rank(a)))
|
|
16
|
+
return cute.composition(a, cute.make_ordered_layout(shape, order=order))
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor:
|
|
20
|
+
return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def expand(a: cute.Tensor, dim: int, size: Int32 | int) -> cute.Tensor:
|
|
24
|
+
shape = (*a.shape[:dim], size, *a.shape[dim:])
|
|
25
|
+
stride = (*a.layout.stride[:dim], 0, *a.layout.stride[dim:])
|
|
26
|
+
return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@cute.jit
|
|
30
|
+
def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
|
|
31
|
+
assert t.element_type.width == 16
|
|
32
|
+
assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation"
|
|
33
|
+
t_u32 = cute.recast_tensor(t, Int32)
|
|
34
|
+
|
|
35
|
+
quad_idx = cute.arch.lane_idx() % 4
|
|
36
|
+
lane_03 = quad_idx == 0 or quad_idx == 3
|
|
37
|
+
selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054)
|
|
38
|
+
selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276)
|
|
39
|
+
# upper_map = [0, 3, 1, 2]
|
|
40
|
+
# lower_map = [1, 2, 0, 3]
|
|
41
|
+
# upper_idx = upper_map[quad_idx]
|
|
42
|
+
# indexing isn't supported so we have to do arithmetic
|
|
43
|
+
upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2
|
|
44
|
+
lower_idx = upper_idx ^ 1
|
|
45
|
+
|
|
46
|
+
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
|
47
|
+
width = 4
|
|
48
|
+
mask = cute.arch.WARP_SIZE - width
|
|
49
|
+
clamp = cute.arch.WARP_SIZE - 1
|
|
50
|
+
mask_and_clamp = mask << 8 | clamp
|
|
51
|
+
|
|
52
|
+
for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True):
|
|
53
|
+
upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1]
|
|
54
|
+
upper0 = upper if lane_03 else lower
|
|
55
|
+
lower0 = lower if lane_03 else upper
|
|
56
|
+
upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
|
|
57
|
+
lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
|
|
58
|
+
t_u32[i * 2 + 0] = prmt(upper0, lower0, selector_upper)
|
|
59
|
+
t_u32[i * 2 + 1] = prmt(upper0, lower0, selector_lower)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@cute.jit
|
|
63
|
+
def permute_Cregs_b32_for_stsm(t: cute.Tensor) -> None:
|
|
64
|
+
"""Permute and shuffle within 4 threads to change the layout from
|
|
65
|
+
T0 | T1 | T2 | T3
|
|
66
|
+
a b | c d | e f | g h
|
|
67
|
+
to
|
|
68
|
+
T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3
|
|
69
|
+
a | b | c | d | e | f | g | h
|
|
70
|
+
This is so that we can use STSM (instead of STS.64) to store C registers without bank conflict.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
assert t.element_type.width == 32
|
|
74
|
+
assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation"
|
|
75
|
+
|
|
76
|
+
quad_idx = cute.arch.lane_idx() % 4
|
|
77
|
+
# left_map = [0, 2, 1, 3]
|
|
78
|
+
# right_map = [2, 0, 3, 1]
|
|
79
|
+
# indexing isn't supported so we have to do arithmetic
|
|
80
|
+
left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2
|
|
81
|
+
right_idx = left_idx ^ 0b10
|
|
82
|
+
|
|
83
|
+
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
|
84
|
+
width = 4
|
|
85
|
+
mask = cute.arch.WARP_SIZE - width
|
|
86
|
+
clamp = cute.arch.WARP_SIZE - 1
|
|
87
|
+
mask_and_clamp = mask << 8 | clamp
|
|
88
|
+
|
|
89
|
+
for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True):
|
|
90
|
+
for r in cutlass.range(2, unroll_full=True):
|
|
91
|
+
left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1]
|
|
92
|
+
# a b | c d | e f | g h -> a b | c d | f e | h g
|
|
93
|
+
left0 = left if quad_idx < 2 else right
|
|
94
|
+
right0 = right if quad_idx < 2 else left
|
|
95
|
+
# a b | c d | f e | h g -> a b | f d | c e | h g
|
|
96
|
+
left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp)
|
|
97
|
+
# a b | f d | c e | h g -> a e | f b | c g | h d
|
|
98
|
+
right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp)
|
|
99
|
+
# a e | f b | c g | h d -> a e | b f | c g | d h
|
|
100
|
+
t[i * 4 + r * 2 + 0] = left0 if quad_idx % 2 == 0 else right0
|
|
101
|
+
t[i * 4 + r * 2 + 1] = right0 if quad_idx % 2 == 0 else left0
|
|
102
|
+
t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1]
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@cute.jit
|
|
106
|
+
def permute_Cregs_b32_for_ldsm(t: cute.Tensor) -> None:
|
|
107
|
+
"""Permute and shuffle within 4 threads to change the layout from
|
|
108
|
+
T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3
|
|
109
|
+
a | b | c | d | e | f | g | h
|
|
110
|
+
to
|
|
111
|
+
T0 | T1 | T2 | T3
|
|
112
|
+
a b | c d | e f | g h
|
|
113
|
+
This is so that we can use LDSM (instead of LDS.64) to store C registers without bank conflict.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
assert t.element_type.width == 32
|
|
117
|
+
assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation"
|
|
118
|
+
|
|
119
|
+
quad_idx = cute.arch.lane_idx() % 4
|
|
120
|
+
# left_map = [0, 2, 1, 3]
|
|
121
|
+
# right_map = [1, 3, 0, 2]
|
|
122
|
+
# indexing isn't supported so we have to do arithmetic
|
|
123
|
+
left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2
|
|
124
|
+
right_idx = left_idx ^ 0b01
|
|
125
|
+
|
|
126
|
+
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
|
127
|
+
width = 4
|
|
128
|
+
mask = cute.arch.WARP_SIZE - width
|
|
129
|
+
clamp = cute.arch.WARP_SIZE - 1
|
|
130
|
+
mask_and_clamp = mask << 8 | clamp
|
|
131
|
+
|
|
132
|
+
# This is just the inverse of permute_Cregs_b32_for_stsm
|
|
133
|
+
for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True):
|
|
134
|
+
t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1]
|
|
135
|
+
for r in cutlass.range(2, unroll_full=True):
|
|
136
|
+
left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1]
|
|
137
|
+
# a e | b f | c g | d h -> a e | f b | c g | h d
|
|
138
|
+
left0 = left if quad_idx % 2 == 0 else right
|
|
139
|
+
right0 = right if quad_idx % 2 == 0 else left
|
|
140
|
+
# a e | f b | c g | h d -> a b | f d | c e | h g
|
|
141
|
+
right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp)
|
|
142
|
+
# a b | f d | c e | h g -> a b | c d | f e | h g
|
|
143
|
+
left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp)
|
|
144
|
+
# a b | c d | f e | h g -> a b | c d | e f | g h
|
|
145
|
+
t[i * 4 + r * 2 + 0] = left0 if quad_idx < 2 else right0
|
|
146
|
+
t[i * 4 + r * 2 + 1] = right0 if quad_idx < 2 else left0
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
@cute.jit
|
|
150
|
+
def concat_layout(*layouts: cute.Layout) -> cute.Layout:
|
|
151
|
+
return cute.make_layout(
|
|
152
|
+
tuple(l.shape for l in layouts),
|
|
153
|
+
stride=tuple(l.stride for l in layouts),
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout:
|
|
158
|
+
"""
|
|
159
|
+
For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
|
|
160
|
+
For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
|
|
161
|
+
"""
|
|
162
|
+
acc_layout_col_major = cute.make_layout(acc_layout.shape)
|
|
163
|
+
acc_layout_mn = cute.make_layout(
|
|
164
|
+
(
|
|
165
|
+
(acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
|
|
166
|
+
(
|
|
167
|
+
acc_layout_col_major.shape[0][0],
|
|
168
|
+
*acc_layout_col_major.shape[0][2:],
|
|
169
|
+
acc_layout_col_major.shape[2],
|
|
170
|
+
), # MMA_N
|
|
171
|
+
*acc_layout_col_major.shape[3:],
|
|
172
|
+
),
|
|
173
|
+
stride=(
|
|
174
|
+
(acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
|
|
175
|
+
(
|
|
176
|
+
acc_layout_col_major.stride[0][0],
|
|
177
|
+
*acc_layout_col_major.stride[0][2:],
|
|
178
|
+
acc_layout_col_major.stride[2],
|
|
179
|
+
), # MMA_N
|
|
180
|
+
*acc_layout_col_major.stride[3:],
|
|
181
|
+
),
|
|
182
|
+
)
|
|
183
|
+
return cute.composition(acc_layout, acc_layout_mn)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor:
|
|
187
|
+
return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@cute.jit
|
|
191
|
+
def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
|
|
192
|
+
# For back to back gemm, convert layout of acc0 to gemm 1 accept layout.
|
|
193
|
+
# For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
|
194
|
+
# For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
|
|
195
|
+
# TODO: Sm90 FP8
|
|
196
|
+
if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90
|
|
197
|
+
l = cute.logical_divide(
|
|
198
|
+
acc_layout, ((None, None, 2), None, None)
|
|
199
|
+
) # ((2, 2, (2, N / 16)), MMA_M, MMA_N)
|
|
200
|
+
rA_mma_view = cute.make_layout(
|
|
201
|
+
(
|
|
202
|
+
(l.shape[0][0], l.shape[0][1], l.shape[0][2][0]),
|
|
203
|
+
l.shape[1],
|
|
204
|
+
(l.shape[0][2][1], l.shape[2]),
|
|
205
|
+
),
|
|
206
|
+
stride=(
|
|
207
|
+
(l.stride[0][0], l.stride[0][1], l.stride[0][2][0]),
|
|
208
|
+
l.stride[1],
|
|
209
|
+
(l.stride[0][2][1], l.stride[2]),
|
|
210
|
+
),
|
|
211
|
+
)
|
|
212
|
+
else: # Sm80
|
|
213
|
+
# (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2))
|
|
214
|
+
l = cute.logical_divide(acc_layout, (None, None, 2))
|
|
215
|
+
rA_mma_view = cute.make_layout(
|
|
216
|
+
(
|
|
217
|
+
(l.shape[0], l.shape[2][0]),
|
|
218
|
+
l.shape[1],
|
|
219
|
+
l.shape[2][1],
|
|
220
|
+
),
|
|
221
|
+
stride=(
|
|
222
|
+
(l.stride[0], l.stride[2][0]),
|
|
223
|
+
l.stride[1],
|
|
224
|
+
l.stride[2][1],
|
|
225
|
+
),
|
|
226
|
+
)
|
|
227
|
+
return rA_mma_view
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def convert_layout_zero_stride(
|
|
231
|
+
input: cute.Tensor | cute.Layout, ref_layout: cute.Layout
|
|
232
|
+
) -> cute.Layout:
|
|
233
|
+
layout = input.layout if const_expr(isinstance(input, cute.Tensor)) else input
|
|
234
|
+
# Group the modes with non-zero stride in the ref_layout together,
|
|
235
|
+
# and the modes with zero stride together
|
|
236
|
+
layout_flat = cute.flatten(layout)
|
|
237
|
+
ref_layout_flat = cute.flatten(ref_layout)
|
|
238
|
+
nonzero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride != 0]
|
|
239
|
+
zero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride == 0]
|
|
240
|
+
# There's an edge case when all modes are zero stride
|
|
241
|
+
new_shape = (
|
|
242
|
+
tuple(layout_flat[i].shape for i in nonzero_modes) if len(nonzero_modes) > 0 else (1,),
|
|
243
|
+
tuple(layout_flat[i].shape for i in zero_modes),
|
|
244
|
+
)
|
|
245
|
+
new_stride = (
|
|
246
|
+
tuple(layout_flat[i].stride for i in nonzero_modes) if len(nonzero_modes) > 0 else (0,),
|
|
247
|
+
tuple(layout_flat[i].stride for i in zero_modes),
|
|
248
|
+
)
|
|
249
|
+
out_layout = cute.make_layout(new_shape, stride=new_stride)
|
|
250
|
+
if const_expr(isinstance(input, cute.Tensor)):
|
|
251
|
+
return cute.make_tensor(input.iterator, out_layout)
|
|
252
|
+
else:
|
|
253
|
+
return out_layout
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def mma_partition_C_vec(
|
|
257
|
+
sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool
|
|
258
|
+
) -> cute.Tensor:
|
|
259
|
+
assert cute.rank(sVec) == 2
|
|
260
|
+
assert sVec.stride[0] == 1
|
|
261
|
+
stage = sVec.shape[1]
|
|
262
|
+
shape = (
|
|
263
|
+
(sVec.shape[0], expand_shape, stage)
|
|
264
|
+
if const_expr(is_colvec)
|
|
265
|
+
else (expand_shape, sVec.shape[0], stage)
|
|
266
|
+
)
|
|
267
|
+
stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
|
|
268
|
+
sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
|
|
269
|
+
tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_C(sVec_mma))
|
|
270
|
+
return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def mma_partition_A_vec(
|
|
274
|
+
sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool
|
|
275
|
+
) -> cute.Tensor:
|
|
276
|
+
assert cute.rank(sVec) == 2
|
|
277
|
+
assert sVec.stride[0] == 1
|
|
278
|
+
stage = sVec.shape[1]
|
|
279
|
+
shape = (
|
|
280
|
+
(sVec.shape[0], expand_shape, stage)
|
|
281
|
+
if const_expr(is_colvec)
|
|
282
|
+
else (expand_shape, sVec.shape[0], stage)
|
|
283
|
+
)
|
|
284
|
+
stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1])
|
|
285
|
+
sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride))
|
|
286
|
+
tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma))
|
|
287
|
+
return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None]
|