quack-kernels 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/broadcast_utils.py +29 -0
  4. quack/compile_utils.py +19 -0
  5. quack/copy_utils.py +487 -0
  6. quack/cross_entropy.py +157 -233
  7. quack/cute_dsl_utils.py +20 -34
  8. quack/gemm.py +194 -0
  9. quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
  10. quack/gemm_config.py +72 -46
  11. quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
  12. quack/gemm_default_epi.py +259 -0
  13. quack/gemm_interface.py +177 -31
  14. quack/gemm_sm100.py +729 -506
  15. quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
  16. quack/gemm_symmetric.py +330 -0
  17. quack/gemm_wrapper_utils.py +3 -1
  18. quack/layout_utils.py +287 -0
  19. quack/linear.py +24 -16
  20. quack/pipeline.py +158 -3
  21. quack/reduce.py +88 -49
  22. quack/reduction_base.py +25 -36
  23. quack/rmsnorm.py +476 -526
  24. quack/sm100_utils.py +62 -0
  25. quack/sm90_utils.py +127 -0
  26. quack/softmax.py +135 -203
  27. quack/sort/bitonic_sort.py +13 -10
  28. quack/sort/utils.py +6 -6
  29. quack/tile_scheduler.py +23 -16
  30. quack/topk.py +409 -85
  31. quack/utils.py +32 -220
  32. quack/varlen_utils.py +370 -1
  33. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/METADATA +4 -2
  34. quack_kernels-0.2.4.dist-info/RECORD +44 -0
  35. quack/layernorm.py +0 -353
  36. quack/symmetric_dense_gemm_sm90.py +0 -2091
  37. quack_kernels-0.2.2.dist-info/RECORD +0 -37
  38. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/WHEEL +0 -0
  39. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/licenses/LICENSE +0 -0
  40. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/top_level.txt +0 -0
quack/tile_scheduler.py CHANGED
@@ -6,7 +6,7 @@ from enum import IntEnum
6
6
 
7
7
  import cutlass
8
8
  import cutlass.cute as cute
9
- from cutlass import Int32, Boolean, const_expr
9
+ from cutlass import Int32, Float32, Boolean, const_expr
10
10
 
11
11
  import quack.utils as utils
12
12
  from quack.fast_math import FastDivmod
@@ -287,6 +287,7 @@ class TileScheduler:
287
287
  ):
288
288
  tidx = cute.arch.thread_idx()[0]
289
289
  bidx = cute.arch.block_idx()[0]
290
+ bidz = cute.arch.block_idx()[2]
290
291
  params = self.params
291
292
  if const_expr(params.is_persistent):
292
293
  num_persistent_clusters = cute.arch.grid_dim()[2]
@@ -300,7 +301,7 @@ class TileScheduler:
300
301
  self._scheduler_pipeline.producer_acquire(self._pipeline_state)
301
302
  lane_idx = cute.arch.lane_idx()
302
303
  if lane_idx < cute.size(params.cluster_shape_mn):
303
- # cute.printf("Producer bidx = {}, tidx = {}, after empty wait, idx = {}", bidx, tidx, current_work_linear_idx)
304
+ # cute.printf("Producer bidx = {}, bidz = {}, tidx = {}, after empty wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
304
305
  if const_expr(cute.size(params.cluster_shape_mn) == 1):
305
306
  self._tile_count[self._pipeline_state.index] = current_work_linear_idx
306
307
  self._scheduler_pipeline.producer_commit(self._pipeline_state)
@@ -318,18 +319,25 @@ class TileScheduler:
318
319
  mbar_ptr=mbar_ptr,
319
320
  peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
320
321
  )
321
- # cute.printf("Producer bidx = {}, tidx = {}, after full arrive", bidx, tidx)
322
+ # cute.printf("Producer bidx = {}, bidz = {}, tidx = {}, after full arrive", bidx, bidz, tidx)
322
323
  else:
323
- # if tidx % 64 == 0: cute.printf("bidx = {},tidx = {}, before full wait, idx = {}", bidx, tidx, current_work_linear_idx)
324
+ # if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, before full wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
324
325
  self._scheduler_pipeline.consumer_wait(self._pipeline_state)
325
- # if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after full wait, idx = {}", bidx, tidx, current_work_linear_idx)
326
+ # if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, after full wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
326
327
  current_work_linear_idx = self._tile_count[self._pipeline_state.index]
327
- # if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after smem read, idx = {}", bidx, tidx, current_work_linear_idx)
328
+ # if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, after smem read, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
329
+ # Need this fence since the STAS from the producer is using the async proxy.
330
+ # Without this, we get race condition / deadlock.
331
+ if const_expr(cute.size(params.cluster_shape_mn) > 1):
332
+ cute.arch.fence_proxy(
333
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
334
+ )
328
335
  cute.arch.sync_warp()
