quack-kernels 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -34
- quack/gemm.py +194 -0
- quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
- quack/gemm_config.py +72 -46
- quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +177 -31
- quack/gemm_sm100.py +729 -506
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +3 -1
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +476 -526
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +23 -16
- quack/topk.py +409 -85
- quack/utils.py +32 -220
- quack/varlen_utils.py +370 -1
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/METADATA +4 -2
- quack_kernels-0.2.4.dist-info/RECORD +44 -0
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.2.dist-info/RECORD +0 -37
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/top_level.txt +0 -0
quack/tile_scheduler.py
CHANGED
|
@@ -6,7 +6,7 @@ from enum import IntEnum
|
|
|
6
6
|
|
|
7
7
|
import cutlass
|
|
8
8
|
import cutlass.cute as cute
|
|
9
|
-
from cutlass import Int32, Boolean, const_expr
|
|
9
|
+
from cutlass import Int32, Float32, Boolean, const_expr
|
|
10
10
|
|
|
11
11
|
import quack.utils as utils
|
|
12
12
|
from quack.fast_math import FastDivmod
|
|
@@ -287,6 +287,7 @@ class TileScheduler:
|
|
|
287
287
|
):
|
|
288
288
|
tidx = cute.arch.thread_idx()[0]
|
|
289
289
|
bidx = cute.arch.block_idx()[0]
|
|
290
|
+
bidz = cute.arch.block_idx()[2]
|
|
290
291
|
params = self.params
|
|
291
292
|
if const_expr(params.is_persistent):
|
|
292
293
|
num_persistent_clusters = cute.arch.grid_dim()[2]
|
|
@@ -300,7 +301,7 @@ class TileScheduler:
|
|
|
300
301
|
self._scheduler_pipeline.producer_acquire(self._pipeline_state)
|
|
301
302
|
lane_idx = cute.arch.lane_idx()
|
|
302
303
|
if lane_idx < cute.size(params.cluster_shape_mn):
|
|
303
|
-
# cute.printf("Producer bidx = {}, tidx = {}, after empty wait, idx = {}", bidx, tidx, current_work_linear_idx)
|
|
304
|
+
# cute.printf("Producer bidx = {}, bidz = {}, tidx = {}, after empty wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
|
|
304
305
|
if const_expr(cute.size(params.cluster_shape_mn) == 1):
|
|
305
306
|
self._tile_count[self._pipeline_state.index] = current_work_linear_idx
|
|
306
307
|
self._scheduler_pipeline.producer_commit(self._pipeline_state)
|
|
@@ -318,18 +319,25 @@ class TileScheduler:
|
|
|
318
319
|
mbar_ptr=mbar_ptr,
|
|
319
320
|
peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
|
|
320
321
|
)
|
|
321
|
-
# cute.printf("Producer bidx = {}, tidx = {}, after full arrive", bidx, tidx)
|
|
322
|
+
# cute.printf("Producer bidx = {}, bidz = {}, tidx = {}, after full arrive", bidx, bidz, tidx)
|
|
322
323
|
else:
|
|
323
|
-
# if tidx %
|
|
324
|
+
# if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, before full wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
|
|
324
325
|
self._scheduler_pipeline.consumer_wait(self._pipeline_state)
|
|
325
|
-
# if tidx %
|
|
326
|
+
# if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, after full wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
|
|
326
327
|
current_work_linear_idx = self._tile_count[self._pipeline_state.index]
|
|
327
|
-
# if tidx %
|
|
328
|
+
# if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, after smem read, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
|
|
329
|
+
# Need this fence since the STAS from the producer is using the async proxy.
|
|
330
|
+
# Without this, we get race condition / deadlock.
|
|
331
|
+
if const_expr(cute.size(params.cluster_shape_mn) > 1):
|
|
332
|
+
cute.arch.fence_proxy(
|
|
333
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
334
|
+
)
|
|
328
335
|
cute.arch.sync_warp()
|
|
329
336
|
with cute.arch.elect_one():
|
|
330
|
-
# if tidx %
|
|
337
|
+
# if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, before empty arrive", bidx, bidz, tidx)
|
|
331
338
|
self._scheduler_pipeline.consumer_release(self._pipeline_state)
|
|
332
|
-
# if tidx
|
|
339
|
+
# if tidx == 320: cute.printf("bidx = {}, bidz = {}, tidx = {}, idx = {}, after empty arrive", bidx, bidz, tidx, current_work_linear_idx)
|
|
340
|
+
# if tidx == 320: cute.printf("bidx = {}, bidz = {}, tidx = {}, idx = {}, after empty arrive", bidx, bidz, tidx, current_work_linear_idx)
|
|
333
341
|
self._current_work_linear_idx = current_work_linear_idx
|
|
334
342
|
self._pipeline_state.advance()
|
|
335
343
|
self.num_tiles_executed += Int32(advance_count)
|
|
@@ -377,7 +385,7 @@ def triangular_idx_to_coord(idx: Int32) -> Tuple[Int32, Int32]:
|
|
|
377
385
|
Convert a triangular index to 2D coordinates.
|
|
378
386
|
This is used to convert the linear index to 2D coordinates for triangular matrices.
|
|
379
387
|
"""
|
|
380
|
-
row = utils.ceil((
|
|
388
|
+
row = utils.ceil((utils.sqrt(2 * idx + 2.25) - 0.5)) - 1
|
|
381
389
|
col = idx - (row * (row + 1)) // 2
|
|
382
390
|
return row, col
|
|
383
391
|
|
|
@@ -389,7 +397,7 @@ class TriangularTileScheduler(TileScheduler):
|
|
|
389
397
|
class Params(ParamsBase):
|
|
390
398
|
problem_shape_ncluster_mnl: cute.Shape
|
|
391
399
|
num_clusters_per_problem_divmod: FastDivmod
|
|
392
|
-
group_size_inv_f32:
|
|
400
|
+
group_size_inv_f32: Float32
|
|
393
401
|
num_groups_regular: Int32
|
|
394
402
|
group_size_divmod: FastDivmod
|
|
395
403
|
group_size_tail_divmod: FastDivmod
|
|
@@ -420,7 +428,7 @@ class TriangularTileScheduler(TileScheduler):
|
|
|
420
428
|
return TriangularTileScheduler.Params(
|
|
421
429
|
problem_shape_ncluster_mnl,
|
|
422
430
|
FastDivmod.create(num_clusters_per_problem),
|
|
423
|
-
|
|
431
|
+
Float32(1.0 / group_size),
|
|
424
432
|
num_groups_regular,
|
|
425
433
|
FastDivmod.create(group_size),
|
|
426
434
|
# Don't divide by 0
|
|
@@ -511,8 +519,7 @@ class TriangularTileScheduler(TileScheduler):
|
|
|
511
519
|
group_size = params.group_size_divmod.divisor
|
|
512
520
|
group_id = (
|
|
513
521
|
utils.ceil(
|
|
514
|
-
(
|
|
515
|
-
* params.group_size_inv_f32
|
|
522
|
+
(utils.sqrt(2 * cluster_id_in_problem + 2.25) - 0.5) * params.group_size_inv_f32
|
|
516
523
|
)
|
|
517
524
|
- 1
|
|
518
525
|
)
|
|
@@ -871,19 +878,19 @@ class VarlenMTileScheduler(TileScheduler):
|
|
|
871
878
|
return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
|
|
872
879
|
|
|
873
880
|
@cute.jit
|
|
874
|
-
def fetch_next_work(self, is_scheduler_warp: bool | Boolean, *, loc=None, ip=None):
|
|
881
|
+
def fetch_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None):
|
|
875
882
|
"""is_scheduler_warp should only be true for one warp in the whole cluster"""
|
|
876
883
|
if const_expr(self.params.tile_count_semaphore is not None):
|
|
877
884
|
params = self.params
|
|
878
885
|
current_work_linear_idx = self._current_work_linear_idx
|
|
879
886
|
if is_scheduler_warp:
|
|
880
887
|
if cute.arch.lane_idx() == 0:
|
|
881
|
-
# cute.printf("before atomicadd, tidx = {}, idx = {}", cute.arch.thread_idx()[0], current_work_linear_idx)
|
|
888
|
+
# cute.printf("before atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_linear_idx)
|
|
882
889
|
num_persistent_clusters = cute.arch.grid_dim()[2]
|
|
883
890
|
current_work_linear_idx = num_persistent_clusters + utils.atomic_add_i32(
|
|
884
891
|
1, params.tile_count_semaphore
|
|
885
892
|
)
|
|
886
|
-
# cute.printf("after atomicadd, tidx = {}, idx = {}", cute.arch.thread_idx()[0], current_work_linear_idx)
|
|
893
|
+
# cute.printf("after atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_linear_idx)
|
|
887
894
|
# lane 0 already has the right tile_idx, just need to broadcast
|
|
888
895
|
current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0)
|
|
889
896
|
self._current_work_linear_idx = current_work_linear_idx
|