quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/autotuner.py +64 -5
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -35
- quack/gemm.py +194 -0
- quack/gemm_act.py +510 -0
- quack/gemm_config.py +72 -46
- quack/gemm_dact.py +215 -0
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +615 -146
- quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +182 -23
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +508 -624
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +55 -61
- quack/topk.py +409 -85
- quack/utils.py +37 -172
- quack/varlen_utils.py +370 -6
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/gemm_act_sm90.py +0 -368
- quack/gemm_dact_sm90.py +0 -150
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.1.dist-info/RECORD +0 -37
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/gemm_act.py
ADDED
|
@@ -0,0 +1,510 @@
|
|
|
1
|
+
# Copyright (c) 2025, Wentao Guo, Tri Dao.
|
|
2
|
+
from typing import Tuple, Optional, Callable
|
|
3
|
+
from functools import partial
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
from torch import Tensor
|
|
7
|
+
|
|
8
|
+
import cutlass
|
|
9
|
+
import cutlass.cute as cute
|
|
10
|
+
import cutlass.utils.hopper_helpers as sm90_utils_og
|
|
11
|
+
import cutlass.utils.blackwell_helpers as sm100_utils
|
|
12
|
+
from cutlass import Int32, Float32, Boolean, const_expr
|
|
13
|
+
from cutlass.cutlass_dsl import if_generate
|
|
14
|
+
import cutlass.torch as cutlass_torch
|
|
15
|
+
from cutlass.cute.runtime import from_dlpack
|
|
16
|
+
|
|
17
|
+
from quack.cute_dsl_utils import ArgumentsBase, ParamsBase
|
|
18
|
+
from quack.varlen_utils import VarlenManager
|
|
19
|
+
from quack.gemm_sm90 import GemmSm90
|
|
20
|
+
from quack.gemm_sm100 import GemmSm100
|
|
21
|
+
from quack.gemm_default_epi import GemmDefaultEpiMixin
|
|
22
|
+
from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters
|
|
23
|
+
from quack.gemm_wrapper_utils import GemmWrapperBase
|
|
24
|
+
import quack.sm90_utils as sm90_utils
|
|
25
|
+
import quack.copy_utils as copy_utils
|
|
26
|
+
import quack.activation
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class GemmActMixin(GemmDefaultEpiMixin):
|
|
30
|
+
num_epi_tensormaps: int = 1
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class EpilogueArguments(ArgumentsBase):
|
|
34
|
+
mPostAct: cute.Tensor
|
|
35
|
+
act_fn: cutlass.Constexpr[Optional[Callable]] = None
|
|
36
|
+
alpha: Optional[Float32 | cute.Tensor] = None
|
|
37
|
+
beta: Optional[Float32 | cute.Tensor] = None
|
|
38
|
+
mRowVecBroadcast: Optional[cute.Tensor] = None
|
|
39
|
+
mColVecBroadcast: Optional[cute.Tensor] = None
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class EpilogueParams(ParamsBase):
|
|
43
|
+
tma_atom_postact: cute.CopyAtom
|
|
44
|
+
mPostAct_mnl: cute.Tensor
|
|
45
|
+
epi_postact_smem_layout_staged: cute.ComposedLayout
|
|
46
|
+
epi_tile_postact: cute.Tile
|
|
47
|
+
act_fn: cutlass.Constexpr[Optional[Callable]] = None
|
|
48
|
+
alpha: Optional[Float32 | cute.Tensor] = None
|
|
49
|
+
beta: Optional[Float32 | cute.Tensor] = None
|
|
50
|
+
mRowVecBroadcast: Optional[cute.Tensor] = None
|
|
51
|
+
mColVecBroadcast: Optional[cute.Tensor] = None
|
|
52
|
+
|
|
53
|
+
def epi_to_underlying_arguments(
|
|
54
|
+
self, args: EpilogueArguments, *, loc=None, ip=None
|
|
55
|
+
) -> EpilogueParams:
|
|
56
|
+
self.postact_dtype = args.mPostAct.element_type
|
|
57
|
+
self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
|
|
58
|
+
|
|
59
|
+
self.cta_tile_shape_postact_mn = self.cta_tile_shape_mnk[:2]
|
|
60
|
+
epi_tile_postact = self.epi_tile
|
|
61
|
+
utils_cls = sm100_utils if self.arch == 100 else sm90_utils
|
|
62
|
+
epi_postact_smem_layout_staged = utils_cls.make_smem_layout_epi(
|
|
63
|
+
self.postact_dtype, self.postact_layout, epi_tile_postact, self.epi_stage
|
|
64
|
+
)
|
|
65
|
+
tma_atom_postact, tma_tensor_postact = self._make_tma_epi_atoms_and_tensors(
|
|
66
|
+
args.mPostAct,
|
|
67
|
+
epi_postact_smem_layout_staged,
|
|
68
|
+
epi_tile_postact,
|
|
69
|
+
op_type="store",
|
|
70
|
+
)
|
|
71
|
+
# Assume all strides are divisible by 32 bits except the last stride
|
|
72
|
+
new_stride = lambda t: tuple(
|
|
73
|
+
cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s
|
|
74
|
+
for s in t.stride
|
|
75
|
+
)
|
|
76
|
+
mRowVecBroadcast, mColVecBroadcast = [
|
|
77
|
+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
|
78
|
+
if t is not None
|
|
79
|
+
else None
|
|
80
|
+
for t in (args.mRowVecBroadcast, args.mColVecBroadcast)
|
|
81
|
+
]
|
|
82
|
+
return self.EpilogueParams(
|
|
83
|
+
tma_atom_postact,
|
|
84
|
+
tma_tensor_postact,
|
|
85
|
+
epi_postact_smem_layout_staged,
|
|
86
|
+
epi_tile_postact,
|
|
87
|
+
args.act_fn,
|
|
88
|
+
alpha=args.alpha,
|
|
89
|
+
beta=args.beta,
|
|
90
|
+
mRowVecBroadcast=mRowVecBroadcast,
|
|
91
|
+
mColVecBroadcast=mColVecBroadcast,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def epi_get_tma_atoms(
|
|
95
|
+
self, params: EpilogueParams, *, loc=None, ip=None
|
|
96
|
+
) -> list[cute.CopyAtom]:
|
|
97
|
+
return [params.tma_atom_postact]
|
|
98
|
+
|
|
99
|
+
def epi_get_tensormap_update_shapes_orders(
|
|
100
|
+
self,
|
|
101
|
+
params: EpilogueParams,
|
|
102
|
+
cu_seqlens_m: Optional[cute.Tensor],
|
|
103
|
+
batch_idx: Int32,
|
|
104
|
+
*,
|
|
105
|
+
loc=None,
|
|
106
|
+
ip=None,
|
|
107
|
+
) -> tuple[list[Int32], list[int]]:
|
|
108
|
+
shapes = [cu_seqlens_m[batch_idx + 1] if cu_seqlens_m is not None else None]
|
|
109
|
+
orders = [0 if const_expr(self.postact_layout.is_m_major_c()) else 1]
|
|
110
|
+
return shapes, orders
|
|
111
|
+
|
|
112
|
+
@staticmethod
|
|
113
|
+
def epi_smem_bytes_per_stage(
|
|
114
|
+
args: EpilogueArguments, cta_tile_shape_mnk: Tuple[int, int, int], epi_tile: cute.Tile
|
|
115
|
+
) -> int:
|
|
116
|
+
postact_dtype = args.mPostAct.element_type
|
|
117
|
+
postact_bytes_per_stage = cute.size(cute.shape(epi_tile)) * (postact_dtype.width // 8)
|
|
118
|
+
rowvec_colvec_bytes = GemmDefaultEpiMixin.epi_smem_bytes_per_stage(
|
|
119
|
+
args, cta_tile_shape_mnk, epi_tile
|
|
120
|
+
)
|
|
121
|
+
return postact_bytes_per_stage + rowvec_colvec_bytes
|
|
122
|
+
|
|
123
|
+
def epi_get_smem_struct(self, params: EpilogueParams):
|
|
124
|
+
row_vec_smem_size = 0 if params.mRowVecBroadcast is None else self.cta_tile_shape_mnk[1]
|
|
125
|
+
col_vec_smem_size = 0 if params.mColVecBroadcast is None else self.cta_tile_shape_mnk[0]
|
|
126
|
+
row_vec_dtype = (
|
|
127
|
+
params.mRowVecBroadcast.element_type if params.mRowVecBroadcast is not None else Float32
|
|
128
|
+
)
|
|
129
|
+
col_vec_dtype = (
|
|
130
|
+
params.mColVecBroadcast.element_type if params.mColVecBroadcast is not None else Float32
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
@cute.struct
|
|
134
|
+
class EpiSharedStorage:
|
|
135
|
+
sRowVec: cute.struct.Align[cute.struct.MemRange[row_vec_dtype, row_vec_smem_size], 16]
|
|
136
|
+
sColVec: cute.struct.Align[cute.struct.MemRange[col_vec_dtype, col_vec_smem_size], 16]
|
|
137
|
+
sPostAct: cute.struct.Align[
|
|
138
|
+
cute.struct.MemRange[
|
|
139
|
+
self.postact_dtype, cute.cosize(params.epi_postact_smem_layout_staged)
|
|
140
|
+
],
|
|
141
|
+
self.buffer_align_bytes,
|
|
142
|
+
]
|
|
143
|
+
|
|
144
|
+
return EpiSharedStorage
|
|
145
|
+
|
|
146
|
+
def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
|
|
147
|
+
sRowVec, sColVec = super().epi_get_smem_tensors(params, storage)
|
|
148
|
+
sPostAct = storage.epi.sPostAct.get_tensor(
|
|
149
|
+
params.epi_postact_smem_layout_staged.outer,
|
|
150
|
+
swizzle=params.epi_postact_smem_layout_staged.inner,
|
|
151
|
+
)
|
|
152
|
+
return (sRowVec, sColVec, sPostAct)
|
|
153
|
+
|
|
154
|
+
@cute.jit
|
|
155
|
+
def epilogue(
|
|
156
|
+
self,
|
|
157
|
+
params: EpilogueParams,
|
|
158
|
+
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
|
159
|
+
tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
|
|
160
|
+
epi_pipeline: cutlass.pipeline.PipelineAsync,
|
|
161
|
+
epi_store_pipeline: cutlass.pipeline.PipelineAsync,
|
|
162
|
+
epi_read_state: cutlass.pipeline.PipelineState,
|
|
163
|
+
epi_producer_state: cutlass.pipeline.PipelineState,
|
|
164
|
+
epi_tile: cute.Tile,
|
|
165
|
+
load_acc_subtile: Callable,
|
|
166
|
+
tRS_rD: cute.Tensor,
|
|
167
|
+
tRS_rC: Optional[cute.Tensor],
|
|
168
|
+
tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100
|
|
169
|
+
tiled_copy_r2s: cute.TiledCopy,
|
|
170
|
+
tRS_sD: cute.Tensor,
|
|
171
|
+
tiled_copy_s2r: Optional[cute.TiledCopy],
|
|
172
|
+
tSR_rC: Optional[cute.Tensor],
|
|
173
|
+
tSR_sC: Optional[cute.Tensor],
|
|
174
|
+
copy_D: Optional[Callable],
|
|
175
|
+
copy_C: Optional[Callable],
|
|
176
|
+
tile_coord_mnkl: cute.Coord,
|
|
177
|
+
varlen_manager: VarlenManager,
|
|
178
|
+
epilogue_barrier: cutlass.pipeline.NamedBarrier,
|
|
179
|
+
tile_scheduler,
|
|
180
|
+
tidx: Int32,
|
|
181
|
+
is_tma_warp: Boolean,
|
|
182
|
+
) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
|
|
183
|
+
has_C = const_expr(tRS_rC is not None)
|
|
184
|
+
has_D = const_expr(copy_D is not None)
|
|
185
|
+
|
|
186
|
+
tma_atom_postact = params.tma_atom_postact
|
|
187
|
+
mPostAct_mnl = params.mPostAct_mnl
|
|
188
|
+
sRowVec, sColVec, sPostAct = epi_smem_tensors
|
|
189
|
+
get_smem_store_op = (
|
|
190
|
+
partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r)
|
|
191
|
+
if self.arch == 100
|
|
192
|
+
else sm90_utils_og.sm90_get_smem_store_op
|
|
193
|
+
)
|
|
194
|
+
copy_atom_postact_r2s = get_smem_store_op(
|
|
195
|
+
self.postact_layout, self.postact_dtype, self.acc_dtype
|
|
196
|
+
)
|
|
197
|
+
# tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
|
|
198
|
+
# tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom)
|
|
199
|
+
tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s)
|
|
200
|
+
tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct)
|
|
201
|
+
(tma_desc_postact_ptr,) = tma_desc_epi_ptrs
|
|
202
|
+
batch_idx = tile_coord_mnkl[3]
|
|
203
|
+
copy_postact, _, _ = self.epilog_gmem_copy_and_partition(
|
|
204
|
+
tma_atom_postact,
|
|
205
|
+
varlen_manager.offset_batch_epi(mPostAct_mnl, batch_idx),
|
|
206
|
+
self.cta_tile_shape_postact_mn,
|
|
207
|
+
params.epi_tile_postact,
|
|
208
|
+
sPostAct,
|
|
209
|
+
tile_coord_mnkl,
|
|
210
|
+
tma_desc_ptr=tma_desc_postact_ptr,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# We iterate over epi tiles in the N dimension first before the M dimension
|
|
214
|
+
epi_tile_shape = cute.zipped_divide(
|
|
215
|
+
cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
|
|
216
|
+
).shape[1]
|
|
217
|
+
epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
|
|
218
|
+
epi_tile_num = cute.size(epi_tile_shape)
|
|
219
|
+
num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
|
|
220
|
+
|
|
221
|
+
epi_tensors = self.epi_begin(
|
|
222
|
+
params,
|
|
223
|
+
epi_smem_tensors,
|
|
224
|
+
epi_tile,
|
|
225
|
+
tiled_copy_t2r,
|
|
226
|
+
tiled_copy_r2s,
|
|
227
|
+
tile_coord_mnkl,
|
|
228
|
+
varlen_manager,
|
|
229
|
+
epilogue_barrier,
|
|
230
|
+
tidx,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
if const_expr(copy_C is not None):
|
|
234
|
+
for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
|
|
235
|
+
gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
|
|
236
|
+
if is_tma_warp:
|
|
237
|
+
epi_pipeline.producer_acquire(epi_producer_state)
|
|
238
|
+
copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
|
|
239
|
+
epi_pipeline.producer_commit(epi_producer_state)
|
|
240
|
+
epi_producer_state.advance()
|
|
241
|
+
|
|
242
|
+
def tma_store_fn(src_idx, dst_idx):
|
|
243
|
+
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
244
|
+
cute.arch.fence_proxy(
|
|
245
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
246
|
+
)
|
|
247
|
+
epilogue_barrier.arrive_and_wait()
|
|
248
|
+
# Copy from shared memory to global memory
|
|
249
|
+
if is_tma_warp:
|
|
250
|
+
if const_expr(has_D):
|
|
251
|
+
copy_D(src_idx=src_idx, dst_idx=dst_idx)
|
|
252
|
+
copy_postact(src_idx=src_idx, dst_idx=dst_idx)
|
|
253
|
+
# Can't use if statement here, epi_store_pipeline object isn't captured somehow
|
|
254
|
+
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
|
|
255
|
+
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
|
|
256
|
+
epilogue_barrier.arrive_and_wait()
|
|
257
|
+
|
|
258
|
+
delay_tma_store = True
|
|
259
|
+
|
|
260
|
+
src_idx_prev, dst_idx_prev = None, None
|
|
261
|
+
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
|
262
|
+
# The global memory coordinate for the current epi tile
|
|
263
|
+
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
264
|
+
# Copy from acc to D registers
|
|
265
|
+
load_acc_subtile(tRS_rD, epi_idx)
|
|
266
|
+
epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
|
|
267
|
+
if const_expr(has_C):
|
|
268
|
+
epi_pipeline.consumer_wait(epi_read_state)
|
|
269
|
+
cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
|
|
270
|
+
# Fence to make sure shared memory read is visible to TMA load
|
|
271
|
+
cute.arch.fence_proxy(
|
|
272
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
273
|
+
)
|
|
274
|
+
cute.arch.sync_warp()
|
|
275
|
+
with cute.arch.elect_one():
|
|
276
|
+
epi_pipeline.consumer_release(epi_read_state)
|
|
277
|
+
epi_read_state.advance()
|
|
278
|
+
if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
|
|
279
|
+
gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
|
|
280
|
+
if is_tma_warp:
|
|
281
|
+
epi_pipeline.producer_acquire(epi_producer_state)
|
|
282
|
+
copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
|
|
283
|
+
epi_pipeline.producer_commit(epi_producer_state)
|
|
284
|
+
epi_producer_state.advance()
|
|
285
|
+
tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
|
|
286
|
+
epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
|
|
287
|
+
if const_expr(delay_tma_store):
|
|
288
|
+
if const_expr(epi_idx > 0):
|
|
289
|
+
tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
|
|
290
|
+
src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
|
|
291
|
+
# Copy from D registers to shared memory
|
|
292
|
+
if const_expr(has_D):
|
|
293
|
+
copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
|
|
294
|
+
cute.copy(
|
|
295
|
+
tiled_copy_postact_r2s,
|
|
296
|
+
tiled_copy_postact_r2s.retile(tRS_rPostAct),
|
|
297
|
+
tRS_sPostAct[None, None, None, epi_buffer],
|
|
298
|
+
)
|
|
299
|
+
if const_expr(not delay_tma_store):
|
|
300
|
+
tma_store_fn(src_idx=epi_buffer, dst_idx=gmem_coord)
|
|
301
|
+
|
|
302
|
+
if const_expr(delay_tma_store):
|
|
303
|
+
tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev)
|
|
304
|
+
|
|
305
|
+
self.epi_end(
|
|
306
|
+
params,
|
|
307
|
+
epi_tensors,
|
|
308
|
+
epi_tile,
|
|
309
|
+
tiled_copy_t2r,
|
|
310
|
+
tiled_copy_r2s,
|
|
311
|
+
tile_coord_mnkl,
|
|
312
|
+
varlen_manager,
|
|
313
|
+
tidx,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
return epi_read_state, epi_producer_state
|
|
317
|
+
|
|
318
|
+
@cute.jit
|
|
319
|
+
def epi_visit_subtile(
|
|
320
|
+
self,
|
|
321
|
+
params: EpilogueParams,
|
|
322
|
+
epi_loop_tensors: Tuple[cute.Tensor, ...],
|
|
323
|
+
tRS_rD: cute.Tensor,
|
|
324
|
+
tRS_rC: Optional[cute.Tensor] = None,
|
|
325
|
+
) -> Optional[cute.Tensor]:
|
|
326
|
+
GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC)
|
|
327
|
+
# Apply activation function if provided
|
|
328
|
+
# If we don't have .shape here, the compiler generates local stores and loads
|
|
329
|
+
if const_expr(params.act_fn is not None):
|
|
330
|
+
tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
|
|
331
|
+
if const_expr(self.arch < 100):
|
|
332
|
+
for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
|
|
333
|
+
tRS_rPostAct[i] = params.act_fn(tRS_rD[i])
|
|
334
|
+
else:
|
|
335
|
+
for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True):
|
|
336
|
+
tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1] = params.act_fn(
|
|
337
|
+
(tRS_rD[2 * i], tRS_rD[2 * i + 1])
|
|
338
|
+
)
|
|
339
|
+
else:
|
|
340
|
+
tRS_rPostAct = tRS_rD
|
|
341
|
+
# Type conversion
|
|
342
|
+
tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype)
|
|
343
|
+
tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
|
|
344
|
+
return tRS_rPostAct_out
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
class GemmActSm90(GemmActMixin, GemmSm90):
|
|
348
|
+
pass
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
class GemmActSm100(GemmActMixin, GemmSm100):
|
|
352
|
+
pass
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
act_fn_map = {
|
|
356
|
+
None: None,
|
|
357
|
+
"relu": quack.activation.relu,
|
|
358
|
+
"relu_sq": quack.activation.relu_sq,
|
|
359
|
+
"gelu_tanh_approx": quack.activation.gelu_tanh_approx,
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def gemm_act(
|
|
364
|
+
A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
|
|
365
|
+
B: Tensor, # (l, n, k)
|
|
366
|
+
D: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
|
|
367
|
+
C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
|
|
368
|
+
PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
|
|
369
|
+
tile_count_semaphore: Optional[Tensor], # (1,)
|
|
370
|
+
activation: Optional[str],
|
|
371
|
+
tile_M: int,
|
|
372
|
+
tile_N: int,
|
|
373
|
+
cluster_M: int,
|
|
374
|
+
cluster_N: int,
|
|
375
|
+
pingpong: bool = False,
|
|
376
|
+
persistent: bool = True,
|
|
377
|
+
max_swizzle_size: int = 8,
|
|
378
|
+
rowvec_bias: Optional[Tensor] = None, # (l, n)
|
|
379
|
+
colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
|
|
380
|
+
cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
|
|
381
|
+
A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
|
|
382
|
+
) -> None:
|
|
383
|
+
if cu_seqlens_m is not None:
|
|
384
|
+
assert persistent, "varlen_m requires persistent=True"
|
|
385
|
+
assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
|
|
386
|
+
if D is not None:
|
|
387
|
+
assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
|
|
388
|
+
assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
|
|
389
|
+
gather_A = A_idx is not None
|
|
390
|
+
if gather_A:
|
|
391
|
+
assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)"
|
|
392
|
+
assert cluster_N == 1, "gather_A requires cluster_N=1"
|
|
393
|
+
assert activation in act_fn_map, f"Unsupported activation {activation}"
|
|
394
|
+
|
|
395
|
+
L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
|
|
396
|
+
A, B, D, C, additional_tensors={"PostAct": PostAct}, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx
|
|
397
|
+
)
|
|
398
|
+
GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None)
|
|
399
|
+
GemmWrapperBase.extract_dtypes(tensor_infos)
|
|
400
|
+
major_configs = {
|
|
401
|
+
"A": ("m", "k", "l"),
|
|
402
|
+
"B": ("n", "k", "l"),
|
|
403
|
+
"D": ("m", "n", "l"),
|
|
404
|
+
"C": ("m", "n", "l"),
|
|
405
|
+
"PostAct": ("m", "n", "l"),
|
|
406
|
+
}
|
|
407
|
+
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
|
408
|
+
|
|
409
|
+
device_capacity = get_device_capacity(A.device)
|
|
410
|
+
assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
|
|
411
|
+
GemmCls = GemmActSm100 if device_capacity[0] > 9 else GemmActSm90
|
|
412
|
+
|
|
413
|
+
acc_dtype = Float32
|
|
414
|
+
tile_shape_mn = (tile_M, tile_N)
|
|
415
|
+
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
|
416
|
+
if not GemmCls.is_valid_dtypes(
|
|
417
|
+
tensor_infos["A"].dtype,
|
|
418
|
+
tensor_infos["B"].dtype,
|
|
419
|
+
acc_dtype,
|
|
420
|
+
tensor_infos["D"].dtype,
|
|
421
|
+
tensor_infos["A"].major,
|
|
422
|
+
tensor_infos["B"].major,
|
|
423
|
+
):
|
|
424
|
+
raise TypeError("Skipping due to unsupported combination of types and majors")
|
|
425
|
+
|
|
426
|
+
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
|
427
|
+
GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
|
|
428
|
+
act_fn = act_fn_map[activation]
|
|
429
|
+
epi_args = GemmCls.EpilogueArguments(
|
|
430
|
+
tensor_infos["PostAct"].cute_tensor,
|
|
431
|
+
act_fn,
|
|
432
|
+
mRowVecBroadcast=from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
|
|
433
|
+
leading_dim=1
|
|
434
|
+
)
|
|
435
|
+
if rowvec_bias is not None
|
|
436
|
+
else None,
|
|
437
|
+
mColVecBroadcast=from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
|
|
438
|
+
leading_dim=1 if cu_seqlens_m is None else 0
|
|
439
|
+
)
|
|
440
|
+
if colvec_bias is not None
|
|
441
|
+
else None,
|
|
442
|
+
)
|
|
443
|
+
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
|
444
|
+
max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
# Create varlen arguments if needed (assumes persistent=True when varlen_m)
|
|
448
|
+
varlen_args = GemmWrapperBase.create_varlen_args(
|
|
449
|
+
cu_seqlens_m,
|
|
450
|
+
None, # cu_seqlens_k
|
|
451
|
+
A_idx,
|
|
452
|
+
max_active_clusters,
|
|
453
|
+
cluster_shape_mnk,
|
|
454
|
+
tensor_infos,
|
|
455
|
+
GemmCls.num_epi_tensormaps,
|
|
456
|
+
pingpong,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
current_stream = cutlass_torch.current_stream()
|
|
460
|
+
compile_key = GemmWrapperBase.get_compile_key(
|
|
461
|
+
tensor_infos,
|
|
462
|
+
activation,
|
|
463
|
+
tile_shape_mn,
|
|
464
|
+
cluster_shape_mnk,
|
|
465
|
+
pingpong,
|
|
466
|
+
persistent,
|
|
467
|
+
tile_count_semaphore is not None,
|
|
468
|
+
device_capacity,
|
|
469
|
+
max_swizzle_size,
|
|
470
|
+
rowvec_bias.dtype if rowvec_bias is not None else None,
|
|
471
|
+
colvec_bias.dtype if colvec_bias is not None else None,
|
|
472
|
+
cu_seqlens_m is not None,
|
|
473
|
+
A_idx is not None,
|
|
474
|
+
key_tensor_names=("A", "B", "D", "PostAct", "C"),
|
|
475
|
+
)
|
|
476
|
+
cache = gemm_act.compile_cache
|
|
477
|
+
if compile_key not in cache:
|
|
478
|
+
if device_capacity[0] == 9:
|
|
479
|
+
GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
|
|
480
|
+
gemm_obj = GemmCls(
|
|
481
|
+
acc_dtype,
|
|
482
|
+
tensor_infos["A"].dtype,
|
|
483
|
+
tile_shape_mn,
|
|
484
|
+
cluster_shape_mnk,
|
|
485
|
+
gather_A=gather_A,
|
|
486
|
+
)
|
|
487
|
+
cache[compile_key] = cute.compile(
|
|
488
|
+
gemm_obj,
|
|
489
|
+
tensor_infos["A"].cute_tensor,
|
|
490
|
+
tensor_infos["B"].cute_tensor,
|
|
491
|
+
tensor_infos["D"].cute_tensor,
|
|
492
|
+
tensor_infos["C"].cute_tensor,
|
|
493
|
+
epi_args,
|
|
494
|
+
scheduler_args,
|
|
495
|
+
varlen_args,
|
|
496
|
+
current_stream,
|
|
497
|
+
)
|
|
498
|
+
cache[compile_key](
|
|
499
|
+
tensor_infos["A"].cute_tensor,
|
|
500
|
+
tensor_infos["B"].cute_tensor,
|
|
501
|
+
tensor_infos["D"].cute_tensor,
|
|
502
|
+
tensor_infos["C"].cute_tensor,
|
|
503
|
+
epi_args,
|
|
504
|
+
scheduler_args,
|
|
505
|
+
varlen_args,
|
|
506
|
+
current_stream,
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
gemm_act.compile_cache = {}
|
quack/gemm_config.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
# Copyright (C) 2025, Fri Dao.
|
|
2
2
|
import itertools
|
|
3
|
-
from typing import Optional, List
|
|
3
|
+
from typing import Optional, List, Literal
|
|
4
|
+
from functools import partial
|
|
4
5
|
from dataclasses import dataclass
|
|
5
6
|
|
|
6
7
|
|
|
@@ -13,57 +14,82 @@ class GemmConfig:
|
|
|
13
14
|
cluster_n: int = 1
|
|
14
15
|
swap_ab: bool = False
|
|
15
16
|
# raster_order: int = 1
|
|
16
|
-
|
|
17
|
+
max_swizzle_size: int = 8
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
def get_all_configs(
|
|
21
|
+
device_capacity: Literal[9, 10] = 9,
|
|
20
22
|
epilogue: Optional[str] = None,
|
|
21
23
|
tune_coop: bool = True,
|
|
22
24
|
# tune_raster_order=True,
|
|
23
25
|
) -> List[GemmConfig]:
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
(
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
tile_mn_vals
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
if epilogue in ["lse"]:
|
|
26
|
+
assert device_capacity in [9, 10]
|
|
27
|
+
if device_capacity == 9:
|
|
28
|
+
tile_n_vals = [128, 144, 160, 176, 192, 208]
|
|
29
|
+
tile_mn_coop_vals = [(256, tile_n) for tile_n in tile_n_vals] + [
|
|
30
|
+
(128, 224),
|
|
31
|
+
(128, 256),
|
|
32
|
+
# (192, 256), # Getting IOT instruction (core dumped) in the bwd
|
|
33
|
+
]
|
|
34
|
+
tile_mn_pingpong_vals = [(128, tile_n) for tile_n in tile_n_vals] + [(192, 128)]
|
|
35
|
+
if epilogue in ["gated"]:
|
|
36
|
+
tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if n % 32 == 0 and m != 192]
|
|
37
|
+
tile_mn_pingpong_vals = [(m, n) for m, n in tile_mn_pingpong_vals if n % 32 == 0]
|
|
38
|
+
elif epilogue in ["lse"]:
|
|
39
|
+
tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if m != 192]
|
|
40
|
+
tile_mn_vals = []
|
|
41
|
+
if tune_coop:
|
|
42
|
+
tile_mn_vals += [(m, n, False) for m, n in tile_mn_coop_vals]
|
|
43
|
+
tile_mn_vals += [(m, n, True) for m, n in tile_mn_pingpong_vals]
|
|
43
44
|
cluster = [(1, 2), (2, 1)]
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
45
|
+
# cluster = [(1, 1), (1, 2), (2, 1)]
|
|
46
|
+
if epilogue in ["lse"]:
|
|
47
|
+
cluster = [(1, 2), (2, 1)]
|
|
48
|
+
swap_ab_vals = [False, True]
|
|
49
|
+
if epilogue in ["lse", "gated"]:
|
|
50
|
+
swap_ab_vals = [False]
|
|
51
|
+
# raster_swizzle = (
|
|
52
|
+
# [(0, 1)]
|
|
53
|
+
# if not tune_raster_order
|
|
54
|
+
# else [(1, 1), (1, 2), (1, 4), (1, 8), (2, 1), (2, 2), (2, 4), (2, 8)]
|
|
55
|
+
# )
|
|
56
|
+
return [
|
|
57
|
+
GemmConfig(
|
|
58
|
+
tile_m=tile_m,
|
|
59
|
+
tile_n=tile_n,
|
|
60
|
+
pingpong=pingpong,
|
|
61
|
+
cluster_m=cluster_m,
|
|
62
|
+
cluster_n=cluster_n,
|
|
63
|
+
swap_ab=swap_ab,
|
|
64
|
+
# raster_order=raster_order,
|
|
65
|
+
# max_swizzle_size=max_swizzle_size,
|
|
66
|
+
)
|
|
67
|
+
for (tile_m, tile_n, pingpong), (cluster_m, cluster_n), swap_ab in itertools.product(
|
|
68
|
+
tile_mn_vals,
|
|
69
|
+
cluster,
|
|
70
|
+
swap_ab_vals,
|
|
71
|
+
# raster_swizzle,
|
|
72
|
+
)
|
|
73
|
+
]
|
|
74
|
+
elif device_capacity == 10:
|
|
75
|
+
tile_n_vals = [128, 160, 192, 224, 256]
|
|
76
|
+
tile_n_64_vals = [128, 192, 256]
|
|
77
|
+
tile_mn_cluster_vals = (
|
|
78
|
+
[(128, tile_n, (1, 2)) for tile_n in tile_n_vals]
|
|
79
|
+
# + [(128, tile_n, (2, 1)) for tile_n in tile_n_64_vals]
|
|
80
|
+
+ [(128, tile_n, (2, 1)) for tile_n in tile_n_vals]
|
|
81
|
+
+ [(256, tile_n, (2, 1)) for tile_n in tile_n_vals]
|
|
62
82
|
)
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
83
|
+
swap_ab_vals = [False, True]
|
|
84
|
+
if epilogue in ["lse", "gated"]:
|
|
85
|
+
swap_ab_vals = [False]
|
|
86
|
+
max_swizzle_size_vals = [4, 8, 16]
|
|
87
|
+
GemmConfigCls = partial(GemmConfig, pingpong=False) # There's no pingpong on Sm100
|
|
88
|
+
return [
|
|
89
|
+
GemmConfigCls(
|
|
90
|
+
tile_m=m, tile_n=n, cluster_m=cm, cluster_n=cn, swap_ab=sab, max_swizzle_size=ms
|
|
91
|
+
)
|
|
92
|
+
for (m, n, (cm, cn)), sab, ms in itertools.product(
|
|
93
|
+
tile_mn_cluster_vals, swap_ab_vals, max_swizzle_size_vals
|
|
94
|
+
)
|
|
95
|
+
]
|