quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/autotuner.py +64 -5
  4. quack/broadcast_utils.py +29 -0
  5. quack/compile_utils.py +19 -0
  6. quack/copy_utils.py +487 -0
  7. quack/cross_entropy.py +157 -233
  8. quack/cute_dsl_utils.py +20 -35
  9. quack/gemm.py +194 -0
  10. quack/gemm_act.py +510 -0
  11. quack/gemm_config.py +72 -46
  12. quack/gemm_dact.py +215 -0
  13. quack/gemm_default_epi.py +259 -0
  14. quack/gemm_interface.py +615 -146
  15. quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
  16. quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
  17. quack/gemm_symmetric.py +330 -0
  18. quack/gemm_wrapper_utils.py +182 -23
  19. quack/layout_utils.py +287 -0
  20. quack/linear.py +24 -16
  21. quack/pipeline.py +158 -3
  22. quack/reduce.py +88 -49
  23. quack/reduction_base.py +25 -36
  24. quack/rmsnorm.py +508 -624
  25. quack/sm100_utils.py +62 -0
  26. quack/sm90_utils.py +127 -0
  27. quack/softmax.py +135 -203
  28. quack/sort/bitonic_sort.py +13 -10
  29. quack/sort/utils.py +6 -6
  30. quack/tile_scheduler.py +55 -61
  31. quack/topk.py +409 -85
  32. quack/utils.py +37 -172
  33. quack/varlen_utils.py +370 -6
  34. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
  35. quack_kernels-0.2.3.dist-info/RECORD +44 -0
  36. quack/gemm_act_sm90.py +0 -368
  37. quack/gemm_dact_sm90.py +0 -150
  38. quack/layernorm.py +0 -353
  39. quack/symmetric_dense_gemm_sm90.py +0 -2091
  40. quack_kernels-0.2.1.dist-info/RECORD +0 -37
  41. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
  42. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
  43. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/tile_scheduler.py CHANGED
@@ -6,7 +6,7 @@ from enum import IntEnum
6
6
 
7
7
  import cutlass
8
8
  import cutlass.cute as cute
9
- from cutlass import Int32, Boolean, const_expr
9
+ from cutlass import Int32, Float32, Boolean, const_expr
10
10
 
11
11
  import quack.utils as utils
12
12
  from quack.fast_math import FastDivmod
@@ -135,7 +135,7 @@ class TileScheduler:
135
135
  ip=None,
136
136
  ):
137
137
  self._current_work_linear_idx = current_work_linear_idx
138
- self._num_tiles_executed = num_tiles_executed
138
+ self.num_tiles_executed = num_tiles_executed
139
139
  self._tile_count = tile_count
140
140
  self._scheduler_pipeline = scheduler_pipeline
141
141
  self._pipeline_state = pipeline_state
@@ -251,7 +251,7 @@ class TileScheduler:
251
251
  )
252
252
  tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
253
253
  if const_expr(not params.is_persistent):
254
- is_valid = self._num_tiles_executed == 0
254
+ is_valid = self.num_tiles_executed == 0
255
255
  else:
256
256
  is_valid = self._current_work_linear_idx < cute.size(params.problem_shape_ncluster_mnl)
257
257
  return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
@@ -276,38 +276,6 @@ class TileScheduler:
276
276
  current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0)
277
277
  self._current_work_linear_idx = current_work_linear_idx
278
278
 