329
336
  with cute.arch.elect_one():
330
- # if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, before empty arrive", bidx, tidx)
337
+ # if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, before empty arrive", bidx, bidz, tidx)
331
338
  self._scheduler_pipeline.consumer_release(self._pipeline_state)
332
- # if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after empty arrive", bidx, tidx)
339
+ # if tidx == 320: cute.printf("bidx = {}, bidz = {}, tidx = {}, idx = {}, after empty arrive", bidx, bidz, tidx, current_work_linear_idx)
340
+ # if tidx == 320: cute.printf("bidx = {}, bidz = {}, tidx = {}, idx = {}, after empty arrive", bidx, bidz, tidx, current_work_linear_idx)
333
341
  self._current_work_linear_idx = current_work_linear_idx
334
342
  self._pipeline_state.advance()
335
343
  self.num_tiles_executed += Int32(advance_count)
@@ -377,7 +385,7 @@ def triangular_idx_to_coord(idx: Int32) -> Tuple[Int32, Int32]:
377
385
  Convert a triangular index to 2D coordinates.
378
386
  This is used to convert the linear index to 2D coordinates for triangular matrices.
379
387
  """
380
- row = utils.ceil((cute.math.sqrt(2 * idx + 2.25, fastmath=True) - 0.5)) - 1
388
+ row = utils.ceil((utils.sqrt(2 * idx + 2.25) - 0.5)) - 1
381
389
  col = idx - (row * (row + 1)) // 2
382
390
  return row, col
383
391
 
@@ -389,7 +397,7 @@ class TriangularTileScheduler(TileScheduler):
389
397
  class Params(ParamsBase):
390
398
  problem_shape_ncluster_mnl: cute.Shape
391
399
  num_clusters_per_problem_divmod: FastDivmod
392
- group_size_inv_f32: cutlass.Float32
400
+ group_size_inv_f32: Float32
393
401
  num_groups_regular: Int32
394
402
  group_size_divmod: FastDivmod
395
403
  group_size_tail_divmod: FastDivmod
@@ -420,7 +428,7 @@ class TriangularTileScheduler(TileScheduler):
420
428
  return TriangularTileScheduler.Params(
421
429
  problem_shape_ncluster_mnl,
422
430
  FastDivmod.create(num_clusters_per_problem),
423
- cutlass.Float32(1.0 / group_size),
431
+ Float32(1.0 / group_size),
424
432
  num_groups_regular,
425
433
  FastDivmod.create(group_size),
426
434
  # Don't divide by 0
@@ -511,8 +519,7 @@ class TriangularTileScheduler(TileScheduler):
511
519
  group_size = params.group_size_divmod.divisor
512
520
  group_id = (
513
521
  utils.ceil(
514
- (cute.math.sqrt(2 * cluster_id_in_problem + 2.25, fastmath=True) - 0.5)
515
- * params.group_size_inv_f32
522
+ (utils.sqrt(2 * cluster_id_in_problem + 2.25) - 0.5) * params.group_size_inv_f32
516
523
  )
517
524
  - 1
518
525
  )
@@ -871,19 +878,19 @@ class VarlenMTileScheduler(TileScheduler):
871
878
  return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
872
879
 
873
880
  @cute.jit
874
- def fetch_next_work(self, is_scheduler_warp: bool | Boolean, *, loc=None, ip=None):
881
+ def fetch_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None):
875
882
  """is_scheduler_warp should only be true for one warp in the whole cluster"""
876
883
  if const_expr(self.params.tile_count_semaphore is not None):
877
884
  params = self.params
878
885
  current_work_linear_idx = self._current_work_linear_idx
879
886
  if is_scheduler_warp:
880
887
  if cute.arch.lane_idx() == 0:
881
- # cute.printf("before atomicadd, tidx = {}, idx = {}", cute.arch.thread_idx()[0], current_work_linear_idx)
888
+ # cute.printf("before atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_linear_idx)
882
889
  num_persistent_clusters = cute.arch.grid_dim()[2]
883
890
  current_work_linear_idx = num_persistent_clusters + utils.atomic_add_i32(
884
891
  1, params.tile_count_semaphore
885
892
  )
886
- # cute.printf("after atomicadd, tidx = {}, idx = {}", cute.arch.thread_idx()[0], current_work_linear_idx)
893
+ # cute.printf("after atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_linear_idx)
887
894
  # lane 0 already has the right tile_idx, just need to broadcast
888
895
  current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0)
889
896
  self._current_work_linear_idx = current_work_linear_idx