quack-kernels 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/activation.py +16 -25
- quack/autotuner.py +64 -5
- quack/cross_entropy.py +6 -10
- quack/cute_dsl_utils.py +6 -7
- quack/dense_gemm_sm90.py +582 -287
- quack/gemm_act_sm90.py +70 -29
- quack/gemm_dact_sm90.py +43 -10
- quack/gemm_interface.py +453 -130
- quack/{dense_gemm_sm100.py → gemm_sm100.py} +443 -419
- quack/gemm_wrapper_utils.py +179 -22
- quack/layernorm.py +1 -1
- quack/reduce.py +6 -7
- quack/rmsnorm.py +126 -158
- quack/softmax.py +1 -1
- quack/tile_scheduler.py +37 -49
- quack/utils.py +61 -71
- quack/varlen_utils.py +1 -6
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/METADATA +3 -3
- quack_kernels-0.2.2.dist-info/RECORD +37 -0
- quack_kernels-0.2.0.dist-info/RECORD +0 -37
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/top_level.txt +0 -0
quack/gemm_act_sm90.py
CHANGED
|
@@ -19,6 +19,8 @@ import quack.activation
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class GemmActSm90(GemmSm90):
|
|
22
|
+
num_epi_tensormaps: int = 1
|
|
23
|
+
|
|
22
24
|
@dataclass
|
|
23
25
|
class EpilogueArguments(ArgumentsBase):
|
|
24
26
|
mPostAct: cute.Tensor
|
|
@@ -41,7 +43,7 @@ class GemmActSm90(GemmSm90):
|
|
|
41
43
|
self.postact_dtype = args.mPostAct.element_type
|
|
42
44
|
self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct)
|
|
43
45
|
|
|
44
|
-
self.
|
|
46
|
+
self.cta_tile_shape_postact_mn = self.cta_tile_shape_mnk[:2]
|
|
45
47
|
self.epi_tile_postact = self.epi_tile
|
|
46
48
|
postact_major_mode_size = (
|
|
47
49
|
self.epi_tile_postact[1]
|
|
@@ -63,7 +65,7 @@ class GemmActSm90(GemmSm90):
|
|
|
63
65
|
args.mPostAct,
|
|
64
66
|
epi_postact_smem_layout_staged,
|
|
65
67
|
self.epi_tile_postact,
|
|
66
|
-
|
|
68
|
+
op_type="store",
|
|
67
69
|
)
|
|
68
70
|
return GemmActSm90.EpilogueParams(
|
|
69
71
|
tma_atom_postact,
|
|
@@ -74,10 +76,28 @@ class GemmActSm90(GemmSm90):
|
|
|
74
76
|
args.beta,
|
|
75
77
|
)
|
|
76
78
|
|
|
79
|
+
def epi_get_tma_atoms(
|
|
80
|
+
self, params: EpilogueParams, *, loc=None, ip=None
|
|
81
|
+
) -> list[cute.CopyAtom]:
|
|
82
|
+
return [params.tma_atom_postact]
|
|
83
|
+
|
|
84
|
+
def epi_get_tensormap_update_shapes_orders(
|
|
85
|
+
self,
|
|
86
|
+
params: EpilogueParams,
|
|
87
|
+
cu_seqlens_m: cute.Tensor,
|
|
88
|
+
batch_idx: Int32,
|
|
89
|
+
*,
|
|
90
|
+
loc=None,
|
|
91
|
+
ip=None,
|
|
92
|
+
) -> tuple[list[Int32], list[int]]:
|
|
93
|
+
shapes = [cu_seqlens_m[batch_idx + 1]]
|
|
94
|
+
orders = [0 if const_expr(self.postact_layout.is_m_major_c()) else 1]
|
|
95
|
+
return shapes, orders
|
|
96
|
+
|
|
77
97
|
@staticmethod
|
|
78
98
|
def epi_smem_bytes_per_stage(
|
|
79
99
|
args: EpilogueArguments,
|
|
80
|
-
|
|
100
|
+
cta_tile_shape_mnk: Tuple[int, int, int],
|
|
81
101
|
epi_tile: Tuple[int, int],
|
|
82
102
|
) -> int:
|
|
83
103
|
postact_dtype = args.mPostAct.element_type
|
|
@@ -108,7 +128,9 @@ class GemmActSm90(GemmSm90):
|
|
|
108
128
|
self,
|
|
109
129
|
params: EpilogueParams,
|
|
110
130
|
epi_smem_tensors: Tuple[cute.Tensor, ...],
|
|
131
|
+
tma_desc_epi_ptrs: list[Optional[cute.Pointer]],
|
|
111
132
|
epi_pipeline: cutlass.pipeline.PipelineAsync,
|
|
133
|
+
epi_store_pipeline: cutlass.pipeline.PipelineAsync,
|
|
112
134
|
epi_read_state: cutlass.pipeline.PipelineState,
|
|
113
135
|
epi_producer_state: cutlass.pipeline.PipelineState,
|
|
114
136
|
tiled_mma: cute.TiledMma,
|
|
@@ -133,7 +155,6 @@ class GemmActSm90(GemmSm90):
|
|
|
133
155
|
) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
|
|
134
156
|
has_C = const_expr(tRS_rC is not None)
|
|
135
157
|
has_D = const_expr(copy_D is not None)
|
|
136
|
-
assert cu_seqlens_m is None, "GemmActSm90 doesn't support varlen_m for now"
|
|
137
158
|
|
|
138
159
|
tma_atom_postact = params.tma_atom_postact
|
|
139
160
|
mPostAct_mnl = params.mPostAct_mnl
|
|
@@ -148,16 +169,17 @@ class GemmActSm90(GemmSm90):
|
|
|
148
169
|
bSG_sPostAct, bSG_gPostAct = self.epilog_gmem_copy_and_partition(
|
|
149
170
|
tma_atom_postact,
|
|
150
171
|
mPostAct_mnl,
|
|
151
|
-
self.
|
|
172
|
+
self.cta_tile_shape_postact_mn,
|
|
152
173
|
self.epi_tile_postact,
|
|
153
174
|
sPostAct,
|
|
154
175
|
tile_coord_mnkl,
|
|
155
176
|
cu_seqlens_m,
|
|
156
177
|
)
|
|
178
|
+
(tma_desc_postact_ptr,) = tma_desc_epi_ptrs
|
|
157
179
|
|
|
158
180
|
# We iterate over epi tiles in the N dimension first before the M dimension
|
|
159
181
|
epi_tile_shape = cute.zipped_divide(
|
|
160
|
-
cute.make_layout(self.
|
|
182
|
+
cute.make_layout(self.cta_tile_shape_mnk[:2]), self.epi_tile
|
|
161
183
|
).shape[1]
|
|
162
184
|
epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
|
|
163
185
|
epi_tile_num = cute.size(epi_tile_shape)
|
|
@@ -214,9 +236,10 @@ class GemmActSm90(GemmSm90):
|
|
|
214
236
|
tma_atom_postact,
|
|
215
237
|
bSG_sPostAct[None, epi_buffer],
|
|
216
238
|
bSG_gPostAct[None, gmem_coord],
|
|
239
|
+
tma_desc_ptr=tma_desc_postact_ptr,
|
|
217
240
|
)
|
|
218
|
-
|
|
219
|
-
|
|
241
|
+
epi_store_pipeline.producer_commit()
|
|
242
|
+
epi_store_pipeline.producer_acquire()
|
|
220
243
|
epilogue_barrier.arrive_and_wait()
|
|
221
244
|
|
|
222
245
|
return epi_read_state, epi_producer_state
|
|
@@ -261,11 +284,12 @@ act_fn_map = {
|
|
|
261
284
|
|
|
262
285
|
|
|
263
286
|
def gemm_act_sm90(
|
|
264
|
-
A: Tensor, # (l, m, k)
|
|
287
|
+
A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
|
|
265
288
|
B: Tensor, # (l, n, k)
|
|
266
|
-
D: Optional[Tensor], # (l, m, n)
|
|
267
|
-
C: Optional[Tensor], # (l, m, n)
|
|
268
|
-
PostAct: Tensor, # (l, m, n)
|
|
289
|
+
D: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
|
|
290
|
+
C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
|
|
291
|
+
PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
|
|
292
|
+
tile_count_semaphore: Optional[Tensor], # (1,)
|
|
269
293
|
activation: Optional[str],
|
|
270
294
|
tile_M: int,
|
|
271
295
|
tile_N: int,
|
|
@@ -273,15 +297,25 @@ def gemm_act_sm90(
|
|
|
273
297
|
cluster_N: int,
|
|
274
298
|
pingpong: bool = False,
|
|
275
299
|
persistent: bool = True,
|
|
276
|
-
|
|
277
|
-
|
|
300
|
+
cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
|
|
301
|
+
A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
|
|
278
302
|
) -> None:
|
|
279
|
-
|
|
303
|
+
if cu_seqlens_m is not None:
|
|
304
|
+
assert persistent, "varlen_m requires persistent=True"
|
|
305
|
+
assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
|
|
306
|
+
if D is not None:
|
|
307
|
+
assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
|
|
308
|
+
assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
|
|
309
|
+
gather_A = A_idx is not None
|
|
310
|
+
if gather_A:
|
|
311
|
+
assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)"
|
|
312
|
+
assert cluster_N == 1, "gather_A requires cluster_N=1"
|
|
280
313
|
assert activation in act_fn_map, f"Unsupported activation {activation}"
|
|
314
|
+
|
|
281
315
|
L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
|
|
282
|
-
A, B, D, C, additional_tensors={"PostAct": PostAct}
|
|
316
|
+
A, B, D, C, additional_tensors={"PostAct": PostAct}, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx
|
|
283
317
|
)
|
|
284
|
-
GemmWrapperBase.permute_tensors(tensor_infos)
|
|
318
|
+
GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None)
|
|
285
319
|
GemmWrapperBase.extract_dtypes(tensor_infos)
|
|
286
320
|
major_configs = {
|
|
287
321
|
"A": ("m", "k", "l"),
|
|
@@ -308,15 +342,23 @@ def gemm_act_sm90(
|
|
|
308
342
|
max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
|
|
309
343
|
GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
|
|
310
344
|
act_fn = act_fn_map[activation]
|
|
311
|
-
epi_args = GemmActSm90.EpilogueArguments(
|
|
312
|
-
tensor_infos["PostAct"].cute_tensor,
|
|
313
|
-
act_fn,
|
|
314
|
-
alpha=Float32(alpha) if alpha != 1.0 else None,
|
|
315
|
-
beta=Float32(beta) if beta != 1.0 else None,
|
|
316
|
-
)
|
|
345
|
+
epi_args = GemmActSm90.EpilogueArguments(tensor_infos["PostAct"].cute_tensor, act_fn)
|
|
317
346
|
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
|
318
347
|
max_active_clusters, tile_count_semaphore
|
|
319
348
|
)
|
|
349
|
+
|
|
350
|
+
# Create varlen arguments if needed (assumes persistent=True when varlen_m)
|
|
351
|
+
varlen_args = GemmWrapperBase.create_varlen_args(
|
|
352
|
+
cu_seqlens_m,
|
|
353
|
+
None, # cu_seqlens_k
|
|
354
|
+
A_idx,
|
|
355
|
+
max_active_clusters,
|
|
356
|
+
cluster_shape_mnk,
|
|
357
|
+
tensor_infos,
|
|
358
|
+
GemmActSm90.num_epi_tensormaps,
|
|
359
|
+
pingpong,
|
|
360
|
+
)
|
|
361
|
+
|
|
320
362
|
current_stream = cutlass_torch.current_stream()
|
|
321
363
|
compile_key = GemmWrapperBase.get_compile_key(
|
|
322
364
|
tensor_infos,
|
|
@@ -326,8 +368,8 @@ def gemm_act_sm90(
|
|
|
326
368
|
pingpong,
|
|
327
369
|
persistent,
|
|
328
370
|
tile_count_semaphore is not None,
|
|
329
|
-
|
|
330
|
-
|
|
371
|
+
cu_seqlens_m is not None,
|
|
372
|
+
A_idx is not None,
|
|
331
373
|
key_tensor_names=("A", "B", "D", "PostAct", "C"),
|
|
332
374
|
)
|
|
333
375
|
cache = gemm_act_sm90.compile_cache
|
|
@@ -339,6 +381,7 @@ def gemm_act_sm90(
|
|
|
339
381
|
cluster_shape_mnk,
|
|
340
382
|
pingpong=pingpong,
|
|
341
383
|
is_persistent=persistent,
|
|
384
|
+
gather_A=gather_A,
|
|
342
385
|
)
|
|
343
386
|
cache[compile_key] = cute.compile(
|
|
344
387
|
gemm,
|
|
@@ -348,8 +391,7 @@ def gemm_act_sm90(
|
|
|
348
391
|
tensor_infos["C"].cute_tensor,
|
|
349
392
|
epi_args,
|
|
350
393
|
scheduler_args,
|
|
351
|
-
|
|
352
|
-
None, # mAIdx
|
|
394
|
+
varlen_args,
|
|
353
395
|
current_stream,
|
|
354
396
|
)
|
|
355
397
|
cache[compile_key](
|
|
@@ -359,8 +401,7 @@ def gemm_act_sm90(
|
|
|
359
401
|
tensor_infos["C"].cute_tensor,
|
|
360
402
|
epi_args,
|
|
361
403
|
scheduler_args,
|
|
362
|
-
|
|
363
|
-
None,
|
|
404
|
+
varlen_args,
|
|
364
405
|
current_stream,
|
|
365
406
|
)
|
|
366
407
|
|
quack/gemm_dact_sm90.py
CHANGED
|
@@ -52,11 +52,11 @@ dact_fn_map = {
|
|
|
52
52
|
|
|
53
53
|
|
|
54
54
|
def gemm_dact_sm90(
|
|
55
|
-
A: Tensor, # (l, m, k)
|
|
55
|
+
A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
|
|
56
56
|
B: Tensor, # (l, n, k)
|
|
57
|
-
Out: Tensor, # (l, m, n)
|
|
58
|
-
PreAct: Tensor, # (l, m, n)
|
|
59
|
-
PostAct: Tensor, # (l, m, n)
|
|
57
|
+
Out: Tensor, # (l, m, n) or (total_m, n) if varlen_m
|
|
58
|
+
PreAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
|
|
59
|
+
PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
|
|
60
60
|
tile_count_semaphore: Optional[Tensor], # (1,)
|
|
61
61
|
activation: Optional[str],
|
|
62
62
|
tile_M: int,
|
|
@@ -65,12 +65,31 @@ def gemm_dact_sm90(
|
|
|
65
65
|
cluster_N: int,
|
|
66
66
|
pingpong: bool = True,
|
|
67
67
|
persistent: bool = True,
|
|
68
|
+
cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
|
|
69
|
+
A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
|
|
68
70
|
) -> None:
|
|
71
|
+
if cu_seqlens_m is not None:
|
|
72
|
+
assert persistent, "varlen_m requires persistent=True"
|
|
73
|
+
assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
|
|
74
|
+
assert Out.stride(-1) == 1, "varlen_m requires Out to be n-major"
|
|
75
|
+
assert PreAct.stride(-1) == 1, "varlen_m requires PreAct to be n-major"
|
|
76
|
+
assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
|
|
77
|
+
gather_A = A_idx is not None
|
|
78
|
+
if gather_A:
|
|
79
|
+
assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)"
|
|
80
|
+
assert cluster_N == 1, "gather_A requires cluster_N=1"
|
|
69
81
|
assert activation in dact_fn_map, f"Unsupported activation {activation}"
|
|
82
|
+
|
|
70
83
|
L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
|
|
71
|
-
A,
|
|
84
|
+
A,
|
|
85
|
+
B,
|
|
86
|
+
Out,
|
|
87
|
+
PreAct,
|
|
88
|
+
additional_tensors={"PostAct": PostAct},
|
|
89
|
+
cu_seqlens_m=cu_seqlens_m,
|
|
90
|
+
A_idx=A_idx,
|
|
72
91
|
)
|
|
73
|
-
GemmWrapperBase.permute_tensors(tensor_infos)
|
|
92
|
+
GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None)
|
|
74
93
|
GemmWrapperBase.extract_dtypes(tensor_infos)
|
|
75
94
|
major_configs = {
|
|
76
95
|
"A": ("m", "k", "l"),
|
|
@@ -101,6 +120,19 @@ def gemm_dact_sm90(
|
|
|
101
120
|
scheduler_args = GemmWrapperBase.create_scheduler_args(
|
|
102
121
|
max_active_clusters, tile_count_semaphore
|
|
103
122
|
)
|
|
123
|
+
|
|
124
|
+
# Create varlen arguments if needed (assumes persistent=True when varlen_m)
|
|
125
|
+
varlen_args = GemmWrapperBase.create_varlen_args(
|
|
126
|
+
cu_seqlens_m,
|
|
127
|
+
None, # cu_seqlens_k
|
|
128
|
+
A_idx,
|
|
129
|
+
max_active_clusters,
|
|
130
|
+
cluster_shape_mnk,
|
|
131
|
+
tensor_infos,
|
|
132
|
+
GemmDActSm90.num_epi_tensormaps,
|
|
133
|
+
pingpong,
|
|
134
|
+
)
|
|
135
|
+
|
|
104
136
|
current_stream = cutlass_torch.current_stream()
|
|
105
137
|
compile_key = GemmWrapperBase.get_compile_key(
|
|
106
138
|
tensor_infos,
|
|
@@ -110,6 +142,8 @@ def gemm_dact_sm90(
|
|
|
110
142
|
pingpong,
|
|
111
143
|
persistent,
|
|
112
144
|
tile_count_semaphore is not None,
|
|
145
|
+
cu_seqlens_m is not None,
|
|
146
|
+
A_idx is not None,
|
|
113
147
|
key_tensor_names=("A", "B", "D", "PostAct", "C"),
|
|
114
148
|
)
|
|
115
149
|
cache = gemm_dact_sm90.compile_cache
|
|
@@ -121,6 +155,7 @@ def gemm_dact_sm90(
|
|
|
121
155
|
cluster_shape_mnk,
|
|
122
156
|
pingpong=pingpong,
|
|
123
157
|
is_persistent=persistent,
|
|
158
|
+
gather_A=gather_A,
|
|
124
159
|
)
|
|
125
160
|
cache[compile_key] = cute.compile(
|
|
126
161
|
gemm,
|
|
@@ -130,8 +165,7 @@ def gemm_dact_sm90(
|
|
|
130
165
|
tensor_infos["C"].cute_tensor,
|
|
131
166
|
epi_args,
|
|
132
167
|
scheduler_args,
|
|
133
|
-
|
|
134
|
-
None, # mAIdx
|
|
168
|
+
varlen_args,
|
|
135
169
|
current_stream,
|
|
136
170
|
)
|
|
137
171
|
cache[compile_key](
|
|
@@ -141,8 +175,7 @@ def gemm_dact_sm90(
|
|
|
141
175
|
tensor_infos["C"].cute_tensor,
|
|
142
176
|
epi_args,
|
|
143
177
|
scheduler_args,
|
|
144
|
-
|
|
145
|
-
None,
|
|
178
|
+
varlen_args,
|
|
146
179
|
current_stream,
|
|
147
180
|
)
|
|
148
181
|
|