279
- # We have to split broadcast_next_work and advance_to_next_work into two functions
280
- # due to a bug in cute-dsl 4.2: https://github.com/NVIDIA/cutlass/issues/2647
281
- @cute.jit
282
- def broadcast_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None):
283
- """is_scheduler_warp should only be true for one warp in the whole cluster"""
284
- params = self.params
285
- if const_expr(params.is_persistent and params.tile_count_semaphore is not None):
286
- current_work_linear_idx = self._current_work_linear_idx
287
- if is_scheduler_warp:
288
- self._scheduler_pipeline.producer_acquire(self._pipeline_state)
289
- lane_idx = cute.arch.lane_idx()
290
- if lane_idx < cute.size(params.cluster_shape_mn):
291
- # cute.printf("Producer bidx = {}, tidx = {}, after empty wait, idx = {}", bidx, tidx, current_work_linear_idx)
292
- if const_expr(cute.size(params.cluster_shape_mn) == 1):
293
- self._tile_count[self._pipeline_state.index] = current_work_linear_idx
294
- self._scheduler_pipeline.producer_commit(self._pipeline_state)
295
- else:
296
- peer_cta_rank_in_cluster = lane_idx
297
- mbar_ptr = self._scheduler_pipeline.producer_get_barrier(
298
- self._pipeline_state
299
- )
300
- cute.arch.mbarrier_arrive_and_expect_tx(
301
- mbar_ptr, 4, peer_cta_rank_in_cluster
302
- )
303
- utils.store_shared_remote(
304
- val=current_work_linear_idx,
305
- smem_ptr=self._tile_count.iterator + self._pipeline_state.index,
306
- mbar_ptr=mbar_ptr,
307
- peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
308
- )
309
- # cute.printf("Producer bidx = {}, tidx = {}, after full arrive", bidx, tidx)
310
-
311
279
  @cute.jit
312
280
  def advance_to_next_work(
313
281
  self,
@@ -319,6 +287,7 @@ class TileScheduler:
319
287
  ):
320
288
  tidx = cute.arch.thread_idx()[0]
321
289
  bidx = cute.arch.block_idx()[0]
290
+ bidz = cute.arch.block_idx()[2]
322
291
  params = self.params
323
292
  if const_expr(params.is_persistent):
324
293
  num_persistent_clusters = cute.arch.grid_dim()[2]
@@ -328,34 +297,60 @@ class TileScheduler:
328
297
  if const_expr(advance_count > 1):
329
298
  self._pipeline_state.advance_iters(advance_count - 1)
330
299
  current_work_linear_idx = self._current_work_linear_idx
331
- if not is_scheduler_warp:
332
- # if tidx % 64 == 0: cute.printf("bidx = {},tidx = {}, before full wait, idx = {}", bidx, tidx, current_work_linear_idx)
300
+ if is_scheduler_warp:
301
+ self._scheduler_pipeline.producer_acquire(self._pipeline_state)
302
+ lane_idx = cute.arch.lane_idx()
303
+ if lane_idx < cute.size(params.cluster_shape_mn):
304
+ # cute.printf("Producer bidx = {}, bidz = {}, tidx = {}, after empty wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
305
+ if const_expr(cute.size(params.cluster_shape_mn) == 1):
306
+ self._tile_count[self._pipeline_state.index] = current_work_linear_idx
307
+ self._scheduler_pipeline.producer_commit(self._pipeline_state)
308
+ else:
309
+ peer_cta_rank_in_cluster = lane_idx
310
+ mbar_ptr = self._scheduler_pipeline.producer_get_barrier(
311
+ self._pipeline_state
312
+ )
313
+ cute.arch.mbarrier_arrive_and_expect_tx(
314
+ mbar_ptr, 4, peer_cta_rank_in_cluster
315
+ )
316
+ utils.store_shared_remote(
317
+ val=current_work_linear_idx,
318
+ smem_ptr=self._tile_count.iterator + self._pipeline_state.index,
319
+ mbar_ptr=mbar_ptr,
320
+ peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
321
+ )
322
+ # cute.printf("Producer bidx = {}, bidz = {}, tidx = {}, after full arrive", bidx, bidz, tidx)
323
+ else:
324
+ # if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, before full wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
333
325
  self._scheduler_pipeline.consumer_wait(self._pipeline_state)
334
- # if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after full wait, idx = {}", bidx, tidx, current_work_linear_idx)
326
+ # if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, after full wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
335
327
  current_work_linear_idx = self._tile_count[self._pipeline_state.index]
