quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/autotuner.py +64 -5
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -35
- quack/gemm.py +194 -0
- quack/gemm_act.py +510 -0
- quack/gemm_config.py +72 -46
- quack/gemm_dact.py +215 -0
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +615 -146
- quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +182 -23
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +508 -624
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +55 -61
- quack/topk.py +409 -85
- quack/utils.py +37 -172
- quack/varlen_utils.py +370 -6
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/gemm_act_sm90.py +0 -368
- quack/gemm_dact_sm90.py +0 -150
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.1.dist-info/RECORD +0 -37
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/tile_scheduler.py
CHANGED
|
@@ -6,7 +6,7 @@ from enum import IntEnum
|
|
|
6
6
|
|
|
7
7
|
import cutlass
|
|
8
8
|
import cutlass.cute as cute
|
|
9
|
-
from cutlass import Int32, Boolean, const_expr
|
|
9
|
+
from cutlass import Int32, Float32, Boolean, const_expr
|
|
10
10
|
|
|
11
11
|
import quack.utils as utils
|
|
12
12
|
from quack.fast_math import FastDivmod
|
|
@@ -135,7 +135,7 @@ class TileScheduler:
|
|
|
135
135
|
ip=None,
|
|
136
136
|
):
|
|
137
137
|
self._current_work_linear_idx = current_work_linear_idx
|
|
138
|
-
self.
|
|
138
|
+
self.num_tiles_executed = num_tiles_executed
|
|
139
139
|
self._tile_count = tile_count
|
|
140
140
|
self._scheduler_pipeline = scheduler_pipeline
|
|
141
141
|
self._pipeline_state = pipeline_state
|
|
@@ -251,7 +251,7 @@ class TileScheduler:
|
|
|
251
251
|
)
|
|
252
252
|
tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
|
|
253
253
|
if const_expr(not params.is_persistent):
|
|
254
|
-
is_valid = self.
|
|
254
|
+
is_valid = self.num_tiles_executed == 0
|
|
255
255
|
else:
|
|
256
256
|
is_valid = self._current_work_linear_idx < cute.size(params.problem_shape_ncluster_mnl)
|
|
257
257
|
return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
|
|
@@ -276,38 +276,6 @@ class TileScheduler:
|
|
|
276
276
|
current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0)
|
|
277
277
|
self._current_work_linear_idx = current_work_linear_idx
|
|
278
278
|
|
|
279
|
-
# We have to split broadcast_next_work and advance_to_next_work into two functions
|
|
280
|
-
# due to a bug in cute-dsl 4.2: https://github.com/NVIDIA/cutlass/issues/2647
|
|
281
|
-
@cute.jit
|
|
282
|
-
def broadcast_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None):
|
|
283
|
-
"""is_scheduler_warp should only be true for one warp in the whole cluster"""
|
|
284
|
-
params = self.params
|
|
285
|
-
if const_expr(params.is_persistent and params.tile_count_semaphore is not None):
|
|
286
|
-
current_work_linear_idx = self._current_work_linear_idx
|
|
287
|
-
if is_scheduler_warp:
|
|
288
|
-
self._scheduler_pipeline.producer_acquire(self._pipeline_state)
|
|
289
|
-
lane_idx = cute.arch.lane_idx()
|
|
290
|
-
if lane_idx < cute.size(params.cluster_shape_mn):
|
|
291
|
-
# cute.printf("Producer bidx = {}, tidx = {}, after empty wait, idx = {}", bidx, tidx, current_work_linear_idx)
|
|
292
|
-
if const_expr(cute.size(params.cluster_shape_mn) == 1):
|
|
293
|
-
self._tile_count[self._pipeline_state.index] = current_work_linear_idx
|
|
294
|
-
self._scheduler_pipeline.producer_commit(self._pipeline_state)
|
|
295
|
-
else:
|
|
296
|
-
peer_cta_rank_in_cluster = lane_idx
|
|
297
|
-
mbar_ptr = self._scheduler_pipeline.producer_get_barrier(
|
|
298
|
-
self._pipeline_state
|
|
299
|
-
)
|
|
300
|
-
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
301
|
-
mbar_ptr, 4, peer_cta_rank_in_cluster
|
|
302
|
-
)
|
|
303
|
-
utils.store_shared_remote(
|
|
304
|
-
val=current_work_linear_idx,
|
|
305
|
-
smem_ptr=self._tile_count.iterator + self._pipeline_state.index,
|
|
306
|
-
mbar_ptr=mbar_ptr,
|
|
307
|
-
peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
|
|
308
|
-
)
|
|
309
|
-
# cute.printf("Producer bidx = {}, tidx = {}, after full arrive", bidx, tidx)
|
|
310
|
-
|
|
311
279
|
@cute.jit
|
|
312
280
|
def advance_to_next_work(
|
|
313
281
|
self,
|
|
@@ -319,6 +287,7 @@ class TileScheduler:
|
|
|
319
287
|
):
|
|
320
288
|
tidx = cute.arch.thread_idx()[0]
|
|
321
289
|
bidx = cute.arch.block_idx()[0]
|
|
290
|
+
bidz = cute.arch.block_idx()[2]
|
|
322
291
|
params = self.params
|
|
323
292
|
if const_expr(params.is_persistent):
|
|
324
293
|
num_persistent_clusters = cute.arch.grid_dim()[2]
|
|
@@ -328,34 +297,60 @@ class TileScheduler:
|
|
|
328
297
|
if const_expr(advance_count > 1):
|
|
329
298
|
self._pipeline_state.advance_iters(advance_count - 1)
|
|
330
299
|
current_work_linear_idx = self._current_work_linear_idx
|
|
331
|
-
if
|
|
332
|
-
|
|
300
|
+
if is_scheduler_warp:
|
|
301
|
+
self._scheduler_pipeline.producer_acquire(self._pipeline_state)
|
|
302
|
+
lane_idx = cute.arch.lane_idx()
|
|
303
|
+
if lane_idx < cute.size(params.cluster_shape_mn):
|
|
304
|
+
# cute.printf("Producer bidx = {}, bidz = {}, tidx = {}, after empty wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
|
|
305
|
+
if const_expr(cute.size(params.cluster_shape_mn) == 1):
|
|
306
|
+
self._tile_count[self._pipeline_state.index] = current_work_linear_idx
|
|
307
|
+
self._scheduler_pipeline.producer_commit(self._pipeline_state)
|
|
308
|
+
else:
|
|
309
|
+
peer_cta_rank_in_cluster = lane_idx
|
|
310
|
+
mbar_ptr = self._scheduler_pipeline.producer_get_barrier(
|
|
311
|
+
self._pipeline_state
|
|
312
|
+
)
|
|
313
|
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
314
|
+
mbar_ptr, 4, peer_cta_rank_in_cluster
|
|
315
|
+
)
|
|
316
|
+
utils.store_shared_remote(
|
|
317
|
+
val=current_work_linear_idx,
|
|
318
|
+
smem_ptr=self._tile_count.iterator + self._pipeline_state.index,
|
|
319
|
+
mbar_ptr=mbar_ptr,
|
|
320
|
+
peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
|
|
321
|
+
)
|
|
322
|
+
# cute.printf("Producer bidx = {}, bidz = {}, tidx = {}, after full arrive", bidx, bidz, tidx)
|
|
323
|
+
else:
|
|
324
|
+
# if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, before full wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
|
|
333
325
|
self._scheduler_pipeline.consumer_wait(self._pipeline_state)
|
|
334
|
-
# if tidx %
|
|
326
|
+
# if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, after full wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
|
|
335
327
|
current_work_linear_idx = self._tile_count[self._pipeline_state.index]
|
|
336
|
-
# if tidx %
|
|
328
|
+
# if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, after smem read, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
|
|
329
|
+
# Need this fence since the STAS from the producer is using the async proxy.
|
|
330
|
+
# Without this, we get race condition / deadlock.
|
|
331
|
+
if const_expr(cute.size(params.cluster_shape_mn) > 1):
|
|
332
|
+
cute.arch.fence_proxy(
|
|
333
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
334
|
+
)
|
|
337
335
|
cute.arch.sync_warp()
|
|
338
336
|
with cute.arch.elect_one():
|
|
339
|
-
# if tidx %
|
|
337
|
+
# if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, before empty arrive", bidx, bidz, tidx)
|
|
340
338
|
self._scheduler_pipeline.consumer_release(self._pipeline_state)
|
|
341
|
-
# if tidx
|
|
339
|
+
# if tidx == 320: cute.printf("bidx = {}, bidz = {}, tidx = {}, idx = {}, after empty arrive", bidx, bidz, tidx, current_work_linear_idx)
|
|
340
|
+
# if tidx == 320: cute.printf("bidx = {}, bidz = {}, tidx = {}, idx = {}, after empty arrive", bidx, bidz, tidx, current_work_linear_idx)
|
|
342
341
|
self._current_work_linear_idx = current_work_linear_idx
|
|
343
342
|
self._pipeline_state.advance()
|
|
344
|
-
self.
|
|
343
|
+
self.num_tiles_executed += Int32(advance_count)
|
|
345
344
|
|
|
346
345
|
def producer_tail(self):
|
|
347
346
|
if const_expr(self.params.is_persistent and self.params.tile_count_semaphore is not None):
|
|
348
347
|
self._scheduler_pipeline.producer_tail(self._pipeline_state)
|
|
349
348
|
|
|
350
|
-
@property
|
|
351
|
-
def num_tiles_executed(self) -> Int32:
|
|
352
|
-
return self._num_tiles_executed
|
|
353
|
-
|
|
354
349
|
def __extract_mlir_values__(self):
|
|
355
350
|
values, self._values_pos = [], []
|
|
356
351
|
for obj in [
|
|
357
352
|
self._current_work_linear_idx,
|
|
358
|
-
self.
|
|
353
|
+
self.num_tiles_executed,
|
|
359
354
|
self._tile_count,
|
|
360
355
|
self._scheduler_pipeline,
|
|
361
356
|
self._pipeline_state,
|
|
@@ -371,7 +366,7 @@ class TileScheduler:
|
|
|
371
366
|
for obj, n_items in zip(
|
|
372
367
|
[
|
|
373
368
|
self._current_work_linear_idx,
|
|
374
|
-
self.
|
|
369
|
+
self.num_tiles_executed,
|
|
375
370
|
self._tile_count,
|
|
376
371
|
self._scheduler_pipeline,
|
|
377
372
|
self._pipeline_state,
|
|
@@ -390,7 +385,7 @@ def triangular_idx_to_coord(idx: Int32) -> Tuple[Int32, Int32]:
|
|
|
390
385
|
Convert a triangular index to 2D coordinates.
|
|
391
386
|
This is used to convert the linear index to 2D coordinates for triangular matrices.
|
|
392
387
|
"""
|
|
393
|
-
row = utils.ceil((
|
|
388
|
+
row = utils.ceil((utils.sqrt(2 * idx + 2.25) - 0.5)) - 1
|
|
394
389
|
col = idx - (row * (row + 1)) // 2
|
|
395
390
|
return row, col
|
|
396
391
|
|
|
@@ -402,7 +397,7 @@ class TriangularTileScheduler(TileScheduler):
|
|
|
402
397
|
class Params(ParamsBase):
|
|
403
398
|
problem_shape_ncluster_mnl: cute.Shape
|
|
404
399
|
num_clusters_per_problem_divmod: FastDivmod
|
|
405
|
-
group_size_inv_f32:
|
|
400
|
+
group_size_inv_f32: Float32
|
|
406
401
|
num_groups_regular: Int32
|
|
407
402
|
group_size_divmod: FastDivmod
|
|
408
403
|
group_size_tail_divmod: FastDivmod
|
|
@@ -433,7 +428,7 @@ class TriangularTileScheduler(TileScheduler):
|
|
|
433
428
|
return TriangularTileScheduler.Params(
|
|
434
429
|
problem_shape_ncluster_mnl,
|
|
435
430
|
FastDivmod.create(num_clusters_per_problem),
|
|
436
|
-
|
|
431
|
+
Float32(1.0 / group_size),
|
|
437
432
|
num_groups_regular,
|
|
438
433
|
FastDivmod.create(group_size),
|
|
439
434
|
# Don't divide by 0
|
|
@@ -524,8 +519,7 @@ class TriangularTileScheduler(TileScheduler):
|
|
|
524
519
|
group_size = params.group_size_divmod.divisor
|
|
525
520
|
group_id = (
|
|
526
521
|
utils.ceil(
|
|
527
|
-
(
|
|
528
|
-
* params.group_size_inv_f32
|
|
522
|
+
(utils.sqrt(2 * cluster_id_in_problem + 2.25) - 0.5) * params.group_size_inv_f32
|
|
529
523
|
)
|
|
530
524
|
- 1
|
|
531
525
|
)
|
|
@@ -562,7 +556,7 @@ class TriangularTileScheduler(TileScheduler):
|
|
|
562
556
|
pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
|
|
563
557
|
tile_coord_mnkl = (pid_m, pid_n, None, bidz)
|
|
564
558
|
if const_expr(not params.is_persistent):
|
|
565
|
-
is_valid = self.
|
|
559
|
+
is_valid = self.num_tiles_executed == 0
|
|
566
560
|
else:
|
|
567
561
|
is_valid = (
|
|
568
562
|
self._current_work_linear_idx
|
|
@@ -681,7 +675,7 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
681
675
|
ip=None,
|
|
682
676
|
):
|
|
683
677
|
self._current_work_linear_idx = current_work_linear_idx
|
|
684
|
-
self.
|
|
678
|
+
self.num_tiles_executed = num_tiles_executed
|
|
685
679
|
self._current_batch_idx = current_batch_idx
|
|
686
680
|
self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch
|
|
687
681
|
self._tile_count = tile_count
|
|
@@ -878,25 +872,25 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
878
872
|
pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
|
|
879
873
|
tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
|
|
880
874
|
if const_expr(not params.is_persistent):
|
|
881
|
-
is_valid = self.
|
|
875
|
+
is_valid = self.num_tiles_executed == 0 and batch_idx < num_batch
|
|
882
876
|
else:
|
|
883
877
|
is_valid = batch_idx < num_batch
|
|
884
878
|
return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
|
|
885
879
|
|
|
886
880
|
@cute.jit
|
|
887
|
-
def fetch_next_work(self, is_scheduler_warp: bool | Boolean, *, loc=None, ip=None):
|
|
881
|
+
def fetch_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None):
|
|
888
882
|
"""is_scheduler_warp should only be true for one warp in the whole cluster"""
|
|
889
883
|
if const_expr(self.params.tile_count_semaphore is not None):
|
|
890
884
|
params = self.params
|
|
891
885
|
current_work_linear_idx = self._current_work_linear_idx
|
|
892
886
|
if is_scheduler_warp:
|
|
893
887
|
if cute.arch.lane_idx() == 0:
|
|
894
|
-
# cute.printf("before atomicadd, tidx = {}, idx = {}", cute.arch.thread_idx()[0], current_work_linear_idx)
|
|
888
|
+
# cute.printf("before atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_linear_idx)
|
|
895
889
|
num_persistent_clusters = cute.arch.grid_dim()[2]
|
|
896
890
|
current_work_linear_idx = num_persistent_clusters + utils.atomic_add_i32(
|
|
897
891
|
1, params.tile_count_semaphore
|
|
898
892
|
)
|
|
899
|
-
# cute.printf("after atomicadd, tidx = {}, idx = {}", cute.arch.thread_idx()[0], current_work_linear_idx)
|
|
893
|
+
# cute.printf("after atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_linear_idx)
|
|
900
894
|
# lane 0 already has the right tile_idx, just need to broadcast
|
|
901
895
|
current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0)
|
|
902
896
|
self._current_work_linear_idx = current_work_linear_idx
|
|
@@ -905,7 +899,7 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
905
899
|
values, self._values_pos = [], []
|
|
906
900
|
for obj in [
|
|
907
901
|
self._current_work_linear_idx,
|
|
908
|
-
self.
|
|
902
|
+
self.num_tiles_executed,
|
|
909
903
|
self._current_batch_idx,
|
|
910
904
|
self._num_work_idx_before_cur_batch,
|
|
911
905
|
self._tile_count,
|
|
@@ -923,7 +917,7 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
923
917
|
for obj, n_items in zip(
|
|
924
918
|
[
|
|
925
919
|
self._current_work_linear_idx,
|
|
926
|
-
self.
|
|
920
|
+
self.num_tiles_executed,
|
|
927
921
|
self._current_batch_idx,
|
|
928
922
|
self._num_work_idx_before_cur_batch,
|
|
929
923
|
self._tile_count,
|