quack-kernels 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/softmax.py CHANGED
@@ -9,7 +9,9 @@ import cutlass.cute as cute
9
9
  from cutlass.cute.runtime import from_dlpack
10
10
 
11
11
  import quack.utils as utils
12
- from quack.reduction_base import ReductionBase, torch2cute_dtype_map
12
+ from quack.reduce import row_reduce, online_softmax_reduce
13
+ from quack.reduction_base import ReductionBase
14
+ from quack.cute_dsl_utils import torch2cute_dtype_map
13
15
 
14
16
 
15
17
  class Softmax(ReductionBase):
@@ -147,7 +149,7 @@ class Softmax(ReductionBase):
147
149
  x = tXrX.load().to(cute.Float32)
148
150
  threads_per_row = tv_layout.shape[0][0]
149
151
  if cutlass.const_expr(not self.online_softmax):
150
- max_x = utils.row_reduce(
152
+ max_x = row_reduce(
151
153
  x,
152
154
  cute.ReductionOp.MAX,
153
155
  threads_per_row,
@@ -158,7 +160,7 @@ class Softmax(ReductionBase):
158
160
  )
159
161
  log2_e = math.log2(math.e)
160
162
  exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
161
- denom = utils.row_reduce(
163
+ denom = row_reduce(
162
164
  exp_x,
163
165
  cute.ReductionOp.ADD,
164
166
  threads_per_row,
@@ -167,7 +169,7 @@ class Softmax(ReductionBase):
167
169
  init_val=0.0,
168
170
  )
169
171
  else:
170
- max_x, denom, exp_x = utils.online_softmax_reduce(
172
+ max_x, denom, exp_x = online_softmax_reduce(
171
173
  x,
172
174
  threads_per_row,
173
175
  reduction_buffer[None, None, 0],
@@ -186,7 +188,8 @@ class Softmax(ReductionBase):
186
188
  cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
187
189
 
188
190
 
189
- def _softmax_fwd(x: torch.Tensor) -> torch.Tensor:
191
+ @torch.library.custom_op("quack::_softmax_fwd", mutates_args={"out"})
192
+ def _softmax_fwd(x: torch.Tensor, out: torch.Tensor) -> None:
190
193
  """Softmax forward pass.
191
194
  Args:
192
195
  x: Input tensor of shape (M, N)
@@ -196,8 +199,7 @@ def _softmax_fwd(x: torch.Tensor) -> torch.Tensor:
196
199
  assert x.dim() == 2, "Input must be 2D"
197
200
  assert x.is_cuda, "Tensor must be on CUDA device"
198
201
  assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
199
- M, N = x.shape
200
- out = torch.empty_like(x)
202
+ N = x.size(1)
201
203
  dtype = torch2cute_dtype_map[x.dtype]
202
204
  convert_from_dlpack = lambda tensor: (
203
205
  from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
@@ -213,12 +215,17 @@ def _softmax_fwd(x: torch.Tensor) -> torch.Tensor:
213
215
  softmax_op, x_tensor, out_tensor, current_stream
214
216
  )
215
217
  _softmax_fwd.compile_cache[compile_key](x_tensor, out_tensor, current_stream)
216
- return out
217
218
 
218
219
 
219
220
  _softmax_fwd.compile_cache = {}
220
221
 
221
222
 
223
+ def softmax_fwd(x: torch.Tensor) -> torch.Tensor:
224
+ out = torch.empty_like(x)
225
+ _softmax_fwd(x, out)
226
+ return out
227
+
228
+
222
229
  class SoftmaxBackward(ReductionBase):
223
230
  def __init__(self, dtype: Type[cutlass.Numeric], N: int):
224
231
  # 1 stage for computing dot product
@@ -372,7 +379,7 @@ class SoftmaxBackward(ReductionBase):
372
379
 
373
380
  # Compute dot product: dot = Σⱼ dy_j × y_j
374
381
  threads_per_row = tv_layout.shape[0][0]
375
- dot = utils.row_reduce(
382
+ dot = row_reduce(
376
383
  dy * y,
377
384
  cute.ReductionOp.ADD,
378
385
  threads_per_row,
@@ -394,7 +401,8 @@ class SoftmaxBackward(ReductionBase):
394
401
  cute.copy(copy_atom_store, tdXrdX, tdXgdX, pred=tdXpdX)
395
402
 
396
403
 
397
- def _softmax_backward(dy: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
404
+ @torch.library.custom_op("quack::_softmax_backward", mutates_args={"dx"})
405
+ def _softmax_backward(dy: torch.Tensor, y: torch.Tensor, dx: torch.Tensor) -> None:
398
406
  """Softmax backward pass.
399
407
  Args:
400
408
  dy: Upstream gradients tensor of shape (M, N)
@@ -409,8 +417,7 @@ def _softmax_backward(dy: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
409
417
  assert dy.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
410
418
  assert y.dtype == dy.dtype, "dy and y must have same dtype"
411
419
 
412
- M, N = dy.shape
413
- dx = torch.empty_like(dy)
420
+ N = dy.size(1)
414
421
  dtype = torch2cute_dtype_map[dy.dtype]
415
422
  convert_from_dlpack = lambda tensor: (
416
423
  from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
@@ -427,23 +434,28 @@ def _softmax_backward(dy: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
427
434
  softmax_backward_op, dy_tensor, y_tensor, dx_tensor, current_stream
428
435
  )
429
436
  _softmax_backward.compile_cache[compile_key](dy_tensor, y_tensor, dx_tensor, current_stream)
430
- return dx
431
437
 
432
438
 
433
439
  _softmax_backward.compile_cache = {}
434
440
 
435
441
 
442
+ def softmax_bwd(dy: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
443
+ dx = torch.empty_like(dy)
444
+ _softmax_backward(dy, y, dx)
445
+ return dx
446
+
447
+
436
448
  class SoftmaxFunction(torch.autograd.Function):
437
449
  @staticmethod
438
450
  def forward(ctx, x):
439
- y = _softmax_fwd(x)
451
+ y = softmax_fwd(x)
440
452
  ctx.save_for_backward(y)
441
453
  return y
442
454
 
443
455
  @staticmethod
444
456
  def backward(ctx, dy):
445
457
  (y,) = ctx.saved_tensors
446
- dx = _softmax_backward(dy, y)
458
+ dx = softmax_bwd(dy, y)
447
459
  return dx
448
460
 
449
461
 
@@ -51,7 +51,7 @@ from quack.tile_scheduler import (
51
51
  RasterOrderOption,
52
52
  TriangularTileScheduler,
53
53
  )
54
- from quack.reduction_base import torch2cute_dtype_map
54
+ from quack.cute_dsl_utils import torch2cute_dtype_map
55
55
 
56
56
  # return PipelineStateWAdvance instead of PipelineState
57
57
  from quack.pipeline import make_pipeline_state
@@ -907,8 +907,11 @@ class HopperSymmetricGemmKernel:
907
907
 
908
908
  acc_shape = tiled_mma.partition_shape_C(cute.select(self.tile_shape_mnk, mode=[0, 1]))
909
909
  acc = cute.make_fragment(acc_shape, self.acc_dtype)
910
- if const_expr(self.fp8_slow_accum):
911
- acc_slow = cute.make_fragment(acc_shape, self.acc_dtype)
910
+ acc_slow = (
911
+ cute.make_fragment(acc_shape, self.acc_dtype)
912
+ if const_expr(self.fp8_slow_accum)
913
+ else None
914
+ )
912
915
 
913
916
  if const_expr(self.pingpong):
914
917
  if warp_group_idx == 0:
@@ -99,6 +99,7 @@ class TensorMapManagerSm90(TensorMapManager):
99
99
  for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr):
100
100
  cute.nvgpu.cpasync.cp_fence_tma_desc_release(gmem_ptr, smem_ptr)
101
101
  else:
102
+ assert len(shapes) == len(orders) == len(tensormap_gmem_ptr)
102
103
  for gmem_ptr, shape, order in zip(tensormap_gmem_ptr, shapes, orders):
103
104
  gmem_ptr_i64 = gmem_ptr.toint().ir_value()
104
105
  llvm.inline_asm(
quack/tile_scheduler.py CHANGED
@@ -1,7 +1,7 @@
1
1
  # Copyright (c) 2025, Tri Dao.
2
2
 
3
3
  from typing import Tuple, Optional
4
- from dataclasses import dataclass, fields
4
+ from dataclasses import dataclass
5
5
  from enum import IntEnum
6
6
 
7
7
  import cutlass
@@ -11,30 +11,7 @@ from cutlass import Int32, Boolean, const_expr
11
11
  import quack.utils as utils
12
12
  from quack.fast_math import FastDivmod
13
13
  from quack.pipeline import PipelineStateWAdvance
14
-
15
-
16
- @dataclass
17
- class ParamsBase:
18
- def __extract_mlir_values__(self):
19
- all_fields = [getattr(self, field.name) for field in fields(self)]
20
- non_constexpr_fields = [f for f in all_fields if not isinstance(f, cutlass.Constexpr)]
21
- values, self._values_pos = [], []
22
- for obj in non_constexpr_fields:
23
- obj_values = cutlass.extract_mlir_values(obj)
24
- values += obj_values
25
- self._values_pos.append(len(obj_values))
26
- return values
27
-
28
- def __new_from_mlir_values__(self, values):
29
- all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
30
- constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, cutlass.Constexpr)}
31
- non_constexpr_fields = {
32
- n: f for n, f in all_fields.items() if not isinstance(f, cutlass.Constexpr)
33
- }
34
- for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
35
- non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
36
- values = values[n_items:]
37
- return self.__class__(**non_constexpr_fields, **constexpr_fields)
14
+ from quack.cute_dsl_utils import ArgumentsBase, ParamsBase
38
15
 
39
16
 
40
17
  class RasterOrderOption(IntEnum):
@@ -66,13 +43,24 @@ def get_raster_order_from_option(
66
43
  return raster_order
67
44
 
68
45
 
46
+ # Grouping arguments together that should be passed to __call__
47
+ @dataclass
48
+ class TileSchedulerOptions(ArgumentsBase):
49
+ max_active_clusters: Int32
50
+ raster_order: cutlass.Constexpr[RasterOrderOption] = RasterOrderOption.Heuristic
51
+ max_swizzle_size: Int32 = Int32(8)
52
+ tile_count_semaphore: Optional[cute.Pointer] = None
53
+ batch_idx_permute: Optional[cute.Tensor] = None
54
+
55
+
69
56
  @dataclass
70
- class TileSchedulerArguments(ParamsBase):
57
+ class TileSchedulerArguments(ArgumentsBase):
71
58
  problem_shape_ntile_mnl: cute.Shape
72
- raster_order: RasterOrderOption
59
+ raster_order: cutlass.Constexpr[RasterOrderOption]
73
60
  group_size: Int32
74
61
  cluster_shape_mnk: cutlass.Constexpr[cute.Shape]
75
62
  tile_count_semaphore: Optional[cute.Pointer] = None
63
+ batch_idx_permute: Optional[cute.Tensor] = None
76
64
  is_persistent: cutlass.Constexpr[bool] = False
77
65
 
78
66
 
@@ -87,6 +75,7 @@ class TileScheduler:
87
75
  group_size_tail_divmod: FastDivmod
88
76
  num_clusters_in_group_divmod: FastDivmod
89
77
  tile_count_semaphore: Optional[cute.Pointer]
78
+ batch_idx_permute: Optional[cute.Tensor]
90
79
  cluster_shape_mn: cutlass.Constexpr[cute.Shape]
91
80
  is_persistent: cutlass.Constexpr[bool]
92
81
 
@@ -128,6 +117,7 @@ class TileScheduler:
128
117
  FastDivmod.create(group_size_tail if group_size_tail > 0 else 1),
129
118
  FastDivmod.create(num_clusters_in_group),
130
119
  args.tile_count_semaphore if const_expr(args.is_persistent) else None,
120
+ args.batch_idx_permute,
131
121
  cluster_shape_mn,
132
122
  args.is_persistent,
133
123
  )
@@ -256,7 +246,10 @@ class TileScheduler:
256
246
  bidx_in_cluster = cute.arch.block_in_cluster_idx()
257
247
  pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0]
258
248
  pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
259
- tile_coord_mnkl = (pid_m, pid_n, None, bidz)
249
+ batch_idx = (
250
+ bidz if const_expr(params.batch_idx_permute is None) else params.batch_idx_permute[bidz]
251
+ )
252
+ tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
260
253
  if const_expr(not params.is_persistent):
261
254
  is_valid = self._num_tiles_executed == 0
262
255
  else:
@@ -267,10 +260,10 @@ class TileScheduler:
267
260
  return self.get_current_work(loc=loc, ip=ip)
268
261
 
269
262
  @cute.jit
270
- def fetch_next_work(self, is_scheduler_warp: bool | Boolean, *, loc=None, ip=None):
263
+ def fetch_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None):
271
264
  """is_scheduler_warp should only be true for one warp in the whole cluster"""
272
- if const_expr(self.params.tile_count_semaphore is not None):
273
- params = self.params
265
+ params = self.params
266
+ if const_expr(params.is_persistent and params.tile_count_semaphore is not None):
274
267
  current_work_linear_idx = self._current_work_linear_idx
275
268
  if is_scheduler_warp:
276
269
  if cute.arch.lane_idx() == 0:
@@ -283,6 +276,38 @@ class TileScheduler:
283
276
  current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0)
284
277
  self._current_work_linear_idx = current_work_linear_idx
285
278
 
279
+ # We have to split broadcast_next_work and advance_to_next_work into two functions
280
+ # due to a bug in cute-dsl 4.2: https://github.com/NVIDIA/cutlass/issues/2647
281
+ @cute.jit
282
+ def broadcast_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None):
283
+ """is_scheduler_warp should only be true for one warp in the whole cluster"""
284
+ params = self.params
285
+ if const_expr(params.is_persistent and params.tile_count_semaphore is not None):
286
+ current_work_linear_idx = self._current_work_linear_idx
287
+ if is_scheduler_warp:
288
+ self._scheduler_pipeline.producer_acquire(self._pipeline_state)
289
+ lane_idx = cute.arch.lane_idx()
290
+ if lane_idx < cute.size(params.cluster_shape_mn):
291
+ # cute.printf("Producer bidx = {}, tidx = {}, after empty wait, idx = {}", bidx, tidx, current_work_linear_idx)
292
+ if const_expr(cute.size(params.cluster_shape_mn) == 1):
293
+ self._tile_count[self._pipeline_state.index] = current_work_linear_idx
294
+ self._scheduler_pipeline.producer_commit(self._pipeline_state)
295
+ else:
296
+ peer_cta_rank_in_cluster = lane_idx
297
+ mbar_ptr = self._scheduler_pipeline.producer_get_barrier(
298
+ self._pipeline_state
299
+ )
300
+ cute.arch.mbarrier_arrive_and_expect_tx(
301
+ mbar_ptr, 4, peer_cta_rank_in_cluster
302
+ )
303
+ utils.store_shared_remote(
304
+ val=current_work_linear_idx,
305
+ smem_ptr=self._tile_count.iterator + self._pipeline_state.index,
306
+ mbar_ptr=mbar_ptr,
307
+ peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
308
+ )
309
+ # cute.printf("Producer bidx = {}, tidx = {}, after full arrive", bidx, tidx)
310
+
286
311
  @cute.jit
287
312
  def advance_to_next_work(
288
313
  self,
@@ -300,32 +325,10 @@ class TileScheduler:
300
325
  if const_expr(params.tile_count_semaphore is None): # Static persistent
301
326
  self._current_work_linear_idx += advance_count * Int32(num_persistent_clusters)
302
327
  else: # Dynamic persistent
303
- self._pipeline_state.advance_iters(advance_count - 1)
328
+ if const_expr(advance_count > 1):
329
+ self._pipeline_state.advance_iters(advance_count - 1)
304
330
  current_work_linear_idx = self._current_work_linear_idx
305
- if is_scheduler_warp:
306
- self._scheduler_pipeline.producer_acquire(self._pipeline_state)
307
- lane_idx = cute.arch.lane_idx()
308
- if lane_idx < cute.size(params.cluster_shape_mn):
309
- # cute.printf("Producer bidx = {}, tidx = {}, after empty wait, idx = {}", bidx, tidx, current_work_linear_idx)
310
- if const_expr(cute.size(params.cluster_shape_mn) == 1):
311
- self._tile_count[self._pipeline_state.index] = current_work_linear_idx
312
- self._scheduler_pipeline.producer_commit(self._pipeline_state)
313
- else:
314
- peer_cta_rank_in_cluster = lane_idx
315
- mbar_ptr = self._scheduler_pipeline.producer_get_barrier(
316
- self._pipeline_state
317
- )
318
- cute.arch.mbarrier_arrive_and_expect_tx(
319
- mbar_ptr, 4, peer_cta_rank_in_cluster
320
- )
321
- utils.store_shared_remote(
322
- val=current_work_linear_idx,
323
- smem_ptr=self._tile_count.iterator + self._pipeline_state.index,
324
- mbar_ptr=mbar_ptr,
325
- peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
326
- )
327
- # cute.printf("Producer bidx = {}, tidx = {}, after full arrive", bidx, tidx)
328
- else:
331
+ if not is_scheduler_warp:
329
332
  # if tidx % 64 == 0: cute.printf("bidx = {},tidx = {}, before full wait, idx = {}", bidx, tidx, current_work_linear_idx)
330
333
  self._scheduler_pipeline.consumer_wait(self._pipeline_state)
331
334
  # if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after full wait, idx = {}", bidx, tidx, current_work_linear_idx)
@@ -580,7 +583,7 @@ class VarlenMTileSchedulerArguments(ParamsBase):
580
583
  cu_seqlens_m: cute.Tensor
581
584
  raster_order: cutlass.Constexpr[RasterOrderOption]
582
585
  group_size: Int32
583
- tile_shape_mnk: cutlass.Constexpr[cute.Shape]
586
+ tile_shape_mn: cutlass.Constexpr[cute.Shape]
584
587
  cluster_shape_mnk: cutlass.Constexpr[cute.Shape]
585
588
  tile_count_semaphore: Optional[cute.Pointer] = None
586
589
  is_persistent: cutlass.Constexpr[bool] = False
@@ -609,7 +612,6 @@ class VarlenMTileScheduler(TileScheduler):
609
612
  ) -> "VarlenMTileScheduler.Params":
610
613
  assert args.cluster_shape_mnk[2] == 1
611
614
  cluster_shape_mn = const_expr(cute.select(args.cluster_shape_mnk, mode=[0, 1]))
612
- tile_shape_mn = const_expr(cute.select(args.tile_shape_mnk, mode=[0, 1]))
613
615
  # problem_shape_ntile_mnl[0] will be None for VarlenM
614
616
  problem_shape_ntile_mn = cute.select(args.problem_shape_ntile_mnl, mode=[0, 1])
615
617
  problem_shape_ncluster_mn = (
@@ -657,7 +659,7 @@ class VarlenMTileScheduler(TileScheduler):
657
659
  FastDivmod.create(num_clusters_in_group)
658
660
  if num_clusters_in_group is not None
659
661
  else None,
660
- tile_shape_mn,
662
+ args.tile_shape_mn,
661
663
  args.tile_count_semaphore if const_expr(args.is_persistent) else None,
662
664
  cluster_shape_mn,
663
665
  args.is_persistent,
quack/topk.py CHANGED
@@ -12,7 +12,7 @@ from cutlass.cute.runtime import from_dlpack
12
12
  from cutlass import const_expr
13
13
 
14
14
  import quack.utils as utils
15
- from quack.reduction_base import torch2cute_dtype_map
15
+ from quack.cute_dsl_utils import torch2cute_dtype_map
16
16
  from quack.sort.bitonic_sort import bitonic_topk
17
17
 
18
18
 
@@ -133,6 +133,7 @@ class TopK:
133
133
 
134
134
  threads_per_row = tv_layout.shape[0][0]
135
135
  topk_vals = bitonic_topk(tXrX_f32, self.k, warp_width=threads_per_row)
136
+
136
137
  # Extract indices and clean values
137
138
  topk_vals_u32 = cute.recast_tensor(topk_vals, cutlass.Uint32)
138
139
  topk_indices = cute.make_fragment(self.k, cutlass.Int32)
@@ -166,7 +167,8 @@ class TopK:
166
167
  cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i])