336
- # if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after smem read, idx = {}", bidx, tidx, current_work_linear_idx)
328
+ # if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, after smem read, idx = {}", bidx, bidz, tidx, current_work_linear_idx)
329
+ # Need this fence since the STAS from the producer is using the async proxy.
330
+ # Without this, we get race condition / deadlock.
331
+ if const_expr(cute.size(params.cluster_shape_mn) > 1):
332
+ cute.arch.fence_proxy(
333
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
334
+ )
337
335
  cute.arch.sync_warp()
338
336
  with cute.arch.elect_one():
339
- # if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, before empty arrive", bidx, tidx)
337
+ # if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, before empty arrive", bidx, bidz, tidx)
340
338
  self._scheduler_pipeline.consumer_release(self._pipeline_state)
341
- # if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after empty arrive", bidx, tidx)
339
+ # if tidx == 320: cute.printf("bidx = {}, bidz = {}, tidx = {}, idx = {}, after empty arrive", bidx, bidz, tidx, current_work_linear_idx)
340
+ # if tidx == 320: cute.printf("bidx = {}, bidz = {}, tidx = {}, idx = {}, after empty arrive", bidx, bidz, tidx, current_work_linear_idx)
342
341
  self._current_work_linear_idx = current_work_linear_idx
343
342
  self._pipeline_state.advance()
344
- self._num_tiles_executed += Int32(advance_count)
343
+ self.num_tiles_executed += Int32(advance_count)
345
344
 
346
345
  def producer_tail(self):
347
346
  if const_expr(self.params.is_persistent and self.params.tile_count_semaphore is not None):
348
347
  self._scheduler_pipeline.producer_tail(self._pipeline_state)
349
348
 
350
- @property
351
- def num_tiles_executed(self) -> Int32:
352
- return self._num_tiles_executed
353
-
354
349
  def __extract_mlir_values__(self):
355
350
  values, self._values_pos = [], []
