quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/autotuner.py +64 -5
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -35
- quack/gemm.py +194 -0
- quack/gemm_act.py +510 -0
- quack/gemm_config.py +72 -46
- quack/gemm_dact.py +215 -0
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +615 -146
- quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +182 -23
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +508 -624
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +55 -61
- quack/topk.py +409 -85
- quack/utils.py +37 -172
- quack/varlen_utils.py +370 -6
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/gemm_act_sm90.py +0 -368
- quack/gemm_dact_sm90.py +0 -150
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.1.dist-info/RECORD +0 -37
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/gemm_symmetric.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
from typing import Tuple, Optional, Callable
|
|
2
|
+
from functools import partial
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from quack.gemm_act import GemmActMixin, act_fn_map, gemm_act
|
|
5
|
+
from quack.gemm_sm90 import GemmSm90
|
|
6
|
+
from quack.gemm_sm100 import GemmSm100
|
|
7
|
+
from quack.tile_scheduler import TriangularTileScheduler
|
|
8
|
+
from quack.gemm_wrapper_utils import GemmWrapperBase
|
|
9
|
+
from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters
|
|
10
|
+
from quack.varlen_utils import VarlenManager
|
|
11
|
+
import quack.copy_utils as copy_utils
|
|
12
|
+
import cutlass
|
|
13
|
+
import cutlass.cute as cute
|
|
14
|
+
import cutlass.torch as cutlass_torch
|
|
15
|
+
from cutlass.cute.runtime import make_ptr
|
|
16
|
+
from cutlass import Int32, Float32, Boolean, const_expr
|
|
17
|
+
import cutlass.utils.hopper_helpers as sm90_utils_og
|
|
18
|
+
import cutlass.utils.blackwell_helpers as sm100_utils
|
|
19
|
+
from cutlass.cutlass_dsl import if_generate
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class GemmSymmetricMixin(GemmActMixin, GemmSm90):
|
|
23
|
+
def get_scheduler_class(self, varlen_m: bool = False):
|
|
24
|
+
return TriangularTileScheduler
|
|
25
|
+
|
|
26
|
+
@cute.jit
|
|
27
|
+
def epilogue(
|
|
28
|
+
self,
|
|
29
|
+
params: GemmActMixin.EpilogueParams,
|
|
30
|
+
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
|
31
|
+
tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
|
|
32
|
+
epi_pipeline: cutlass.pipeline.PipelineAsync,
|
|
33
|
+
epi_store_pipeline: cutlass.pipeline.PipelineAsync,
|
|
34
|
+
epi_read_state: cutlass.pipeline.PipelineState,
|
|
35
|
+
epi_producer_state: cutlass.pipeline.PipelineState,
|
|
36
|
+
epi_tile: cute.Tile,
|
|
37
|
+
load_acc_subtile: Callable,
|
|
38
|
+
tRS_rD: cute.Tensor,
|
|
39
|
+
tRS_rC: Optional[cute.Tensor],
|
|
40
|
+
tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100
|
|
41
|
+
tiled_copy_r2s: cute.TiledCopy,
|
|
42
|
+
tRS_sD: cute.Tensor,
|
|
43
|
+
tiled_copy_s2r: Optional[cute.TiledCopy],
|
|
44
|
+
tSR_rC: Optional[cute.Tensor],
|
|
45
|
+
tSR_sC: Optional[cute.Tensor],
|
|
46
|
+
copy_D: Optional[Callable],
|
|
47
|
+
copy_C: Optional[Callable],
|
|
48
|
+
tile_coord_mnkl: cute.Coord,
|
|
49
|
+
varlen_manager: VarlenManager,
|
|
50
|
+
epilogue_barrier: cutlass.pipeline.NamedBarrier,
|
|
51
|
+
tile_scheduler,
|
|
52
|
+
tidx: Int32,
|
|
53
|
+
is_tma_warp: Boolean,
|
|
54
|
+
) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
|
|
55
|
+
has_C = const_expr(tRS_rC is not None)
|
|
56
|
+
has_D = const_expr(copy_D is not None)
|
|
57
|
+
|
|
58
|
+
tma_atom_postact = params.tma_atom_postact
|
|
59
|
+
mPostAct_mnl = params.mPostAct_mnl
|
|
60
|
+
sRowVec, sColVec, sPostAct = epi_smem_tensors
|
|
61
|
+
get_smem_store_op = (
|
|
62
|
+
partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r)
|
|
63
|
+
if self.arch == 100
|
|
64
|
+
else sm90_utils_og.sm90_get_smem_store_op
|
|
65
|
+
)
|
|
66
|
+
copy_atom_postact_r2s = get_smem_store_op(
|
|
67
|
+
self.postact_layout, self.postact_dtype, self.acc_dtype
|
|
68
|
+
)
|
|
69
|
+
# tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma)
|
|
70
|
+
# tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom)
|
|
71
|
+
tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s)
|
|
72
|
+
tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct)
|
|
73
|
+
(tma_desc_postact_ptr,) = tma_desc_epi_ptrs
|
|
74
|
+
batch_idx = tile_coord_mnkl[3]
|
|
75
|
+
copy_postact, _, _ = self.epilog_gmem_copy_and_partition(
|
|
76
|
+
tma_atom_postact,
|
|
77
|
+
varlen_manager.offset_batch_epi(mPostAct_mnl, batch_idx),
|
|
78
|
+
self.cta_tile_shape_postact_mn,
|
|
79
|
+
params.epi_tile_postact,
|
|
80
|
+
sPostAct,
|
|
81
|
+
tile_coord_mnkl,
|
|
82
|
+
tma_desc_ptr=tma_desc_postact_ptr,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# We iterate over epi tiles in the N dimension first before the M dimension
|
|
86
|
+
epi_tile_shape = cute.zipped_divide(
|
|
87
|
+
cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
|
|
88
|
+
).shape[1]
|
|
89
|
+
epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
|
|
90
|
+
epi_tile_num = cute.size(epi_tile_shape)
|
|
91
|
+
num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
|
|
92
|
+
|
|
93
|
+
epi_tensors = self.epi_begin(
|
|
94
|
+
params,
|
|
95
|
+
epi_smem_tensors,
|
|
96
|
+
epi_tile,
|
|
97
|
+
tiled_copy_t2r,
|
|
98
|
+
tiled_copy_r2s,
|
|
99
|
+
tile_coord_mnkl,
|
|
100
|
+
varlen_manager,
|
|
101
|
+
epilogue_barrier,
|
|
102
|
+
tidx,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if const_expr(copy_C is not None):
|
|
106
|
+
for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
|
|
107
|
+
gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
|
|
108
|
+
if is_tma_warp:
|
|
109
|
+
epi_pipeline.producer_acquire(epi_producer_state)
|
|
110
|
+
copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
|
|
111
|
+
epi_pipeline.producer_commit(epi_producer_state)
|
|
112
|
+
epi_producer_state.advance()
|
|
113
|
+
|
|
114
|
+
def tma_store_fn(src_idx, dst_idx, tile_coord_mnkl):
|
|
115
|
+
pid_m = tile_coord_mnkl[0]
|
|
116
|
+
pid_n = tile_coord_mnkl[1]
|
|
117
|
+
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
118
|
+
cute.arch.fence_proxy(
|
|
119
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
120
|
+
)
|
|
121
|
+
epilogue_barrier.arrive_and_wait()
|
|
122
|
+
# Copy from shared memory to global memory
|
|
123
|
+
if is_tma_warp:
|
|
124
|
+
square_tile_m = pid_m // self.cluster_shape_mnk[0]
|
|
125
|
+
square_tile_n = pid_n // self.cluster_shape_mnk[1]
|
|
126
|
+
if const_expr(has_D):
|
|
127
|
+
copy_D(src_idx=src_idx, dst_idx=dst_idx)
|
|
128
|
+
if square_tile_m != square_tile_n: # don't write twice to the same tile
|
|
129
|
+
copy_postact(src_idx=src_idx, dst_idx=dst_idx)
|
|
130
|
+
# Can't use if statement here, epi_store_pipeline object isn't captured somehow
|
|
131
|
+
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit())
|
|
132
|
+
if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire())
|
|
133
|
+
epilogue_barrier.arrive_and_wait()
|
|
134
|
+
|
|
135
|
+
delay_tma_store = True
|
|
136
|
+
|
|
137
|
+
src_idx_prev, dst_idx_prev = None, None
|
|
138
|
+
for epi_idx in cutlass.range_constexpr(epi_tile_num):
|
|
139
|
+
# The global memory coordinate for the current epi tile
|
|
140
|
+
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
|
|
141
|
+
# Copy from acc to D registers
|
|
142
|
+
load_acc_subtile(tRS_rD, epi_idx)
|
|
143
|
+
epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
|
|
144
|
+
if const_expr(has_C):
|
|
145
|
+
epi_pipeline.consumer_wait(epi_read_state)
|
|
146
|
+
cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
|
|
147
|
+
# Fence to make sure shared memory read is visible to TMA load
|
|
148
|
+
cute.arch.fence_proxy(
|
|
149
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
150
|
+
)
|
|
151
|
+
cute.arch.sync_warp()
|
|
152
|
+
with cute.arch.elect_one():
|
|
153
|
+
epi_pipeline.consumer_release(epi_read_state)
|
|
154
|
+
epi_read_state.advance()
|
|
155
|
+
if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
|
|
156
|
+
gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
|
|
157
|
+
if is_tma_warp:
|
|
158
|
+
epi_pipeline.producer_acquire(epi_producer_state)
|
|
159
|
+
copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
|
|
160
|
+
epi_pipeline.producer_commit(epi_producer_state)
|
|
161
|
+
epi_producer_state.advance()
|
|
162
|
+
tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
|
|
163
|
+
epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
|
|
164
|
+
if const_expr(delay_tma_store):
|
|
165
|
+
if const_expr(epi_idx > 0):
|
|
166
|
+
tma_store_fn(
|
|
167
|
+
src_idx=src_idx_prev, dst_idx=dst_idx_prev, tile_coord_mnkl=tile_coord_mnkl
|
|
168
|
+
)
|
|
169
|
+
src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord
|
|
170
|
+
# Copy from D registers to shared memory
|
|
171
|
+
if const_expr(has_D):
|
|
172
|
+
copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
|
|
173
|
+
cute.copy(
|
|
174
|
+
tiled_copy_postact_r2s,
|
|
175
|
+
tiled_copy_postact_r2s.retile(tRS_rPostAct),
|
|
176
|
+
tRS_sPostAct[None, None, None, epi_buffer],
|
|
177
|
+
)
|
|
178
|
+
if const_expr(not delay_tma_store):
|
|
179
|
+
tma_store_fn(
|
|
180
|
+
src_idx=epi_buffer, dst_idx=gmem_coord, tile_coord_mnkl=tile_coord_mnkl
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
if const_expr(delay_tma_store):
|
|
184
|
+
tma_store_fn(
|
|
185
|
+
src_idx=src_idx_prev, dst_idx=dst_idx_prev, tile_coord_mnkl=tile_coord_mnkl
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
self.epi_end(
|
|
189
|
+
params,
|
|
190
|
+
epi_tensors,
|
|
191
|
+
epi_tile,
|
|
192
|
+
tiled_copy_t2r,
|
|
193
|
+
tiled_copy_r2s,
|
|
194
|
+
tile_coord_mnkl,
|
|
195
|
+
varlen_manager,
|
|
196
|
+
tidx,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
return epi_read_state, epi_producer_state
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class GemmSymmetricSm90(GemmSymmetricMixin, GemmSm90):
|
|
203
|
+
pass
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class GemmSymmetricSm100(GemmSymmetricMixin, GemmSm100):
|
|
207
|
+
pass
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def gemm_symmetric(
|
|
211
|
+
A: Tensor, # (l, m, k)
|
|
212
|
+
B: Tensor, # (l, m, k)
|
|
213
|
+
D: Optional[Tensor], # (l, m, m)
|
|
214
|
+
C: Optional[Tensor], # (l, m, m)
|
|
215
|
+
tile_count_semaphore: Optional[Tensor], # (1,)
|
|
216
|
+
tile_M: int,
|
|
217
|
+
tile_N: int,
|
|
218
|
+
cluster_M: int,
|
|
219
|
+
cluster_N: int,
|
|
220
|
+
pingpong: bool = False,
|
|
221
|
+
persistent: bool = True,
|
|
222
|
+
max_swizzle_size: int = 8,
|
|
223
|
+
alpha: float | Tensor = 1.0,
|
|
224
|
+
beta: float | Tensor = 1.0,
|
|
225
|
+
) -> None:
|
|
226
|
+
# Tranpose D so the "activation" is a write to the mirrored tile
|
|
227
|
+
PostAct = D.mT
|
|
228
|
+
|
|
229
|
+
L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
|
|
230
|
+
A, B, D, C, additional_tensors={"PostAct": PostAct}
|
|
231
|
+
)
|
|
232
|
+
assert M == N, "M and N must be the same; symmetric gemm only supports square matrices"
|
|
233
|
+
GemmWrapperBase.permute_tensors(tensor_infos)
|
|
234
|
+
GemmWrapperBase.extract_dtypes(tensor_infos)
|
|
235
|
+
major_configs = {
|
|
236
|
+
"A": ("m", "k", "l"),
|
|
237
|
+
"B": ("n", "k", "l"),
|
|
238
|
+
"D": ("m", "n", "l"),
|
|
239
|
+
"C": ("m", "n", "l"),
|
|
240
|
+
"PostAct": ("m", "n", "l"),
|
|
241
|
+
}
|
|
242
|
+
GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
|
|
243
|
+
|
|
244
|
+
device_capacity = get_device_capacity(A.device)
|
|
245
|
+
assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
|
|
246
|
+
GemmCls = GemmSymmetricSm90 if device_capacity[0] == 9 else GemmSymmetricSm100
|
|
247
|
+
|
|
248
|
+
acc_dtype = Float32
|
|
249
|
+
tile_shape_mn = (tile_M, tile_N)
|
|
250
|
+
cluster_shape_mnk = (cluster_M, cluster_N, 1)
|
|
251
|
+
if not GemmCls.is_valid_dtypes(
|
|
252
|
+
tensor_infos["A"].dtype,
|
|
253
|
+
tensor_infos["B"].dtype,
|
|
254
|
+
acc_dtype,
|
|
255
|
+
tensor_infos["D"].dtype,
|
|
256
|
+
tensor_infos["A"].major,
|
|
257
|
+
tensor_infos["B"].major,
|
|
258
|
+
):
|
|
259
|
+
raise TypeError("Skipping due to unsupported combination of types and majors")
|
|
260
|
+
|
|
261
|
+
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
|
262
|
+
GemmWrapperBase.create_cute_tensors({k: v for k, v in tensor_infos.items()}, major_configs)
|
|
263
|
+
|
|
264
|
+
def scalar_arg(scalar: float | Tensor):
|
|
265
|
+
if isinstance(scalar, float):
|
|
266
|
+
return Float32(scalar) if scalar != 1.0 else None
|
|
267
|
+
else:
|
|
268
|
+
assert isinstance(scalar, Tensor)
|
|
269
|
+
return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
|
|
270
|
+
|
|
271
|
+
activation = None # Equivalent to identity
|
|
272
|
+
act_fn = act_fn_map[activation]
|
|
273
|
+
epi_args = GemmCls.EpilogueArguments(
|
|
274
|
+
tensor_infos["PostAct"].cute_tensor, act_fn, scalar_arg(alpha), scalar_arg(beta)
|
|
275
|
+
)
|
|
276
|
+
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
|
277
|
+
max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
|
|
278
|
+
)
|
|
279
|
+
varlen_args = None
|
|
280
|
+
|
|
281
|
+
current_stream = cutlass_torch.current_stream()
|
|
282
|
+
compile_key = GemmWrapperBase.get_compile_key(
|
|
283
|
+
tensor_infos,
|
|
284
|
+
activation,
|
|
285
|
+
tile_shape_mn,
|
|
286
|
+
cluster_shape_mnk,
|
|
287
|
+
pingpong,
|
|
288
|
+
persistent,
|
|
289
|
+
tile_count_semaphore is not None,
|
|
290
|
+
device_capacity,
|
|
291
|
+
max_swizzle_size,
|
|
292
|
+
2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
|
|
293
|
+
2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
|
|
294
|
+
key_tensor_names=("A", "B", "D", "PostAct", "C"),
|
|
295
|
+
)
|
|
296
|
+
cache = gemm_act.compile_cache
|
|
297
|
+
if compile_key not in cache:
|
|
298
|
+
if device_capacity[0] == 9:
|
|
299
|
+
GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
|
|
300
|
+
gemm_obj = GemmCls(
|
|
301
|
+
acc_dtype,
|
|
302
|
+
tensor_infos["A"].dtype,
|
|
303
|
+
tile_shape_mn,
|
|
304
|
+
cluster_shape_mnk,
|
|
305
|
+
gather_A=False,
|
|
306
|
+
)
|
|
307
|
+
cache[compile_key] = cute.compile(
|
|
308
|
+
gemm_obj,
|
|
309
|
+
tensor_infos["A"].cute_tensor,
|
|
310
|
+
tensor_infos["B"].cute_tensor,
|
|
311
|
+
tensor_infos["D"].cute_tensor,
|
|
312
|
+
tensor_infos["C"].cute_tensor,
|
|
313
|
+
epi_args,
|
|
314
|
+
scheduler_args,
|
|
315
|
+
varlen_args,
|
|
316
|
+
current_stream,
|
|
317
|
+
)
|
|
318
|
+
cache[compile_key](
|
|
319
|
+
tensor_infos["A"].cute_tensor,
|
|
320
|
+
tensor_infos["B"].cute_tensor,
|
|
321
|
+
tensor_infos["D"].cute_tensor,
|
|
322
|
+
tensor_infos["C"].cute_tensor,
|
|
323
|
+
epi_args,
|
|
324
|
+
scheduler_args,
|
|
325
|
+
varlen_args,
|
|
326
|
+
current_stream,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
gemm_act.compile_cache = {}
|
quack/gemm_wrapper_utils.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
from typing import Optional, Tuple, Dict, Any
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
|
|
5
|
+
import torch
|
|
5
6
|
from torch import Tensor
|
|
6
7
|
|
|
7
8
|
import cutlass.cute as cute
|
|
@@ -9,7 +10,8 @@ from cutlass import Int32
|
|
|
9
10
|
from cutlass.cute.runtime import from_dlpack, make_ptr
|
|
10
11
|
|
|
11
12
|
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
12
|
-
from quack.
|
|
13
|
+
from quack.varlen_utils import VarlenArguments
|
|
14
|
+
from quack.tile_scheduler import TileSchedulerOptions
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
@dataclass
|
|
@@ -22,8 +24,8 @@ class GemmTensorInfo:
|
|
|
22
24
|
|
|
23
25
|
class GemmWrapperBase:
|
|
24
26
|
@staticmethod
|
|
25
|
-
def
|
|
26
|
-
assert tensor.dim() ==
|
|
27
|
+
def validate_tensor(tensor: Tensor, name: str, ndim: int) -> None:
|
|
28
|
+
assert tensor.dim() == ndim and tensor.is_cuda, f"{name} must be a {ndim}D CUDA tensor"
|
|
27
29
|
assert tensor.dtype in torch2cute_dtype_map, f"Unsupported dtype for {name}"
|
|
28
30
|
|
|
29
31
|
@staticmethod
|
|
@@ -47,7 +49,7 @@ class GemmWrapperBase:
|
|
|
47
49
|
) -> Optional[cute.Tensor]:
|
|
48
50
|
if tensor is None:
|
|
49
51
|
return None
|
|
50
|
-
# Tensor is already permuted to (dims[0], dims[1], dims[2])
|
|
52
|
+
# Tensor is already permuted to (dims[0], dims[1], dims[2]) or (dim[0], dim[1])
|
|
51
53
|
# If major is dims[1], leading_dim is 1; if major is dims[0], leading_dim is 0
|
|
52
54
|
leading_dim = 1 if major == dims[1] else 0
|
|
53
55
|
return from_dlpack(tensor.detach(), assumed_align=assumed_align).mark_layout_dynamic(
|
|
@@ -61,43 +63,131 @@ class GemmWrapperBase:
|
|
|
61
63
|
D: Optional[Tensor] = None,
|
|
62
64
|
C: Optional[Tensor] = None,
|
|
63
65
|
additional_tensors: Optional[Dict[str, Tensor]] = None,
|
|
66
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
67
|
+
cu_seqlens_k: Optional[Tensor] = None,
|
|
68
|
+
A_idx: Optional[Tensor] = None,
|
|
64
69
|
) -> Tuple[int, int, int, int, Dict[str, GemmTensorInfo]]:
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
_, N, _ = B.shape
|
|
70
|
+
assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), (
|
|
71
|
+
"Only one of cu_seqlens_m and cu_seqlens_k can be specified"
|
|
72
|
+
)
|
|
69
73
|
assert B.dtype == A.dtype, "A and B must have the same dtype"
|
|
70
|
-
|
|
74
|
+
|
|
75
|
+
# Validate A_idx if provided (for gather_A case)
|
|
76
|
+
gather_A = A_idx is not None
|
|
77
|
+
if gather_A:
|
|
78
|
+
assert cu_seqlens_m is not None or cu_seqlens_k is not None, (
|
|
79
|
+
"gather_A requires either varlen_m or varlen_k"
|
|
80
|
+
)
|
|
81
|
+
assert A_idx.dtype == torch.int32, f"A_idx must be int32, got {A_idx.dtype}"
|
|
82
|
+
assert A_idx.dim() == 1, f"A_idx must be 1D, got {A_idx.dim()}D"
|
|
83
|
+
|
|
84
|
+
# Determine mode and extract dimensions
|
|
85
|
+
if cu_seqlens_m is not None:
|
|
86
|
+
# varlen_m: A is (total_m, k) or (whatever, k) if gather_A, B is (l, n, k), D/C are (total_m, n)
|
|
87
|
+
assert A.dim() == 2, f"A must be 2D when using varlen_m, got {A.dim()}D"
|
|
88
|
+
assert B.dim() == 3, f"B must be 3D with varlen_m, got {B.dim()}D"
|
|
89
|
+
|
|
90
|
+
if gather_A:
|
|
91
|
+
# When gather_A, A can have any number of rows, we use A_idx.shape[0] as total_M
|
|
92
|
+
total_M = A_idx.shape[0]
|
|
93
|
+
_, K = A.shape
|
|
94
|
+
else:
|
|
95
|
+
total_M, K = A.shape
|
|
96
|
+
|
|
97
|
+
L, N, K_B = B.shape
|
|
98
|
+
assert K == K_B, f"K dimension mismatch: A has {K}, B has {K_B}"
|
|
99
|
+
assert cu_seqlens_m.shape == (L + 1,), (
|
|
100
|
+
f"cu_seqlens_m must have shape ({L + 1},), got {cu_seqlens_m.shape}"
|
|
101
|
+
)
|
|
102
|
+
M = total_M
|
|
103
|
+
dc_shape = (total_M, N)
|
|
104
|
+
dc_ndim = 2
|
|
105
|
+
elif cu_seqlens_k is not None:
|
|
106
|
+
# varlen_k: A is (m, total_k) or (m, whatever) if gather_A, B is (n, total_k), D/C are (l, m, n)
|
|
107
|
+
assert A.dim() == 2, f"A must be 2D when using varlen_k, got {A.dim()}D"
|
|
108
|
+
assert B.dim() == 2, f"B must be 2D with varlen_k, got {B.dim()}D"
|
|
109
|
+
|
|
110
|
+
if gather_A:
|
|
111
|
+
# When gather_A with varlen_k, A can have any number of columns, we use A_idx.shape[0] as total_K
|
|
112
|
+
M, _ = A.shape
|
|
113
|
+
total_K = A_idx.shape[0]
|
|
114
|
+
else:
|
|
115
|
+
M, total_K = A.shape
|
|
116
|
+
|
|
117
|
+
N, K_B = B.shape
|
|
118
|
+
assert total_K == K_B, f"K dimension mismatch: expected {total_K}, B has {K_B}"
|
|
119
|
+
L = cu_seqlens_k.shape[0] - 1
|
|
120
|
+
assert cu_seqlens_k.shape == (L + 1,), (
|
|
121
|
+
f"cu_seqlens_k must have shape ({L + 1},), got {cu_seqlens_k.shape}"
|
|
122
|
+
)
|
|
123
|
+
K = total_K
|
|
124
|
+
dc_shape = (L, M, N)
|
|
125
|
+
dc_ndim = 3
|
|
126
|
+
else:
|
|
127
|
+
# Normal case - all tensors must be 3D
|
|
128
|
+
GemmWrapperBase.validate_tensor(A, "A", 3)
|
|
129
|
+
GemmWrapperBase.validate_tensor(B, "B", 3)
|
|
130
|
+
L, M, K = A.shape
|
|
131
|
+
_, N, K_B = B.shape
|
|
132
|
+
assert K == K_B, f"K dimension mismatch: A has {K}, B has {K_B}"
|
|
133
|
+
GemmWrapperBase.validate_shape(B, (L, N, K), "B")
|
|
134
|
+
dc_shape = (L, M, N)
|
|
135
|
+
dc_ndim = 3
|
|
136
|
+
|
|
137
|
+
# Validate D and C shapes uniformly
|
|
138
|
+
for tensor, name in [(D, "D"), (C, "C")]:
|
|
139
|
+
if tensor is not None:
|
|
140
|
+
assert tensor.dim() == dc_ndim, (
|
|
141
|
+
f"{name} must be {dc_ndim}D for this mode, got {tensor.dim()}D"
|
|
142
|
+
)
|
|
143
|
+
assert tensor.shape == dc_shape, (
|
|
144
|
+
f"{name} shape {tensor.shape} doesn't match expected {dc_shape}"
|
|
145
|
+
)
|
|
146
|
+
|
|
71
147
|
tensors = {
|
|
72
148
|
"A": GemmTensorInfo(A),
|
|
73
149
|
"B": GemmTensorInfo(B),
|
|
74
150
|
"D": GemmTensorInfo(D),
|
|
75
151
|
"C": GemmTensorInfo(C),
|
|
76
152
|
}
|
|
77
|
-
|
|
78
|
-
GemmWrapperBase.validate_tensor_3d(D, "D")
|
|
79
|
-
GemmWrapperBase.validate_shape(D, (L, M, N), "D")
|
|
80
|
-
if C is not None:
|
|
81
|
-
GemmWrapperBase.validate_tensor_3d(C, "C")
|
|
82
|
-
GemmWrapperBase.validate_shape(C, (L, M, N), "C")
|
|
153
|
+
|
|
83
154
|
if additional_tensors:
|
|
84
155
|
for name, tensor in additional_tensors.items():
|
|
85
156
|
if tensor is not None:
|
|
86
|
-
|
|
87
|
-
|
|
157
|
+
assert tensor.dim() == dc_ndim, (
|
|
158
|
+
f"{name} must be {dc_ndim}D for this mode, got {tensor.dim()}D"
|
|
159
|
+
)
|
|
160
|
+
assert tensor.shape == dc_shape, (
|
|
161
|
+
f"{name} shape {tensor.shape} doesn't match expected {dc_shape}"
|
|
162
|
+
)
|
|
88
163
|
tensors[name] = GemmTensorInfo(tensor)
|
|
89
164
|
|
|
90
165
|
return L, M, K, N, tensors
|
|
91
166
|
|
|
92
167
|
@staticmethod
|
|
93
|
-
def permute_tensors(
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
168
|
+
def permute_tensors(
|
|
169
|
+
tensors: Dict[str, GemmTensorInfo], varlen_m: bool = False, varlen_k: bool = False
|
|
170
|
+
) -> None:
|
|
171
|
+
# Determine which tensors need permutation
|
|
172
|
+
if varlen_m:
|
|
173
|
+
# Only B needs permutation (3D tensor)
|
|
174
|
+
tensors_to_permute = ["B"]
|
|
175
|
+
elif varlen_k:
|
|
176
|
+
# Only D and C need permutation (3D tensors)
|
|
177
|
+
tensors_to_permute = ["D", "C"]
|
|
178
|
+
else:
|
|
179
|
+
# All tensors need permutation
|
|
180
|
+
tensors_to_permute = None
|
|
181
|
+
|
|
182
|
+
# Apply permutation from (L, *, *) -> (*, *, L) for selected tensors
|
|
183
|
+
for name, info in tensors.items():
|
|
184
|
+
if info.tensor is not None and info.tensor.ndim == 3:
|
|
185
|
+
if tensors_to_permute is None or name in tensors_to_permute:
|
|
186
|
+
info.tensor = info.tensor.permute(1, 2, 0)
|
|
97
187
|
|
|
98
188
|
@staticmethod
|
|
99
189
|
def extract_dtypes(tensors: Dict[str, GemmTensorInfo]) -> None:
|
|
100
|
-
for info in tensors.
|
|
190
|
+
for name, info in tensors.items():
|
|
101
191
|
if info.tensor is not None:
|
|
102
192
|
info.dtype = torch2cute_dtype_map[info.tensor.dtype]
|
|
103
193
|
|
|
@@ -121,7 +211,10 @@ class GemmWrapperBase:
|
|
|
121
211
|
|
|
122
212
|
@staticmethod
|
|
123
213
|
def create_scheduler_args(
|
|
124
|
-
max_active_clusters: int,
|
|
214
|
+
max_active_clusters: int,
|
|
215
|
+
tile_count_semaphore: Optional[Tensor] = None,
|
|
216
|
+
batch_idx_permute: Optional[Tensor] = None,
|
|
217
|
+
max_swizzle_size: int = 8,
|
|
125
218
|
) -> TileSchedulerOptions:
|
|
126
219
|
return TileSchedulerOptions(
|
|
127
220
|
Int32(max_active_clusters),
|
|
@@ -130,6 +223,72 @@ class GemmWrapperBase:
|
|
|
130
223
|
)
|
|
131
224
|
if tile_count_semaphore is not None
|
|
132
225
|
else None,
|
|
226
|
+
batch_idx_permute=(
|
|
227
|
+
from_dlpack(batch_idx_permute, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
228
|
+
)
|
|
229
|
+
if batch_idx_permute is not None
|
|
230
|
+
else None,
|
|
231
|
+
max_swizzle_size=Int32(max_swizzle_size),
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
@staticmethod
|
|
235
|
+
def create_varlen_args(
|
|
236
|
+
cu_seqlens_m: Optional[Tensor],
|
|
237
|
+
cu_seqlens_k: Optional[Tensor],
|
|
238
|
+
A_idx: Optional[Tensor],
|
|
239
|
+
max_active_clusters: int,
|
|
240
|
+
cluster_shape_mnk: Tuple[int, int, int],
|
|
241
|
+
tensors: Dict[str, GemmTensorInfo],
|
|
242
|
+
num_epi_tensormaps: int = 0,
|
|
243
|
+
pingpong: bool = False,
|
|
244
|
+
) -> Optional[Any]:
|
|
245
|
+
if cu_seqlens_m is None and cu_seqlens_k is None:
|
|
246
|
+
return None
|
|
247
|
+
# When varlen_m, we assume persistent=True
|
|
248
|
+
# Grid size depends on num_active_clusters and cluster size
|
|
249
|
+
cluster_size = cluster_shape_mnk[0] * cluster_shape_mnk[1]
|
|
250
|
+
num_blocks = max_active_clusters * cluster_size
|
|
251
|
+
# Calculate number of tensormaps needed
|
|
252
|
+
if cu_seqlens_m is not None:
|
|
253
|
+
# For varlen_m: need tensormaps for D and epilogue tensors
|
|
254
|
+
num_tensormaps = num_epi_tensormaps * (1 if not pingpong else 2)
|
|
255
|
+
if tensors["D"].tensor is not None:
|
|
256
|
+
num_tensormaps += 1 if not pingpong else 2 # D tensormap
|
|
257
|
+
else:
|
|
258
|
+
# For varlen_k: need tensormaps for A & B
|
|
259
|
+
num_tensormaps = 2 if A_idx is None else 1
|
|
260
|
+
# Create tensormap buffer (each tensormap is 128 bytes = 16 int64s)
|
|
261
|
+
tensormap_size = 128 // 8 # 16 int64s
|
|
262
|
+
if num_tensormaps > 0:
|
|
263
|
+
device = cu_seqlens_m.device if cu_seqlens_m is not None else cu_seqlens_k.device
|
|
264
|
+
tensormaps = torch.empty(
|
|
265
|
+
(num_blocks, num_tensormaps, tensormap_size),
|
|
266
|
+
dtype=torch.int64,
|
|
267
|
+
device=device,
|
|
268
|
+
)
|
|
269
|
+
tensormaps_cute = from_dlpack(tensormaps, assumed_align=128).mark_compact_shape_dynamic(
|
|
270
|
+
mode=0, stride_order=(0, 1, 2)
|
|
271
|
+
)
|
|
272
|
+
else:
|
|
273
|
+
tensormaps_cute = None
|
|
274
|
+
|
|
275
|
+
return VarlenArguments(
|
|
276
|
+
mCuSeqlensM=(
|
|
277
|
+
from_dlpack(cu_seqlens_m, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
278
|
+
if cu_seqlens_m is not None
|
|
279
|
+
else None
|
|
280
|
+
),
|
|
281
|
+
mCuSeqlensK=(
|
|
282
|
+
from_dlpack(cu_seqlens_k, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
283
|
+
if cu_seqlens_k is not None
|
|
284
|
+
else None
|
|
285
|
+
),
|
|
286
|
+
mTensormaps=tensormaps_cute,
|
|
287
|
+
mAIdx=(
|
|
288
|
+
from_dlpack(A_idx, assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
|
289
|
+
if A_idx is not None
|
|
290
|
+
else None
|
|
291
|
+
),
|
|
133
292
|
)
|
|
134
293
|
|
|
135
294
|
@staticmethod
|