167
168
 
168
169
 
169
- def _topk_fwd(x: torch.Tensor, k: int):
170
+ @torch.library.custom_op("quack::_topk_fwd", mutates_args={"values", "indices"})
171
+ def _topk_fwd(x: torch.Tensor, k: int, values: torch.Tensor, indices: torch.Tensor) -> None:
170
172
  """Top-k forward pass.
171
173
  Args:
172
174
  x: Input tensor of shape (M, N)
@@ -179,9 +181,7 @@ def _topk_fwd(x: torch.Tensor, k: int):
179
181
  assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
180
182
  assert k > 0 and k <= x.shape[1], "k must be positive and <= N"
181
183
 
182
- M, N = x.shape
183
- values = torch.empty((M, k), dtype=x.dtype, device=x.device)
184
- indices = torch.empty((M, k), dtype=torch.int32, device=x.device)
184
+ N = x.size(1)
185
185
 
186
186
  dtype = torch2cute_dtype_map[x.dtype]
187
187
  convert_from_dlpack = lambda tensor: (
@@ -202,8 +202,6 @@ def _topk_fwd(x: torch.Tensor, k: int):
202
202
  )
203
203
  _topk_fwd.compile_cache[compile_key](x_tensor, values_tensor, indices_tensor, current_stream)
204
204
 
205
- return values, indices
206
-
207
205
 
208
206
  _topk_fwd.compile_cache = {}
209
207
 
@@ -218,4 +216,12 @@ def topk(x: torch.Tensor, k: int):
218
216
  Returns:
219
217
  Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k))
220
218
  """
221
- return _topk_fwd(x, k)
219
+
220
+ M = x.size(0)
221
+
222
+ values = torch.empty((M, k), dtype=x.dtype, device=x.device)
223
+ indices = torch.empty((M, k), dtype=torch.int32, device=x.device)
224
+
225
+ _topk_fwd(x, k, values, indices)
226
+
227
+ return values, indices