356
351
  for obj in [
357
352
  self._current_work_linear_idx,
358
- self._num_tiles_executed,
353
+ self.num_tiles_executed,
359
354
  self._tile_count,
360
355
  self._scheduler_pipeline,
361
356
  self._pipeline_state,
@@ -371,7 +366,7 @@ class TileScheduler:
371
366
  for obj, n_items in zip(
372
367
  [
373
368
  self._current_work_linear_idx,
374
- self._num_tiles_executed,
369
+ self.num_tiles_executed,
375
370
  self._tile_count,
376
371
  self._scheduler_pipeline,
377
372
  self._pipeline_state,
@@ -390,7 +385,7 @@ def triangular_idx_to_coord(idx: Int32) -> Tuple[Int32, Int32]:
390
385
  Convert a triangular index to 2D coordinates.
391
386
  This is used to convert the linear index to 2D coordinates for triangular matrices.
392
387
  """
393
- row = utils.ceil((cute.math.sqrt(2 * idx + 2.25, fastmath=True) - 0.5)) - 1
388
+ row = utils.ceil((utils.sqrt(2 * idx + 2.25) - 0.5)) - 1
394
389
  col = idx - (row * (row + 1)) // 2
395
390
  return row, col
396
391
 
@@ -402,7 +397,7 @@ class TriangularTileScheduler(TileScheduler):
402
397
  class Params(ParamsBase):
403
398
  problem_shape_ncluster_mnl: cute.Shape
404
399
  num_clusters_per_problem_divmod: FastDivmod
405
- group_size_inv_f32: cutlass.Float32
400
+ group_size_inv_f32: Float32
406
401
  num_groups_regular: Int32
407
402
  group_size_divmod: FastDivmod
408
403
  group_size_tail_divmod: FastDivmod
@@ -433,7 +428,7 @@ class TriangularTileScheduler(TileScheduler):
433
428
  return TriangularTileScheduler.Params(
434
429
  problem_shape_ncluster_mnl,
435
430
  FastDivmod.create(num_clusters_per_problem),
436
- cutlass.Float32(1.0 / group_size),
431
+ Float32(1.0 / group_size),
437
432
  num_groups_regular,
438
433
  FastDivmod.create(group_size),
439
434
  # Don't divide by 0
@@ -524,8 +519,7 @@ class TriangularTileScheduler(TileScheduler):
524
519
  group_size = params.group_size_divmod.divisor
525
520
  group_id = (
526
521
  utils.ceil(
527
- (cute.math.sqrt(2 * cluster_id_in_problem + 2.25, fastmath=True) - 0.5)
528
- * params.group_size_inv_f32
522
+ (utils.sqrt(2 * cluster_id_in_problem + 2.25) - 0.5) * params.group_size_inv_f32
529
523
  )
530
524
  - 1
531
525
  )
@@ -562,7 +556,7 @@ class TriangularTileScheduler(TileScheduler):
562
556
  pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
563
557
  tile_coord_mnkl = (pid_m, pid_n, None, bidz)
564
558
  if const_expr(not params.is_persistent):
565
- is_valid = self._num_tiles_executed == 0
559
+ is_valid = self.num_tiles_executed == 0
566
560
  else:
567
561
  is_valid = (
568
562
  self._current_work_linear_idx
@@ -681,7 +675,7 @@ class VarlenMTileScheduler(TileScheduler):
681
675
  ip=None,
682
676
  ):
683
677
  self._current_work_linear_idx = current_work_linear_idx
684
- self._num_tiles_executed = num_tiles_executed
678
+ self.num_tiles_executed = num_tiles_executed
685
679
  self._current_batch_idx = current_batch_idx
686
680
  self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch
687
681
  self._tile_count = tile_count
@@ -878,25 +872,25 @@ class VarlenMTileScheduler(TileScheduler):
878
872
  pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
879
873
  tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
880
874
  if const_expr(not params.is_persistent):
881
- is_valid = self._num_tiles_executed == 0 and batch_idx < num_batch
875
+ is_valid = self.num_tiles_executed == 0 and batch_idx < num_batch
882
876
  else:
883
877
  is_valid = batch_idx < num_batch
884
878
  return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
885
879
 
886
880
  @cute.jit
887
- def fetch_next_work(self, is_scheduler_warp: bool | Boolean, *, loc=None, ip=None):
881
+ def fetch_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None):
888
882
  """is_scheduler_warp should only be true for one warp in the whole cluster"""
889
883
  if const_expr(self.params.tile_count_semaphore is not None):
890
884
  params = self.params
891
885
  current_work_linear_idx = self._current_work_linear_idx
892
886
  if is_scheduler_warp:
893
887
  if cute.arch.lane_idx() == 0:
894
- # cute.printf("before atomicadd, tidx = {}, idx = {}", cute.arch.thread_idx()[0], current_work_linear_idx)
888
+ # cute.printf("before atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_linear_idx)
895
889
  num_persistent_clusters = cute.arch.grid_dim()[2]
896
890
  current_work_linear_idx = num_persistent_clusters + utils.atomic_add_i32(
897
891
  1, params.tile_count_semaphore
898
892
  )
899
- # cute.printf("after atomicadd, tidx = {}, idx = {}", cute.arch.thread_idx()[0], current_work_linear_idx)
893
+ # cute.printf("after atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_linear_idx)
900
894
  # lane 0 already has the right tile_idx, just need to broadcast
901
895
  current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0)
902
896
  self._current_work_linear_idx = current_work_linear_idx
@@ -905,7 +899,7 @@ class VarlenMTileScheduler(TileScheduler):
905
899
  values, self._values_pos = [], []
906
900
  for obj in [
907
901
  self._current_work_linear_idx,
908
- self._num_tiles_executed,
902
+ self.num_tiles_executed,
909
903
  self._current_batch_idx,
910
904
  self._num_work_idx_before_cur_batch,
911
905
  self._tile_count,
@@ -923,7 +917,7 @@ class VarlenMTileScheduler(TileScheduler):
923
917
  for obj, n_items in zip(
924
918
  [
925
919
  self._current_work_linear_idx,
926
- self._num_tiles_executed,
920
+ self.num_tiles_executed,
927
921
  self._current_batch_idx,
928
922
  self._num_work_idx_before_cur_batch,
929
923
  self._tile_count,